From 0d7a9c30805d3ac61cebee9c461fa98975ee6417 Mon Sep 17 00:00:00 2001 From: ZoeLeibowitz Date: Fri, 30 Aug 2024 13:23:26 +0100 Subject: [PATCH] compiler: Remove time iter for petsc callbacks after specialize clusters --- devito/operator/operator.py | 4 +++- devito/petsc/clusters.py | 17 +++++++++++++---- devito/petsc/solve.py | 4 ++-- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 69f34ee3b0..fa342dd33a 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -32,7 +32,7 @@ from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer, disk_layer) from devito.petsc.iet.passes import lower_petsc, sort_frees -from devito.petsc.clusters import petsc_lift +from devito.petsc.clusters import petsc_lift, petsc_project __all__ = ['Operator'] @@ -383,6 +383,8 @@ def _lower_clusters(cls, expressions, profiler=None, **kwargs): clusters = cls._specialize_clusters(clusters, **kwargs) + clusters = petsc_project(clusters) + # Operation count after specialization final_ops = sum(estimate_cost(c.exprs) for c in clusters if c.is_dense) try: diff --git a/devito/petsc/clusters.py b/devito/petsc/clusters.py index e0a3639116..5719cfc721 100644 --- a/devito/petsc/clusters.py +++ b/devito/petsc/clusters.py @@ -7,21 +7,30 @@ def petsc_lift(clusters): """ - Lift the iteration space surrounding each PETSc equation to create distinct iteration loops. - - Drop time-loop for expressions which appear in PETSc callback functions. """ processed = [] for c in clusters: - if isinstance(c.exprs[0].rhs, LinearSolveExpr): ispace = c.ispace.lift(c.exprs[0].rhs.target.space_dimensions) processed.append(c.rebuild(ispace=ispace)) + else: + processed.append(c) + + return processed + +@timed_pass() +def petsc_project(clusters): + """ + - Drop time-loop for expressions which appear in PETSc callback functions. + """ + processed = [] + for c in clusters: # Drop time-loop for expressions that appear in PETSc callback functions - elif isinstance(c.exprs[0].rhs, CallbackExpr): + if isinstance(c.exprs[0].rhs, CallbackExpr): time_dims = [d for d in c.ispace.intervals.dimensions if d.is_Time] ispace = c.ispace.project(lambda d: d not in time_dims) processed.append(c.rebuild(ispace=ispace)) - else: processed.append(c) diff --git a/devito/petsc/solve.py b/devito/petsc/solve.py index 93b83697a2..f811e00d4c 100644 --- a/devito/petsc/solve.py +++ b/devito/petsc/solve.py @@ -93,8 +93,8 @@ def PETScSolve(eq, target, bcs=None, solver_parameters=None, **kwargs): inject_solve = InjectSolveEq(target, LinearSolveExpr( dummy_expr, target=target, solver_parameters=solver_parameters, matvecs=[matvecaction]+bcs_for_matvec, - formfuncs=[formfunction]+bcs_for_formfunc, - formrhs=[rhs]+bcs_for_rhs, + formfuncs=[formfunction], + formrhs=[rhs], arrays=arrays, ), subdomain=eq.subdomain)