Skip to content

Commit

Permalink
Merge pull request devitocodes#2272 from devitocodes/halo-inner-dim-f…
Browse files Browse the repository at this point in the history
…abio

mpi: Fix haloupdate with inner dim [v2]
  • Loading branch information
mloubout authored Nov 21, 2023
2 parents 7a86b36 + d6d8a6e commit 25fa68b
Show file tree
Hide file tree
Showing 13 changed files with 214 additions and 58 deletions.
59 changes: 29 additions & 30 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,54 +374,53 @@ class Communications(Queue):

B = Symbol(name='⊥')

@timed_pass(name='schedule')
@timed_pass(name='communications')
def process(self, clusters):
return self._process_fatd(clusters, 1, seen=set())

def callback(self, clusters, prefix, seen=None):
if seen.issuperset(clusters):
if not prefix:
return clusters

d = prefix[-1].dim

# Construct the mock exprs representing the halo accesses
exprs = []
# Construct a representation of the halo accesses
processed = []
for c in clusters:
if c.properties.is_sequential(d):
if c.properties.is_sequential(d) or \
c in seen:
continue

halo_scheme = HaloScheme(c.exprs, c.ispace)
hs = HaloScheme(c.exprs, c.ispace)
if hs.is_void or \
not d._defines & hs.distributed_aindices:
continue

if not halo_scheme.is_void and \
c.properties.is_parallel_relaxed(d):
points = set()
for f in halo_scheme.fmapper:
for a in c.scope.getreads(f):
points.add(a.access)
points = set()
for f in hs.fmapper:
for a in c.scope.getreads(f):
points.add(a.access)

# We also add all written symbols to ultimately create mock WARs
# with `c`, which will prevent the newly created HaloTouch to ever
# be rescheduled after `c` upon topological sorting
points.update(a.access for a in c.scope.accesses if a.is_write)
# We also add all written symbols to ultimately create mock WARs
# with `c`, which will prevent the newly created HaloTouch to ever
# be rescheduled after `c` upon topological sorting
points.update(a.access for a in c.scope.accesses if a.is_write)

# Sort for determinism
# NOTE: not sorting might impact code generation. The order of
# the args is important because that's what search functions honor!
points = sorted(points, key=str)
# Sort for determinism
# NOTE: not sorting might impact code generation. The order of
# the args is important because that's what search functions honor!
points = sorted(points, key=str)

rhs = HaloTouch(*points, halo_scheme=halo_scheme)
# Construct the HaloTouch Cluster
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))

# Insert only if not redundant, to avoid useless pollution
if not any(rhs == e.rhs for e in exprs):
exprs.append(Eq(self.B, rhs))
key = lambda i: i in prefix[:-1] or i in hs.loc_indices
ispace = c.ispace.project(key)

processed = []
if exprs:
ispace = prefix[:prefix.index(d)]
properties = prefix.properties.drop(d)
halo_touch = c.rebuild(exprs=expr, ispace=ispace)

processed.append(Cluster(exprs, ispace, c.guards, properties))
seen.update(clusters)
processed.append(halo_touch)
seen.update({halo_touch, c})

processed.extend(clusters)

Expand Down
5 changes: 4 additions & 1 deletion devito/ir/iet/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def iet_build(stree):
nsections += 1

elif i.is_Halo:
body = HaloSpot(queues.pop(i), i.halo_scheme)
try:
body = HaloSpot(queues.pop(i), i.halo_scheme)
except KeyError:
body = HaloSpot(None, i.halo_scheme)

elif i.is_Sync:
body = SyncSpot(i.sync_ops, body=queues.pop(i, None))
Expand Down
30 changes: 26 additions & 4 deletions devito/ir/stree/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from devito.ir.support import (SEQUENTIAL, Any, Interval, IterationInterval,
IterationSpace, normalize_properties, normalize_syncs)
from devito.mpi.halo_scheme import HaloScheme
from devito.tools import Bunch, DefaultOrderedDict
from devito.tools import Bunch, DefaultOrderedDict, as_mapper

__all__ = ['stree_build']

Expand Down Expand Up @@ -85,6 +85,10 @@ def stree_build(clusters, profiler=None, **kwargs):
if needs_nodehalo(it.dim, c.halo_scheme):
v.bottom.parent = NodeHalo(c.halo_scheme, v.bottom.parent)
break
else:
if c.halo_scheme:
assert not c.exprs # See preprocess() -- we rarely end up here!
tip = NodeHalo(c.halo_scheme, v.bottom)

# Add in NodeExprs
exprs = []
Expand Down Expand Up @@ -150,11 +154,14 @@ def preprocess(clusters, options=None, **kwargs):
for c in clusters:
if c.is_halo_touch:
hs = HaloScheme.union(e.rhs.halo_scheme for e in c.exprs)
queue.append(c.rebuild(halo_scheme=hs))
queue.append(c.rebuild(exprs=[], halo_scheme=hs))

elif c.is_critical_region and c.syncs:
processed.append(c.rebuild(exprs=None, guards=c.guards, syncs=c.syncs))

elif c.is_wild:
continue

