diff --git a/devito/petsc/solve.py b/devito/petsc/solve.py index e5e6f3fbd3..c4abf96985 100644 --- a/devito/petsc/solve.py +++ b/devito/petsc/solve.py @@ -88,22 +88,34 @@ def generate_field_data(self, eqns, target, time_mapper): def build_callback_eqns(self, eq, target, arrays, time_mapper): b, F_target, targets = separate_eqn(eq, target) name = target.name + + matvec = self.make_matvec(eq, F_target, arrays, name, targets) + formfunc = self.make_formfunc(eq, F_target, arrays, name, targets) + formrhs = self.make_rhs(eq, b, arrays, name) + + return tuple(expr.subs(time_mapper) for expr in (matvec, formfunc, formrhs)) + + def make_matvec(self, eq, F_target, arrays, name, targets): matvec = Eq( arrays['y_matvec_%s' % name], F_target.subs(targets_to_arrays(arrays['x_matvec_%s' % name], targets)), subdomain=eq.subdomain ) - matvec = matvec.subs(time_mapper) + return matvec + + def make_formfunc(self, eq, F_target, arrays, name, targets): formfunc = Eq( arrays['y_formfunc_%s' % name], F_target.subs(targets_to_arrays(arrays['x_formfunc_%s' % name], targets)), subdomain=eq.subdomain ) - formfunc = formfunc.subs(time_mapper) - formrhs = Eq( - arrays['b_tmp_%s' % target.name], b.subs(time_mapper), subdomain=eq.subdomain + return formfunc + + def make_rhs(self, eq, b, arrays, name): + rhs = Eq( + arrays['b_tmp_%s' % name], b, subdomain=eq.subdomain ) - return matvec, formfunc, formrhs + return rhs class InjectSolveNested(InjectSolve):