Skip to content

Commit

Permalink
add reweight es naqs. not completed yet
Browse files Browse the repository at this point in the history
when sampling, it use self.es, and when get ws, it use self.psi
self.psi is updated by outside, but self.es need to be optimize too, which is
not implemented yet.
  • Loading branch information
hzhangxyz committed Sep 3, 2024
1 parent 85d01a2 commit 8458a19
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 6 deletions.
57 changes: 54 additions & 3 deletions tetragono/tetragono/sampling_neural_state/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,39 @@

import numpy as np
import torch
from ..utility import allreduce_buffer, allreduce_number, show, showln
from ..utility import allreduce_buffer, allreduce_number, show, showln, mpi_comm
from .state import Configuration, index_tensor_element

opt = None


def torch_tensor_allgather(tensor):
from mpi4py import MPI
# Get the device of the input tensor
device = tensor.device

# Convert torch tensor to numpy array
np_array = tensor.cpu().detach().numpy()

# Initialize MPI
comm = mpi_comm
rank = comm.Get_rank()
size = comm.Get_size()

counts = comm.allgather(np_array.size)
first = comm.allgather(np_array.shape[0])
total_length = sum(first)
# Create a buffer to hold all gathered numpy arrays
gathered_np_arrays = np.empty((total_length, *np_array.shape[1:]), dtype=np_array.dtype)

# Perform allgather
comm.Allgatherv(np_array, [gathered_np_arrays, counts])

# Convert gathered numpy arrays back to torch tensor
gathered_tensor = torch.from_numpy(gathered_np_arrays).to(device)

return gathered_tensor


class Observer():
"""
Expand Down Expand Up @@ -66,8 +96,7 @@ def __enter__(self):
if self._enable_gradient:
self._Delta = None
self._EDelta = None
if self._enable_natural:
self._Deltas = []
self._Deltas = [] # 临时使用这个list做别的用处

def __exit__(self, exc_type, exc_val, exc_tb):
"""
Expand Down Expand Up @@ -114,6 +143,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
allreduce_buffer(self._Delta)
allreduce_buffer(self._EDelta)

cs = torch.stack([c for c, e in self._Deltas])
es = torch.tensor([e for c, e in self._Deltas], dtype=torch.complex128, device=cs.device)
cs = torch_tensor_allgather(cs)
es = torch.view_as_complex(torch_tensor_allgather(torch.view_as_real(es)))
es = es - es.mean() # 总之这个是用来采样的东西,以后可能会添加别的比如Delta也乘进去
with torch.enable_grad():
global opt
if opt is None:
opt = torch.optim.Adam(self.owner.network.es.parameters(), 1e-2)
for _ in range(100):
hes = self.owner.network.es(cs)
error = hes / hes.norm() - es / es.norm()
error = (error.abs()**2).mean()
show(error.item())
opt.zero_grad()
error.backward()
opt.step()
showln("es error", error.item())

def __init__(
self,
owner,
Expand Down Expand Up @@ -395,6 +443,9 @@ def __call__(self, configurations, amplitudes, weights, multiplicities):
name].imag * reweight
if name == "energy" and self._enable_gradient:
Es = whole_result[batch_index][name]
# train self.es
# collect and optimize self.es
self._Deltas.append((configurations[batch_index], Es))
if self.owner.Tensor.is_real:
Es = Es.real

Expand Down
18 changes: 15 additions & 3 deletions tetragono/tetragono/sampling_neural_state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,28 @@ def holes(self, value):
if self.Tensor.is_complex:
with torch_grad(True):
value.real.backward(retain_graph=True)
real = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad])
real = torch.cat([
param.grad.reshape([-1])
for param in self.network.parameters()
if param.requires_grad and param.grad is not None
])
self.network.zero_grad()
with torch_grad(True):
value.imag.backward()
imag = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad])
imag = torch.cat([
param.grad.reshape([-1])
for param in self.network.parameters()
if param.requires_grad and param.grad is not None
])
self.network.zero_grad()
result = (real + 1j * imag)
else:
value.backward()
result = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad])
result = torch.cat([
param.grad.reshape([-1])
for param in self.network.parameters()
if param.requires_grad and param.grad is not None
])
self.network.zero_grad()
result = result / value
return result.detach_()
Expand Down
246 changes: 246 additions & 0 deletions tetraku/tetraku/networks/naqs/reweight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2024 Hao Zhang<[email protected]>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

