diff --git a/tat/tensor.py b/tat/tensor.py index c5caa7eb3..8aee3a462 100644 --- a/tat/tensor.py +++ b/tat/tensor.py @@ -786,18 +786,20 @@ def reverse_edge( parity_exclude_names = set() assert all(name in self.names for name in reversed_names) assert all(name in reversed_names for name in parity_exclude_names) - # Parity is xor of all valid reverse parity - parity = functools.reduce( - torch.logical_xor, - ( - _utility.unsqueeze(edge.parity, current_index, self.rank) - # Loop over all edge - for current_index, [name, edge] in enumerate(zip(self.names, self.edges)) - # Check if this edge is reversed and if this edge will be applied parity - if (name in reversed_names) and (apply_parity ^ (name in parity_exclude_names))), - torch.zeros([], dtype=torch.bool), - ) - data = torch.where(parity, -self.data, +self.data) + data = self.data + if any(self.fermion): + # Parity is xor of all valid reverse parity + parity = functools.reduce( + torch.logical_xor, + ( + _utility.unsqueeze(edge.parity, current_index, self.rank) + # Loop over all edge + for current_index, [name, edge] in enumerate(zip(self.names, self.edges)) + # Check if this edge is reversed and if this edge will be applied parity + if (name in reversed_names) and (apply_parity ^ (name in parity_exclude_names))), + torch.zeros([], dtype=torch.bool), + ) + data = torch.where(parity, -data, +data) return Tensor( names=self.names, edges=tuple(