diff --git a/pyop3/axtree/tree.py b/pyop3/axtree/tree.py index e232a45a..1096e385 100644 --- a/pyop3/axtree/tree.py +++ b/pyop3/axtree/tree.py @@ -1078,6 +1078,8 @@ def restore(self): def index(self) -> LoopIndex: from pyop3.itree import LoopIndex + # TODO + # return LoopIndex(self.owned) return LoopIndex(self) @property @@ -1306,6 +1308,8 @@ def __getitem__(self, indices) -> ContextSensitiveAxisTree: def index(self) -> LoopIndex: from pyop3.itree import LoopIndex + # TODO + # return LoopIndex(self.owned) return LoopIndex(self) @cached_property diff --git a/pyop3/itree/tree.py b/pyop3/itree/tree.py index e721c7cf..8f483ee5 100644 --- a/pyop3/itree/tree.py +++ b/pyop3/itree/tree.py @@ -1339,6 +1339,18 @@ def iter_axis_tree( # yield path_, indices_ +class ArrayPointLabel(enum.IntEnum): + CORE = 0 + ROOT = 1 + LEAF = 2 + + +class IterationPointType(enum.IntEnum): + CORE = 0 + ROOT = 1 + LEAF = 2 + + # TODO This should work for multiple loop indices. One should really pass a loop expression. def partition_iterset(index: LoopIndex, arrays): """Split an iteration set into core, root and leaf index sets. @@ -1381,23 +1393,24 @@ def partition_iterset(index: LoopIndex, arrays): continue # take first - array_paraxes = [ - axis for axis in array.orig_array.axes.nodes if axis.sf is not None - ] - - array_paraxis = array_paraxes[0] - sf = array_paraxis.sf + # array_paraxes = [ + # axis for axis in array.orig_array.axes.nodes if axis.sf is not None + # ] + # + # array_paraxis = array_paraxes[0] + # sf = array_paraxis.sf + sf = array.orig_array.axes.sf # the dof sf # mark leaves and roots - is_root_or_leaf = np.full(sf.size, False, dtype=bool) - is_root_or_leaf[sf.iroot] = True - is_root_or_leaf[sf.ileaf] = True + is_root_or_leaf = np.full(sf.size, ArrayPointLabel.CORE, dtype=np.uint8) + is_root_or_leaf[sf.iroot] = ArrayPointLabel.ROOT + is_root_or_leaf[sf.ileaf] = ArrayPointLabel.LEAF # do this because we need to think of the indices here as a selector # rather than a map. We need to transform to the new numbering, hence we # need to apply the map default -> reordered, but the indexing semantics # are the opposite of this - is_root_or_leaf = is_root_or_leaf[list(array_paraxis.numbering)] + # is_root_or_leaf = is_root_or_leaf[array_paraxis.numbering] # this is equivalent to: # new_labels = np.empty_like(labels) # for i, l in enumerate(labels): @@ -1407,12 +1420,11 @@ def partition_iterset(index: LoopIndex, arrays): is_root_or_leaf_per_array[array.name] = is_root_or_leaf - is_core = np.full(paraxis.size, True, dtype=bool) + labels = np.full(paraxis.size, IterationPointType.CORE, dtype=np.uint8) for path, target_path, indices, target_indices in index.iter(): parindex = indices[paraxis.label] assert isinstance(parindex, numbers.Integral) - # replace_map = freeze({(index.id, axis): i for axis, i in indices.items()}) replace_map = freeze( {(index.id, axis): i for axis, i in target_indices.items()} ) @@ -1421,7 +1433,7 @@ def partition_iterset(index: LoopIndex, arrays): # skip purely local arrays if not array.orig_array.array.is_distributed: continue - if not is_core[parindex]: + if labels[parindex] == IterationPointType.LEAF: continue # loop over stencil @@ -1432,39 +1444,56 @@ def partition_iterset(index: LoopIndex, arrays): array_indices, array_target_indices, ) in array.axes.index().iter(replace_map): - allexprs = dict(array.axes.index_exprs.get(None, {})) - if not array.axes.is_empty: - for myaxis, mycpt in array.axes.path_with_nodes( - *array.axes._node_from_path(array_path) - ).items(): - allexprs.update(array.axes.index_exprs[myaxis.id, mycpt]) + # allexprs = dict(array.axes.index_exprs.get(None, {})) + # if not array.axes.is_empty: + # for myaxis, mycpt in array.axes.path_with_nodes( + # *array.axes._node_from_path(array_path) + # ).items(): + # allexprs.update(array.axes.index_exprs[myaxis.id, mycpt]) + # + offset = array.axes.offset(array_path, array_indices | replace_map) # allexprs is indexed with the "source" labels but we want a particular # "target" label, need to go backwards... or something - if len(target_path) != 1: - raise NotImplementedError - target_parallel_axis_label = just_one(target_path.keys()) - the_expr_i_want = allexprs[target_parallel_axis_label] - - pt_index = pym.evaluate( - the_expr_i_want, - replace_map | array_indices, - ExpressionEvaluator, - ) - assert isinstance(pt_index, numbers.Integral) - - if is_root_or_leaf_per_array[array.name][pt_index]: - is_core[parindex] = False - # no point doing more analysis - break + # if len(target_path) != 1: + # raise NotImplementedError + # target_parallel_axis_label = just_one(target_path.keys()) + # the_expr_i_want = allexprs[target_parallel_axis_label] + # + # # but this is for a particular component!! need to map component index to + # # "full" one, how? or just do offset? + # pt_index = pym.evaluate( + # the_expr_i_want, + # replace_map | array_indices, + # ExpressionEvaluator, + # ) + # print_if_rank(1, "ptindex", pt_index) + # assert isinstance(pt_index, numbers.Integral) + + # point_label = is_root_or_leaf_per_array[array.name][pt_index] + point_label = is_root_or_leaf_per_array[array.name][offset] + print_if_rank(1, "ptlabel", point_label) + if point_label == ArrayPointLabel.LEAF: + labels[parindex] = IterationPointType.LEAF + break # no point doing more analysis + elif point_label == ArrayPointLabel.ROOT: + assert labels[parindex] != IterationPointType.LEAF + labels[parindex] = IterationPointType.ROOT + else: + assert point_label == ArrayPointLabel.CORE + pass parcpt = just_one(paraxis.components) # for now - core = just_one(np.nonzero(is_core)) - noncore = just_one(np.nonzero(np.logical_not(is_core))) + print_with_rank("arrayper", is_root_or_leaf_per_array) + print_with_rank("labels", labels) + + core = just_one(np.nonzero(labels == IterationPointType.CORE)) + root = just_one(np.nonzero(labels == IterationPointType.ROOT)) + leaf = just_one(np.nonzero(labels == IterationPointType.LEAF)) subsets = [] - for data in [core, noncore]: + for data in [core, root, leaf]: # Constant? size = Dat(AxisTree(), data=np.asarray([len(data)]), dtype=IntType) subset = Dat( diff --git a/pyop3/lang.py b/pyop3/lang.py index 9d6b7fc7..8b98d463 100644 --- a/pyop3/lang.py +++ b/pyop3/lang.py @@ -141,10 +141,14 @@ def __call__(self, **kwargs): if self.is_parallel: # interleave computation and communication - new_index, (icore, inoncore) = partition_iterset( + new_index, (icore, iroot, ileaf) = partition_iterset( self.index, [a for a, _ in self.all_function_arguments] ) + print_with_rank("icore", icore.data) + print_with_rank("iroot", iroot.data) + print_with_rank("ileaf", ileaf.data) + assert self.index.id == new_index.id # substitute subsets into loopexpr, should maybe be done in partition_iterset @@ -152,7 +156,7 @@ def __call__(self, **kwargs): code = compile(parallel_loop) # interleave communication and computation - with self._updates_in_flight(): + with self._updates_in_flight(0): # replace the parallel axis subset with one for the specific indices here extent = just_one(icore.axes.root.components).count core_kwargs = merge_dicts( @@ -160,12 +164,21 @@ def __call__(self, **kwargs): ) code(**core_kwargs) - # noncore - noncore_extent = just_one(inoncore.axes.root.components).count - noncore_kwargs = merge_dicts( - [kwargs, {icore.name: inoncore, extent.name: noncore_extent}] + # roots + with self._updates_in_flight(1): + # replace the parallel axis subset with one for the specific indices here + root_extent = just_one(iroot.axes.root.components).count + root_kwargs = merge_dicts( + [kwargs, {icore.name: iroot, extent.name: root_extent}] + ) + code(**root_kwargs) + + # leaves + leaf_extent = just_one(ileaf.axes.root.components).count + leaf_kwargs = merge_dicts( + [kwargs, {icore.name: ileaf, extent.name: leaf_extent}] ) - code(**noncore_kwargs) + code(**leaf_kwargs) # also may need to eagerly assemble Mats, or be clever? else: @@ -230,6 +243,9 @@ def _array_updates(self): # core entities (in the iterset) are defined as being those that do # not overlap with any points in the star forest. + # NOTE: The following is now slightly out-of-date. We now distinguish + # *per array* and also split each generation into starts and finalizers. + # TODO update this comment to account for different threading models # Since we sometimes have to do a reduce and then a broadcast the messages @@ -253,23 +269,33 @@ def _array_updates(self): if intent in {READ, RW}: if touches_ghost_points: if not array._roots_valid: - messages[array][0].append(array._reduce_leaves_to_roots_begin) - messages[array][1].extend( + # 2-tuple of inits (bad name) and finalizers + messages[array][0] = ( + # init + [array._reduce_leaves_to_roots_begin], + # finalizer [ array._reduce_leaves_to_roots_end, array._broadcast_roots_to_leaves_begin, - ] + ], + ) + messages[array][1] = ( + [], + [array._broadcast_roots_to_leaves_end], ) - messages[array][-1].append(array._broadcast_roots_to_leaves_end) else: - messages[array][0].append( - array._broadcast_roots_to_leaves_begin + messages[array][0] = ( + [array._broadcast_roots_to_leaves_begin], + [], + ) + messages[array][1] = ( + [], + [array._broadcast_roots_to_leaves_end], ) - messages[array][-1].append(array._broadcast_roots_to_leaves_end) else: if not array._roots_valid: - messages[array][0].append(array.reduce_leaves_to_roots_begin) - messages[array][-1].append(array.reduce_leaves_to_roots_end) + messages[array][0] = ([array.reduce_leaves_to_roots_begin], []) + messages[array][1] = ([], [array.reduce_leaves_to_roots_end]) elif intent == WRITE: # Assumes that all points are written to (i.e. not a subset). If @@ -292,8 +318,8 @@ def _array_updates(self): # explained in the documentation. if intent in {INC, MIN_RW, MAX_RW}: assert array._pending_reduction is not None - messages[array][0].append(array.reduce_leaves_to_roots_begin) - messages[array][-1].append(array.reduce_leaves_to_roots_end) + messages[array][0] = ([array.reduce_leaves_to_roots_begin], []) + messages[array][1] = ([], [array.reduce_leaves_to_roots_end]) # We are modifying owned values so the leaves must now be wrong array._leaves_valid = False @@ -320,11 +346,13 @@ def _array_updates(self): return messages @contextlib.contextmanager - def _updates_in_flight(self): + def _updates_in_flight(self, generation): """Context manager for interleaving computation and communication.""" sendrecvs = self._array_updates() if config["thread_model"] in {"SINGLE", "SERIALIZED"}: + raise NotImplementedError("Untested") + # now out-of-date # cannot do a thread per-array, compress the sendrecvs new_sendrecvs = defaultdict(list) for sendrecvs_per_array in sendrecvs.values(): @@ -334,6 +362,7 @@ def _updates_in_flight(self): # TODO make an enum if config["thread_model"] == "SINGLE": + raise NotImplementedError("Untested") # multithreading is not supported, do all of the sendrecvs apart # from the final generation eagerly ngenerations = len(sendrecvs) - 1 @@ -342,6 +371,7 @@ def _updates_in_flight(self): sendrecv() finalizers = messages[-1] elif config["thread_model"] == "SERIALIZED": + raise NotImplementedError("Untested") # ghost exchanges can only be done on a single thread thread = threading.Thread( target=self.__class__._sendrecv, args=(sendrecvs,) @@ -352,8 +382,9 @@ def _updates_in_flight(self): # different thread per array finalizers = [] for sendrecvs_per_array in sendrecvs.values(): + inits, finalizers_ = sendrecvs_per_array[generation] thread = threading.Thread( - target=self.__class__._sendrecv, args=(sendrecvs_per_array,) + target=self.__class__._sendrecv, args=(inits, finalizers_) ) thread.start() finalizers.append(thread.join) @@ -367,12 +398,11 @@ def _updates_in_flight(self): f() @staticmethod - def _sendrecv(messages): - # loop over generations starting from 0 and ending with -1 - ngenerations = len(messages) - 1 - for gen in [*range(ngenerations), -1]: - for msg in messages[gen]: - msg() + def _sendrecv(inits, finalizers): + for msg in inits: + msg() + for fin in finalizers: + fin() # TODO singledispatch diff --git a/tests/integration/test_parallel_loops.py b/tests/integration/test_parallel_loops.py index 33974c56..2e88b150 100644 --- a/tests/integration/test_parallel_loops.py +++ b/tests/integration/test_parallel_loops.py @@ -7,6 +7,7 @@ import pyop3 as op3 from pyop3.extras.debug import print_with_rank from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET +from pyop3.utils import just_one def set_kernel(size, intent): @@ -212,16 +213,29 @@ def test_parallel_loop_with_map(comm, mesh_axis, cone_map, scalar_copy_kernel): knl = set_kernel(2, intent) + # since we don't unpick "owned" for indexed axes yet + loop_index = mesh_axis.axes.freeze().owned["cells"].index() + op3.do_loop( - c := mesh_axis["cells"].index(), + # c := mesh_axis["cells"].index(), + c := loop_index, knl(rank_dat, dat[cone_map(c)]), ) - # for the cone of an interval mesh (and before reductions) we expect interior - # vertices to be touched twice and exterior vertices to be touched once - nverts = mesh_axis.components[1].count - assert np.count_nonzero(dat.array._data == write_value * 2) == nverts - 2 - assert np.count_nonzero(dat.array._data == write_value) == 2 + # we now expect the (renumbered) values to look like + # 1 0 2 0 1 * 0 0 + # [rank 0] x-----x-----x * -----x + # * + # [rank 1] x * -----x-----x-----x-----x + # 2 * 0 4 0 4 0 4 0 2 + if comm.rank == 0: + assert np.count_nonzero(dat.array._data == 0) == 4 + assert np.count_nonzero(dat.array._data == 1) == 2 + assert np.count_nonzero(dat.array._data == 2) == 1 + else: + assert np.count_nonzero(dat.array._data == 0) == 4 + assert np.count_nonzero(dat.array._data == 2) == 2 + assert np.count_nonzero(dat.array._data == 4) == 3 # there should be a pending reduction assert dat.array._pending_reduction == intent @@ -229,28 +243,49 @@ def test_parallel_loop_with_map(comm, mesh_axis, cone_map, scalar_copy_kernel): assert not dat.array._leaves_valid # now do the reduction - # dat.data_ro - print_with_rank("before", dat.array._data) dat.array._reduce_leaves_to_roots() - - print_with_rank("after", dat.array._data) - assert dat.array._pending_reduction is None assert dat.array._roots_valid # leaves are still not up-to-date, requires a broadcast assert not dat.array._leaves_valid - # NOTE: This demonstrates an issue with my current implementation. We really - # want an SF for each loop that we do because if we only modify a subset of values - # (here just the vertices) then we still do a reduction on the cells even if they - # weren't changed. - - # Both ranks have a single exterior ghost vertex that touches an interior vertex - # on the other rank. Therefore we expect one owned value to have a value of - # interior_value_on_current_rank + exterior_value_on_other_rank. - assert np.count_nonzero(dat.array._data == write_value * 2 + other_write_value) == 1 - assert np.count_nonzero(dat.array._data == write_value) == 2 - assert np.count_nonzero(dat.array._data == write_value * 2) == nverts - 3 + # we now expect the (renumbered) values to look like + # 1 0 2 0 3 * 0 0 + # [rank 0] x-----x-----x * -----x + # * + # [rank 1] x * -----x-----x-----x-----x + # 2 * 0 4 0 4 0 4 0 2 + if comm.rank == 0: + assert np.count_nonzero(dat.array._data == 0) == 4 + assert np.count_nonzero(dat.array._data == 1) == 1 + assert np.count_nonzero(dat.array._data == 2) == 1 + assert np.count_nonzero(dat.array._data == 3) == 1 + else: + assert np.count_nonzero(dat.array._data == 0) == 4 + assert np.count_nonzero(dat.array._data == 2) == 2 + assert np.count_nonzero(dat.array._data == 4) == 3 + + # now broadcast to leaves + dat.array._broadcast_roots_to_leaves() + assert dat.array._leaves_valid + + # we now expect the (renumbered) values to look like + # 1 0 2 0 3 * 0 4 + # [rank 0] x-----x-----x * -----x + # * + # [rank 1] x * -----x-----x-----x-----x + # 3 * 0 4 0 4 0 4 0 2 + if comm.rank == 0: + assert np.count_nonzero(dat.array._data == 0) == 3 + assert np.count_nonzero(dat.array._data == 1) == 1 + assert np.count_nonzero(dat.array._data == 2) == 1 + assert np.count_nonzero(dat.array._data == 3) == 1 + assert np.count_nonzero(dat.array._data == 4) == 1 + else: + assert np.count_nonzero(dat.array._data == 0) == 4 + assert np.count_nonzero(dat.array._data == 2) == 1 + assert np.count_nonzero(dat.array._data == 3) == 1 + assert np.count_nonzero(dat.array._data == 4) == 3 @pytest.mark.parallel(nprocs=2) diff --git a/tests/unit/test_parallel.py b/tests/unit/test_parallel.py index 7a9669f8..fbf1c841 100644 --- a/tests/unit/test_parallel.py +++ b/tests/unit/test_parallel.py @@ -175,27 +175,20 @@ def test_partition_iterset_scalar(comm, paxis, with_ghosts): p = paxis.index() tmp = array[p] - _, (icore, inoncore) = partition_iterset(p, [tmp]) + _, (icore, iroot, ileaf) = partition_iterset(p, [tmp]) if comm.rank == 0: - # from [0, 1, 3, 2, 4, 5] and knowing that ... - # this is so confusing - # basically for this case the numbering is such that the root entities - # come before the core ones. Ghost will always be the final entries because that - # is how we do the numbering in the first place. expected_icore = [2, 3] - expected_inoncore = [0, 1] - if with_ghosts: - expected_inoncore += [4, 5] + expected_iroot = [0, 1] + expected_ileaf = [4, 5] if with_ghosts else [] else: assert comm.rank == 1 - # numbering = [0, 4, 1, 2, 5, 3] expected_icore = [0, 1] - expected_inoncore = [2, 3] - if with_ghosts: - expected_inoncore += [4, 5] + expected_iroot = [2, 3] + expected_ileaf = [4, 5] if with_ghosts else [] assert np.equal(icore.data_ro, expected_icore).all() - assert np.equal(inoncore.data_ro, expected_inoncore).all() + assert np.equal(iroot.data_ro, expected_iroot).all() + assert np.equal(ileaf.data_ro, expected_ileaf).all() @pytest.mark.parallel(nprocs=2) @@ -237,18 +230,17 @@ def test_partition_iterset_with_map(comm, paxis, with_ghosts): else: p = paxis.index() tmp = array[map0(p)] - _, (icore, inoncore) = partition_iterset(p, [tmp]) + _, (icore, iroot, ileaf) = partition_iterset(p, [tmp]) if comm.rank == 0: expected_icore = [3] - expected_inoncore = [0, 1, 2] - if with_ghosts: - expected_inoncore += [4, 5] + expected_iroot = [1, 2] + expected_ileaf = [0, 4, 5] if with_ghosts else [0] else: assert comm.rank == 1 expected_icore = [0] - expected_inoncore = [1, 2, 3] - if with_ghosts: - expected_inoncore += [4, 5] + expected_iroot = [1, 2] + expected_ileaf = [3, 4, 5] if with_ghosts else [3] assert np.equal(icore.data_ro, expected_icore).all() - assert np.equal(inoncore.data_ro, expected_inoncore).all() + assert np.equal(iroot.data_ro, expected_iroot).all() + assert np.equal(ileaf.data_ro, expected_ileaf).all()