import torch


class FakeLinear(torch.nn.Module):

def __init__(self, dim_in, dim_out):
super().__init__()
self.bias = torch.nn.Parameter(torch.zeros([dim_out]))

def forward(self, x):
shape = x.shape[:-1]
prod = torch.tensor(shape).prod()
return self.bias.view([1, -1]).expand([prod, -1]).view([*shape, -1])


def Linear(dim_in, dim_out):
if dim_in == 0:
return FakeLinear(dim_in, dim_out)
else:
return torch.nn.Linear(dim_in, dim_out)


class MLP(torch.nn.Module):

def __init__(self, dim_input, dim_output, hidden_size):
super().__init__()
self.dim_input = dim_input
self.dim_output = dim_output
self.hidden_size = hidden_size
self.depth = len(hidden_size)

self.model = torch.nn.Sequential(*(Linear(
dim_input if i == 0 else hidden_size[i - 1],
dim_output if i == self.depth else hidden_size[i],
) if j == 0 else torch.nn.SiLU() for i in range(self.depth + 1) for j in range(2) if i != self.depth or j != 1))

def forward(self, x):
return self.model(x)


class WaveFunction(torch.nn.Module):

def __init__(
self,
*,
L1,
L2,
orbit_num,
physical_dim,
is_complex,
spin_up,
spin_down,
hidden_size,
ordering,
):
super().__init__()
self.L1 = L1
self.L2 = L2
self.orbit_num = orbit_num
self.sites = L1 * L2 * orbit_num // 2
assert physical_dim == 2
assert is_complex == True
self.spin_up = spin_up
self.spin_down = spin_down
self.hidden_size = tuple(hidden_size)

self.amplitude = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)])
self.phase = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)])

if isinstance(ordering, int) and ordering == +1:
ordering = list(range(self.sites))
if isinstance(ordering, int) and ordering == -1:
ordering = list(reversed(range(self.sites)))
self.register_buffer('ordering', torch.tensor(ordering, dtype=torch.int64), persistent=True)
ordering_bak = torch.zeros(self.sites, dtype=torch.int64)
ordering_bak.scatter_(0, self.ordering, torch.arange(self.sites))
self.register_buffer('ordering_bak', ordering_bak, persistent=True)

def mask(self, x):
# x : batch * i * 2
i = x.size(1)
# number : batch * 2
number = x.sum(dim=1)

up_electron = number[:, 0]
down_electron = number[:, 1]
up_hole = i - up_electron
down_hole = i - down_electron

add_up_electron = up_electron < self.spin_up
add_down_electron = down_electron < self.spin_down
add_up_hole = up_hole < self.sites - self.spin_up
add_down_hole = down_hole < self.sites - self.spin_down

add_up = torch.stack([add_up_hole, add_up_electron], dim=-1).unsqueeze(-1)
add_down = torch.stack([add_down_hole, add_down_electron], dim=-1).unsqueeze(-2)
add = torch.logical_and(add_up, add_down)
return add

def normalize_amplitude(self, x):
param = -(2 * x).exp().sum(dim=[1, 2]).log() / 2
x = x + param.unsqueeze(-1).unsqueeze(-1)
return x

def forward(self, x):
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype

batch_size = x.size(0)
x = x.reshape([batch_size, self.sites, 2])
x = torch.index_select(x, 1, self.ordering_bak)

xf = x.to(dtype=dtype)
arange = torch.arange(batch_size, device=device)
total_amplitude = 0
total_phase = 0
for i in range(self.sites):
amplitude = self.amplitude[i](xf[:, :i].reshape([batch_size, 2 * i])).reshape([batch_size, 2, 2])
phase = self.phase[i](xf[:, :i].reshape([batch_size, 2 * i])).reshape([batch_size, 2, 2])
amplitude = amplitude + torch.where(self.mask(x[:, :i]), 0, -torch.inf)
amplitude = self.normalize_amplitude(amplitude)
amplitude = amplitude[arange, x[:, i, 0], x[:, i, 1]]
phase = phase[arange, x[:, i, 0], x[:, i, 1]]
total_amplitude = total_amplitude + amplitude
total_phase = total_phase + phase
return (total_amplitude + 1j * total_phase).exp()

