Skip to content

Commit

Permalink
tests: Fix tests and add apply one back
Browse files Browse the repository at this point in the history
  • Loading branch information
ZoeLeibowitz committed Oct 3, 2024
1 parent ea17d4f commit e12261a
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 33 deletions.
1 change: 0 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from devito.petsc.iet.passes import lower_petsc
from devito.petsc.clusters import petsc_lift


__all__ = ['Operator']


Expand Down
17 changes: 9 additions & 8 deletions devito/petsc/iet/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,13 @@ def lower_petsc(iet, **kwargs):
builder = PETScCallbackBuilder(**kwargs)

for iters, (injectsolve,) in injectsolve_mapper.items():
target = injectsolve.expr.rhs.target
solver_objs = build_solver_objs(target, **kwargs)
# target = injectsolve.expr.rhs.target
solver_objs = build_solver_objs(injectsolve, iters, **kwargs)

# Generate the solver setup for each InjectSolveDummy
solver_setup = generate_solver_setup(solver_objs, objs, injectsolve)
setup.extend(solver_setup)

solver_objs['true_dims'] = retrieve_time_dims(iters)
solver_objs['time_mapper'] = injectsolve.expr.rhs.time_mapper
solver_objs['target'] = target
# Generate all PETSc callback functions for the target via recursive compilation
matvec_op, formfunc_op, runsolve = builder.make(injectsolve,
objs, solver_objs)
Expand Down Expand Up @@ -171,7 +168,8 @@ def create_dmda(dmda, objs):
return dmda


def build_solver_objs(target, **kwargs):
def build_solver_objs(injectsolve, iters, **kwargs):
target = injectsolve.expr.rhs.target
sreg = kwargs['sregistry']
return {
'Jac': Mat(sreg.make_name(prefix='J_')),
Expand All @@ -188,12 +186,15 @@ def build_solver_objs(target, **kwargs):
'Y_local': LocalVec(sreg.make_name(prefix='Y_local_'), liveness='eager'),
'dummy': DummyArg(sreg.make_name(prefix='dummy_')),
'localsize': PetscInt(sreg.make_name(prefix='localsize_')),
'start_ptr': StartPtr(sreg.make_name(prefix='start_ptr_'), target.dtype)
'start_ptr': StartPtr(sreg.make_name(prefix='start_ptr_'), target.dtype),
'true_dims': retrieve_time_dims(iters),
'target': target,
'time_mapper': injectsolve.expr.rhs.time_mapper,
}


def generate_solver_setup(solver_objs, objs, injectsolve):
target = injectsolve.expr.rhs.target
target = solver_objs['target']

dmda = objs['da_so_%s' % target.space_order]

Expand Down
4 changes: 2 additions & 2 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def xreplace_indices(exprs, mapper, key=None):
handle = [i for i in handle if i.base.label in key]
elif callable(key):
handle = [i for i in handle if key(i)]
mapper_new = dict(zip(handle, [i.xreplace(mapper) for i in handle]))
replaced = [uxreplace(i, mapper_new) for i in as_tuple(exprs)]
mapper = dict(zip(handle, [i.xreplace(mapper) for i in handle]))
replaced = [uxreplace(i, mapper) for i in as_tuple(exprs)]
return replaced if isinstance(exprs, Iterable) else replaced[0]


Expand Down
48 changes: 26 additions & 22 deletions tests/test_petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,31 +482,31 @@ def test_petsc_struct():
assert all(not isinstance(i, CCompositeObject) for i in op.parameters)


# @skipif('petsc')
# @pytest.mark.parallel(mode=[2, 4, 8])
# def test_apply(mode):
@skipif('petsc')
@pytest.mark.parallel(mode=[2, 4, 8])
def test_apply(mode):

# grid = Grid(shape=(13, 13), dtype=np.float64)
grid = Grid(shape=(13, 13), dtype=np.float64)

# pn = Function(name='pn', grid=grid, space_order=2, dtype=np.float64)
# rhs = Function(name='rhs', grid=grid, space_order=2, dtype=np.float64)
# mu = Constant(name='mu', value=2.0)
pn = Function(name='pn', grid=grid, space_order=2, dtype=np.float64)
rhs = Function(name='rhs', grid=grid, space_order=2, dtype=np.float64)
mu = Constant(name='mu', value=2.0)

# eqn = Eq(pn.laplace*mu, rhs, subdomain=grid.interior)
eqn = Eq(pn.laplace*mu, rhs, subdomain=grid.interior)

# petsc = PETScSolve(eqn, pn)
petsc = PETScSolve(eqn, pn)

# # Build the op
# with switchconfig(openmp=False, mpi=True):
# op = Operator(petsc)
# Build the op
with switchconfig(openmp=False, mpi=True):
op = Operator(petsc)

# # Check the Operator runs without errors. Not verifying output for
# # now. Need to consolidate BC implementation
# op.apply()
# Check the Operator runs without errors. Not verifying output for
# now. Need to consolidate BC implementation
op.apply()

# # Verify that users can override `mu`
# mu_new = Constant(name='mu_new', value=4.0)
# op.apply(mu=mu_new)
# Verify that users can override `mu`
mu_new = Constant(name='mu_new', value=4.0)
op.apply(mu=mu_new)


@skipif('petsc')
Expand Down Expand Up @@ -602,7 +602,8 @@ def test_time_loop():
v1 = Function(name='v1', grid=grid, space_order=2)
eq1 = Eq(v1.laplace, u1)
petsc1 = PETScSolve(eq1, v1)
op1 = Operator(petsc1)
with switchconfig(openmp=False):
op1 = Operator(petsc1)
body1 = str(op1.body)
rhs1 = str(op1._func_table['FormRHS_0'].root.ccode)

Expand All @@ -616,7 +617,8 @@ def test_time_loop():
v2 = Function(name='v2', grid=grid, space_order=2, save=5)
eq2 = Eq(v2.laplace, u2)
petsc2 = PETScSolve(eq2, v2)
op2 = Operator(petsc2)
with switchconfig(openmp=False):
op2 = Operator(petsc2)
body2 = str(op2.body)
rhs2 = str(op2._func_table['FormRHS_0'].root.ccode)

Expand All @@ -627,7 +629,8 @@ def test_time_loop():
# used in one of the callback functions
eq3 = Eq(v1.laplace, u1 + u1.forward)
petsc3 = PETScSolve(eq3, v1)
op3 = Operator(petsc3)
with switchconfig(openmp=False):
op3 = Operator(petsc3)
body3 = str(op3.body)
rhs3 = str(op3._func_table['FormRHS_0'].root.ccode)

Expand All @@ -642,7 +645,8 @@ def test_time_loop():
petsc4 = PETScSolve(eq4, v1)
eq5 = Eq(v2.laplace, u1)
petsc5 = PETScSolve(eq5, v2)
op4 = Operator(petsc4 + petsc5)
with switchconfig(openmp=False):
op4 = Operator(petsc4 + petsc5)
body4 = str(op4.body)

assert 'ctx.t0 = t0' in body4
Expand Down

0 comments on commit e12261a

Please sign in to comment.