diff --git a/tat/tensor.py b/tat/tensor.py index 69af96c36..c5caa7eb3 100644 --- a/tat/tensor.py +++ b/tat/tensor.py @@ -583,6 +583,8 @@ def __setitem__(self: Tensor, position: tuple[int, ...] | tuple[slice, ...] | di indices = self._prepare_position(position) if self.mask[indices]: self.data[indices] = value + else: + raise IndexError("The indices specified are masked, so it is invalid to set item here.") def clear_symmetry(self: Tensor) -> Tensor: """ diff --git a/tests/test_create_tensor.py b/tests/test_create_tensor.py index 001384d78..f784539c8 100644 --- a/tests/test_create_tensor.py +++ b/tests/test_create_tensor.py @@ -44,7 +44,11 @@ def test_tensor_get_set_item() -> None: assert a[0, 0] == 1 a["i":2, "j":3] = 2 # type: ignore[misc] assert a[{"i": 2, "j": 3}] == 2 - a[2, 0] = 3 + try: + a[2, 0] = 3 + assert False + except IndexError: + pass assert a["i":2, "j":0] == 0 # type: ignore[misc] b = tat.Tensor( @@ -61,7 +65,11 @@ def test_tensor_get_set_item() -> None: assert b[0, 0] == 1 b["i":2, "j":3] = 2 # type: ignore[misc] assert b[{"i": 2, "j": 3}] == 2 - b[2, 0] = 3 + try: + b[2, 0] = 3 + assert False + except IndexError: + pass assert b["i":2, "j":0] == 0 # type: ignore[misc]