Skip to content

Commit

Permalink
compiler: Remove PESTcCallable class and add liveness to PETScArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ZoeLeibowitz committed Jan 15, 2024
1 parent 738bc43 commit ce3854f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 53 deletions.
4 changes: 1 addition & 3 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from devito.symbolics import IntDiv, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin
from devito.types.petsc import Action, Solution
from devito.types.petsc import Action

__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax']

Expand Down Expand Up @@ -99,7 +99,6 @@ def detect(cls, expr):
ReduceMax: OpMax,
ReduceMin: OpMin,
Action: OpAction,
Solution: OpSolution,
}
try:
return reduction_mapper[type(expr)]
Expand All @@ -117,7 +116,6 @@ def detect(cls, expr):
OpMax = Operation('max')
OpMin = Operation('min')
OpAction = Operation('action')
OpSolution = Operation('solution')


class LoweredEq(IREq):
Expand Down
50 changes: 19 additions & 31 deletions devito/passes/iet/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from devito.ir.iet import (Expression, FindNodes, Section, List,
Callable, Call, Transformer, Callback,
Definition, Uxreplace, FindSymbols)
from devito.ir.equations.equation import OpAction, OpSolution
from devito.ir.equations.equation import OpAction
from devito.types.petsc import Mat, Vec, DM, PetscErrorCode, PETScStruct, PETScArray
from devito.symbolics import FieldFromPointer
from devito.symbolics import FieldFromPointer, Byref


__all__ = ['lower_petsc']
Expand All @@ -13,28 +13,25 @@
@iet_pass
def lower_petsc(iet, **kwargs):

# Find the Section containing the 'action' and the Section
# containing the 'solution'
# Find the Section containing the 'action'.
sections = FindNodes(Section).visit(iet)
secs_with_action = []
secs_with_solution = []
sections_with_action = []
for section in sections:
section_exprs = FindNodes(Expression).visit(section)
if any(expr.operation is OpAction for expr in section_exprs):
secs_with_action.append(section)
if any(expr.operation is OpSolution for expr in section_exprs):
secs_with_solution.append(section)
sections_with_action.append(section)

# TODO: Extend to multiple targets but for now I am assuming we
# are only solving 1 equation via PETSc.
target = FindNodes(Expression).visit(secs_with_solution[0])
target = FindNodes(Expression).visit(sections_with_action[0])
target = [i for i in target[0].functions if not isinstance(i, PETScArray)][0]

# Build PETSc objects required for the solve.
petsc_objs = build_petsc_objects(target)

# Replace target with a PETScArray inside 'action'.
mapper = {target.indexed: petsc_objs['xvec_tmp'].indexed}
updated_action = Uxreplace(mapper).visit(secs_with_action[0])
updated_action = Uxreplace(mapper).visit(sections_with_action[0])

struct = build_struct(updated_action)

Expand All @@ -45,7 +42,7 @@ def lower_petsc(iet, **kwargs):

# TODO: Eventually, this will be extended to deal with multiple
# 'actions' associated with different equations to solve.
iet = Transformer({secs_with_action[0]: solve_body}).visit(iet)
iet = Transformer({sections_with_action[0]: solve_body}).visit(iet)

return iet, {'efuncs': [matvec_callback]}

Expand All @@ -65,19 +62,6 @@ def build_petsc_objects(target):
shape=target.shape, liveness='eager')}


class PETScCallable(Callable):
"""
Special Callable class that does not update the arguments
of the function based on the parameters found inside the
body.
TODO: Pretty sure this is hacky but it works?
"""

@property
def all_parameters(self):
return ()


def build_struct(action):

# Build the struct
Expand All @@ -90,7 +74,11 @@ def build_struct(action):

def build_matvec_body(action, objs, struct):

get_context = Call('PetscCall', [Call('MatShellGetContext',
arguments=[objs['A_matfree'],
Byref(struct.name)])])
body = List(body=[Definition(struct),
get_context,
action])
# Replace all symbols in the body that appear in the struct
# with a pointer to the struct
Expand All @@ -102,12 +90,12 @@ def build_matvec_body(action, objs, struct):

def build_solve(matvec_body, petsc_objs, struct):

