Skip to content

Commit

Permalink
compiler: Adjust pass so that the same function object can be solved …
Browse files Browse the repository at this point in the history
…for more than once within same time loop
  • Loading branch information
ZoeLeibowitz committed Sep 30, 2024
1 parent 26f6760 commit 3b0a88b
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 79 deletions.
95 changes: 41 additions & 54 deletions devito/petsc/iet/passes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import cgen as c

from devito.passes.iet.engine import iet_pass
from devito.ir.iet import (FindNodes, Transformer,
MapNodes, Iteration, List, BlankLine,
from devito.ir.iet import (Transformer, MapNodes, Iteration, List, BlankLine,
Callable, CallableBody, DummyExpr, Call)
from devito.symbolics import Byref, Macro, FieldFromPointer
from devito.tools import filter_ordered
Expand All @@ -18,71 +17,57 @@

@iet_pass
def lower_petsc(iet, **kwargs):
injectsolve_mapper = MapNodes(Iteration, InjectSolveDummy,
'groupby').visit(iet)

# Check if PETScSolve was used
petsc_nodes = FindNodes(InjectSolveDummy).visit(iet)

if not petsc_nodes:
if not injectsolve_mapper:
return iet, {}

unique_targets = list({i.expr.rhs.target for i in petsc_nodes})
targets = [i.expr.rhs.target for (i,) in injectsolve_mapper.values()]
init = init_petsc(**kwargs)

# Assumption is that all targets have the same grid so can use any target here
objs = build_core_objects(unique_targets[-1], **kwargs)
objs = build_core_objects(targets[-1], **kwargs)

# Create core PETSc calls (not specific to each PETScSolve)
core = make_core_petsc_calls(objs, **kwargs)

injectsolve_mapper = MapNodes(Iteration, InjectSolveDummy,
'groupby').visit(iet)

setup = []
subs = {}

# Create a different DMDA for each target with a unique space order
unique_dmdas = create_dmda_objs(unique_targets)
unique_dmdas = create_dmda_objs(targets)
objs.update(unique_dmdas)
for dmda in unique_dmdas.values():
setup.extend(create_dmda_calls(dmda, objs))

builder = PETScCallbackBuilder(**kwargs)

# Create the PETSc calls which are specific to each target
for target in unique_targets:
solver_objs = build_solver_objs(target)

# Generate the solver setup for target. This is required only
# once per target
for (injectsolve,) in injectsolve_mapper.values():
# Skip if not associated with the target
if injectsolve.expr.rhs.target != target:
continue
solver_setup = generate_solver_setup(solver_objs, objs, injectsolve, target)
setup.extend(solver_setup)
break
for iters, (injectsolve,) in injectsolve_mapper.items():
target = injectsolve.expr.rhs.target
solver_objs = build_solver_objs(target, **kwargs)

# Generate the solver setup for each target
solver_setup = generate_solver_setup(solver_objs, objs, injectsolve)
setup.extend(solver_setup)

# Retrieve the modulo dimensions and map them to their origins based
# on the initial lowering
solver_objs['mod_dims'] = retrieve_mod_dims(iters)
space_iter, = spatial_injectsolve_iter(iters, injectsolve)
# Generate all PETSc callback functions for the target via recusive compilation
for iters, (injectsolve,) in injectsolve_mapper.items():
if injectsolve.expr.rhs.target != target:
continue
# Retrieve the modulo dimensions and map them to their origins based
# on the initial lowering
solver_objs['mod_dims'] = retrieve_mod_dims(iters)
space_iter, = spatial_injectsolve_iter(iters, injectsolve)
matvec_op, formfunc_op, runsolve = builder.make(injectsolve,
objs, solver_objs)
setup.extend([matvec_op, formfunc_op])
subs.update({space_iter: List(body=runsolve)})
break
matvec_op, formfunc_op, runsolve = builder.make(injectsolve,
objs, solver_objs)
setup.extend([matvec_op, formfunc_op, BlankLine])
subs.update({space_iter: List(body=runsolve)})

# Generate callback to populate main struct object
struct_main = petsc_struct('ctx', filter_ordered(builder.struct_params))
struct_callback = generate_struct_callback(struct_main)
call_struct_callback = petsc_call(struct_callback.name, [Byref(struct_main)])
calls_set_app_ctx = [petsc_call('DMSetApplicationContext', [i, Byref(struct_main)])
for i in unique_dmdas]
setup.extend([BlankLine, call_struct_callback] + calls_set_app_ctx)
setup.extend([call_struct_callback] + calls_set_app_ctx)

iet = Transformer(subs).visit(iet)

Expand Down Expand Up @@ -184,26 +169,28 @@ def create_dmda(dmda, objs):
return dmda


def build_solver_objs(target):
name = target.name
def build_solver_objs(target, **kwargs):
sreg = kwargs['sregistry']
return {
'Jac': Mat(name='J_%s' % name),
'x_global': GlobalVec(name='x_global_%s' % name),
'x_local': LocalVec(name='x_local_%s' % name, liveness='eager'),
'b_global': GlobalVec(name='b_global_%s' % name),
'b_local': LocalVec(name='b_local_%s' % name),
'ksp': KSP(name='ksp_%s' % name),
'pc': PC(name='pc_%s' % name),
'snes': SNES(name='snes_%s' % name),
'X_global': GlobalVec(name='X_global_%s' % name),
'Y_global': GlobalVec(name='Y_global_%s' % name),
'X_local': LocalVec(name='X_local_%s' % name, liveness='eager'),
'Y_local': LocalVec(name='Y_local_%s' % name, liveness='eager'),
'dummy': DummyArg(name='dummy_%s' % name)
'Jac': Mat(sreg.make_name(prefix='J_')),
'x_global': GlobalVec(sreg.make_name(prefix='x_global_')),
'x_local': LocalVec(sreg.make_name(prefix='x_local_'), liveness='eager'),
'b_global': GlobalVec(sreg.make_name(prefix='b_global_')),
'b_local': LocalVec(sreg.make_name(prefix='b_local_')),
'ksp': KSP(sreg.make_name(prefix='ksp_')),
'pc': PC(sreg.make_name(prefix='pc_')),
'snes': SNES(sreg.make_name(prefix='snes_')),
'X_global': GlobalVec(sreg.make_name(prefix='X_global_')),
'Y_global': GlobalVec(sreg.make_name(prefix='Y_global_')),
'X_local': LocalVec(sreg.make_name(prefix='X_local_'), liveness='eager'),
'Y_local': LocalVec(sreg.make_name(prefix='Y_local_'), liveness='eager'),
'dummy': DummyArg(sreg.make_name(prefix='dummy_'))
}


def generate_solver_setup(solver_objs, objs, injectsolve, target):
def generate_solver_setup(solver_objs, objs, injectsolve):
target = injectsolve.expr.rhs.target

dmda = objs['da_so_%s' % target.space_order]

solver_params = injectsolve.expr.rhs.solver_parameters
Expand Down
19 changes: 10 additions & 9 deletions devito/petsc/iet/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def make_all(self, injectsolve, objs, solver_objs):
return matvec_callback, formfunc_callback, formrhs_callback

def make_matvec(self, injectsolve, objs, solver_objs):
target = injectsolve.expr.rhs.target
# Compile matvec `eqns` into an IET via recursive compilation
irs_matvec, _ = self.rcompile(injectsolve.expr.rhs.matvecs,
options={'mpi': False}, sregistry=SymbolRegistry())
Expand All @@ -76,7 +75,8 @@ def make_matvec(self, injectsolve, objs, solver_objs):
solver_objs, objs)

