Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Minor tweaks for elastic code gen #2453

Merged
merged 6 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,10 @@ def __init_finalize__(self, **kwargs):
if not configuration['safe-math']:
self.cflags.append('--use_fast_math')

# Optionally print out per-kernel shared memory and register usage
if configuration['profiling'] == 'advanced2':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!

self.cflags.append('--ptxas-options=-v')

self.src_ext = 'cu'

# NOTE: not sure where we should place this. It definitely needs
Expand Down
2 changes: 1 addition & 1 deletion devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,6 @@ def _(expr, x0, **kwargs):
if x0_expr:
dims = tuple((d, 0) for d in x0_expr)
fd_o = tuple([2]*len(dims))
return Derivative(expr, *dims, fd_order=fd_o, x0=x0_expr)._evaluate(**kwargs)
return Derivative(expr, *dims, fd_order=fd_o, x0=x0_expr)
else:
return expr
2 changes: 1 addition & 1 deletion devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def make_stencil_dimension(expr, _min, _max):
Create a StencilDimension for `expr` with unique name.
"""
n = len(expr.find(StencilDimension))
return StencilDimension(name='i%d' % n, _min=_min, _max=_max)
return StencilDimension('i%d' % n, _min, _max)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: could these just be min and max now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are special python keywords so we tend not to as it's not recommended



@cacheit
Expand Down
26 changes: 15 additions & 11 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from devito.mpi.reduction_scheme import DistReduce
from devito.symbolics import estimate_cost
from devito.tools import as_tuple, flatten, infer_dtype
from devito.types import WeakFence, CriticalRegion
from devito.types import Fence, WeakFence, CriticalRegion

__all__ = ["Cluster", "ClusterGroup"]

Expand Down Expand Up @@ -239,42 +239,46 @@ def is_sparse(self):
"""
return any(a.is_irregular for a in self.scope.accesses)

@property
@cached_property
def is_wild(self):
"""
True if encoding a non-mathematical operation, False otherwise.
"""
return self.is_halo_touch or self.is_dist_reduce or self.is_fence
return (self.is_halo_touch or
self.is_dist_reduce or
self.is_weak_fence or
self.is_critical_region)

@property
@cached_property
def is_halo_touch(self):
return self.exprs and all(isinstance(e.rhs, HaloTouch) for e in self.exprs)

@property
@cached_property
def is_dist_reduce(self):
return self.exprs and all(isinstance(e.rhs, DistReduce) for e in self.exprs)

@property
@cached_property
def is_fence(self):
return self.is_weak_fence or self.is_critical_region
return (self.exprs and all(isinstance(e.rhs, Fence) for e in self.exprs) or
self.is_critical_region)

@property
@cached_property
def is_weak_fence(self):
return self.exprs and all(isinstance(e.rhs, WeakFence) for e in self.exprs)

@property
@cached_property
def is_critical_region(self):
return self.exprs and all(isinstance(e.rhs, CriticalRegion) for e in self.exprs)

@property
@cached_property
def is_async(self):
"""
True if an asynchronous Cluster, False otherwise.
"""
return any(isinstance(s, (WithLock, PrefetchUpdate))
for s in flatten(self.syncs.values()))

