Skip to content

Commit

Permalink
Implement naqs ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Aug 18, 2024
1 parent 13e1317 commit e33275a
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion tetraku/tetraku/networks/naqs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
spin_up,
spin_down,
hidden_size,
ordering,
):
super().__init__()
self.L1 = L1
Expand All @@ -65,6 +66,15 @@ def __init__(
self.amplitude = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)])
self.phase = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)])

if isinstance(ordering, int) and ordering == +1:
ordering = list(range(self.sites))
if isinstance(ordering, int) and ordering == -1:
ordering = list(reversed(range(self.sites)))
self.register_buffer('ordering', torch.tensor(ordering, dtype=torch.int64), persistent=True)
ordering_bak = torch.zeros(self.sites, dtype=torch.int64)
ordering_bak.scatter_(0, self.ordering, torch.arange(self.sites))
self.register_buffer('ordering_bak', ordering_bak, persistent=True)

def mask(self, x):
# x : batch * i * 2
i = x.size(1)
Expand Down Expand Up @@ -97,6 +107,7 @@ def forward(self, x):

batch_size = x.size(0)
x = x.reshape([batch_size, self.sites, 2])
x = torch.index_select(x, 1, self.ordering_bak)

xf = x.to(dtype=dtype)
arange = torch.arange(batch_size, device=device)
Expand Down Expand Up @@ -167,10 +178,11 @@ def generate(self, batch_size, alpha=1):

real_amplitude = amplitude_phase.exp()
real_probability = (real_amplitude.conj() * real_amplitude).real
x = torch.index_select(x, 1, self.ordering)
return x.reshape([x.size(0), self.L1, self.L2, self.orbit_num]), real_amplitude, real_probability, multiplicity


def network(state, spin_up, spin_down, hidden_size):
def network(state, spin_up, spin_down, hidden_size, ordering=+1):
max_orbit_index = max(orbit for [l1, l2, orbit], edge in state.physics_edges)
max_physical_dim = max(edge.dimension for [l1, l2, orbit], edge in state.physics_edges)
network = WaveFunction(
Expand All @@ -182,5 +194,6 @@ def network(state, spin_up, spin_down, hidden_size):
spin_up=spin_up,
spin_down=spin_down,
hidden_size=hidden_size,
ordering=ordering,
).double()
return network

0 comments on commit e33275a

Please sign in to comment.