diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index ac800606ee..b7f6572a26 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -79,7 +79,6 @@ def rule1(dep, candidates, loc_dims): rules = [rule0, rule1] # Precompute scopes to save time - # TODO: investigate scopes = {i: Scope([e.expr for e in v if not isinstance(e, Call)]) for i, v in MapNodes().visit(iet).items()} diff --git a/devito/petsc/iet/passes.py b/devito/petsc/iet/passes.py index 7c2f8da5ee..895e1e8233 100644 --- a/devito/petsc/iet/passes.py +++ b/devito/petsc/iet/passes.py @@ -17,6 +17,7 @@ @iet_pass def lower_petsc(iet, **kwargs): + # Check if PETScSolve was used injectsolve_mapper = MapNodes(Iteration, InjectSolveDummy, 'groupby').visit(iet) @@ -51,14 +52,14 @@ def lower_petsc(iet, **kwargs): solver_setup = generate_solver_setup(solver_objs, objs, injectsolve) setup.extend(solver_setup) - # Retrieve the modulo dimensions and map them to their origins based - # on the initial lowering + # Retrieve modulo dimensions to use in callback functions solver_objs['mod_dims'] = retrieve_mod_dims(iters) - space_iter, = spatial_injectsolve_iter(iters, injectsolve) # Generate all PETSc callback functions for the target via recusive compilation matvec_op, formfunc_op, runsolve = builder.make(injectsolve, objs, solver_objs) setup.extend([matvec_op, formfunc_op, BlankLine]) + # Only want to Transform the spatial iteration loop + space_iter, = spatial_injectsolve_iter(iters, injectsolve) subs.update({space_iter: List(body=runsolve)}) # Generate callback to populate main struct object @@ -113,8 +114,7 @@ def build_core_objects(target, **kwargs): 'size': PetscMPIInt(name='size'), 'comm': communicator, 'err': PetscErrorCode(name='err'), - 'grid': target.grid, - 'localsize': PetscInt(name='localsize') + 'grid': target.grid } @@ -184,7 +184,8 @@ def build_solver_objs(target, **kwargs): 'Y_global': GlobalVec(sreg.make_name(prefix='Y_global_')), 'X_local': LocalVec(sreg.make_name(prefix='X_local_'), liveness='eager'), 'Y_local': LocalVec(sreg.make_name(prefix='Y_local_'), liveness='eager'), - 'dummy': DummyArg(sreg.make_name(prefix='dummy_')) + 'dummy': DummyArg(sreg.make_name(prefix='dummy_')), + 'localsize': PetscInt(sreg.make_name(prefix='localsize_')) } diff --git a/devito/petsc/iet/routines.py b/devito/petsc/iet/routines.py index 4c8671e622..204e853baa 100644 --- a/devito/petsc/iet/routines.py +++ b/devito/petsc/iet/routines.py @@ -90,7 +90,7 @@ def create_matvec_body(self, injectsolve, body, solver_objs, objs): dmda = objs['da_so_%s' % linsolveexpr.target.space_order] - body = correct_mod_dims(body, solver_objs['mod_dims']) + body = uxreplace_mod_dims(body, solver_objs['mod_dims']) struct = build_petsc_struct(body, 'matvec', liveness='eager') @@ -219,7 +219,7 @@ def create_formfunc_body(self, injectsolve, body, solver_objs, objs): dmda = objs['da_so_%s' % linsolveexpr.target.space_order] - body = correct_mod_dims(body, solver_objs['mod_dims']) + body = uxreplace_mod_dims(body, solver_objs['mod_dims']) struct = build_petsc_struct(body, 'formfunc', liveness='eager') @@ -351,7 +351,7 @@ def create_formrhs_body(self, injectsolve, body, solver_objs, objs): 'DMDAGetLocalInfo', [dmda, Byref(dmda.info)] ) - body = correct_mod_dims(body, solver_objs['mod_dims']) + body = uxreplace_mod_dims(body, solver_objs['mod_dims']) struct = build_petsc_struct(body, 'formrhs', liveness='eager') @@ -474,21 +474,29 @@ class StartPtr(LocalObject): start_ptr = StartPtr(sregistry.make_name(prefix='start_ptr_')) vec_get_size = petsc_call( - 'VecGetSize', [solver_objs['x_local'], Byref(objs['localsize'])] + 'VecGetSize', [solver_objs['x_local'], Byref(solver_objs['localsize'])] ) # TODO: What is the correct way to use Mul here? Devito Mul? Sympy Mul? field_from_ptr = FieldFromPointer(target._C_field_data, target._C_symbol) expr = DummyExpr( start_ptr, BarCast(field_from_ptr, ' *') + - Mul(target_time, objs['localsize']), init=True + Mul(target_time, solver_objs['localsize']), init=True ) vec_replace_array = petsc_call('VecReplaceArray', [solver_objs['x_local'], start_ptr]) return (vec_get_size, expr, vec_replace_array) -def correct_mod_dims(body, mod_dims): +def uxreplace_mod_dims(body, mod_dims): + """ + Replace ModuloDimensions in callback functions to align/match with the + initial lowering. The ModuloDimensions in callback functions must match + those generated during the initial lowering, as they are assigned and + updated in the struct at each time step. This is a valid uxreplace because + all functions appearing in the callback functions are passed through + the initial lowering. + """ old_mod_dims = [ i for i in FindSymbols('dimensions').visit(body) if isinstance(i, ModuloDimension) ] diff --git a/tests/test_petsc.py b/tests/test_petsc.py index 2d7aa3ad0d..ddd30b90b4 100644 --- a/tests/test_petsc.py +++ b/tests/test_petsc.py @@ -564,15 +564,26 @@ def test_start_ptr(): that the correct memory location is accessed and modified during each time step. """ grid = Grid((11, 11)) - u = TimeFunction(name='u', grid=grid, space_order=2, dtype=np.float32) - eq = Eq(u.dt, u.laplace, subdomain=grid.interior) - petsc = PETScSolve(eq, u.forward) + u1 = TimeFunction(name='u1', grid=grid, space_order=2, dtype=np.float32) + eq1 = Eq(u1.dt, u1.laplace, subdomain=grid.interior) + petsc1 = PETScSolve(eq1, u1.forward) with switchconfig(openmp=False): - op = Operator(petsc) + op1 = Operator(petsc1) # Verify the case with modulo time stepping - assert 'float * start_ptr_0 = t1*localsize + (float *)(u_vec->data);' in str(op) + assert 'float * start_ptr_0 = t1*localsize + (float *)(u1_vec->data);' in str(op1) + + # Verify the case with no modulo time stepping + u2 = TimeFunction(name='u2', grid=grid, space_order=2, dtype=np.float32, save=5) + eq2 = Eq(u2.dt, u2.laplace, subdomain=grid.interior) + petsc2 = PETScSolve(eq2, u2.forward) + + with switchconfig(openmp=False): + op2 = Operator(petsc2) + + assert 'float * start_ptr_0 = (time + 1)*localsize + ' + \ + '(float *)(u2_vec->data);' in str(op2) @skipif('petsc')