Skip to content

Commit

Permalink
misc: Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ZoeLeibowitz committed Aug 29, 2024
1 parent 77ef9a7 commit 258646e
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 54 deletions.
1 change: 0 additions & 1 deletion devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from devito.types import Array, Eq, Symbol
from devito.types.dimension import BOTTOM, ModuloDimension


__all__ = ['clusterize']


Expand Down
3 changes: 3 additions & 0 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def cire(clusters, mode, sregistry, options, platform):
transformer = cls(sregistry, options, platform)

clusters = transformer.process(clusters)

return clusters


Expand All @@ -116,6 +117,7 @@ def __init__(self, sregistry, options, platform):

def _aliases_from_clusters(self, clusters, exclude, meta):
exprs = flatten([c.exprs for c in clusters])

# [Clusters]_n -> [Schedule]_m
variants = []
for mapper in self._generate(exprs, exclude):
Expand All @@ -127,6 +129,7 @@ def _aliases_from_clusters(self, clusters, exclude, meta):
schedule = lower_aliases(aliases, meta, self.opt_maxpar)

variants.append(Variant(schedule, pexprs))

if not variants:
return []

Expand Down
1 change: 0 additions & 1 deletion devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def cse(cluster, sregistry, options, *args):
Common sub-expressions elimination (CSE).
"""
make = lambda: CTemp(name=sregistry.make_name(), dtype=cluster.dtype)

exprs = _cse(cluster, make, min_cost=options['cse-min-cost'])

return cluster.rebuild(exprs=exprs)
Expand Down
1 change: 0 additions & 1 deletion devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def rule1(dep, candidates, loc_dims):
# TODO: Investigate
scopes = {i: Scope([e.expr for e in v if not isinstance(e, Call)])
for i, v in MapNodes().visit(iet).items()}
# scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()}

# Analysis
hsmapper = {}
Expand Down
1 change: 1 addition & 0 deletions devito/petsc/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def petsc_lift(clusters):
"""
processed = []
for c in clusters:

if isinstance(c.exprs[0].rhs, LinearSolveExpr):
ispace = c.ispace.lift(c.exprs[0].rhs.target.space_dimensions)
processed.append(c.rebuild(ispace=ispace))
Expand Down
8 changes: 2 additions & 6 deletions devito/petsc/iet/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def lower_petsc(iet, **kwargs):
break

# Generate callback to populate main struct object
struct_main = objs['ctx']._rebuild(fields=filter_ordered(builder.struct_params))
struct_main = petsc_struct('ctx', filter_ordered(builder.struct_params))
struct_callback = generate_struct_callback(struct_main)
call_struct_callback = petsc_call(struct_callback.name, [Byref(struct_main)])
calls_set_app_ctx = [petsc_call('DMSetApplicationContext', [i, Byref(struct_main)])
Expand All @@ -93,9 +93,6 @@ def lower_petsc(iet, **kwargs):
)
iet = iet._rebuild(body=body)
metadata = core_metadata()

# Remove temporary objects in efuncs that were previously used to store metadata
# efuncs = transform_efuncs(builder.efuncs)
efuncs = tuple(builder.efuncs.values())+(struct_callback,)
metadata.update({'efuncs': efuncs})

Expand Down Expand Up @@ -130,8 +127,7 @@ def build_core_objects(target, **kwargs):
'comm': communicator,
'err': PetscErrorCode(name='err'),
'grid': target.grid,
'localsize': PetscInt(name='localsize'),
'ctx': petsc_struct('ctx', [])
'localsize': PetscInt(name='localsize')
}


Expand Down
7 changes: 2 additions & 5 deletions devito/petsc/types/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,8 @@ def fields(self):

@property
def time_dim_fields(self):
time_dims = []
for f in self.fields:
if isinstance(f, (ModuloDimension, TimeDimension)):
time_dims.append(f)
return time_dims
return [f for f in self.fields
if isinstance(f, (ModuloDimension, TimeDimension))]

@property
def _C_ctype(self):
Expand Down
72 changes: 33 additions & 39 deletions devito/petsc/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,39 @@
from devito.tools import Reconstructable, sympy_mutex


class BaseInfoExpr(sympy.Function, Reconstructable):
class LinearSolveExpr(sympy.Function, Reconstructable):

__rargs__ = ('expr',)
__rkwargs__ = ('target', 'solver_parameters', 'matvecs',
'formfuncs', 'formrhs', 'arrays')

defaults = {
'ksp_type': 'gmres',
'pc_type': 'jacobi',
'ksp_rtol': 1e-7, # Relative tolerance
'ksp_atol': 1e-50, # Absolute tolerance
'ksp_divtol': 1e4, # Divergence tolerance
'ksp_max_it': 10000 # Maximum iterations
}

def __new__(cls, expr, target=None, solver_parameters=None,
matvecs=None, formfuncs=None, formrhs=None, arrays=None, **kwargs):

if solver_parameters is None:
solver_parameters = cls.defaults
else:
for key, val in cls.defaults.items():
solver_parameters[key] = solver_parameters.get(key, val)

def __new__(cls, expr, **kwargs):
with sympy_mutex:
obj = sympy.Basic.__new__(cls, expr)
obj._expr = expr

for key, value in kwargs.items():
setattr(obj, "_%s" % key, value)

obj._target = target
obj._solver_parameters = solver_parameters
obj._matvecs = matvecs
obj._formfuncs = formfuncs
obj._formrhs = formrhs
obj._arrays = arrays
return obj

def __repr__(self):
Expand All @@ -28,44 +50,14 @@ def __hash__(self):
return hash(self.expr)

def __eq__(self, other):
return isinstance(other, self.__class__) and self.expr == other.expr
return (isinstance(other, LinearSolveExpr) and
self.expr == other.expr and
self.target == other.target)

@property
def expr(self):
return self._expr

func = Reconstructable._rebuild


class LinearSolveExpr(BaseInfoExpr):

__rkwargs__ = ('target', 'solver_parameters', 'matvecs',
'formfuncs', 'formrhs', 'arrays')

defaults = {
'ksp_type': 'gmres',
'pc_type': 'jacobi',
'ksp_rtol': 1e-7,
'ksp_atol': 1e-50,
'ksp_divtol': 1e4,
'ksp_max_it': 10000
}

def __new__(cls, expr, target=None, solver_parameters=None,
matvecs=None, formfuncs=None, formrhs=None, arrays=None, **kwargs):

if solver_parameters is None:
solver_parameters = cls.defaults
else:
for key, val in cls.defaults.items():
solver_parameters[key] = solver_parameters.get(key, val)

return super().__new__(
cls, expr, target=target, solver_parameters=solver_parameters,
matvecs=matvecs, formfuncs=formfuncs, formrhs=formrhs,
arrays=arrays, **kwargs
)

@property
def target(self):
return self._target
Expand All @@ -90,6 +82,8 @@ def formrhs(self):
def arrays(self):
return self._arrays

func = Reconstructable._rebuild


class CallbackExpr(sympy.Function):
@classmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/test_petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,4 +572,4 @@ def test_start_prt():
op = Operator(petsc)

# Verify the case with modulo time stepping
assert 'float * start_ptr = t1*localsize + (float *)(u_vec->data);' in str(op)
assert 'float * start_ptr_u = t1*localsize + (float *)(u_vec->data);' in str(op)

0 comments on commit 258646e

Please sign in to comment.