Skip to content

Commit

Permalink
change FunctionPointer class and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZoeLeibowitz committed Aug 8, 2023
1 parent 69a8791 commit b9768ea
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 55 deletions.
36 changes: 24 additions & 12 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast',
'DefFunction', 'InlineIf', 'Keyword', 'String', 'Macro', 'MacroArgument',
'CustomType', 'Deref', 'INT', 'FLOAT', 'DOUBLE', 'VOID',
'Null', 'SizeOf', 'rfunc', 'cast_mapper', 'BasicWrapperMixin', 'FunctionPtr']
'Null', 'SizeOf', 'rfunc', 'cast_mapper', 'BasicWrapperMixin',
'FunctionPointer']


class CondEq(sympy.Eq):
Expand Down Expand Up @@ -398,26 +399,37 @@ def typ(self):
@property
def _op(self):
return '(%s)' % self.typ


class FunctionPtr(UnaryOp):

class FunctionPointer(sympy.Expr, Pickable):

"""
Symbolic representation of C's function pointers.
Symbolic representation of C's function pointers to be used as an
argument to a function.
"""

_return_typ = ''
_parameter_typ = ''
__rargs__ = ('func_name', 'return_type', 'parameter_type',)

def __new__(cls, func_name, return_type, parameter_type, **kwargs):

obj = sympy.Expr.__new__(cls, func_name, return_type, parameter_type)
obj.func_name = func_name
obj.return_type = return_type
obj.parameter_type = parameter_type

def __new__(cls, base, **kwargs):
obj = super().__new__(cls, base)
return obj

func = Pickable._rebuild
def __str__(self):
return "(%s (%s)(%s))%s" % (self.return_type, '*',
self.parameter_type, self.func_name)

@property
def _op(self):
return '(%s (%s)(%s))' % (self._return_typ, '*', self._parameter_typ)
__repr__ = __str__

def _hashable_content(self):
return (self.func_name, self.return_type, self.parameter_type)

# Pickling support
__reduce_ex__ = Pickable.__reduce_ex__


class IndexedPointer(sympy.Expr, Pickable, BasicWrapperMixin):
Expand Down
50 changes: 29 additions & 21 deletions tests/test_iet.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,32 +446,32 @@ def test_petsc_expressions():
def test_petsc_iterations():
dims = {'x': Dimension(name='x'),
'y': Dimension(name='y')}

symbs = {'left': PetscObject(name='left', petsc_type='PetscScalar'),
'right': PetscObject(name='right', petsc_type='PetscScalar'),
'x_m': PetscObject(name='x_m', petsc_type='PetscInt', is_const=True),
'x_M': PetscObject(name='x_M', petsc_type='PetscInt', is_const=True),
'y_m': PetscObject(name='y_m', petsc_type='PetscInt', is_const=True),
'y_M': PetscObject(name='y_M', petsc_type='PetscInt', is_const=True)}
'right': PetscObject(name='right', petsc_type='PetscScalar'),
'x_m': PetscObject(name='x_m', petsc_type='PetscInt', is_const=True),
'x_M': PetscObject(name='x_M', petsc_type='PetscInt', is_const=True),
'y_m': PetscObject(name='y_m', petsc_type='PetscInt', is_const=True),
'y_M': PetscObject(name='y_M', petsc_type='PetscInt', is_const=True)}

def get_exprs(left, right):
return [Expression(DummyEq(left, 0.)),
Expression(DummyEq(right, 0.))]

exprs = get_exprs(symbs['left'],
symbs['right'])

def get_iters(dims, symbs):
return [lambda ex: Iteration(ex, dims['x'], (symbs['x_m'], symbs['x_M'], 1)),
lambda ex: Iteration(ex, dims['y'], (symbs['y_m'], symbs['y_M'], 1))]

iters = get_iters(dims, symbs)

def get_block(exprs, iters):
return iters[0](iters[1]([exprs[0],exprs[1]]))
return iters[0](iters[1]([exprs[0], exprs[1]]))

block1 = get_block(exprs, iters)

kernel = Callable('foo', block1, 'void', ())

assert str(kernel) == """\
Expand All @@ -487,11 +487,12 @@ def get_block(exprs, iters):
}
}"""


def test_petsc_dummy():
"""
"""
#### create an 'operator' manually
# create an 'operator' manually
dims_op = {'x': SpaceDimension(name='x'),
'y': SpaceDimension(name='y')}

Expand Down Expand Up @@ -520,22 +521,29 @@ def get_block1(exprs_op, iters_op):
'A_matfree': PetscObject(name='A_matfree', petsc_type='Mat'),
'xvec': PetscObject(name='xvec', petsc_type='Vec'),
'yvec': PetscObject(name='yvec', petsc_type='Vec'),
'xarr': PetscObject(name='xarr', petsc_type='PetscScalar', grid=Grid((2,)), is_const=True),
'yarr': PetscObject(name='yarr', petsc_type='PetscScalar', grid=Grid((2,)))}
'xarr': PetscObject(name='xarr', petsc_type='PetscScalar',
grid=Grid((2,)), is_const=True),
'yarr': PetscObject(name='yarr', petsc_type='PetscScalar',
grid=Grid((2,)))}

