diff --git a/pyop3/array/harray.py b/pyop3/array/harray.py index 5095794..6c8bac3 100644 --- a/pyop3/array/harray.py +++ b/pyop3/array/harray.py @@ -498,14 +498,13 @@ def copy(self, other, subset=Ellipsis): else: self[subset].assign(other[subset]) - # symbolic - def zero(self, *, subset=Ellipsis): - return ReplaceAssignment(self[subset], 0) - - def eager_zero(self, *, subset=Ellipsis): + def zero(self, *, subset=Ellipsis, eager=True): + # old Firedrake code may hit this, should probably raise a warning if subset is None: subset = Ellipsis - self.zero(subset=subset)() + + expr = ReplaceAssignment(self[subset], 0) + return expr() if eager else expr # Needs to be subclass for isinstance checks to work diff --git a/pyop3/array/petsc.py b/pyop3/array/petsc.py index f3c4fae..141c1e7 100644 --- a/pyop3/array/petsc.py +++ b/pyop3/array/petsc.py @@ -230,6 +230,9 @@ def assemble(self): self.mat.assemble() def assign(self, other, *, eager=True): + if eager: + raise NotImplementedError("Cannot eagerly assign to Mats") + if isinstance(other, HierarchicalArray): # TODO: Check axes match between self and other expr = PetscMatStore(self, other) @@ -243,7 +246,7 @@ def assign(self, other, *, eager=True): else: raise NotImplementedError - return expr() if eager else expr + return expr @property def nested(self): @@ -621,8 +624,11 @@ def from_sparsity(cls, sparsity, *, name=None): mat = sparsity.materialize() return cls(sparsity.raxes, sparsity.caxes, sparsity.mat_type, mat, name=name, block_shape=sparsity.block_shape) - def eager_zero(self): - self.mat.zeroEntries() + def zero(self, *, eager=True): + if eager: + self.mat.zeroEntries() + else: + raise NotImplementedError @property def values(self): @@ -724,7 +730,7 @@ def is_column_matrix(self): # return self.dat.data_ro.reshape(*shape)[key] def zeroEntries(self, mat): - self.dat.eager_zero() + self.dat.zero() def mult(self, A, x, y): """Set y = A * x (where A is self).""" diff --git a/tests/unit/test_array.py b/tests/unit/test_array.py index 8f41b81..dc2c7cd 100644 --- a/tests/unit/test_array.py +++ b/tests/unit/test_array.py @@ -3,11 +3,11 @@ import pyop3 as op3 -def test_eager_zero(): +def test_zero(): axes = op3.Axis(5) array = op3.HierarchicalArray(axes, dtype=op3.IntType) assert (array.buffer._data == 0).all() array.buffer._data[...] = 666 - array.eager_zero() + array.zero() assert (array.buffer._data == 0).all()