def binomial(self, count, possibility):
possibility = torch.clamp(possibility, min=0, max=1)
possibility = torch.where(count == 0, 0, possibility)
dist = torch.distributions.binomial.Binomial(count, possibility)
result = dist.sample()
result = result.to(dtype=torch.int64)
# Numerical error since result was cast to float.
return torch.clamp(result, min=torch.zeros_like(count), max=count)

def generate(self, batch_size, alpha=1):
# https://arxiv.org/pdf/2109.12606
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
assert alpha == 1

x = torch.empty([1, 0, 2], device=device, dtype=torch.int64)
multiplicity = torch.tensor([batch_size], dtype=torch.int64, device=device)
amplitude_phase = torch.tensor([0], dtype=dtype.to_complex(), device=device)
for i in range(self.sites):
local_batch_size = x.size(0)

xf = x.to(dtype=dtype)
amplitude = self.amplitude[i](xf.reshape([local_batch_size, 2 * i])).reshape([local_batch_size, 2, 2])
phase = self.phase[i](xf.reshape([local_batch_size, 2 * i])).reshape([local_batch_size, 2, 2])
amplitude = amplitude + torch.where(self.mask(x), 0, -torch.inf)
amplitude = self.normalize_amplitude(amplitude)
delta_amplitude_phase = (amplitude + 1j * phase).reshape([local_batch_size, 4])
probability = (2 * amplitude).exp().reshape([local_batch_size, 4])
probability = probability / probability.sum(dim=-1).unsqueeze(-1)

sample0123 = multiplicity
prob23 = probability[:, 2] + probability[:, 3]
prob01 = probability[:, 0] + probability[:, 1]
sample23 = self.binomial(sample0123, prob23)
sample3 = self.binomial(sample23, probability[:, 3] / prob23)
sample2 = sample23 - sample3
sample01 = sample0123 - sample23
sample1 = self.binomial(sample01, probability[:, 1] / prob01)
sample0 = sample01 - sample1

x0 = torch.cat([x, torch.tensor([[0, 0]], device=device).expand(local_batch_size, -1, -1)], dim=1)
x1 = torch.cat([x, torch.tensor([[0, 1]], device=device).expand(local_batch_size, -1, -1)], dim=1)
x2 = torch.cat([x, torch.tensor([[1, 0]], device=device).expand(local_batch_size, -1, -1)], dim=1)
x3 = torch.cat([x, torch.tensor([[1, 1]], device=device).expand(local_batch_size, -1, -1)], dim=1)

new_x = torch.cat([x0, x1, x2, x3])
new_multiplicity = torch.cat([sample0, sample1, sample2, sample3])
new_amplitude_phase = (amplitude_phase.unsqueeze(0) + delta_amplitude_phase.permute(1, 0)).reshape([-1])

selected = new_multiplicity != 0
x = new_x[selected]
multiplicity = new_multiplicity[selected]
amplitude_phase = new_amplitude_phase[selected]

real_amplitude = amplitude_phase.exp()
real_probability = (real_amplitude.conj() * real_amplitude).real
x = torch.index_select(x, 1, self.ordering)
return x.reshape([x.size(0), self.L1, self.L2, self.orbit_num]), real_amplitude, torch.ones_like(real_probability), torch.ones_like(multiplicity)


class ReweightWaveFunction(torch.nn.Module):

def __init__(
self,
*args,
**kwargs,
):
super().__init__()
self.psi = WaveFunction(*args, **kwargs)
self._es = WaveFunction(*args, **kwargs).cuda(),
self.es.load_state_dict(self.psi.state_dict())
self.es.cuda()

@property
def es(self):
return self._es[0]

def forward(self, x):
return self.psi(x)

def generate(self, batch_size, alpha=1):
configurations, _, weights, multiplicities = self.es.generate(batch_size, alpha)
amplitudes = self(configurations)
return configurations, amplitudes, weights, multiplicities


def network(state, spin_up, spin_down, hidden_size, ordering=+1):
max_orbit_index = max(orbit for [l1, l2, orbit], edge in state.physics_edges)
max_physical_dim = max(edge.dimension for [l1, l2, orbit], edge in state.physics_edges)
network = ReweightWaveFunction(
L1=state.L1,
L2=state.L2,
orbit_num=max_orbit_index + 1,
physical_dim=max_physical_dim,
is_complex=state.Tensor.is_complex,
spin_up=spin_up,
spin_down=spin_down,
hidden_size=hidden_size,
ordering=ordering,
).double()
return network

0 comments on commit 8458a19

Please sign in to comment.