diff --git a/devito/petsc/iet/nodes.py b/devito/petsc/iet/nodes.py index 0ac6a0e57a..2137495487 100644 --- a/devito/petsc/iet/nodes.py +++ b/devito/petsc/iet/nodes.py @@ -36,4 +36,4 @@ def callback_form(self): class PETScCall(Call): - pass \ No newline at end of file + pass diff --git a/devito/petsc/iet/passes.py b/devito/petsc/iet/passes.py index 69f2dfb199..eafd5cd832 100644 --- a/devito/petsc/iet/passes.py +++ b/devito/petsc/iet/passes.py @@ -12,7 +12,8 @@ from devito.petsc.utils import solver_mapper, core_metadata from devito.petsc.iet.routines import PETScCallbackBuilder from devito.petsc.iet.utils import (petsc_call, petsc_call_mpi, petsc_struct, - spatial_iteration_loops, assign_time_iters) + spatial_injectsolve_iter, assign_time_iters, + retrieve_mod_dims) @iet_pass @@ -33,11 +34,8 @@ def lower_petsc(iet, **kwargs): # Create core PETSc calls (not specific to each PETScSolve) core = make_core_petsc_calls(objs, **kwargs) - # Create injectsolve mapper from the spatial iteration loops - # (exclude time loop if present) - spatial_body = spatial_iteration_loops(iet) injectsolve_mapper = MapNodes(Iteration, InjectSolveDummy, - 'groupby').visit(List(body=spatial_body)) + 'groupby').visit(iet) setup = [] subs = {} @@ -65,13 +63,17 @@ def lower_petsc(iet, **kwargs): break # Generate all PETSc callback functions for the target via recusive compilation - for iter, (injectsolve,) in injectsolve_mapper.items(): + for iters, (injectsolve,) in injectsolve_mapper.items(): if injectsolve.expr.rhs.target != target: continue + # Retrieve the modulo dimensions and map them to their origins based + # on the initial lowering + solver_objs['mod_dims'] = retrieve_mod_dims(iters) + space_iter, = spatial_injectsolve_iter(iters, injectsolve) matvec_op, formfunc_op, runsolve = builder.make(injectsolve, objs, solver_objs) setup.extend([matvec_op, formfunc_op]) - subs.update({iter[0]: List(body=runsolve)}) + subs.update({space_iter: List(body=runsolve)}) break # Generate callback to populate main struct object @@ -83,7 +85,7 @@ def lower_petsc(iet, **kwargs): setup.extend([BlankLine, call_struct_callback] + calls_set_app_ctx) iet = Transformer(subs).visit(iet) - + iet = assign_time_iters(iet, struct_main) body = core + tuple(setup) + (BlankLine,) + iet.body.body diff --git a/devito/petsc/iet/routines.py b/devito/petsc/iet/routines.py index 7f54975ab1..75e3ceb595 100644 --- a/devito/petsc/iet/routines.py +++ b/devito/petsc/iet/routines.py @@ -90,6 +90,8 @@ def create_matvec_body(self, injectsolve, body, solver_objs, objs): dmda = objs['da_so_%s' % linsolveexpr.target.space_order] + body = correct_mod_dims(body, solver_objs['mod_dims']) + struct = build_petsc_struct(body, 'matvec', liveness='eager') y_matvec = linsolveexpr.arrays['y_matvec'] @@ -217,6 +219,8 @@ def create_formfunc_body(self, injectsolve, body, solver_objs, objs): dmda = objs['da_so_%s' % linsolveexpr.target.space_order] + body = correct_mod_dims(body, solver_objs['mod_dims']) + struct = build_petsc_struct(body, 'formfunc', liveness='eager') y_formfunc = linsolveexpr.arrays['y_formfunc'] @@ -348,6 +352,8 @@ def create_formrhs_body(self, injectsolve, body, solver_objs, objs): 'DMDAGetLocalInfo', [dmda, Byref(dmda.info)] ) + body = correct_mod_dims(body, solver_objs['mod_dims']) + struct = build_petsc_struct(body, 'formrhs', liveness='eager') dm_get_app_context = petsc_call( @@ -481,6 +487,13 @@ class StartPtr(LocalObject): return (vec_get_size, expr, vec_replace_array) +def correct_mod_dims(body, mod_dims): + old_mod_dims = [ + i for i in FindSymbols('dimensions').visit(body) if isinstance(i, ModuloDimension) + ] + return Uxreplace({i: mod_dims[i.origin] for i in old_mod_dims}).visit(body) + + Null = Macro('NULL') void = 'void' diff --git a/devito/petsc/iet/utils.py b/devito/petsc/iet/utils.py index 9b5e5258dc..8f1b0a34aa 100644 --- a/devito/petsc/iet/utils.py +++ b/devito/petsc/iet/utils.py @@ -22,11 +22,12 @@ def petsc_struct(name, fields, liveness='lazy'): fields=fields, liveness=liveness) -def spatial_iteration_loops(iet): +def spatial_injectsolve_iter(iter, injectsolve): spatial_body = [] - for tree in retrieve_iteration_tree(iet): + for tree in retrieve_iteration_tree(iter[0]): root = filter_iterations(tree, key=lambda i: i.dim.is_Space)[0] - spatial_body.append(root) + if injectsolve in FindNodes(InjectSolveDummy).visit(root): + spatial_body.append(root) return spatial_body @@ -64,10 +65,19 @@ def assign_time_iters(iet, struct): mapper = {} for iter in time_iters: - common_dimensions = [dim for dim in iter.dimensions if dim in struct.fields] - common_dimensions = [DummyExpr(FieldFromComposite(dim, struct), dim) - for dim in common_dimensions] - iter_new = iter._rebuild(nodes=List(body=tuple(common_dimensions)+iter.nodes)) + common_dims = [dim for dim in iter.dimensions if dim in struct.fields] + common_dims = [ + DummyExpr(FieldFromComposite(dim, struct), dim) for dim in common_dims + ] + iter_new = iter._rebuild(nodes=List(body=tuple(common_dims)+iter.nodes)) mapper.update({iter: iter_new}) - from IPython import embed; embed() - return Transformer(mapper).visit(iet) \ No newline at end of file + + return Transformer(mapper).visit(iet) + + +def retrieve_mod_dims(iters): + outer_iter_dims = iters[0].dimensions + if any(dim.is_Time for dim in outer_iter_dims): + mod_dims = [dim for dim in outer_iter_dims if dim.is_Modulo] + return {dim.origin: dim for dim in mod_dims} + return {} diff --git a/devito/petsc/solve.py b/devito/petsc/solve.py index 20487b4037..005f9363aa 100644 --- a/devito/petsc/solve.py +++ b/devito/petsc/solve.py @@ -8,13 +8,14 @@ from devito.types.equation import InjectSolveEq from devito.operations.solve import eval_time_derivatives from devito.symbolics import retrieve_functions +from devito.tools import filter_sorted from devito.petsc.types import LinearSolveExpr, PETScArray, CallbackExpr __all__ = ['PETScSolve'] -def PETScSolve(eq, target, bcs=None, solver_parameters=None, **kwargs): +def PETScSolve(eqns, target, solver_parameters=None, **kwargs): prefixes = ['y_matvec', 'x_matvec', 'y_formfunc', 'x_formfunc', 'b_tmp'] arrays = { @@ -28,74 +29,44 @@ def PETScSolve(eq, target, bcs=None, solver_parameters=None, **kwargs): for p in prefixes } - b, F_target = separate_eqn(eq, target) - + eqns = eqns if isinstance(eqns, (list, tuple)) else [eqns] # Passed through main kernel and removed at iet level, used to generate # correct time loop etc - dummy = list(set(retrieve_functions(F_target - b))) - dummy_expr = sum(dummy) - - # TODO: Current assumption is that problem is linear and user has not provided - # a jacobian. Hence, we can use F_target to form the jac-vec product - - matvecaction = Eq( - arrays['y_matvec'], - CallbackExpr(F_target.subs({target: arrays['x_matvec']}), *dummy), - subdomain=eq.subdomain - ) - - formfunction = Eq( - arrays['y_formfunc'], - CallbackExpr(F_target.subs({target: arrays['x_formfunc']}), *dummy), - subdomain=eq.subdomain - ) - - rhs = Eq( - arrays['b_tmp'], - CallbackExpr(b, *dummy), - subdomain=eq.subdomain - ) + dummy_expr = sum(filter_sorted(retrieve_functions(eqns))) - # Placeholder equation for inserting calls to the solver - inject_solve = InjectSolveEq(target, LinearSolveExpr( - dummy_expr, target=target, solver_parameters=solver_parameters, - matvecs=[matvecaction], formfuncs=[formfunction], - formrhs=[rhs], arrays=arrays, - ), subdomain=eq.subdomain) + matvecs = [] + formfuncs = [] + formrhs = [] + + for eq in eqns: + b, F_target = separate_eqn(eq, target) - if not bcs: - return [inject_solve] - - # NOTE: BELOW IS NOT FULLY TESTED/IMPLEMENTED YET - bcs_for_matvec = [] - bcs_for_formfunc = [] - bcs_for_rhs = [] - for bc in bcs: - # TODO: Insert code to distiguish between essential and natural - # boundary conditions since these are treated differently within - # the solver - # NOTE: May eventually remove the essential bcs from the solve - # (and move to rhs) but for now, they are included since this - # is not trivial to implement when using DMDA - # NOTE: Below is temporary -> Just using this as a palceholder for - # the actual BC implementation - centre = centre_stencil(F_target, target) - bcs_for_matvec.append(Eq( + # TODO: Current assumption is that problem is linear and user has not provided + # a jacobian. Hence, we can use F_target to form the jac-vec product + + matvecs.append(Eq( arrays['y_matvec'], - CallbackExpr(centre.subs({target: arrays['x_matvec']}), *dummy), - subdomain=bc.subdomain + CallbackExpr(F_target.subs({target: arrays['x_matvec']})), + subdomain=eq.subdomain + )) + + formfuncs.append(Eq( + arrays['y_formfunc'], + CallbackExpr(F_target.subs({target: arrays['x_formfunc']})), + subdomain=eq.subdomain )) - # NOTE: Temporary - bcs_for_formfunc.append(Eq(arrays['y_formfunc'], - 0., subdomain=bc.subdomain)) - bcs_for_rhs.append(Eq(arrays['b_tmp'], 0., subdomain=bc.subdomain)) + formrhs.append(Eq( + arrays['b_tmp'], + CallbackExpr(b), + subdomain=eq.subdomain + )) + + # Placeholder equation for inserting calls to the solver inject_solve = InjectSolveEq(target, LinearSolveExpr( dummy_expr, target=target, solver_parameters=solver_parameters, - matvecs=[matvecaction]+bcs_for_matvec, - formfuncs=[formfunction], - formrhs=[rhs], - arrays=arrays, + matvecs=matvecs, formfuncs=formfuncs, + formrhs=formrhs, arrays=arrays, ), subdomain=eq.subdomain) return [inject_solve] @@ -180,4 +151,4 @@ def _(expr, target): if not expr.has(target): return 0 args = [centre_stencil(a, target) for a in expr.evaluate.args] - return expr.evaluate.func(*args) \ No newline at end of file + return expr.evaluate.func(*args)