diff --git a/tetraku/tetraku/networks/naqs/__init__.py b/tetraku/tetraku/networks/naqs/__init__.py index 11599ddb..54cd0e6c 100644 --- a/tetraku/tetraku/networks/naqs/__init__.py +++ b/tetraku/tetraku/networks/naqs/__init__.py @@ -200,7 +200,7 @@ def generate(self, batch_size, alpha=1): real_amplitude = amplitude_phase.exp() real_probability = (real_amplitude.conj() * real_amplitude).real x = torch.index_select(x, 1, self.ordering) - return x.reshape([x.size(0), self.L1, self.L2, self.orbit_num]), real_amplitude, real_probability, multiplicity + return x.reshape([x.size(0), self.L1, self.L2, self.orbit_num]), real_amplitude, torch.ones_like(real_probability), torch.ones_like(multiplicity) def network(state, spin_up, spin_down, hidden_size, ordering=+1):