From ce3854f4228a8b5d32ade0cb9cc15cc296e00c4c Mon Sep 17 00:00:00 2001 From: ZoeLeibowitz Date: Mon, 15 Jan 2024 11:53:18 +0000 Subject: [PATCH] compiler: Remove PESTcCallable class and add liveness to PETScArrays --- devito/ir/equations/equation.py | 4 +- devito/passes/iet/petsc.py | 50 +++++++------------ devito/types/petsc.py | 29 ++++++++--- .../petsc/tmp_for_illustration/petsc_solve.c | 14 +----- 4 files changed, 44 insertions(+), 53 deletions(-) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 34d1c4ea31e..2f67a4b9f8d 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -8,7 +8,7 @@ from devito.symbolics import IntDiv, uxreplace from devito.tools import Pickable, Tag, frozendict from devito.types import Eq, Inc, ReduceMax, ReduceMin -from devito.types.petsc import Action, Solution +from devito.types.petsc import Action __all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax'] @@ -99,7 +99,6 @@ def detect(cls, expr): ReduceMax: OpMax, ReduceMin: OpMin, Action: OpAction, - Solution: OpSolution, } try: return reduction_mapper[type(expr)] @@ -117,7 +116,6 @@ def detect(cls, expr): OpMax = Operation('max') OpMin = Operation('min') OpAction = Operation('action') -OpSolution = Operation('solution') class LoweredEq(IREq): diff --git a/devito/passes/iet/petsc.py b/devito/passes/iet/petsc.py index 1f69ea21942..fea9e32fce0 100644 --- a/devito/passes/iet/petsc.py +++ b/devito/passes/iet/petsc.py @@ -2,9 +2,9 @@ from devito.ir.iet import (Expression, FindNodes, Section, List, Callable, Call, Transformer, Callback, Definition, Uxreplace, FindSymbols) -from devito.ir.equations.equation import OpAction, OpSolution +from devito.ir.equations.equation import OpAction from devito.types.petsc import Mat, Vec, DM, PetscErrorCode, PETScStruct, PETScArray -from devito.symbolics import FieldFromPointer +from devito.symbolics import FieldFromPointer, Byref __all__ = ['lower_petsc'] @@ -13,28 +13,25 @@ @iet_pass def lower_petsc(iet, **kwargs): - # Find the Section containing the 'action' and the Section - # containing the 'solution' + # Find the Section containing the 'action'. sections = FindNodes(Section).visit(iet) - secs_with_action = [] - secs_with_solution = [] + sections_with_action = [] for section in sections: section_exprs = FindNodes(Expression).visit(section) if any(expr.operation is OpAction for expr in section_exprs): - secs_with_action.append(section) - if any(expr.operation is OpSolution for expr in section_exprs): - secs_with_solution.append(section) + sections_with_action.append(section) # TODO: Extend to multiple targets but for now I am assuming we # are only solving 1 equation via PETSc. - target = FindNodes(Expression).visit(secs_with_solution[0]) + target = FindNodes(Expression).visit(sections_with_action[0]) target = [i for i in target[0].functions if not isinstance(i, PETScArray)][0] # Build PETSc objects required for the solve. petsc_objs = build_petsc_objects(target) + # Replace target with a PETScArray inside 'action'. mapper = {target.indexed: petsc_objs['xvec_tmp'].indexed} - updated_action = Uxreplace(mapper).visit(secs_with_action[0]) + updated_action = Uxreplace(mapper).visit(sections_with_action[0]) struct = build_struct(updated_action) @@ -45,7 +42,7 @@ def lower_petsc(iet, **kwargs): # TODO: Eventually, this will be extended to deal with multiple # 'actions' associated with different equations to solve. - iet = Transformer({secs_with_action[0]: solve_body}).visit(iet) + iet = Transformer({sections_with_action[0]: solve_body}).visit(iet) return iet, {'efuncs': [matvec_callback]} @@ -65,19 +62,6 @@ def build_petsc_objects(target): shape=target.shape, liveness='eager')} -class PETScCallable(Callable): - """ - Special Callable class that does not update the arguments - of the function based on the parameters found inside the - body. - TODO: Pretty sure this is hacky but it works? - """ - - @property - def all_parameters(self): - return () - - def build_struct(action): # Build the struct @@ -90,7 +74,11 @@ def build_struct(action): def build_matvec_body(action, objs, struct): + get_context = Call('PetscCall', [Call('MatShellGetContext', + arguments=[objs['A_matfree'], + Byref(struct.name)])]) body = List(body=[Definition(struct), + get_context, action]) # Replace all symbols in the body that appear in the struct # with a pointer to the struct @@ -102,12 +90,12 @@ def build_matvec_body(action, objs, struct): def build_solve(matvec_body, petsc_objs, struct): - matvec_callback = PETScCallable('MyMatShellMult', - matvec_body, - retval=petsc_objs['err'], - parameters=(petsc_objs['A_matfree'], - petsc_objs['xvec'], - petsc_objs['yvec'])) + matvec_callback = Callable('MyMatShellMult', + matvec_body, + retval=petsc_objs['err'], + parameters=(petsc_objs['A_matfree'], + petsc_objs['xvec'], + petsc_objs['yvec'])) matvec_operation = Call('PetscCall', [Call('MatShellSetOperation', arguments=[petsc_objs['A_matfree'], diff --git a/devito/types/petsc.py b/devito/types/petsc.py index 662b2fd7b31..de45cbe979b 100644 --- a/devito/types/petsc.py +++ b/devito/types/petsc.py @@ -78,6 +78,15 @@ class PETScArray(ArrayBasic): _data_alignment = False + __rkwargs__ = (ArrayBasic.__rkwargs__ + + ('liveness',)) + + def __init_finalize__(self, *args, **kwargs): + super().__init_finalize__(*args, **kwargs) + + self._liveness = kwargs.get('liveness', 'lazy') + assert self._liveness in ['eager', 'lazy'] + @classmethod def __dtype_setup__(cls, **kwargs): return kwargs.get('dtype', np.float32) @@ -92,6 +101,18 @@ def _C_ctype(self): def _C_name(self): return self.name + @property + def liveness(self): + return self._liveness + + @property + def _mem_internal_eager(self): + return self._liveness == 'eager' + + @property + def _mem_internal_lazy(self): + return self._liveness == 'lazy' + def dtype_to_petsctype(dtype): """ @@ -130,17 +151,11 @@ def PETScSolve(eq, target, **kwargs): dimensions=target.dimensions, shape=target.shape, liveness='eager') - solution_tmp = PETScArray(name='solution_tmp', dtype=target.dtype, - dimensions=target.dimensions, - shape=target.shape, liveness='eager') - # For now, assume the application of the linear operator on # a vector is eqn.lhs action = Action(yvec_tmp, eq.lhs.evaluate) - solution = Solution(target, solution_tmp) - - return [action] + [solution] + return [action] class PETScStruct(CompositeObject): diff --git a/examples/petsc/tmp_for_illustration/petsc_solve.c b/examples/petsc/tmp_for_illustration/petsc_solve.c index bff2b70d5fb..bf130a1b51e 100644 --- a/examples/petsc/tmp_for_illustration/petsc_solve.c +++ b/examples/petsc/tmp_for_illustration/petsc_solve.c @@ -32,23 +32,12 @@ int Kernel(struct dataobj *restrict pn_vec, struct dataobj *restrict u_vec, stru { Mat A_matfree; - PetscScalar**restrict solution_tmp; - float (*restrict pn)[pn_vec->size[1]] __attribute__ ((aligned (64))) = (float (*)[pn_vec->size[1]]) pn_vec->data; float (*restrict u)[u_vec->size[1]][u_vec->size[2]] __attribute__ ((aligned (64))) = (float (*)[u_vec->size[1]][u_vec->size[2]]) u_vec->data; float (*restrict v)[v_vec->size[1]][v_vec->size[2]] __attribute__ ((aligned (64))) = (float (*)[v_vec->size[1]][v_vec->size[2]]) v_vec->data; PetscCall(MatShellSetContext(A_matfree,ctx)); PetscCall(MatShellSetOperation(A_matfree,MATOP_MULT,(void (*)(void))MyMatShellMult)); - - for (int x = x_m; x <= x_M; x += 1) - { - for (int y = y_m; y <= y_M; y += 1) - { - pn[x + 2][y + 2] = solution_tmp[x][y]; - } - } - for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2)) { for (int x = x_m; x <= x_M; x += 1) @@ -71,6 +60,7 @@ PetscErrorCode MyMatShellMult(Mat A_matfree, Vec xvec, Vec yvec) PetscScalar**restrict yvec_tmp; struct MatContext * ctx; + PetscCall(MatShellGetContext(A_matfree,&ctx)); for (int x = ctx->x_m; x <= ctx->x_M; x += 1) { @@ -79,4 +69,4 @@ PetscErrorCode MyMatShellMult(Mat A_matfree, Vec xvec, Vec yvec) yvec_tmp[x][y] = pow(ctx->h_x, -2)*xvec_tmp[x + 1][y + 2] - 2.0F*pow(ctx->h_x, -2)*xvec_tmp[x + 2][y + 2] + pow(ctx->h_x, -2)*xvec_tmp[x + 3][y + 2] + pow(ctx->h_y, -2)*xvec_tmp[x + 2][y + 1] - 2.0F*pow(ctx->h_y, -2)*xvec_tmp[x + 2][y + 2] + pow(ctx->h_y, -2)*xvec_tmp[x + 2][y + 3]; } } -} \ No newline at end of file +}