Skip to content

Commit

Permalink
Iterset is now partitioned into CORE, ROOT and LEAF
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Nov 21, 2023
1 parent 513d894 commit 0a37498
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 108 deletions.
4 changes: 4 additions & 0 deletions pyop3/axtree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,8 @@ def restore(self):
def index(self) -> LoopIndex:
from pyop3.itree import LoopIndex

# TODO
# return LoopIndex(self.owned)
return LoopIndex(self)

@property
Expand Down Expand Up @@ -1306,6 +1308,8 @@ def __getitem__(self, indices) -> ContextSensitiveAxisTree:
def index(self) -> LoopIndex:
from pyop3.itree import LoopIndex

# TODO
# return LoopIndex(self.owned)
return LoopIndex(self)

@cached_property
Expand Down
105 changes: 67 additions & 38 deletions pyop3/itree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,18 @@ def iter_axis_tree(
# yield path_, indices_


class ArrayPointLabel(enum.IntEnum):
CORE = 0
ROOT = 1
LEAF = 2


class IterationPointType(enum.IntEnum):
CORE = 0
ROOT = 1
LEAF = 2


# TODO This should work for multiple loop indices. One should really pass a loop expression.
def partition_iterset(index: LoopIndex, arrays):
"""Split an iteration set into core, root and leaf index sets.
Expand Down Expand Up @@ -1381,23 +1393,24 @@ def partition_iterset(index: LoopIndex, arrays):
continue

# take first
array_paraxes = [
axis for axis in array.orig_array.axes.nodes if axis.sf is not None
]

array_paraxis = array_paraxes[0]
sf = array_paraxis.sf
# array_paraxes = [
# axis for axis in array.orig_array.axes.nodes if axis.sf is not None
# ]
#
# array_paraxis = array_paraxes[0]
# sf = array_paraxis.sf
sf = array.orig_array.axes.sf # the dof sf

# mark leaves and roots
is_root_or_leaf = np.full(sf.size, False, dtype=bool)
is_root_or_leaf[sf.iroot] = True
is_root_or_leaf[sf.ileaf] = True
is_root_or_leaf = np.full(sf.size, ArrayPointLabel.CORE, dtype=np.uint8)
is_root_or_leaf[sf.iroot] = ArrayPointLabel.ROOT
is_root_or_leaf[sf.ileaf] = ArrayPointLabel.LEAF

# do this because we need to think of the indices here as a selector
# rather than a map. We need to transform to the new numbering, hence we
# need to apply the map default -> reordered, but the indexing semantics
# are the opposite of this
is_root_or_leaf = is_root_or_leaf[list(array_paraxis.numbering)]
# is_root_or_leaf = is_root_or_leaf[array_paraxis.numbering]
# this is equivalent to:
# new_labels = np.empty_like(labels)
# for i, l in enumerate(labels):
Expand All @@ -1407,12 +1420,11 @@ def partition_iterset(index: LoopIndex, arrays):

is_root_or_leaf_per_array[array.name] = is_root_or_leaf

is_core = np.full(paraxis.size, True, dtype=bool)
labels = np.full(paraxis.size, IterationPointType.CORE, dtype=np.uint8)
for path, target_path, indices, target_indices in index.iter():
parindex = indices[paraxis.label]
assert isinstance(parindex, numbers.Integral)

# replace_map = freeze({(index.id, axis): i for axis, i in indices.items()})
replace_map = freeze(
{(index.id, axis): i for axis, i in target_indices.items()}
)
Expand All @@ -1421,7 +1433,7 @@ def partition_iterset(index: LoopIndex, arrays):
# skip purely local arrays
if not array.orig_array.array.is_distributed:
continue
if not is_core[parindex]:
if labels[parindex] == IterationPointType.LEAF:
continue

# loop over stencil
Expand All @@ -1432,39 +1444,56 @@ def partition_iterset(index: LoopIndex, arrays):
array_indices,
array_target_indices,
) in array.axes.index().iter(replace_map):
allexprs = dict(array.axes.index_exprs.get(None, {}))
if not array.axes.is_empty:
for myaxis, mycpt in array.axes.path_with_nodes(
*array.axes._node_from_path(array_path)
).items():
allexprs.update(array.axes.index_exprs[myaxis.id, mycpt])
# allexprs = dict(array.axes.index_exprs.get(None, {}))
# if not array.axes.is_empty:
# for myaxis, mycpt in array.axes.path_with_nodes(
# *array.axes._node_from_path(array_path)
# ).items():
# allexprs.update(array.axes.index_exprs[myaxis.id, mycpt])
#
offset = array.axes.offset(array_path, array_indices | replace_map)

# allexprs is indexed with the "source" labels but we want a particular
# "target" label, need to go backwards... or something
if len(target_path) != 1:
raise NotImplementedError
target_parallel_axis_label = just_one(target_path.keys())
the_expr_i_want = allexprs[target_parallel_axis_label]

pt_index = pym.evaluate(
the_expr_i_want,
replace_map | array_indices,
ExpressionEvaluator,
)
assert isinstance(pt_index, numbers.Integral)

if is_root_or_leaf_per_array[array.name][pt_index]:
is_core[parindex] = False
# no point doing more analysis
break
# if len(target_path) != 1:
# raise NotImplementedError
# target_parallel_axis_label = just_one(target_path.keys())
# the_expr_i_want = allexprs[target_parallel_axis_label]
#
# # but this is for a particular component!! need to map component index to
# # "full" one, how? or just do offset?
# pt_index = pym.evaluate(
# the_expr_i_want,
# replace_map | array_indices,
# ExpressionEvaluator,
# )
# print_if_rank(1, "ptindex", pt_index)
# assert isinstance(pt_index, numbers.Integral)

# point_label = is_root_or_leaf_per_array[array.name][pt_index]
point_label = is_root_or_leaf_per_array[array.name][offset]
print_if_rank(1, "ptlabel", point_label)
if point_label == ArrayPointLabel.LEAF:
labels[parindex] = IterationPointType.LEAF
break # no point doing more analysis
elif point_label == ArrayPointLabel.ROOT:
assert labels[parindex] != IterationPointType.LEAF
labels[parindex] = IterationPointType.ROOT
else:
assert point_label == ArrayPointLabel.CORE
pass

parcpt = just_one(paraxis.components) # for now

core = just_one(np.nonzero(is_core))
noncore = just_one(np.nonzero(np.logical_not(is_core)))
print_with_rank("arrayper", is_root_or_leaf_per_array)
print_with_rank("labels", labels)

core = just_one(np.nonzero(labels == IterationPointType.CORE))
root = just_one(np.nonzero(labels == IterationPointType.ROOT))
leaf = just_one(np.nonzero(labels == IterationPointType.LEAF))

subsets = []
for data in [core, noncore]:
for data in [core, root, leaf]:
# Constant?
size = Dat(AxisTree(), data=np.asarray([len(data)]), dtype=IntType)
subset = Dat(
Expand Down
82 changes: 56 additions & 26 deletions pyop3/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,31 +141,44 @@ def __call__(self, **kwargs):

if self.is_parallel:
# interleave computation and communication
new_index, (icore, inoncore) = partition_iterset(
new_index, (icore, iroot, ileaf) = partition_iterset(
self.index, [a for a, _ in self.all_function_arguments]
)

print_with_rank("icore", icore.data)
print_with_rank("iroot", iroot.data)
print_with_rank("ileaf", ileaf.data)

assert self.index.id == new_index.id

# substitute subsets into loopexpr, should maybe be done in partition_iterset
parallel_loop = self.copy(index=new_index)
code = compile(parallel_loop)

# interleave communication and computation
with self._updates_in_flight():
with self._updates_in_flight(0):
# replace the parallel axis subset with one for the specific indices here
extent = just_one(icore.axes.root.components).count
core_kwargs = merge_dicts(
[kwargs, {icore.name: icore, extent.name: extent}]
)
code(**core_kwargs)

# noncore
noncore_extent = just_one(inoncore.axes.root.components).count
noncore_kwargs = merge_dicts(
[kwargs, {icore.name: inoncore, extent.name: noncore_extent}]
# roots
with self._updates_in_flight(1):
# replace the parallel axis subset with one for the specific indices here
root_extent = just_one(iroot.axes.root.components).count
root_kwargs = merge_dicts(
[kwargs, {icore.name: iroot, extent.name: root_extent}]
)
code(**root_kwargs)

# leaves
leaf_extent = just_one(ileaf.axes.root.components).count
leaf_kwargs = merge_dicts(
[kwargs, {icore.name: ileaf, extent.name: leaf_extent}]
)
code(**noncore_kwargs)
code(**leaf_kwargs)

# also may need to eagerly assemble Mats, or be clever?
else:
Expand Down Expand Up @@ -230,6 +243,9 @@ def _array_updates(self):
# core entities (in the iterset) are defined as being those that do
# not overlap with any points in the star forest.

# NOTE: The following is now slightly out-of-date. We now distinguish
# *per array* and also split each generation into starts and finalizers.

# TODO update this comment to account for different threading models

# Since we sometimes have to do a reduce and then a broadcast the messages
Expand All @@ -253,23 +269,33 @@ def _array_updates(self):
if intent in {READ, RW}:
if touches_ghost_points:
if not array._roots_valid:
messages[array][0].append(array._reduce_leaves_to_roots_begin)
messages[array][1].extend(
# 2-tuple of inits (bad name) and finalizers
messages[array][0] = (
# init
[array._reduce_leaves_to_roots_begin],
# finalizer
[
array._reduce_leaves_to_roots_end,
array._broadcast_roots_to_leaves_begin,
]
],
)
messages[array][1] = (
[],
[array._broadcast_roots_to_leaves_end],
)
messages[array][-1].append(array._broadcast_roots_to_leaves_end)
else:
messages[array][0].append(
array._broadcast_roots_to_leaves_begin
messages[array][0] = (
[array._broadcast_roots_to_leaves_begin],
[],
)
messages[array][1] = (
[],
[array._broadcast_roots_to_leaves_end],
)
messages[array][-1].append(array._broadcast_roots_to_leaves_end)
else:
if not array._roots_valid:
messages[array][0].append(array.reduce_leaves_to_roots_begin)
messages[array][-1].append(array.reduce_leaves_to_roots_end)
messages[array][0] = ([array.reduce_leaves_to_roots_begin], [])
messages[array][1] = ([], [array.reduce_leaves_to_roots_end])

elif intent == WRITE:
# Assumes that all points are written to (i.e. not a subset). If
Expand All @@ -292,8 +318,8 @@ def _array_updates(self):
# explained in the documentation.
if intent in {INC, MIN_RW, MAX_RW}:
assert array._pending_reduction is not None
messages[array][0].append(array.reduce_leaves_to_roots_begin)
messages[array][-1].append(array.reduce_leaves_to_roots_end)
messages[array][0] = ([array.reduce_leaves_to_roots_begin], [])
messages[array][1] = ([], [array.reduce_leaves_to_roots_end])

# We are modifying owned values so the leaves must now be wrong
array._leaves_valid = False
Expand All @@ -320,11 +346,13 @@ def _array_updates(self):
return messages

@contextlib.contextmanager
def _updates_in_flight(self):
def _updates_in_flight(self, generation):
"""Context manager for interleaving computation and communication."""
sendrecvs = self._array_updates()

if config["thread_model"] in {"SINGLE", "SERIALIZED"}:
raise NotImplementedError("Untested")
# now out-of-date
# cannot do a thread per-array, compress the sendrecvs
new_sendrecvs = defaultdict(list)
for sendrecvs_per_array in sendrecvs.values():
Expand All @@ -334,6 +362,7 @@ def _updates_in_flight(self):

# TODO make an enum
if config["thread_model"] == "SINGLE":
raise NotImplementedError("Untested")
# multithreading is not supported, do all of the sendrecvs apart
# from the final generation eagerly
ngenerations = len(sendrecvs) - 1
Expand All @@ -342,6 +371,7 @@ def _updates_in_flight(self):
sendrecv()
finalizers = messages[-1]
elif config["thread_model"] == "SERIALIZED":
raise NotImplementedError("Untested")
# ghost exchanges can only be done on a single thread
thread = threading.Thread(
target=self.__class__._sendrecv, args=(sendrecvs,)
Expand All @@ -352,8 +382,9 @@ def _updates_in_flight(self):
# different thread per array
finalizers = []
for sendrecvs_per_array in sendrecvs.values():
inits, finalizers_ = sendrecvs_per_array[generation]
thread = threading.Thread(
target=self.__class__._sendrecv, args=(sendrecvs_per_array,)
target=self.__class__._sendrecv, args=(inits, finalizers_)
)
thread.start()
finalizers.append(thread.join)
Expand All @@ -367,12 +398,11 @@ def _updates_in_flight(self):
f()

@staticmethod
def _sendrecv(messages):
# loop over generations starting from 0 and ending with -1
ngenerations = len(messages) - 1
for gen in [*range(ngenerations), -1]:
for msg in messages[gen]:
msg()
def _sendrecv(inits, finalizers):
for msg in inits:
msg()
for fin in finalizers:
fin()


# TODO singledispatch
Expand Down
Loading

0 comments on commit 0a37498

Please sign in to comment.