Skip to content

Commit

Permalink
types: Add coefficients arg to PETScArray
Browse files Browse the repository at this point in the history
  • Loading branch information
ZoeLeibowitz committed Mar 15, 2024
1 parent b9a1f74 commit 0806d43
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
21 changes: 20 additions & 1 deletion devito/types/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cached_property import cached_property
from devito.finite_differences import Differentiable
from devito.types.basic import AbstractFunction
from devito.finite_differences.tools import fd_weights_registry


class DM(LocalObject):
Expand Down Expand Up @@ -72,13 +73,31 @@ class PETScArray(ArrayBasic, Differentiable):

_data_alignment = False

# Default method for the finite difference approximation weights computation.
_default_fd = 'taylor'

__rkwargs__ = (AbstractFunction.__rkwargs__ +
('dimensions', 'liveness'))
('dimensions', 'liveness', 'coefficients'))

def __init_finalize__(self, *args, **kwargs):

super().__init_finalize__(*args, **kwargs)

# Symbolic (finite difference) coefficients
self._coefficients = kwargs.get('coefficients', self._default_fd)
if self._coefficients not in fd_weights_registry:
raise ValueError("coefficients must be one of %s"
" not %s" % (str(fd_weights_registry), self._coefficients))

@classmethod
def __dtype_setup__(cls, **kwargs):
return kwargs.get('dtype', np.float32)

@property
def coefficients(self):
"""Form of the coefficients of the function."""
return self._coefficients

@cached_property
def _C_ctype(self):
petsc_type = dtype_to_petsctype(self.dtype)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,7 @@ def test_petsc_subs():

assert str(eqn_subs) == 'Eq(f1(x, y), Derivative(arr(x, y), (x, 2))' + \
' + Derivative(arr(x, y), (y, 2)))'

assert str(eqn_subs.rhs.evaluate) == '-2.0*arr(x, y)/h_x**2' + \
' + arr(x - h_x, y)/h_x**2 + arr(x + h_x, y)/h_x**2 - 2.0*arr(x, y)/h_y**2' + \
' + arr(x, y - h_y)/h_y**2 + arr(x, y + h_y)/h_y**2'

0 comments on commit 0806d43

Please sign in to comment.