From 2cab1e2ea7d9dd909238189a327e3f34b079ed04 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Wed, 28 Aug 2024 09:48:21 +0800 Subject: [PATCH] add reweight es naqs. not completed yet when sampling, it use self.es, and when get ws, it use self.psi self.psi is updated by outside, but self.es need to be optimize too, which is not implemented yet. --- .../sampling_neural_state/observer.py | 57 +++- .../tetragono/sampling_neural_state/state.py | 18 +- tetraku/tetraku/networks/naqs/reweight.py | 246 ++++++++++++++++++ 3 files changed, 315 insertions(+), 6 deletions(-) create mode 100644 tetraku/tetraku/networks/naqs/reweight.py diff --git a/tetragono/tetragono/sampling_neural_state/observer.py b/tetragono/tetragono/sampling_neural_state/observer.py index 51fab1de..ac7d772c 100644 --- a/tetragono/tetragono/sampling_neural_state/observer.py +++ b/tetragono/tetragono/sampling_neural_state/observer.py @@ -18,9 +18,39 @@ import numpy as np import torch -from ..utility import allreduce_buffer, allreduce_number, show, showln +from ..utility import allreduce_buffer, allreduce_number, show, showln, mpi_comm from .state import Configuration, index_tensor_element +opt = None + + +def torch_tensor_allgather(tensor): + from mpi4py import MPI + # Get the device of the input tensor + device = tensor.device + + # Convert torch tensor to numpy array + np_array = tensor.cpu().detach().numpy() + + # Initialize MPI + comm = mpi_comm + rank = comm.Get_rank() + size = comm.Get_size() + + counts = comm.allgather(np_array.size) + first = comm.allgather(np_array.shape[0]) + total_length = sum(first) + # Create a buffer to hold all gathered numpy arrays + gathered_np_arrays = np.empty((total_length, *np_array.shape[1:]), dtype=np_array.dtype) + + # Perform allgather + comm.Allgatherv(np_array, [gathered_np_arrays, counts]) + + # Convert gathered numpy arrays back to torch tensor + gathered_tensor = torch.from_numpy(gathered_np_arrays).to(device) + + return gathered_tensor + class Observer(): """ @@ -66,8 +96,7 @@ def __enter__(self): if self._enable_gradient: self._Delta = None self._EDelta = None - if self._enable_natural: - self._Deltas = [] + self._Deltas = [] # 临时使用这个list做别的用处 def __exit__(self, exc_type, exc_val, exc_tb): """ @@ -114,6 +143,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): allreduce_buffer(self._Delta) allreduce_buffer(self._EDelta) + cs = torch.stack([c for c, e in self._Deltas]) + es = torch.tensor([e for c, e in self._Deltas], dtype=torch.complex128, device=cs.device) + cs = torch_tensor_allgather(cs) + es = torch.view_as_complex(torch_tensor_allgather(torch.view_as_real(es))) + es = es - es.mean() # 总之这个是用来采样的东西,以后可能会添加别的比如Delta也乘进去 + with torch.enable_grad(): + global opt + if opt is None: + opt = torch.optim.Adam(self.owner.network.es.parameters(), 1e-2) + for _ in range(100): + hes = self.owner.network.es(cs) + error = hes / hes.norm() - es / es.norm() + error = (error.abs()**2).mean() + show(error.item()) + opt.zero_grad() + error.backward() + opt.step() + showln("es error", error.item()) + def __init__( self, owner, @@ -395,6 +443,9 @@ def __call__(self, configurations, amplitudes, weights, multiplicities): name].imag * reweight if name == "energy" and self._enable_gradient: Es = whole_result[batch_index][name] + # train self.es + # collect and optimize self.es + self._Deltas.append((configurations[batch_index], Es)) if self.owner.Tensor.is_real: Es = Es.real diff --git a/tetragono/tetragono/sampling_neural_state/state.py b/tetragono/tetragono/sampling_neural_state/state.py index 342e8bee..9f17f6b2 100644 --- a/tetragono/tetragono/sampling_neural_state/state.py +++ b/tetragono/tetragono/sampling_neural_state/state.py @@ -348,16 +348,28 @@ def holes(self, value): if self.Tensor.is_complex: with torch_grad(True): value.real.backward(retain_graph=True) - real = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad]) + real = torch.cat([ + param.grad.reshape([-1]) + for param in self.network.parameters() + if param.requires_grad and param.grad is not None + ]) self.network.zero_grad() with torch_grad(True): value.imag.backward() - imag = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad]) + imag = torch.cat([ + param.grad.reshape([-1]) + for param in self.network.parameters() + if param.requires_grad and param.grad is not None + ]) self.network.zero_grad() result = (real + 1j * imag) else: value.backward() - result = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad]) + result = torch.cat([ + param.grad.reshape([-1]) + for param in self.network.parameters() + if param.requires_grad and param.grad is not None + ]) self.network.zero_grad() result = result / value return result.detach_() diff --git a/tetraku/tetraku/networks/naqs/reweight.py b/tetraku/tetraku/networks/naqs/reweight.py new file mode 100644 index 00000000..589fb2cf --- /dev/null +++ b/tetraku/tetraku/networks/naqs/reweight.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (C) 2024 Hao Zhang +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +import torch + + +class FakeLinear(torch.nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.bias = torch.nn.Parameter(torch.zeros([dim_out])) + + def forward(self, x): + shape = x.shape[:-1] + prod = torch.tensor(shape).prod() + return self.bias.view([1, -1]).expand([prod, -1]).view([*shape, -1]) + + +def Linear(dim_in, dim_out): + if dim_in == 0: + return FakeLinear(dim_in, dim_out) + else: + return torch.nn.Linear(dim_in, dim_out) + + +class MLP(torch.nn.Module): + + def __init__(self, dim_input, dim_output, hidden_size): + super().__init__() + self.dim_input = dim_input + self.dim_output = dim_output + self.hidden_size = hidden_size + self.depth = len(hidden_size) + + self.model = torch.nn.Sequential(*(Linear( + dim_input if i == 0 else hidden_size[i - 1], + dim_output if i == self.depth else hidden_size[i], + ) if j == 0 else torch.nn.SiLU() for i in range(self.depth + 1) for j in range(2) if i != self.depth or j != 1)) + + def forward(self, x): + return self.model(x) + + +class WaveFunction(torch.nn.Module): + + def __init__( + self, + *, + L1, + L2, + orbit_num, + physical_dim, + is_complex, + spin_up, + spin_down, + hidden_size, + ordering, + ): + super().__init__() + self.L1 = L1 + self.L2 = L2 + self.orbit_num = orbit_num + self.sites = L1 * L2 * orbit_num // 2 + assert physical_dim == 2 + assert is_complex == True + self.spin_up = spin_up + self.spin_down = spin_down + self.hidden_size = tuple(hidden_size) + + self.amplitude = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)]) + self.phase = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)]) + + if isinstance(ordering, int) and ordering == +1: + ordering = list(range(self.sites)) + if isinstance(ordering, int) and ordering == -1: + ordering = list(reversed(range(self.sites))) + self.register_buffer('ordering', torch.tensor(ordering, dtype=torch.int64), persistent=True) + ordering_bak = torch.zeros(self.sites, dtype=torch.int64) + ordering_bak.scatter_(0, self.ordering, torch.arange(self.sites)) + self.register_buffer('ordering_bak', ordering_bak, persistent=True) + + def mask(self, x): + # x : batch * i * 2 + i = x.size(1) + # number : batch * 2 + number = x.sum(dim=1) + + up_electron = number[:, 0] + down_electron = number[:, 1] + up_hole = i - up_electron + down_hole = i - down_electron + + add_up_electron = up_electron < self.spin_up + add_down_electron = down_electron < self.spin_down + add_up_hole = up_hole < self.sites - self.spin_up + add_down_hole = down_hole < self.sites - self.spin_down + + add_up = torch.stack([add_up_hole, add_up_electron], dim=-1).unsqueeze(-1) + add_down = torch.stack([add_down_hole, add_down_electron], dim=-1).unsqueeze(-2) + add = torch.logical_and(add_up, add_down) + return add + + def normalize_amplitude(self, x): + param = -(2 * x).exp().sum(dim=[1, 2]).log() / 2 + x = x + param.unsqueeze(-1).unsqueeze(-1) + return x + + def forward(self, x): + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + + batch_size = x.size(0) + x = x.reshape([batch_size, self.sites, 2]) + x = torch.index_select(x, 1, self.ordering_bak) + + xf = x.to(dtype=dtype) + arange = torch.arange(batch_size, device=device) + total_amplitude = 0 + total_phase = 0 + for i in range(self.sites): + amplitude = self.amplitude[i](xf[:, :i].reshape([batch_size, 2 * i])).reshape([batch_size, 2, 2]) + phase = self.phase[i](xf[:, :i].reshape([batch_size, 2 * i])).reshape([batch_size, 2, 2]) + amplitude = amplitude + torch.where(self.mask(x[:, :i]), 0, -torch.inf) + amplitude = self.normalize_amplitude(amplitude) + amplitude = amplitude[arange, x[:, i, 0], x[:, i, 1]] + phase = phase[arange, x[:, i, 0], x[:, i, 1]] + total_amplitude = total_amplitude + amplitude + total_phase = total_phase + phase + return (total_amplitude + 1j * total_phase).exp() + + def binomial(self, count, possibility): + possibility = torch.clamp(possibility, min=0, max=1) + possibility = torch.where(count == 0, 0, possibility) + dist = torch.distributions.binomial.Binomial(count, possibility) + result = dist.sample() + result = result.to(dtype=torch.int64) + # Numerical error since result was cast to float. + return torch.clamp(result, min=torch.zeros_like(count), max=count) + + def generate(self, batch_size, alpha=1): + # https://arxiv.org/pdf/2109.12606 + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + assert alpha == 1 + + x = torch.empty([1, 0, 2], device=device, dtype=torch.int64) + multiplicity = torch.tensor([batch_size], dtype=torch.int64, device=device) + amplitude_phase = torch.tensor([0], dtype=dtype.to_complex(), device=device) + for i in range(self.sites): + local_batch_size = x.size(0) + + xf = x.to(dtype=dtype) + amplitude = self.amplitude[i](xf.reshape([local_batch_size, 2 * i])).reshape([local_batch_size, 2, 2]) + phase = self.phase[i](xf.reshape([local_batch_size, 2 * i])).reshape([local_batch_size, 2, 2]) + amplitude = amplitude + torch.where(self.mask(x), 0, -torch.inf) + amplitude = self.normalize_amplitude(amplitude) + delta_amplitude_phase = (amplitude + 1j * phase).reshape([local_batch_size, 4]) + probability = (2 * amplitude).exp().reshape([local_batch_size, 4]) + probability = probability / probability.sum(dim=-1).unsqueeze(-1) + + sample0123 = multiplicity + prob23 = probability[:, 2] + probability[:, 3] + prob01 = probability[:, 0] + probability[:, 1] + sample23 = self.binomial(sample0123, prob23) + sample3 = self.binomial(sample23, probability[:, 3] / prob23) + sample2 = sample23 - sample3 + sample01 = sample0123 - sample23 + sample1 = self.binomial(sample01, probability[:, 1] / prob01) + sample0 = sample01 - sample1 + + x0 = torch.cat([x, torch.tensor([[0, 0]], device=device).expand(local_batch_size, -1, -1)], dim=1) + x1 = torch.cat([x, torch.tensor([[0, 1]], device=device).expand(local_batch_size, -1, -1)], dim=1) + x2 = torch.cat([x, torch.tensor([[1, 0]], device=device).expand(local_batch_size, -1, -1)], dim=1) + x3 = torch.cat([x, torch.tensor([[1, 1]], device=device).expand(local_batch_size, -1, -1)], dim=1) + + new_x = torch.cat([x0, x1, x2, x3]) + new_multiplicity = torch.cat([sample0, sample1, sample2, sample3]) + new_amplitude_phase = (amplitude_phase.unsqueeze(0) + delta_amplitude_phase.permute(1, 0)).reshape([-1]) + + selected = new_multiplicity != 0 + x = new_x[selected] + multiplicity = new_multiplicity[selected] + amplitude_phase = new_amplitude_phase[selected] + + real_amplitude = amplitude_phase.exp() + real_probability = (real_amplitude.conj() * real_amplitude).real + x = torch.index_select(x, 1, self.ordering) + return x.reshape([x.size(0), self.L1, self.L2, self.orbit_num]), real_amplitude, torch.ones_like(real_probability), torch.ones_like(multiplicity) + + +class ReweightWaveFunction(torch.nn.Module): + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__() + self.psi = WaveFunction(*args, **kwargs) + self._es = WaveFunction(*args, **kwargs).cuda(), + self.es.load_state_dict(self.psi.state_dict()) + self.es.cuda() + + @property + def es(self): + return self._es[0] + + def forward(self, x): + return self.psi(x) + + def generate(self, batch_size, alpha=1): + configurations, _, weights, multiplicities = self.es.generate(batch_size, alpha) + amplitudes = self(configurations) + return configurations, amplitudes, weights, multiplicities + + +def network(state, spin_up, spin_down, hidden_size, ordering=+1): + max_orbit_index = max(orbit for [l1, l2, orbit], edge in state.physics_edges) + max_physical_dim = max(edge.dimension for [l1, l2, orbit], edge in state.physics_edges) + network = ReweightWaveFunction( + L1=state.L1, + L2=state.L2, + orbit_num=max_orbit_index + 1, + physical_dim=max_physical_dim, + is_complex=state.Tensor.is_complex, + spin_up=spin_up, + spin_down=spin_down, + hidden_size=hidden_size, + ordering=ordering, + ).double() + return network