Skip to content

Commit

Permalink
Add a SubrangeMapper helper class.
Browse files Browse the repository at this point in the history
Its functionality
===
Equipped with a `src` and a `dst` range of equal volumes but possibly different shapes or even dimensions (i.e.,
there is an 1-to-1 correspondence between the elements of `src` and `dst`), maps a subrange of `src` to its
counterpart in `dst`, if possible.

Note that such subrange-to-subrange mapping may not always exist.
  • Loading branch information
pratyai committed Oct 23, 2024
1 parent beb44a2 commit 4911a01
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 30 deletions.
143 changes: 116 additions & 27 deletions dace/subsets.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import operator

import dace.serialize
from dace import data, symbolic, dtypes
import re
import sympy as sp
import warnings
from functools import reduce
import sympy.core.sympify
from typing import List, Optional, Sequence, Set, Union
import warnings

import sympy as sp
import sympy.core.sympify
from sympy import ceiling

import dace.serialize
from dace import symbolic
from dace.config import Config


Expand All @@ -22,6 +24,7 @@ def nng(expr):
except AttributeError: # No free_symbols in expr
return expr


def bounding_box_cover_exact(subset_a, subset_b) -> bool:
min_elements_a = subset_a.min_element()
max_elements_a = subset_a.max_element()
Expand All @@ -31,16 +34,17 @@ def bounding_box_cover_exact(subset_a, subset_b) -> bool:
# Covering only make sense if the two subsets have the same number of dimensions.
if len(min_elements_a) != len(min_elements_b):
return ValueError(
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
)

return all([(symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))) == True
and (symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))) == True
for rb, re, orb, ore in zip(min_elements_a, max_elements_a,
min_elements_b, max_elements_b)])

def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)-> bool:

def bounding_box_symbolic_positive(subset_a, subset_b, approximation=False) -> bool:
min_elements_a = subset_a.min_element_approx() if approximation else subset_a.min_element()
max_elements_a = subset_a.max_element_approx() if approximation else subset_a.max_element()
min_elements_b = subset_b.min_element_approx() if approximation else subset_b.min_element()
Expand All @@ -49,8 +53,8 @@ def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)->
# Covering only make sense if the two subsets have the same number of dimensions.
if len(min_elements_a) != len(min_elements_b):
return ValueError(
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
)

for rb, re, orb, ore in zip(min_elements_a, max_elements_a,
Expand All @@ -72,6 +76,7 @@ def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)->
return False
return True


class Subset(object):
""" Defines a subset of a data descriptor. """

Expand All @@ -82,7 +87,7 @@ def covers(self, other):
# Subsets of different dimensionality can never cover each other.
if self.dims() != other.dims():
return ValueError(
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
)

if not Config.get('optimizer', 'symbolic_positive'):
Expand All @@ -101,20 +106,22 @@ def covers(self, other):
return False

return True

def covers_precise(self, other):
""" Returns True if self contains all the elements in other. """

# Subsets of different dimensionality can never cover each other.
if self.dims() != other.dims():
return ValueError(
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
)

# If self does not cover other with a bounding box union, return false.
symbolic_positive = Config.get('optimizer', 'symbolic_positive')
try:
bounding_box_cover = bounding_box_cover_exact(self, other) if symbolic_positive else bounding_box_symbolic_positive(self, other)
bounding_box_cover = bounding_box_cover_exact(self,
other) if symbolic_positive else bounding_box_symbolic_positive(
self, other)
if not bounding_box_cover:
return False
except TypeError:
Expand Down Expand Up @@ -153,14 +160,13 @@ def covers_precise(self, other):
except:
return False
return True
# unknown type
# unknown type
else:
raise TypeError

except TypeError:
return False


def __repr__(self):
return '%s (%s)' % (type(self).__name__, self.__str__())

Expand Down Expand Up @@ -231,6 +237,7 @@ def _tuple_to_symexpr(val):
@dace.serialize.serializable
class Range(Subset):
""" Subset defined in terms of a fixed range. """

def __init__(self, ranges):
parsed_ranges = []
parsed_tiles = []
Expand Down Expand Up @@ -584,7 +591,7 @@ def from_string(string):
value = symbolic.pystr_to_symbolic(uni_dim_tokens[0].strip())
ranges.append((value, value, 1))
continue
#return Range(ranges)
# return Range(ranges)
# If dimension has more than 4 tokens, the range is invalid
if len(uni_dim_tokens) > 4:
raise SyntaxError("Invalid range: {}".format(multi_dim_tokens))
Expand Down Expand Up @@ -854,6 +861,7 @@ def intersects(self, other: 'Range'):
class Indices(Subset):
""" A subset of one element representing a single index in an
N-dimensional data descriptor. """

def __init__(self, indices):
if indices is None or len(indices) == 0:
raise TypeError('Expected an array of index expressions: got empty' ' array or None')
Expand All @@ -880,7 +888,7 @@ def from_json(obj, context=None):
raise TypeError("from_json of class \"Indices\" called on json "
"with type %s (expected 'Indices')" % obj['type'])

#return Indices(symbolic.SymExpr(obj['indices']))
# return Indices(symbolic.SymExpr(obj['indices']))
return Indices([*map(symbolic.pystr_to_symbolic, obj['indices'])])

def __hash__(self):
Expand Down Expand Up @@ -1091,6 +1099,7 @@ def intersection(self, other: 'Indices'):
return self
return None


