Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expunge eager_zero and fix Mat.assign #29

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions pyop3/array/harray.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,14 +498,13 @@ def copy(self, other, subset=Ellipsis):
else:
self[subset].assign(other[subset])

# symbolic
def zero(self, *, subset=Ellipsis):
return ReplaceAssignment(self[subset], 0)

def eager_zero(self, *, subset=Ellipsis):
def zero(self, *, subset=Ellipsis, eager=True):
# old Firedrake code may hit this, should probably raise a warning
if subset is None:
subset = Ellipsis
self.zero(subset=subset)()

expr = ReplaceAssignment(self[subset], 0)
return expr() if eager else expr


# Needs to be subclass for isinstance checks to work
Expand Down
14 changes: 10 additions & 4 deletions pyop3/array/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ def assemble(self):
self.mat.assemble()

def assign(self, other, *, eager=True):
if eager:
raise NotImplementedError("Cannot eagerly assign to Mats")

if isinstance(other, HierarchicalArray):
# TODO: Check axes match between self and other
expr = PetscMatStore(self, other)
Expand All @@ -243,7 +246,7 @@ def assign(self, other, *, eager=True):
else:
raise NotImplementedError

return expr() if eager else expr
return expr

@property
def nested(self):
Expand Down Expand Up @@ -621,8 +624,11 @@ def from_sparsity(cls, sparsity, *, name=None):
mat = sparsity.materialize()
return cls(sparsity.raxes, sparsity.caxes, sparsity.mat_type, mat, name=name, block_shape=sparsity.block_shape)

def eager_zero(self):
self.mat.zeroEntries()
def zero(self, *, eager=True):
if eager:
self.mat.zeroEntries()
else:
raise NotImplementedError

@property
def values(self):
Expand Down Expand Up @@ -724,7 +730,7 @@ def is_column_matrix(self):
# return self.dat.data_ro.reshape(*shape)[key]

def zeroEntries(self, mat):
self.dat.eager_zero()
self.dat.zero()

def mult(self, A, x, y):
"""Set y = A * x (where A is self)."""
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import pyop3 as op3


def test_eager_zero():
def test_zero():
axes = op3.Axis(5)
array = op3.HierarchicalArray(axes, dtype=op3.IntType)
assert (array.buffer._data == 0).all()

array.buffer._data[...] = 666
array.eager_zero()
array.zero()
assert (array.buffer._data == 0).all()
Loading