From 328973ec3f300efaf5c4497031abb61b0b3b07b3 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 8 Dec 2023 14:09:42 +0000 Subject: [PATCH] More cleanup, tests passing --- pyop3/array/harray.py | 5 - pyop3/axtree/parallel.py | 8 - pyop3/axtree/tree.py | 93 +------- pyop3/buffer.py | 1 - pyop3/ir/lower.py | 45 ---- pyop3/itree/tree.py | 114 --------- pyop3/lang.py | 1 - pyop3/sf.py | 5 - pyop3/space.py | 280 ----------------------- tests/integration/test_parallel_loops.py | 1 - tests/unit/test_axis_ordering_old.py | 119 ---------- tests/unit/test_distarray.py | 1 - tests/unit/test_sparsity_old.py | 271 ---------------------- 13 files changed, 1 insertion(+), 943 deletions(-) delete mode 100644 pyop3/space.py delete mode 100644 tests/unit/test_axis_ordering_old.py delete mode 100644 tests/unit/test_sparsity_old.py diff --git a/pyop3/array/harray.py b/pyop3/array/harray.py index bfe8dc64..110262e9 100644 --- a/pyop3/array/harray.py +++ b/pyop3/array/harray.py @@ -36,7 +36,6 @@ ) from pyop3.buffer import Buffer, DistributedBuffer from pyop3.dtypes import IntType, ScalarType, get_mpi_dtype -from pyop3.extras.debug import print_if_rank, print_with_rank from pyop3.itree import IndexTree, as_index_forest, index_axes from pyop3.itree.tree import CalledMapVariable, collect_loop_indices, iter_axis_tree from pyop3.lang import KernelArgument @@ -342,14 +341,10 @@ def offset(self, *args, allow_unused=False, insert_zeros=False): return strict_int(offset) def simple_offset(self, path, indices): - print_if_rank(0, "self.layouts", self.layouts) - print_if_rank(0, "path", path) - print_if_rank(0, "indices", indices) offset = pym.evaluate(self.layouts[path], indices, ExpressionEvaluator) return strict_int(offset) def iter_indices(self, outer_map): - print_with_rank(0, "myiexpr!!!!!!!!!!!!!!!!!!", self.index_exprs) return iter_axis_tree(self.axes, self.target_paths, self.index_exprs, outer_map) def _with_axes(self, axes): diff --git a/pyop3/axtree/parallel.py b/pyop3/axtree/parallel.py index 68542326..f4bce59a 100644 --- a/pyop3/axtree/parallel.py +++ b/pyop3/axtree/parallel.py @@ -9,7 +9,6 @@ from pyop3.axtree.layout import _as_int, _axis_component_size, step_size from pyop3.dtypes import IntType, as_numpy_dtype, get_mpi_dtype -from pyop3.extras.debug import print_with_rank from pyop3.utils import checked_zip, just_one, strict_int @@ -131,8 +130,6 @@ def grow_dof_sf(axes, axis, path, indices): ) root_offsets[pt] = offset - print_with_rank("root offsets before", root_offsets) - point_sf.broadcast(root_offsets, MPI.REPLACE) # for sanity reasons remove the original root values from the buffer @@ -175,9 +172,4 @@ def grow_dof_sf(axes, axis, path, indices): remote_leaf_dof_offsets[counter] = [rank, root_offsets[pos] + d] counter += 1 - print_with_rank("root offsets: ", root_offsets) - print_with_rank("local leaf offsets", local_leaf_offsets) - print_with_rank("local dof offsets: ", local_leaf_dof_offsets) - print_with_rank("remote offsets: ", remote_leaf_dof_offsets) - return (nroots, local_leaf_dof_offsets, remote_leaf_dof_offsets) diff --git a/pyop3/axtree/tree.py b/pyop3/axtree/tree.py index 326f2b1a..1784a13f 100644 --- a/pyop3/axtree/tree.py +++ b/pyop3/axtree/tree.py @@ -26,7 +26,6 @@ from pyop3 import utils from pyop3.dtypes import IntType, PointerType, get_mpi_dtype -from pyop3.extras.debug import print_if_rank, print_with_rank from pyop3.sf import StarForest from pyop3.tree import ( LabelledNodeComponent, @@ -187,15 +186,9 @@ def map_called_map(self, expr): # the inner_expr tells us the right mapping for the temporary, however, # for maps that are arrays the innermost axis label does not always match # the label used by the temporary. Therefore we need to do a swap here. - # I don't like this. - # print_if_rank(0, repr(array.axes)) - # print_if_rank(0, "before: ",indices) inner_axis = array.axes.leaf_axis indices[inner_axis.label] = indices.pop(expr.function.full_map.name) - # print_if_rank(0, "after:",indices) - # print_if_rank(0, repr(expr)) - # print_if_rank(0, self.context) return array.get_value(path, indices) @@ -580,55 +573,6 @@ def add_node( parent_cpt_label = _as_axis_component_label(parent_component) return super().add_node(axis, parent, parent_cpt_label, **kwargs) - # alias - add_subaxis = add_node - - # currently untested but should keep - @classmethod - def from_layout(cls, layout: Sequence[ConstrainedMultiAxis]) -> Any: # TODO - return order_axes(layout) - - # TODO this is just a regular tree search - @deprecated(internal=True) # I think? - def get_part_from_path(self, path, axis=None): - axis = axis or self.root - - label, *sublabels = path - - (component, component_index) = just_one( - [ - (cpt, cidx) - for cidx, cpt in enumerate(axis.components) - if (axis.label, cidx) == label - ] - ) - if sublabels: - return self.get_part_from_path( - sublabels, self.component_child(axis, component) - ) - else: - return axis, component - - @deprecated(internal=True) - def drop_last(self): - """Remove the last subaxis""" - if not self.part.subaxis: - return None - else: - return self.copy( - parts=[self.part.copy(subaxis=self.part.subaxis.drop_last())] - ) - - @property - @deprecated(internal=True) - def is_linear(self): - """Return ``True`` if the multi-axis contains no branches at any level.""" - if self.nparts == 1: - return self.part.subaxis.is_linear if self.part.subaxis else True - else: - return False - - @deprecated() def add_subaxis(self, subaxis, *loc): return self.add_node(subaxis, *loc) @@ -657,8 +601,6 @@ class AxisTree(PartialAxisTree, Indexed, ContextFreeLoopIterable): "target_paths", "index_exprs", "layout_exprs", - "layouts", - "sf", } def __init__( @@ -667,7 +609,6 @@ def __init__( target_paths=None, index_exprs=None, layout_exprs=None, - sf=None, ): if some_but_not_all( arg is None for arg in [target_paths, index_exprs, layout_exprs] @@ -678,7 +619,6 @@ def __init__( self._target_paths = target_paths or self._default_target_paths() self._index_exprs = index_exprs or self._default_index_exprs() self.layout_exprs = layout_exprs or self._default_layout_exprs() - self.sf = sf or self._default_sf() def __getitem__(self, indices): from pyop3.itree.tree import as_index_forest, collect_loop_contexts, index_axes @@ -762,7 +702,7 @@ def layouts(self): @cached_property def sf(self): - return cls._default_sf(tree) + return self._default_sf() @cached_property def datamap(self): @@ -771,17 +711,6 @@ def datamap(self): else: dmap = postvisit(self, _collect_datamap, axes=self) - # for cleverdict in [self.layouts, self.orig_layout_fn]: - # for layout in cleverdict.values(): - # for layout_expr in layout.values(): - # # catch invalid layouts - # if isinstance(layout_expr, pym.primitives.NaN): - # continue - # for array in MultiArrayCollector()(layout_expr): - # dmap.update(array.datamap) - - # TODO - # for cleverdict in [self.index_exprs, self.layout_exprs]: for cleverdict in [self.index_exprs]: for exprs in cleverdict.values(): for expr in exprs.values(): @@ -939,26 +868,6 @@ def datamap(self): return merge_dicts(axes.datamap for axes in self.context_map.values()) -@dataclasses.dataclass(frozen=True) -class Path: - # TODO Make a persistent dict? - from_axes: Tuple[Any] # axis part IDs I guess (or labels) - to_axess: Tuple[Any] # axis part IDs I guess (or labels) - arity: int - selector: Optional[Any] = None - """The thing that chooses between the different possible output axes at runtime.""" - - @property - def degree(self): - return len(self.to_axess) - - @property - def to_axes(self): - if self.degree != 1: - raise RuntimeError("Only for degree 1 paths") - return self.to_axess[0] - - @functools.singledispatch def as_axis_tree(arg: Any): from pyop3.array import HierarchicalArray # cyclic import diff --git a/pyop3/buffer.py b/pyop3/buffer.py index 32a91d6e..2bc741c4 100644 --- a/pyop3/buffer.py +++ b/pyop3/buffer.py @@ -8,7 +8,6 @@ from mpi4py import MPI from pyop3.dtypes import ScalarType -from pyop3.extras.debug import print_if_rank from pyop3.lang import KernelArgument from pyop3.utils import UniqueNameGenerator, as_tuple, deprecated, readonly diff --git a/pyop3/ir/lower.py b/pyop3/ir/lower.py index 8e8c41d3..a388dda9 100644 --- a/pyop3/ir/lower.py +++ b/pyop3/ir/lower.py @@ -29,7 +29,6 @@ from pyop3.axtree.tree import ContextSensitiveAxisTree from pyop3.buffer import DistributedBuffer, PackedBuffer from pyop3.dtypes import IntType, PointerType -from pyop3.extras.debug import print_with_rank from pyop3.itree import ( AffineSliceComponent, CalledMap, @@ -442,11 +441,6 @@ def parse_loop_properly_this_time( # these aren't jnames! my_index_exprs = axes.index_exprs.get((axis.id, component.label), {}) - print_with_rank("myindexexprs", my_index_exprs) - print_with_rank("new_iname_rplacemap", new_iname_replace_map) - print_with_rank("jname_replace_map", jname_replace_map) - print_with_rank("outerreplac", outer_replace_map) - jname_extras = {} for axis_label, index_expr in my_index_exprs.items(): jname_expr = JnameSubstitutor( @@ -854,10 +848,6 @@ def array_expr(): array_ = array.with_context(context) return make_array_expr( array, - # I think... - # not calling substitute layouts from above so loop indices not - # present in the layout... - # subst_layout(axes, source_path, target_path), array_.layouts[target_path], target_path, iname_replace_map | jname_replace_map, @@ -919,14 +909,6 @@ def make_temp_expr(temporary, shape, path, jnames, ctx): return pym.subscript(pym.var(temporary.name), extra_indices + (temp_offset_var,)) -def subst_layout(axes, source_path, target_path): - replace_map = {} - for axis, cpt in axes.detailed_path(source_path).items(): - replace_map.update(axes.layout_exprs[axis.id, cpt]) - - return IndexExpressionReplacer(replace_map)(axes.layouts[target_path]) - - class JnameSubstitutor(pym.mapper.IdentityMapper): def __init__(self, replace_map, codegen_context): self._labels_to_jnames = replace_map @@ -1117,17 +1099,6 @@ def map_variable(self, expr): return self._replace_map.get(expr.name, expr) -def collect_arrays(expr: pym.primitives.Expr): - collector = MultiArrayCollector() - return collector(expr) - - -def replace_variables( - expr: pym.primitives.Expr, replace_map: dict[str, pym.primitives.Variable] -): - return VariableReplacer(replace_map)(expr) - - def _scalar_assignment( array, path, @@ -1146,22 +1117,6 @@ def _scalar_assignment( return rexpr -def find_axis(axes, path, target, current_axis=None): - """Return the axis matching ``target`` along ``path``. - - ``path`` is a mapping between axis labels and the selected component indices. - """ - current_axis = current_axis or axes.root - - if current_axis.label == target: - return current_axis - else: - subaxis = axes.child(current_axis, path[current_axis.label]) - if not subaxis: - assert False, "oops" - return find_axis(axes, path, target, subaxis) - - def context_from_indices(loop_indices): loop_context = {} for loop_index, (path, _) in loop_indices.items(): diff --git a/pyop3/itree/tree.py b/pyop3/itree/tree.py index de77f957..240a5d2f 100644 --- a/pyop3/itree/tree.py +++ b/pyop3/itree/tree.py @@ -35,7 +35,6 @@ PartialAxisTree, ) from pyop3.dtypes import IntType, get_mpi_dtype -from pyop3.extras.debug import print_if_rank, print_with_rank from pyop3.tree import LabelledTree, Node, Tree, postvisit from pyop3.utils import ( Identified, @@ -56,33 +55,21 @@ def __init__(self, replace_map): self._replace_map = replace_map def map_axis_variable(self, expr): - # print_if_rank(0, "replace map ", self._replace_map) - # return self._replace_map[expr.axis_label] return self._replace_map.get(expr.axis_label, expr) def map_multi_array(self, expr): from pyop3.array.harray import MultiArrayVariable - # print_if_rank(0, self._replace_map) - # print_if_rank(0, expr.indices) indices = {axis: self.rec(index) for axis, index in expr.indices.items()} return MultiArrayVariable(expr.array, indices) def map_called_map(self, expr): array = expr.function.map_component.array - # should the following only exist at eval time? - # the inner_expr tells us the right mapping for the temporary, however, # for maps that are arrays the innermost axis label does not always match # the label used by the temporary. Therefore we need to do a swap here. - # I don't like this. - # inner_axis = array.axes.leaf_axis - # print_if_rank(0, self._replace_map) - # print_if_rank(0, expr.parameters) indices = {axis: self.rec(idx) for axis, idx in expr.parameters.items()} - # indices[inner_axis.label] = indices.pop(expr.function.full_map.name) - return CalledMapVariable(expr.function, indices) def map_loop_index(self, expr): @@ -90,13 +77,6 @@ def map_loop_index(self, expr): return self._replace_map.get((expr.name, expr.axis), expr) -# just use a pmap for this -# class IndexForest: -# def __init__(self, trees: Mapping[Mapping, IndexTree]): -# # per loop context -# self.trees = trees - - # index trees are different to axis trees because we know less about # the possible attaching components. In particular a CalledMap can # have different "attaching components"/output components depending on @@ -148,9 +128,6 @@ def parse_parent_to_children(parent_to_children, loop_context, parent=None): return pmap() -IndexLabel = collections.namedtuple("IndexLabel", ["axis", "component"]) - - class DatamapCollector(pym.mapper.CombineMapper): def combine(self, values): return merge_dicts(values) @@ -366,7 +343,6 @@ def index(self) -> LoopIndex: target_paths, index_exprs, layout_exprs, - axes.layouts, ) context_sensitive_axes = ContextSensitiveAxisTree({context: axes}) @@ -1029,7 +1005,6 @@ def _(called_map: CalledMap, **kwargs): ) = _make_leaf_axis_from_called_map( called_map, prior_target_path, prior_index_exprs ) - # axes = axes.add_node(subaxis, prior_leaf_axis, prior_leaf_cpt) axes = axes.add_subtree( PartialAxisTree(subaxis), prior_leaf_axis, prior_leaf_cpt ) @@ -1219,10 +1194,6 @@ def _compose_bits( if iaxis is None: target_path |= itarget_paths.get(None, {}) partial_index_exprs |= iindex_exprs.get(None, {}) - # partial_layout_exprs |= ilayout_exprs.get(None, {}) - - # no idea why I put this line here - # visited_target_axes = visited_target_axes.union(target_path.keys()) iaxis = indexed_axes.root target_path_per_cpt = collections.defaultdict(dict) @@ -1247,8 +1218,6 @@ def _compose_bits( ] # if target_path is "complete" then do stuff, else pass responsibility to next func down - # index_exprs[iaxis.id, icpt.label] = {} - # layout_exprs[iaxis.id, icpt.label] = {} new_visited_target_axes = visited_target_axes if axes.is_valid_path(new_target_path): detailed_path = axes.detailed_path(new_target_path) @@ -1287,7 +1256,6 @@ def _compose_bits( new_index_exprs_acc = new_index_exprs_acc | { axis_label: new_index_expr } - # new_partial_index_exprs = pmap() # now do the layout expressions, this is simpler since target path magic isnt needed # compose layout expressions, this does an *outside* substitution @@ -1344,11 +1312,7 @@ def _compose_bits( else: pass - # assert not skip huh? - # assert not new_partial_index_exprs - # assert not new_partial_layout_exprs - # breakpoint() return ( freeze(dict(target_path_per_cpt)), freeze(dict(index_exprs)), @@ -1375,10 +1339,6 @@ def iter_axis_tree( new_exprs = {} for axlabel, index_expr in myindex_exprs.items(): new_index = ExpressionEvaluator(outermap)(index_expr) - print_with_rank("initialrepr", repr(index_expr)) - print_with_rank("replacedrepr", repr(new_index)) - print_with_rank("initial", index_expr) - print_with_rank("replaced", new_index) assert new_index != index_expr new_exprs[axlabel] = new_index index_exprs_acc = freeze(new_exprs) @@ -1400,14 +1360,9 @@ def iter_axis_tree( new_index = ExpressionEvaluator(outermap | indices | {axis.label: pt})( index_expr ) - print_with_rank("initialrepr", repr(index_expr)) - print_with_rank("replacedrepr", repr(new_index)) - print_with_rank("initial", index_expr) - print_with_rank("replaced", new_index) assert new_index != index_expr new_exprs[axlabel] = new_index index_exprs_ = index_exprs_acc | new_exprs - # index_exprs_ = index_exprs | myindex_exprs indices_ = indices | {axis.label: pt} if subaxis: yield from iter_axis_tree( @@ -1422,9 +1377,7 @@ def iter_axis_tree( index_exprs_, ) else: - # yield path_, index_exprs_, indices_ yield path_, target_path_, indices_, index_exprs_ - # yield path_, indices_ class ArrayPointLabel(enum.IntEnum): @@ -1464,7 +1417,6 @@ def partition_iterset(index: LoopIndex, arrays): from pyop3.array import HierarchicalArray # take first - # paraxis = [axis for axis in index.iterset.nodes if axis.sf is not None][0] if index.iterset.depth > 1: raise NotImplementedError("Need a good way to sniff the parallel axis") paraxis = index.iterset.root @@ -1480,13 +1432,6 @@ def partition_iterset(index: LoopIndex, arrays): if not array.array.is_distributed: continue - # take first - # array_paraxes = [ - # axis for axis in array.orig_array.axes.nodes if axis.sf is not None - # ] - # - # array_paraxis = array_paraxes[0] - # sf = array_paraxis.sf sf = array.array.sf # the dof sf # mark leaves and roots @@ -1494,18 +1439,6 @@ def partition_iterset(index: LoopIndex, arrays): is_root_or_leaf[sf.iroot] = ArrayPointLabel.ROOT is_root_or_leaf[sf.ileaf] = ArrayPointLabel.LEAF - # do this because we need to think of the indices here as a selector - # rather than a map. We need to transform to the new numbering, hence we - # need to apply the map default -> reordered, but the indexing semantics - # are the opposite of this - # is_root_or_leaf = is_root_or_leaf[array_paraxis.numbering] - # this is equivalent to: - # new_labels = np.empty_like(labels) - # for i, l in enumerate(labels): - # j = array_paraxis._inverse_numbering[i] - # new_labels[j] = l - # labels = new_labels - is_root_or_leaf_per_array[array.name] = is_root_or_leaf labels = np.full(paraxis.size, IterationPointType.CORE, dtype=np.uint8) @@ -1527,59 +1460,15 @@ def partition_iterset(index: LoopIndex, arrays): # loop over stencil array = array.with_context({index.id: target_path}) - print_with_rank("array axes", array.axes) - print_with_rank("array targetpaths", array.target_paths) - print_with_rank("array idxsexprs", array.index_exprs) for ( array_path, array_target_path, array_indices, array_target_indices, ) in array.iter_indices(replace_map): - # allexprs = dict(array.axes.index_exprs.get(None, {})) - # if not array.axes.is_empty: - # for myaxis, mycpt in array.axes.path_with_nodes( - # *array.axes._node_from_path(array_path) - # ).items(): - # allexprs.update(array.axes.index_exprs[myaxis.id, mycpt]) - # - # offset = array.axes.offset(array_path, array_indices | replace_map) - # offset = array.offset( - # array_target_path, array_target_indices | replace_map - # ) - # offset = array.simple_offset(array_path, array_indices | replace_map) - print_with_rank("array path", array_path) - print_with_rank("array idxs", array_indices) - print_with_rank("array target path", array_target_path) - print_with_rank("array target idxs", array_target_indices) - # offset = array.simple_offset(array_target_path, array_target_indices | replace_map) - - print_with_rank("myindices", array_indices | replace_map) - # offset = array.simple_offset(array_target_path, array_indices | replace_map) - # offset = array.simple_offset(array_target_path, array_target_indices | replace_map) - offset = array.simple_offset(array_target_path, array_target_indices) - # allexprs is indexed with the "source" labels but we want a particular - # "target" label, need to go backwards... or something - # if len(target_path) != 1: - # raise NotImplementedError - # target_parallel_axis_label = just_one(target_path.keys()) - # the_expr_i_want = allexprs[target_parallel_axis_label] - # - # # but this is for a particular component!! need to map component index to - # # "full" one, how? or just do offset? - # pt_index = pym.evaluate( - # the_expr_i_want, - # replace_map | array_indices, - # ExpressionEvaluator, - # ) - # print_if_rank(1, "ptindex", pt_index) - # assert isinstance(pt_index, numbers.Integral) - - # point_label = is_root_or_leaf_per_array[array.name][pt_index] point_label = is_root_or_leaf_per_array[array.name][offset] - print_if_rank(1, "ptlabel", point_label) if point_label == ArrayPointLabel.LEAF: labels[parindex] = IterationPointType.LEAF break # no point doing more analysis @@ -1592,9 +1481,6 @@ def partition_iterset(index: LoopIndex, arrays): parcpt = just_one(paraxis.components) # for now - print_with_rank("arrayper", is_root_or_leaf_per_array) - print_with_rank("labels", labels) - core = just_one(np.nonzero(labels == IterationPointType.CORE)) root = just_one(np.nonzero(labels == IterationPointType.ROOT)) leaf = just_one(np.nonzero(labels == IterationPointType.LEAF)) diff --git a/pyop3/lang.py b/pyop3/lang.py index b335efbb..156e99ad 100644 --- a/pyop3/lang.py +++ b/pyop3/lang.py @@ -22,7 +22,6 @@ from pyop3.axtree.tree import ContextFree, ContextSensitive, MultiArrayCollector from pyop3.config import config from pyop3.dtypes import IntType, dtype_limits -from pyop3.extras.debug import print_with_rank from pyop3.utils import as_tuple, checked_zip, just_one, merge_dicts, unique diff --git a/pyop3/sf.py b/pyop3/sf.py index 74288021..292a9bb6 100644 --- a/pyop3/sf.py +++ b/pyop3/sf.py @@ -5,7 +5,6 @@ from petsc4py import PETSc from pyop3.dtypes import get_mpi_dtype -from pyop3.extras.debug import print_with_rank from pyop3.utils import just_one @@ -85,14 +84,10 @@ def reduce(self, *args): def reduce_begin(self, *args): reduce_args = self._prepare_args(*args) self.sf.reduceBegin(*reduce_args) - print_with_rank(reduce_args) - print_with_rank("reduce begin") def reduce_end(self, *args): reduce_args = self._prepare_args(*args) self.sf.reduceEnd(*reduce_args) - print_with_rank(reduce_args) - print_with_rank("reduce end") @cached_property def _graph(self): diff --git a/pyop3/space.py b/pyop3/space.py deleted file mode 100644 index 1379d3cb..00000000 --- a/pyop3/space.py +++ /dev/null @@ -1,280 +0,0 @@ -"""This is an old file that isn't used any more.""" - -import functools -import numbers -from typing import Any, Dict, FrozenSet, Hashable, Optional - -import numpy as np -import pytools -from petsc4py import PETSc - -from pyop3.axes import Axis, AxisTree -from pyop3.axes.tree import has_independently_indexed_subaxis_parts -from pyop3.utils import as_tuple - -DEFAULT_AXIS_PRIORITY = 100 - - -class InvalidConstraintsException(Exception): - pass - - -class ConstrainedAxis(pytools.ImmutableRecord): - fields = {"axis", "priority", "within_labels"} - # TODO We could use 'label' to set the priority - # via commandline options - - def __init__( - self, - axis: Axis, - *, - priority: int = DEFAULT_AXIS_PRIORITY, - within_labels: FrozenSet[Hashable] = frozenset(), - ): - self.axis = axis - self.priority = priority - self.within_labels = frozenset(within_labels) - super().__init__() - - def __str__(self) -> str: - return f"{self.__class__.__name__}(axis=({', '.join(str(axis_cpt) for axis_cpt in self.axis)}), priority={self.priority}, within_labels={self.within_labels})" - - -class Space: - def __init__(self, mesh, layout): - # TODO mesh.axes is an iterable (could be extruded for example) - if mesh.axes.depth > 1: - raise NotImplementedError("need to unpack here somehow") - meshaxis = ConstrainedAxis(mesh.axes.root, priority=10) - axes = order_axes([meshaxis] + layout) - - self.mesh = mesh - self.layout = layout - self.axes = axes - - @property - def comm(self): - return self.mesh.comm - - # I don't like that this is an underscored property. I think internal_comm might be better - @property - def _comm(self): - return self.mesh._comm - - # TODO I think that this could be replaced with callbacks or something. - # DMShell supports passing a callback - # https://petsc.org/release/manualpages/DM/DMShellSetCreateGlobalVector/ - # so we could avoid allocating something here. - @functools.cached_property - def layout_vec(self): - """A PETSc Vec compatible with the dof layout of this DataSet.""" - vec = PETSc.Vec().create(comm=self.comm) - vec.setSizes((self.axes.calc_size(self.axes.root), None), bsize=1) - vec.setUp() - return vec - - @functools.cached_property - def dm(self): - from firedrake import dmhooks - from firedrake.mg.utils import get_level - - dm = PETSc.DMShell().create(comm=self._comm) - dm.setGlobalVector(self.layout_vec) - _, level = get_level(self.mesh) - - # We need to pass sf and section for preconditioners that are not - # implemented in Python (e.g. PCPATCH). For Python-level preconditioning - # we can extract all of this information from the function space. - # Since pyop3 spaces are more dynamic than a classical PETSc Vec we can only - # emit sections for "simple" structures where we have points and DoFs/point. - # TODO for "mixed" problems we could use fields in this section as well. - # it is still very fragile compared with pyop3 - # Extruded meshes are currently outlawed (axis tree depth must be 2) and to - # get them to work the numbering would need to be flattened. - # Ephemeral meshes would probably be rather helpful here. - - # this algorithm is basically equivalent to get_global_numbering - # we are allowed to do this if the inner size is always fixed. This is - # regardless of any depth considerations from vector shape etc. - if self.axes.root.label == "mesh" and all( - has_independently_indexed_subaxis_parts( - self.axes, self.axes.root, cpt, cidx - ) - for cidx, cpt in enumerate(self.axes.root.components) - ): - section = PETSc.Section().create(comm=self._comm) - section.setChart(0, self.axes.root.count) - for d in range(self.mesh.dimension + 1): - # in pyop3 points per tdim are counted from zero - # i.e. cell0, vertex0 instead of all cells being < all vertices - subaxis = self.axes.find_node((self.axes.root.id, d)) - ndofs = self.axes.calc_size(subaxis) - if ndofs > 0: - for idx, pt in enumerate(range(*self.mesh.plex.getDepthStratum(d))): - # get the offset for the zeroth sub axis element - indices = [(d, idx)] + [0] * (self.axes.depth - 1) - offset = self.axes.offset(indices) - section.setOffset(pt, offset) - section.setDof(pt, ndofs) - else: - for pt in range(*self.mesh.plex.getDepthStratum(d)): - section.setDof(pt, 0) - section.setOffset(pt, 0) - - sf = self.mesh.plex.getPointSF() - else: - section = None - sf = None - - dmhooks.attach_hooks(dm, level=level, section=section, sf=sf) - return dm - - """note the following code reproduces the same thing as the current - "global_numbering" section from Firedrake. But, I don't see why this should - be correct. I already handle the permutation when I compute the offsets. - - if self.axes.depth == 2 and self.axes.root.label == "mesh": - section = PETSc.Section().create(comm=self._comm) - section.setChart(0, self.axes.root.count) - for d in range(self.mesh.dimension + 1): - # in pyop3 points per tdim are counted from zero - # i.e. cell0, vertex0 instead of all cells being < all vertices - subaxis = self.axes.find_node((self.axes.root.id, d)) - dof = self.axes.calc_size(subaxis) - for pt in range(*self.mesh.plex.getDepthStratum(d)): - section.setDof(pt, dof) - else: - for pt in range(*self.mesh.plex.getDepthStratum(d)): - section.setDof(pt, 0) - - #TODO could make this a property - iset = PETSc.IS().createGeneral(self.mesh.axes.root.permutation, comm=self._comm) - section.setPermutation(iset) - section.setUp() - - sf = self.mesh.plex.getPointSF() - else: - section = None - sf = None - - dmhooks.attach_hooks(dm, level=level, section=section, sf=sf) - return dm - - """ - - -def order_axes(layout): - axes = AxisTree() - layout = list(map(as_constrained_axis, as_tuple(layout))) - axis_to_constraint = {caxis.axis.label: caxis for caxis in layout} - history = set() - while layout: - if tuple(layout) in history: - raise ValueError("Seen this before, cyclic") - history.add(tuple(layout)) - - constrained_axis = layout.pop(0) - axes, inserted = _insert_axis( - axes, constrained_axis, axes.root, axis_to_constraint - ) - if not inserted: - layout.append(constrained_axis) - return axes - - -def _insert_axis( - axes: AxisTree, - new_caxis: ConstrainedAxis, - current_axis: Axis, - axis_to_caxis: Dict[Axis, ConstrainedAxis], - path: Optional[Dict[Hashable, Dict]] = None, -): - path = path or {} - - within_labels = set(path.items()) - - # alias - remove - axis_to_constraint = axis_to_caxis - - if not axes.root: - if not new_caxis.within_labels: - axes = axes.put_node(new_caxis.axis) - return axes, True - else: - return axes, False - - # current_axis = current_axis or axes.root - current_caxis = axis_to_constraint[current_axis.label] - - if new_caxis.priority < current_caxis.priority: - raise NotImplementedError("TODO subtrees") - if new_caxis.within_labels <= within_labels: - # diagram or something? - parent_axis = axes.parent(current_axis) - subtree = axes.pop_subtree(current_axis) - betterid = new_caxis.axis.copy(id=next(Axis._id_generator)) - if not parent_axis: - axes.add_node(betterid) - else: - axes.add_node(betterid, path) - - # must already obey the constraints - so stick back in for all sub components - for comp in betterid.components: - stree = subtree.copy() - # stree.replace_node(stree.root.copy(id=next(MultiAxis._id_generator))) - mypath = (axes._node_to_path[betterid.id] or {}) | { - betterid.label: comp.label - } - axes.add_subtree(stree, mypath, uniquify=True) - axes._parent_and_label_to_child[(betterid, comp.label)] = stree.root.id - # need to register the right parent label - return True - else: - # The priority is less so the axes should definitely - # not be inserted below here - do not recurse - return False - else: - inserted = False - for cidx, cpt in enumerate(current_axis.components): - subaxis = axes.children(current_axis)[cidx] - if subaxis: - # axes can be unchanged - axes, now_inserted = _insert_axis( - axes, - new_caxis, - subaxis, - axis_to_constraint, - path | {current_axis.label: cidx}, - ) - else: - assert new_caxis.priority >= current_caxis.priority - if new_caxis.within_labels <= within_labels | { - (current_axis.label, cidx) - }: - # bad uniquify - betterid = new_caxis.axis.copy(id=next(Axis._id_generator)) - axes = axes.put_node(betterid, current_axis.id, cidx) - now_inserted = True - - inserted = inserted or now_inserted - return axes, inserted - - -@functools.singledispatch -def as_constrained_axis(arg: Any): - raise TypeError - - -@as_constrained_axis.register -def _(arg: ConstrainedAxis): - return arg - - -@as_constrained_axis.register -def _(arg: Axis): - return ConstrainedAxis(arg) - - -@as_constrained_axis.register -def _(arg: numbers.Integral): - return ConstrainedAxis(Axis([arg])) diff --git a/tests/integration/test_parallel_loops.py b/tests/integration/test_parallel_loops.py index 168aa4f8..cff1ea43 100644 --- a/tests/integration/test_parallel_loops.py +++ b/tests/integration/test_parallel_loops.py @@ -5,7 +5,6 @@ from pyrsistent import freeze import pyop3 as op3 -from pyop3.extras.debug import print_with_rank from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET from pyop3.utils import just_one diff --git a/tests/unit/test_axis_ordering_old.py b/tests/unit/test_axis_ordering_old.py deleted file mode 100644 index b8701caa..00000000 --- a/tests/unit/test_axis_ordering_old.py +++ /dev/null @@ -1,119 +0,0 @@ -import pytest - -from pyop3 import * - -# from pyop3.multiaxis import ConstrainedMultiAxis - - -# This file is pretty outdated -pytest.skip(allow_module_level=True) - - -def id_tuple(nodes): - return tuple(node.id for node in nodes) - - -def label_tuple(nodes): - return tuple(node.label for node in nodes) - - -def test_axis_ordering(): - axis0 = MultiAxis([MultiAxisComponent(3, "cpt0")], "ax0") - axis1 = MultiAxis([MultiAxisComponent(1, "cpt0")], "ax1") - - layout = [ConstrainedMultiAxis(axis0), ConstrainedMultiAxis(axis1)] - axes = MultiAxisTree.from_layout(layout) - - assert axes.depth == 2 - assert axes.root == axis0 - assert just_one(axes.children({"ax0": "cpt0"})).label == "ax1" - assert not axes.children({"ax0": "cpt0", "ax1": "cpt0"}) - - ### - - layout = [ConstrainedMultiAxis(axis0), ConstrainedMultiAxis(axis1, priority=0)] - axes = MultiAxisTree.from_layout(layout) - assert axes.depth == 2 - assert axes.root.label == "ax1" - assert just_one(axes.children({"ax1": "cpt0"})).label == "ax0" - assert not axes.children({"ax0": "cpt0", "ax1": "cpt0"}) - - -def test_multicomponent_constraints(): - axis0 = MultiAxis( - [MultiAxisComponent(3, "cpt0"), MultiAxisComponent(3, "cpt1")], "ax0" - ) - axis1 = MultiAxis([MultiAxisComponent(3, "cpt0")], "ax1") - - ### - - layout = [ - ConstrainedMultiAxis(axis0, within_labels={("ax1", "cpt0")}), - ConstrainedMultiAxis(axis1), - ] - axes = order_axes(layout) - - assert axes.depth == 2 - assert axes.root.label == "ax1" - assert just_one(axes.children({"ax1": "cpt0"})).label == "ax0" - assert not axes.children({"ax1": "cpt0", "ax0": "cpt0"}) - assert not axes.children({"ax1": "cpt0", "ax0": "cpt1"}) - - ### - - with pytest.raises(ValueError): - layout = [ - ConstrainedMultiAxis(axis0, within_labels={("ax1", "cpt0")}), - ConstrainedMultiAxis(axis1, within_labels={("ax0", "cpt0")}), - ] - order_axes(layout) - - ### - - layout = [ - ConstrainedMultiAxis(axis0), - ConstrainedMultiAxis(axis1, within_labels={("ax0", "cpt1")}), - ] - axes = order_axes(layout) - - assert axes.depth == 2 - assert axes.root.label == "ax0" - assert not axes.children({"ax0": "cpt0"}) - assert just_one(axes.children({"ax0": "cpt1"})).label == "ax1" - assert not axes.children({"ax0": "cpt1", "ax1": "cpt0"}) - - ### - - -def test_multicomponent_constraints_more(): - # ax0 - # ├──➤ cpt0 : ax1 - # │ └──➤ cpt0 - # └──➤ cpt1 : ax2 - # └──➤ cpt0 : ax1 - # └──➤ cpt0 - - axis0 = MultiAxis( - [ - MultiAxisComponent(3, "cpt0"), - MultiAxisComponent(3, "cpt1"), - ], - "ax0", - ) - axis1 = MultiAxis([MultiAxisComponent(3, "cpt0")], "ax1") - axis2 = MultiAxis([MultiAxisComponent(3, "cpt0")], "ax2") - - layout = [ - ConstrainedMultiAxis(axis0, priority=0), - ConstrainedMultiAxis(axis1, priority=20), - ConstrainedMultiAxis(axis2, within_labels={("ax0", "cpt1")}, priority=10), - ] - axes = order_axes(layout) - - assert axes.depth == 3 - assert axes.root.label == "ax0" - assert just_one(axes.children({"ax0": "cpt0"})).label == "ax1" - assert just_one(axes.children({"ax0": "cpt1"})).label == "ax2" - assert not axes.children({"ax0": "cpt0", "ax1": "cpt0"}) - assert just_one(axes.children({"ax0": "cpt1", "ax2": "cpt0"})).label == "ax1" - assert not axes.children({"ax0": "cpt1", "ax2": "cpt0", "ax1": "cpt0"}) diff --git a/tests/unit/test_distarray.py b/tests/unit/test_distarray.py index b191be10..4969ec2c 100644 --- a/tests/unit/test_distarray.py +++ b/tests/unit/test_distarray.py @@ -7,7 +7,6 @@ from petsc4py import PETSc import pyop3 as op3 -from pyop3.extras.debug import print_with_rank @pytest.fixture diff --git a/tests/unit/test_sparsity_old.py b/tests/unit/test_sparsity_old.py deleted file mode 100644 index 5c8b9f90..00000000 --- a/tests/unit/test_sparsity_old.py +++ /dev/null @@ -1,271 +0,0 @@ -import pytest - -# This file is pretty outdated -pytest.skip(allow_module_level=True) - -from pyop3 import * -from pyop3.distarray.petsc import * -from pyop3.extras.debug import print_with_rank - - -def test_read_sparse_matrix(): - """Read values from a matrix that looks like: - - 0 10 20 - 0 30 40 - 50 0 0 - - """ - nnzaxes = MultiAxis([AxisPart(3, id="p1")]).set_up() - nnz = MultiArray(nnzaxes, name="nnz", data=np.array([2, 2, 1], dtype=np.uint64)) - - indices = MultiArray.from_list( - [[1, 2], [1, 2], [0]], labels=["p1", "any"], name="indices", dtype=np.uint64 - ) - - mataxes = ( - nnzaxes.copy().add_subaxis("p1", [AxisPart(nnz, indices=indices)]).set_up() - ) - mat = MultiArray(mataxes, name="mat", data=np.arange(10, 51, 10)) - - assert mat.get_value([0, 1]) == 10 - assert mat.get_value([0, 2]) == 20 - assert mat.get_value([1, 1]) == 30 - assert mat.get_value([1, 2]) == 40 - assert mat.get_value([2, 0]) == 50 - - -def test_read_sparse_rank_3_tensor(): - """Read values from a matrix that looks like: - - 0 A - B C - - with: - A : [10 0 20] - B : [30 40 0] - C : [50 0 60] - - - """ - ax1 = MultiAxis([AxisPart(2, id="p1")]).set_up() - nnz = MultiArray(ax1, name="nnz", data=np.array([1, 2], dtype=np.uint64)) - - indices1 = MultiArray.from_list( - [[1], [0, 1]], labels=["p1", "any"], name="indices", dtype=np.uint64 - ) - - indices2 = MultiArray.from_list( - [[[0, 2]], [[0, 1], [0, 2]]], - labels=["p1", "any1", "any2"], - name="indices", - dtype=np.uint64, - ) - - ax2 = ( - ax1.copy() - .add_subaxis("p1", [AxisPart(nnz, indices=indices1, id="p2")]) - .add_subaxis("p2", [AxisPart(2, indices=indices2)]) - ).set_up() - tensor = MultiArray(ax2, name="tensor", data=np.arange(10, 61, 10)) - - assert tensor.get_value([0, 1, 0]) == 10 - assert tensor.get_value([0, 1, 2]) == 20 - assert tensor.get_value([1, 0, 0]) == 30 - assert tensor.get_value([1, 0, 1]) == 40 - assert tensor.get_value([1, 1, 0]) == 50 - assert tensor.get_value([1, 1, 2]) == 60 - - -@pytest.fixture -def sparsity1dp1(): - """ - - The cone sparsity of the following mesh: - - v0 v1 v2 v3 - x---x---x---x - c0 c1 c2 - - should look like: - - v0 v1 v2 v3 - v0 x x - v1 x x x - v2 x x x - v3 x x - - """ - mapaxes = ( - MultiAxis([AxisPart(3, label="cells", id="cells")]).add_subaxis( - "cells", [AxisPart(2, label="any")] - ) - ).set_up() - mapdata = MultiArray( - mapaxes, name="map0", data=np.array([0, 1, 1, 2, 2, 3], dtype=IntType) - ) - - iterindex = RangeNode("cells", 3, id="i0") - lmap = rmap = iterindex.add_child( - "i0", TabulatedMapNode(["cells"], ["nodes"], arity=2, data=mapdata[[iterindex]]) - ) - - return make_sparsity(iterindex, lmap, rmap) - - -def test_make_sparsity(sparsity1dp1): - expected = { - (("nodes",), ("nodes",)): { - ((0,), (0,)), - ((0,), (1,)), - ((1,), (0,)), - ((1,), (1,)), - ((1,), (2,)), - ((2,), (1,)), - ((2,), (2,)), - ((2,), (3,)), - ((3,), (2,)), - ((3,), (3,)), - }, - } - - assert sparsity1dp1 == expected - - -def test_make_matrix(sparsity1dp1): - raxes = MultiAxis([AxisPart(4, label="nodes")]) - caxes = raxes.copy() - - mat = PetscMatAIJ(raxes, caxes, sparsity1dp1) - - import pdb - - pdb.set_trace() - - assert False - - -@pytest.mark.parallel(nprocs=2) -def test_make_parallel_matrix(): - """TODO - - Construct a P1 matrix for the following 1D mesh: - - v0 v1 v2 v3 - x-----x-----x-----x - c0 c1 c2 - - The mesh is distributed between 2 processes so the local meshes are: - - v0 v1 v2 - proc 1: x-----x-----o - c0 c1 - - v1 v2 v3 - proc 2: o~~~~~x-----x - c1 c2 - - Where o and ~ (instead of x and -) denote that points are halo, not owned. - - It is essential that all owned points fully store all points in their - adjacency. For FEM the adjacency is given by cl(support(pt)). For process 1 - this means that v2 must be stored, but for process 2, owning v2 requires - that c1 and v1 both exist in the halo. - - Given the adjacency relation as described. The matrix sparsity should be: - - v0 v1 v2 v3 - v0 x x - v1 x x x - v2 x x x - v3 x x - - The sparsities for each process are given by: - - proc 1: - - v0 v1 v2 - v0 x x - v1 x s s - v2 h h - - proc 2: - - v1 v2 v3 - v1 h h - v2 s s x - v3 x x - - Here "s" denotes shared and "h" halo. - - Since PETSc divides ownership across rows, the DoFs in (v3, :) are dropped for - process 1 and the DoFs in (v1, :) are dropped for process 2. - - """ - comm = PETSc.Sys.getDefaultComm() - assert comm.size == 2 - - if comm.rank == 0: - # v0, v1 and v2 - nnodes = 3 - overlap = [Owned(), Shared(), Halo(RemotePoint(1, 1))] - - # now make the sparsity - mapaxes = ( - MultiAxis([AxisPart(2, label="cells", id="cells")]).add_subaxis( - "cells", [AxisPart(2, label="any")] - ) - ).set_up() - mapdata = MultiArray( - mapaxes, name="map0", data=np.array([0, 1, 1, 2], dtype=IntType) - ) - - iterindex = RangeNode("cells", 2, id="i0") - lmap = rmap = iterindex.add_child( - "i0", - TabulatedMapNode(["cells"], ["nodes"], arity=2, data=mapdata[[iterindex]]), - ) - else: - # v1, v2 and v3 - nnodes = 3 - # FIXME: Unclear on the ordering of Shared and Owned - # should they be handled by some numbering? - overlap = [Owned(), Shared(), Halo(RemotePoint(0, 1))] - - # now make the sparsity - mapaxes = ( - MultiAxis([AxisPart(2, label="cells", id="cells")]).add_subaxis( - "cells", [AxisPart(2, label="any")] - ) - ).set_up() - mapdata = MultiArray( - mapaxes, name="map0", data=np.array([2, 1, 1, 0], dtype=IntType) - ) - - iterindex = RangeNode("cells", 2, id="i0") - lmap = rmap = iterindex.add_child( - "i0", - TabulatedMapNode(["cells"], ["nodes"], arity=2, data=mapdata[[iterindex]]), - ) - - axes = MultiAxis([AxisPart(nnodes, label="nodes", overlap=overlap)]).set_up() - sparsity = make_sparsity(iterindex, lmap, rmap) - - print_with_rank(sparsity) - - # new_sparsity = distribute_sparsity(sparsity, axes, axes) - - # print_with_rank(new_sparsity) - - mat = PetscMatAIJ(axes, axes, sparsity) - - # import pdb; pdb.set_trace() - # - mat.petscmat.getLGMap()[0].view() - mat.petscmat.getLGMap()[1].view() - - mat.petscmat.view() - - -if __name__ == "__main__": - test_make_parallel_matrix()