Skip to content

Commit

Permalink
Expunge eager_zero and fix Mat.assign (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward authored Apr 26, 2024
1 parent cfc270f commit 16e2b62
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
11 changes: 5 additions & 6 deletions pyop3/array/harray.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,14 +498,13 @@ def copy(self, other, subset=Ellipsis):
else:
self[subset].assign(other[subset])

# symbolic
def zero(self, *, subset=Ellipsis):
return ReplaceAssignment(self[subset], 0)

def eager_zero(self, *, subset=Ellipsis):
def zero(self, *, subset=Ellipsis, eager=True):
# old Firedrake code may hit this, should probably raise a warning
if subset is None:
subset = Ellipsis
self.zero(subset=subset)()

expr = ReplaceAssignment(self[subset], 0)
return expr() if eager else expr


# Needs to be subclass for isinstance checks to work
Expand Down
14 changes: 10 additions & 4 deletions pyop3/array/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ def assemble(self):
self.mat.assemble()

def assign(self, other, *, eager=True):
if eager:
raise NotImplementedError("Cannot eagerly assign to Mats")

if isinstance(other, HierarchicalArray):
# TODO: Check axes match between self and other
expr = PetscMatStore(self, other)
Expand All @@ -243,7 +246,7 @@ def assign(self, other, *, eager=True):
else:
raise NotImplementedError

return expr() if eager else expr
return expr

@property
def nested(self):
Expand Down Expand Up @@ -621,8 +624,11 @@ def from_sparsity(cls, sparsity, *, name=None):
mat = sparsity.materialize()
return cls(sparsity.raxes, sparsity.caxes, sparsity.mat_type, mat, name=name, block_shape=sparsity.block_shape)

def eager_zero(self):
self.mat.zeroEntries()
def zero(self, *, eager=True):
if eager:
self.mat.zeroEntries()
else:
raise NotImplementedError

@property
def values(self):
Expand Down Expand Up @@ -724,7 +730,7 @@ def is_column_matrix(self):
# return self.dat.data_ro.reshape(*shape)[key]

def zeroEntries(self, mat):
self.dat.eager_zero()
self.dat.zero()

def mult(self, A, x, y):
"""Set y = A * x (where A is self)."""
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import pyop3 as op3


def test_eager_zero():
def test_zero():
axes = op3.Axis(5)
array = op3.HierarchicalArray(axes, dtype=op3.IntType)
assert (array.buffer._data == 0).all()

array.buffer._data[...] = 666
array.eager_zero()
array.zero()
assert (array.buffer._data == 0).all()

0 comments on commit 16e2b62

Please sign in to comment.