class SubsetUnion(Subset):
"""
Wrapper subset type that stores multiple Subsets in a list.
Expand Down Expand Up @@ -1128,7 +1137,7 @@ def covers(self, other):
return False
else:
return any(s.covers(other) for s in self.subset_list)

def covers_precise(self, other):
"""
Returns True if this SubsetUnion covers another
Expand All @@ -1154,7 +1163,7 @@ def __str__(self):
string += " "
string += subset.__str__()
return string

def dims(self):
if not self.subset_list:
return 0
Expand All @@ -1178,7 +1187,7 @@ def free_symbols(self) -> Set[str]:
for subset in self.subset_list:
result |= subset.free_symbols
return result

def replace(self, repl_dict):
for subset in self.subset_list:
subset.replace(repl_dict)
Expand All @@ -1188,13 +1197,12 @@ def num_elements(self):
min = 0
for subset in self.subset_list:
try:
if subset.num_elements() < min or min ==0:
if subset.num_elements() < min or min == 0:
min = subset.num_elements()
except:
continue

return min

return min


def _union_special_cases(arb: symbolic.SymbolicType, brb: symbolic.SymbolicType, are: symbolic.SymbolicType,
Expand Down Expand Up @@ -1261,8 +1269,6 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range:
return Range(result)




def union(subset_a: Subset, subset_b: Subset) -> Subset:
""" Compute the union of two Subset objects.
If the subsets are not of the same type, degenerates to bounding-box
Expand Down Expand Up @@ -1331,6 +1337,7 @@ def list_union(subset_a: Subset, subset_b: Subset) -> Subset:
except TypeError:
return None


def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]:
"""
Returns True if two subsets intersect, False if they do not, or
Expand All @@ -1352,3 +1359,85 @@ def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]:
return None
except TypeError: # cannot determine truth value of Relational
return None


class SubrangeMapper:
"""
Equipped with a `src` and a `dst` range of equal volumes but possibly different shapes or even dimensions (i.e.,
there is an 1-to-1 correspondence between the elements of `src` and `dst`), maps a subrange of `src` to its
counterpart in `dst`, if possible.
Note that such subrange-to-subrange mapping may not always exist.
"""

def __init__(self, src: Range, dst: Range):
src, dst = self.canonical(src), self.canonical(dst)
assert src.volume_exact() == dst.volume_exact()
self.src, self.dst = src, dst

@staticmethod
def canonical(r: Range) -> Range:
"""
Extends the (excluded) upper bound of each component of the ranges as much as possible, without affecting the
volume of the range.
"""
return Range([(b, b + s * ceiling((e - b + 1) / s) - 1, s)
for b, e, s in r.ndrange()])

def map(self, r: Range) -> Optional[Range]:
r = self.canonical(r)
# Ideally we also have `assert self.src.covers_precise(r)`. However, we cannot determine that for symbols.
assert self.src.dims() == r.dims()
out = []
src_i, dst_i = 0, 0
while src_i < self.src.dims():
assert dst_i < self.dst.dims()

src_j, dst_j = None, None
for sj in range(src_i + 1, self.src.dims() + 1):
for dj in range(dst_i + 1, self.dst.dims() + 1):
if Range(self.src.ranges[src_i:sj]).volume_exact() == Range(
self.dst.ranges[dst_i:dj]).volume_exact():
src_j, dst_j = sj, dj
break
else:
continue
break
if src_j is None:
return None

if Range(r.ranges[src_i: src_j]).volume_exact() == 1:
# If we are selecting just a single point in this segment, we can just pick the mapping of that point.
src_segment, dst_segment = Range(self.src.ranges[src_i: src_j]), Range(self.dst.ranges[dst_i: dst_j])
# Compute the local 1D coordinate of the point on `src`.
loc = 0
for (idx, _, _), (ridx, _, _), s in zip(reversed(src_segment.ranges),
reversed(r.ranges[src_i: src_j]),
reversed(src_segment.size())):
loc = loc * s + (ridx - idx)
# Translate that local 1D coordinate onto `dst`.
dst_coord = []
for (idx, _, _), s in zip(dst_segment.ranges, dst_segment.size()):
dst_coord.append(loc % s + idx)
loc = loc // s
out.extend([(idx, idx, 1) for idx in dst_coord])
elif self.src.ranges[src_i: src_j] == r.ranges[src_i: src_j]:
# If we are selecting the entirety of this segment, we can just pick the corresponding mapped segment in
# its entirety too.
out.extend(self.dst.ranges[dst_i:dst_j])
elif src_j - src_i == 1 and dst_j - dst_i == 1:
# If the segment lengths on both sides are just 1, the mapping is easy to compute.
sb, se, ss = self.src.ranges[src_i]
db, de, ds = self.dst.ranges[dst_i]
b, e, s = r.ranges[src_i]
lb, le, ls = (b - sb) // ss, (e - se) // ss - 1, s // ss
tb, te, ts = db + lb * ds, de + (le + 1) * ds, ds * ls
out.append((tb, te, ts))
else:
# TODO: Can we narrow down this case even more? That would be number theoretic problem.
# E.g., If we are reshaping [6, 5] to [2, 15], we are demanding that these dimensions must be wholly
# selected for now.
return None

src_i, dst_i = src_j, dst_j
return Range(out)
Loading

0 comments on commit 4911a01

Please sign in to comment.