matvec_callback = PETScCallable(
'MyMatShellMult_%s' % target.name, body_matvec, retval=objs['err'],
self.sregistry.make_name(prefix='MyMatShellMult_'), body_matvec,
retval=objs['err'],
parameters=(
solver_objs['Jac'], solver_objs['X_global'], solver_objs['Y_global']
)
Expand Down Expand Up @@ -195,7 +195,6 @@ def create_matvec_body(self, injectsolve, body, solver_objs, objs):
return matvec_body

def make_formfunc(self, injectsolve, objs, solver_objs):
target = injectsolve.expr.rhs.target
# Compile formfunc `eqns` into an IET via recursive compilation
irs_formfunc, _ = self.rcompile(
injectsolve.expr.rhs.formfuncs,
Expand All @@ -206,7 +205,8 @@ def make_formfunc(self, injectsolve, objs, solver_objs):
solver_objs, objs)

formfunc_callback = PETScCallable(
'FormFunction_%s' % target.name, body_formfunc, retval=objs['err'],
self.sregistry.make_name(prefix='FormFunction_'), body_formfunc,
retval=objs['err'],
parameters=(solver_objs['snes'], solver_objs['X_global'],
solver_objs['Y_global'], solver_objs['dummy'])
)
Expand Down Expand Up @@ -316,7 +316,6 @@ def create_formfunc_body(self, injectsolve, body, solver_objs, objs):
return formfunc_body

def make_formrhs(self, injectsolve, objs, solver_objs):
target = injectsolve.expr.rhs.target
# Compile formrhs `eqns` into an IET via recursive compilation
irs_formrhs, _ = self.rcompile(injectsolve.expr.rhs.formrhs,
options={'mpi': False}, sregistry=SymbolRegistry())
Expand All @@ -325,7 +324,7 @@ def make_formrhs(self, injectsolve, objs, solver_objs):
solver_objs, objs)

