Skip to content

Commit

Permalink
compiler: Correct the modulo dims in all callbacks to match initial l…
Browse files Browse the repository at this point in the history
…owering mapper with mod dims and origin
  • Loading branch information
ZoeLeibowitz committed Sep 26, 2024
1 parent 6aada57 commit 165662b
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 79 deletions.
2 changes: 1 addition & 1 deletion devito/petsc/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ def callback_form(self):


class PETScCall(Call):
pass
pass
18 changes: 10 additions & 8 deletions devito/petsc/iet/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from devito.petsc.utils import solver_mapper, core_metadata
from devito.petsc.iet.routines import PETScCallbackBuilder
from devito.petsc.iet.utils import (petsc_call, petsc_call_mpi, petsc_struct,
spatial_iteration_loops, assign_time_iters)
spatial_injectsolve_iter, assign_time_iters,
retrieve_mod_dims)


@iet_pass
Expand All @@ -33,11 +34,8 @@ def lower_petsc(iet, **kwargs):
# Create core PETSc calls (not specific to each PETScSolve)
core = make_core_petsc_calls(objs, **kwargs)

# Create injectsolve mapper from the spatial iteration loops
# (exclude time loop if present)
spatial_body = spatial_iteration_loops(iet)
injectsolve_mapper = MapNodes(Iteration, InjectSolveDummy,
'groupby').visit(List(body=spatial_body))
'groupby').visit(iet)

setup = []
subs = {}
Expand Down Expand Up @@ -65,13 +63,17 @@ def lower_petsc(iet, **kwargs):
break

# Generate all PETSc callback functions for the target via recusive compilation
for iter, (injectsolve,) in injectsolve_mapper.items():
for iters, (injectsolve,) in injectsolve_mapper.items():
if injectsolve.expr.rhs.target != target:
continue
# Retrieve the modulo dimensions and map them to their origins based
# on the initial lowering
solver_objs['mod_dims'] = retrieve_mod_dims(iters)
space_iter, = spatial_injectsolve_iter(iters, injectsolve)
matvec_op, formfunc_op, runsolve = builder.make(injectsolve,
objs, solver_objs)
setup.extend([matvec_op, formfunc_op])
subs.update({iter[0]: List(body=runsolve)})
subs.update({space_iter: List(body=runsolve)})
break

# Generate callback to populate main struct object
Expand All @@ -83,7 +85,7 @@ def lower_petsc(iet, **kwargs):
setup.extend([BlankLine, call_struct_callback] + calls_set_app_ctx)

iet = Transformer(subs).visit(iet)

iet = assign_time_iters(iet, struct_main)

body = core + tuple(setup) + (BlankLine,) + iet.body.body
Expand Down
13 changes: 13 additions & 0 deletions devito/petsc/iet/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def create_matvec_body(self, injectsolve, body, solver_objs, objs):

dmda = objs['da_so_%s' % linsolveexpr.target.space_order]

body = correct_mod_dims(body, solver_objs['mod_dims'])

struct = build_petsc_struct(body, 'matvec', liveness='eager')

y_matvec = linsolveexpr.arrays['y_matvec']
Expand Down Expand Up @@ -217,6 +219,8 @@ def create_formfunc_body(self, injectsolve, body, solver_objs, objs):

dmda = objs['da_so_%s' % linsolveexpr.target.space_order]

body = correct_mod_dims(body, solver_objs['mod_dims'])

struct = build_petsc_struct(body, 'formfunc', liveness='eager')

y_formfunc = linsolveexpr.arrays['y_formfunc']
Expand Down Expand Up @@ -348,6 +352,8 @@ def create_formrhs_body(self, injectsolve, body, solver_objs, objs):
'DMDAGetLocalInfo', [dmda, Byref(dmda.info)]
)

body = correct_mod_dims(body, solver_objs['mod_dims'])

struct = build_petsc_struct(body, 'formrhs', liveness='eager')

dm_get_app_context = petsc_call(
Expand Down Expand Up @@ -481,6 +487,13 @@ class StartPtr(LocalObject):
return (vec_get_size, expr, vec_replace_array)


def correct_mod_dims(body, mod_dims):
old_mod_dims = [
i for i in FindSymbols('dimensions').visit(body) if isinstance(i, ModuloDimension)
]
return Uxreplace({i: mod_dims[i.origin] for i in old_mod_dims}).visit(body)


Null = Macro('NULL')
void = 'void'

Expand Down
28 changes: 19 additions & 9 deletions devito/petsc/iet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ def petsc_struct(name, fields, liveness='lazy'):
fields=fields, liveness=liveness)


def spatial_iteration_loops(iet):
def spatial_injectsolve_iter(iter, injectsolve):
spatial_body = []
for tree in retrieve_iteration_tree(iet):
for tree in retrieve_iteration_tree(iter[0]):
root = filter_iterations(tree, key=lambda i: i.dim.is_Space)[0]
spatial_body.append(root)
if injectsolve in FindNodes(InjectSolveDummy).visit(root):
spatial_body.append(root)
return spatial_body


Expand Down Expand Up @@ -64,10 +65,19 @@ def assign_time_iters(iet, struct):