else:
dims = set(c.ispace.promote(lambda d: d.is_Block).itdims)

Expand All @@ -181,8 +188,23 @@ def preprocess(clusters, options=None, **kwargs):
ispace = c.ispace.project(syncs)
processed.append(c.rebuild(exprs=[], ispace=ispace, syncs=syncs))

halo_scheme = HaloScheme.union([c1.halo_scheme for c1 in found])
processed.append(c.rebuild(halo_scheme=halo_scheme))
if all(c1.ispace.is_subset(c.ispace) for c1 in found):
# 99% of the cases we end up here
hs = HaloScheme.union([c1.halo_scheme for c1 in found])
processed.append(c.rebuild(halo_scheme=hs))
elif options['mpi']:
# We end up here with e.g. `t,x,y,z,f` where `f` is a sequential
# dimension requiring a loc-index in the HaloScheme. The compiler
# will generate the non-perfect loop nest `t,f ; t,x,y,z,f`, with
# the first nest triggering all necessary halo exchanges along `f`
mapper = as_mapper(found, lambda c1: c1.ispace)
for k, v in mapper.items():
hs = HaloScheme.union([c1.halo_scheme for c1 in v])
processed.append(c.rebuild(exprs=[], ispace=k, halo_scheme=hs))
processed.append(c)
else:
# Avoid ugly empty loops
processed.append(c)

# Sanity check!
try:
Expand Down
13 changes: 13 additions & 0 deletions devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,19 @@ def reorder(self, relations=None, mode=None):

return IterationSpace(intervals, self.sub_iterators, self.directions)

def is_subset(self, other):
"""
True if `self` is included within `other`, False otherwise.
"""
if not self:
return True

d = self[-1].dim
try:
return self == other[:other.index(d) + 1]
except ValueError:
return False

def is_compatible(self, other):
"""
A relaxed version of ``__eq__``, in which only non-derived dimensions
Expand Down
2 changes: 2 additions & 0 deletions devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from devito.types.utils import DimensionTuple


__all__ = ['CustomTopology']

# Do not prematurely initialize MPI
# This allows launching a Devito program from within another Python program
# that has *already* initialized MPI
Expand Down
12 changes: 10 additions & 2 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __len__(self):
return len(self._mapper)

def __hash__(self):
return (self._mapper.__hash__(), self.honored.__hash__())
return hash((self._mapper.__hash__(), self.honored.__hash__()))

@classmethod
def build(cls, fmapper, honored):
Expand Down Expand Up @@ -582,13 +582,21 @@ def _sympystr(self, printer):
return str(self)

def __hash__(self):
return id(self)
return hash(self.halo_scheme)

def __eq__(self, other):
return isinstance(other, HaloTouch) and self.halo_scheme == other.halo_scheme

func = Reconstructable._rebuild

@property
def fmapper(self):
return self.halo_scheme.fmapper

@property
def dims(self):
return frozenset().union(*[v.dims for v in self.fmapper.values()])


def _uxreplace_dispatch_haloscheme(hs0, rule):
changed = False
Expand Down
15 changes: 9 additions & 6 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from devito.finite_differences.differentiable import Mul
from devito.finite_differences.elementary import floor
from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT
from devito.tools import as_tuple, flatten
from devito.tools import as_tuple, flatten, filter_ordered
from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol,
CustomDimension)
from devito.types.utils import DimensionTuple
Expand Down Expand Up @@ -163,7 +163,7 @@ def r(self):

@cached_property
def _rdim(self):
parent = self.sfunction.dimensions[-1]
parent = self.sfunction._sparse_dim
dims = [CustomDimension("r%s%s" % (self.sfunction.name, d.name),
-self.r+1, self.r, 2*self.r, parent)
for d in self._gdims]
Expand All @@ -184,15 +184,18 @@ def _rdim(self):

def _augment_implicit_dims(self, implicit_dims, extras=None):
if extras is not None:
extra = set([i for v in extras for i in v.dimensions]) - set(self._gdims)
extra = filter_ordered([i for v in extras for i in v.dimensions
if i not in self._gdims and
i not in self.sfunction.dimensions])
extra = tuple(extra)
else:
extra = tuple()

if self.sfunction._sparse_position == -1:
return self.sfunction.dimensions + as_tuple(implicit_dims) + extra
idims = self.sfunction.dimensions + as_tuple(implicit_dims) + extra
else:
return as_tuple(implicit_dims) + self.sfunction.dimensions + extra
idims = extra + as_tuple(implicit_dims) + self.sfunction.dimensions
return tuple(idims)

def _coeff_temps(self, implicit_dims):
return []
Expand Down Expand Up @@ -283,7 +286,7 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
variables = list(retrieve_function_carriers(_expr))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims)
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)

# List of indirection indices for all adjacent grid points
idx_subs, temps = self._interp_idx(variables, implicit_dims=implicit_dims)
Expand Down
10 changes: 8 additions & 2 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def rule1(dep, candidates, loc_dims):
for q in d._defines])

for n, i in enumerate(iters):
if i not in scopes:
continue

candidates = [i.dim._defines for i in iters[n:]]

all_candidates = set().union(*candidates)
Expand Down Expand Up @@ -251,9 +254,10 @@ def _mark_overlappable(iet):
found = []
for hs in FindNodes(HaloSpot).visit(iet):
expressions = FindNodes(Expression).visit(hs)
scope = Scope([i.expr for i in expressions])
if not expressions:
continue

test = True
scope = Scope([i.expr for i in expressions])

# Comp/comm overlaps is legal only if the OWNED regions can grow
# arbitrarly, which means all of the dependences must be carried
Expand All @@ -270,6 +274,8 @@ def _mark_overlappable(iet):
# f[x, y] = ...
test = False
break
else:
test = True

# Heuristic: avoid comp/comm overlap for sparse Iteration nests
if test:
Expand Down
4 changes: 4 additions & 0 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,10 @@ def _hashable_content(self):
def indices(self):
return DimensionTuple(*super().indices, getters=self.function.dimensions)

@cached_property
def dimensions(self):
return self.function.dimensions

@property
def function(self):
return self.base.function
Expand Down
5 changes: 3 additions & 2 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ class SparseTimeFunction(AbstractSparseTimeFunction, SparseFunction):
__rkwargs__ = tuple(filter_ordered(AbstractSparseTimeFunction.__rkwargs__ +
SparseFunction.__rkwargs__))

def interpolate(self, expr, u_t=None, p_t=None, increment=False):
def interpolate(self, expr, u_t=None, p_t=None, increment=False, implicit_dims=None):
"""
Generate equations interpolating an arbitrary expression into ``self``.
Expand All @@ -921,7 +921,8 @@ def interpolate(self, expr, u_t=None, p_t=None, increment=False):
if p_t is not None:
subs = {self.time_dim: p_t}