MyMatShellMult = Callable('MyMatShellMult', kernel_op.body, retval=symbs_petsc['retval'],
parameters=(symbs_petsc['A_matfree'], symbs_petsc['xvec'], symbs_petsc['yvec']))
MyMatShellMult = Callable('MyMatShellMult', kernel_op.body,
retval=symbs_petsc['retval'],
parameters=(symbs_petsc['A_matfree'],
symbs_petsc['xvec'], symbs_petsc['yvec']))

call = Call(MyMatShellMult.name)
transformer = Transformer({block1: call})
main_block = transformer.visit(block1)
new_op_block = [Call('PetscCall', [Call('VecGetArrayRead', arguments=[symbs_petsc['xvec'], Byref(symbs_petsc['xarr'])])]),
new_op_block = [Call('PetscCall', [Call('VecGetArrayRead',
arguments=[symbs_petsc['xvec'],
Byref(symbs_petsc['xarr'])])]),
main_block]
main = Callable('main', new_op_block, 'int', ())

assert('Original kernel:\n' + str(kernel_op) + '\n' + \
'MyMatShellMult with body of original kernel:\n' + str(MyMatShellMult) + '\n' + \
'New kernel with a call to the MyMatShellMult function:\n' + str(main)) == """\
assert('Original kernel:\n' + str(kernel_op) + '\n' +
'MyMatShellMult with body of original kernel:\n' + str(MyMatShellMult) +
'\n' + 'New kernel with a call to the MyMatShellMult function:\n' +
str(main)) == """\
Original kernel:
int kernel()
{
Expand Down
7 changes: 6 additions & 1 deletion tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from devito.types.basic import BoundSymbol
from devito.tools import EnrichedTuple
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
CallFromPointer, DefFunction)
CallFromPointer, DefFunction, FunctionPointer)
from examples.seismic import (demo_model, AcquisitionGeometry,
TimeAxis, RickerSource, Receiver)

Expand Down Expand Up @@ -310,6 +310,11 @@ def test_symbolics(self, pickle):
assert df == new_df
assert df.arguments == new_df.arguments

fp = FunctionPointer('foo', 'void', 'void')
pkl_fp = pickle.dumps(fp)
new_fp = pickle.loads(pkl_fp)
assert fp == new_fp

def test_timers(self, pickle):
"""Pickling for Timers used in Operators for C-level profiling."""
timer = Timer('timer', ['sec0', 'sec1'])
Expand Down
41 changes: 20 additions & 21 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
from devito.ir import Expression, FindNodes
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
CallFromPointer, Cast, DefFunction, FieldFromPointer,
INT, FieldFromComposite, IntDiv, ccode, uxreplace, FunctionPtr)
INT, FieldFromComposite, IntDiv, ccode, uxreplace,
FunctionPointer)
from devito.tools import as_tuple
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
Symbol as dSymbol)
from devito.passes.iet.petsc import PetscObject
from devito.ir.iet import Callable, Definition


def test_float_indices():
Expand Down Expand Up @@ -219,6 +218,24 @@ def test_field_from_composite():
assert ffc1.free_symbols == {s}


def test_function_pointer():

# Test construction
fp0 = FunctionPointer('foo0', 'void', 'void')
assert str(fp0) == '(void (*)(void))foo0'
fp1 = FunctionPointer('foo0', 'int', 'void')
fp2 = FunctionPointer('foo0', 'void', 'float')
fp3 = FunctionPointer('foo1', 'void', 'void')
assert fp0 != fp1
assert fp0 != fp2
assert fp0 != fp3

# Test hashing
assert hash(fp0) != hash(fp1)
assert hash(fp0) != hash(fp2)
assert hash(fp0) != hash(fp3)


def test_extended_sympy_arithmetic():
# NOTE: `s` gets turned into a devito.Symbol, whose dtype
# defaults to np.int32
Expand Down Expand Up @@ -293,24 +310,6 @@ class BarCast(Cast):
assert v != v1


def test_function_ptr():
s = Symbol(name='s', dtype=np.float32)
iet = Definition(s)

A_matfree = PetscObject(name='A_matfree', petsc_type='Mat')
xvec = PetscObject(name='xvec', petsc_type='Vec')
yvec = PetscObject(name='yvec', petsc_type='Vec')

MyMatShellMult = Callable('MyMatShellMult', iet, retval='PetscErrorCode', parameters=(A_matfree, xvec, yvec))

class voidFunctionPtr(FunctionPtr):
_return_typ = 'void'
_parameter_typ = 'void'

tmp = voidFunctionPtr(MyMatShellMult.name)
assert ccode(tmp) == '(void (*)(void))MyMatShellMult'


def test_symbolic_printing():
b = Symbol('b')

Expand Down

0 comments on commit b9768ea

Please sign in to comment.