From 052fba6d0f5705d3bc20e197ae6c79a689eaccd0 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 8 Dec 2023 13:52:02 +0000 Subject: [PATCH 1/2] Remove old code, tests passing --- pyop3/array/harray.py | 204 +--------------------------- pyop3/axtree/layout.py | 32 ----- pyop3/axtree/parallel.py | 19 ++- pyop3/axtree/tree.py | 286 --------------------------------------- pyop3/ir/lower.py | 4 +- 5 files changed, 21 insertions(+), 524 deletions(-) diff --git a/pyop3/array/harray.py b/pyop3/array/harray.py index 60c4e2fd..bfe8dc64 100644 --- a/pyop3/array/harray.py +++ b/pyop3/array/harray.py @@ -106,14 +106,8 @@ def __init__( ): super().__init__(name=name, prefix=prefix) - # TODO This is ugly - # temporary_axes = as_axis_tree(axes).freeze() # used for the temporary - # previously layout_axes - # drop index_exprs... axes = as_axis_tree(axes) - # axes = as_layout_axes(axes) - if isinstance(data, Buffer): # disable for now, temporaries hit this in an annoying way # if data.sf is not axes.sf: @@ -139,9 +133,7 @@ def __init__( self.buffer = data # instead implement "materialize" - # self.temporary_axes = temporary_axes self.axes = axes - self.layout_axes = axes # used? likely don't need all these self.max_value = max_value @@ -167,7 +159,7 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]: loop_contexts = collect_loop_contexts(indices) if not loop_contexts: - index_tree = just_one(as_index_forest(indices, axes=self.layout_axes)) + index_tree = just_one(as_index_forest(indices, axes=self.axes)) ( indexed_axes, target_path_per_indexed_cpt, @@ -530,197 +522,3 @@ def datamap(self): def _shared_attr(self, attr: str): return single_valued(getattr(a, attr) for a in self.context_map.values()) - - -def replace_layout(orig_layout, replace_map): - return IndexExpressionReplacer(replace_map)(orig_layout) - - -def as_layout_axes(axes: AxisTree) -> AxisTree: - # drop index exprs, everything else drops out - return AxisTree( - axes.parent_to_children, - axes.target_paths, - axes._default_index_exprs(), - axes.layout_exprs, - axes.layouts, - sf=axes.sf, - ) - - -def make_sparsity( - iterindex, - lmap, - rmap, - llabels=PrettyTuple(), - rlabels=PrettyTuple(), - lindices=PrettyTuple(), - rindices=PrettyTuple(), -): - if iterindex: - if iterindex.children: - raise NotImplementedError( - "Need to think about what to do when we have more complicated " - "iteration sets that have multiple indices (e.g. extruded cells)" - ) - - if not isinstance(iterindex, Range): - raise NotImplementedError( - "Need to think about whether maps are reasonable here" - ) - - if not is_single_valued(idx.id for idx in [iterindex, lmap, rmap]): - raise ValueError("Indices must share common roots") - - sparsity = collections.defaultdict(set) - for i in range(iterindex.size): - subsparsity = make_sparsity( - None, - lmap.child, - rmap.child, - llabels | iterindex.label, - rlabels | iterindex.label, - lindices | i, - rindices | i, - ) - for labels, indices in subsparsity.items(): - sparsity[labels].update(indices) - return sparsity - elif lmap: - if not isinstance(lmap, TabulatedMap): - raise NotImplementedError("Need to think about other index types") - if len(lmap.children) not in [0, 1]: - raise NotImplementedError("Need to think about maps forking") - - new_labels = list(llabels) - # first pop the old things - for lbl in lmap.from_labels: - if lbl != new_labels[-1]: - raise ValueError("from_labels must match existing labels") - new_labels.pop() - # then append the new ones - only do the labels here, indices are - # done inside the loop - new_labels.extend(lmap.to_labels) - new_labels = PrettyTuple(new_labels) - - sparsity = collections.defaultdict(set) - for i in range(lmap.size): - new_indices = PrettyTuple([lmap.data.get_value(lindices | i)]) - subsparsity = make_sparsity( - None, lmap.child, rmap, new_labels, rlabels, new_indices, rindices - ) - for labels, indices in subsparsity.items(): - sparsity[labels].update(indices) - return sparsity - elif rmap: - if not isinstance(rmap, TabulatedMap): - raise NotImplementedError("Need to think about other index types") - if len(rmap.children) not in [0, 1]: - raise NotImplementedError("Need to think about maps forking") - - new_labels = list(rlabels) - # first pop the old labels - for lbl in rmap.from_labels: - if lbl != new_labels[-1]: - raise ValueError("from_labels must match existing labels") - new_labels.pop() - # then append the new ones - new_labels.extend(rmap.to_labels) - new_labels = PrettyTuple(new_labels) - - sparsity = collections.defaultdict(set) - for i in range(rmap.size): - new_indices = PrettyTuple([rmap.data.get_value(rindices | i)]) - subsparsity = make_sparsity( - None, lmap, rmap.child, llabels, new_labels, lindices, new_indices - ) - for labels, indices in subsparsity.items(): - sparsity[labels].update(indices) - return sparsity - else: - # at the bottom, record an entry - # return {(llabels, rlabels): {(lindices, rindices)}} - # TODO: For now assume single values for each of these - llabel, rlabel = map(single_valued, [llabels, rlabels]) - lindex, rindex = map(single_valued, [lindices, rindices]) - return {(llabel, rlabel): {(lindex, rindex)}} - - -def distribute_sparsity(sparsity, ax1, ax2, owner="row"): - if any(ax.nparts > 1 for ax in [ax1, ax2]): - raise NotImplementedError("Only dealing with single-part multi-axes for now") - - # how many points need to get sent to other processes? - # how many points do I get from other processes? - new_sparsity = collections.defaultdict(set) - points_to_send = collections.defaultdict(set) - for lindex, rindex in sparsity[ax1.part.label, ax2.part.label]: - if owner == "row": - olabel = ax1.part.overlap[lindex] - if is_owned_by_process(olabel): - new_sparsity[ax1.part.label, ax2.part.label].add((lindex, rindex)) - else: - points_to_send[olabel.root.rank].add( - (ax1.part.lgmap[lindex], ax2.part.lgmap[rindex]) - ) - else: - raise NotImplementedError - - # send points - - # first determine how many new points we are getting from each rank - comm = single_valued([ax1.sf.comm, ax2.sf.comm]).tompi4py() - npoints_to_send = np.array( - [len(points_to_send[rank]) for rank in range(comm.size)], dtype=IntType - ) - npoints_to_recv = np.empty_like(npoints_to_send) - comm.Alltoall(npoints_to_send, npoints_to_recv) - - # communicate the offsets back - from_offsets = np.cumsum(npoints_to_recv) - to_offsets = np.empty_like(from_offsets) - comm.Alltoall(from_offsets, to_offsets) - - # now send the globally numbered row, col values for each point that - # needs to be sent. This is easiest with an SF. - - # nroots is the number of points to send - nroots = sum(npoints_to_send) - local_points = None # contiguous storage - - idx = 0 - remote_points = [] - for rank in range(comm.size): - for i in range(npoints_to_recv[rank]): - remote_points.extend([rank, to_offsets[idx]]) - idx += 1 - - sf = PETSc.SF().create(comm) - sf.setGraph(nroots, local_points, remote_points) - - # create a buffer to hold the new values - # x2 since we are sending row and column numbers - new_points = np.empty(sum(npoints_to_recv) * 2, dtype=IntType) - rootdata = np.array( - [ - num - for rank in range(comm.size) - for lnum, rnum in points_to_send[rank] - for num in [lnum, rnum] - ], - dtype=new_points.dtype, - ) - - mpi_dtype, _ = get_mpi_dtype(np.dtype(IntType)) - mpi_op = MPI.REPLACE - args = (mpi_dtype, rootdata, new_points, mpi_op) - sf.bcastBegin(*args) - sf.bcastEnd(*args) - - for i in range(sum(npoints_to_recv)): - new_sparsity[ax1.part.label, ax2.part.label].add( - (new_points[2 * i], new_points[2 * i + 1]) - ) - - # import pdb; pdb.set_trace() - return new_sparsity diff --git a/pyop3/axtree/layout.py b/pyop3/axtree/layout.py index 7641f4e5..1fda7011 100644 --- a/pyop3/axtree/layout.py +++ b/pyop3/axtree/layout.py @@ -114,38 +114,6 @@ def step_size( return 1 -def make_star_forest_per_axis_part(part, comm): - if part.is_distributed: - # we have a root if a point is shared but doesn't point to another rank - nroots = len( - [pt for pt in part.overlap if isinstance(pt, Shared) and not pt.root] - ) - - # which local points are leaves? - local_points = [ - i for i, pt in enumerate(part.overlap) if not is_owned_by_process(pt) - ] - - # roots of other processes (rank, index) - remote_points = utils.flatten( - [pt.root.as_tuple() for pt in part.overlap if not is_owned_by_process(pt)] - ) - - # import pdb; pdb.set_trace() - - sf = PETSc.SF().create(comm) - sf.setGraph(nroots, local_points, remote_points) - return sf - else: - raise NotImplementedError( - "Need to think about concatenating star forests. This will happen if mixed." - ) - - -def attach_owned_star_forest(axis): - raise NotImplementedError - - def has_halo(axes, axis): if axis.sf is not None: return True diff --git a/pyop3/axtree/parallel.py b/pyop3/axtree/parallel.py index 850b9172..68542326 100644 --- a/pyop3/axtree/parallel.py +++ b/pyop3/axtree/parallel.py @@ -1,16 +1,33 @@ from __future__ import annotations +import functools + import numpy as np from mpi4py import MPI from petsc4py import PETSc from pyrsistent import pmap from pyop3.axtree.layout import _as_int, _axis_component_size, step_size -from pyop3.dtypes import IntType, get_mpi_dtype +from pyop3.dtypes import IntType, as_numpy_dtype, get_mpi_dtype from pyop3.extras.debug import print_with_rank from pyop3.utils import checked_zip, just_one, strict_int +def reduction_op(op, invec, inoutvec, datatype): + dtype = as_numpy_dtype(datatype) + invec = np.frombuffer(invec, dtype=dtype) + inoutvec = np.frombuffer(inoutvec, dtype=dtype) + inoutvec[:] = op(invec, inoutvec) + + +_contig_min_op = MPI.Op.Create( + functools.partial(reduction_op, np.minimum), commute=True +) +_contig_max_op = MPI.Op.Create( + functools.partial(reduction_op, np.maximum), commute=True +) + + def partition_ghost_points(axis, sf): npoints = sf.size is_owned = np.full(npoints, True, dtype=bool) diff --git a/pyop3/axtree/tree.py b/pyop3/axtree/tree.py index e7582de5..326f2b1a 100644 --- a/pyop3/axtree/tree.py +++ b/pyop3/axtree/tree.py @@ -199,254 +199,6 @@ def map_called_map(self, expr): return array.get_value(path, indices) -def get_bottom_part(axis): - # must be linear - return just_one(axis.leaves) - - -def as_multiaxis(axis): - if isinstance(axis, MultiAxis): - return axis - elif isinstance(axis, AxisPart): - return MultiAxis(axis) - else: - raise TypeError - - -# def is_set_up(axtree, axis=None): -# """Return ``True`` if all parts (recursively) of the multi-axis have an associated -# layout function. -# """ -# axis = axis or axtree.root -# return all( -# part_is_set_up(axtree, axis, cpt, cidx) -# for cidx, cpt in enumerate(axis.components) -# ) - - -# # this would be an easy place to start with writing a tree visitor instead -# def part_is_set_up(axtree, axis, cpt): -# if (subaxis := axtree.child(axis, cpt)) and not is_set_up( -# axtree, subaxis -# ): -# return False -# if (axis.id, component_index) not in axtree._layouts: -# return False -# return True - - -def has_halo(axes, axis): - if axis.sf is not None: - return True - else: - for component in axis.components: - subaxis = axes.component_child(axis, component) - if subaxis and has_halo(axes, subaxis): - return True - return False - return axis.sf is not None or has_halo(axes, subaxis) - - -def has_independently_indexed_subaxis_parts(axes, axis, cpt): - """ - subaxis parts are independently indexed if they don't depend on the index from - ``part``. - - if one sub-part needs this index to determine its extent then we need to create - a layout function as the step sizes will differ. - - Note that we need to consider both ragged sizes and permutations here - """ - if subaxis := axes.component_child(axis, cpt): - return not any( - requires_external_index(axes, subaxis, c) for c in subaxis.components - ) - else: - return True - - -def only_linear(func): - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - if not self.is_linear: - raise RuntimeError(f"{func.__name__} only admits linear multi-axes") - return func(self, *args, **kwargs) - - return wrapper - - -def can_be_affine(axtree, axis, component, component_index): - return ( - has_independently_indexed_subaxis_parts( - axtree, axis, component, component_index - ) - and component.permutation is None - ) - - -def has_constant_start( - axtree, axis, component, component_index, outer_axes_are_all_indexed: bool -): - """ - We will have an affine layout with a constant start (usually zero) if either we are not - ragged or if we are ragged but everything above is indexed (i.e. a temporary). - """ - assert can_be_affine(axtree, axis, component, component_index) - return isinstance(component.count, numbers.Integral) or outer_axes_are_all_indexed - - -def has_fixed_size(axes, axis, component): - return not size_requires_external_index(axes, axis, component) - - -def step_size( - axes: AxisTree, - axis: Axis, - component: AxisComponent, - path=pmap(), - indices=PrettyTuple(), -): - """Return the size of step required to stride over a multi-axis component. - - Non-constant strides will raise an exception. - """ - if not has_constant_step(axes, axis, component) and not indices: - raise ValueError - if subaxis := axes.component_child(axis, component): - return _axis_size(axes, subaxis, path, indices) - else: - return 1 - - -def make_star_forest_per_axis_part(part, comm): - if part.is_distributed: - # we have a root if a point is shared but doesn't point to another rank - nroots = len( - [pt for pt in part.overlap if isinstance(pt, Shared) and not pt.root] - ) - - # which local points are leaves? - local_points = [ - i for i, pt in enumerate(part.overlap) if not is_owned_by_process(pt) - ] - - # roots of other processes (rank, index) - remote_points = utils.flatten( - [pt.root.as_tuple() for pt in part.overlap if not is_owned_by_process(pt)] - ) - - # import pdb; pdb.set_trace() - - sf = PETSc.SF().create(comm) - sf.setGraph(nroots, local_points, remote_points) - return sf - else: - raise NotImplementedError( - "Need to think about concatenating star forests. This will happen if mixed." - ) - - -def attach_owned_star_forest(axis): - raise NotImplementedError - - -@dataclasses.dataclass -class RemotePoint: - rank: numbers.Integral - index: numbers.Integral - - def as_tuple(self): - return (self.rank, self.index) - - -@dataclasses.dataclass -class PointOverlapLabel(abc.ABC): - pass - - -@dataclasses.dataclass -class Owned(PointOverlapLabel): - pass - - -@dataclasses.dataclass -class Shared(PointOverlapLabel): - root: Optional[RemotePoint] = None - - -@dataclasses.dataclass -class Halo(PointOverlapLabel): - root: RemotePoint - - -def is_owned_by_process(olabel): - return isinstance(olabel, Owned) or isinstance(olabel, Shared) and not olabel.root - - -# --------------------- \/ lifted from halo.py \/ ------------------------- - - -from pyop3.dtypes import as_numpy_dtype - - -def reduction_op(op, invec, inoutvec, datatype): - dtype = as_numpy_dtype(datatype) - invec = np.frombuffer(invec, dtype=dtype) - inoutvec = np.frombuffer(inoutvec, dtype=dtype) - inoutvec[:] = op(invec, inoutvec) - - -_contig_min_op = MPI.Op.Create( - functools.partial(reduction_op, np.minimum), commute=True -) -_contig_max_op = MPI.Op.Create( - functools.partial(reduction_op, np.maximum), commute=True -) - -# --------------------- ^ lifted from halo.py ^ ------------------------- - - -class PointLabel(abc.ABC): - """Container associating points in an :class:`AxisPart` with a enumerated label.""" - - -# TODO: Maybe could make this a little more descriptive a la star forest so we could -# then automatically generate an SF for the multi-axis. -class PointOwnershipLabel(PointLabel): - """Label indicating parallel point ownership semantics (i.e. owned or halo).""" - - # TODO: Write a factory function/constructor that takes advantage of the fact that - # the majority of the points are OWNED and there are only two options so a set is - # an efficient choice of data structure. - def __init__(self, owned_points, halo_points): - owned_set = set(owned_points) - halo_set = set(halo_points) - - if len(owned_set) != len(owned_points) or len(halo_set) != len(halo_points): - raise ValueError("Labels cannot contain duplicate values") - if owned_set.intersection(halo_set): - raise ValueError("Points cannot appear with different values") - - self._owned_points = owned_points - self._halo_points = halo_points - - def __len__(self): - return len(self._owned_points) + len(self._halo_points) - - -# this isn't really a thing I should be caring about - it's just a multi-axis! -class Sparsity: - def __init__(self, maps): - if isinstance(maps, collections.abc.Sequence): - rmap, cmap = maps - else: - rmap, cmap = maps, maps - - ... - - raise NotImplementedError - - def _collect_datamap(axis, *subdatamaps, axes): from pyop3.array import HierarchicalArray @@ -487,10 +239,6 @@ class AxisComponent(LabelledNodeComponent): fields = LabelledNodeComponent.fields | { "count", - "overlap", - "indexed", - "indices", - "lgmap", } def __init__( @@ -499,37 +247,11 @@ def __init__( label=None, *, indices=None, - overlap=None, indexed=False, lgmap=None, ): super().__init__(label=label) self.count = count - self.indices = indices - self.overlap = overlap - self.indexed = indexed - self.lgmap = lgmap - """ - this property is required because we can hit situations like the following: - - sizes = 3 -> [2, 1, 2] -> [[2, 1], [1], [3, 2]] - - this yields a layout that looks like - - [[0, 2], [3], [4, 7]] - - however, if we have a temporary where we only want the inner two dimensions - then we need a layout that looks like the following: - - [[0, 2], [0], [0, 3]] - - This effectively means that we need to zero the offset as we traverse the - tree to produce the layout. This is why we need this ``indexed`` flag. - """ - - @property - def is_distributed(self): - return self.overlap is not None @property def has_integer_count(self): @@ -552,14 +274,6 @@ def alloc_size(self, axtree, axis): else: return npoints - @property - def has_partitioned_halo(self): - if self.overlap is None: - return True - - remaining = itertools.dropwhile(lambda o: is_owned_by_process(o), self.overlap) - return all(isinstance(o, Halo) for o in remaining) - class Axis(MultiComponentLabelledNode, LoopIterable): fields = MultiComponentLabelledNode.fields | { diff --git a/pyop3/ir/lower.py b/pyop3/ir/lower.py index 1d2d7f00..8e8c41d3 100644 --- a/pyop3/ir/lower.py +++ b/pyop3/ir/lower.py @@ -901,7 +901,7 @@ def make_array_expr(array, layouts, path, jnames, ctx): def make_temp_expr(temporary, shape, path, jnames, ctx): - layout = temporary.layout_axes.layouts[path] + layout = temporary.axes.layouts[path] temp_offset = make_offset_expr( layout, jnames, @@ -1138,7 +1138,7 @@ def _scalar_assignment( ctx.add_argument(array) offset_expr = make_offset_expr( - array.layout_axes.layouts[path], + array.layouts[path], array_labels_to_jnames, ctx, ) From 328973ec3f300efaf5c4497031abb61b0b3b07b3 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 8 Dec 2023 14:09:42 +0000 Subject: [PATCH 2/2] More cleanup, tests passing --- pyop3/array/harray.py | 5 - pyop3/axtree/parallel.py | 8 - pyop3/axtree/tree.py | 93 +------- pyop3/buffer.py | 1 - pyop3/ir/lower.py | 45 ---- pyop3/itree/tree.py | 114 --------- pyop3/lang.py | 1 - pyop3/sf.py | 5 - pyop3/space.py | 280 ----------------------- tests/integration/test_parallel_loops.py | 1 - tests/unit/test_axis_ordering_old.py | 119 ---------- tests/unit/test_distarray.py | 1 - tests/unit/test_sparsity_old.py | 271 ---------------------- 13 files changed, 1 insertion(+), 943 deletions(-) delete mode 100644 pyop3/space.py delete mode 100644 tests/unit/test_axis_ordering_old.py delete mode 100644 tests/unit/test_sparsity_old.py diff --git a/pyop3/array/harray.py b/pyop3/array/harray.py index bfe8dc64..110262e9 100644 --- a/pyop3/array/harray.py +++ b/pyop3/array/harray.py @@ -36,7 +36,6 @@ ) from pyop3.buffer import Buffer, DistributedBuffer from pyop3.dtypes import IntType, ScalarType, get_mpi_dtype -from pyop3.extras.debug import print_if_rank, print_with_rank from pyop3.itree import IndexTree, as_index_forest, index_axes from pyop3.itree.tree import CalledMapVariable, collect_loop_indices, iter_axis_tree from pyop3.lang import KernelArgument @@ -342,14 +341,10 @@ def offset(self, *args, allow_unused=False, insert_zeros=False): return strict_int(offset) def simple_offset(self, path, indices): - print_if_rank(0, "self.layouts", self.layouts) - print_if_rank(0, "path", path) - print_if_rank(0, "indices", indices) offset = pym.evaluate(self.layouts[path], indices, ExpressionEvaluator) return strict_int(offset) def iter_indices(self, outer_map): - print_with_rank(0, "myiexpr!!!!!!!!!!!!!!!!!!", self.index_exprs) return iter_axis_tree(self.axes, self.target_paths, self.index_exprs, outer_map) def _with_axes(self, axes): diff --git a/pyop3/axtree/parallel.py b/pyop3/axtree/parallel.py index 68542326..f4bce59a 100644 --- a/pyop3/axtree/parallel.py +++ b/pyop3/axtree/parallel.py @@ -9,7 +9,6 @@ from pyop3.axtree.layout import _as_int, _axis_component_size, step_size from pyop3.dtypes import IntType, as_numpy_dtype, get_mpi_dtype -from pyop3.extras.debug import print_with_rank from pyop3.utils import checked_zip, just_one, strict_int @@ -131,8 +130,6 @@ def grow_dof_sf(axes, axis, path, indices): ) root_offsets[pt] = offset - print_with_rank("root offsets before", root_offsets) - point_sf.broadcast(root_offsets, MPI.REPLACE) # for sanity reasons remove the original root values from the buffer @@ -175,9 +172,4 @@ def grow_dof_sf(axes, axis, path, indices): remote_leaf_dof_offsets[counter] = [rank, root_offsets[pos] + d] counter += 1 - print_with_rank("root offsets: ", root_offsets) - print_with_rank("local leaf offsets", local_leaf_offsets) - print_with_rank("local dof offsets: ", local_leaf_dof_offsets) - print_with_rank("remote offsets: ", remote_leaf_dof_offsets) - return (nroots, local_leaf_dof_offsets, remote_leaf_dof_offsets) diff --git a/pyop3/axtree/tree.py b/pyop3/axtree/tree.py index 326f2b1a..1784a13f 100644 --- a/pyop3/axtree/tree.py +++ b/pyop3/axtree/tree.py @@ -26,7 +26,6 @@ from pyop3 import utils from pyop3.dtypes import IntType, PointerType, get_mpi_dtype -from pyop3.extras.debug import print_if_rank, print_with_rank from pyop3.sf import StarForest from pyop3.tree import ( LabelledNodeComponent, @@ -187,15 +186,9 @@ def map_called_map(self, expr): # the inner_expr tells us the right mapping for the temporary, however, # for maps that are arrays the innermost axis label does not always match # the label used by the temporary. Therefore we need to do a swap here. - # I don't like this. - # print_if_rank(0, repr(array.axes)) - # print_if_rank(0, "before: ",indices) inner_axis = array.axes.leaf_axis indices[inner_axis.label] = indices.pop(expr.function.full_map.name) - # print_if_rank(0, "after:",indices) - # print_if_rank(0, repr(expr)) - # print_if_rank(0, self.context) return array.get_value(path, indices) @@ -580,55 +573,6 @@ def add_node( parent_cpt_label = _as_axis_component_label(parent_component) return super().add_node(axis, parent, parent_cpt_label, **kwargs) - # alias - add_subaxis = add_node - - # currently untested but should keep - @classmethod - def from_layout(cls, layout: Sequence[ConstrainedMultiAxis]) -> Any: # TODO - return order_axes(layout) - - # TODO this is just a regular tree search - @deprecated(internal=True) # I think? - def get_part_from_path(self, path, axis=None): - axis = axis or self.root - - label, *sublabels = path - - (component, component_index) = just_one( - [ - (cpt, cidx) - for cidx, cpt in enumerate(axis.components) - if (axis.label, cidx) == label - ] - ) - if sublabels: - return self.get_part_from_path( - sublabels, self.component_child(axis, component) - ) - else: - return axis, component - - @deprecated(internal=True) - def drop_last(self): - """Remove the last subaxis""" - if not self.part.subaxis: - return None - else: - return self.copy( - parts=[self.part.copy(subaxis=self.part.subaxis.drop_last())] - ) - - @property - @deprecated(internal=True) - def is_linear(self): - """Return ``True`` if the multi-axis contains no branches at any level.""" - if self.nparts == 1: - return self.part.subaxis.is_linear if self.part.subaxis else True - else: - return False - - @deprecated() def add_subaxis(self, subaxis, *loc): return self.add_node(subaxis, *loc) @@ -657,8 +601,6 @@ class AxisTree(PartialAxisTree, Indexed, ContextFreeLoopIterable): "target_paths", "index_exprs", "layout_exprs", - "layouts", - "sf", } def __init__( @@ -667,7 +609,6 @@ def __init__( target_paths=None, index_exprs=None, layout_exprs=None, - sf=None, ): if some_but_not_all( arg is None for arg in [target_paths, index_exprs, layout_exprs] @@ -678,7 +619,6 @@ def __init__( self._target_paths = target_paths or self._default_target_paths() self._index_exprs = index_exprs or self._default_index_exprs() self.layout_exprs = layout_exprs or self._default_layout_exprs() - self.sf = sf or self._default_sf() def __getitem__(self, indices): from pyop3.itree.tree import as_index_forest, collect_loop_contexts, index_axes @@ -762,7 +702,7 @@ def layouts(self): @cached_property def sf(self): - return cls._default_sf(tree) + return self._default_sf() @cached_property def datamap(self): @@ -771,17 +711,6 @@ def datamap(self): else: dmap = postvisit(self, _collect_datamap, axes=self) - # for cleverdict in [self.layouts, self.orig_layout_fn]: - # for layout in cleverdict.values(): - # for layout_expr in layout.values(): - # # catch invalid layouts - # if isinstance(layout_expr, pym.primitives.NaN): - # continue - # for array in MultiArrayCollector()(layout_expr): - # dmap.update(array.datamap) - - # TODO - # for cleverdict in [self.index_exprs, self.layout_exprs]: for cleverdict in [self.index_exprs]: for exprs in cleverdict.values(): for expr in exprs.values(): @@ -939,26 +868,6 @@ def datamap(self): return merge_dicts(axes.datamap for axes in self.context_map.values()) -@dataclasses.dataclass(frozen=True) -class Path: - # TODO Make a persistent dict? - from_axes: Tuple[Any] # axis part IDs I guess (or labels) - to_axess: Tuple[Any] # axis part IDs I guess (or labels) - arity: int - selector: Optional[Any] = None - """The thing that chooses between the different possible output axes at runtime.""" - - @property - def degree(self): - return len(self.to_axess) - - @property - def to_axes(self): - if self.degree != 1: - raise RuntimeError("Only for degree 1 paths") - return self.to_axess[0] - - @functools.singledispatch def as_axis_tree(arg: Any): from pyop3.array import HierarchicalArray # cyclic import diff --git a/pyop3/buffer.py b/pyop3/buffer.py index 32a91d6e..2bc741c4 100644 --- a/pyop3/buffer.py +++ b/pyop3/buffer.py @@ -8,7 +8,6 @@ from mpi4py import MPI from pyop3.dtypes import ScalarType -from pyop3.extras.debug import print_if_rank from pyop3.lang import KernelArgument from pyop3.utils import UniqueNameGenerator, as_tuple, deprecated, readonly diff --git a/pyop3/ir/lower.py b/pyop3/ir/lower.py index 8e8c41d3..a388dda9 100644 --- a/pyop3/ir/lower.py +++ b/pyop3/ir/lower.py @@ -29,7 +29,6 @@ from pyop3.axtree.tree import ContextSensitiveAxisTree from pyop3.buffer import DistributedBuffer, PackedBuffer from pyop3.dtypes import IntType, PointerType -from pyop3.extras.debug import print_with_rank from pyop3.itree import ( AffineSliceComponent, CalledMap, @@ -442,11 +441,6 @@ def parse_loop_properly_this_time( # these aren't jnames! my_index_exprs = axes.index_exprs.get((axis.id, component.label), {}) - print_with_rank("myindexexprs", my_index_exprs) - print_with_rank("new_iname_rplacemap", new_iname_replace_map) - print_with_rank("jname_replace_map", jname_replace_map) - print_with_rank("outerreplac", outer_replace_map) - jname_extras = {} for axis_label, index_expr in my_index_exprs.items(): jname_expr = JnameSubstitutor( @@ -854,10 +848,6 @@ def array_expr(): array_ = array.with_context(context) return make_array_expr( array, - # I think... - # not calling substitute layouts from above so loop indices not - # present in the layout... - # subst_layout(axes, source_path, target_path), array_.layouts[target_path], target_path, iname_replace_map | jname_replace_map, @@ -919,14 +909,6 @@ def make_temp_expr(temporary, shape, path, jnames, ctx): return pym.subscript(pym.var(temporary.name), extra_indices + (temp_offset_var,)) -def subst_layout(axes, source_path, target_path): - replace_map = {} - for axis, cpt in axes.detailed_path(source_path).items(): - replace_map.update(axes.layout_exprs[axis.id, cpt]) - - return IndexExpressionReplacer(replace_map)(axes.layouts[target_path]) - - class JnameSubstitutor(pym.mapper.IdentityMapper): def __init__(self, replace_map, codegen_context): self._labels_to_jnames = replace_map @@ -1117,17 +1099,6 @@ def map_variable(self, expr): return self._replace_map.get(expr.name, expr) -def collect_arrays(expr: pym.primitives.Expr): - collector = MultiArrayCollector() - return collector(expr) - - -def replace_variables( - expr: pym.primitives.Expr, replace_map: dict[str, pym.primitives.Variable] -): - return VariableReplacer(replace_map)(expr) - - def _scalar_assignment( array, path, @@ -1146,22 +1117,6 @@ def _scalar_assignment( return rexpr -def find_axis(axes, path, target, current_axis=None): - """Return the axis matching ``target`` along ``path``. - - ``path`` is a mapping between axis labels and the selected component indices. - """ - current_axis = current_axis or axes.root - - if current_axis.label == target: - return current_axis - else: - subaxis = axes.child(current_axis, path[current_axis.label]) - if not subaxis: - assert False, "oops" - return find_axis(axes, path, target, subaxis) - - def context_from_indices(loop_indices): loop_context = {} for loop_index, (path, _) in loop_indices.items(): diff --git a/pyop3/itree/tree.py b/pyop3/itree/tree.py index de77f957..240a5d2f 100644 --- a/pyop3/itree/tree.py +++ b/pyop3/itree/tree.py @@ -35,7 +35,6 @@ PartialAxisTree, ) from pyop3.dtypes import IntType, get_mpi_dtype -from pyop3.extras.debug import print_if_rank, print_with_rank from pyop3.tree import LabelledTree, Node, Tree, postvisit from pyop3.utils import ( Identified, @@ -56,33 +55,21 @@ def __init__(self, replace_map): self._replace_map = replace_map def map_axis_variable(self, expr): - # print_if_rank(0, "replace map ", self._replace_map) - # return self._replace_map[expr.axis_label] return self._replace_map.get(expr.axis_label, expr) def map_multi_array(self, expr): from pyop3.array.harray import MultiArrayVariable - # print_if_rank(0, self._replace_map) - # print_if_rank(0, expr.indices) indices = {axis: self.rec(index) for axis, index in expr.indices.items()} return MultiArrayVariable(expr.array, indices) def map_called_map(self, expr): array = expr.function.map_component.array - # should the following only exist at eval time? - # the inner_expr tells us the right mapping for the temporary, however, # for maps that are arrays the innermost axis label does not always match # the label used by the temporary. Therefore we need to do a swap here. - # I don't like this. - # inner_axis = array.axes.leaf_axis - # print_if_rank(0, self._replace_map) - # print_if_rank(0, expr.parameters) indices = {axis: self.rec(idx) for axis, idx in expr.parameters.items()} - # indices[inner_axis.label] = indices.pop(expr.function.full_map.name) - return CalledMapVariable(expr.function, indices) def map_loop_index(self, expr): @@ -90,13 +77,6 @@ def map_loop_index(self, expr): return self._replace_map.get((expr.name, expr.axis), expr) -# just use a pmap for this -# class IndexForest: -# def __init__(self, trees: Mapping[Mapping, IndexTree]): -# # per loop context -# self.trees = trees - - # index trees are different to axis trees because we know less about # the possible attaching components. In particular a CalledMap can # have different "attaching components"/output components depending on @@ -148,9 +128,6 @@ def parse_parent_to_children(parent_to_children, loop_context, parent=None): return pmap() -IndexLabel = collections.namedtuple("IndexLabel", ["axis", "component"]) - - class DatamapCollector(pym.mapper.CombineMapper): def combine(self, values): return merge_dicts(values) @@ -366,7 +343,6 @@ def index(self) -> LoopIndex: target_paths, index_exprs, layout_exprs, - axes.layouts, ) context_sensitive_axes = ContextSensitiveAxisTree({context: axes}) @@ -1029,7 +1005,6 @@ def _(called_map: CalledMap, **kwargs): ) = _make_leaf_axis_from_called_map( called_map, prior_target_path, prior_index_exprs ) - # axes = axes.add_node(subaxis, prior_leaf_axis, prior_leaf_cpt) axes = axes.add_subtree( PartialAxisTree(subaxis), prior_leaf_axis, prior_leaf_cpt ) @@ -1219,10 +1194,6 @@ def _compose_bits( if iaxis is None: target_path |= itarget_paths.get(None, {}) partial_index_exprs |= iindex_exprs.get(None, {}) - # partial_layout_exprs |= ilayout_exprs.get(None, {}) - - # no idea why I put this line here - # visited_target_axes = visited_target_axes.union(target_path.keys()) iaxis = indexed_axes.root target_path_per_cpt = collections.defaultdict(dict) @@ -1247,8 +1218,6 @@ def _compose_bits( ] # if target_path is "complete" then do stuff, else pass responsibility to next func down - # index_exprs[iaxis.id, icpt.label] = {} - # layout_exprs[iaxis.id, icpt.label] = {} new_visited_target_axes = visited_target_axes if axes.is_valid_path(new_target_path): detailed_path = axes.detailed_path(new_target_path) @@ -1287,7 +1256,6 @@ def _compose_bits( new_index_exprs_acc = new_index_exprs_acc | { axis_label: new_index_expr } - # new_partial_index_exprs = pmap() # now do the layout expressions, this is simpler since target path magic isnt needed # compose layout expressions, this does an *outside* substitution @@ -1344,11 +1312,7 @@ def _compose_bits( else: pass - # assert not skip huh? - # assert not new_partial_index_exprs - # assert not new_partial_layout_exprs - # breakpoint() return ( freeze(dict(target_path_per_cpt)), freeze(dict(index_exprs)), @@ -1375,10 +1339,6 @@ def iter_axis_tree( new_exprs = {} for axlabel, index_expr in myindex_exprs.items(): new_index = ExpressionEvaluator(outermap)(index_expr) - print_with_rank("initialrepr", repr(index_expr)) - print_with_rank("replacedrepr", repr(new_index)) - print_with_rank("initial", index_expr) - print_with_rank("replaced", new_index) assert new_index != index_expr new_exprs[axlabel] = new_index index_exprs_acc = freeze(new_exprs) @@ -1400,14 +1360,9 @@ def iter_axis_tree( new_index = ExpressionEvaluator(outermap | indices | {axis.label: pt})( index_expr ) - print_with_rank("initialrepr", repr(index_expr)) - print_with_rank("replacedrepr", repr(new_index)) - print_with_rank("initial", index_expr) - print_with_rank("replaced", new_index) assert new_index != index_expr new_exprs[axlabel] = new_index index_exprs_ = index_exprs_acc | new_exprs - # index_exprs_ = index_exprs | myindex_exprs indices_ = indices | {axis.label: pt} if subaxis: yield from iter_axis_tree( @@ -1422,9 +1377,7 @@ def iter_axis_tree( index_exprs_, ) else: - # yield path_, index_exprs_, indices_ yield path_, target_path_, indices_, index_exprs_ - # yield path_, indices_ class ArrayPointLabel(enum.IntEnum): @@ -1464,7 +1417,6 @@ def partition_iterset(index: LoopIndex, arrays): from pyop3.array import HierarchicalArray # take first - # paraxis = [axis for axis in index.iterset.nodes if axis.sf is not None][0] if index.iterset.depth > 1: raise NotImplementedError("Need a good way to sniff the parallel axis") paraxis = index.iterset.root @@ -1480,13 +1432,6 @@ def partition_iterset(index: LoopIndex, arrays): if not array.array.is_distributed: continue - # take first - # array_paraxes = [ - # axis for axis in array.orig_array.axes.nodes if axis.sf is not None - # ] - # - # array_paraxis = array_paraxes[0] - # sf = array_paraxis.sf sf = array.array.sf # the dof sf # mark leaves and roots @@ -1494,18 +1439,6 @@ def partition_iterset(index: LoopIndex, arrays): is_root_or_leaf[sf.iroot] = ArrayPointLabel.ROOT is_root_or_leaf[sf.ileaf] = ArrayPointLabel.LEAF - # do this because we need to think of the indices here as a selector - # rather than a map. We need to transform to the new numbering, hence we - # need to apply the map default -> reordered, but the indexing semantics - # are the opposite of this - # is_root_or_leaf = is_root_or_leaf[array_paraxis.numbering] - # this is equivalent to: - # new_labels = np.empty_like(labels) - # for i, l in enumerate(labels): - # j = array_paraxis._inverse_numbering[i] - # new_labels[j] = l - # labels = new_labels - is_root_or_leaf_per_array[array.name] = is_root_or_leaf labels = np.full(paraxis.size, IterationPointType.CORE, dtype=np.uint8) @@ -1527,59 +1460,15 @@ def partition_iterset(index: LoopIndex, arrays): # loop over stencil array = array.with_context({index.id: target_path}) - print_with_rank("array axes", array.axes) - print_with_rank("array targetpaths", array.target_paths) - print_with_rank("array idxsexprs", array.index_exprs) for ( array_path, array_target_path, array_indices, array_target_indices, ) in array.iter_indices(replace_map): - # allexprs = dict(array.axes.index_exprs.get(None, {})) - # if not array.axes.is_empty: - # for myaxis, mycpt in array.axes.path_with_nodes( - # *array.axes._node_from_path(array_path) - # ).items(): - # allexprs.update(array.axes.index_exprs[myaxis.id, mycpt]) - # - # offset = array.axes.offset(array_path, array_indices | replace_map) - # offset = array.offset( - # array_target_path, array_target_indices | replace_map - # ) - # offset = array.simple_offset(array_path, array_indices | replace_map) - print_with_rank("array path", array_path) - print_with_rank("array idxs", array_indices) - print_with_rank("array target path", array_target_path) - print_with_rank("array target idxs", array_target_indices) - # offset = array.simple_offset(array_target_path, array_target_indices | replace_map) - - print_with_rank("myindices", array_indices | replace_map) - # offset = array.simple_offset(array_target_path, array_indices | replace_map) - # offset = array.simple_offset(array_target_path, array_target_indices | replace_map) - offset = array.simple_offset(array_target_path, array_target_indices) - # allexprs is indexed with the "source" labels but we want a particular - # "target" label, need to go backwards... or something - # if len(target_path) != 1: - # raise NotImplementedError - # target_parallel_axis_label = just_one(target_path.keys()) - # the_expr_i_want = allexprs[target_parallel_axis_label] - # - # # but this is for a particular component!! need to map component index to - # # "full" one, how? or just do offset? - # pt_index = pym.evaluate( - # the_expr_i_want, - # replace_map | array_indices, - # ExpressionEvaluator, - # ) - # print_if_rank(1, "ptindex", pt_index) - # assert isinstance(pt_index, numbers.Integral) - - # point_label = is_root_or_leaf_per_array[array.name][pt_index] point_label = is_root_or_leaf_per_array[array.name][offset] - print_if_rank(1, "ptlabel", point_label) if point_label == ArrayPointLabel.LEAF: labels[parindex] = IterationPointType.LEAF break # no point doing more analysis @@ -1592,9 +1481,6 @@ def partition_iterset(index: LoopIndex, arrays): parcpt = just_one(paraxis.components) # for now - print_with_rank("arrayper", is_root_or_leaf_per_array) - print_with_rank("labels", labels) - core = just_one(np.nonzero(labels == IterationPointType.CORE)) root = just_one(np.nonzero(labels == IterationPointType.ROOT)) leaf = just_one(np.nonzero(labels == IterationPointType.LEAF)) diff --git a/pyop3/lang.py b/pyop3/lang.py index b335efbb..156e99ad 100644 --- a/pyop3/lang.py +++ b/pyop3/lang.py @@ -22,7 +22,6 @@ from pyop3.axtree.tree import ContextFree, ContextSensitive, MultiArrayCollector from pyop3.config import config from pyop3.dtypes import IntType, dtype_limits -from pyop3.extras.debug import print_with_rank from pyop3.utils import as_tuple, checked_zip, just_one, merge_dicts, unique diff --git a/pyop3/sf.py b/pyop3/sf.py index 74288021..292a9bb6 100644 --- a/pyop3/sf.py +++ b/pyop3/sf.py @@ -5,7 +5,6 @@ from petsc4py import PETSc from pyop3.dtypes import get_mpi_dtype -from pyop3.extras.debug import print_with_rank from pyop3.utils import just_one @@ -85,14 +84,10 @@ def reduce(self, *args): def reduce_begin(self, *args): reduce_args = self._prepare_args(*args) self.sf.reduceBegin(*reduce_args) - print_with_rank(reduce_args) - print_with_rank("reduce begin") def reduce_end(self, *args): reduce_args = self._prepare_args(*args) self.sf.reduceEnd(*reduce_args) - print_with_rank(reduce_args) - print_with_rank("reduce end") @cached_property def _graph(self): diff --git a/pyop3/space.py b/pyop3/space.py deleted file mode 100644 index 1379d3cb..00000000 --- a/pyop3/space.py +++ /dev/null @@ -1,280 +0,0 @@ -"""This is an old file that isn't used any more.""" - -import functools -import numbers -from typing import Any, Dict, FrozenSet, Hashable, Optional - -import numpy as np -import pytools -from petsc4py import PETSc - -from pyop3.axes import Axis, AxisTree -from pyop3.axes.tree import has_independently_indexed_subaxis_parts -from pyop3.utils import as_tuple - -DEFAULT_AXIS_PRIORITY = 100 - - -class InvalidConstraintsException(Exception): - pass - - -class ConstrainedAxis(pytools.ImmutableRecord): - fields = {"axis", "priority", "within_labels"} - # TODO We could use 'label' to set the priority - # via commandline options - - def __init__( - self, - axis: Axis, - *, - priority: int = DEFAULT_AXIS_PRIORITY, - within_labels: FrozenSet[Hashable] = frozenset(), - ): - self.axis = axis - self.priority = priority - self.within_labels = frozenset(within_labels) - super().__init__() - - def __str__(self) -> str: - return f"{self.__class__.__name__}(axis=({', '.join(str(axis_cpt) for axis_cpt in self.axis)}), priority={self.priority}, within_labels={self.within_labels})" - - -class Space: - def __init__(self, mesh, layout): - # TODO mesh.axes is an iterable (could be extruded for example) - if mesh.axes.depth > 1: - raise NotImplementedError("need to unpack here somehow") - meshaxis = ConstrainedAxis(mesh.axes.root, priority=10) - axes = order_axes([meshaxis] + layout) - - self.mesh = mesh - self.layout = layout - self.axes = axes - - @property - def comm(self): - return self.mesh.comm - - # I don't like that this is an underscored property. I think internal_comm might be better - @property - def _comm(self): - return self.mesh._comm - - # TODO I think that this could be replaced with callbacks or something. - # DMShell supports passing a callback - # https://petsc.org/release/manualpages/DM/DMShellSetCreateGlobalVector/ - # so we could avoid allocating something here. - @functools.cached_property - def layout_vec(self): - """A PETSc Vec compatible with the dof layout of this DataSet.""" - vec = PETSc.Vec().create(comm=self.comm) - vec.setSizes((self.axes.calc_size(self.axes.root), None), bsize=1) - vec.setUp() - return vec - - @functools.cached_property - def dm(self): - from firedrake import dmhooks - from firedrake.mg.utils import get_level - - dm = PETSc.DMShell().create(comm=self._comm) - dm.setGlobalVector(self.layout_vec) - _, level = get_level(self.mesh) - - # We need to pass sf and section for preconditioners that are not - # implemented in Python (e.g. PCPATCH). For Python-level preconditioning - # we can extract all of this information from the function space. - # Since pyop3 spaces are more dynamic than a classical PETSc Vec we can only - # emit sections for "simple" structures where we have points and DoFs/point. - # TODO for "mixed" problems we could use fields in this section as well. - # it is still very fragile compared with pyop3 - # Extruded meshes are currently outlawed (axis tree depth must be 2) and to - # get them to work the numbering would need to be flattened. - # Ephemeral meshes would probably be rather helpful here. - - # this algorithm is basically equivalent to get_global_numbering - # we are allowed to do this if the inner size is always fixed. This is - # regardless of any depth considerations from vector shape etc. - if self.axes.root.label == "mesh" and all( - has_independently_indexed_subaxis_parts( - self.axes, self.axes.root, cpt, cidx - ) - for cidx, cpt in enumerate(self.axes.root.components) - ): - section = PETSc.Section().create(comm=self._comm) - section.setChart(0, self.axes.root.count) - for d in range(self.mesh.dimension + 1): - # in pyop3 points per tdim are counted from zero - # i.e. cell0, vertex0 instead of all cells being < all vertices - subaxis = self.axes.find_node((self.axes.root.id, d)) - ndofs = self.axes.calc_size(subaxis) - if ndofs > 0: - for idx, pt in enumerate(range(*self.mesh.plex.getDepthStratum(d))): - # get the offset for the zeroth sub axis element - indices = [(d, idx)] + [0] * (self.axes.depth - 1) - offset = self.axes.offset(indices) - section.setOffset(pt, offset) - section.setDof(pt, ndofs) - else: - for pt in range(*self.mesh.plex.getDepthStratum(d)): - section.setDof(pt, 0) - section.setOffset(pt, 0) - - sf = self.mesh.plex.getPointSF() - else: - section = None - sf = None - - dmhooks.attach_hooks(dm, level=level, section=section, sf=sf) - return dm - - """note the following code reproduces the same thing as the current - "global_numbering" section from Firedrake. But, I don't see why this should - be correct. I already handle the permutation when I compute the offsets. - - if self.axes.depth == 2 and self.axes.root.label == "mesh": - section = PETSc.Section().create(comm=self._comm) - section.setChart(0, self.axes.root.count) - for d in range(self.mesh.dimension + 1): - # in pyop3 points per tdim are counted from zero - # i.e. cell0, vertex0 instead of all cells being < all vertices - subaxis = self.axes.find_node((self.axes.root.id, d)) - dof = self.axes.calc_size(subaxis) - for pt in range(*self.mesh.plex.getDepthStratum(d)): - section.setDof(pt, dof) - else: - for pt in range(*self.mesh.plex.getDepthStratum(d)): - section.setDof(pt, 0) - - #TODO could make this a property - iset = PETSc.IS().createGeneral(self.mesh.axes.root.permutation, comm=self._comm) - section.setPermutation(iset) - section.setUp() - - sf = self.mesh.plex.getPointSF() - else: - section = None - sf = None - - dmhooks.attach_hooks(dm, level=level, section=section, sf=sf) - return dm - - """ - - -def order_axes(layout): - axes = AxisTree() - layout = list(map(as_constrained_axis, as_tuple(layout))) - axis_to_constraint = {caxis.axis.label: caxis for caxis in layout} - history = set() - while layout: - if tuple(layout) in history: - raise ValueError("Seen this before, cyclic") - history.add(tuple(layout)) - - constrained_axis = layout.pop(0) - axes, inserted = _insert_axis( - axes, constrained_axis, axes.root, axis_to_constraint - ) - if not inserted: - layout.append(constrained_axis) - return axes - - -def _insert_axis( - axes: AxisTree, - new_caxis: ConstrainedAxis, - current_axis: Axis, - axis_to_caxis: Dict[Axis, ConstrainedAxis], - path: Optional[Dict[Hashable, Dict]] = None, -): - path = path or {} - - within_labels = set(path.items()) - - # alias - remove - axis_to_constraint = axis_to_caxis - - if not axes.root: - if not new_caxis.within_labels: - axes = axes.put_node(new_caxis.axis) - return axes, True - else: - return axes, False - - # current_axis = current_axis or axes.root - current_caxis = axis_to_constraint[current_axis.label] - - if new_caxis.priority < current_caxis.priority: - raise NotImplementedError("TODO subtrees") - if new_caxis.within_labels <= within_labels: - # diagram or something? - parent_axis = axes.parent(current_axis) - subtree = axes.pop_subtree(current_axis) - betterid = new_caxis.axis.copy(id=next(Axis._id_generator)) - if not parent_axis: - axes.add_node(betterid) - else: - axes.add_node(betterid, path) - - # must already obey the constraints - so stick back in for all sub components - for comp in betterid.components: - stree = subtree.copy() - # stree.replace_node(stree.root.copy(id=next(MultiAxis._id_generator))) - mypath = (axes._node_to_path[betterid.id] or {}) | { - betterid.label: comp.label - } - axes.add_subtree(stree, mypath, uniquify=True) - axes._parent_and_label_to_child[(betterid, comp.label)] = stree.root.id - # need to register the right parent label - return True - else: - # The priority is less so the axes should definitely - # not be inserted below here - do not recurse - return False - else: - inserted = False - for cidx, cpt in enumerate(current_axis.components): - subaxis = axes.children(current_axis)[cidx] - if subaxis: - # axes can be unchanged - axes, now_inserted = _insert_axis( - axes, - new_caxis, - subaxis, - axis_to_constraint, - path | {current_axis.label: cidx}, - ) - else: - assert new_caxis.priority >= current_caxis.priority - if new_caxis.within_labels <= within_labels | { - (current_axis.label, cidx) - }: - # bad uniquify - betterid = new_caxis.axis.copy(id=next(Axis._id_generator)) - axes = axes.put_node(betterid, current_axis.id, cidx) - now_inserted = True - - inserted = inserted or now_inserted - return axes, inserted - - -@functools.singledispatch -def as_constrained_axis(arg: Any): - raise TypeError - - -@as_constrained_axis.register -def _(arg: ConstrainedAxis): - return arg - - -@as_constrained_axis.register -def _(arg: Axis): - return ConstrainedAxis(arg) - - -@as_constrained_axis.register -def _(arg: numbers.Integral): - return ConstrainedAxis(Axis([arg])) diff --git a/tests/integration/test_parallel_loops.py b/tests/integration/test_parallel_loops.py index 168aa4f8..cff1ea43 100644 --- a/tests/integration/test_parallel_loops.py +++ b/tests/integration/test_parallel_loops.py @@ -5,7 +5,6 @@ from pyrsistent import freeze import pyop3 as op3 -from pyop3.extras.debug import print_with_rank from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET from pyop3.utils import just_one diff --git a/tests/unit/test_axis_ordering_old.py b/tests/unit/test_axis_ordering_old.py deleted file mode 100644 index b8701caa..00000000 --- a/tests/unit/test_axis_ordering_old.py +++ /dev/null @@ -1,119 +0,0 @@ -import pytest - -from pyop3 import * - -# from pyop3.multiaxis import ConstrainedMultiAxis - - -# This file is pretty outdated -pytest.skip(allow_module_level=True) - - -def id_tuple(nodes): - return tuple(node.id for node in nodes) - - -def label_tuple(nodes): - return tuple(node.label for node in nodes) - - -def test_axis_ordering(): - axis0 = MultiAxis([MultiAxisComponent(3, "cpt0")], "ax0") - axis1 = MultiAxis([MultiAxisComponent(1, "cpt0")], "ax1") - - layout = [ConstrainedMultiAxis(axis0), ConstrainedMultiAxis(axis1)] - axes = MultiAxisTree.from_layout(layout) - - assert axes.depth == 2 - assert axes.root == axis0 - assert just_one(axes.children({"ax0": "cpt0"})).label == "ax1" - assert not axes.children({"ax0": "cpt0", "ax1": "cpt0"}) - - ### - - layout = [ConstrainedMultiAxis(axis0), ConstrainedMultiAxis(axis1, priority=0)] - axes = MultiAxisTree.from_layout(layout) - assert axes.depth == 2 - assert axes.root.label == "ax1" - assert just_one(axes.children({"ax1": "cpt0"})).label == "ax0" - assert not axes.children({"ax0": "cpt0", "ax1": "cpt0"}) - - -def test_multicomponent_constraints(): - axis0 = MultiAxis( - [MultiAxisComponent(3, "cpt0"), MultiAxisComponent(3, "cpt1")], "ax0" - ) - axis1 = MultiAxis([MultiAxisComponent(3, "cpt0")], "ax1") - - ### - - layout = [ - ConstrainedMultiAxis(axis0, within_labels={("ax1", "cpt0")}), - ConstrainedMultiAxis(axis1), - ] - axes = order_axes(layout) - - assert axes.depth == 2 - assert axes.root.label == "ax1" - assert just_one(axes.children({"ax1": "cpt0"})).label == "ax0" - assert not axes.children({"ax1": "cpt0", "ax0": "cpt0"}) - assert not axes.children({"ax1": "cpt0", "ax0": "cpt1"}) - - ### - - with pytest.raises(ValueError): - layout = [ - ConstrainedMultiAxis(axis0, within_labels={("ax1", "cpt0")}), - ConstrainedMultiAxis(axis1, within_labels={("ax0", "cpt0")}), - ] - order_axes(layout) - - ### - - layout = [ - ConstrainedMultiAxis(axis0), - ConstrainedMultiAxis(axis1, within_labels={("ax0", "cpt1")}), - ] - axes = order_axes(layout) - - assert axes.depth == 2 - assert axes.root.label == "ax0" - assert not axes.children({"ax0": "cpt0"}) - assert just_one(axes.children({"ax0": "cpt1"})).label == "ax1" - assert not axes.children({"ax0": "cpt1", "ax1": "cpt0"}) - - ### - - -def test_multicomponent_constraints_more(): - # ax0 - # ├──➤ cpt0 : ax1 - # │ └──➤ cpt0 - # └──➤ cpt1 : ax2 - # └──➤ cpt0 : ax1 - # └──➤ cpt0 - - axis0 = MultiAxis( - [ - MultiAxisComponent(3, "cpt0"), - MultiAxisComponent(3, "cpt1"), - ], - "ax0", - ) - axis1 = MultiAxis([MultiAxisComponent(3, "cpt0")], "ax1") - axis2 = MultiAxis([MultiAxisComponent(3, "cpt0")], "ax2") - - layout = [ - ConstrainedMultiAxis(axis0, priority=0), - ConstrainedMultiAxis(axis1, priority=20), - ConstrainedMultiAxis(axis2, within_labels={("ax0", "cpt1")}, priority=10), - ] - axes = order_axes(layout) - - assert axes.depth == 3 - assert axes.root.label == "ax0" - assert just_one(axes.children({"ax0": "cpt0"})).label == "ax1" - assert just_one(axes.children({"ax0": "cpt1"})).label == "ax2" - assert not axes.children({"ax0": "cpt0", "ax1": "cpt0"}) - assert just_one(axes.children({"ax0": "cpt1", "ax2": "cpt0"})).label == "ax1" - assert not axes.children({"ax0": "cpt1", "ax2": "cpt0", "ax1": "cpt0"}) diff --git a/tests/unit/test_distarray.py b/tests/unit/test_distarray.py index b191be10..4969ec2c 100644 --- a/tests/unit/test_distarray.py +++ b/tests/unit/test_distarray.py @@ -7,7 +7,6 @@ from petsc4py import PETSc import pyop3 as op3 -from pyop3.extras.debug import print_with_rank @pytest.fixture diff --git a/tests/unit/test_sparsity_old.py b/tests/unit/test_sparsity_old.py deleted file mode 100644 index 5c8b9f90..00000000 --- a/tests/unit/test_sparsity_old.py +++ /dev/null @@ -1,271 +0,0 @@ -import pytest - -# This file is pretty outdated -pytest.skip(allow_module_level=True) - -from pyop3 import * -from pyop3.distarray.petsc import * -from pyop3.extras.debug import print_with_rank - - -def test_read_sparse_matrix(): - """Read values from a matrix that looks like: - - 0 10 20 - 0 30 40 - 50 0 0 - - """ - nnzaxes = MultiAxis([AxisPart(3, id="p1")]).set_up() - nnz = MultiArray(nnzaxes, name="nnz", data=np.array([2, 2, 1], dtype=np.uint64)) - - indices = MultiArray.from_list( - [[1, 2], [1, 2], [0]], labels=["p1", "any"], name="indices", dtype=np.uint64 - ) - - mataxes = ( - nnzaxes.copy().add_subaxis("p1", [AxisPart(nnz, indices=indices)]).set_up() - ) - mat = MultiArray(mataxes, name="mat", data=np.arange(10, 51, 10)) - - assert mat.get_value([0, 1]) == 10 - assert mat.get_value([0, 2]) == 20 - assert mat.get_value([1, 1]) == 30 - assert mat.get_value([1, 2]) == 40 - assert mat.get_value([2, 0]) == 50 - - -def test_read_sparse_rank_3_tensor(): - """Read values from a matrix that looks like: - - 0 A - B C - - with: - A : [10 0 20] - B : [30 40 0] - C : [50 0 60] - - - """ - ax1 = MultiAxis([AxisPart(2, id="p1")]).set_up() - nnz = MultiArray(ax1, name="nnz", data=np.array([1, 2], dtype=np.uint64)) - - indices1 = MultiArray.from_list( - [[1], [0, 1]], labels=["p1", "any"], name="indices", dtype=np.uint64 - ) - - indices2 = MultiArray.from_list( - [[[0, 2]], [[0, 1], [0, 2]]], - labels=["p1", "any1", "any2"], - name="indices", - dtype=np.uint64, - ) - - ax2 = ( - ax1.copy() - .add_subaxis("p1", [AxisPart(nnz, indices=indices1, id="p2")]) - .add_subaxis("p2", [AxisPart(2, indices=indices2)]) - ).set_up() - tensor = MultiArray(ax2, name="tensor", data=np.arange(10, 61, 10)) - - assert tensor.get_value([0, 1, 0]) == 10 - assert tensor.get_value([0, 1, 2]) == 20 - assert tensor.get_value([1, 0, 0]) == 30 - assert tensor.get_value([1, 0, 1]) == 40 - assert tensor.get_value([1, 1, 0]) == 50 - assert tensor.get_value([1, 1, 2]) == 60 - - -@pytest.fixture -def sparsity1dp1(): - """ - - The cone sparsity of the following mesh: - - v0 v1 v2 v3 - x---x---x---x - c0 c1 c2 - - should look like: - - v0 v1 v2 v3 - v0 x x - v1 x x x - v2 x x x - v3 x x - - """ - mapaxes = ( - MultiAxis([AxisPart(3, label="cells", id="cells")]).add_subaxis( - "cells", [AxisPart(2, label="any")] - ) - ).set_up() - mapdata = MultiArray( - mapaxes, name="map0", data=np.array([0, 1, 1, 2, 2, 3], dtype=IntType) - ) - - iterindex = RangeNode("cells", 3, id="i0") - lmap = rmap = iterindex.add_child( - "i0", TabulatedMapNode(["cells"], ["nodes"], arity=2, data=mapdata[[iterindex]]) - ) - - return make_sparsity(iterindex, lmap, rmap) - - -def test_make_sparsity(sparsity1dp1): - expected = { - (("nodes",), ("nodes",)): { - ((0,), (0,)), - ((0,), (1,)), - ((1,), (0,)), - ((1,), (1,)), - ((1,), (2,)), - ((2,), (1,)), - ((2,), (2,)), - ((2,), (3,)), - ((3,), (2,)), - ((3,), (3,)), - }, - } - - assert sparsity1dp1 == expected - - -def test_make_matrix(sparsity1dp1): - raxes = MultiAxis([AxisPart(4, label="nodes")]) - caxes = raxes.copy() - - mat = PetscMatAIJ(raxes, caxes, sparsity1dp1) - - import pdb - - pdb.set_trace() - - assert False - - -@pytest.mark.parallel(nprocs=2) -def test_make_parallel_matrix(): - """TODO - - Construct a P1 matrix for the following 1D mesh: - - v0 v1 v2 v3 - x-----x-----x-----x - c0 c1 c2 - - The mesh is distributed between 2 processes so the local meshes are: - - v0 v1 v2 - proc 1: x-----x-----o - c0 c1 - - v1 v2 v3 - proc 2: o~~~~~x-----x - c1 c2 - - Where o and ~ (instead of x and -) denote that points are halo, not owned. - - It is essential that all owned points fully store all points in their - adjacency. For FEM the adjacency is given by cl(support(pt)). For process 1 - this means that v2 must be stored, but for process 2, owning v2 requires - that c1 and v1 both exist in the halo. - - Given the adjacency relation as described. The matrix sparsity should be: - - v0 v1 v2 v3 - v0 x x - v1 x x x - v2 x x x - v3 x x - - The sparsities for each process are given by: - - proc 1: - - v0 v1 v2 - v0 x x - v1 x s s - v2 h h - - proc 2: - - v1 v2 v3 - v1 h h - v2 s s x - v3 x x - - Here "s" denotes shared and "h" halo. - - Since PETSc divides ownership across rows, the DoFs in (v3, :) are dropped for - process 1 and the DoFs in (v1, :) are dropped for process 2. - - """ - comm = PETSc.Sys.getDefaultComm() - assert comm.size == 2 - - if comm.rank == 0: - # v0, v1 and v2 - nnodes = 3 - overlap = [Owned(), Shared(), Halo(RemotePoint(1, 1))] - - # now make the sparsity - mapaxes = ( - MultiAxis([AxisPart(2, label="cells", id="cells")]).add_subaxis( - "cells", [AxisPart(2, label="any")] - ) - ).set_up() - mapdata = MultiArray( - mapaxes, name="map0", data=np.array([0, 1, 1, 2], dtype=IntType) - ) - - iterindex = RangeNode("cells", 2, id="i0") - lmap = rmap = iterindex.add_child( - "i0", - TabulatedMapNode(["cells"], ["nodes"], arity=2, data=mapdata[[iterindex]]), - ) - else: - # v1, v2 and v3 - nnodes = 3 - # FIXME: Unclear on the ordering of Shared and Owned - # should they be handled by some numbering? - overlap = [Owned(), Shared(), Halo(RemotePoint(0, 1))] - - # now make the sparsity - mapaxes = ( - MultiAxis([AxisPart(2, label="cells", id="cells")]).add_subaxis( - "cells", [AxisPart(2, label="any")] - ) - ).set_up() - mapdata = MultiArray( - mapaxes, name="map0", data=np.array([2, 1, 1, 0], dtype=IntType) - ) - - iterindex = RangeNode("cells", 2, id="i0") - lmap = rmap = iterindex.add_child( - "i0", - TabulatedMapNode(["cells"], ["nodes"], arity=2, data=mapdata[[iterindex]]), - ) - - axes = MultiAxis([AxisPart(nnodes, label="nodes", overlap=overlap)]).set_up() - sparsity = make_sparsity(iterindex, lmap, rmap) - - print_with_rank(sparsity) - - # new_sparsity = distribute_sparsity(sparsity, axes, axes) - - # print_with_rank(new_sparsity) - - mat = PetscMatAIJ(axes, axes, sparsity) - - # import pdb; pdb.set_trace() - # - mat.petscmat.getLGMap()[0].view() - mat.petscmat.getLGMap()[1].view() - - mat.petscmat.view() - - -if __name__ == "__main__": - test_make_parallel_matrix()