Skip to content

Commit

Permalink
compiler: Move localsize to solverobjs
Browse files Browse the repository at this point in the history
  • Loading branch information
ZoeLeibowitz committed Sep 30, 2024
1 parent 3b0a88b commit ff8d580
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 18 deletions.
1 change: 0 additions & 1 deletion devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def rule1(dep, candidates, loc_dims):
rules = [rule0, rule1]

# Precompute scopes to save time
# TODO: investigate
scopes = {i: Scope([e.expr for e in v if not isinstance(e, Call)])
for i, v in MapNodes().visit(iet).items()}

Expand Down
13 changes: 7 additions & 6 deletions devito/petsc/iet/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

@iet_pass
def lower_petsc(iet, **kwargs):
# Check if PETScSolve was used
injectsolve_mapper = MapNodes(Iteration, InjectSolveDummy,
'groupby').visit(iet)

Expand Down Expand Up @@ -51,14 +52,14 @@ def lower_petsc(iet, **kwargs):
solver_setup = generate_solver_setup(solver_objs, objs, injectsolve)
setup.extend(solver_setup)

# Retrieve the modulo dimensions and map them to their origins based
# on the initial lowering
# Retrieve modulo dimensions to use in callback functions
solver_objs['mod_dims'] = retrieve_mod_dims(iters)
space_iter, = spatial_injectsolve_iter(iters, injectsolve)
# Generate all PETSc callback functions for the target via recusive compilation
matvec_op, formfunc_op, runsolve = builder.make(injectsolve,
objs, solver_objs)
setup.extend([matvec_op, formfunc_op, BlankLine])
# Only want to Transform the spatial iteration loop
space_iter, = spatial_injectsolve_iter(iters, injectsolve)
subs.update({space_iter: List(body=runsolve)})

# Generate callback to populate main struct object
Expand Down Expand Up @@ -113,8 +114,7 @@ def build_core_objects(target, **kwargs):
'size': PetscMPIInt(name='size'),
'comm': communicator,
'err': PetscErrorCode(name='err'),
'grid': target.grid,
'localsize': PetscInt(name='localsize')
'grid': target.grid
}


Expand Down Expand Up @@ -184,7 +184,8 @@ def build_solver_objs(target, **kwargs):
'Y_global': GlobalVec(sreg.make_name(prefix='Y_global_')),
'X_local': LocalVec(sreg.make_name(prefix='X_local_'), liveness='eager'),
'Y_local': LocalVec(sreg.make_name(prefix='Y_local_'), liveness='eager'),
'dummy': DummyArg(sreg.make_name(prefix='dummy_'))
'dummy': DummyArg(sreg.make_name(prefix='dummy_')),
'localsize': PetscInt(sreg.make_name(prefix='localsize_'))
}


Expand Down
20 changes: 14 additions & 6 deletions devito/petsc/iet/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def create_matvec_body(self, injectsolve, body, solver_objs, objs):

dmda = objs['da_so_%s' % linsolveexpr.target.space_order]

body = correct_mod_dims(body, solver_objs['mod_dims'])
body = uxreplace_mod_dims(body, solver_objs['mod_dims'])

struct = build_petsc_struct(body, 'matvec', liveness='eager')

Expand Down Expand Up @@ -219,7 +219,7 @@ def create_formfunc_body(self, injectsolve, body, solver_objs, objs):

dmda = objs['da_so_%s' % linsolveexpr.target.space_order]

body = correct_mod_dims(body, solver_objs['mod_dims'])
body = uxreplace_mod_dims(body, solver_objs['mod_dims'])

struct = build_petsc_struct(body, 'formfunc', liveness='eager')

Expand Down Expand Up @@ -351,7 +351,7 @@ def create_formrhs_body(self, injectsolve, body, solver_objs, objs):
'DMDAGetLocalInfo', [dmda, Byref(dmda.info)]
)

body = correct_mod_dims(body, solver_objs['mod_dims'])
body = uxreplace_mod_dims(body, solver_objs['mod_dims'])

struct = build_petsc_struct(body, 'formrhs', liveness='eager')

Expand Down Expand Up @@ -474,21 +474,29 @@ class StartPtr(LocalObject):
start_ptr = StartPtr(sregistry.make_name(prefix='start_ptr_'))

vec_get_size = petsc_call(
'VecGetSize', [solver_objs['x_local'], Byref(objs['localsize'])]
'VecGetSize', [solver_objs['x_local'], Byref(solver_objs['localsize'])]
)

# TODO: What is the correct way to use Mul here? Devito Mul? Sympy Mul?
field_from_ptr = FieldFromPointer(target._C_field_data, target._C_symbol)
expr = DummyExpr(
start_ptr, BarCast(field_from_ptr, ' *') +
Mul(target_time, objs['localsize']), init=True
Mul(target_time, solver_objs['localsize']), init=True
)

vec_replace_array = petsc_call('VecReplaceArray', [solver_objs['x_local'], start_ptr])
return (vec_get_size, expr, vec_replace_array)


def correct_mod_dims(body, mod_dims):
def uxreplace_mod_dims(body, mod_dims):
"""
Replace ModuloDimensions in callback functions to align/match with the
initial lowering. The ModuloDimensions in callback functions must match
those generated during the initial lowering, as they are assigned and
updated in the struct at each time step. This is a valid uxreplace because
all functions appearing in the callback functions are passed through
the initial lowering.
"""
old_mod_dims = [
i for i in FindSymbols('dimensions').visit(body) if isinstance(i, ModuloDimension)
]
Expand Down
21 changes: 16 additions & 5 deletions tests/test_petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,15 +564,26 @@ def test_start_ptr():
that the correct memory location is accessed and modified during each time step.
"""
grid = Grid((11, 11))
u = TimeFunction(name='u', grid=grid, space_order=2, dtype=np.float32)
eq = Eq(u.dt, u.laplace, subdomain=grid.interior)
petsc = PETScSolve(eq, u.forward)
u1 = TimeFunction(name='u1', grid=grid, space_order=2, dtype=np.float32)
eq1 = Eq(u1.dt, u1.laplace, subdomain=grid.interior)
petsc1 = PETScSolve(eq1, u1.forward)

with switchconfig(openmp=False):
op = Operator(petsc)
op1 = Operator(petsc1)

# Verify the case with modulo time stepping
assert 'float * start_ptr_0 = t1*localsize + (float *)(u_vec->data);' in str(op)
assert 'float * start_ptr_0 = t1*localsize + (float *)(u1_vec->data);' in str(op1)

# Verify the case with no modulo time stepping
u2 = TimeFunction(name='u2', grid=grid, space_order=2, dtype=np.float32, save=5)
eq2 = Eq(u2.dt, u2.laplace, subdomain=grid.interior)
petsc2 = PETScSolve(eq2, u2.forward)

with switchconfig(openmp=False):
op2 = Operator(petsc2)

assert 'float * start_ptr_0 = (time + 1)*localsize + ' + \
'(float *)(u2_vec->data);' in str(op2)


@skipif('petsc')
Expand Down

0 comments on commit ff8d580

Please sign in to comment.