Skip to content

Commit

Permalink
Only calculate mask for fermion tensor im reverse edge.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 22, 2023
1 parent ddd3b95 commit 25abd46
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions tat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,18 +786,20 @@ def reverse_edge(
parity_exclude_names = set()
assert all(name in self.names for name in reversed_names)
assert all(name in reversed_names for name in parity_exclude_names)
# Parity is xor of all valid reverse parity
parity = functools.reduce(
torch.logical_xor,
(
_utility.unsqueeze(edge.parity, current_index, self.rank)
# Loop over all edge
for current_index, [name, edge] in enumerate(zip(self.names, self.edges))
# Check if this edge is reversed and if this edge will be applied parity
if (name in reversed_names) and (apply_parity ^ (name in parity_exclude_names))),
torch.zeros([], dtype=torch.bool),
)
data = torch.where(parity, -self.data, +self.data)
data = self.data
if any(self.fermion):
# Parity is xor of all valid reverse parity
parity = functools.reduce(
torch.logical_xor,
(
_utility.unsqueeze(edge.parity, current_index, self.rank)
# Loop over all edge
for current_index, [name, edge] in enumerate(zip(self.names, self.edges))
# Check if this edge is reversed and if this edge will be applied parity
if (name in reversed_names) and (apply_parity ^ (name in parity_exclude_names))),
torch.zeros([], dtype=torch.bool),
)
data = torch.where(parity, -data, +data)
return Tensor(
names=self.names,
edges=tuple(
Expand Down

0 comments on commit 25abd46

Please sign in to comment.