Skip to content

Commit

Permalink
Fix pytorch warning when apply logical_xor on different shape tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 22, 2023
1 parent c071053 commit 901298d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tat/_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def neg_symmetry(tensor: torch.Tensor) -> torch.Tensor:

def add_symmetry(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
if tensor_1.dtype is torch.bool:
return torch.logical_xor(tensor_1, tensor_2)
return tensor_1 ^ tensor_2
else:
return torch.add(tensor_1, tensor_2)
return tensor_1 + tensor_2


def zero_symmetry(tensor: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit 901298d

Please sign in to comment.