Skip to content

Commit

Permalink
compiler: Remove time iter for petsc callbacks after specialize clusters
Browse files Browse the repository at this point in the history
  • Loading branch information
ZoeLeibowitz committed Aug 30, 2024
1 parent 552649b commit 0d7a9c3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
4 changes: 3 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from devito.types import (Buffer, Grid, Evaluable, host_layer, device_layer,
disk_layer)
from devito.petsc.iet.passes import lower_petsc, sort_frees
from devito.petsc.clusters import petsc_lift
from devito.petsc.clusters import petsc_lift, petsc_project

__all__ = ['Operator']

Expand Down Expand Up @@ -383,6 +383,8 @@ def _lower_clusters(cls, expressions, profiler=None, **kwargs):

clusters = cls._specialize_clusters(clusters, **kwargs)

clusters = petsc_project(clusters)

# Operation count after specialization
final_ops = sum(estimate_cost(c.exprs) for c in clusters if c.is_dense)
try:
Expand Down
17 changes: 13 additions & 4 deletions devito/petsc/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,30 @@ def petsc_lift(clusters):
"""
- Lift the iteration space surrounding each PETSc equation to create
distinct iteration loops.
- Drop time-loop for expressions which appear in PETSc callback functions.
"""
processed = []
for c in clusters:

if isinstance(c.exprs[0].rhs, LinearSolveExpr):
ispace = c.ispace.lift(c.exprs[0].rhs.target.space_dimensions)
processed.append(c.rebuild(ispace=ispace))
else:
processed.append(c)

return processed


@timed_pass()
def petsc_project(clusters):
"""
- Drop time-loop for expressions which appear in PETSc callback functions.
"""
processed = []
for c in clusters:
# Drop time-loop for expressions that appear in PETSc callback functions
elif isinstance(c.exprs[0].rhs, CallbackExpr):
if isinstance(c.exprs[0].rhs, CallbackExpr):
time_dims = [d for d in c.ispace.intervals.dimensions if d.is_Time]
ispace = c.ispace.project(lambda d: d not in time_dims)
processed.append(c.rebuild(ispace=ispace))

else:
processed.append(c)

Expand Down
4 changes: 2 additions & 2 deletions devito/petsc/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def PETScSolve(eq, target, bcs=None, solver_parameters=None, **kwargs):
inject_solve = InjectSolveEq(target, LinearSolveExpr(
dummy_expr, target=target, solver_parameters=solver_parameters,
matvecs=[matvecaction]+bcs_for_matvec,
formfuncs=[formfunction]+bcs_for_formfunc,
formrhs=[rhs]+bcs_for_rhs,
formfuncs=[formfunction],
formrhs=[rhs],
arrays=arrays,
), subdomain=eq.subdomain)

Expand Down

0 comments on commit 0d7a9c3

Please sign in to comment.