mapper = {}
for iter in time_iters:
common_dimensions = [dim for dim in iter.dimensions if dim in struct.fields]
common_dimensions = [DummyExpr(FieldFromComposite(dim, struct), dim)
for dim in common_dimensions]
iter_new = iter._rebuild(nodes=List(body=tuple(common_dimensions)+iter.nodes))
common_dims = [dim for dim in iter.dimensions if dim in struct.fields]
common_dims = [
DummyExpr(FieldFromComposite(dim, struct), dim) for dim in common_dims
]
iter_new = iter._rebuild(nodes=List(body=tuple(common_dims)+iter.nodes))
mapper.update({iter: iter_new})
from IPython import embed; embed()
return Transformer(mapper).visit(iet)

return Transformer(mapper).visit(iet)


def retrieve_mod_dims(iters):
outer_iter_dims = iters[0].dimensions
if any(dim.is_Time for dim in outer_iter_dims):
mod_dims = [dim for dim in outer_iter_dims if dim.is_Modulo]
return {dim.origin: dim for dim in mod_dims}
return {}
93 changes: 32 additions & 61 deletions devito/petsc/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from devito.types.equation import InjectSolveEq
from devito.operations.solve import eval_time_derivatives
from devito.symbolics import retrieve_functions
from devito.tools import filter_sorted
from devito.petsc.types import LinearSolveExpr, PETScArray, CallbackExpr


__all__ = ['PETScSolve']


def PETScSolve(eq, target, bcs=None, solver_parameters=None, **kwargs):
def PETScSolve(eqns, target, solver_parameters=None, **kwargs):
prefixes = ['y_matvec', 'x_matvec', 'y_formfunc', 'x_formfunc', 'b_tmp']

arrays = {
Expand All @@ -28,74 +29,44 @@ def PETScSolve(eq, target, bcs=None, solver_parameters=None, **kwargs):
for p in prefixes
}

b, F_target = separate_eqn(eq, target)

eqns = eqns if isinstance(eqns, (list, tuple)) else [eqns]
# Passed through main kernel and removed at iet level, used to generate
# correct time loop etc
dummy = list(set(retrieve_functions(F_target - b)))
dummy_expr = sum(dummy)

# TODO: Current assumption is that problem is linear and user has not provided
# a jacobian. Hence, we can use F_target to form the jac-vec product

matvecaction = Eq(
arrays['y_matvec'],
CallbackExpr(F_target.subs({target: arrays['x_matvec']}), *dummy),
subdomain=eq.subdomain
)

formfunction = Eq(
arrays['y_formfunc'],
CallbackExpr(F_target.subs({target: arrays['x_formfunc']}), *dummy),
subdomain=eq.subdomain
)

rhs = Eq(
arrays['b_tmp'],
CallbackExpr(b, *dummy),
subdomain=eq.subdomain
)
dummy_expr = sum(filter_sorted(retrieve_functions(eqns)))

# Placeholder equation for inserting calls to the solver
inject_solve = InjectSolveEq(target, LinearSolveExpr(
dummy_expr, target=target, solver_parameters=solver_parameters,
matvecs=[matvecaction], formfuncs=[formfunction],
formrhs=[rhs], arrays=arrays,
), subdomain=eq.subdomain)
matvecs = []
formfuncs = []
formrhs = []

for eq in eqns:
b, F_target = separate_eqn(eq, target)

if not bcs:
return [inject_solve]

# NOTE: BELOW IS NOT FULLY TESTED/IMPLEMENTED YET
bcs_for_matvec = []
bcs_for_formfunc = []
bcs_for_rhs = []
for bc in bcs:
# TODO: Insert code to distiguish between essential and natural
# boundary conditions since these are treated differently within
# the solver
# NOTE: May eventually remove the essential bcs from the solve
# (and move to rhs) but for now, they are included since this
# is not trivial to implement when using DMDA
# NOTE: Below is temporary -> Just using this as a palceholder for
# the actual BC implementation
centre = centre_stencil(F_target, target)
bcs_for_matvec.append(Eq(
# TODO: Current assumption is that problem is linear and user has not provided
# a jacobian. Hence, we can use F_target to form the jac-vec product

matvecs.append(Eq(
arrays['y_matvec'],
CallbackExpr(centre.subs({target: arrays['x_matvec']}), *dummy),
subdomain=bc.subdomain
CallbackExpr(F_target.subs({target: arrays['x_matvec']})),
subdomain=eq.subdomain
))

formfuncs.append(Eq(
arrays['y_formfunc'],
CallbackExpr(F_target.subs({target: arrays['x_formfunc']})),
subdomain=eq.subdomain
))
# NOTE: Temporary
bcs_for_formfunc.append(Eq(arrays['y_formfunc'],
0., subdomain=bc.subdomain))
bcs_for_rhs.append(Eq(arrays['b_tmp'], 0., subdomain=bc.subdomain))

formrhs.append(Eq(
arrays['b_tmp'],
CallbackExpr(b),
subdomain=eq.subdomain
))

# Placeholder equation for inserting calls to the solver
inject_solve = InjectSolveEq(target, LinearSolveExpr(
dummy_expr, target=target, solver_parameters=solver_parameters,
matvecs=[matvecaction]+bcs_for_matvec,
formfuncs=[formfunction],
formrhs=[rhs],
arrays=arrays,
matvecs=matvecs, formfuncs=formfuncs,
formrhs=formrhs, arrays=arrays,
), subdomain=eq.subdomain)

return [inject_solve]
Expand Down Expand Up @@ -180,4 +151,4 @@ def _(expr, target):
if not expr.has(target):
return 0
args = [centre_stencil(a, target) for a in expr.evaluate.args]
return expr.evaluate.func(*args)
return expr.evaluate.func(*args)

0 comments on commit 165662b

Please sign in to comment.