Skip to content

Commit

Permalink
Raise error when set item to masked position.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 22, 2023
1 parent 0496837 commit ddd3b95
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
2 changes: 2 additions & 0 deletions tat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ def __setitem__(self: Tensor, position: tuple[int, ...] | tuple[slice, ...] | di
indices = self._prepare_position(position)
if self.mask[indices]:
self.data[indices] = value
else:
raise IndexError("The indices specified are masked, so it is invalid to set item here.")

def clear_symmetry(self: Tensor) -> Tensor:
"""
Expand Down
12 changes: 10 additions & 2 deletions tests/test_create_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def test_tensor_get_set_item() -> None:
assert a[0, 0] == 1
a["i":2, "j":3] = 2 # type: ignore[misc]
assert a[{"i": 2, "j": 3}] == 2
a[2, 0] = 3
try:
a[2, 0] = 3
assert False
except IndexError:
pass
assert a["i":2, "j":0] == 0 # type: ignore[misc]

b = tat.Tensor(
Expand All @@ -61,7 +65,11 @@ def test_tensor_get_set_item() -> None:
assert b[0, 0] == 1
b["i":2, "j":3] = 2 # type: ignore[misc]
assert b[{"i": 2, "j": 3}] == 2
b[2, 0] = 3
try:
b[2, 0] = 3
assert False
except IndexError:
pass
assert b["i":2, "j":0] == 0 # type: ignore[misc]


Expand Down

0 comments on commit ddd3b95

Please sign in to comment.