formrhs_callback = PETScCallable(
'FormRHS_%s' % target.name, body_formrhs, retval=objs['err'],
self.sregistry.make_name(prefix='FormRHS_'), body_formrhs, retval=objs['err'],
parameters=(
solver_objs['snes'], solver_objs['b_local']
)
Expand Down Expand Up @@ -404,7 +403,9 @@ def runsolve(self, solver_objs, objs, rhs_callback, injectsolve):
[dmda, Byref(solver_objs['x_local'])])

if any(i.is_Time for i in target.dimensions):
vec_replace_array = time_dep_replace(injectsolve, target, solver_objs, objs)
vec_replace_array = time_dep_replace(
injectsolve, target, solver_objs, objs, self.sregistry
)
else:
field_from_ptr = FieldFromPointer(target._C_field_data, target._C_symbol)
vec_replace_array = (petsc_call(
Expand Down Expand Up @@ -455,7 +456,7 @@ def build_petsc_struct(iet, name, liveness):
return petsc_struct(name, fields, liveness)


def time_dep_replace(injectsolve, target, solver_objs, objs):
def time_dep_replace(injectsolve, target, solver_objs, objs, sregistry):
target_time = injectsolve.expr.lhs
target_time = [i for i, d in zip(target_time.indices,
target_time.dimensions) if d.is_Time]
Expand All @@ -470,7 +471,7 @@ class BarCast(Cast):
class StartPtr(LocalObject):
dtype = CustomDtype(ctype_str, modifier=' *')

start_ptr = StartPtr(name='start_ptr_%s' % target.name)
start_ptr = StartPtr(sregistry.make_name(prefix='start_ptr_'))

vec_get_size = petsc_call(
'VecGetSize', [solver_objs['x_local'], Byref(objs['localsize'])]
Expand Down
33 changes: 17 additions & 16 deletions tests/test_petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def test_petsc_solve():

callable_roots = [meta_call.root for meta_call in op._func_table.values()]

matvec_callback = [root for root in callable_roots if root.name == 'MyMatShellMult_f']
matvec_callback = [root for root in callable_roots if root.name == 'MyMatShellMult_0']

formrhs_callback = [root for root in callable_roots if root.name == 'FormRHS_f']
formrhs_callback = [root for root in callable_roots if root.name == 'FormRHS_0']

action_expr = FindNodes(Expression).visit(matvec_callback[0])
rhs_expr = FindNodes(Expression).visit(formrhs_callback[0])
Expand Down Expand Up @@ -439,14 +439,14 @@ def test_callback_arguments():
with switchconfig(openmp=False):
op = Operator(petsc1)

mv = op._func_table['MyMatShellMult_f1'].root
ff = op._func_table['FormFunction_f1'].root
mv = op._func_table['MyMatShellMult_0'].root
ff = op._func_table['FormFunction_0'].root

assert len(mv.parameters) == 3
assert len(ff.parameters) == 4

assert str(mv.parameters) == '(J_f1, X_global_f1, Y_global_f1)'
assert str(ff.parameters) == '(snes_f1, X_global_f1, Y_global_f1, dummy_f1)'
assert str(mv.parameters) == '(J_0, X_global_0, Y_global_0)'
assert str(ff.parameters) == '(snes_0, X_global_0, Y_global_0, dummy_0)'


@skipif('petsc')
Expand Down Expand Up @@ -526,10 +526,10 @@ def test_petsc_frees():
frees = op.body.frees

# Check the frees appear in the following order
assert str(frees[0]) == 'PetscCall(VecDestroy(&(b_global_f)));'
assert str(frees[1]) == 'PetscCall(VecDestroy(&(x_global_f)));'
assert str(frees[2]) == 'PetscCall(MatDestroy(&(J_f)));'
assert str(frees[3]) == 'PetscCall(SNESDestroy(&(snes_f)));'
assert str(frees[0]) == 'PetscCall(VecDestroy(&(b_global_0)));'
assert str(frees[1]) == 'PetscCall(VecDestroy(&(x_global_0)));'
assert str(frees[2]) == 'PetscCall(MatDestroy(&(J_0)));'
assert str(frees[3]) == 'PetscCall(SNESDestroy(&(snes_0)));'
assert str(frees[4]) == 'PetscCall(DMDestroy(&(da_so_2)));'


Expand All @@ -549,8 +549,8 @@ def test_calls_to_callbacks():

ccode = str(op.ccode)

assert '(void (*)(void))MyMatShellMult_f' in ccode
assert 'PetscCall(SNESSetFunction(snes_f,NULL,FormFunction_f,NULL));' in ccode
assert '(void (*)(void))MyMatShellMult_0' in ccode
assert 'PetscCall(SNESSetFunction(snes_0,NULL,FormFunction_0,NULL));' in ccode


@skipif('petsc')
Expand All @@ -572,7 +572,7 @@ def test_start_ptr():
op = Operator(petsc)

# Verify the case with modulo time stepping
assert 'float * start_ptr_u = t1*localsize + (float *)(u_vec->data);' in str(op)
assert 'float * start_ptr_0 = t1*localsize + (float *)(u_vec->data);' in str(op)


@skipif('petsc')
Expand All @@ -593,7 +593,7 @@ def test_time_loop():
petsc1 = PETScSolve(eq1, v1)
op1 = Operator(petsc1)
body1 = str(op1.body)
rhs1 = str(op1._func_table['FormRHS_v1'].root.ccode)
rhs1 = str(op1._func_table['FormRHS_0'].root.ccode)

assert 'ctx.t0 = t0' in body1
assert 'ctx.t1 = t1' not in body1
Expand All @@ -607,7 +607,7 @@ def test_time_loop():
petsc2 = PETScSolve(eq2, v2)
op2 = Operator(petsc2)
body2 = str(op2.body)
rhs2 = str(op2._func_table['FormRHS_v2'].root.ccode)
rhs2 = str(op2._func_table['FormRHS_0'].root.ccode)

assert 'ctx.time = time' in body2
assert 'formrhs->time' in rhs2
Expand All @@ -618,7 +618,7 @@ def test_time_loop():
petsc3 = PETScSolve(eq3, v1)
op3 = Operator(petsc3)
body3 = str(op3.body)
rhs3 = str(op3._func_table['FormRHS_v1'].root.ccode)
rhs3 = str(op3._func_table['FormRHS_0'].root.ccode)

assert 'ctx.t0 = t0' in body3
assert 'ctx.t1 = t1' in body3
Expand All @@ -635,3 +635,4 @@ def test_time_loop():
body4 = str(op4.body)

assert 'ctx.t0 = t0' in body4
assert body4.count('ctx.t0 = t0') == 1

0 comments on commit 3b0a88b

Please sign in to comment.