matvec_callback = PETScCallable('MyMatShellMult',
matvec_body,
retval=petsc_objs['err'],
parameters=(petsc_objs['A_matfree'],
petsc_objs['xvec'],
petsc_objs['yvec']))
matvec_callback = Callable('MyMatShellMult',
matvec_body,
retval=petsc_objs['err'],
parameters=(petsc_objs['A_matfree'],
petsc_objs['xvec'],
petsc_objs['yvec']))

matvec_operation = Call('PetscCall', [Call('MatShellSetOperation',
arguments=[petsc_objs['A_matfree'],
Expand Down
29 changes: 22 additions & 7 deletions devito/types/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class PETScArray(ArrayBasic):

_data_alignment = False

__rkwargs__ = (ArrayBasic.__rkwargs__ +
('liveness',))

def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)

self._liveness = kwargs.get('liveness', 'lazy')
assert self._liveness in ['eager', 'lazy']

@classmethod
def __dtype_setup__(cls, **kwargs):
return kwargs.get('dtype', np.float32)
Expand All @@ -92,6 +101,18 @@ def _C_ctype(self):
def _C_name(self):
return self.name

@property
def liveness(self):
return self._liveness

@property
def _mem_internal_eager(self):
return self._liveness == 'eager'

@property
def _mem_internal_lazy(self):
return self._liveness == 'lazy'


def dtype_to_petsctype(dtype):
"""
Expand Down Expand Up @@ -130,17 +151,11 @@ def PETScSolve(eq, target, **kwargs):
dimensions=target.dimensions,
shape=target.shape, liveness='eager')

solution_tmp = PETScArray(name='solution_tmp', dtype=target.dtype,
dimensions=target.dimensions,
shape=target.shape, liveness='eager')

# For now, assume the application of the linear operator on
# a vector is eqn.lhs
action = Action(yvec_tmp, eq.lhs.evaluate)

solution = Solution(target, solution_tmp)

return [action] + [solution]
return [action]


class PETScStruct(CompositeObject):
Expand Down
14 changes: 2 additions & 12 deletions examples/petsc/tmp_for_illustration/petsc_solve.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,12 @@ int Kernel(struct dataobj *restrict pn_vec, struct dataobj *restrict u_vec, stru
{
Mat A_matfree;

PetscScalar**restrict solution_tmp;

float (*restrict pn)[pn_vec->size[1]] __attribute__ ((aligned (64))) = (float (*)[pn_vec->size[1]]) pn_vec->data;
float (*restrict u)[u_vec->size[1]][u_vec->size[2]] __attribute__ ((aligned (64))) = (float (*)[u_vec->size[1]][u_vec->size[2]]) u_vec->data;
float (*restrict v)[v_vec->size[1]][v_vec->size[2]] __attribute__ ((aligned (64))) = (float (*)[v_vec->size[1]][v_vec->size[2]]) v_vec->data;

PetscCall(MatShellSetContext(A_matfree,ctx));
PetscCall(MatShellSetOperation(A_matfree,MATOP_MULT,(void (*)(void))MyMatShellMult));

for (int x = x_m; x <= x_M; x += 1)
{
for (int y = y_m; y <= y_M; y += 1)
{
pn[x + 2][y + 2] = solution_tmp[x][y];
}
}

for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))
{
for (int x = x_m; x <= x_M; x += 1)
Expand All @@ -71,6 +60,7 @@ PetscErrorCode MyMatShellMult(Mat A_matfree, Vec xvec, Vec yvec)
PetscScalar**restrict yvec_tmp;

struct MatContext * ctx;
PetscCall(MatShellGetContext(A_matfree,&ctx));

for (int x = ctx->x_m; x <= ctx->x_M; x += 1)
{
Expand All @@ -79,4 +69,4 @@ PetscErrorCode MyMatShellMult(Mat A_matfree, Vec xvec, Vec yvec)
yvec_tmp[x][y] = pow(ctx->h_x, -2)*xvec_tmp[x + 1][y + 2] - 2.0F*pow(ctx->h_x, -2)*xvec_tmp[x + 2][y + 2] + pow(ctx->h_x, -2)*xvec_tmp[x + 3][y + 2] + pow(ctx->h_y, -2)*xvec_tmp[x + 2][y + 1] - 2.0F*pow(ctx->h_y, -2)*xvec_tmp[x + 2][y + 2] + pow(ctx->h_y, -2)*xvec_tmp[x + 2][y + 3];
}
}
}
}

0 comments on commit ce3854f

Please sign in to comment.