@property
@cached_property
def is_wait(self):
"""
True if a Cluster waiting on a lock (that is a special synchronization
Expand Down
12 changes: 8 additions & 4 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def is_cross(source, sink):
return t0 < v <= t1 or t1 < v <= t0

for cg1 in cgroups[n+1:]:
n1 = cgroups.index(cg1)

# A Scope to compute all cross-ClusterGroup anti-dependences
scope = Scope(exprs=cg0.exprs + cg1.exprs, rules=is_cross)

Expand All @@ -355,14 +357,16 @@ def is_cross(source, sink):
break

# Any anti- and iaw-dependences impose that `cg1` follows `cg0`
# and forbid any sort of fusion
elif any(scope.d_anti_gen()) or\
any(i.is_iaw for i in scope.d_output_gen()):
# and forbid any sort of fusion. Fences have the same effect
elif (any(scope.d_anti_gen()) or
any(i.is_iaw for i in scope.d_output_gen()) or
any(c.is_fence for c in flatten(cgroups[n:n1+1]))):
dag.add_edge(cg0, cg1)

# Any flow-dependences along an inner Dimension (i.e., a Dimension
# that doesn't appear in `prefix`) impose that `cg1` follows `cg0`
elif any(not (i.cause and i.cause & prefix) for i in scope.d_flow_gen()):
elif any(not (i.cause and i.cause & prefix)
for i in scope.d_flow_gen()):
dag.add_edge(cg0, cg1)

# Clearly, output dependences must be honored
Expand Down
24 changes: 20 additions & 4 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@
from sympy.core.add import _addsort
from sympy.core.mul import _mulsort

from devito.finite_differences.differentiable import EvalDerivative
from devito.finite_differences.differentiable import (
EvalDerivative, IndexDerivative
)
from devito.symbolics.extended_sympy import DefFunction, rfunc
from devito.symbolics.queries import q_leaf
from devito.symbolics.search import retrieve_indexed, retrieve_functions
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
from devito.types.basic import Basic
from devito.types.basic import Basic, Indexed
from devito.types.array import ComponentAccess
from devito.types.equation import Eq
from devito.types.relational import Le, Lt, Gt, Ge

__all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args',
'normalize_args', 'uxreplace', 'Uxmapper', 'reuse_if_untouched',
'evalrel', 'flatten_args']
'normalize_args', 'uxreplace', 'Uxmapper', 'subs_if_composite',
'reuse_if_untouched', 'evalrel', 'flatten_args']


def uxreplace(expr, rule):
Expand Down Expand Up @@ -246,6 +248,20 @@ def add(self, expr, make, terms=None):
self[base] = self.extracted[base] = make()


def subs_if_composite(expr, subs):
"""
Call `expr.subs(subs)` if `subs` contain composite expressions, that is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: "contains"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noted