return super().interpolate(expr, increment=increment, self_subs=subs)
return super().interpolate(expr, increment=increment, self_subs=subs,
implicit_dims=implicit_dims)

def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None):
"""
Expand Down
12 changes: 4 additions & 8 deletions examples/seismic/tti/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,7 @@ def ForwardOperator(model, geometry, space_order=4,

# Source and receivers
expr = src * dt / m if kernel == 'staggered' else src * dt**2 / m
stencils += src.inject(field=u.forward, expr=expr)
stencils += src.inject(field=v.forward, expr=expr)
stencils += src.inject(field=(u.forward, v.forward), expr=expr)
stencils += rec.interpolate(expr=u + v)

# Substitute spacing terms to reduce flops
Expand Down Expand Up @@ -601,8 +600,7 @@ def AdjointOperator(model, geometry, space_order=4,

# Construct expression to inject receiver values
expr = rec * dt / m if kernel == 'staggered' else rec * dt**2 / m
stencils += rec.inject(field=p.backward, expr=expr)
stencils += rec.inject(field=r.backward, expr=expr)
stencils += rec.inject(field=(p.backward, r.backward), expr=expr)

# Create interpolation expression for the adjoint-source
stencils += srca.interpolate(expr=p + r)
Expand Down Expand Up @@ -661,8 +659,7 @@ def JacobianOperator(model, geometry, space_order=4,
eqn2 = FD_kernel(model, du, dv, space_order, qu=lin_usrc, qv=lin_vsrc)

# Construct expression to inject source values, injecting at u0(t+dt)/v0(t+dt)
src_term = src.inject(field=u0.forward, expr=src * dt**2 / m)
src_term += src.inject(field=v0.forward, expr=src * dt**2 / m)
src_term = src.inject(field=(u0.forward, v0.forward), expr=src * dt**2 / m)

# Create interpolation expression for receivers, extracting at du(t)+dv(t)
rec_term = rec.interpolate(expr=du + dv)
Expand Down Expand Up @@ -716,8 +713,7 @@ def JacobianAdjOperator(model, geometry, space_order=4,
dm_update = Inc(dm, - (u0 * du.dt2 + v0 * dv.dt2))

# Add expression for receiver injection
rec_term = rec.inject(field=du.backward, expr=rec * dt**2 / m)
rec_term += rec.inject(field=dv.backward, expr=rec * dt**2 / m)
rec_term = rec.inject(field=(du.backward, dv.backward), expr=rec * dt**2 / m)

# Substitute spacing terms to reduce flops
return Operator(eqn + rec_term + [dm_update], subs=model.spacing_map,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,8 @@ class SparseFirst(SparseFunction):
ds = DefaultDimension("ps", default_value=3)
grid = Grid((11, 11))
dims = grid.dimensions
s = SparseFirst(name="s", grid=grid, npoint=2, dimensions=(dr, ds), shape=(2, 3))
s.coordinates.data[:] = [[.5, .5], [.2, .2]]
s = SparseFirst(name="s", grid=grid, npoint=2, dimensions=(dr, ds), shape=(2, 3),
coordinates=[[.5, .5], [.2, .2]])

# Check dimensions and shape are correctly initialized
assert s.indices[s._sparse_position] == dr
Expand Down
Loading

0 comments on commit 25fa68b

Please sign in to comment.