expressions that can be part of larger expressions of the same type (e.g.,
`a*b` could be part of `a*b*c`, while `a[1]` cannot be part of a "larger
Indexed"). Instead, if `subs` consists of just "primitive" expressions, then
resort to the much faster `uxreplace`.
"""
if all(isinstance(i, (Indexed, IndexDerivative)) for i in subs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So why can't this just be moved inside uxreplace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it'd contradict the API -- uxreplace performs no re-simplifications.

return uxreplace(expr, subs)
else:
return expr.subs(subs)


def xreplace_indices(exprs, mapper, key=None):
"""
Replace array indices in expressions.
Expand Down
6 changes: 6 additions & 0 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import inspect
from collections import namedtuple
from ctypes import POINTER, _Pointer, c_char_p, c_char
from functools import reduce, cached_property
Expand Down Expand Up @@ -490,6 +491,11 @@ def _cache_key(cls, *args, **kwargs):
# From the kwargs
key.update(kwargs)

# Any missing __rkwargs__ along with their default values
params = inspect.signature(cls.__init_finalize__).parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch how can we end up in such a weird spot

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try the newly added tests without this patch 😬 I don't remember the details, but basically caching bypassed because a different cache key gets computed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but I don't get why this needs this elaborate inspect instead of just having StencilDimension implement _cache_key and add step/spacing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure what makes you think it's due to StencilDimension? the test maybe ? but it was not just that. Maybe it emerged from there, but the problem is way more general. In fact, IIRC the issue was the presence/absence of the is_const flag, which pops up after reconstruction but it's not part of the key (without this patch) at first instantiation

missing = [i for i in cls.__rkwargs__ if i in set(params).difference(key)]
key.update({i: params[i].default for i in missing})

return frozendict(key)

def __new__(cls, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,9 +1534,9 @@ class StencilDimension(BasicDimension):
__rargs__ = BasicDimension.__rargs__ + ('_min', '_max')
__rkwargs__ = BasicDimension.__rkwargs__ + ('step',)

def __init_finalize__(self, name, _min, _max, spacing=None, step=1,
def __init_finalize__(self, name, _min, _max, spacing=1, step=1,
**kwargs):
self._spacing = sympy.sympify(spacing) or sympy.S.One
self._spacing = sympy.sympify(spacing)

if not is_integer(_min):
raise ValueError("Expected integer `min` (got %s)" % _min)
Expand Down
33 changes: 18 additions & 15 deletions tests/test_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,31 +61,34 @@ def test_interp():
a = Function(name="a", grid=grid, staggered=NODE)
sa = Function(name="as", grid=grid, staggered=x)

sp_diff = lambda a, b: sympy.simplify(a - b) == 0
def sp_diff(a, b):
a = getattr(a, 'evaluate', a)
b = getattr(b, 'evaluate', b)
return sympy.simplify(a - b) == 0

# Base case, no interp
assert interp_for_fd(a, {}, expand=True) == a
assert interp_for_fd(a, {x: x}, expand=True) == a
assert interp_for_fd(sa, {}, expand=True) == sa
assert interp_for_fd(sa, {x: x + x.spacing/2}, expand=True) == sa
assert interp_for_fd(a, {}) == a
assert interp_for_fd(a, {x: x}) == a
assert interp_for_fd(sa, {}) == sa
assert interp_for_fd(sa, {x: x + x.spacing/2}) == sa

# Base case, interp
assert sp_diff(interp_for_fd(a, {x: x + x.spacing/2}, expand=True),
assert sp_diff(interp_for_fd(a, {x: x + x.spacing/2}),
.5*a + .5*a.shift(x, x.spacing))
assert sp_diff(interp_for_fd(sa, {x: x}, expand=True),
assert sp_diff(interp_for_fd(sa, {x: x}),
.5*sa + .5*sa.shift(x, -x.spacing))

# Mul case, split interp
assert sp_diff(interp_for_fd(a*sa, {x: x + x.spacing/2}, expand=True),
sa * interp_for_fd(a, {x: x + x.spacing/2}, expand=True))
assert sp_diff(interp_for_fd(a*sa, {x: x}, expand=True),
a * interp_for_fd(sa, {x: x}, expand=True))
assert sp_diff(interp_for_fd(a*sa, {x: x + x.spacing/2}),
sa * interp_for_fd(a, {x: x + x.spacing/2}))
assert sp_diff(interp_for_fd(a*sa, {x: x}),
a * interp_for_fd(sa, {x: x}))

# Add case, split interp
assert sp_diff(interp_for_fd(a + sa, {x: x + x.spacing/2}, expand=True),
sa + interp_for_fd(a, {x: x + x.spacing/2}, expand=True))
assert sp_diff(interp_for_fd(a + sa, {x: x}, expand=True),
a + interp_for_fd(sa, {x: x}, expand=True))
assert sp_diff(interp_for_fd(a + sa, {x: x + x.spacing/2}),
sa + interp_for_fd(a, {x: x + x.spacing/2}))
assert sp_diff(interp_for_fd(a + sa, {x: x}),
a + interp_for_fd(sa, {x: x}))


@pytest.mark.parametrize('ndim', [1, 2, 3])
Expand Down
25 changes: 25 additions & 0 deletions tests/test_rebuild.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import pytest

from devito import Dimension, Function
from devito.types import StencilDimension
from devito.data.allocators import DataReference


Expand Down Expand Up @@ -40,3 +42,26 @@ def test_w_new_dims(self):
assert f3.function is f3
assert f3.dimensions == dims0
assert np.all(f3.data[:] == 1)


class TestDimension:

def test_stencil_dimension(self):
sd0 = StencilDimension('i', 0, 1)
sd1 = StencilDimension('i', 0, 1)

# StencilDimensions are cached by devito so they are guaranteed to be
# unique for a given set of args/kwargs
assert sd0 is sd1

# Same applies to reconstruction
sd2 = sd0._rebuild()
assert sd0 is sd2

@pytest.mark.xfail(reason="Borked caching when supplying a kwarg for an arg")
def test_stencil_dimension_borked(self):
sd0 = StencilDimension('i', 0, _max=1)
sd1 = sd0._rebuild()

# TODO: Look into Symbol._cache_key and the way the key is generated
assert sd0 is sd1
Loading