diff --git a/pyop3/__init__.py b/pyop3/__init__.py index bf9384ce..b3e12802 100644 --- a/pyop3/__init__.py +++ b/pyop3/__init__.py @@ -3,14 +3,22 @@ # tracebacks for @property methods so we remove it here. import pytools -del pytools.RecordWithoutPickling.__getattr__ +try: + del pytools.RecordWithoutPickling.__getattr__ +except AttributeError: + pass del pytools -import pyop3.transforms -from pyop3.array import Array, HierarchicalArray, MultiArray, PetscMat -from pyop3.axtree import Axis, AxisComponent, AxisTree # noqa: F401 -from pyop3.buffer import DistributedBuffer # noqa: F401 +import pyop3.ir +import pyop3.transform +from pyop3.array import Array, HierarchicalArray, MultiArray + +# TODO where should these live? +from pyop3.array.harray import AxisVariable +from pyop3.array.petsc import PetscMat, PetscMatAIJ, PetscMatPreallocator # noqa: F401 +from pyop3.axtree import Axis, AxisComponent, AxisTree, PartialAxisTree # noqa: F401 +from pyop3.buffer import DistributedBuffer, NullBuffer # noqa: F401 from pyop3.dtypes import IntType, ScalarType # noqa: F401 from pyop3.itree import ( # noqa: F401 AffineSliceComponent, @@ -23,18 +31,25 @@ Subset, TabulatedMapComponent, ) +from pyop3.itree.tree import ScalarIndex from pyop3.lang import ( # noqa: F401 INC, MAX_RW, MAX_WRITE, MIN_RW, MIN_WRITE, + NA, READ, RW, WRITE, + AddAssignment, + DummyKernelArgument, Function, Loop, + OpaqueKernelArgument, + Pack, + ReplaceAssignment, do_loop, loop, ) -from pyop3.tensor import Dat, Global, Mat, Tensor # noqa: F401 +from pyop3.sf import StarForest, serial_forest, single_star diff --git a/pyop3/array/__init__.py b/pyop3/array/__init__.py index eaef9dc5..8d3e0e3a 100644 --- a/pyop3/array/__init__.py +++ b/pyop3/array/__init__.py @@ -1,7 +1,10 @@ +# arguably put this directly in pyop3/__init__.py +# no use namespacing here really + from .base import Array # noqa: F401 from .harray import ( # noqa: F401 ContextSensitiveMultiArray, HierarchicalArray, MultiArray, ) -from .petsc import PackedPetscMatAIJ, PetscMat, PetscMatAIJ # noqa: F401 +from .petsc import PetscMat, PetscMatAIJ # noqa: F401 diff --git a/pyop3/array/base.py b/pyop3/array/base.py index 6f2b1adc..69ac790f 100644 --- a/pyop3/array/base.py +++ b/pyop3/array/base.py @@ -1,6 +1,6 @@ import abc -from pyop3.lang import KernelArgument +from pyop3.lang import KernelArgument, ReplaceAssignment from pyop3.utils import UniqueNameGenerator @@ -13,7 +13,5 @@ def __init__(self, name=None, *, prefix=None) -> None: raise ValueError("Can only specify one of name and prefix") self.name = name or self._name_generator(prefix or self._prefix) - @property - @abc.abstractmethod - def valid_ranks(self): - pass + def assign(self, other): + return ReplaceAssignment(self, other) diff --git a/pyop3/array/harray.py b/pyop3/array/harray.py index 110262e9..41d5687f 100644 --- a/pyop3/array/harray.py +++ b/pyop3/array/harray.py @@ -26,23 +26,24 @@ ContextSensitive, as_axis_tree, ) +from pyop3.axtree.layout import eval_offset from pyop3.axtree.tree import ( AxisVariable, ExpressionEvaluator, Indexed, MultiArrayCollector, - _path_and_indices_from_index_tuple, - _trim_path, + PartialAxisTree, ) from pyop3.buffer import Buffer, DistributedBuffer from pyop3.dtypes import IntType, ScalarType, get_mpi_dtype -from pyop3.itree import IndexTree, as_index_forest, index_axes -from pyop3.itree.tree import CalledMapVariable, collect_loop_indices, iter_axis_tree -from pyop3.lang import KernelArgument +from pyop3.lang import KernelArgument, ReplaceAssignment, do_loop +from pyop3.log import warning +from pyop3.sf import single_star from pyop3.utils import ( PrettyTuple, UniqueNameGenerator, as_tuple, + debug_assert, deprecated, is_single_valued, just_one, @@ -59,25 +60,71 @@ class IncompatibleShapeError(Exception): """TODO, also bad name""" -class MultiArrayVariable(pym.primitives.Variable): - mapper_method = sys.intern("map_multi_array") +class ArrayVar(pym.primitives.AlgebraicLeaf): + mapper_method = sys.intern("map_array") - def __init__(self, array, indices): - super().__init__(array.name) + def __init__(self, array, indices, path=None): + if path is None: + if array.axes.is_empty: + path = pmap() + else: + path = just_one(array.axes.leaf_paths) + + super().__init__() self.array = array self.indices = freeze(indices) - - def __repr__(self) -> str: - return f"MultiArrayVariable({self.array!r}, {self.indices!r})" + self.path = freeze(path) def __getinitargs__(self): - return self.array, self.indices + return (self.array, self.indices, self.path) + + # def __str__(self) -> str: + # return f"{self.array.name}[{{{', '.join(f'{i[0]}: {i[1]}' for i in self.indices.items())}}}]" + # + # def __repr__(self) -> str: + # return f"MultiArrayVariable({self.array!r}, {self.indices!r})" + + +from pymbolic.mapper.stringifier import PREC_CALL, PREC_NONE, StringifyMapper + + +# This was adapted from pymbolic's map_subscript +def stringify_array(self, array, enclosing_prec, *args, **kwargs): + index_str = self.join_rec( + ", ", array.index_exprs.values(), PREC_NONE, *args, **kwargs + ) + + return self.parenthesize_if_needed( + self.format("%s[%s]", array.name, index_str), enclosing_prec, PREC_CALL + ) - @property - def datamap(self): - return self.array.datamap | merge_dicts( - idx.datamap for idx in self.indices.values() - ) + +pym.mapper.stringifier.StringifyMapper.map_array = stringify_array + + +CalledMapVariable = ArrayVar + + +# does not belong here! +# class CalledMapVariable(ArrayVar): +# mapper_method = sys.intern("map_called_map_variable") +# +# def __init__(self, array, path, input_index_exprs, shape_index_exprs): +# super().__init__(array, {**input_index_exprs, **shape_index_exprs}, path) +# self.input_index_exprs = freeze(input_index_exprs) +# self.shape_index_exprs = freeze(shape_index_exprs) +# +# def __getinitargs__(self): +# return ( +# self.array, +# self.target_path, +# self.input_index_exprs, +# self.shape_index_exprs, +# ) + + +class FancyIndexWriteException(Exception): + pass class HierarchicalArray(Array, Indexed, ContextFree, KernelArgument): @@ -97,11 +144,13 @@ def __init__( *, data=None, max_value=None, + layouts=None, target_paths=None, index_exprs=None, - layouts=None, + outer_loops=None, name=None, prefix=None, + constant=False, ): super().__init__(name=name, prefix=prefix) @@ -124,56 +173,60 @@ def __init__( data = np.asarray(data, dtype=dtype) shape = data.shape else: - shape = axes.size + shape = axes.global_size + data = DistributedBuffer( - shape, dtype, name=self.name, data=data, sf=axes.sf + shape, + axes.sf or axes.comm, + dtype, + name=self.name, + data=data, ) self.buffer = data - - # instead implement "materialize" - self.axes = axes - + self._axes = axes self.max_value = max_value + # TODO This attr really belongs to the buffer not the array + self.constant = constant + if some_but_not_all(x is None for x in [target_paths, index_exprs]): raise ValueError - self._target_paths = target_paths or axes._default_target_paths() - self._index_exprs = index_exprs or axes._default_index_exprs() + if target_paths is None: + target_paths = axes._default_target_paths() + if index_exprs is None: + index_exprs = axes._default_index_exprs() - self.layouts = layouts or axes.layouts + self._target_paths = freeze(target_paths) + self._index_exprs = freeze(index_exprs) + self._outer_loops = outer_loops or () + + self._layouts = layouts if layouts is not None else axes.layouts def __str__(self): return self.name - def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]: - from pyop3.itree.tree import ( - _compose_bits, - _index_axes, - as_index_tree, - collect_loop_contexts, - index_axes, - ) + def __getitem__(self, indices): + return self.getitem(indices, strict=False) + + def getitem(self, indices, *, strict=False): + from pyop3.itree.tree import _compose_bits, _index_axes, as_index_forest + + index_forest = as_index_forest(indices, axes=self.axes, strict=strict) + if len(index_forest) == 1 and pmap() in index_forest: + index_tree = just_one(index_forest.values()) + indexed_axes = _index_axes(index_tree, pmap(), self.axes) - loop_contexts = collect_loop_contexts(indices) - if not loop_contexts: - index_tree = just_one(as_index_forest(indices, axes=self.axes)) - ( - indexed_axes, - target_path_per_indexed_cpt, - index_exprs_per_indexed_cpt, - layout_exprs_per_indexed_cpt, - ) = _index_axes(self.axes, index_tree, pmap()) target_paths, index_exprs, layout_exprs = _compose_bits( self.axes, self.target_paths, self.index_exprs, None, indexed_axes, - target_path_per_indexed_cpt, - index_exprs_per_indexed_cpt, - layout_exprs_per_indexed_cpt, + indexed_axes.target_paths, + indexed_axes.index_exprs, + indexed_axes.layout_exprs, ) return HierarchicalArray( @@ -182,19 +235,14 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]: max_value=self.max_value, target_paths=target_paths, index_exprs=index_exprs, + outer_loops=indexed_axes.outer_loops, layouts=self.layouts, name=self.name, ) array_per_context = {} - for index_tree in as_index_forest(indices, axes=self.axes): - loop_context = index_tree.loop_context - ( - indexed_axes, - target_path_per_indexed_cpt, - index_exprs_per_indexed_cpt, - layout_exprs_per_indexed_cpt, - ) = _index_axes(self.axes, index_tree, loop_context) + for loop_context, index_tree in index_forest.items(): + indexed_axes = _index_axes(index_tree, loop_context, self.axes) ( target_paths, @@ -206,19 +254,20 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]: self.index_exprs, None, indexed_axes, - target_path_per_indexed_cpt, - index_exprs_per_indexed_cpt, - layout_exprs_per_indexed_cpt, + indexed_axes.target_paths, + indexed_axes.index_exprs, + indexed_axes.layout_exprs, ) array_per_context[loop_context] = HierarchicalArray( indexed_axes, data=self.array, - max_value=self.max_value, + layouts=self.layouts, target_paths=target_paths, index_exprs=index_exprs, - layouts=self.layouts, + outer_loops=indexed_axes.outer_loops, name=self.name, + max_value=self.max_value, ) return ContextSensitiveMultiArray(array_per_context) @@ -226,10 +275,6 @@ def __getitem__(self, indices) -> Union[MultiArray, ContextSensitiveMultiArray]: # to be iterable (which it's not). This avoids some confusing behaviour. __iter__ = None - @property - def valid_ranks(self): - return frozenset(range(self.axes.depth + 1)) - @property @deprecated("buffer") def array(self): @@ -239,6 +284,12 @@ def array(self): def dtype(self): return self.array.dtype + @property + def kernel_dtype(self): + # TODO Think about the fact that the dtype refers to either to dtype of the + # array entries (e.g. double), or the dtype of the whole thing (double*) + return self.dtype + @property @deprecated(".data_rw") def data(self): @@ -246,22 +297,69 @@ def data(self): @property def data_rw(self): - return self.array.data_rw + self._check_no_copy_access() + return self.buffer.data_rw[self._buffer_indices] @property def data_ro(self): - return self.array.data_ro + if not isinstance(self._buffer_indices, slice): + warning( + "Read-only access to the array is provided with a copy, " + "consider avoiding if possible." + ) + return self.buffer.data_ro[self._buffer_indices] @property def data_wo(self): """ - Have to be careful. If not setting all values (i.e. subsets) should call - `reduce_leaves_to_roots` first. + Have to be careful. If not setting all values (i.e. subsets) should + call `reduce_leaves_to_roots` first. When this is called we set roots_valid, claiming that any (lazy) 'in-flight' writes can be dropped. """ - return self.array.data_wo + self._check_no_copy_access() + return self.buffer.data_wo[self._buffer_indices] + + # TODO: This should be more widely cached, don't want to tabulate more often + # than required. + @cached_property + def _buffer_indices(self): + assert self.size > 0 + + indices = np.full(self.axes.owned.size, -1, dtype=IntType) + # TODO: Handle any outer loops. + # TODO: Generate code for this. + for i, p in enumerate(self.axes.iter()): + indices[i] = self.offset(p.source_exprs, p.source_path) + debug_assert(lambda: (indices >= 0).all()) + + # The packed indices are collected component-by-component so, for + # numbered multi-component axes, they are not in ascending order. + # We sort them so we can test for "affine-ness". + indices.sort() + + # See if we can represent these indices as a slice. This is important + # because slices enable no-copy access to the array. + steps = np.unique(indices[1:] - indices[:-1]) + if len(steps) == 1: + start = indices[0] + stop = indices[-1] + 1 + (step,) = steps + return slice(start, stop, step) + else: + return indices + + def _check_no_copy_access(self): + if not isinstance(self._buffer_indices, slice): + raise FancyIndexWriteException( + "Writing to the array directly is not supported for " + "non-trivially indexed (i.e. sliced) arrays." + ) + + @property + def axes(self): + return self._axes @property def target_paths(self): @@ -271,13 +369,26 @@ def target_paths(self): def index_exprs(self): return self._index_exprs + @property + def outer_loops(self): + return self._outer_loops + + @property + def layouts(self): + return self._layouts + @property def sf(self): return self.array.sf + @property + def comm(self): + return self.buffer.comm + @cached_property def datamap(self): - datamap_ = {self.name: self} + datamap_ = {} + datamap_.update(self.buffer.datamap) datamap_.update(self.axes.datamap) for index_exprs in self.index_exprs.values(): for expr in index_exprs.values(): @@ -289,6 +400,7 @@ def datamap(self): return freeze(datamap_) # TODO update docstring + # TODO is this a property of the buffer? def assemble(self, update_leaves=False): """Ensure that stored values are up-to-date. @@ -306,46 +418,27 @@ def assemble(self, update_leaves=False): def materialize(self) -> HierarchicalArray: """Return a new "unindexed" array with the same shape.""" # "unindexed" axis tree - axes = AxisTree(self.axes.parent_to_children) + # strip parallel semantics (in a bad way) + parent_to_children = collections.defaultdict(list) + for p, cs in self.axes.parent_to_children.items(): + for c in cs: + if c is not None and c.sf is not None: + c = c.copy(sf=None) + parent_to_children[p].append(c) + + axes = AxisTree(parent_to_children) return type(self)(axes, dtype=self.dtype) - def offset(self, *args, allow_unused=False, insert_zeros=False): - nargs = len(args) - if nargs == 2: - path, indices = args[0], args[1] - else: - assert nargs == 1 - path, indices = _path_and_indices_from_index_tuple(self.axes, args[0]) - - if allow_unused: - path = _trim_path(self.axes, path) - - if insert_zeros: - # extend the path by choosing the zero offset option every time - # this is needed if we don't have all the internal bits available - while path not in self.layouts: - axis, clabel = self.axes._node_from_path(path) - subaxis = self.axes.child(axis, clabel) - # choose the component that is first in the renumbering - if subaxis.numbering: - cidx = subaxis._component_index_from_axis_number( - subaxis.numbering.data_ro[0] - ) - else: - cidx = 0 - subcpt = subaxis.components[cidx] - path |= {subaxis.label: subcpt.label} - indices |= {subaxis.label: 0} - - offset = pym.evaluate(self.layouts[path], indices, ExpressionEvaluator) - return strict_int(offset) - - def simple_offset(self, path, indices): - offset = pym.evaluate(self.layouts[path], indices, ExpressionEvaluator) - return strict_int(offset) - def iter_indices(self, outer_map): - return iter_axis_tree(self.axes, self.target_paths, self.index_exprs, outer_map) + from pyop3.itree.tree import iter_axis_tree + + return iter_axis_tree( + self.axes.index(), + self.axes, + self.target_paths, + self.index_exprs, + outer_map, + ) def _with_axes(self, axes): """Return a new `Dat` with new axes pointing to the same data.""" @@ -357,16 +450,6 @@ def _with_axes(self, axes): name=self.name, ) - def as_var(self): - # must not be branched... - indices = freeze( - { - axis: AxisVariable(axis) - for axis, _ in self.axes.path(*self.axes.leaf).items() - } - ) - return MultiArrayVariable(self, indices) - @property def alloc_size(self): return self.axes.alloc_size() if not self.axes.is_empty else 1 @@ -410,11 +493,22 @@ def _get_count_data(cls, data): count.append(y) return flattened, count - def get_value(self, *args, **kwargs): - return self.data[self.offset(*args, **kwargs)] - - def set_value(self, path, indices, value): - self.data[self.simple_offset(path, indices)] = value + def get_value(self, indices, path=None, *, loop_exprs=pmap()): + offset = self.offset(indices, path, loop_exprs=loop_exprs) + return self.buffer.data_ro[offset] + + def set_value(self, indices, value, path=None, *, loop_exprs=pmap()): + offset = self.offset(indices, path, loop_exprs=loop_exprs) + self.buffer.data_wo[offset] = value + + def offset(self, indices, path=None, *, loop_exprs=pmap()): + return eval_offset( + self.axes, + self.subst_layouts, + indices, + path, + loop_exprs=loop_exprs, + ) def select_axes(self, indices): selected = [] @@ -424,6 +518,45 @@ def select_axes(self, indices): current_axis = current_axis.get_part(idx.npart).subaxis return tuple(selected) + def copy(self, other): + """Copy the contents of the array into another.""" + # NOTE: Is copy_to/copy_into a clearer name for this? + # TODO: Check that self and other are compatible, should have same axes and dtype + # for sure + # TODO: We can optimise here and copy the private data attribute and set halo + # validity. Here we do the simple but hopefully correct thing. + other.data_wo[...] = self.data_ro + + # symbolic + def zero(self, *, subset=Ellipsis): + return ReplaceAssignment(self[subset], 0) + + def eager_zero(self, *, subset=Ellipsis): + self.zero(subset=subset)() + + @property + @deprecated(".vec_rw") + def vec(self): + return self.vec_rw + + @property + def vec_rw(self): + # FIXME: This does not work for the case when the array here is indexed in some + # way. E.g. dat[::2] since the full buffer is returned. + return self.buffer.vec_rw + + @property + def vec_ro(self): + # FIXME: This does not work for the case when the array here is indexed in some + # way. E.g. dat[::2] since the full buffer is returned. + return self.buffer.vec_ro + + @property + def vec_wo(self): + # FIXME: This does not work for the case when the array here is indexed in some + # way. E.g. dat[::2] since the full buffer is returned. + return self.buffer.vec_wo + # Needs to be subclass for isinstance checks to work # TODO Delete @@ -434,32 +567,27 @@ def __init__(self, *args, **kwargs): # Now ContextSensitiveDat -class ContextSensitiveMultiArray(ContextSensitive, KernelArgument): - def __getitem__(self, indices) -> ContextSensitiveMultiArray: - from pyop3.itree.tree import ( - _compose_bits, - _index_axes, - as_index_tree, - collect_loop_contexts, - index_axes, - ) +class ContextSensitiveMultiArray(Array, ContextSensitive): + def __init__(self, arrays): + name = single_valued(a.name for a in arrays.values()) - loop_contexts = collect_loop_contexts(indices) - if not loop_contexts: - raise NotImplementedError("code path untested") + Array.__init__(self, name) + ContextSensitive.__init__(self, arrays) + + def __getitem__(self, indices) -> ContextSensitiveMultiArray: + from pyop3.itree.tree import _compose_bits, _index_axes, as_index_forest # FIXME for now assume that there is only one context context, array = just_one(self.context_map.items()) + index_forest = as_index_forest(indices, axes=array.axes) + + if len(index_forest) == 1 and pmap() in index_forest: + raise NotImplementedError("code path untested") + array_per_context = {} - for index_tree in as_index_forest(indices, axes=array.axes): - loop_context = index_tree.loop_context - ( - indexed_axes, - target_path_per_indexed_cpt, - index_exprs_per_indexed_cpt, - layout_exprs_per_indexed_cpt, - ) = _index_axes(array.axes, index_tree, loop_context) + for loop_context, index_tree in index_forest.items(): + indexed_axes = _index_axes(index_tree, loop_context, array.axes) ( target_paths, @@ -471,9 +599,9 @@ def __getitem__(self, indices) -> ContextSensitiveMultiArray: array.index_exprs, None, indexed_axes, - target_path_per_indexed_cpt, - index_exprs_per_indexed_cpt, - layout_exprs_per_indexed_cpt, + indexed_axes.target_paths, + indexed_axes.index_exprs, + indexed_axes.layout_exprs, ) array_per_context[loop_context] = HierarchicalArray( indexed_axes, @@ -481,6 +609,7 @@ def __getitem__(self, indices) -> ContextSensitiveMultiArray: max_value=self.max_value, target_paths=target_paths, index_exprs=index_exprs, + outer_loops=indexed_axes.outer_loops, layouts=self.layouts, name=self.name, ) @@ -500,12 +629,14 @@ def dtype(self): return self._shared_attr("dtype") @property - def max_value(self): - return self._shared_attr("max_value") + def kernel_dtype(self): + # TODO Think about the fact that the dtype refers to either to dtype of the + # array entries (e.g. double), or the dtype of the whole thing (double*) + return self.dtype @property - def name(self): - return self._shared_attr("name") + def max_value(self): + return self._shared_attr("max_value") @property def layouts(self): diff --git a/pyop3/array/petsc.py b/pyop3/array/petsc.py index de402262..c7ed603a 100644 --- a/pyop3/array/petsc.py +++ b/pyop3/array/petsc.py @@ -1,6 +1,8 @@ from __future__ import annotations import abc +import collections +import enum import itertools import numbers from functools import cached_property @@ -13,19 +15,21 @@ from pyop3.array.base import Array from pyop3.array.harray import ContextSensitiveMultiArray, HierarchicalArray from pyop3.axtree import AxisTree -from pyop3.axtree.tree import ContextFree, ContextSensitive, as_axis_tree -from pyop3.buffer import PackedBuffer -from pyop3.dtypes import ScalarType -from pyop3.itree import IndexTree -from pyop3.itree.tree import ( - _compose_bits, - _index_axes, - as_index_forest, - as_index_tree, - collect_loop_contexts, - index_axes, +from pyop3.axtree.layout import collect_external_loops +from pyop3.axtree.tree import ( + ContextFree, + ContextSensitive, + PartialAxisTree, + as_axis_tree, + relabel_axes, ) -from pyop3.utils import just_one, merge_dicts, single_valued, strictly_all +from pyop3.buffer import PackedBuffer +from pyop3.cache import cached +from pyop3.dtypes import IntType, ScalarType +from pyop3.itree.tree import CalledMap, LoopIndex, _index_axes, as_index_forest +from pyop3.lang import PetscMatStore, do_loop, loop +from pyop3.mpi import hash_comm +from pyop3.utils import deprecated, just_one, merge_dicts, single_valued, strictly_all # don't like that I need this @@ -38,19 +42,12 @@ def __init__(self, obj: PetscObject): class PetscObject(Array, abc.ABC): dtype = ScalarType - def as_var(self): - return PetscVariable(self) - class PetscVec(PetscObject): def __new__(cls, *args, **kwargs): # dispatch to different vec types based on -vec_type raise NotImplementedError - @property - def valid_ranks(self): - return frozenset({0, 1}) - class PetscVecStandard(PetscVec): ... @@ -60,150 +57,345 @@ class PetscVecNest(PetscVec): ... -class PetscMat(PetscObject): +class PetscMat(PetscObject, abc.ABC): + DEFAULT_MAT_TYPE = PETSc.Mat.Type.AIJ + prefix = "mat" def __new__(cls, *args, **kwargs): - # TODO dispatch to different mat types based on -mat_type - return object.__new__(PetscMatAIJ) + # If the user called PetscMat(...), as opposed to PetscMatAIJ(...) etc + # then inspect mat_type and return the right object. + if cls is PetscMat: + mat_type = kwargs.pop("mat_type", cls.DEFAULT_MAT_TYPE) + if mat_type == PETSc.Mat.Type.AIJ: + return object.__new__(PetscMatAIJ) + # elif mat_type == PETSc.Mat.Type.BAIJ: + # return object.__new__(PetscMatBAIJ) + else: + raise AssertionError + else: + return object.__new__(cls) + + # like Dat, bad name? handle? + @property + def array(self): + return self.mat @property - def valid_ranks(self): - return frozenset({2}) + def values(self): + if self.raxes.size * self.caxes.size > 1e6: + raise ValueError( + "Printing a dense matrix with more than 1 million " + "entries is not allowed" + ) - @cached_property - def datamap(self): - return freeze({self.name: self}) + self.assemble() + return self.mat[:, :] + def assemble(self): + self.mat.assemble() -# is this required? -class ContextSensitiveIndexedPetscMat(ContextSensitive): - pass + def assign(self, other): + return PetscMatStore(self, other) + def eager_zero(self): + self.mat.zeroEntries() -# Not a super important class, could just inspect type of .array instead? -class PackedPetscMatAIJ(PackedBuffer): - pass +class MonolithicPetscMat(PetscMat, abc.ABC): + _row_suffix = "_row" + _col_suffix = "_col" -class PetscMatAIJ(PetscMat): - def __init__(self, raxes, caxes, sparsity, *, comm=None, name: str = None): + def __init__(self, raxes, caxes, sparsity=None, *, name=None): raxes = as_axis_tree(raxes) caxes = as_axis_tree(caxes) - super().__init__(name) - if any(axes.depth > 1 for axes in [raxes, caxes]): - # TODO, good exceptions - # raise InvalidDimensionException("Cannot instantiate PetscMats with nested axis trees") - raise RuntimeError - if any(len(axes.root.components) > 1 for axes in [raxes, caxes]): - # TODO, good exceptions - raise RuntimeError - - sizes = (raxes.leaf_component.count, caxes.leaf_component.count) - nnz = sparsity.axes.leaf_component.count - mat = PETSc.Mat().createAIJ(sizes, nnz=nnz.data, comm=comm) - - # fill with zeros (this should be cached) - # this could be done as a pyop3 loop (if we get ragged local working) or - # explicitly in cython - raxis, rcpt = raxes.leaf - caxis, ccpt = caxes.leaf - # e.g. - # map_ = Map({pmap({raxis.label: rcpt.label}): [TabulatedMapComponent(caxes.label, ccpt.label, sparsity)]}) - # do_loop(p := raxes.index(), write(zeros, mat[p, map_(p)])) - - # but for now do in Python... - assert nnz.max_value is not None - zeros = np.zeros(nnz.max_value, dtype=self.dtype) - for row_idx in range(rcpt.count): - cstart = sparsity.axes.offset([row_idx, 0]) - try: - cstop = sparsity.axes.offset([row_idx + 1, 0]) - except IndexError: - # catch the last one - cstop = len(sparsity.data_ro) - # truncate zeros - mat.setValuesLocal( - [row_idx], sparsity.data_ro[cstart:cstop], zeros[: cstop - cstart] - ) - mat.assemble() + if sparsity is not None: + mat = sparsity.materialize(self.mat_type) + else: + mat = self._make_mat(raxes, caxes, self.mat_type) - self.raxis = raxes.root - self.caxis = caxes.root - self.sparsity = sparsity + super().__init__(name) + self.raxes = raxes + self.caxes = caxes + self.mat = mat - self.axes = AxisTree.from_nest({self.raxis: self.caxis}) + def __getitem__(self, indices): + return self.getitem(indices, strict=False) - # copy only needed if we reuse the zero matrix - self.petscmat = mat.copy() + # Since __getitem__ is implemented, this class is implicitly considered + # to be iterable (which it's not). This avoids some confusing behaviour. + __iter__ = None - def __getitem__(self, indices): + def getitem(self, indices, *, strict=False): # TODO also support context-free (see MultiArray.__getitem__) - array_per_context = {} - for index_tree in as_index_forest(indices, axes=self.axes): - # make a temporary of the right shape - loop_context = index_tree.loop_context - ( - indexed_axes, - # target_path_per_indexed_cpt, - # index_exprs_per_indexed_cpt, - target_paths, - index_exprs, - layout_exprs_per_indexed_cpt, - ) = _index_axes(self.axes, index_tree, loop_context) - - # is this needed? Just use the defaults? - # ( - # target_paths, - # index_exprs, - # layout_exprs, - # ) = _compose_bits( - # self.axes, - # # use the defaults because Mats can only be indexed once - # # (then they turn into Dats) - # self.axes._default_target_paths(), - # self.axes._default_index_exprs(), - # None, - # indexed_axes, - # target_path_per_indexed_cpt, - # index_exprs_per_indexed_cpt, - # layout_exprs_per_indexed_cpt, - # ) - - # "freeze" the indexed_axes, we want to tabulate the layout of them - # (when usually we don't) - indexed_axes = indexed_axes.set_up() - - packed = PackedPetscMatAIJ(self) - - array_per_context[loop_context] = HierarchicalArray( - indexed_axes, + if len(indices) != 2: + raise ValueError + + # Combine the loop contexts of the row and column indices. Consider + # a loop over a multi-component axis with components "a" and "b": + # + # loop(p, mat[p, p]) + # + # The row and column index forests with "merged" loop contexts would + # look like: + # + # { + # {p: "a"}: [rtree0, ctree0], + # {p: "b"}: [rtree1, ctree1] + # } + # + # By contrast, distinct loop indices are combined as a product, not + # merged. For example, the loop + # + # loop(p, loop(q, mat[p, q])) + # + # with p still a multi-component loop over "a" and "b" and q the same + # over "x" and "y". This would give the following combined set of + # index forests: + # + # { + # {p: "a", q: "x"}: [rtree0, ctree0], + # {p: "a", q: "y"}: [rtree0, ctree1], + # {p: "b", q: "x"}: [rtree1, ctree0], + # {p: "b", q: "y"}: [rtree1, ctree1], + # } + + rtrees = as_index_forest(indices[0], axes=self.raxes, strict=strict) + ctrees = as_index_forest(indices[1], axes=self.caxes, strict=strict) + rcforest = {} + for rctx, rtree in rtrees.items(): + for cctx, ctree in ctrees.items(): + # skip if the row and column contexts are incompatible + if any(idx in rctx and rctx[idx] != path for idx, path in cctx.items()): + continue + rcforest[rctx | cctx] = (rtree, ctree) + + arrays = {} + for ctx, (rtree, ctree) in rcforest.items(): + indexed_raxes = _index_axes(rtree, ctx, self.raxes) + indexed_caxes = _index_axes(ctree, ctx, self.caxes) + + if indexed_raxes.alloc_size() == 0 or indexed_caxes.alloc_size() == 0: + continue + router_loops = indexed_raxes.outer_loops + couter_loops = indexed_caxes.outer_loops + + # rmap_axes = AxisTree(indexed_raxes.layout_axes.parent_to_children) + # cmap_axes = AxisTree(indexed_caxes.layout_axes.parent_to_children) + + """ + + KEY POINTS + ---------- + + * These maps require new layouts. Typically when we index something + we want to use the prior layout, here we want to materialise them. + This is basically what we always want for temporaries but this time + we actually want to materialise data. + * We then have to use the default target paths and index exprs. If these + are the "indexed" ones then they don't work. For instance the target + paths target non-existent layouts since we are using new layouts. + + """ + + rmap = HierarchicalArray( + indexed_raxes, + # indexed_raxes.layout_axes, + # rmap_axes, + # target_paths=indexed_raxes.target_paths, + index_exprs=indexed_raxes.index_exprs, + target_paths=indexed_raxes._default_target_paths(), + # index_exprs=indexed_raxes._default_index_exprs(), + layouts=indexed_raxes.layouts, + # target_paths=indexed_raxes.layout_axes.target_paths, + # index_exprs=indexed_raxes.layout_axes.index_exprs, + # layouts=indexed_raxes.layout_axes.layouts, + outer_loops=router_loops, + dtype=IntType, + ) + cmap = HierarchicalArray( + indexed_caxes, + # indexed_caxes.layout_axes, + # cmap_axes, + # target_paths=indexed_caxes.target_paths, + index_exprs=indexed_caxes.index_exprs, + target_paths=indexed_caxes._default_target_paths(), + # index_exprs=indexed_caxes._default_index_exprs(), + layouts=indexed_caxes.layouts, + # target_paths=indexed_caxes.layout_axes.target_paths, + # index_exprs=indexed_caxes.layout_axes.index_exprs, + # layouts=indexed_caxes.layout_axes.layouts, + outer_loops=couter_loops, + dtype=IntType, + ) + + from pyop3.axtree.layout import my_product + + # so these are now failing BADLY because I have no real idea what + # I'm doing here... + # So the issue is that cmap is having values set in the wrong place + # when we are building a sparsity. + + for idxs in my_product(router_loops): + # I don't think that source_indices is currently required because + # we express layouts in terms of the LoopIndexVariable instead of + # LocalLoopIndexVariable (which we should fix). + source_indices = {idx.index.id: idx.source_exprs for idx in idxs} + target_indices = {idx.index.id: idx.target_exprs for idx in idxs} + for p in indexed_raxes.iter(idxs): + offset = self.raxes.offset( + p.target_exprs, p.target_path, loop_exprs=target_indices + ) + rmap.set_value( + # p.source_exprs, offset, p.source_path, loop_exprs=source_indices + p.source_exprs, + offset, + p.source_path, + loop_exprs=target_indices, + ) + + for idxs in my_product(couter_loops): + source_indices = {idx.index.id: idx.source_exprs for idx in idxs} + target_indices = {idx.index.id: idx.target_exprs for idx in idxs} + for p in indexed_caxes.iter(idxs): + offset = self.caxes.offset( + p.target_exprs, p.target_path, loop_exprs=target_indices + ) + cmap.set_value( + # p.source_exprs, offset, p.source_path, loop_exprs=source_indices + p.source_exprs, + offset, + p.source_path, + loop_exprs=target_indices, + ) + + shape = (indexed_raxes.size, indexed_caxes.size) + packed = PackedPetscMat(self, rmap, cmap, shape) + + # Since axes require unique labels, relabel the row and column axis trees + # with different suffixes. This allows us to create a combined axis tree + # without clashes. + raxes_relabel = relabel_axes(indexed_raxes, self._row_suffix) + caxes_relabel = relabel_axes(indexed_caxes, self._col_suffix) + + axes = PartialAxisTree(raxes_relabel.parent_to_children) + for leaf in raxes_relabel.leaves: + axes = axes.add_subtree(caxes_relabel, *leaf, uniquify_ids=True) + axes = axes.set_up() + + outer_loops = list(router_loops) + all_ids = [l.id for l in router_loops] + for ol in couter_loops: + if ol.id not in all_ids: + outer_loops.append(ol) + + my_target_paths = indexed_raxes.target_paths | indexed_caxes.target_paths + my_index_exprs = indexed_raxes.index_exprs | indexed_caxes.index_exprs + + arrays[ctx] = HierarchicalArray( + axes, data=packed, - target_paths=target_paths, - index_exprs=index_exprs, + target_paths=my_target_paths, + index_exprs=my_index_exprs, + # TODO ordered set? + outer_loops=outer_loops, name=self.name, ) + return ContextSensitiveMultiArray(arrays) - return ContextSensitiveMultiArray(array_per_context) + @property + @abc.abstractmethod + def mat_type(self) -> str: + pass + + @staticmethod + def _make_mat(raxes, caxes, mat_type): + # TODO: Internal comm? + comm = single_valued([raxes.comm, caxes.comm]) + mat = PETSc.Mat().create(comm) + mat.setType(mat_type) + # None is for the global size, PETSc will determine it + mat.setSizes(((raxes.owned.size, None), (caxes.owned.size, None))) + + rlgmap = PETSc.LGMap().create(raxes.global_numbering(), comm=comm) + clgmap = PETSc.LGMap().create(caxes.global_numbering(), comm=comm) + mat.setLGMap(rlgmap, clgmap) + + return mat + + @cached_property + def datamap(self): + return freeze({self.name: self}) - # like Dat, bad name? handle? @property - def array(self): - return self.petscmat + def kernel_dtype(self): + raise NotImplementedError("opaque type?") -class PetscMatBAIJ(PetscMat): - ... +class PetscMatAIJ(MonolithicPetscMat): + def __init__(self, raxes, caxes, sparsity=None, *, name: str = None): + super().__init__(raxes, caxes, sparsity, name=name) + @property + def mat_type(self) -> str: + return PETSc.Mat.Type.AIJ -class PetscMatNest(PetscMat): - ... +# class PetscMatBAIJ(MonolithicPetscMat): +# ... -class PetscMatDense(PetscMat): - ... +class PetscMatPreallocator(MonolithicPetscMat): + def __init__(self, raxes, caxes, *, name: str = None): + super().__init__(raxes, caxes, name=name) + self._lazy_template = None -class PetscMatPython(PetscMat): - ... + @property + def mat_type(self) -> str: + return PETSc.Mat.Type.PREALLOCATOR + + def materialize(self, mat_type: str) -> PETSc.Mat: + if self._lazy_template is None: + self.assemble() + + template = self._make_mat(self.raxes, self.caxes, mat_type) + template.preallocateWithMatPreallocator(self.mat) + # We can safely set these options since by using a sparsity we + # are asserting that we know where the non-zeros are going. + template.setOption(PETSc.Mat.Option.NEW_NONZERO_LOCATION_ERR, True) + template.setOption(PETSc.Mat.Option.IGNORE_ZERO_ENTRIES, True) + self._lazy_template = template + return self._lazy_template.copy() + + +# class PetscMatDense(MonolithicPetscMat): +# ... + + +# class PetscMatNest(PetscMat): +# ... + + +# class PetscMatPython(PetscMat): +# ... + + +class PackedPetscMat(PackedBuffer): + def __init__(self, mat, rmap, cmap, shape): + super().__init__(mat) + self.rmap = rmap + self.cmap = cmap + self.shape = shape + + @property + def mat(self): + return self.array + + @cached_property + def datamap(self): + datamap_ = self.mat.datamap | self.rmap.datamap | self.cmap.datamap + for s in self.shape: + if isinstance(s, HierarchicalArray): + datamap_ |= s.datamap + return datamap_ diff --git a/pyop3/axtree/__init__.py b/pyop3/axtree/__init__.py index bf1f2e0c..9ed3fa2b 100644 --- a/pyop3/axtree/__init__.py +++ b/pyop3/axtree/__init__.py @@ -3,8 +3,10 @@ AxisComponent, AxisTree, AxisVariable, + ContextAware, ContextFree, ContextSensitive, LoopIterable, + PartialAxisTree, as_axis_tree, ) diff --git a/pyop3/axtree/layout.py b/pyop3/axtree/layout.py index 1fda7011..9a8249bd 100644 --- a/pyop3/axtree/layout.py +++ b/pyop3/axtree/layout.py @@ -1,7 +1,10 @@ from __future__ import annotations +import collections import functools +import itertools import numbers +import operator import sys from collections import defaultdict from typing import Optional @@ -10,10 +13,27 @@ import pymbolic as pym from pyrsistent import freeze, pmap -from pyop3.axtree.tree import Axis, AxisComponent, AxisTree +from pyop3.axtree.tree import ( + Axis, + AxisComponent, + AxisTree, + ExpressionEvaluator, + PartialAxisTree, + UnrecognisedAxisException, + component_number_from_offsets, + component_offsets, +) from pyop3.dtypes import IntType, PointerType from pyop3.tree import LabelledTree, MultiComponentLabelledNode -from pyop3.utils import PrettyTuple, merge_dicts, strict_int, strictly_all +from pyop3.utils import ( + PrettyTuple, + as_tuple, + checked_zip, + just_one, + merge_dicts, + strict_int, + strictly_all, +) # hacky class for index_exprs to work, needs cleaning up @@ -91,25 +111,77 @@ def has_constant_start( return isinstance(component.count, numbers.Integral) or outer_axes_are_all_indexed -def has_fixed_size(axes, axis, component): - return not size_requires_external_index(axes, axis, component) +def has_constant_step(axes: AxisTree, axis, cpt, inner_loop_vars, path=pmap()): + # we have a constant step if none of the internal dimensions need to index themselves + # with the current index (numbering doesn't matter here) + if subaxis := axes.child(axis, cpt): + return all( + # not size_requires_external_index(axes, subaxis, c, path | {axis.label: cpt.label}) + not size_requires_external_index(axes, subaxis, c, inner_loop_vars, path) + for c in subaxis.components + ) + else: + return True + + +def has_fixed_size(axes, axis, component, inner_loop_vars): + return not size_requires_external_index(axes, axis, component, inner_loop_vars) + + +def requires_external_index(axtree, axis, component_index): + """Return `True` if more indices are required to index the multi-axis layouts + than exist in the given subaxis. + """ + return size_requires_external_index( + axtree, axis, component_index + ) # or numbering_requires_external_index(axtree, axis, component_index) + + +def size_requires_external_index(axes, axis, component, inner_loop_vars, path=pmap()): + from pyop3.array import HierarchicalArray + + count = component.count + if isinstance(count, HierarchicalArray): + if count.axes.is_empty: + leafpath = pmap() + else: + leafpath = just_one(count.axes.leaf_paths) + layout = count.subst_layouts[leafpath] + required_loop_vars = LoopIndexCollector(linear=False)(layout) + if not required_loop_vars.issubset(inner_loop_vars): + return True + # is the path sufficient? i.e. do we have enough externally provided indices + # to correctly index the axis? + if not count.axes.is_empty: + for axlabel, clabel in count.axes.path(*count.axes.leaf).items(): + if axlabel in path: + assert path[axlabel] == clabel + else: + return True + + if subaxis := axes.child(axis, component): + for c in subaxis.components: + # path_ = path | {subaxis.label: c.label} + path_ = path | {axis.label: component.label} + if size_requires_external_index(axes, subaxis, c, inner_loop_vars, path_): + return True + return False def step_size( axes: AxisTree, axis: Axis, component: AxisComponent, - path=pmap(), indices=PrettyTuple(), + *, + loop_indices=pmap(), ): """Return the size of step required to stride over a multi-axis component. Non-constant strides will raise an exception. """ - if not has_constant_step(axes, axis, component) and not indices: - raise ValueError - if subaxis := axes.component_child(axis, component): - return _axis_size(axes, subaxis, path, indices) + if subaxis := axes.child(axis, component): + return _axis_size(axes, subaxis, indices, loop_indices=loop_indices) else: return 1 @@ -126,85 +198,198 @@ def has_halo(axes, axis): return axis.sf is not None or has_halo(axes, subaxis) -def requires_external_index(axtree, axis, component_index): - """Return ``True`` if more indices are required to index the multi-axis layouts - than exist in the given subaxis. - """ - return size_requires_external_index( - axtree, axis, component_index - ) # or numbering_requires_external_index(axtree, axis, component_index) +# NOTE: I am not sure that this is really required any more. We just want to +# check for loop indices in any index_exprs +# No, we need this because loop indices do not necessarily mean we need extra shape. +def collect_externally_indexed_axes(axes, axis=None, component=None, path=pmap()): + assert False, "old code" + from pyop3.array import HierarchicalArray + if axes.is_empty: + return () + + # use a dict as an ordered set + if axis is None: + assert component is None + + external_axes = {} + for component in axes.root.components: + external_axes.update( + { + # NOTE: no longer axes + ax.id: ax + for ax in collect_externally_indexed_axes( + axes, axes.root, component + ) + } + ) + return tuple(external_axes.values()) -def size_requires_external_index(axes, axis, component, path=pmap()): - count = component.count - if not component.has_integer_count: + external_axes = {} + csize = component.count + if isinstance(csize, HierarchicalArray): # is the path sufficient? i.e. do we have enough externally provided indices # to correctly index the axis? - if count.axes.is_empty: - return False - for axlabel, clabel in count.axes.path(*count.axes.leaf).items(): - if axlabel in path: - assert path[axlabel] == clabel - else: - return True + loop_indices = collect_external_loops(csize.axes, csize.index_exprs) + for index in sorted(loop_indices, key=lambda i: i.id): + external_axes[index.id] = index else: - if subaxis := axes.component_child(axis, component): - for c in subaxis.components: - # path_ = path | {subaxis.label: c.label} - path_ = path | {axis.label: component.label} - if size_requires_external_index(axes, subaxis, c, path_): - return True - return False + assert isinstance(csize, numbers.Integral) + + path_ = path | {axis.label: component.label} + if subaxis := axes.child(axis, component): + for subcpt in subaxis.components: + external_axes.update( + { + # NOTE: no longer axes + ax.id: ax + for ax in collect_externally_indexed_axes( + axes, subaxis, subcpt, path_ + ) + } + ) + return tuple(external_axes.values()) -def has_constant_step(axes: AxisTree, axis, cpt): - # we have a constant step if none of the internal dimensions need to index themselves - # with the current index (numbering doesn't matter here) - if subaxis := axes.child(axis, cpt): - return all( - # not size_requires_external_index(axes, subaxis, c, freeze({subaxis.label: c.label})) - not size_requires_external_index(axes, subaxis, c) - for c in subaxis.components + +class LoopIndexCollector(pym.mapper.CombineMapper): + def __init__(self, linear: bool): + super().__init__() + self._linear = linear + + def combine(self, values): + if self._linear: + return sum(values, start=()) + else: + return functools.reduce(operator.or_, values, frozenset()) + + def map_algebraic_leaf(self, expr): + return () if self._linear else frozenset() + + def map_constant(self, expr): + return () if self._linear else frozenset() + + def map_loop_index(self, index): + rec = collect_external_loops( + index.index.iterset, index.index.iterset.index_exprs, linear=self._linear ) - else: - return True + if self._linear: + return rec + (index,) + else: + return rec | {index} + + def map_multi_array(self, array): + if self._linear: + return tuple( + item for expr in array.index_exprs.values() for item in self.rec(expr) + ) + else: + return frozenset( + {item for expr in array.index_exprs.values() for item in self.rec(expr)} + ) + + # def map_called_map_variable(self, index): + # result = ( + # idx + # for index_expr in index.input_index_exprs.values() + # for idx in self.rec(index_expr) + # ) + # return tuple(result) if self._linear else frozenset(result) + + +def collect_external_loops(axes, index_exprs, linear=False): + collector = LoopIndexCollector(linear) + keys = [None] + if not axes.is_empty: + nodes = ( + axes.path_with_nodes(*axes.leaf, and_components=True, ordered=True) + if linear + else tuple((ax, cpt) for ax in axes.nodes for cpt in ax.components) + ) + keys.extend((ax.id, cpt.label) for ax, cpt in nodes) + result = ( + loop + for key in keys + for expr in index_exprs.get(key, {}).values() + for loop in collector(expr) + ) + return tuple(result) if linear else frozenset(result) -# use this to build a tree of sizes that we use to construct -# the right count arrays -class CustomNode(MultiComponentLabelledNode): - fields = MultiComponentLabelledNode.fields | {"counts"} +def _collect_inner_loop_vars(axes: AxisTree, axis: Axis, loop_vars): + # Terminate eagerly because axes representing loops must be outermost. + if axis.label not in loop_vars: + return frozenset() - def __init__(self, counts, **kwargs): - super().__init__(counts, **kwargs) - self.counts = tuple(counts) + loop_var = loop_vars[axis.label] + # Axes representing loops must be single-component. + if subaxis := axes.child(axis, axis.component): + return _collect_inner_loop_vars(axes, subaxis, loop_vars) | {loop_var} + else: + return frozenset({loop_var}) +# TODO: If an axis has size 1 then we don't need a variable for it. def _compute_layouts( axes: AxisTree, + loop_vars, axis=None, - path=pmap(), + layout_path=pmap(), + index_exprs_acc=pmap(), ): - axis = axis or axes.root + """ + Parameters + ---------- + axes + The axis tree to construct a layout for. + loop_vars + Mapping from axis label to loop index variable. Needed for tabulating + indexed layouts because, as we go up the tree, we can identify which + loop indices are materialised. + """ + + from pyop3.array.harray import ArrayVar + + if axis is None: + assert not axes.is_empty + axis = axes.root + # get rid of this + index_exprs_acc |= axes.index_exprs.get(None, {}) + + # Collect the loop variables that are captured by this axis and those below + # it. This lets us determine whether or not something that is indexed is + # sufficiently "within" loops for us to tabulate. + if len(axis.components) == 1 and (subaxis := axes.child(axis, axis.component)): + inner_loop_vars = _collect_inner_loop_vars(axes, subaxis, loop_vars) + else: + inner_loop_vars = frozenset() + inner_loop_vars_with_self = _collect_inner_loop_vars(axes, axis, loop_vars) + layouts = {} steps = {} # Post-order traversal - # make sure to catch children that are None - csubroots = [] csubtrees = [] sublayoutss = [] for cpt in axis.components: - if subaxis := axes.component_child(axis, cpt): - sublayouts, csubroot, csubtree, substeps = _compute_layouts( - axes, subaxis, path | {axis.label: cpt.label} + index_exprs_acc_ = index_exprs_acc | axes.index_exprs.get( + (axis.id, cpt.label), {} + ) + + layout_path_ = layout_path | {axis.label: cpt.label} + + if subaxis := axes.child(axis, cpt): + ( + sublayouts, + csubtree, + substeps, + ) = _compute_layouts( + axes, loop_vars, subaxis, layout_path_, index_exprs_acc_ ) sublayoutss.append(sublayouts) - csubroots.append(csubroot) csubtrees.append(csubtree) steps.update(substeps) else: - csubroots.append(None) csubtrees.append(None) sublayoutss.append(defaultdict(list)) @@ -237,41 +422,43 @@ def _compute_layouts( # 1. do we need to pass further up? i.e. are we variable size? # also if we have halo data then we need to pass to the top - if (not all(has_fixed_size(axes, axis, cpt) for cpt in axis.components)) or ( - has_halo(axes, axis) and axis != axes.root - ): + if ( + not all( + has_fixed_size(axes, axis, cpt, inner_loop_vars_with_self) + for cpt in axis.components + ) + ) or (has_halo(axes, axis) and axis != axes.root): if has_halo(axes, axis) or not all( - has_constant_step(axes, axis, c) for c in axis.components + has_constant_step(axes, axis, c, inner_loop_vars) + for i, c in enumerate(axis.components) ): - croot = CustomNode( - [(cpt.count, axis.label, cpt.label) for cpt in axis.components] - ) + ctree = PartialAxisTree(axis.copy(numbering=None)) + + # we enforce here that all subaxes must be tabulated, is this always + # needed? if strictly_all(sub is not None for sub in csubtrees): - cparent_to_children = pmap( - {croot.id: [sub for sub in csubroots]} - ) | merge_dicts(sub for sub in csubtrees) - else: - cparent_to_children = {} - ctree = cparent_to_children + for component, subtree in checked_zip(axis.components, csubtrees): + ctree = ctree.add_subtree(subtree, axis, component) else: # we must be at the bottom of a ragged patch - therefore don't # add to shape of things # in theory if we are ragged and permuted then we do want to include this level - croot = None ctree = None - for c in axis.components: + for i, c in enumerate(axis.components): step = step_size(axes, axis, c) - layouts.update( - { - path - # | {axis.label: c.label}: AffineLayout(axis.label, c.label, step) - | {axis.label: c.label}: AxisVariable(axis.label) * step - } - ) + if (axis.id, c.label) in loop_vars: + axis_var = loop_vars[axis.id, c.label][axis.label] + else: + axis_var = AxisVariable(axis.label) + layouts.update({layout_path | {axis.label: c.label}: axis_var * step}) # layouts and steps are just propagated from below layouts.update(merge_dicts(sublayoutss)) - return layouts, croot, ctree, steps + return ( + layouts, + ctree, + steps, + ) # 2. add layouts here else: @@ -279,103 +466,170 @@ def _compute_layouts( interleaved = len(axis.components) > 1 and axis.numbering is not None if ( interleaved - or not all(has_constant_step(axes, axis, c) for c in axis.components) + or not all( + has_constant_step(axes, axis, c, inner_loop_vars) + for i, c in enumerate(axis.components) + ) or has_halo(axes, axis) - and axis == axes.root + and axis == axes.root # at the top ): - # super ick - bits = [] - for cpt in axis.components: - axlabel, clabel = axis.label, cpt.label - bits.append((cpt.count, axlabel, clabel)) - croot = CustomNode(bits) + ctree = PartialAxisTree(axis.copy(numbering=None)) + # we enforce here that all subaxes must be tabulated, is this always + # needed? if strictly_all(sub is not None for sub in csubtrees): - cparent_to_children = pmap( - {croot.id: [sub for sub in csubroots]} - ) | merge_dicts(sub for sub in csubtrees) - else: - cparent_to_children = {} + for component, subtree, subiexprs in checked_zip( + axis.components, csubtrees + ): + ctree = ctree.add_subtree(subtree, axis, component) - cparent_to_children |= {None: (croot,)} - ctree = LabelledTree(cparent_to_children) - - fulltree = _create_count_array_tree(ctree) + fulltree = _create_count_array_tree(ctree, loop_vars) # now populate fulltree offset = IntRef(0) - _tabulate_count_array_tree(axes, axis, fulltree, offset, setting_halo=False) + _tabulate_count_array_tree( + axes, + axis, + loop_vars, + fulltree, + offset, + setting_halo=False, + ) # apply ghost offset stuff, the offset from the previous pass is used - _tabulate_count_array_tree(axes, axis, fulltree, offset, setting_halo=True) + _tabulate_count_array_tree( + axes, + axis, + loop_vars, + fulltree, + offset, + setting_halo=True, + ) for subpath, offset_data in fulltree.items(): - layouts[path | subpath] = offset_data.as_var() + # offset_data must be linear so we can unroll the indices + # flat_indices = { + # ax: expr + # } + source_path = offset_data.axes.path_with_nodes(*offset_data.axes.leaf) + index_keys = [None] + [ + (axis.id, cpt) for axis, cpt in source_path.items() + ] + myindices = merge_dicts( + offset_data.index_exprs.get(key, {}) for key in index_keys + ) + offset_var = ArrayVar(offset_data, myindices) + + layouts[layout_path | subpath] = offset_var ctree = None - steps = {path: _axis_size(axes, axis)} + + # bit of a hack, we can skip this if we aren't passing higher up + if axis == axes.root: + steps = "not used" + else: + steps = {layout_path: _axis_size(axes, axis)} layouts.update(merge_dicts(sublayoutss)) - return layouts, None, ctree, steps + return ( + layouts, + ctree, + steps, + ) # must therefore be affine else: assert all(sub is None for sub in csubtrees) - ctree = None layouts = {} - steps = [step_size(axes, axis, c) for c in axis.components] + steps = [ + # step_size(axes, axis, c, index_exprs_acc_) + step_size(axes, axis, c) + # step_size(axes, axis, c, subloops[i]) + for i, c in enumerate(axis.components) + ] start = 0 for cidx, step in enumerate(steps): mycomponent = axis.components[cidx] sublayouts = sublayoutss[cidx].copy() - new_layout = AxisVariable(axis.label) * step + start - sublayouts[path | {axis.label: mycomponent.label}] = new_layout + # key = (axis.id, mycomponent.label) + # axis_var = index_exprs[key][axis.label] + axis_var = AxisVariable(axis.label) + # axis_var = axes.index_exprs[key][axis.label] + # if key in index_exprs: + # axis_var = index_exprs[key][axis.label] + # else: + # axis_var = AxisVariable(axis.label) + new_layout = axis_var * step + start + + sublayouts[layout_path | {axis.label: mycomponent.label}] = new_layout start += _axis_component_size(axes, axis, mycomponent) layouts.update(sublayouts) - steps = {path: _axis_size(axes, axis)} - return layouts, None, None, steps + steps = {layout_path: _axis_size(axes, axis)} + return ( + layouts, + None, + steps, + ) -# I don't think that this actually needs to be a tree, just return a dict -# TODO I need to clean this up a lot now I'm using component labels def _create_count_array_tree( - ctree, current_node=None, counts=PrettyTuple(), path=pmap() + ctree, + loop_vars, + axis=None, + axes_acc=None, + path=pmap(), ): from pyop3.array import HierarchicalArray + from pyop3.itree.tree import IndexExpressionReplacer - current_node = current_node or ctree.root - arrays = {} + if strictly_all(x is None for x in [axis, axes_acc]): + axis = ctree.root + axes_acc = () - for cidx in range(current_node.degree): - count, axis_label, cpt_label = current_node.counts[cidx] + arrays = {} + for component in axis.components: + path_ = path | {axis.label: component.label} + linear_axis = axis[component.label].root + axes_acc_ = axes_acc + (linear_axis,) - child = ctree.children(current_node)[cidx] - new_path = path | {axis_label: cpt_label} - if child is None: - # make a multiarray here from the given sizes - axes = [ - Axis([(ct, clabel)], axlabel) - for (ct, axlabel, clabel) in counts | current_node.counts[cidx] - ] - root = axes[0] - parent_to_children = {None: (root,)} - for parent, child in zip(axes, axes[1:]): - parent_to_children[parent.id] = (child,) - axtree = AxisTree.from_node_map(parent_to_children) - countarray = HierarchicalArray( - axtree, - data=np.full(axis_tree_size(axtree), -1, dtype=IntType), - ) - arrays[new_path] = countarray - else: + if subaxis := ctree.child(axis, component): arrays.update( _create_count_array_tree( ctree, - child, - counts | current_node.counts[cidx], - new_path, + loop_vars, + subaxis, + axes_acc_, + path_, ) ) + else: + # make a multiarray here from the given sizes + + # do we have any external axes from loop indices? + axtree = AxisTree.from_iterable(axes_acc_) + + if loop_vars: + index_exprs = {} + for myaxis in axes_acc_: + key = (myaxis.id, myaxis.component.label) + if myaxis.label in loop_vars: + loop_var = loop_vars[myaxis.label] + index_expr = {myaxis.label: loop_var} + else: + index_expr = {myaxis.label: AxisVariable(myaxis.label)} + index_exprs[key] = index_expr + else: + index_exprs = axtree._default_index_exprs() + + countarray = HierarchicalArray( + axtree, + target_paths=axtree._default_target_paths(), + index_exprs=index_exprs, + outer_loops=(), # ??? + data=np.full(axtree.global_size, -1, dtype=IntType), + prefix="offset", + ) + arrays[path_] = countarray return arrays @@ -383,100 +637,100 @@ def _create_count_array_tree( def _tabulate_count_array_tree( axes, axis, + loop_vars, count_arrays, offset, - path=pmap(), + path=pmap(), # might not be needed indices=pmap(), is_owned=True, setting_halo=False, + outermost=True, + loop_indices=pmap(), # much nicer to combine into indices? ): - npoints = sum(_as_int(c.count, path, indices) for c in axis.components) + npoints = sum(_as_int(c.count, indices) for c in axis.components) - point_to_component_id = np.empty(npoints, dtype=np.int8) - point_to_component_num = np.empty(npoints, dtype=PointerType) - *strata_offsets, _ = [0] + list( - np.cumsum([_as_int(c.count, path, indices) for c in axis.components]) - ) - pos = 0 - point = 0 - # TODO this is overkill, we can just inspect the ranges? - for cidx, component in enumerate(axis.components): - # can determine this once above - csize = _as_int(component.count, path, indices) - for i in range(csize): - point_to_component_id[point] = cidx - # this is now just the identity with an offset? - point_to_component_num[point] = i - point += 1 - pos += csize - - counters = np.zeros(len(axis.components), dtype=int) + offsets = component_offsets(axis, indices) points = axis.numbering.data_ro if axis.numbering is not None else range(npoints) + + counters = {c: itertools.count() for c in axis.components} for new_pt, old_pt in enumerate(points): if axis.sf is not None: - # more efficient outside of loop - _, ilocal, _ = axis.sf._graph - is_owned = new_pt < npoints - len(ilocal) - - # equivalent to plex strata - selected_component_id = point_to_component_id[old_pt] - # selected_component_num = point_to_component_num[old_pt] - selected_component_num = old_pt - strata_offsets[selected_component_id] - selected_component = axis.components[selected_component_id] - - new_strata_pt = counters[selected_component_id] - counters[selected_component_id] += 1 - - new_path = path | {axis.label: selected_component.label} - new_indices = indices | {axis.label: new_strata_pt} - if new_path in count_arrays: + is_owned = new_pt < axis.sf.nowned + + component, _ = component_number_from_offsets(axis, old_pt, offsets) + + new_strata_pt = next(counters[component]) + + path_ = path | {axis.label: component.label} + + if axis.label in loop_vars: + loop_var = loop_vars[axis.label] + loop_indices_ = loop_indices | {loop_var.id: {loop_var.axis: new_strata_pt}} + indices_ = indices + else: + loop_indices_ = loop_indices + indices_ = indices | {axis.label: new_strata_pt} + + if path_ in count_arrays: if is_owned and not setting_halo or not is_owned and setting_halo: - count_arrays[new_path].set_value(new_path, new_indices, offset.value) + count_arrays[path_].set_value( + indices_, offset.value, loop_exprs=loop_indices_ + ) offset += step_size( axes, axis, - selected_component, - new_path, - new_indices, + component, + indices=indices_, + loop_indices=loop_indices_, ) else: - subaxis = axes.component_child(axis, selected_component) + subaxis = axes.component_child(axis, component) assert subaxis _tabulate_count_array_tree( axes, subaxis, + loop_vars, count_arrays, offset, - new_path, - new_indices, + path_, + indices_, is_owned=is_owned, setting_halo=setting_halo, + outermost=False, + loop_indices=loop_indices_, ) # TODO this whole function sucks, should accumulate earlier def _collect_at_leaves( axes, + layout_axes, values, axis: Optional[Axis] = None, path=pmap(), + layout_path=pmap(), prior=0, ): - axis = axis or axes.root - acc = {} + if axis is None: + axis = layout_axes.root - for cpt in axis.components: - new_path = path | {axis.label: cpt.label} - if new_path in values: - # prior_ = prior | {axis.label: values[new_path]} - prior_ = prior + values[new_path] - else: - prior_ = prior - if subaxis := axes.component_child(axis, cpt): - acc.update(_collect_at_leaves(axes, values, subaxis, new_path, prior_)) + acc = {pmap(): prior} if axis == axes.root else {} + for component in axis.components: + layout_path_ = layout_path | {axis.label: component.label} + prior_ = prior + values.get(layout_path_, 0) + + if axis in axes.nodes: + path_ = path | {axis.label: component.label} + acc[path_] = prior_ else: - acc[new_path] = prior_ + path_ = path + if subaxis := layout_axes.child(axis, component): + acc.update( + _collect_at_leaves( + axes, layout_axes, values, subaxis, path_, layout_path_, prior_ + ) + ) return acc @@ -487,19 +741,131 @@ def axis_tree_size(axes: AxisTree) -> int: example, an array with shape ``(10, 3)`` will have a size of 30. """ + # outer_loops = collect_external_loops(axes, axes.index_exprs) + outer_loops = axes.outer_loops + + # loop_exprs = {} + # for ol in outer_loops: + # assert not ol.iterset.index_exprs.get(None, {}), "not sure what to do here" + # + # loop_exprs[ol.id] = {} + # for axis in ol.iterset.nodes: + # key = (axis.id, axis.component.label) + # for ax, expr in ol.iterset.index_exprs.get(key, {}).items(): + # loop_exprs[ol.id][ax] = expr + + # external_axes = collect_externally_indexed_axes(axes) + # if len(external_axes) == 0: if axes.is_empty: return 1 - return _axis_size(axes, axes.root, pmap(), pmap()) + + if all( + has_fixed_size(axes, axes.root, cpt, outer_loops) + for cpt in axes.root.components + ): + # if not outer_loops: + # return _axis_size(axes, axes.root, loop_exprs=loop_exprs) + return _axis_size(axes, axes.root) + + # axis size is now an array + + # axes_iter = [] + # index_exprs = {} + # outer_loop_map = {} + # for ol in outer_loops_ord: + # iterset = ol.index.iterset + # for axis in iterset.path_with_nodes(*iterset.leaf): + # axis_ = axis.copy(id=Axis.unique_id(), label=Axis.unique_label()) + # # axis_ = axis + # axes_iter.append(axis_) + # index_exprs[axis_.id, axis_.component.label] = {axis.label: ol} + # outer_loop_map[axis_] = ol + # size_axes = PartialAxisTree.from_iterable(axes_iter) + # + # # hack + # target_paths = AxisTree(size_axes.parent_to_children)._default_target_paths() + # layout_exprs = {} + # + # size_axes = AxisTree(size_axes.parent_to_children, target_paths=target_paths, index_exprs=index_exprs, outer_loops=outer_loops_ord[:-1], layout_exprs=layout_exprs) + # + # sizes = HierarchicalArray( + # size_axes, + # target_paths=target_paths, + # index_exprs=index_exprs, + # # outer_loops=frozenset(), # only temporaries need this + # # outer_loops=axes.outer_loops, # causes infinite recursion + # outer_loops=outer_loops_ord[:-1], + # dtype=IntType, + # prefix="size", + # ) + # sizes = HierarchicalArray(AxisTree(), target_paths={}, index_exprs={}, outer_loops=outer_loops_ord[:-1]) + # sizes = HierarchicalArray(AxisTree(outer_loops=outer_loops_ord), target_paths={}, index_exprs={}, outer_loops=outer_loops_ord) + # sizes = HierarchicalArray(axes) + sizes = [] + + # for idxs in itertools.product(*outer_loops_iter): + for idxs in my_product(outer_loops): + print(idxs) + # for idx in size_axes.iter(): + # idxs = [idx] + source_indices = merge_dicts(idx.source_exprs for idx in idxs) + target_indices = merge_dicts(idx.target_exprs for idx in idxs) + + # indices = {} + # target_indices = {} + # # myindices = {} + # for axis in size_axes.nodes: + # loop_var = outer_loop_map[axis] + # idx = just_one(idx for idx in idxs if idx.index == loop_var.index) + # # myindices[axis.label] = just_one(sum(idx.source_exprs.values())) + # + # axlabel = just_one(idx.index.iterset.nodes).label + # value = just_one(idx.target_exprs.values()) + # indices[loop_var.index.id] = {axlabel: value} + + # target_indices[just_one(idx.target_path.keys())] = just_one(idx.target_exprs.values()) + + # this is a hack + if axes.is_empty: + size = 1 + else: + size = _axis_size(axes, axes.root, target_indices) + # sizes.set_value(source_indices, size) + sizes.append(size) + # return sizes + return np.asarray(sizes, dtype=IntType) + + +def my_product(loops): + if len(loops) > 1: + raise NotImplementedError( + "Now we are nesting loops so having multiple is a " + "headache I haven't yet tackled" + ) + # loop, *inner_loops = loops + (loop,) = loops + + if loop.iterset.outer_loops: + for indices in my_product(loop.iterset.outer_loops): + context = frozenset(indices) + for index in loop.iter(context): + indices_ = indices + (index,) + yield indices_ + else: + for index in loop.iter(): + yield (index,) def _axis_size( axes: AxisTree, axis: Axis, - path=pmap(), indices=pmap(), -) -> int: + *, + loop_indices=pmap(), +): return sum( - _axis_component_size(axes, axis, cpt, path, indices) for cpt in axis.components + _axis_component_size(axes, axis, cpt, indices, loop_indices=loop_indices) + for cpt in axis.components ) @@ -507,17 +873,18 @@ def _axis_component_size( axes: AxisTree, axis: Axis, component: AxisComponent, - path=pmap(), indices=pmap(), + *, + loop_indices=pmap(), ): - count = _as_int(component.count, path, indices) + count = _as_int(component.count, indices, loop_indices=loop_indices) if subaxis := axes.component_child(axis, component): return sum( _axis_size( axes, subaxis, - path | {axis.label: component.label}, indices | {axis.label: i}, + loop_indices=loop_indices, ) for i in range(count) ) @@ -526,39 +893,131 @@ def _axis_component_size( @functools.singledispatch -def _as_int(arg: Any, path, indices): +def _as_int(arg: Any, indices, path=None, *, loop_indices=pmap()): from pyop3.array import HierarchicalArray if isinstance(arg, HierarchicalArray): + # this shouldn't be here, but it will break things the least to do so + # at the moment + # if index_exprs is None: + # index_exprs = merge_dicts(arg.index_exprs.values()) + # TODO this might break if we have something like [:, subset] # I will need to map the "source" axis (e.g. slice_label0) back # to the "target" axis - return arg.get_value(path, indices, allow_unused=True) + # return arg.get_value(indices, target_path, index_exprs) + return arg.get_value(indices, path, loop_exprs=loop_indices) else: raise TypeError @_as_int.register -def _(arg: numbers.Real, path, indices): +def _(arg: numbers.Real, *args, **kwargs): return strict_int(arg) -def collect_sizes(axes: AxisTree) -> pmap: # TODO value-type of returned pmap? - return _collect_sizes_rec(axes, axes.root) +class LoopExpressionReplacer(pym.mapper.IdentityMapper): + def __init__(self, loop_exprs): + self._loop_exprs = loop_exprs + def map_multi_array(self, array): + index_exprs = {ax: self.rec(expr) for ax, expr in array.index_exprs.items()} + return type(array)(array.array, array.target_path, index_exprs) -def _collect_sizes_rec(axes, axis) -> pmap: - sizes = {} - for cpt in axis.components: - sizes[axis.label, cpt.label] = cpt.count - - if subaxis := axes.component_child(axis, cpt): - subsizes = _collect_sizes_rec(axes, subaxis) - for loc, size in subsizes.items(): - # make sure that sizes always match for duplicates - if loc not in sizes: - sizes[loc] = size - else: - if sizes[loc] != size: - raise RuntimeError - return pmap(sizes) + def map_loop_index(self, index): + return self._loop_exprs[index.id][index.axis] + + +def eval_offset( + # axes, layouts, indices, target_paths, index_exprs, path=None, *, loop_exprs=pmap() + axes, + layouts, + indices, + path=None, + *, + loop_exprs=pmap(), +): + from pyop3.itree.tree import IndexExpressionReplacer + + # layout_axes = axes.layout_axes + layout_axes = axes + + # now select target paths and index exprs from the full collection + # target_path = target_paths.get(None, {}) + # index_exprs_ = index_exprs.get(None, {}) + + # if not layout_axes.is_empty: + # if path is None: + # path = just_one(layout_axes.leaf_paths) + # node_path = layout_axes.path_with_nodes(*layout_axes._node_from_path(path)) + # for axis, component in node_path.items(): + # key = axis.id, component + # if key in target_paths: + # target_path.update(target_paths[key]) + # if key in index_exprs: + # index_exprs_.update(index_exprs[key]) + + if path is None: + path = pmap() if axes.is_empty else just_one(axes.leaf_paths) + + # if the provided indices are not a dict then we assume that they apply in order + # as we go down the selected path of the tree + if not isinstance(indices, collections.abc.Mapping): + # a single index is treated like a 1-tuple + indices = as_tuple(indices) + + indices_ = {} + ordered_path = iter(just_one(axes.ordered_leaf_paths)) + for index in indices: + axis_label, _ = next(ordered_path) + indices_[axis_label] = index + indices = indices_ + + # # then any provided + # if index_exprs is not None: + # replace_map_new = {} + # replacer = ExpressionEvaluator(indices) + # for axis, index_expr in index_exprs.items(): + # try: + # replace_map_new[axis] = replacer(index_expr) + # except UnrecognisedAxisException: + # pass + # indices2 = replace_map_new + # else: + # indices2 = indices + # + # replace_map_new = {} + # replacer = ExpressionEvaluator(indices2) + # for axlabel, index_expr in axes.index_exprs.get(None, {}).items(): + # try: + # replace_map_new[axlabel] = replacer(index_expr) + # except UnrecognisedAxisException: + # pass + # for axis, component in source_path_node.items(): + # for axlabel, index_expr in axes.index_exprs.get((axis.id, component), {}).items(): + # try: + # replace_map_new[axlabel] = replacer(index_expr) + # except UnrecognisedAxisException: + # pass + # indices1 = replace_map_new + + # Substitute indices into index exprs + # if index_exprs: + + # Replace any loop index variables in index_exprs + # index_exprs_ = {} + # replacer = LoopExpressionReplacer(loop_exprs) # different class? + # for ax, expr in index_exprs.items(): + # # if isinstance(expr, LoopIndexVariable): + # # index_exprs_[ax] = loop_exprs[expr.id][ax] + # # else: + # index_exprs_[ax] = replacer(expr) + + # replacer = IndexExpressionReplacer(index_exprs_, loop_exprs) + # layout_orig = layouts[freeze(target_path)] + # layout_subst = replacer(layout_orig) + + layout_subst = layouts[freeze(path)] + + offset = ExpressionEvaluator(indices, loop_exprs)(layout_subst) + return strict_int(offset) diff --git a/pyop3/axtree/parallel.py b/pyop3/axtree/parallel.py index f4bce59a..65399266 100644 --- a/pyop3/axtree/parallel.py +++ b/pyop3/axtree/parallel.py @@ -49,30 +49,23 @@ def partition_ghost_points(axis, sf): return numbering -# stolen from stackoverflow -# https://stackoverflow.com/questions/11649577/how-to-invert-a-permutation-array-in-numpy -def invert(p): - """Return an array s with which np.array_equal(arr[p][s], arr) is True. - The array_like argument p must be some permutation of 0, 1, ..., len(p)-1. - """ - p = np.asanyarray(p) # in case p is a tuple, etc. - s = np.empty_like(p) - s[p] = np.arange(p.size) - return s - - def collect_sf_graphs(axes, axis=None, path=pmap(), indices=pmap()): + # it does not make sense for temporary-like objects to have SFs + if axes.outer_loops: + return () + # NOTE: This function does not check for nested SFs (which should error) - axis = axis or axes.root + if axis is None: + axis = axes.root if axis.sf is not None: return (grow_dof_sf(axes, axis, path, indices),) else: graphs = [] for component in axis.components: - subaxis = axes.child(axis, component) - if subaxis is not None: - for pt in range(_as_int(component.count, path, indices)): + if subaxis := axes.child(axis, component): + # think path is not needed + for pt in range(_as_int(component.count, indices, path)): graphs.extend( collect_sf_graphs( axes, @@ -85,6 +78,7 @@ def collect_sf_graphs(axes, axis=None, path=pmap(), indices=pmap()): # perhaps I can defer renumbering the SF to here? +# PETSc provides a similar function that composes an SF with a Section, can I use that? def grow_dof_sf(axes, axis, path, indices): point_sf = axis.sf # TODO, use convenience methods @@ -95,19 +89,25 @@ def grow_dof_sf(axes, axis, path, indices): npoints = component_offsets[-1] # renumbering per component, can skip if no renumbering present - renumbering = [np.empty(c.count, dtype=int) for c in axis.components] - counters = [0] * len(axis.components) - for new_pt, old_pt in enumerate(axis.numbering.data_ro): - for cidx, (min_, max_) in enumerate( - zip(component_offsets, component_offsets[1:]) - ): - if min_ <= old_pt < max_: - renumbering[cidx][old_pt - min_] = counters[cidx] - counters[cidx] += 1 - break - assert all(count == c.count for count, c in checked_zip(counters, axis.components)) + if axis.numbering is not None: + renumbering = [np.empty(c.count, dtype=int) for c in axis.components] + counters = [0] * len(axis.components) + for new_pt, old_pt in enumerate(axis.numbering.data_ro): + for cidx, (min_, max_) in enumerate( + zip(component_offsets, component_offsets[1:]) + ): + if min_ <= old_pt < max_: + renumbering[cidx][old_pt - min_] = counters[cidx] + counters[cidx] += 1 + break + assert all( + count == c.count for count, c in checked_zip(counters, axis.components) + ) + else: + renumbering = [np.arange(c.count, dtype=int) for c in axis.components] # effectively build the section + new_nroots = 0 root_offsets = np.full(npoints, -1, IntType) for pt in point_sf.iroot: # convert to a component-wise numbering @@ -124,11 +124,17 @@ def grow_dof_sf(axes, axis, path, indices): assert component_num is not None offset = axes.offset( - path | {axis.label: selected_component.label}, indices | {axis.label: component_num}, - insert_zeros=True, + path | {axis.label: selected_component.label}, ) root_offsets[pt] = offset + new_nroots += step_size( + axes, + axis, + selected_component, + (), + indices | {axis.label: component_num}, + ) point_sf.broadcast(root_offsets, MPI.REPLACE) @@ -151,13 +157,13 @@ def grow_dof_sf(axes, axis, path, indices): assert selected_component is not None assert component_num is not None + # this is wrong? offset = axes.offset( - path | {axis.label: selected_component.label}, indices | {axis.label: component_num}, - insert_zeros=True, + path | {axis.label: selected_component.label}, ) local_leaf_offsets[myindex] = offset - leaf_ndofs[myindex] = step_size(axes, axis, selected_component) + leaf_ndofs[myindex] = step_size(axes, axis, selected_component, ()) # construct a new SF with these offsets ndofs = sum(leaf_ndofs) @@ -172,4 +178,4 @@ def grow_dof_sf(axes, axis, path, indices): remote_leaf_dof_offsets[counter] = [rank, root_offsets[pos] + d] counter += 1 - return (nroots, local_leaf_dof_offsets, remote_leaf_dof_offsets) + return (new_nroots, local_leaf_dof_offsets, remote_leaf_dof_offsets) diff --git a/pyop3/axtree/tree.py b/pyop3/axtree/tree.py index 1784a13f..2a234922 100644 --- a/pyop3/axtree/tree.py +++ b/pyop3/axtree/tree.py @@ -22,15 +22,17 @@ import pytools from mpi4py import MPI from petsc4py import PETSc -from pyrsistent import freeze, pmap +from pyrsistent import freeze, pmap, thaw from pyop3 import utils from pyop3.dtypes import IntType, PointerType, get_mpi_dtype -from pyop3.sf import StarForest +from pyop3.extras.debug import print_with_rank +from pyop3.sf import StarForest, serial_forest from pyop3.tree import ( LabelledNodeComponent, LabelledTree, MultiComponentLabelledNode, + as_component_label, postvisit, previsit, ) @@ -38,15 +40,19 @@ PrettyTuple, as_tuple, checked_zip, + debug_assert, deprecated, flatten, frozen_record, has_unique_entries, + invert, is_single_valued, just_one, merge_dicts, + pairwise, single_valued, some_but_not_all, + steps, strict_int, strictly_all, unique, @@ -54,6 +60,11 @@ class Indexed(abc.ABC): + @property + @abc.abstractmethod + def axes(self): + pass + @property @abc.abstractmethod def target_paths(self): @@ -64,6 +75,67 @@ def target_paths(self): def index_exprs(self): pass + @property + @abc.abstractmethod + def outer_loops(self): + pass + + @property + @abc.abstractmethod + def layouts(self): + pass + + @cached_property + def subst_layouts(self): + return self._subst_layouts() + + def _subst_layouts(self, axis=None, path=None, target_path=None, index_exprs=None): + from pyop3 import HierarchicalArray + from pyop3.itree.tree import IndexExpressionReplacer + + # TODO Don't do this every time this function is called + loop_exprs = {} + # for outer_loop in self.outer_loops: + # loop_exprs[outer_loop.id] = {} + # for ax in outer_loop.iterset.nodes: + # key = (ax.id, ax.component.label) + # for ax_, expr in outer_loop.iterset.index_exprs.get(key, {}).items(): + # loop_exprs[outer_loop.id][ax_] = expr + + layouts = {} + if strictly_all(x is None for x in [axis, path, target_path, index_exprs]): + path = pmap() + target_path = self.target_paths.get(None, pmap()) + index_exprs = self.index_exprs.get(None, pmap()) + # target_path = pmap() + # index_exprs = pmap() + + replacer = IndexExpressionReplacer(index_exprs, loop_exprs=loop_exprs) + layouts[path] = replacer(self.layouts.get(target_path, 0)) + + if not self.axes.is_empty: + layouts.update( + self._subst_layouts(self.axes.root, path, target_path, index_exprs) + ) + else: + for component in axis.components: + path_ = path | {axis.label: component.label} + target_path_ = target_path | self.target_paths.get( + (axis.id, component.label), {} + ) + index_exprs_ = index_exprs | self.index_exprs.get( + (axis.id, component.label), {} + ) + + replacer = IndexExpressionReplacer(index_exprs_) + layouts[path_] = replacer(self.layouts.get(target_path_, 0)) + + if subaxis := self.axes.child(axis, component): + layouts.update( + self._subst_layouts(subaxis, path_, target_path_, index_exprs_) + ) + return freeze(layouts) + class ContextAware(abc.ABC): @abc.abstractmethod @@ -92,8 +164,10 @@ class ContextSensitive(ContextAware, abc.ABC): # # """ # - def __init__(self, context_map: pmap[pmap[LoopIndex, pmap[str, str]], ContextFree]): - self.context_map = pmap(context_map) + def __init__(self, context_map): + if isinstance(context_map, pyrsistent.PMap): + raise TypeError("context_map must be deterministically ordered") + self.context_map = context_map @cached_property def keys(self): @@ -111,8 +185,8 @@ def filter_context(self, context): key = {} for loop_index, path in context.items(): if loop_index in self.keys: - key.update({loop_index: path}) - return pmap(key) + key.update({loop_index: freeze(path)}) + return freeze(key) # this is basically just syntactic sugar, might not be needed @@ -159,37 +233,47 @@ class ContextSensitiveLoopIterable(LoopIterable, ContextSensitive, abc.ABC): pass -class ExpressionEvaluator(pym.mapper.evaluator.EvaluationMapper): - def map_axis_variable(self, expr): - return self.context[expr.axis_label] - - def map_multi_array(self, expr): - # path = _trim_path(array.axes, self.context[0]) - # not multi-component for now, is that useful to add? - path = expr.array.axes.path(*expr.array.axes.leaf) - # context = [] - # for keyval in self.context.items(): - # context.append(keyval) - # return expr.array.get_value(path, self.context[1]) - replace_map = {axis: self.rec(idx) for axis, idx in expr.indices.items()} - return expr.array.get_value(path, replace_map) - - def map_loop_index(self, expr): - return self.context[expr.name, expr.axis] +class UnrecognisedAxisException(ValueError): + pass - def map_called_map(self, expr): - array = expr.function.map_component.array - indices = {axis: self.rec(idx) for axis, idx in expr.parameters.items()} - path = array.axes.path(*array.axes.leaf) +class ExpressionEvaluator(pym.mapper.evaluator.EvaluationMapper): + def __init__(self, context, loop_exprs): + super().__init__(context) + self._loop_exprs = loop_exprs - # the inner_expr tells us the right mapping for the temporary, however, - # for maps that are arrays the innermost axis label does not always match - # the label used by the temporary. Therefore we need to do a swap here. - inner_axis = array.axes.leaf_axis - indices[inner_axis.label] = indices.pop(expr.function.full_map.name) + def map_axis_variable(self, expr): + try: + return self.context[expr.axis_label] + except KeyError as e: + raise UnrecognisedAxisException from e + + def map_array(self, array_var): + from pyop3.itree.tree import ExpressionEvaluator, IndexExpressionReplacer + + array = array_var.array + + indices = {ax: self.rec(idx) for ax, idx in array_var.indices.items()} + # replacer = IndexExpressionReplacer(indices, self._loop_exprs) + # layout_orig = array.layouts[freeze(array_var.target_path)] + # layout_subst = replacer(layout_orig) + layout_subst = array.subst_layouts[array_var.path] + + # offset = ExpressionEvaluator(indices, self._loop_exprs)(layout_subst) + # offset = ExpressionEvaluator(self.context | indices, self._loop_exprs)(layout_subst) + offset = ExpressionEvaluator(indices, self._loop_exprs)(layout_subst) + offset = strict_int(offset) + + # return array_var.array.get_value( + # self.context, + # array_var.target_path, # should be source path + # # index_exprs=array_var.index_exprs, + # loop_exprs=self._loop_exprs, + # ) + return array.data_ro[offset] - return array.get_value(path, indices) + def map_loop_index(self, expr): + return self._loop_exprs[expr.id][expr.axis] def _collect_datamap(axis, *subdatamaps, axes): @@ -243,13 +327,14 @@ def __init__( indexed=False, lgmap=None, ): + from pyop3.array import HierarchicalArray + + if not isinstance(count, (numbers.Integral, HierarchicalArray)): + raise TypeError("Invalid count type") + super().__init__(label=label) self.count = count - @property - def has_integer_count(self): - return isinstance(self.count, numbers.Integral) - # TODO this is just a traversal - clean up def alloc_size(self, axtree, axis): from pyop3.array import HierarchicalArray @@ -295,8 +380,9 @@ def __init__( if sum(c.count for c in components) != numbering.size: raise ValueError - super().__init__(components, label=label, id=id) + super().__init__(label=label, id=id) + self.components = components self.numbering = numbering self.sf = sf @@ -304,7 +390,11 @@ def __getitem__(self, indices): # NOTE: This *must* return an axis tree because that is where we attach # index expression information. Just returning as_axis_tree(self).root # here will break things. + # Actually this is not the case for "identity" slices since index_exprs + # and labels are unchanged + # TODO return a flat axis in these cases return as_axis_tree(self)[indices] + # if indexed.depth == 1: def __call__(self, *args): return as_axis_tree(self)(*args) @@ -330,6 +420,22 @@ def from_serial(cls, serial: Axis, sf): numbering = partition_ghost_points(serial, sf) return cls(serial.components, serial.label, numbering=numbering, sf=sf) + @property + def component_labels(self): + return tuple(c.label for c in self.components) + + @property + def component(self): + return just_one(self.components) + + def component_index(self, component) -> int: + clabel = as_component_label(component) + return self.component_labels.index(clabel) + + @property + def comm(self): + return self.sf.comm if self.sf else MPI.COMM_SELF + @property def size(self): return as_axis_tree(self).size @@ -366,14 +472,34 @@ def owned_count_per_component(self): def ghost_count_per_component(self): counts = np.zeros_like(self.components, dtype=int) for leaf_index in self.sf.ileaf: - counts[self._component_index_from_axis_number(leaf_index)] += 1 + counts[self._axis_number_to_component_index(leaf_index)] += 1 return freeze( {cpt: count for cpt, count in checked_zip(self.components, counts)} ) + @cached_property + def owned(self): + from pyop3.itree import AffineSliceComponent, Slice + + if self.comm.size == 1: + return self + + slices = [ + AffineSliceComponent( + c.label, + stop=self.owned_count_per_component[c], + ) + for c in self.components + ] + slice_ = Slice(self.label, slices) + return self[slice_].root + def index(self): return self._tree.index() + def iter(self): + return self._tree.iter() + @property def target_path_per_component(self): return self._tree.target_path_per_component @@ -411,57 +537,72 @@ def as_tree(self) -> AxisTree: """ return self._tree - # Note: these functions assume that the numbering follows the plex convention - # of numbering each strata contiguously. I think (?) that I effectively also do this. - # actually this might well be wrong. we have a renumbering after all - this gives us - # the original numbering only - def component_number_to_axis_number(self, component, num): - component_index = self.components.index(component) - canonical = self._component_numbering_offsets[component_index] + num - return self._to_renumbered(canonical) - - def axis_number_to_component(self, num): - # guess, is this the right map (from new numbering to original)? - # I don't think so because we have a funky point SF. can we get rid? - # num = self.numbering[num] - component_index = self._component_index_from_axis_number(num) - component_num = num - self._component_numbering_offsets[component_index] - # return self.components[component_index], component_num - return self.components[component_index], component_num - - def _component_index_from_axis_number(self, num): - offsets = self._component_numbering_offsets - for i, (min_, max_) in enumerate(zip(offsets, offsets[1:])): - if min_ <= num < max_: - return i - raise ValueError(f"Axis number {num} not found.") + # Ideally I want to cythonize a lot of these methods + def component_numbering(self, component): + cidx = self.component_index(component) + return self._default_to_applied_numbering[cidx] - @cached_property - def _component_numbering_offsets(self): - return (0,) + tuple(np.cumsum([c.count for c in self.components], dtype=int)) + def component_permutation(self, component): + cidx = self.component_index(component) + return self._default_to_applied_permutation[cidx] - # FIXME bad name - def _to_renumbered(self, num): - """Convert a flat/canonical/unpermuted axis number to its renumbered equivalent.""" - if self.numbering is None: - return num - else: - return self._inverse_numbering[num] + def default_to_applied_component_number(self, component, number): + cidx = self.component_index(component) + return self._default_to_applied_numbering[cidx][number] - @cached_property - def _inverse_numbering(self): - # put in utils.py - from pyop3.axtree.parallel import invert + def applied_to_default_component_number(self, component, number): + cidx = self.component_index(component) + return self._applied_to_default_numbering[cidx][number] - if self.numbering is None: - return np.arange(self.count, dtype=IntType) - else: - return invert(self.numbering.data_ro) + def axis_to_component_number(self, number): + # return axis_to_component_number(self, number) + cidx = self._axis_number_to_component_index(number) + return self.components[cidx], number - self._component_offsets[cidx] + + def component_to_axis_number(self, component, number): + cidx = self.component_index(component) + return self._component_offsets[cidx] + number + + def renumber_point(self, component, point): + renumbering = self.component_numbering(component) + return renumbering[point] @cached_property def _tree(self): return AxisTree(self) + @cached_property + def _component_offsets(self): + return (0,) + tuple(np.cumsum([c.count for c in self.components], dtype=int)) + + @cached_property + def _default_to_applied_numbering(self): + renumbering = [np.empty(c.count, dtype=IntType) for c in self.components] + counters = [itertools.count() for _ in range(self.degree)] + for pt in self.numbering.data_ro: + cidx = self._axis_number_to_component_index(pt) + old_cpt_pt = pt - self._component_offsets[cidx] + renumbering[cidx][old_cpt_pt] = next(counters[cidx]) + assert all(next(counters[i]) == c.count for i, c in enumerate(self.components)) + return tuple(renumbering) + + @cached_property + def _default_to_applied_permutation(self): + # is this right? + return self._applied_to_default_numbering + + # same as the permutation... + @cached_property + def _applied_to_default_numbering(self): + return tuple(invert(num) for num in self._default_to_applied_numbering) + + def _axis_number_to_component_index(self, number): + off = self._component_offsets + for i, (min_, max_) in enumerate(zip(off, off[1:])): + if min_ <= number < max_: + return i + raise ValueError(f"{number} not found") + @staticmethod def _parse_components(components): if isinstance(components, collections.abc.Mapping): @@ -489,19 +630,38 @@ def _parse_numbering(numbering): ) -class MultiArrayCollector(pym.mapper.Collector): - def map_called_map(self, expr): - return self.rec(expr.function) | set.union( - *(self.rec(idx) for idx in expr.parameters.values()) - ) +# Do I ever want this? component_offsets is expensive so we don't want to +# do it every time +def axis_to_component_number(axis, number, context=pmap()): + offsets = component_offsets(axis, context) + return component_number_from_offsets(axis, number, offsets) + + +# TODO move into layout.py +def component_number_from_offsets(axis, number, offsets): + cidx = None + for i, (min_, max_) in enumerate(pairwise(offsets)): + if min_ <= number < max_: + cidx = i + break + assert cidx is not None + return axis.components[cidx], number - offsets[cidx] + + +# TODO move into layout.py +def component_offsets(axis, context): + from pyop3.axtree.layout import _as_int - def map_map_variable(self, expr): - return {expr.map_component.array} + return steps([_as_int(c.count, context) for c in axis.components]) - def map_multi_array(self, expr): - return {expr} - def map_nan(self, expr): +class MultiArrayCollector(pym.mapper.Collector): + def map_array(self, array_var): + return {array_var.array}.union( + *(self.rec(expr) for expr in array_var.indices.values()) + ) + + def map_nan(self, nan): return set() @@ -538,9 +698,41 @@ def __init__( ): super().__init__(parent_to_children) + # TODO Move check to generic LabelledTree + self._check_node_labels_unique_in_paths(self.parent_to_children) + # makea cached property, then delete this method self._layout_exprs = AxisTree._default_index_exprs(self) + @classmethod + def from_iterable(cls, iterable): + # NOTE: This currently only works for linear trees + item, *iterable = iterable + tree = PartialAxisTree(as_axis_tree(item).parent_to_children) + for item in iterable: + tree = tree.add_subtree(as_axis_tree(item), *tree.leaf) + return tree + + @classmethod + def _check_node_labels_unique_in_paths( + cls, node_map, node=None, seen_labels=frozenset() + ): + from pyop3.tree import InvalidTreeException + + if not node_map: + return + + if node is None: + node = just_one(node_map[None]) + + if node.label in seen_labels: + raise InvalidTreeException("Duplicate labels found along a path") + + for subnode in filter(None, node_map.get(node.id, [])): + cls._check_node_labels_unique_in_paths( + node_map, subnode, seen_labels | {node.label} + ) + def set_up(self): return AxisTree.from_partial_tree(self) @@ -582,7 +774,9 @@ def leaf_axis(self): @property def leaf_component(self): - return self.leaf[1] + leaf_axis, leaf_clabel = self.leaf + leaf_cidx = leaf_axis.component_index(leaf_clabel) + return leaf_axis.components[leaf_cidx] @cached_property def size(self): @@ -590,16 +784,62 @@ def size(self): return axis_tree_size(self) + @cached_property + def global_size(self): + from pyop3.array import HierarchicalArray + from pyop3.axtree.layout import _axis_size, my_product + + if not self.outer_loops: + return self.size + + mysize = 0 + for idxs in my_product(self.outer_loops): + loop_exprs = {idx.index.id: idx.source_exprs for idx in idxs} + # target_indices = merge_dicts(idx.target_exprs for idx in idxs) + # this is a hack + if self.is_empty: + mysize += 1 + else: + mysize += _axis_size(self, self.root, loop_indices=loop_exprs) + return mysize + + if isinstance(self.size, HierarchicalArray): + # does this happen any more? + return np.sum(self.size.data_ro, dtype=IntType) + if isinstance(self.size, np.ndarray): + return np.sum(self.size, dtype=IntType) + else: + assert isinstance(self.size, numbers.Integral) + return self.size + + # rename to local_size? def alloc_size(self, axis=None): axis = axis or self.root return sum(cpt.alloc_size(self, axis) for cpt in axis.components) +class LoopIndexReplacer(pym.mapper.IdentityMapper): + def __init__(self, replace_map): + super().__init__() + self._replace_map = replace_map + + def map_axis_variable(self, var): + try: + return self._replace_map[var.axis] + except KeyError: + return var + + def map_array(self, array_var): + indices = {ax: self(expr) for ax, expr in array_var.indices.items()} + return type(array_var)(array_var.array, indices, array_var.path) + + @frozen_record class AxisTree(PartialAxisTree, Indexed, ContextFreeLoopIterable): fields = PartialAxisTree.fields | { "target_paths", "index_exprs", + "outer_loops", "layout_exprs", } @@ -608,32 +848,60 @@ def __init__( parent_to_children=pmap(), target_paths=None, index_exprs=None, + outer_loops=None, layout_exprs=None, ): - if some_but_not_all( - arg is None for arg in [target_paths, index_exprs, layout_exprs] - ): - raise ValueError + # if some_but_not_all( + # arg is None + # for arg in [target_paths, index_exprs, outer_loops, layout_exprs] + # ): + # raise ValueError + + if outer_loops is None: + outer_loops = () + else: + assert isinstance(outer_loops, tuple) super().__init__(parent_to_children) self._target_paths = target_paths or self._default_target_paths() self._index_exprs = index_exprs or self._default_index_exprs() self.layout_exprs = layout_exprs or self._default_layout_exprs() + self._outer_loops = tuple(outer_loops) def __getitem__(self, indices): - from pyop3.itree.tree import as_index_forest, collect_loop_contexts, index_axes + from pyop3.itree.tree import _compose_bits, _index_axes, as_index_forest if indices is Ellipsis: + raise NotImplementedError("TODO") indices = index_tree_from_ellipsis(self) - if not collect_loop_contexts(indices): - index_tree = just_one(as_index_forest(indices, axes=self)) - return index_axes(self, index_tree) - axis_trees = {} - for index_tree in as_index_forest(indices, axes=self): - axis_trees[index_tree.loop_context] = index_axes(self, index_tree) - return ContextSensitiveAxisTree(axis_trees) + for context, index_tree in as_index_forest(indices, axes=self).items(): + indexed_axes = _index_axes(index_tree, context, self) + + target_paths, index_exprs, layout_exprs = _compose_bits( + self, + self.target_paths, + self.index_exprs, + self.layout_exprs, + indexed_axes, + indexed_axes.target_paths, + indexed_axes.index_exprs, + indexed_axes.layout_exprs, + ) + axis_tree = AxisTree( + indexed_axes.parent_to_children, + target_paths, + index_exprs, + outer_loops=indexed_axes.outer_loops, + layout_exprs=layout_exprs, + ) + axis_trees[context] = axis_tree + + if len(axis_trees) == 1 and just_one(axis_trees.keys()) == pmap(): + return axis_trees[pmap()] + else: + return ContextSensitiveAxisTree(axis_trees) @classmethod def from_nest(cls, nest) -> AxisTree: @@ -641,6 +909,18 @@ def from_nest(cls, nest) -> AxisTree: node_map.update({None: [root]}) return cls.from_node_map(node_map) + @classmethod + def from_iterable( + cls, iterable, *, target_paths=None, index_exprs=None, layout_exprs=None + ) -> AxisTree: + tree = PartialAxisTree.from_iterable(iterable) + return AxisTree( + tree.parent_to_children, + target_paths=target_paths, + index_exprs=index_exprs, + layout_exprs=layout_exprs, + ) + @classmethod def from_node_map(cls, node_map): tree = PartialAxisTree(node_map) @@ -651,17 +931,48 @@ def from_partial_tree(cls, tree: PartialAxisTree) -> AxisTree: target_paths = cls._default_target_paths(tree) index_exprs = cls._default_index_exprs(tree) layout_exprs = index_exprs + outer_loops = () return cls( tree.parent_to_children, target_paths, index_exprs, - layout_exprs, + outer_loops=outer_loops, + layout_exprs=layout_exprs, ) - def index(self): - from pyop3.itree import LoopIndex + def index(self, ghost=False): + from pyop3.itree.tree import ContextFreeLoopIndex, LoopIndex + + iterset = self if ghost else self.owned + # If the iterset is linear (single-component for every axis) then we + # can consider the loop to be "context-free". + if len(iterset.leaves) == 1: + path = iterset.path(*iterset.leaf) + target_path = {} + for ax, cpt in iterset.path_with_nodes(*iterset.leaf).items(): + target_path.update(iterset.target_paths.get((ax.id, cpt), {})) + return ContextFreeLoopIndex(iterset, path, target_path) + else: + return LoopIndex(iterset) + + def iter(self, outer_loops=(), loop_index=None, include=False, ghost=False): + from pyop3.itree.tree import iter_axis_tree + + iterset = self if ghost else self.owned + + return iter_axis_tree( + # hack because sometimes we know the right loop index to use + loop_index or self.index(), + iterset, + self.target_paths, + self.index_exprs, + outer_loops, + include, + ) - return LoopIndex(self.owned) + @property + def axes(self): + return self @property def target_paths(self): @@ -671,39 +982,208 @@ def target_paths(self): def index_exprs(self): return self._index_exprs + @property + def outer_loops(self): + return self._outer_loops + + # This could easily be two functions + @cached_property + def outer_loop_bits(self): + from pyop3.itree.tree import LocalLoopIndexVariable + + if len(self.outer_loops) > 1: + # We do not yet support something like dat[p, q] if p and q + # are independent (i.e. q != f(p) ). + raise NotImplementedError( + "Multiple independent outer loops are not supported." + ) + loop = just_one(self.outer_loops) + + # TODO: Don't think this is needed + # Since loop itersets must be linear, we can unpack target_paths + # and index_exprs from + # + # {(axis_id, component_label): {axis_label: expr}} + # + # to simply + # + # {axis_label: expr} + flat_target_paths = {} + flat_index_exprs = {} + for axis in loop.iterset.nodes: + key = (axis.id, axis.component.label) + flat_target_paths.update(loop.iterset.target_paths.get(key, {})) + flat_index_exprs.update(loop.iterset.index_exprs.get(key, {})) + + # Make sure that the layout axes are uniquely labelled. + suffix = f"_{loop.id}" + loop_axes = relabel_axes(loop.iterset, suffix) + + # Nasty hack: loop_axes need to be a PartialAxisTree so we can add to it. + loop_axes = PartialAxisTree(loop_axes.parent_to_children) + + # When we tabulate the layout, the layout expressions will contain + # axis variables that we actually want to be loop index variables. Here + # we construct the right replacement map. + loop_vars = { + axis.label + suffix: LocalLoopIndexVariable(loop, axis.label) + for axis in loop.iterset.nodes + } + + # Recursively fetch other outer loops and make them the root of + # the current axes. + if loop.iterset.outer_loops: + ax_rec, lv_rec = loop.iterset.outer_loop_bits + loop_axes = ax_rec.add_subtree(loop_axes, *ax_rec.leaf) + loop_vars.update(lv_rec) + + return loop_axes, freeze(loop_vars) + + ### + + # # NOTE: Using iterset.size feels a bit wrong here, but it is indexed + # # correctly so I think that it's the right thing. Care will need to be + # # taken if outer loops with multiple output axes are supported (e.g. + # # loops over extruded cells). + # loop_axis = Axis(outer_loop.iterset.size, outer_loop.id) + # loop_axis_key = (loop_axis.id, loop_axis.component.label) + # axes_iter = (loop_axis,) + # + # # This is valid because we can only target one axis currently. + # target_axis_label = just_one(flat_target_paths.keys()) + # target_paths = {loop_axis_key: flat_target_paths} + # + # # Once we have tabulated a layout with these axes, replace the axis + # # variables in the layouts with the right index expressions that + # # are composed of source loop index variables. + # # Usually substituting index_exprs into layouts is not a safe thing + # # to do eagerly because axes may be indexed again which would then + # # not work. It *is* safe to do for loop indices though because those + # # axes get eliminated and cannot be further indexed. + # # TODO: Provide an example. + # # NOTE: Ideally index_exprs should only know about target expressions. + # # The source expressions here muddy things. + # orig_expr = flat_index_exprs[target_axis_label] + # # NOTE: The replace map actually contains non-local loop index + # # variables. In a refactor this should be dropped in favour of + # # the actual loop index expression containing local indices. + # replace_map = { + # target_axis_label: LoopIndexVariable(outer_loop, target_axis_label) + # } + # new_expr = LoopIndexReplacer(replace_map)(orig_expr) + # + # # Try returning a flat thing instead, this isn't quite the same as + # # "normal" index_exprs + # # index_exprs = {loop_axis_key: {target_axis_label: new_expr}} + # # index_exprs = {outer_loop.id: new_expr} + # index_exprs = {outer_loop.id: LoopIndexVariable(outer_loop, target_axis_label)} + # + # # Recursively fetch other outer loops and make them the root of + # # the current axes. + # if outer_loop.iterset.outer_loops: + # ax_rec, tp_rec, ie_rec = outer_loop.iterset.outer_loop_bits + # axes_iter = ax_rec + axes_iter + # target_paths.update(tp_rec) + # index_exprs.update(ie_rec) + # + # return axes_iter, freeze(target_paths), freeze(index_exprs) + + @cached_property + def layout_axes(self): + if not self.outer_loops: + return self + loop_axes, _ = self.outer_loop_bits + return loop_axes.add_subtree(self, *loop_axes.leaf).set_up() + @cached_property def layouts(self): """Initialise the multi-axis by computing the layout functions.""" - from pyop3.axtree.layout import _collect_at_leaves, _compute_layouts - from pyop3.itree.tree import IndexExpressionReplacer + from pyop3.axtree.layout import ( + _collect_at_leaves, + _compute_layouts, + collect_externally_indexed_axes, + ) + from pyop3.itree.tree import IndexExpressionReplacer, LoopIndexVariable - if self.is_empty: - return pmap({pmap(): 0}) + if self.layout_axes.is_empty: + return freeze({pmap(): 0}) + + loop_vars = self.outer_loop_bits[1] if self.outer_loops else {} + layouts, check_none, _ = _compute_layouts(self.layout_axes, loop_vars) + + assert check_none is None - layouts, _, _, _ = _compute_layouts(self, self.root) - layoutsnew = _collect_at_leaves(self, layouts) + layoutsnew = _collect_at_leaves(self, self.layout_axes, layouts) layouts = freeze(dict(layoutsnew)) - layouts_ = {} - for leaf in self.leaves: - orig_path = self.path(*leaf) - new_path = {} - replace_map = {} - for axis, cpt in self.path_with_nodes(*leaf).items(): - new_path.update(self.target_paths[axis.id, cpt]) - replace_map.update(self.layout_exprs[axis.id, cpt]) - new_path = freeze(new_path) - - orig_layout = layouts[orig_path] - new_layout = IndexExpressionReplacer(replace_map)(orig_layout) - # assert new_layout != orig_layout - layouts_[new_path] = new_layout + if self.outer_loops: + _, loop_vars = self.outer_loop_bits + + layouts_ = {} + for k, layout in layouts.items(): + layouts_[k] = IndexExpressionReplacer(loop_vars)(layout) + layouts = freeze(layouts_) + + # for now + return freeze(layouts) + + # Have not considered how to do sparse things with external loops + if self.layout_axes.depth > self.depth: + return layouts + + layouts_ = {pmap(): 0} + for axis in self.nodes: + for component in axis.components: + orig_path = self.path(axis, component) + new_path = {} + replace_map = {} + for ax, cpt in self.path_with_nodes(axis, component).items(): + new_path.update(self.target_paths.get((ax.id, cpt), {})) + replace_map.update(self.layout_exprs.get((ax.id, cpt), {})) + new_path = freeze(new_path) + + orig_layout = layouts[orig_path] + new_layout = IndexExpressionReplacer(replace_map, loop_exprs)( + orig_layout + ) + layouts_[new_path] = new_layout return freeze(layouts_) + @cached_property + def leaf_target_paths(self): + return tuple( + merge_dicts( + self.target_paths.get((ax.id, clabel), {}) + for ax, clabel in self.path_with_nodes(*leaf, ordered=True) + ) + for leaf in self.leaves + ) + @cached_property def sf(self): return self._default_sf() + # @property + # def lgmap(self): + # if not hasattr(self, "_lazy_lgmap"): + # # if self.sf.nleaves == 0 then some assumptions are broken in + # # ISLocalToGlobalMappingCreateSF, but we need to be careful things are done + # # collectively + # self.sf.sf.view() + # lgmap = PETSc.LGMap().createSF(self.sf.sf, PETSc.DECIDE) + # lgmap.setType(PETSc.LGMap.Type.BASIC) + # self._lazy_lgmap = lgmap + # lgmap.view() + # return self._lazy_lgmap + + @property + def comm(self): + paraxes = [axis for axis in self.nodes if axis.sf is not None] + if not paraxes: + return MPI.COMM_SELF + else: + return single_valued(ax.comm for ax in paraxes) + @cached_property def datamap(self): if self.is_empty: @@ -716,9 +1196,6 @@ def datamap(self): for expr in exprs.values(): for array in MultiArrayCollector()(expr): dmap.update(array.datamap) - for layout_expr in self.layouts.values(): - for array in MultiArrayCollector()(layout_expr): - dmap.update(array.datamap) return pmap(dmap) @cached_property @@ -726,6 +1203,9 @@ def owned(self): """Return the owned portion of the axis tree.""" from pyop3.itree import AffineSliceComponent, Slice + if self.comm.size == 1: + return self + paraxes = [axis for axis in self.nodes if axis.sf is not None] if len(paraxes) == 0: return self @@ -737,46 +1217,40 @@ def owned(self): AffineSliceComponent( c.label, stop=paraxis.owned_count_per_component[c], + # this feels like a hack, generally don't want this ambiguity + label=c.label, ) for c in paraxis.components ] - slice_ = Slice(paraxis.label, slices) + # this feels like a hack, generally don't want this ambiguity + slice_ = Slice(paraxis.label, slices, label=paraxis.label) return self[slice_] def freeze(self): return self - # needed here? or just for the HierarchicalArray? perhaps a free function? - def offset(self, *args, allow_unused=False, insert_zeros=False): - nargs = len(args) - if nargs == 2: - path, indices = args[0], args[1] - else: - assert nargs == 1 - path, indices = _path_and_indices_from_index_tuple(self, args[0]) - - if allow_unused: - path = _trim_path(self, path) - - if insert_zeros: - # extend the path by choosing the zero offset option every time - # this is needed if we don't have all the internal bits available - while path not in self.layouts: - axis, clabel = self._node_from_path(path) - subaxis = self.component_child(axis, clabel) - # choose the component that is first in the renumbering - if subaxis.numbering: - cidx = subaxis._component_index_from_axis_number( - subaxis.numbering.data_ro[0] - ) - else: - cidx = 0 - subcpt = subaxis.components[cidx] - path |= {subaxis.label: subcpt.label} - indices |= {subaxis.label: 0} + def as_tree(self): + return self - offset = pym.evaluate(self.layouts[path], indices, ExpressionEvaluator) - return strict_int(offset) + def offset(self, indices, path=None, *, loop_exprs=pmap()): + from pyop3.axtree.layout import eval_offset + + # return eval_offset( + # self, + # self.layouts, + # indices, + # self.target_paths, + # self.index_exprs, + # path, + # loop_exprs=loop_exprs, + # ) + return eval_offset( + self, + self.subst_layouts, + indices, + path, + loop_exprs=loop_exprs, + ) @cached_property def owned_size(self): @@ -824,11 +1298,12 @@ def _default_sf(self): from pyop3.axtree.parallel import collect_sf_graphs if self.is_empty: - return None + # no, this is probably not right. Could have a global + return serial_forest(self.global_size) graphs = collect_sf_graphs(self) if len(graphs) == 0: - return None + return serial_forest(self.global_size) else: # merge the graphs nroots = 0 @@ -841,8 +1316,35 @@ def _default_sf(self): iremotes.append(iremote) ilocal = np.concatenate(ilocals) iremote = np.concatenate(iremotes) - # fixme, get the right comm (and ensure consistency) - return StarForest.from_graph(self.size, nroots, ilocal, iremote) + return StarForest.from_graph(self.size, nroots, ilocal, iremote, self.comm) + + # should be a cached property? + def global_numbering(self): + if self.comm.size == 1: + return np.arange(self.size, dtype=IntType) + + numbering = np.full(self.size, -1, dtype=IntType) + + start = self.sf.comm.tompi4py().exscan(self.owned.size, MPI.SUM) + if start is None: + start = 0 + + # TODO do I need to account for numbering/layouts? The SF should probably + # manage this. + numbering[: self.owned.size] = np.arange( + start, start + self.owned.size, dtype=IntType + ) + # numbering[self.numbering.data_ro[: self.owned.size]] = np.arange( + # start, start + self.owned.size, dtype=IntType + # ) + + # print_with_rank("before", numbering) + + self.sf.broadcast(numbering, MPI.REPLACE) + + # print_with_rank("after", numbering) + debug_assert(lambda: (numbering >= 0).all()) + return numbering class ContextSensitiveAxisTree(ContextSensitiveLoopIterable): @@ -943,8 +1445,8 @@ def _(arg: tuple) -> AxisComponent: @functools.singledispatch -def _as_axis_component_label(arg: Any) -> ComponentLabel: - if isinstance(arg, ComponentLabel): +def _as_axis_component_label(arg: Any): + if isinstance(arg, str): return arg else: raise TypeError(f"No handler registered for {type(arg).__name__}") @@ -955,50 +1457,16 @@ def _(component: AxisComponent): return component.label -def _path_and_indices_from_index_tuple(axes, index_tuple): - from pyop3.axtree.layout import _as_int - - path = pmap() - indices = pmap() - axis = axes.root - for index in index_tuple: - if axis is None: - raise IndexError("Too many indices provided") - if isinstance(index, numbers.Integral): - if axis.degree > 1: - raise IndexError( - "Cannot index multi-component array with integers, a " - "2-tuple of (component index, index value) is needed" - ) - cpt_label = axis.components[0].label - else: - cpt_label, index = index - - cpt_index = axis.component_labels.index(cpt_label) - - if index < 0: - # In theory we could still get this to work... - raise IndexError("Cannot use negative indices") - # TODO need to pass indices here for ragged things - if index >= _as_int(axis.components[cpt_index].count, path, indices): - raise IndexError("Index is too large") - - indices |= {axis.label: index} - path |= {axis.label: cpt_label} - axis = axes.component_child(axis, cpt_label) - - if axis is not None: - raise IndexError("Insufficient number of indices given") - - return path, indices - - -def _trim_path(axes: AxisTree, path) -> pmap: - """Drop unused axes from the axis path.""" - new_path = {} - axis = axes.root - while axis: - cpt_label = path[axis.label] - new_path[axis.label] = cpt_label - axis = axes.component_child(axis, cpt_label) - return pmap(new_path) +def relabel_axes(axes: AxisTree, suffix: str) -> AxisTree: + # comprehension? + parent_to_children = {} + for parent_id, children in axes.parent_to_children.items(): + children_ = [] + for axis in children: + if axis is not None: + axis_ = axis.copy(label=axis.label + suffix) + else: + axis_ = None + children_.append(axis_) + parent_to_children[parent_id] = children_ + return AxisTree(parent_to_children) diff --git a/pyop3/buffer.py b/pyop3/buffer.py index 2bc741c4..870bed28 100644 --- a/pyop3/buffer.py +++ b/pyop3/buffer.py @@ -1,14 +1,18 @@ from __future__ import annotations import abc +import contextlib import numbers from functools import cached_property import numpy as np from mpi4py import MPI +from petsc4py import PETSc +from pyrsistent import freeze, pmap from pyop3.dtypes import ScalarType -from pyop3.lang import KernelArgument +from pyop3.lang import READ, RW, WRITE, KernelArgument +from pyop3.sf import StarForest from pyop3.utils import UniqueNameGenerator, as_tuple, deprecated, readonly @@ -46,8 +50,40 @@ class Buffer(KernelArgument, abc.ABC): def dtype(self): pass + @property + @abc.abstractmethod + def datamap(self): + pass + + @property + def kernel_dtype(self): + return self.dtype + + +class NullBuffer(Buffer): + """A buffer that does not carry data. + + This is useful for handling temporaries when we generate code. For much + of the compilation we want to treat temporaries like ordinary arrays but + they are not passed as kernel arguments nor do they have any parallel + semantics. + + """ + + def __init__(self, dtype=None): + if dtype is None: + dtype = self.DEFAULT_DTYPE + self._dtype = dtype + + @property + def dtype(self): + return self._dtype + + @property + def datamap(self): + return pmap() + -# TODO should AbstractBuffer be a class and then a serial buffer can be its own class? class DistributedBuffer(Buffer): """An array distributed across multiple processors with ghost values.""" @@ -61,9 +97,20 @@ class DistributedBuffer(Buffer): _name_generator = UniqueNameGenerator() def __init__( - self, shape, dtype=None, *, name=None, prefix=None, data=None, sf=None + self, + shape, + sf_or_comm, + dtype=None, + *, + name=None, + prefix=None, + data=None, ): shape = as_tuple(shape) + + if not all(isinstance(s, numbers.Integral) for s in shape): + raise TypeError + if dtype is None: dtype = self.DEFAULT_DTYPE @@ -76,14 +123,24 @@ def __init__( if data.dtype != dtype: raise ValueError - if sf and shape[0] != sf.size: - raise IncompatibleStarForestException + if isinstance(sf_or_comm, StarForest): + sf = sf_or_comm + comm = sf.comm + # TODO I don't really like having shape as an argument... + if sf and shape[0] != sf.size: + raise IncompatibleStarForestException + else: + sf = None + comm = sf_or_comm self.shape = shape self._dtype = dtype self._lazy_data = data self.sf = sf + assert comm is not None + self.comm = comm + self.name = name or self._name_generator(prefix or self._prefix) # counter used to keep track of modifications @@ -94,6 +151,8 @@ def __init__( self._pending_reduction = None self._finalizer = None + self._lazy_vec = None + # @classmethod # def from_array(cls, array: np.ndarray, **kwargs): # return cls(array.shape, array.dtype, data=array, **kwargs) @@ -146,7 +205,47 @@ def data_wo(self): @property def is_distributed(self) -> bool: - return self.sf is not None + return self.comm.size > 1 + + @property + def leaves_valid(self) -> bool: + return self._leaves_valid + + @property + def datamap(self): + return freeze({self.name: self}) + + @contextlib.contextmanager + def vec_context(self, intent): + """Wrap the buffer in a PETSc Vec. + + TODO implement intent parameter + + """ + yield self._vec + # if access is not Access.READ: + # self.halo_valid = False + + @property + @deprecated(".vec_rw") + def vec(self): + return self.vec_rw + + @property + def vec_rw(self): + # TODO I don't think that intent is the right thing here. We really only have + # READ, WRITE or RW + return self.vec_context(RW) + + @property + def vec_ro(self): + # TODO I don't think that intent is the right thing here. We really only have + # READ, WRITE or RW + return self.vec_context(READ) + + @property + def vec_wo(self): + return self.vec_context(WRITE) @property def _data(self): @@ -156,7 +255,7 @@ def _data(self): @property def _owned_data(self): - if self.is_distributed: + if self.is_distributed and self.sf.nleaves > 0: return self._data[: -self.sf.nleaves] else: return self._data @@ -172,9 +271,10 @@ def _transfer_in_flight(self) -> bool: @cached_property def _reduction_ops(self): # TODO Move this import out, requires moving location of these intents - from pyop3.lang import INC + from pyop3.lang import INC, WRITE return { + WRITE: MPI.REPLACE, INC: MPI.SUM, } @@ -239,6 +339,19 @@ def _reduce_then_broadcast(self): self._reduce_leaves_to_roots() self._broadcast_roots_to_leaves() + @property + def _vec(self): + if self.dtype != PETSc.ScalarType: + raise RuntimeError( + f"Cannot create a Vec with data type {self.dtype}, " + "must be {PETSc.ScalarType}" + ) + + if self._lazy_vec is None: + vec = PETSc.Vec().createWithArray(self._owned_data, comm=self.comm) + self._lazy_vec = vec + return self._lazy_vec + class PackedBuffer(Buffer): """Abstract buffer originating from a function call. @@ -248,9 +361,6 @@ class PackedBuffer(Buffer): """ - # TODO Haven't exactly decided on the right API here, subclasses? - # def __init__(self, pack_fn, unpack_fn, dtype): - # self._dtype = dtype def __init__(self, array): self.array = array @@ -258,3 +368,7 @@ def __init__(self, array): @property def dtype(self): return self.array.dtype + + @property + def is_distributed(self) -> bool: + return False diff --git a/pyop3/cache.py b/pyop3/cache.py index 49711daa..4f8d7232 100644 --- a/pyop3/cache.py +++ b/pyop3/cache.py @@ -1,38 +1,3 @@ -# This file is part of PyOP2 -# -# PyOP2 is Copyright (c) 2012, Imperial College London and -# others. Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Provides common base classes for cached objects.""" - import hashlib import os import pickle diff --git a/pyop3/config.py b/pyop3/config.py index 33053a46..1108a896 100644 --- a/pyop3/config.py +++ b/pyop3/config.py @@ -64,6 +64,7 @@ class Configuration(dict): "print_cache_size": ("PYOP3_PRINT_CACHE_SIZE", bool, False), "matnest": ("PYOP3_MATNEST", bool, True), "block_sparsity": ("PYOP3_BLOCK_SPARSITY", bool, True), + "max_static_array_size": ("PYOP3_MAX_STATIC_ARRAY_SIZE", int, 100), } """Default values for PyOP2 configuration parameters""" diff --git a/pyop3/ir/lower.py b/pyop3/ir/lower.py index a388dda9..db11c3c7 100644 --- a/pyop3/ir/lower.py +++ b/pyop3/ir/lower.py @@ -1,34 +1,27 @@ from __future__ import annotations import abc -import collections import contextlib -import copy -import dataclasses import enum import functools -import itertools import numbers -import operator import textwrap +from functools import cached_property from typing import Any, Dict, FrozenSet, Optional, Sequence, Tuple, Union import loopy as lp import loopy.symbolic import numpy as np import pymbolic as pym -import pytools -from petsc4py import PETSc from pyrsistent import freeze, pmap -from pyop3 import utils -from pyop3.array import HierarchicalArray, PackedPetscMatAIJ, PetscMatAIJ -from pyop3.array.harray import ContextSensitiveMultiArray -from pyop3.array.petsc import PetscMat, PetscObject -from pyop3.axtree import Axis, AxisComponent, AxisTree, AxisVariable -from pyop3.axtree.tree import ContextSensitiveAxisTree -from pyop3.buffer import DistributedBuffer, PackedBuffer -from pyop3.dtypes import IntType, PointerType +from pyop3.array import HierarchicalArray +from pyop3.array.harray import CalledMapVariable, ContextSensitiveMultiArray +from pyop3.array.petsc import PetscMat +from pyop3.axtree import Axis, AxisComponent, AxisTree, AxisVariable, ContextFree +from pyop3.buffer import DistributedBuffer, NullBuffer, PackedBuffer +from pyop3.config import config +from pyop3.dtypes import IntType from pyop3.itree import ( AffineSliceComponent, CalledMap, @@ -37,15 +30,16 @@ LocalLoopIndex, LoopIndex, Map, - MapVariable, Slice, Subset, TabulatedMapComponent, ) from pyop3.itree.tree import ( - CalledMapVariable, + ContextFreeLoopIndex, IndexExpressionReplacer, + LocalLoopIndexVariable, LoopIndexVariable, + collect_shape_index_callback, ) from pyop3.lang import ( INC, @@ -53,17 +47,27 @@ MAX_WRITE, MIN_RW, MIN_WRITE, + NA, READ, RW, WRITE, + AddAssignment, Assignment, CalledFunction, + ContextAwareLoop, + DummyKernelArgument, Loop, + PetscMatAdd, + PetscMatInstruction, + PetscMatLoad, + PetscMatStore, + ReplaceAssignment, ) from pyop3.log import logger -from pyop3.tensor import Dat, Tensor from pyop3.utils import ( PrettyTuple, + UniqueNameGenerator, + as_tuple, checked_zip, just_one, merge_dicts, @@ -89,6 +93,18 @@ class AssignmentType(enum.Enum): ZERO = enum.auto() +class Renamer(pym.mapper.IdentityMapper): + def __init__(self, replace_map): + super().__init__() + self._replace_map = replace_map + + def map_variable(self, var): + try: + return pym.var(self._replace_map[var.name]) + except KeyError: + return var + + class CodegenContext(abc.ABC): pass @@ -100,10 +116,17 @@ def __init__(self): self._args = [] self._subkernels = [] + self.actual_to_kernel_rename_map = {} + self._within_inames = frozenset() self._last_insn_id = None - self._name_generator = pytools.UniqueNameGenerator() + self._name_generator = UniqueNameGenerator() + + # TODO remove + self._dummy_names = {} + + self._seen_arrays = set() @property def domains(self): @@ -115,13 +138,19 @@ def instructions(self): @property def arguments(self): - # TODO should renumber things here return tuple(self._args) @property def subkernels(self): return tuple(self._subkernels) + @property + def kernel_to_actual_rename_map(self): + return { + kernel: actual + for actual, kernel in self.actual_to_kernel_rename_map.items() + } + def add_domain(self, iname, *args): nargs = len(args) if nargs == 1: @@ -133,6 +162,16 @@ def add_domain(self, iname, *args): self._domains.append(domain_str) def add_assignment(self, assignee, expression, prefix="insn"): + # TODO recover this functionality, in other words we should produce + # non-renamed expressions. This means that the Renamer can also register + # arguments so we only use the ones we actually need! + + # renamer = Renamer(self.actual_to_kernel_rename_map) + # assignee = renamer(assignee) + # expression = renamer(expression) + + # breakpoint() + insn = lp.Assignment( assignee, expression, @@ -166,22 +205,46 @@ def add_function_call(self, assignees, expression, prefix="insn"): ) self._add_instruction(insn) + # TODO wrap into add_argument + def add_dummy_argument(self, arg, dtype): + if arg in self._dummy_names: + name = self._dummy_names[arg] + else: + name = self._dummy_names.setdefault(arg, self._name_generator("dummy")) + self._args.append(lp.ValueArg(name, dtype=dtype)) + + # deprecated def add_argument(self, array): - # FIXME if self._args is a set then we can add duplicates here provided - # that we canonically renumber at a later point - if array.name in [a.name for a in self._args]: - logger.debug( - f"Skipping adding {array.name} to the codegen context as it is already present" - ) + return self.add_array(array) + + # TODO we pass a lot more data here than we need I think, need to use unique *buffers* + def add_array(self, array: HierarchicalArray) -> None: + if array.name in self._seen_arrays: return + self._seen_arrays.add(array.name) + + debug = bool(config["debug"]) - if isinstance(array.buffer, PackedPetscMatAIJ): - arg = lp.ValueArg(array.name, dtype=self._dtype(array)) + injected = array.constant and array.size < config["max_static_array_size"] + if isinstance(array.buffer, NullBuffer) or injected: + name = self.unique_name("t") if not debug else array.name + shape = self._temporary_shapes.get(array.name, (array.alloc_size,)) + initializer = array.buffer.data_ro if injected else None + arg = lp.TemporaryVariable( + name, dtype=array.dtype, shape=shape, initializer=initializer + ) + elif isinstance(array.buffer, PackedBuffer): + name = self.unique_name("packed") if not debug else array.name + arg = lp.ValueArg(name, dtype=self._dtype(array)) else: + name = self.unique_name("array") if not debug else array.name assert isinstance(array.buffer, DistributedBuffer) - arg = lp.GlobalArg(array.name, dtype=self._dtype(array), shape=None) + arg = lp.GlobalArg(name, dtype=self._dtype(array), shape=None) + + self.actual_to_kernel_rename_map[array.name] = name self._args.append(arg) + # can this now go? no, not all things are arrays def add_temporary(self, name, dtype=IntType, shape=()): temp = lp.TemporaryVariable(name, dtype=dtype, shape=shape) self._args.append(temp) @@ -191,9 +254,6 @@ def add_subkernel(self, subkernel): # I am not sure that this belongs here, I generate names separately from adding domains etc def unique_name(self, prefix): - # add prefix to the generator so names are generated starting with - # "prefix_0" instead of "prefix" - self._name_generator.add_name(prefix, conflicting_ok=True) return self._name_generator(prefix) @contextlib.contextmanager @@ -235,8 +295,8 @@ def _(self, array): return array.dtype @_dtype.register - def _(self, array: PackedPetscMatAIJ): - return OpaqueType("Mat") + def _(self, array: PackedBuffer): + return self._dtype(array.array) @_dtype.register def _(self, array: PetscMat): @@ -246,21 +306,31 @@ def _add_instruction(self, insn): self._insns.append(insn) self._last_insn_id = insn.id + # FIXME, bad API + def set_temporary_shapes(self, shapes): + self._temporary_shapes = shapes + class CodegenResult: - # TODO also accept a map from input arrays to the renumbered ones, helpful for replacement - def __init__(self, expr, ir): - self.expr = expr + def __init__(self, expr, ir, arg_replace_map): + self.expr = as_tuple(expr) self.ir = ir + self.arg_replace_map = arg_replace_map + + @cached_property + def datamap(self): + return merge_dicts(e.datamap for e in self.expr) def __call__(self, **kwargs): from pyop3.target import compile_loopy - args = [ - _as_pointer(kwargs.get(arg.name, self.expr.datamap[arg.name])) - for arg in self.ir.default_entrypoint.args - ] - compile_loopy(self.ir)(*args) + # breakpoint() + data_args = [] + for kernel_arg in self.ir.default_entrypoint.args: + actual_arg_name = self.arg_replace_map[kernel_arg.name] + array = kwargs.get(actual_arg_name, self.datamap[actual_arg_name]) + data_args.append(_as_pointer(array)) + compile_loopy(self.ir)(*data_args) def target_code(self, target): raise NotImplementedError("TODO") @@ -316,9 +386,37 @@ def generate_preambles(self, target): # prefer generate_code? -def compile(expr: LoopExpr, name="mykernel"): +def compile(expr: Instruction, name="mykernel"): + # preprocess expr before lowering + from pyop3.transform import expand_implicit_pack_unpack, expand_loop_contexts + + cs_expr = expand_loop_contexts(expr) ctx = LoopyCodegenContext() - _compile(expr, pmap(), ctx) + for context, expr in cs_expr: + expr = expand_implicit_pack_unpack(expr) + + # add external loop indices as kernel arguments + loop_indices = {} + for index, (path, _) in context.items(): + if len(path) > 1: + raise NotImplementedError("needs to be sorted") + + # dummy = HierarchicalArray(index.iterset, data=NullBuffer(IntType)) + dummy = HierarchicalArray(Axis(1), dtype=IntType) + # this is dreadful, pass an integer array instead + ctx.add_argument(dummy) + myname = ctx.actual_to_kernel_rename_map[dummy.name] + replace_map = { + axis: pym.subscript(pym.var(myname), (i,)) + for i, axis in enumerate(path.keys()) + } + # FIXME currently assume that source and target exprs are the same, they are not! + loop_indices[index] = (replace_map, replace_map) + + for e in as_tuple(expr): + # context manager? + ctx.set_temporary_shapes(_collect_temporary_shapes(e)) + _compile(e, loop_indices, ctx) # add a no-op instruction touching all of the kernel arguments so they are # not silently dropped @@ -365,34 +463,66 @@ def compile(expr: LoopExpr, name="mykernel"): # add callables tu = lp.register_callable(tu, "bsearch", BinarySearchCallable()) - tu = tu.with_entrypoints("mykernel") + tu = tu.with_entrypoints(name) - # breakpoint() - return CodegenResult(expr, tu) + return CodegenResult(expr, tu, ctx.kernel_to_actual_rename_map) +# put into a class in transform.py? @functools.singledispatch -def _compile(expr: Any, ctx: LoopyCodegenContext) -> None: - raise TypeError +def _collect_temporary_shapes(expr): + raise TypeError(f"No handler defined for {type(expr).__name__}") + + +@_collect_temporary_shapes.register +def _(expr: ContextAwareLoop): + shapes = {} + for stmts in expr.statements.values(): + for stmt in stmts: + for temp, shape in _collect_temporary_shapes(stmt).items(): + if temp in shapes: + assert shapes[temp] == shape + else: + shapes[temp] = shape + return shapes + + +@_collect_temporary_shapes.register +def _(expr: Assignment): + return pmap() + + +@_collect_temporary_shapes.register +def _(expr: PetscMatInstruction): + return pmap() + + +@_collect_temporary_shapes.register +def _(call: CalledFunction): + return freeze( + { + arg.name: lp_arg.shape + for lp_arg, arg in checked_zip( + call.function.code.default_entrypoint.args, call.arguments + ) + } + ) + + +@functools.singledispatch +def _compile(expr: Any, loop_indices, ctx: LoopyCodegenContext) -> None: + raise TypeError(f"No handler defined for {type(expr).__name__}") @_compile.register def _( - loop: Loop, + loop: ContextAwareLoop, loop_indices, codegen_context: LoopyCodegenContext, ) -> None: - loop_context = context_from_indices(loop_indices) - iterset = loop.index.iterset.with_context(loop_context) - - loop_index_replace_map = {} - for _, replace_map in loop_indices.values(): - loop_index_replace_map.update(replace_map) - loop_index_replace_map = pmap(loop_index_replace_map) - parse_loop_properly_this_time( loop, - iterset, + loop.index.iterset, loop_indices, codegen_context, ) @@ -406,90 +536,87 @@ def parse_loop_properly_this_time( *, axis=None, source_path=pmap(), - target_path=pmap(), iname_replace_map=pmap(), - jname_replace_map=pmap(), + target_path=None, + index_exprs=None, ): - outer_replace_map = {} - for _, replace_map in loop_indices.values(): - outer_replace_map.update(replace_map) - outer_replace_map = freeze(outer_replace_map) - if axes.is_empty: raise NotImplementedError("does this even make sense?") - axis = axis or axes.root + if axis is None: + target_path = freeze(axes.target_paths.get(None, {})) + + # again, repeated this pattern all over the place + # target_replace_map = {} + index_exprs = freeze(axes.index_exprs.get(None, {})) + # replacer = JnameSubstitutor(outer_replace_map, codegen_context) + # for axis_label, index_expr in index_exprs.items(): + # target_replace_map[axis_label] = replacer(index_expr) + # target_replace_map = freeze(target_replace_map) - domain_insns = [] - leaf_data = [] + axis = axes.root for component in axis.components: - iname = codegen_context.unique_name("i") - extent_var = register_extent( - component.count, - iname_replace_map | jname_replace_map | outer_replace_map, - codegen_context, - ) - codegen_context.add_domain(iname, extent_var) + axis_index_exprs = axes.index_exprs.get((axis.id, component.label), {}) + index_exprs_ = index_exprs | axis_index_exprs + + # FIXME: This is not the cause of my problems + # if component.count != 1: + if True: + iname = codegen_context.unique_name("i") + extent_var = register_extent( + component.count, + iname_replace_map | loop_indices, + codegen_context, + ) + codegen_context.add_domain(iname, extent_var) + axis_replace_map = {axis.label: pym.var(iname)} + within_inames = {iname} + else: + axis_replace_map = {axis.label: 0} + within_inames = set() + + source_path_ = source_path | {axis.label: component.label} + iname_replace_map_ = iname_replace_map | axis_replace_map - new_source_path = source_path | {axis.label: component.label} - new_target_path = target_path | axes.target_paths.get( + target_path_ = target_path | axes.target_paths.get( (axis.id, component.label), {} ) - new_iname_replace_map = iname_replace_map | {axis.label: pym.var(iname)} - - # these aren't jnames! - my_index_exprs = axes.index_exprs.get((axis.id, component.label), {}) - - jname_extras = {} - for axis_label, index_expr in my_index_exprs.items(): - jname_expr = JnameSubstitutor( - new_iname_replace_map | jname_replace_map | outer_replace_map, - codegen_context, - )(index_expr) - # jname_extras[axis_label] = jname_expr - jname_extras[axis_label] = jname_expr - - new_jname_replace_map = jname_replace_map | jname_extras - with codegen_context.within_inames({iname}): - if subaxis := axes.child(axis, component): + with codegen_context.within_inames(within_inames): + subaxis = axes.child(axis, component) + if subaxis: parse_loop_properly_this_time( loop, axes, loop_indices, codegen_context, axis=subaxis, - source_path=new_source_path, - target_path=new_target_path, - iname_replace_map=new_iname_replace_map, - jname_replace_map=new_jname_replace_map, + source_path=source_path_, + iname_replace_map=iname_replace_map_, + target_path=target_path_, + index_exprs=index_exprs_, ) else: - new_iname_replace_map = pmap( - { - (loop.index.local_index.id, myaxislabel): jname_expr - for myaxislabel, jname_expr in new_iname_replace_map.items() - } - ) - new_jname_replace_map = pmap( - { - (loop.index.id, myaxislabel): jname_expr - for myaxislabel, jname_expr in new_jname_replace_map.items() - } + target_replace_map = {} + replacer = JnameSubstitutor( + # outer_replace_map | iname_replace_map_, codegen_context + iname_replace_map_ | loop_indices, + codegen_context, ) - for stmt in loop.statements: + for axis_label, index_expr in index_exprs_.items(): + target_replace_map[axis_label] = replacer(index_expr) + + index_replace_map = target_replace_map + local_index_replace_map = iname_replace_map_ + for stmt in loop.statements[source_path_]: _compile( stmt, loop_indices | { - loop.index: ( - new_target_path, - new_jname_replace_map, - ), - loop.index.local_index: ( - new_source_path, - new_iname_replace_map, + loop.index.id: ( + local_index_replace_map, + index_replace_map, ), }, codegen_context, @@ -498,11 +625,6 @@ def parse_loop_properly_this_time( @_compile.register def _(call: CalledFunction, loop_indices, ctx: LoopyCodegenContext) -> None: - """ - Turn an exprs.FunctionCall into a series of assignment instructions etc. - Handles packing/accessor logic. - """ - temporaries = [] subarrayrefs = {} extents = {} @@ -510,425 +632,355 @@ def _(call: CalledFunction, loop_indices, ctx: LoopyCodegenContext) -> None: # loopy args can contain ragged params too loopy_args = call.function.code.default_entrypoint.args[: len(call.arguments)] for loopy_arg, arg, spec in checked_zip(loopy_args, call.arguments, call.argspec): - loop_context = context_from_indices(loop_indices) - - assert isinstance(arg, (HierarchicalArray, ContextSensitiveMultiArray)) - # FIXME materialize is a bad name here, it implies actually packing the values - # into the temporary. - temporary = arg.with_context(loop_context).materialize() - indexed_temp = temporary - - if loopy_arg.shape is None: - shape = (temporary.alloc_size,) + # this check fails because we currently assume that all arrays require packing + # from pyop3.transform import _requires_pack_unpack + # assert not _requires_pack_unpack(arg) + # old names + temporary = arg + indexed_temp = arg + + if isinstance(arg, DummyKernelArgument): + ctx.add_dummy_argument(arg, loopy_arg.dtype) + name = ctx._dummy_names[arg] + subarrayrefs[arg] = pym.var(name) else: - if np.prod(loopy_arg.shape, dtype=int) != temporary.alloc_size: - raise RuntimeError("Shape mismatch between inner and outer kernels") - shape = loopy_arg.shape - - temporaries.append((arg, indexed_temp, spec.access, shape)) - - # Register data - ctx.add_argument(arg) - - ctx.add_temporary(temporary.name, temporary.dtype, shape) - - # subarrayref nonsense/magic - indices = [] - for s in shape: - iname = ctx.unique_name("i") - ctx.add_domain(iname, s) - indices.append(pym.var(iname)) - indices = tuple(indices) - - subarrayrefs[arg.name] = lp.symbolic.SubArrayRef( - indices, pym.subscript(pym.var(temporary.name), indices) - ) - - # we need to pass sizes through if they are only known at runtime (ragged) - # NOTE: If we register an extent to pass through loopy will complain - # unless we register it as an assumption of the local kernel (e.g. "n <= 3") - - # FIXME ragged is broken since I commented this out! determining shape of - # ragged things requires thought! - # for cidx in range(indexed_temp.index.root.degree): - # extents |= self.collect_extents( - # indexed_temp.index, - # indexed_temp.index.root, - # cidx, - # within_indices, - # within_inames, - # depends_on, - # ) + if loopy_arg.shape is None: + shape = (temporary.alloc_size,) + else: + if np.prod(loopy_arg.shape, dtype=int) != temporary.alloc_size: + raise RuntimeError("Shape mismatch between inner and outer kernels") + shape = loopy_arg.shape + + temporaries.append((arg, indexed_temp, spec.access, shape)) + + # Register data + # TODO This might be bad for temporaries + if isinstance(arg, HierarchicalArray): + ctx.add_argument(arg) + + # this should already be done in an assignment + # ctx.add_temporary(temporary.name, temporary.dtype, shape) + + # subarrayref nonsense/magic + indices = [] + for s in shape: + iname = ctx.unique_name("i") + ctx.add_domain(iname, s) + indices.append(pym.var(iname)) + indices = tuple(indices) + + temp_name = ctx.actual_to_kernel_rename_map[temporary.name] + subarrayrefs[arg] = lp.symbolic.SubArrayRef( + indices, pym.subscript(pym.var(temp_name), indices) + ) # TODO this is pretty much the same as what I do in fix_intents in loopexpr.py # probably best to combine them - could add a sensible check there too. assignees = tuple( - subarrayrefs[arg.name] + subarrayrefs[arg] for arg, spec in checked_zip(call.arguments, call.argspec) + # if spec.access in {WRITE, RW, INC, MIN_RW, MIN_WRITE, MAX_RW, MAX_WRITE, NA} if spec.access in {WRITE, RW, INC, MIN_RW, MIN_WRITE, MAX_RW, MAX_WRITE} ) expression = pym.primitives.Call( pym.var(call.function.code.default_entrypoint.name), tuple( - subarrayrefs[arg.name] + subarrayrefs[arg] for arg, spec in checked_zip(call.arguments, call.argspec) - if spec.access in {READ, RW, INC, MIN_RW, MAX_RW} + if spec.access in {READ, RW, INC, MIN_RW, MAX_RW, NA} ) + tuple(extents.values()), ) - # gathers - for arg, temp, access, shape in temporaries: - if access in {READ, RW, MIN_RW, MAX_RW}: - op = AssignmentType.READ - else: - assert access in {WRITE, INC, MIN_WRITE, MAX_WRITE} - op = AssignmentType.ZERO - parse_assignment(arg, temp, shape, op, loop_indices, ctx) - ctx.add_function_call(assignees, expression) ctx.add_subkernel(call.function.code) - # scatters - for arg, temp, access, shape in temporaries: - if access == READ: - continue - elif access in {WRITE, RW, MIN_RW, MIN_WRITE, MAX_RW, MAX_WRITE}: - op = AssignmentType.WRITE - else: - assert access == INC - op = AssignmentType.INC - parse_assignment(arg, temp, shape, op, loop_indices, ctx) - # FIXME this is practically identical to what we do in build_loop +@_compile.register(Assignment) def parse_assignment( - array, - temp, - shape, - op, + assignment, loop_indices, codegen_ctx, ): - # TODO singledispatch - loop_context = context_from_indices(loop_indices) - - if isinstance(array.with_context(loop_context).buffer, PackedBuffer): - if not isinstance(array.with_context(loop_context).buffer, PackedPetscMatAIJ): - raise NotImplementedError("TODO") - parse_assignment_petscmat( - array.with_context(loop_context), temp, shape, op, loop_indices, codegen_ctx - ) - return - else: - assert isinstance(array.with_context(loop_context).buffer, DistributedBuffer) - - # get the right index tree given the loop context - - axes = array.with_context(loop_context).axes - minimal_context = array.filter_context(loop_context) - - target_path = {} - # for _, jnames in new_indices.values(): - for loop_index, (path, iname_expr) in loop_indices.items(): - if loop_index in minimal_context: - # assert all(k not in jname_replace_map for k in iname_expr) - # jname_replace_map.update(iname_expr) - target_path.update(path) - # jname_replace_map = freeze(jname_replace_map) - target_path = freeze(target_path) - - jname_replace_map = merge_dicts(mymap for _, mymap in loop_indices.values()) - + # this seems wrong parse_assignment_properly_this_time( - array, - temp, - shape, - op, - axes, + assignment, loop_indices, codegen_ctx, - iname_replace_map=jname_replace_map, - jname_replace_map=jname_replace_map, - target_path=target_path, ) -def parse_assignment_petscmat(array, temp, shape, op, loop_indices, codegen_context): - ctx = codegen_context +@_compile.register(PetscMatInstruction) +def _(assignment, loop_indices, codegen_context): + # now emit the right line of code, this should properly be a lp.ScalarCallable + # https://petsc.org/release/manualpages/Mat/MatGetValuesLocal/ - (iraxis, ircpt), (icaxis, iccpt) = array.axes.path_with_nodes( - *array.axes.leaf, ordered=True + mat = assignment.mat_arg.buffer.mat + array = assignment.array_arg + rmap = assignment.mat_arg.buffer.rmap + cmap = assignment.mat_arg.buffer.cmap + + codegen_context.add_argument(assignment.mat_arg) + codegen_context.add_argument(array) + codegen_context.add_argument(rmap) + codegen_context.add_argument(cmap) + + mat_name = codegen_context.actual_to_kernel_rename_map[mat.name] + array_name = codegen_context.actual_to_kernel_rename_map[array.name] + rmap_name = codegen_context.actual_to_kernel_rename_map[rmap.name] + cmap_name = codegen_context.actual_to_kernel_rename_map[cmap.name] + + # these sizes can be expressions that need evaluating + rsize, csize = assignment.mat_arg.buffer.shape + + if not isinstance(rsize, numbers.Integral): + # rindex_exprs = merge_dicts( + # rsize.index_exprs.get((ax.id, clabel), {}) + # for ax, clabel in rsize.axes.path_with_nodes(*rsize.axes.leaf).items() + # ) + rsize_var = register_extent( + # rsize, rindex_exprs, my_replace_map, codegen_context + rsize, + loop_indices, + codegen_context, + ) + else: + rsize_var = rsize + + if not isinstance(csize, numbers.Integral): + # cindex_exprs = merge_dicts( + # csize.index_exprs.get((ax.id, clabel), {}) + # for ax, clabel in csize.axes.path_with_nodes(*csize.axes.leaf).items() + # ) + csize_var = register_extent( + # csize, cindex_exprs, my_replace_map, codegen_context + csize, + loop_indices, + codegen_context, + ) + else: + csize_var = csize + + # rlayouts = rmap.layouts[ + # freeze({rmap.axes.root.label: rmap.axes.root.component.label}) + # ] + rlayouts = rmap.layouts[pmap()] + roffset = JnameSubstitutor(loop_indices, codegen_context)(rlayouts) + + # clayouts = cmap.layouts[ + # freeze({cmap.axes.root.label: cmap.axes.root.component.label}) + # ] + clayouts = cmap.layouts[pmap()] + coffset = JnameSubstitutor(loop_indices, codegen_context)(clayouts) + + irow = f"{rmap_name}[{roffset}]" + icol = f"{cmap_name}[{coffset}]" + + # debug + # MatSetValuesLocal(array_4, 1, &(array_5[i_0]), 1, &(array_6[i_0]), &(t_1[0]), ADD_VALUES); + # if rmap.name == "array_5": + # codegen_context.add_cinstruction( + # r""" + # printf("%d\n", i_0); + # printf("%d\n", array_5[i_0]); + # printf("%d\n", array_6[i_0]); + # printf("t_1: %f\n", t_1[0]); + # //printf("t_3: %f, %f, %f, %f, %f, %f\n", t_3[0], t_3[1], t_3[2], t_3[3], t_3[4], t_3[5]); + # //printf("closure_6: %d, %d, %d, %d\n", closure_6[0], closure_6[1], closure_6[2], closure_6[3]); + # //printf("offset_1: %d, %d, %d, %d\n", offset_1[0], offset_1[1], offset_1[2], offset_1[3]); + # //printf("coords: %f, %f, %f, %f\n", firedrake_default_coordinates[0], firedrake_default_coordinates[1], firedrake_default_coordinates[2], firedrake_default_coordinates[3]); + # + # """) + + call_str = _petsc_mat_insn( + assignment, mat_name, array_name, rsize_var, csize_var, irow, icol ) - rkey = (iraxis.id, ircpt) - ckey = (icaxis.id, iccpt) - - rexpr = array.index_exprs[rkey][just_one(array.target_paths[rkey])] - cexpr = array.index_exprs[ckey][just_one(array.target_paths[ckey])] - - mat = array.buffer.array - - # need to generate code like map0[i0] instead of the usual map0[i0, i1] - # this is because we are passing the full map through to the function call - - # similarly we also need to be careful to interrupt this function early - # we don't want to emit loops for things! + codegen_context.add_cinstruction(call_str) - # I believe that this is probably the right place to be flattening the map - # expressions. We want to have already done any clever substitution for arity 1 - # objects. - # rexpr = self._flatten(rexpr) - # cexpr = self._flatten(cexpr) +@functools.singledispatch +def _petsc_mat_insn(assignment, *args): + raise TypeError(f"{assignment} not recognised") - iname_expr_replace_map = {} - for _, replace_map in loop_indices.values(): - iname_expr_replace_map.update(replace_map) - # for now assume that we pass exactly the right map through, do no composition - if not isinstance(rexpr, CalledMapVariable) or len(rexpr.parameters) != 2: - raise NotImplementedError +@_petsc_mat_insn.register +def _(assignment: PetscMatLoad, mat_name, array_name, nrow, ncol, irow, icol): + return f"MatGetValuesLocal({mat_name}, {nrow}, &({irow}), {ncol}, &({icol}), &({array_name}[0]));" - rinner_axis_label = rexpr.function.full_map.name - # substitute a zero for the inner axis, we want to avoid this inner loop - new_rexpr = JnameSubstitutor( - iname_expr_replace_map | {rinner_axis_label: 0}, codegen_context - )(rexpr) +@_petsc_mat_insn.register +def _(assignment: PetscMatStore, mat_name, array_name, nrow, ncol, irow, icol): + return f"MatSetValuesLocal({mat_name}, {nrow}, &({irow}), {ncol}, &({icol}), &({array_name}[0]), INSERT_VALUES);" - if not isinstance(cexpr, CalledMapVariable) or len(cexpr.parameters) != 2: - raise NotImplementedError - cinner_axis_label = cexpr.function.full_map.name - # substitute a zero for the inner axis, we want to avoid this inner loop - new_cexpr = JnameSubstitutor( - iname_expr_replace_map | {cinner_axis_label: 0}, codegen_context - )(cexpr) - # now emit the right line of code, this should properly be a lp.ScalarCallable - # https://petsc.org/release/manualpages/Mat/MatGetValuesLocal/ - # PetscErrorCode MatGetValuesLocal(Mat mat, PetscInt nrow, const PetscInt irow[], PetscInt ncol, const PetscInt icol[], PetscScalar y[]) - nrow = rexpr.function.map_component.arity - irow = new_rexpr - ncol = cexpr.function.map_component.arity - icol = new_cexpr - - # can only use GetValuesLocal when lgmaps are set (which I don't yet do) - call_str = ( - # f"MatGetValuesLocal({mat.name}, {nrow}, &({irow}), {ncol}, &({icol}), &({temp.name}[0]));" - f"MatGetValues({mat.name}, {nrow}, &({irow}), {ncol}, &({icol}), &({temp.name}[0]));" - ) - codegen_context.add_cinstruction(call_str) +@_petsc_mat_insn.register +def _(assignment: PetscMatAdd, mat_name, array_name, nrow, ncol, irow, icol): + return f"PetscCall(MatSetValuesLocal({mat_name}, {nrow}, &({irow}), {ncol}, &({icol}), &({array_name}[0]), ADD_VALUES));" # TODO now I attach a lot of info to the context-free array, do I need to pass axes around? def parse_assignment_properly_this_time( - array, - temp, - shape, - op, - axes, + assignment, loop_indices, codegen_context, *, - iname_replace_map, - jname_replace_map, - target_path, + iname_replace_map=pmap(), + # TODO document these under "Other Parameters" axis=None, - source_path=pmap(), + path=None, ): - context = context_from_indices(loop_indices) - ctx_free_array = array.with_context(context) + axes = assignment.assignee.axes + + if strictly_all(x is None for x in [axis, path]): + for array in assignment.arrays: + codegen_context.add_argument(array) - if axis is None: axis = axes.root - target_path = target_path | ctx_free_array.target_paths.get(None, pmap()) - my_index_exprs = ctx_free_array.index_exprs.get(None, pmap()) - jname_extras = {} - for axis_label, index_expr in my_index_exprs.items(): - jname_expr = JnameSubstitutor( - iname_replace_map | jname_replace_map, codegen_context - )(index_expr) - jname_extras[axis_label] = jname_expr - jname_replace_map = jname_replace_map | jname_extras + path = pmap() if axes.is_empty: add_leaf_assignment( - array, - temp, - shape, - op, - axes, - source_path, - target_path, - iname_replace_map, - jname_replace_map, + assignment, + path, + iname_replace_map | loop_indices, codegen_context, loop_indices, ) return for component in axis.components: - iname = codegen_context.unique_name("i") - extent_var = register_extent( - component.count, iname_replace_map | jname_replace_map, codegen_context - ) - codegen_context.add_domain(iname, extent_var) + # if component.count != 1: + if True: + iname = codegen_context.unique_name("i") - new_source_path = source_path | {axis.label: component.label} # not used - new_target_path = target_path | ctx_free_array.target_paths.get( - (axis.id, component.label), {} - ) + extent_var = register_extent( + component.count, + iname_replace_map | loop_indices, + codegen_context, + ) + codegen_context.add_domain(iname, extent_var) + new_iname_replace_map = iname_replace_map | {axis.label: pym.var(iname)} + within_inames = {iname} + else: + new_iname_replace_map = iname_replace_map | {axis.label: 0} + within_inames = set() + + path_ = path | {axis.label: component.label} - new_iname_replace_map = iname_replace_map | {axis.label: pym.var(iname)} - - # I don't like that I need to do this here and also when I emit the layout - # instructions. - # Do I need the jnames on the way down? Think so for things like ragged... - my_index_exprs = ctx_free_array.index_exprs.get((axis.id, component.label), {}) - jname_extras = {} - for axis_label, index_expr in my_index_exprs.items(): - jname_expr = JnameSubstitutor( - new_iname_replace_map | jname_replace_map, codegen_context - )(index_expr) - jname_extras[axis_label] = jname_expr - new_jname_replace_map = jname_replace_map | jname_extras - # new_jname_replace_map = new_iname_replace_map - - with codegen_context.within_inames({iname}): + with codegen_context.within_inames(within_inames): if subaxis := axes.child(axis, component): parse_assignment_properly_this_time( - array, - temp, - shape, - op, - axes, + assignment, loop_indices, codegen_context, - axis=subaxis, - source_path=new_source_path, - target_path=new_target_path, iname_replace_map=new_iname_replace_map, - jname_replace_map=new_jname_replace_map, + axis=subaxis, + path=path_, ) else: add_leaf_assignment( - array, - temp, - shape, - op, - axes, - new_source_path, - new_target_path, - new_iname_replace_map, - new_jname_replace_map, + assignment, + path_, + new_iname_replace_map | loop_indices, codegen_context, loop_indices, ) def add_leaf_assignment( - array, - temporary, - shape, - op, - axes, - source_path, - target_path, + assignment, + path, iname_replace_map, - jname_replace_map, codegen_context, loop_indices, ): - context = context_from_indices(loop_indices) + larr = assignment.assignee + rarr = assignment.expression - assert isinstance(array, (HierarchicalArray, ContextSensitiveMultiArray)) - - def array_expr(): - array_ = array.with_context(context) - return make_array_expr( - array, - array_.layouts[target_path], - target_path, - iname_replace_map | jname_replace_map, + if isinstance(rarr, HierarchicalArray): + rexpr = make_array_expr( + rarr, + path, + iname_replace_map, codegen_context, ) + else: + assert isinstance(rarr, numbers.Number) + rexpr = rarr - temp_expr = functools.partial( - make_temp_expr, - temporary, - shape, - source_path, + lexpr = make_array_expr( + larr, + path, iname_replace_map, codegen_context, ) - if op == AssignmentType.READ: - lexpr = temp_expr() - rexpr = array_expr() - elif op == AssignmentType.WRITE: - lexpr = array_expr() - rexpr = temp_expr() - elif op == AssignmentType.INC: - lexpr = array_expr() - rexpr = lexpr + temp_expr() - elif op == AssignmentType.ZERO: - lexpr = temp_expr() - rexpr = 0 + # if larr.name == "t_4": + # breakpoint() + + if isinstance(assignment, AddAssignment): + rexpr = lexpr + rexpr else: - raise AssertionError("Invalid assignment type") + assert isinstance(assignment, ReplaceAssignment) codegen_context.add_assignment(lexpr, rexpr) -def make_array_expr(array, layouts, path, jnames, ctx): +def make_array_expr(array, path, inames, ctx): array_offset = make_offset_expr( - layouts, - jnames, - ctx, - ) - return pym.subscript(pym.var(array.name), array_offset) - - -def make_temp_expr(temporary, shape, path, jnames, ctx): - layout = temporary.axes.layouts[path] - temp_offset = make_offset_expr( - layout, - jnames, + array.subst_layouts[path], + inames, ctx, ) # hack to handle the fact that temporaries can have shape but we want to # linearly index it here - extra_indices = (0,) * (len(shape) - 1) - # also has to be a scalar, not an expression - temp_offset_var = ctx.unique_name("off") - ctx.add_temporary(temp_offset_var) - ctx.add_assignment(temp_offset_var, temp_offset) - temp_offset_var = pym.var(temp_offset_var) - return pym.subscript(pym.var(temporary.name), extra_indices + (temp_offset_var,)) + if array.name in ctx._temporary_shapes: + shape = ctx._temporary_shapes[array.name] + assert shape is not None + rank = len(shape) + extra_indices = (0,) * (rank - 1) + + # also has to be a scalar, not an expression + temp_offset_name = ctx.unique_name("j") + temp_offset_var = pym.var(temp_offset_name) + ctx.add_temporary(temp_offset_name) + ctx.add_assignment(temp_offset_var, array_offset) + indices = extra_indices + (temp_offset_var,) + else: + indices = (array_offset,) + + name = ctx.actual_to_kernel_rename_map[array.name] + return pym.subscript(pym.var(name), indices) class JnameSubstitutor(pym.mapper.IdentityMapper): def __init__(self, replace_map, codegen_context): - self._labels_to_jnames = replace_map + self._replace_map = replace_map self._codegen_context = codegen_context def map_axis_variable(self, expr): - return self._labels_to_jnames[expr.axis_label] + return self._replace_map[expr.axis_label] # this is cleaner if I do it as a single line expression # rather than register assignments for things. - def map_multi_array(self, expr): - path = expr.array.axes.path(*expr.array.axes.leaf) - replace_map = {axis: self.rec(index) for axis, index in expr.indices.items()} - varname = _scalar_assignment( - expr.array, - path, + def map_array(self, expr): + # Register data + self._codegen_context.add_argument(expr.array) + new_name = self._codegen_context.actual_to_kernel_rename_map[expr.array.name] + + replace_map = {ax: self.rec(expr_) for ax, expr_ in expr.indices.items()} + replace_map.update(self._replace_map) + + offset_expr = make_offset_expr( + expr.array.subst_layouts[expr.path], replace_map, self._codegen_context, ) - return varname + rexpr = pym.subscript(pym.var(new_name), offset_expr) + return rexpr def map_called_map(self, expr): if not isinstance(expr.function.map_component.array, HierarchicalArray): @@ -940,7 +992,7 @@ def map_called_map(self, expr): # handle [map0(p)][map1(p)] where map0 does not have an associated loop try: - jname = self._labels_to_jnames[expr.function.full_map.name] + jname = self._replace_map[expr.function.full_map.name] except KeyError: jname = self._codegen_context.unique_name("j") self._codegen_context.add_temporary(jname) @@ -970,7 +1022,14 @@ def map_called_map(self, expr): return jname_expr def map_loop_index(self, expr): - return self._labels_to_jnames[expr.name, expr.axis] + # if expr.id.endswith("1"): + # breakpoint() + # FIXME pretty sure I have broken local loop index stuff + if isinstance(expr, LocalLoopIndexVariable): + return self._replace_map[expr.id][0][expr.axis] + else: + assert isinstance(expr, LoopIndexVariable) + return self._replace_map[expr.id][1][expr.axis] def map_call(self, expr): if expr.function.name == "mybsearch": @@ -978,14 +1037,12 @@ def map_call(self, expr): else: raise NotImplementedError("hmm") - # def _flatten(self, expr): - # for - def _map_bsearch(self, expr): indices_var, axis_var = expr.parameters indices = indices_var.array - leaf_axis, leaf_component = indices.axes.leaf + leaf_axis = indices.axes.leaf_axis + leaf_component = indices.axes.leaf_component ctx = self._codegen_context # should do elsewhere? @@ -1009,32 +1066,42 @@ def _map_bsearch(self, expr): ctx.add_assignment(key_var, key_expr) # base + # replace loop indices with axis variables - this feels very hacky replace_map = {} - for key, replace_expr in self._labels_to_jnames.items(): - # for (LoopIndex_id0, axis0) - if isinstance(key, tuple): - replace_map[key[1]] = replace_expr + for key, replace_expr in self._replace_map.items(): + # loop indices + if isinstance(replace_expr, tuple): + # use target exprs + replace_expr = replace_expr[1] + for ax, rep_expr in replace_expr.items(): + replace_map[ax] = rep_expr else: - assert isinstance(key, str) replace_map[key] = replace_expr # and set start to zero start_replace_map = replace_map.copy() start_replace_map[leaf_axis.label] = 0 start_expr = make_offset_expr( - indices.layouts[indices.axes.path(leaf_axis, leaf_component)], + indices.subst_layouts[indices.axes.path(leaf_axis, leaf_component)], start_replace_map, self._codegen_context, ) base_varname = ctx.unique_name("base") + + # rename things + indices_name = ctx.actual_to_kernel_rename_map[indices.name] + renamer = Renamer(ctx.actual_to_kernel_rename_map) + start_expr = renamer(start_expr) + # breaks if unsigned ctx.add_cinstruction( - f"int32_t* {base_varname} = {indices.name} + {start_expr};", {indices.name} + f"int32_t* {base_varname} = {indices_name} + {start_expr};", {indices_name} ) # nitems nitems_varname = ctx.unique_name("nitems") ctx.add_temporary(nitems_varname) + nitems_expr = register_extent(leaf_component.count, replace_map, ctx) # result @@ -1066,7 +1133,7 @@ def make_offset_expr( return JnameSubstitutor(jname_replace_map, codegen_context)(layouts) -def register_extent(extent, jnames, ctx): +def register_extent(extent, iname_replace_map, ctx): if isinstance(extent, numbers.Integral): return extent @@ -1078,7 +1145,14 @@ def register_extent(extent, jnames, ctx): path = extent.axes.path(*extent.axes.leaf) else: path = pmap() - expr = _scalar_assignment(extent, path, jnames, ctx) + + index_exprs = extent.index_exprs.get(None, {}) + # extent must be linear + if not extent.axes.is_empty: + for axis, cpt in extent.axes.path_with_nodes(*extent.axes.leaf).items(): + index_exprs.update(extent.index_exprs[axis.id, cpt]) + + expr = _scalar_assignment(extent, path, index_exprs, iname_replace_map, ctx) varname = ctx.unique_name("p") ctx.add_temporary(varname) @@ -1086,11 +1160,6 @@ def register_extent(extent, jnames, ctx): return varname -class MultiArrayCollector(pym.mapper.Collector): - def map_multi_array(self, expr): - return {expr} - - class VariableReplacer(pym.mapper.IdentityMapper): def __init__(self, replace_map): self._replace_map = replace_map @@ -1101,29 +1170,37 @@ def map_variable(self, expr): def _scalar_assignment( array, - path, - array_labels_to_jnames, + source_path, + index_exprs, + iname_replace_map, ctx, ): # Register data ctx.add_argument(array) + # can this all go? + index_keys = [None] + [ + (axis.id, cpt.label) + for axis, cpt in array.axes.detailed_path(source_path).items() + ] + target_path = merge_dicts(array.target_paths.get(key, {}) for key in index_keys) + # index_exprs = merge_dicts(array.index_exprs.get(key, {}) for key in index_keys) + + jname_replace_map = {} + replacer = JnameSubstitutor(iname_replace_map, ctx) + for axlabel, index_expr in index_exprs.items(): + jname_replace_map[axlabel] = replacer(index_expr) + offset_expr = make_offset_expr( - array.layouts[path], - array_labels_to_jnames, + array.layouts[target_path], + jname_replace_map, ctx, ) - rexpr = pym.subscript(pym.var(array.name), offset_expr) + name = ctx.actual_to_kernel_rename_map[array.name] + rexpr = pym.subscript(pym.var(name), offset_expr) return rexpr -def context_from_indices(loop_indices): - loop_context = {} - for loop_index, (path, _) in loop_indices.items(): - loop_context[loop_index.id] = path - return freeze(loop_context) - - # lives here?? @functools.singledispatch def _as_pointer(array) -> int: @@ -1161,9 +1238,4 @@ def _(arg: PackedBuffer): @_as_pointer.register def _(array: PetscMat): - return array.petscmat.handle - - -@_as_pointer.register -def _(arg: Tensor): - return _as_pointer(arg.data) + return array.mat.handle diff --git a/pyop3/itree/__init__.py b/pyop3/itree/__init__.py index b903aa31..922cb503 100644 --- a/pyop3/itree/__init__.py +++ b/pyop3/itree/__init__.py @@ -6,12 +6,9 @@ LocalLoopIndex, LoopIndex, Map, - MapVariable, Slice, SliceComponent, Subset, TabulatedMapComponent, as_index_forest, - collect_loop_contexts, - index_axes, ) diff --git a/pyop3/itree/tree.py b/pyop3/itree/tree.py index 240a5d2f..fee34093 100644 --- a/pyop3/itree/tree.py +++ b/pyop3/itree/tree.py @@ -9,6 +9,7 @@ import math import numbers import sys +from functools import cached_property from typing import Any, Collection, Hashable, Mapping, Sequence import numpy as np @@ -16,13 +17,15 @@ import pyrsistent import pytools from mpi4py import MPI -from pyrsistent import freeze, pmap +from pyrsistent import PMap, freeze, pmap, thaw +from pyop3.array import HierarchicalArray from pyop3.axtree import ( Axis, AxisComponent, AxisTree, AxisVariable, + ContextAware, ContextFree, ContextSensitive, LoopIterable, @@ -33,9 +36,16 @@ ContextSensitiveLoopIterable, ExpressionEvaluator, PartialAxisTree, + UnrecognisedAxisException, ) from pyop3.dtypes import IntType, get_mpi_dtype -from pyop3.tree import LabelledTree, Node, Tree, postvisit +from pyop3.lang import KernelArgument +from pyop3.tree import ( + LabelledNodeComponent, + LabelledTree, + MultiComponentLabelledNode, + postvisit, +) from pyop3.utils import ( Identified, Labelled, @@ -49,83 +59,50 @@ bsearch = pym.var("mybsearch") -# FIXME this is copied from loopexpr2loopy VariableReplacer class IndexExpressionReplacer(pym.mapper.IdentityMapper): - def __init__(self, replace_map): + def __init__(self, replace_map, loop_exprs=pmap()): self._replace_map = replace_map + self._loop_exprs = loop_exprs def map_axis_variable(self, expr): return self._replace_map.get(expr.axis_label, expr) - def map_multi_array(self, expr): - from pyop3.array.harray import MultiArrayVariable - - indices = {axis: self.rec(index) for axis, index in expr.indices.items()} - return MultiArrayVariable(expr.array, indices) - - def map_called_map(self, expr): - array = expr.function.map_component.array - - # the inner_expr tells us the right mapping for the temporary, however, - # for maps that are arrays the innermost axis label does not always match - # the label used by the temporary. Therefore we need to do a swap here. - indices = {axis: self.rec(idx) for axis, idx in expr.parameters.items()} - return CalledMapVariable(expr.function, indices) + def map_array(self, array_var): + indices = {ax: self.rec(expr) for ax, expr in array_var.indices.items()} + return type(array_var)(array_var.array, indices, array_var.path) - def map_loop_index(self, expr): - # this is hacky, if I make this raise a KeyError then we fail in indexing - return self._replace_map.get((expr.name, expr.axis), expr) - - -# index trees are different to axis trees because we know less about -# the possible attaching components. In particular a CalledMap can -# have different "attaching components"/output components depending on -# the loop context. This is awful for a user to have to build since we -# need something like a SplitCalledMap. Instead we will just admit any -# parent_to_children map and do error checking when we convert it to shape. -class IndexTree(Tree): - def __init__(self, parent_to_children=pmap(), *, loop_context=pmap()): - super().__init__(parent_to_children) - # FIXME, don't need to modify parent_to_children in this function - parent_to_children, loop_context = parse_index_tree( - self.parent_to_children, loop_context - ) - self.loop_context = loop_context - - @staticmethod - def _parse_node(node): - if isinstance(node, Index): - return node - elif isinstance(node, Axis): - return Slice( - node.label, [AffineSliceComponent(c.label) for c in node.components] - ) + def map_loop_index(self, index): + if index.id in self._loop_exprs: + return self._loop_exprs[index.id][index.axis] else: - raise TypeError(f"No handler defined for {type(node).__name__}") - + return index -def parse_index_tree(parent_to_children, loop_context): - new_parent_to_children = parse_parent_to_children(parent_to_children, loop_context) - return pmap(new_parent_to_children), loop_context +class IndexTree(LabelledTree): + fields = LabelledTree.fields | {"outer_loops"} - -def parse_parent_to_children(parent_to_children, loop_context, parent=None): - if parent in parent_to_children: - new_children = [] - subparents_to_children = [] - for child in parent_to_children[parent]: - if child is None: - continue - child = apply_loop_context(child, loop_context) - new_children.append(child) - subparents_to_children.append( - parse_parent_to_children(parent_to_children, loop_context, child.id) - ) - - return pmap({parent: tuple(new_children)}) | merge_dicts(subparents_to_children) - else: - return pmap() + # TODO rename to node_map + def __init__(self, parent_to_children=pmap(), outer_loops=()): + super().__init__(parent_to_children) + assert isinstance(outer_loops, tuple) + self.outer_loops = outer_loops + + @classmethod + def from_nest(cls, nest): + root, node_map = cls._from_nest(nest) + node_map.update({None: [root]}) + return cls(node_map) + + @classmethod + def from_iterable(cls, iterable): + # All iterable entries must be indices for now as we do no parsing + root, *rest = iterable + node_map = {None: (root,)} + parent = root + for index in rest: + node_map.update({parent.id: (index,)}) + parent = index + return cls(node_map) class DatamapCollector(pym.mapper.CombineMapper): @@ -149,41 +126,49 @@ def collect_datamap_from_expression(expr: pym.primitives.Expr) -> dict: return _datamap_collector(expr) -class SliceComponent(pytools.ImmutableRecord, abc.ABC): - fields = {"component"} - - def __init__(self, component): - super().__init__() +class SliceComponent(LabelledNodeComponent, abc.ABC): + def __init__(self, component, *, label=None): + super().__init__(label) self.component = component class AffineSliceComponent(SliceComponent): fields = SliceComponent.fields | {"start", "stop", "step"} - def __init__(self, component, start=None, stop=None, step=None): - super().__init__(component) - # use None for the default args here since that agrees with Python slices + # use None for the default args here since that agrees with Python slices + def __init__(self, component, start=None, stop=None, step=None, **kwargs): + super().__init__(component, **kwargs) + # could be None here self.start = start if start is not None else 0 self.stop = stop + # could be None here self.step = step if step is not None else 1 @property - def datamap(self): + def datamap(self) -> PMap: return pmap() + @property + def is_full(self): + return self.start == 0 and self.stop is None and self.step == 1 + -class Subset(SliceComponent): +class SubsetSliceComponent(SliceComponent): fields = SliceComponent.fields | {"array"} - def __init__(self, component, array: MultiArray): - super().__init__(component) + def __init__(self, component, array, **kwargs): + super().__init__(component, **kwargs) self.array = array @property - def datamap(self): + def datamap(self) -> PMap: return self.array.datamap +# alternative name, better or worse? +Subset = SubsetSliceComponent + + class MapComponent(pytools.ImmutableRecord, Labelled, abc.ABC): fields = {"target_axis", "target_component", "label"} @@ -201,15 +186,22 @@ def arity(self): # TODO: Implement AffineMapComponent class TabulatedMapComponent(MapComponent): - fields = MapComponent.fields | {"array"} + fields = MapComponent.fields | {"array", "arity"} + + def __init__(self, target_axis, target_component, array, *, arity=None, label=None): + # determine the arity from the provided array + if arity is None: + leaf_axis, leaf_clabel = array.axes.leaf + leaf_cidx = leaf_axis.component_index(leaf_clabel) + arity = leaf_axis.components[leaf_cidx].count - def __init__(self, target_axis, target_component, array, *, label=None): super().__init__(target_axis, target_component, label=label) self.array = array + self._arity = arity @property def arity(self): - return self.array.axes.leaf_component.count + return self._arity # old alias @property @@ -221,69 +213,290 @@ def datamap(self): return self.array.datamap -class Index(Node): +class Index(MultiComponentLabelledNode): + fields = MultiComponentLabelledNode.fields | {"component_labels"} + + def __init__(self, label=None, *, component_labels=None, id=None): + super().__init__(label, id=id) + self._component_labels = component_labels + + @property @abc.abstractmethod - def target_paths(self, context): + def leaf_target_paths(self): + # rename to just target paths? pass - -class AbstractLoopIndex(Index, abc.ABC): + @property + def component_labels(self): + # TODO cleanup + if self._component_labels is None: + # do this for now (since leaf_target_paths currently requires an + # instantiated object to determine) + self._component_labels = tuple( + self.unique_label() for _ in self.leaf_target_paths + ) + return self._component_labels + + +class ContextFreeIndex(Index, ContextFree, abc.ABC): + # The following is unimplemented but may prove useful + # @property + # def axes(self): + # return self._tree.axes + # + # @property + # def target_paths(self): + # return self._tree.target_paths + # + # @cached_property + # def _tree(self): + # """ + # + # Notes + # ----- + # This method will deliberately not work for slices since slices + # require additional existing axis information in order to be valid. + # + # """ + # return as_index_tree(self) pass +class ContextSensitiveIndex(Index, ContextSensitive, abc.ABC): + def __init__(self, context_map, *, id=None): + Index.__init__(self, id) + ContextSensitive.__init__(self, context_map) + + +class AbstractLoopIndex( + pytools.ImmutableRecord, KernelArgument, Identified, ContextAware, abc.ABC +): + dtype = IntType + fields = {"id"} + + def __init__(self, id=None): + pytools.ImmutableRecord.__init__(self) + Identified.__init__(self, id) + + @property + def kernel_dtype(self): + return self.dtype + + +# Is this really an index? I dont think it's valid in an index tree class LoopIndex(AbstractLoopIndex): - fields = AbstractLoopIndex.fields | {"iterset"} + """ + Parameters + ---------- + iterset: AxisTree or ContextSensitiveAxisTree (!!!) + Only add context later on + + """ - # does the label ever matter here? def __init__(self, iterset, *, id=None): - super().__init__(id) + super().__init__(id=id) self.iterset = iterset - self.local_index = LocalLoopIndex(self) + + @cached_property + def local_index(self): + return LocalLoopIndex(self) @property def i(self): return self.local_index + # @property + # def paths(self): + # return tuple(self.iterset.path(*leaf) for leaf in self.iterset.leaves) + # + # NOTE: This is confusing terminology. A loop index can be context-sensitive + # in two senses: + # 1. axes.index() is context-sensitive if axes is multi-component + # 2. axes[p].index() is context-sensitive if p is context-sensitive + # I think this can be resolved by considering axes[p] and axes as "iterset" + # and handling that separately. + def with_context(self, context, *args): + iterset = self.iterset.with_context(context) + source_path, path = context[self.id] + + # think I want this sorted... + slices = [] + axis = iterset.root + while axis is not None: + cpt = source_path[axis.label] + slices.append(Slice(axis.label, AffineSliceComponent(cpt))) + axis = iterset.child(axis, cpt) + + # the iterset is a single-component full slice of the overall iterset + iterset_ = iterset[slices] + return ContextFreeLoopIndex(iterset_, source_path, path, id=self.id) + + # unsure if this is required @property - def j(self): - # is this evil? + def datamap(self): + return self.iterset.datamap + + +class LoopIndexReplacer(pym.mapper.IdentityMapper): + def __init__(self, index): + super().__init__() + self._index = index + + def map_axis_variable(self, axis_var): + # this is unconditional, key error should not occur here + return LocalLoopIndexVariable(self._index, axis_var.axis) + + def map_array(self, array_var): + indices = {ax: self.rec(expr) for ax, expr in array_var.indices.items()} + return type(array_var)(array_var.array, indices, array_var.path) + + +# FIXME class hierarchy is very confusing +class ContextFreeLoopIndex(ContextFreeIndex): + fields = {"iterset", "source_path", "path", "id"} + + def __init__(self, iterset: AxisTree, source_path, path, *, id=None): + super().__init__(id=id, label=id, component_labels=("XXX",)) + self.iterset = iterset + self.source_path = freeze(source_path) + self.path = freeze(path) + + # if self.label == "_label_ContextFreeLoopIndex_15": + # breakpoint() + + def with_context(self, context, *args): return self + @property + def leaf_target_paths(self): + return (self.path,) + + # TODO is this better as an alias for iterset? + @property + def axes(self): + return AxisTree() + + @property + def target_paths(self): + return freeze({None: self.path}) + + # should now be ignored + @property + def index_exprs(self): + if self.source_path != self.path and len(self.path) != 1: + raise NotImplementedError("no idea what to do here") + + # Need to replace the index_exprs with LocalLoopIndexVariable equivs + flat_index_exprs = {} + replacer = LoopIndexReplacer(self) + for axis in self.iterset.nodes: + key = axis.id, axis.component.label + for axis_label, orig_expr in self.iterset.index_exprs[key].items(): + new_expr = replacer(orig_expr) + flat_index_exprs[axis_label] = new_expr + + return freeze({None: flat_index_exprs}) + + # target = just_one(self.path.keys()) + # return freeze( + # { + # None: { + # target: LoopIndexVariable(self, axis) + # # for axis in self.source_path.keys() + # for axis in self.path.keys() + # }, + # } + # ) + + @property + def loops(self): + # return self.iterset.outer_loops | { + # LocalLoopIndexVariable(self, axis) + # for axis in self.iterset.path(*self.iterset.leaf).keys() + # } + # return self.iterset.outer_loops + (self,) + return (self,) + + @property + def layout_exprs(self): + # FIXME, no clue if this is right or not + return freeze({None: 0}) + @property def datamap(self): return self.iterset.datamap - def target_paths(self, context): - return (context[self.id],) - def iter(self, stuff=pmap()): - if not isinstance(self.iterset, AxisTree): - raise NotImplementedError return iter_axis_tree( - self.iterset, self.iterset.target_paths, self.iterset.index_exprs, stuff + self, + self.iterset, + self.iterset.target_paths, + self.iterset.index_exprs, + stuff, ) + # return iter_loop( + # self, + # # stuff, + # ) -class LocalLoopIndex(AbstractLoopIndex): - """Class representing a 'local' index.""" +# TODO This is properly awful, needs a big cleanup +class ContextFreeLocalLoopIndex(ContextFreeLoopIndex): + @property + def index_exprs(self): + return freeze( + { + None: { + axis: LocalLoopIndexVariable(self, axis) + for axis in self.path.keys() + } + } + ) - fields = AbstractLoopIndex.fields | {"loop_index"} - def __init__(self, loop_index: LoopIndex, *, id=None): - super().__init__(id) +# class LocalLoopIndex(AbstractLoopIndex): +class LocalLoopIndex: + """Class representing a 'local' index.""" + + def __init__(self, loop_index: LoopIndex): + # super().__init__(id) self.loop_index = loop_index - def target_paths(self, context): - return (context[self.id],) + # @property + # def id(self): + # return self.loop_index.id + + @property + def iterset(self): + return self.loop_index.iterset + + def with_context(self, context, axes=None): + # not sure about this + iterset = self.loop_index.iterset.with_context(context) + path, _ = context[self.loop_index.id] # here different from LoopIndex + return ContextFreeLocalLoopIndex(iterset, path, path, id=self.loop_index.id) @property def datamap(self): return self.loop_index.datamap +class ScalarIndex(ContextFreeIndex): + fields = {"axis", "component", "value", "id"} + + def __init__(self, axis, component, value, *, id=None): + super().__init__(axis, component_labels=["XXX"], id=id) + self.axis = axis + self.component = component + self.value = value + + @property + def leaf_target_paths(self): + return (freeze({self.axis: self.component}),) + + # TODO I want a Slice to have "bits" like a Map/CalledMap does -# class Slice(Index, Labelled): -class Slice(Index): +class Slice(ContextFreeIndex): """ A slice can be thought of as a map from a smaller space to the target space. @@ -293,60 +506,179 @@ class Slice(Index): """ - # TODO remove "label" - fields = Index.fields | {"axis", "slices"} - # fields = Index.fields | {"axis", "slices", "label"} + # fields = Index.fields | {"axis", "slices", "numbering"} - {"label", "component_labels"} + fields = {"axis", "slices", "numbering", "label"} - def __init__(self, axis, slices, *, id=None, label=None): - super().__init__(id) - # Index.__init__(self, id) - # Labelled.__init__(self, label) # remove + def __init__(self, axis, slices, *, numbering=None, id=None, label=None): + super().__init__(label=label, id=id) self.axis = axis self.slices = as_tuple(slices) + self.numbering = numbering - def target_paths(self, context): - return tuple(pmap({self.axis: subslice.component}) for subslice in self.slices) + @property + def components(self): + return self.slices + + @cached_property + def leaf_target_paths(self): + return tuple( + freeze({self.axis: subslice.component}) for subslice in self.slices + ) @property def datamap(self): return merge_dicts([s.datamap for s in self.slices]) - @property - def label(self): - return self.axis + +class Map(pytools.ImmutableRecord): + """ + + Notes + ----- + This class *cannot* be used as an index. Instead, one must use a + `CalledMap` which can be formed from a `Map` using call syntax. + """ + + fields = {"connectivity", "name", "numbering"} + + counter = 0 + + def __init__(self, connectivity, name=None, *, numbering=None) -> None: + # FIXME It is not appropriate to attach the numbering here because the + # numbering may differ depending on the loop context. + if numbering is not None and len(connectivity.keys()) != 1: + raise NotImplementedError + + super().__init__() + self.connectivity = freeze(connectivity) + self.numbering = numbering + + # TODO delete entirely + if name is None: + # lazy unique name + name = f"_Map_{self.counter}" + self.counter += 1 + self.name = name + + def __call__(self, index): + if isinstance(index, (ContextFreeIndex, ContextFreeCalledMap)): + leaf_target_paths = tuple( + freeze({mcpt.target_axis: mcpt.target_component}) + for path in index.leaf_target_paths + for mcpt in self.connectivity[path] + ) + return ContextFreeCalledMap(self, index, leaf_target_paths) + else: + return CalledMap(self, index) + + @cached_property + def datamap(self): + data = {} + for bit in self.connectivity.values(): + for map_cpt in bit: + data.update(map_cpt.datamap) + return pmap(data) -class CalledMap(Index, LoopIterable): - # This function cannot be part of an index tree because it has not specialised - # to a particular loop index path. - # FIXME, is this true? - def __init__(self, map, from_index, **kwargs): +class CalledMap(Identified, Labelled, LoopIterable): + def __init__(self, map, from_index, *, id=None, label=None): + Identified.__init__(self, id=id) + Labelled.__init__(self, label=label) self.map = map self.from_index = from_index - Index.__init__(self, **kwargs) def __getitem__(self, indices): raise NotImplementedError("TODO") + # figure out the current loop context, just a single loop index + from_index = self.from_index + while isinstance(from_index, CalledMap): + from_index = from_index.from_index + existing_loop_contexts = tuple( + freeze({from_index.id: path}) for path in from_index.paths + ) + + index_forest = {} + for existing_context in existing_loop_contexts: + axes = self.with_context(existing_context) + index_forest.update( + as_index_forest(indices, axes=axes, loop_context=existing_context) + ) + + array_per_context = {} + for loop_context, index_tree in index_forest.items(): + indexed_axes = _index_axes(index_tree, loop_context, self.axes) + + ( + target_paths, + index_exprs, + layout_exprs, + ) = _compose_bits( + self.axes, + self.target_paths, + self.index_exprs, + None, + indexed_axes, + indexed_axes.target_paths, + indexed_axes.index_exprs, + indexed_axes.layout_exprs, + ) + + array_per_context[loop_context] = HierarchicalArray( + indexed_axes, + data=self.array, + layouts=self.layouts, + target_paths=target_paths, + index_exprs=index_exprs, + name=self.name, + max_value=self.max_value, + ) + return ContextSensitiveMultiArray(array_per_context) def index(self) -> LoopIndex: - contexts = collect_loop_contexts(self) - # FIXME this assumption is not always true - context = just_one(contexts) - axes, target_paths, index_exprs, layout_exprs = collect_shape_index_callback( - self, loop_indices=context + context_map = { + ctx: _index_axes(itree, ctx) for ctx, itree in as_index_forest(self).items() + } + context_sensitive_axes = ContextSensitiveAxisTree(context_map) + return LoopIndex(context_sensitive_axes) + + def iter(self, outer_loops=()): + loop_context = merge_dicts( + iter_entry.loop_context for iter_entry in outer_loops + ) + cf_called_map = self.with_context(loop_context) + # breakpoint() + return iter_axis_tree( + self.index(), + cf_called_map.axes, + cf_called_map.target_paths, + cf_called_map.index_exprs, + outer_loops, ) - axes = AxisTree.from_node_map(axes.parent_to_children) + def with_context(self, context, axes=None): + # TODO stole this docstring from elsewhere, correct it + """Remove map outputs that are not present in the axes. - axes = AxisTree( - axes.parent_to_children, - target_paths, - index_exprs, - layout_exprs, - ) + This is useful for the case where we have a general map acting on a + restricted set of axes. An example would be a cell closure map (maps + cells to cells, edges and vertices) acting on a data structure that + only holds values on vertices. The cell-to-cell and cell-to-edge elements + of the closure map would produce spurious entries in the index tree. - context_sensitive_axes = ContextSensitiveAxisTree({context: axes}) - return LoopIndex(context_sensitive_axes) + If the map has no valid outputs then an exception will be raised. + + """ + cf_index = self.from_index.with_context(context, axes) + leaf_target_paths = tuple( + freeze({mcpt.target_axis: mcpt.target_component}) + for path in cf_index.leaf_target_paths + for mcpt in self.connectivity[path] + ) + if len(leaf_target_paths) == 0: + raise RuntimeError + return ContextFreeCalledMap( + self.map, cf_index, leaf_target_paths, id=self.id, label=self.label + ) @property def name(self): @@ -356,42 +688,82 @@ def name(self): def connectivity(self): return self.map.connectivity - def target_paths(self, context): - targets = [] - for src_path in self.from_index.target_paths(context): - for map_component in self.connectivity[src_path]: - targets.append( - pmap({map_component.target_axis: map_component.target_component}) - ) - return tuple(targets) +# class ContextFreeCalledMap(Index, ContextFree): +# TODO: ContextFreeIndex +class ContextFreeCalledMap(Index): + # FIXME this is clumsy + # fields = Index.fields | {"map", "index", "leaf_target_paths"} - {"label", "component_labels"} + fields = {"map", "index", "leaf_target_paths", "label", "id"} -class Map(pytools.ImmutableRecord): - """ + def __init__(self, map, index, leaf_target_paths, *, id=None, label=None): + super().__init__(id=id, label=label) + self.map = map + # better to call it "input_index"? + self.index = index + self._leaf_target_paths = leaf_target_paths + + # alias for compat with ContextFreeCalledMap + self.from_index = index + + # TODO cleanup + def with_context(self, context, axes=None): + # maybe this line isn't needed? + # cf_index = self.from_index.with_context(context, axes) + cf_index = self.index + leaf_target_paths = tuple( + freeze({mcpt.target_axis: mcpt.target_component}) + for path in cf_index.leaf_target_paths + for mcpt in self.map.connectivity[path] + # if axes is None we are *building* the axes from this map + if axes is None + or axes.is_valid_path( + {mcpt.target_axis: mcpt.target_component}, complete=False + ) + ) + if len(leaf_target_paths) == 0: + raise RuntimeError + return ContextFreeCalledMap(self.map, cf_index, leaf_target_paths, id=self.id) - Notes - ----- - This class *cannot* be used as an index. Instead, one must use a - `CalledMap` which can be formed from a `Map` using call syntax. - """ + @property + def name(self) -> str: + return self.map.name - fields = {"connectivity", "name"} + # is this ever used? + # @property + # def components(self): + # return self.map.connectivity[self.index.target_paths] - def __init__(self, connectivity, name, **kwargs) -> None: - super().__init__(**kwargs) - self.connectivity = connectivity - self.name = name + @property + def leaf_target_paths(self): + return self._leaf_target_paths - def __call__(self, index) -> Union[CalledMap, ContextSensitiveCalledMap]: - return CalledMap(self, index) + # return tuple( + # freeze({mcpt.target_axis: mcpt.target_component}) + # for path in self.index.leaf_target_paths + # for mcpt in self.map.connectivity[path] + # ) - @functools.cached_property - def datamap(self): - data = {} - for bit in self.connectivity.values(): - for map_cpt in bit: - data.update(map_cpt.datamap) - return pmap(data) + @cached_property + def axes(self): + return self._axes_info[0] + + @cached_property + def target_paths(self): + return self._axes_info[1] + + @cached_property + def index_exprs(self): + return self._axes_info[2] + + @cached_property + def layout_exprs(self): + return self._axes_info[3] + + # TODO This is bad design, unroll the traversal and store as properties + @cached_property + def _axes_info(self): + return collect_shape_index_callback(self, (), prev_axes=None) class LoopIndexVariable(pym.primitives.Variable): @@ -418,420 +790,353 @@ def datamap(self): return self.index.datamap -class MapVariable(pym.primitives.Variable): - """Pymbolic variable representing the action of a map.""" +class LoopIndexEnumerateIndexVariable(pym.primitives.Leaf): + """Variable representing the index of an enumerated index. - mapper_method = sys.intern("map_map_variable") + The variable is equivalent to the index ``i`` in the expression - def __init__(self, full_map, map_component): - super().__init__(map_component.array.name) - self.full_map = full_map - self.map_component = map_component + for i, x in enumerate(X): + ... - def __call__(self, *args): - return CalledMapVariable(self, *args) + Here, if ``X`` were composed of multiple axes, this class would + be implemented like - @functools.cached_property - def datamap(self): - return self.map_component.datamap + i = 0 + for x0 in X[0]: + for x1 in X[1]: + x = f(x0, x1) + ... + i += 1 + This class is very important because it allows us to express layouts + when we materialise indexed things. An example is the maps that are + required for indexing PETSc matrices. -class CalledMapVariable(pym.primitives.Call): - def __str__(self) -> str: - return f"{self.function.name}({self.parameters})" + """ - mapper_method = sys.intern("map_called_map") + init_arg_names = ("index",) - @functools.cached_property - def datamap(self): - return self.function.datamap | merge_dicts( - idx.datamap for idx in self.parameters.values() - ) + mapper_method = sys.intern("map_enumerate") + + # This could perhaps support a target_axis argument in future were we + # to have loop indices targeting multiple output axes. + def __init__(self, index): + super().__init__() + self.index = index + + def __getinitargs__(self) -> tuple: + return (self.index,) + + @property + def datamap(self) -> PMap: + return self.index.datamap + + +class LocalLoopIndexVariable(LoopIndexVariable): + pass class ContextSensitiveCalledMap(ContextSensitiveLoopIterable): pass +# TODO make kwargs explicit +def as_index_forest(forest: Any, *, axes=None, strict=False, **kwargs): + # TODO: I think that this is the wrong place for this to exist. Also + # the implementation only seems to work for flat axes. + if forest is Ellipsis: + # full slice of all components + assert axes is not None + if axes.is_empty: + raise NotImplementedError("TODO, think about this") + forest = Slice( + axes.root.label, + [AffineSliceComponent(c.label) for c in axes.root.components], + ) + + forest = _as_index_forest(forest, axes=axes, **kwargs) + assert isinstance(forest, dict), "must be ordered" + + # If axes are provided then check that the index tree is compatible + # and add extra slices if required. + if axes is not None: + forest_ = {} + for ctx, tree in forest.items(): + if not strict: + tree = _complete_index_tree(tree, axes) + if not _index_tree_is_complete(tree, axes): + raise ValueError("Index tree does not completely index axes") + forest_[ctx] = tree + forest = forest_ + + # TODO: Clean this up, and explain why it's here. + forest_ = {} + for ctx, index_tree in forest.items(): + forest_[ctx] = index_tree.copy(outer_loops=axes.outer_loops) + forest = forest_ + return forest + + @functools.singledispatch -def apply_loop_context(arg, loop_context, *, axes, path): +def _as_index_forest(arg: Any, *, axes=None, path=pmap(), **kwargs): + # FIXME no longer a cyclic import from pyop3.array import HierarchicalArray if isinstance(arg, HierarchicalArray): + # NOTE: This is the same behaviour as for slices parent = axes._node_from_path(path) if parent is not None: parent_axis, parent_cpt = parent target_axis = axes.child(parent_axis, parent_cpt) else: target_axis = axes.root - slice_cpts = [] - # potentially a bad idea to apply the subset to all components. Might want to match - # labels. In fact I enforce that here and so multiple components would break things. - # Not sure what the right approach is. This is also potentially tricky for multi-level - # subsets - array_axis, array_component = arg.axes.leaf - for cpt in target_axis.components: - slice_cpt = Subset(cpt.label, arg) - slice_cpts.append(slice_cpt) - return Slice(target_axis.label, slice_cpts) - elif isinstance(arg, str): - # component label - # FIXME this is not right, only works at top level - return Slice(axes.root.label, AffineSliceComponent(arg)) - elif isinstance(arg, numbers.Integral): - return apply_loop_context( - slice(arg, arg + 1), loop_context, axes=axes, path=path - ) - else: - raise TypeError - - -@apply_loop_context.register -def _(index: Index, loop_context, **kwargs): - return index - - -@apply_loop_context.register -def _(index: Axis, *args, **kwargs): - return Slice(index.label, [AffineSliceComponent(c.label) for c in index.components]) - - -@apply_loop_context.register -def _(slice_: slice, loop_context, axes, path): - parent = axes._node_from_path(path) - if parent is not None: - parent_axis, parent_cpt = parent - target_axis = axes.child(parent_axis, parent_cpt) - else: - target_axis = axes.root - slice_cpts = [] - for cpt in target_axis.components: - slice_cpt = AffineSliceComponent( - cpt.label, slice_.start, slice_.stop, slice_.step - ) - slice_cpts.append(slice_cpt) - return Slice(target_axis.label, slice_cpts) - - -def combine_contexts(contexts): - new_contexts = [] - for mycontexts in itertools.product(*contexts): - new_contexts.append(pmap(merge_dicts(mycontexts))) - return new_contexts - -@functools.singledispatch -def collect_loop_indices(arg): - from pyop3.array import HierarchicalArray + if target_axis.degree > 1: + raise ValueError( + "Passing arrays as indices is only allowed when there is no ambiguity" + ) - if isinstance(arg, (HierarchicalArray, Slice, slice, str)): - return () - elif isinstance(arg, collections.abc.Iterable): - return sum(map(collect_loop_indices, arg), ()) + slice_cpt = Subset(target_axis.component.label, arg) + slice_ = Slice(target_axis.label, [slice_cpt]) + return {pmap(): IndexTree(slice_)} else: - raise NotImplementedError + raise TypeError(f"No handler provided for {type(arg).__name__}") -@collect_loop_indices.register -def _(arg: LoopIndex): - return (arg,) - - -@collect_loop_indices.register -def _(arg: LocalLoopIndex): - return (arg,) +@_as_index_forest.register +def _(indices: collections.abc.Sequence, *, path=pmap(), loop_context=pmap(), **kwargs): + index, *subindices = indices + # FIXME This fails because strings are considered sequences, perhaps we should + # cast component labels into their own type? + # if isinstance(index, collections.abc.Sequence): + # # what's the right exception? Some sort of BadIndexException? + # raise ValueError("Nested iterables are not supported") + + forest = {} + # TODO, it is a bad pattern to build a forest here when I really just want to convert + # a single index + for context, tree in _as_index_forest( + index, path=path, loop_context=loop_context, **kwargs + ).items(): + # converting a single index should only produce index trees with depth 1 + assert tree.depth == 1 + cf_index = tree.root + + if subindices: + for clabel, target_path in checked_zip( + cf_index.component_labels, cf_index.leaf_target_paths + ): + # if not kwargs["axes"].is_valid_path(path|target_path): + # continue + subforest = _as_index_forest( + subindices, + path=path | target_path, + loop_context=loop_context | context, + **kwargs, + ) + for subctx, subtree in subforest.items(): + forest[subctx] = tree.add_subtree(subtree, cf_index, clabel) + else: + forest[context] = tree + return forest -@collect_loop_indices.register -def _(arg: IndexTree): - return collect_loop_indices(arg.root) + tuple( - loop_index - for child in arg.parent_to_children.values() - for loop_index in collect_loop_indices(child) - ) +@_as_index_forest.register +def _(forest: collections.abc.Mapping, **kwargs): + return forest -@collect_loop_indices.register -def _(arg: CalledMap): - return collect_loop_indices(arg.from_index) +@_as_index_forest.register +def _(index_tree: IndexTree, **kwargs): + return {pmap(): index_tree} -@collect_loop_indices.register -def _(arg: int): - return () +@_as_index_forest.register +def _(index: ContextFreeIndex, **kwargs): + return {pmap(): IndexTree(index)} -def loop_contexts_from_iterable(indices): - all_loop_indices = tuple( - loop_index for index in indices for loop_index in collect_loop_indices(index) - ) - if len(all_loop_indices) == 0: - return {} +# TODO This function can definitely be refactored +@_as_index_forest.register(AbstractLoopIndex) +@_as_index_forest.register(LocalLoopIndex) +def _(index, *, loop_context=pmap(), **kwargs): + local = isinstance(index, LocalLoopIndex) - contexts = combine_contexts( - [collect_loop_contexts(idx) for idx in all_loop_indices] - ) + forest = {} + if isinstance(index.iterset, ContextSensitive): + for context, axes in index.iterset.context_map.items(): + if axes.is_empty: + source_path = pmap() + target_path = axes.target_paths.get(None, pmap()) - # add on context-free contexts, these cannot already be included - for index in indices: - if not isinstance(index, ContextSensitive): - continue - loop_index, paths = index.loop_context - if loop_index in contexts[0].keys(): - raise AssertionError - for ctx in contexts: - ctx[loop_index.id] = paths - return contexts + context_ = ( + loop_context | context | {index.id: (source_path, target_path)} + ) + cf_index = index.with_context(context_) + forest[context_] = IndexTree(cf_index) + else: + for leaf in axes.leaves: + source_path = axes.path(*leaf) + target_path = axes.target_paths.get(None, pmap()) + for axis, cpt in axes.path_with_nodes( + *leaf, and_components=True + ).items(): + target_path |= axes.target_paths.get((axis.id, cpt.label), {}) -@functools.singledispatch -def collect_loop_contexts(arg, *args, **kwargs): - from pyop3.array import HierarchicalArray + context_ = ( + loop_context | context | {index.id: (source_path, target_path)} + ) - if isinstance(arg, (HierarchicalArray, numbers.Integral)): - return {} - elif isinstance(arg, collections.abc.Iterable): - return loop_contexts_from_iterable(arg) - if arg is Ellipsis: - return {} + cf_index = index.with_context(context_) + forest[context_] = IndexTree(cf_index) else: - raise TypeError - - -@collect_loop_contexts.register -def _(index_tree: IndexTree): - contexts = {} - for loop_index, paths in index_tree.loop_context.items(): - contexts[loop_index] = [paths] - return contexts - - -@collect_loop_contexts.register -def _(arg: LocalLoopIndex): - return collect_loop_contexts(arg.loop_index, local=True) - - -@collect_loop_contexts.register -def _(arg: LoopIndex, local=False): - if isinstance(arg.iterset, ContextSensitiveAxisTree): - contexts = [] - for loop_context, axis_tree in arg.iterset.context_map.items(): - extra_source_context = {} - extracontext = {} - for leaf in axis_tree.leaves: - source_path = axis_tree.path(*leaf) - target_path = {} - for axis, cpt in axis_tree.path_with_nodes( - *leaf, and_components=True - ).items(): - target_path.update(axis_tree.target_paths[axis.id, cpt.label]) - extra_source_context.update(source_path) - extracontext.update(target_path) - if local: - contexts.append( - loop_context | {arg.local_index.id: pmap(extra_source_context)} - ) - else: - contexts.append(loop_context | {arg.id: pmap(extracontext)}) - return tuple(contexts) - else: - assert isinstance(arg.iterset, AxisTree) - iterset = arg.iterset - contexts = [] - for leaf_axis, leaf_cpt in iterset.leaves: - source_path = iterset.path(leaf_axis, leaf_cpt) - target_path = {} - for axis, cpt in iterset.path_with_nodes( + assert isinstance(index.iterset, ContextFree) + for leaf_axis, leaf_cpt in index.iterset.leaves: + source_path = index.iterset.path(leaf_axis, leaf_cpt) + target_path = index.iterset.target_paths.get(None, pmap()) + for axis, cpt in index.iterset.path_with_nodes( leaf_axis, leaf_cpt, and_components=True ).items(): - target_path.update( - iterset.target_paths[axis.id, cpt.label] - # iterset.paths[axis.id, cpt.label] - ) - if local: - contexts.append(pmap({arg.local_index.id: source_path})) - else: - contexts.append(pmap({arg.id: pmap(target_path)})) - return tuple(contexts) - - -def _paths_from_called_map_loop_index(index, context): - # terminal - if isinstance(index, LoopIndex): - return (context[index][1],) - - assert isinstance(index, CalledMap) - paths = [] - for from_path in _paths_from_called_map_loop_index(index.from_index, context): - for map_component in index.connectivity[from_path]: - paths.append( - ( - pmap({index.label: map_component.label}), - pmap({map_component.target_axis: map_component.target_component}), - ) - ) - return tuple(paths) - + target_path |= index.iterset.target_paths[axis.id, cpt.label] + # TODO cleanup + my_id = index.id if not local else index.loop_index.id + context = loop_context | {my_id: (source_path, target_path)} -@collect_loop_contexts.register -def _(called_map: CalledMap): - return collect_loop_contexts(called_map.from_index) + cf_index = index.with_context(context) + forest[context] = IndexTree(cf_index) + return forest -@collect_loop_contexts.register -def _(slice_: slice): - return () +@_as_index_forest.register(CalledMap) +@_as_index_forest.register(ContextFreeCalledMap) +def _(called_map, *, axes, **kwargs): + forest = {} + input_forest = _as_index_forest(called_map.from_index, axes=axes, **kwargs) + for context in input_forest.keys(): + cf_called_map = called_map.with_context(context, axes) + forest[context] = IndexTree(cf_called_map) + return forest -@collect_loop_contexts.register -def _(slice_: Slice): - return () +@_as_index_forest.register +def _(index: numbers.Integral, **kwargs): + return _as_index_forest(slice(index, index + 1), **kwargs) -def is_fully_indexed(axes: AxisTree, indices: IndexTree) -> bool: - """Check that the provided indices are compatible with the axis tree.""" - # To check for correctness we ensure that all of the paths through the - # index tree generate valid paths through the axis tree. - for leaf_index, component_label in indices.leaves: - # this maps indices to the specific component being accessed - # use this to find the right target_path - index_path = indices.path_with_nodes(leaf_index, component_label) - - full_target_path = {} - for index, cpt_label in index_path.items(): - # select the target_path corresponding to this component label - cidx = index.component_labels.index(cpt_label) - full_target_path |= index.target_paths[cidx] - - # the axis addressed by the full path should be a leaf, else we are - # not fully indexing the array - final_axis, final_cpt = axes._node_from_path(full_target_path) - if axes.child(final_axis, final_cpt) is not None: - return False - - return True +@_as_index_forest.register +def _(slice_: slice, *, axes=None, path=pmap(), loop_context=pmap(), **kwargs): + if axes is None: + raise RuntimeError("invalid slice usage") + parent = axes._node_from_path(path) + if parent is not None: + parent_axis, parent_cpt = parent + target_axis = axes.child(parent_axis, parent_cpt) + else: + target_axis = axes.root -def _collect_datamap(index, *subdatamaps, itree): - return index.datamap | merge_dicts(subdatamaps) + if target_axis.degree > 1: + # badindexexception? + raise ValueError( + "Cannot slice multi-component things using generic slices, ambiguous" + ) + slice_cpt = AffineSliceComponent( + target_axis.component.label, slice_.start, slice_.stop, slice_.step + ) + slice_ = Slice(target_axis.label, [slice_cpt]) + return {loop_context: IndexTree(slice_)} -def index_tree_from_ellipsis(axes, current_axis=None, first_call=True): - current_axis = current_axis or axes.root - slice_components = [] - subroots = [] - subtrees = [] - for component in current_axis.components: - slice_components.append(AffineSliceComponent(component.label)) - if subaxis := axes.child(current_axis, component): - subroot, subtree = index_tree_from_ellipsis(axes, subaxis, first_call=False) - subroots.append(subroot) - subtrees.append(subtree) - else: - subroots.append(None) - subtrees.append({}) +@_as_index_forest.register +def _(label: str, *, axes, **kwargs): + # if we use a string then we assume we are taking a full slice of the + # top level axis + axis = axes.root + component = just_one(c for c in axis.components if c.label == label) + slice_ = Slice(axis.label, [AffineSliceComponent(component.label)]) + return _as_index_forest(slice_, axes=axes, **kwargs) - fullslice = Slice(current_axis.label, slice_components) - myslice = fullslice - if first_call: - return IndexTree(myslice, pmap({myslice.id: subroots}) | merge_dicts(subtrees)) - else: - return myslice, pmap({myslice.id: subroots}) | merge_dicts(subtrees) +def _complete_index_tree( + tree: IndexTree, axes: AxisTree, index=None, axis_path=pmap() +) -> IndexTree: + """Add extra slices to the index tree to match the axes. + Notes + ----- + This function is currently only capable of adding additional slices if + they are "innermost". -def index_tree_from_iterable( - indices, loop_context, axes=None, path=pmap(), first_call=False -): - index, *subindices = indices + """ + if index is None: + index = tree.root - index = apply_loop_context(index, loop_context, axes=axes, path=path) - - if subindices: - children = [] - subtrees = [] - # used to be leaves... - for target_path in index.target_paths(loop_context): - assert target_path - new_path = path | target_path - child, subtree = index_tree_from_iterable( - subindices, loop_context, axes, new_path + tree_ = IndexTree(index) + for component_label, path in checked_zip( + index.component_labels, index.leaf_target_paths + ): + axis_path_ = axis_path | path + if subindex := tree.child(index, component_label): + subtree = _complete_index_tree( + tree, + axes, + subindex, + axis_path_, ) - children.append(child) - subtrees.append(subtree) - - parent_to_children = pmap({index.id: children}) | merge_dicts(subtrees) - else: - parent_to_children = {} - - if first_call: - assert None not in parent_to_children - parent_to_children |= {None: [index]} - return IndexTree(parent_to_children, loop_context=loop_context) - else: - return index, parent_to_children - - -@functools.singledispatch -def as_index_tree(arg, loop_context, **kwargs): - if isinstance(arg, collections.abc.Iterable): - return index_tree_from_iterable(arg, loop_context, first_call=True, **kwargs) - else: - raise TypeError + else: + # At the bottom of the index tree, add any extra slices if needed. + subtree = _complete_index_tree_slices(axes, axis_path_) + tree_ = tree_.add_subtree(subtree, index, component_label) + return tree_ -@as_index_tree.register -def _(index: Index, ctx, **kwargs): - return IndexTree(index, loop_context=ctx) +def _complete_index_tree_slices(axes: AxisTree, path: PMap, axis=None) -> IndexTree: + if axis is None: + axis = axes.root -@functools.singledispatch -def as_index_forest(arg: Any, **kwargs): - from pyop3.array import HierarchicalArray - - if isinstance(arg, HierarchicalArray): - slice_ = apply_loop_context(arg, loop_context=pmap(), path=pmap(), **kwargs) - return (IndexTree(slice_),) - elif isinstance(arg, collections.abc.Sequence): - loop_contexts = collect_loop_contexts(arg) or [pmap()] - forest = [] - for context in loop_contexts: - forest.append(as_index_tree(arg, context, **kwargs)) - return tuple(forest) + if axis.label in path: + if subaxis := axes.child(axis, path[axis.label]): + return _complete_index_tree_slices(axes, path, subaxis) + else: + return IndexTree() else: - raise TypeError - - -@as_index_forest.register -def _(index_tree: IndexTree, **kwargs): - return (index_tree,) - - -@as_index_forest.register -def _(index: Index, **kwargs): - loop_contexts = collect_loop_contexts(index) or [pmap()] - forest = [] - for context in loop_contexts: - forest.append(as_index_tree(index, context, **kwargs)) - return tuple(forest) - - -@as_index_forest.register -def _(slice_: slice, **kwargs): - slice_ = apply_loop_context(slice_, loop_context=pmap(), path=pmap(), **kwargs) - return (IndexTree(slice_),) - + # Axis is missing from the index tree, use a full slice. + slice_ = Slice( + axis.label, [AffineSliceComponent(c.label) for c in axis.components] + ) + tree = IndexTree(slice_) + + for axis_component, index_component in checked_zip( + axis.components, slice_.component_labels + ): + if subaxis := axes.child(axis, axis_component): + subtree = _complete_index_tree_slices(axes, path, subaxis) + tree = tree.add_subtree(subtree, slice_, index_component) + return tree + + +def _index_tree_is_complete(indices: IndexTree, axes: AxisTree): + """Return whether the index tree completely indexes the axis tree.""" + # For each leaf in the index tree, collect the resulting axis path + # and check that this is a leaf of the axis tree. + for index_leaf_path in indices.ordered_leaf_paths_with_nodes: + axis_path = {} + for index, index_cpt_label in index_leaf_path: + index_cpt_index = index.component_labels.index(index_cpt_label) + for axis, axis_cpt in index.leaf_target_paths[index_cpt_index].items(): + assert axis not in axis_path, "Paths should not clash" + axis_path[axis] = axis_cpt + axis_path = freeze(axis_path) + + if axis_path not in axes.leaf_paths: + return False -@as_index_forest.register -def _(label: str, *, axes, **kwargs): - # if we use a string then we assume we are taking a full slice of the - # top level axis - axis = axes.root - component = just_one(c for c in axis.components if c.label == label) - slice_ = Slice(axis.label, [AffineSliceComponent(component.label)]) - return as_index_forest(slice_, axes=axes, **kwargs) + # All leaves of the tree are complete + return True @functools.singledispatch @@ -840,80 +1145,137 @@ def collect_shape_index_callback(index, *args, **kwargs): @collect_shape_index_callback.register -def _(loop_index: LoopIndex, *, loop_indices, **kwargs): - iterset = loop_index.iterset - - target_path_per_component = pmap({None: loop_indices[loop_index.id]}) - # fairly sure that here I want the *output* path of the loop indices - index_exprs_per_component = pmap( - { - None: pmap( - { - axis: LoopIndexVariable(loop_index, axis) - for axis in loop_indices[loop_index.id].keys() - } - ) - } - ) - layout_exprs_per_component = pmap({None: 0}) +def _( + loop_index: ContextFreeLoopIndex, + indices, + **kwargs, +): + axes = loop_index.axes + target_paths = loop_index.target_paths + index_exprs = loop_index.index_exprs + return ( - PartialAxisTree(), - target_path_per_component, - index_exprs_per_component, - layout_exprs_per_component, + axes, + target_paths, + index_exprs, + loop_index.layout_exprs, + loop_index.loops, + {}, ) @collect_shape_index_callback.register -def _(local_index: LocalLoopIndex, *args, loop_indices, **kwargs): - path = loop_indices[local_index.id] - - loop_index = local_index.loop_index - iterset = loop_index.iterset - - target_path_per_cpt = pmap({None: path}) - index_exprs_per_cpt = pmap( - { - None: pmap( - {axis: LoopIndexVariable(local_index, axis) for axis in path.keys()} - ) - } - ) - - layout_exprs_per_cpt = pmap({None: 0}) +def _(index: ScalarIndex, indices, **kwargs): + target_path = freeze({None: just_one(index.leaf_target_paths)}) + index_exprs = freeze({None: {index.axis: index.value}}) + layout_exprs = freeze({None: 0}) return ( - PartialAxisTree(), - target_path_per_cpt, - index_exprs_per_cpt, - layout_exprs_per_cpt, + AxisTree(), + target_path, + index_exprs, + layout_exprs, + (), + {}, ) @collect_shape_index_callback.register -def _(slice_: Slice, *, prev_axes, **kwargs): +def _(slice_: Slice, indices, *, target_path_acc, prev_axes, **kwargs): + from pyop3.array.harray import ArrayVar + + # If we are just taking a component from a multi-component array, + # e.g. mesh.points["cells"], then relabelling the axes just leads to + # needless confusion. For instance if we had + # + # myslice0 = Slice("mesh", AffineSliceComponent("cells", step=2)) + # + # then mesh.points[myslice0] would work but mesh.points["cells"][myslice0] + # would fail. + # As a counter example, if we have non-trivial subsets then this sort of + # relabelling is essential for things to make sense. If we have two subsets: + # + # subset0 = Slice("mesh", Subset("cells", [1, 2, 3])) + # + # and + # + # subset1 = Slice("mesh", Subset("cells", [4, 5, 6])) + # + # then mesh.points[subset0][subset1] is confusing, should subset1 be + # assumed to work on the already sliced axis? This can be a major source of + # confusion for things like interior facets in Firedrake where the first slice + # happens in one function and the other happens elsewhere. We hit situations like + # + # mesh.interior_facets[interior_facets_I_want] + # + # conflicts with + # + # mesh.interior_facets[facets_I_want] + # + # where one subset is given with facet numbering and the other with interior + # facet numbering. The labels are the same so identifying this is really difficult. + # + # We fix this here by requiring that non-full slices perform a relabelling and + # full slices do not. + is_full_slice = all( + isinstance(s, AffineSliceComponent) and s.is_full for s in slice_.slices + ) + + axis_label = slice_.axis if is_full_slice else slice_.label + components = [] target_path_per_subslice = [] index_exprs_per_subslice = [] layout_exprs_per_subslice = [] - axis_label = slice_.label - for subslice in slice_.slices: - # we are assuming that axes with the same label *must* be identical. They are - # only allowed to differ in that they have different IDs. - target_axis, target_cpt = prev_axes.find_component( - slice_.axis, subslice.component, also_node=True + if not prev_axes.is_valid_path(target_path_acc, complete=False): + raise NotImplementedError( + "If we swap axes around then we must check " + "that we don't get clashes." + ) + + # previous code: + # we are assuming that axes with the same label *must* be identical. They are + # only allowed to differ in that they have different IDs. + # target_axis, target_cpt = prev_axes.find_component( + # slice_.axis, subslice.component, also_node=True + # ) + + if not target_path_acc: + target_axis = prev_axes.root + else: + parent = prev_axes._node_from_path(target_path_acc) + target_axis = prev_axes.child(*parent) + + assert target_axis.label == slice_.axis + target_cpt = just_one( + c for c in target_axis.components if c.label == subslice.component ) + if isinstance(subslice, AffineSliceComponent): - if subslice.stop is None: - stop = target_cpt.count + # TODO handle this is in a test, slices of ragged things + if isinstance(target_cpt.count, HierarchicalArray): + if ( + subslice.stop is not None + or subslice.start != 0 + or subslice.step != 1 + ): + raise NotImplementedError("TODO") + if len(indices) == 0: + size = target_cpt.count + else: + size = target_cpt.count[indices] else: - stop = subslice.stop - size = math.ceil((stop - subslice.start) / subslice.step) + if subslice.stop is None: + stop = target_cpt.count + else: + stop = subslice.stop + size = math.ceil((stop - subslice.start) / subslice.step) else: assert isinstance(subslice, Subset) size = subslice.array.axes.leaf_component.count - cpt = AxisComponent(size, label=subslice.component) + mylabel = subslice.component if is_full_slice else subslice.label + cpt = AxisComponent(size, label=mylabel) components.append(cpt) target_path_per_subslice.append(pmap({slice_.axis: subslice.component})) @@ -921,21 +1283,72 @@ def _(slice_: Slice, *, prev_axes, **kwargs): newvar = AxisVariable(axis_label) layout_var = AxisVariable(slice_.axis) if isinstance(subslice, AffineSliceComponent): - index_exprs_per_subslice.append( - pmap({slice_.axis: newvar * subslice.step + subslice.start}) - ) + if is_full_slice: + index_exprs_per_subslice.append( + freeze( + { + slice_.axis: newvar * subslice.step + subslice.start, + } + ) + ) + else: + index_exprs_per_subslice.append( + freeze( + { + slice_.axis: newvar * subslice.step + subslice.start, + # slice_.label: AxisVariable(slice_.label), + } + ) + ) layout_exprs_per_subslice.append( pmap({slice_.label: (layout_var - subslice.start) // subslice.step}) ) else: - index_exprs_per_subslice.append( - pmap({slice_.axis: subslice.array.as_var()}) + assert isinstance(subslice, Subset) + + # below is also used for maps - cleanup + subset_array = subslice.array + subset_axes = subset_array.axes + + # must be single component + source_path = subset_axes.path(*subset_axes.leaf) + index_keys = [None] + [ + (axis.id, cpt.label) + for axis, cpt in subset_axes.detailed_path(source_path).items() + ] + old_index_exprs = merge_dicts( + subset_array.index_exprs.get(key, {}) for key in index_keys ) + + my_index_exprs = {} + index_expr_replace_map = {subset_axes.leaf_axis.label: newvar} + replacer = IndexExpressionReplacer(index_expr_replace_map) + for axlabel, index_expr in old_index_exprs.items(): + my_index_exprs[axlabel] = replacer(index_expr) + subset_var = ArrayVar(subslice.array, my_index_exprs) + + if is_full_slice: + index_exprs_per_subslice.append( + freeze( + { + slice_.axis: subset_var, + } + ) + ) + else: + index_exprs_per_subslice.append( + freeze( + { + slice_.axis: subset_var, + # slice_.label: AxisVariable(slice_.label), + } + ) + ) layout_exprs_per_subslice.append( - pmap({slice_.label: bsearch(subslice.array.as_var(), layout_var)}) + pmap({slice_.label: bsearch(subset_var, layout_var)}) ) - axis = Axis(components, label=axis_label) + axis = Axis(components, label=axis_label, numbering=slice_.numbering) axes = PartialAxisTree(axis) target_path_per_component = {} index_exprs_per_component = {} @@ -952,19 +1365,36 @@ def _(slice_: Slice, *, prev_axes, **kwargs): return ( axes, target_path_per_component, - index_exprs_per_component, + freeze(index_exprs_per_component), layout_exprs_per_component, + (), # no outer loops + {}, ) @collect_shape_index_callback.register -def _(called_map: CalledMap, **kwargs): +def _( + called_map: ContextFreeCalledMap, + indices, + *, + prev_axes, + **kwargs, +): ( prior_axes, prior_target_path_per_cpt, prior_index_exprs_per_cpt, _, - ) = collect_shape_index_callback(called_map.from_index, **kwargs) + outer_loops, + prior_extra_index_exprs, + ) = collect_shape_index_callback( + called_map.index, + indices, + prev_axes=prev_axes, + **kwargs, + ) + + extra_index_exprs = dict(prior_extra_index_exprs) if not prior_axes: prior_target_path = prior_target_path_per_cpt[None] @@ -974,12 +1404,19 @@ def _(called_map: CalledMap, **kwargs): target_path_per_cpt, index_exprs_per_cpt, layout_exprs_per_cpt, + more_extra_index_exprs, ) = _make_leaf_axis_from_called_map( - called_map, prior_target_path, prior_index_exprs + called_map, + prior_target_path, + prior_index_exprs, + prev_axes, ) axes = PartialAxisTree(axis) + + extra_index_exprs.update(more_extra_index_exprs) + else: - axes = prior_axes + axes = PartialAxisTree(prior_axes.parent_to_children) target_path_per_cpt = {} index_exprs_per_cpt = {} layout_exprs_per_cpt = {} @@ -990,96 +1427,218 @@ def _(called_map: CalledMap, **kwargs): for myaxis, mycomponent_label in prior_axes.path_with_nodes( prior_leaf_axis.id, prior_leaf_cpt ).items(): - prior_target_path |= prior_target_path_per_cpt[ - myaxis.id, mycomponent_label - ] - prior_index_exprs |= prior_index_exprs_per_cpt[ - myaxis.id, mycomponent_label - ] + prior_target_path |= prior_target_path_per_cpt.get( + (myaxis.id, mycomponent_label), {} + ) + prior_index_exprs |= prior_index_exprs_per_cpt.get( + (myaxis.id, mycomponent_label), {} + ) ( subaxis, subtarget_paths, subindex_exprs, sublayout_exprs, + subextra_index_exprs, ) = _make_leaf_axis_from_called_map( - called_map, prior_target_path, prior_index_exprs + called_map, + prior_target_path, + prior_index_exprs, + prev_axes, ) + axes = axes.add_subtree( - PartialAxisTree(subaxis), prior_leaf_axis, prior_leaf_cpt + PartialAxisTree(subaxis), + prior_leaf_axis, + prior_leaf_cpt, ) target_path_per_cpt.update(subtarget_paths) index_exprs_per_cpt.update(subindex_exprs) layout_exprs_per_cpt.update(sublayout_exprs) + extra_index_exprs.update(subextra_index_exprs) return ( axes, - pmap(target_path_per_cpt), - pmap(index_exprs_per_cpt), - pmap(layout_exprs_per_cpt), + freeze(target_path_per_cpt), + freeze(index_exprs_per_cpt), + freeze(layout_exprs_per_cpt), + outer_loops, + freeze(extra_index_exprs), ) -def _make_leaf_axis_from_called_map(called_map, prior_target_path, prior_index_exprs): +def _make_leaf_axis_from_called_map( + called_map, + prior_target_path, + prior_index_exprs, + prev_axes, +): + from pyop3.array.harray import CalledMapVariable + axis_id = Axis.unique_id() components = [] target_path_per_cpt = {} index_exprs_per_cpt = {} layout_exprs_per_cpt = {} + extra_index_exprs = {} + all_skipped = True for map_cpt in called_map.map.connectivity[prior_target_path]: - cpt = AxisComponent(map_cpt.arity, label=map_cpt.label) + if prev_axes is not None and not prev_axes.is_valid_path( + {map_cpt.target_axis: map_cpt.target_component}, complete=False + ): + continue + + all_skipped = False + if isinstance(map_cpt.arity, HierarchicalArray): + arity = map_cpt.arity[called_map.index] + else: + arity = map_cpt.arity + cpt = AxisComponent(arity, label=map_cpt.label) components.append(cpt) target_path_per_cpt[axis_id, cpt.label] = pmap( {map_cpt.target_axis: map_cpt.target_component} ) - map_var = MapVariable(called_map, map_cpt) - axisvar = AxisVariable(called_map.name) + axisvar = AxisVariable(called_map.map.name) + + if not isinstance(map_cpt, TabulatedMapComponent): + raise NotImplementedError("Currently we assume only arrays here") + + map_array = map_cpt.array + map_axes = map_array.axes + + assert map_axes.depth == 2 + + source_path = map_axes.path(*map_axes.leaf) + index_keys = [None] + [ + (axis.id, cpt.label) + for axis, cpt in map_axes.detailed_path(source_path).items() + ] + my_target_path = merge_dicts( + map_array.target_paths.get(key, {}) for key in index_keys + ) + + # the outer index is provided from "prior" whereas the inner one requires + # a replacement + map_leaf_axis, map_leaf_component = map_axes.leaf + old_inner_index_expr = map_array.index_exprs[ + map_leaf_axis.id, map_leaf_component + ] + + my_index_exprs = {} + index_expr_replace_map = {map_axes.leaf_axis.label: axisvar} + replacer = IndexExpressionReplacer(index_expr_replace_map) + for axlabel, index_expr in old_inner_index_expr.items(): + my_index_exprs[axlabel] = replacer(index_expr) + new_inner_index_expr = my_index_exprs + + map_var = CalledMapVariable( + map_cpt.array, merge_dicts([prior_index_exprs, new_inner_index_expr]) + ) index_exprs_per_cpt[axis_id, cpt.label] = { - map_cpt.target_axis: map_var(prior_index_exprs | {called_map.name: axisvar}) + map_cpt.target_axis: map_var, + } + + # also one for the new axis + # Nooooo, bad idea + extra_index_exprs[axis_id, cpt.label] = { + # axisvar.axis: axisvar, } # don't think that this is possible for maps layout_exprs_per_cpt[axis_id, cpt.label] = { - called_map.name: pym.primitives.NaN(IntType) + called_map.id: pym.primitives.NaN(IntType) } - axis = Axis(components, label=called_map.name, id=axis_id) + if all_skipped: + raise RuntimeError("map does not target any relevant axes") - return axis, target_path_per_cpt, index_exprs_per_cpt, layout_exprs_per_cpt + axis = Axis( + components, + label=called_map.map.name, + id=axis_id, + numbering=called_map.map.numbering, + ) + + return ( + axis, + target_path_per_cpt, + index_exprs_per_cpt, + layout_exprs_per_cpt, + extra_index_exprs, + ) -def _index_axes(axes, indices: IndexTree, loop_context): +def _index_axes( + indices: IndexTree, + loop_context, + axes=None, +): ( indexed_axes, tpaths, index_expr_per_target, layout_expr_per_target, + outer_loops, ) = _index_axes_rec( indices, + (), + pmap(), # target_path current_index=indices.root, loop_indices=loop_context, prev_axes=axes, ) - if indexed_axes is None: - indexed_axes = {} + outer_loops += indices.outer_loops - # return the new axes plus the new index expressions per leaf - return indexed_axes, tpaths, index_expr_per_target, layout_expr_per_target + # drop duplicates + outer_loops_ = [] + allids = set() + for ol in outer_loops: + if ol.id in allids: + continue + outer_loops_.append(ol) + allids.add(ol.id) + outer_loops = tuple(outer_loops_) + + # check that slices etc have not been missed + if axes is not None: + for leaf_iaxis, leaf_icpt in indexed_axes.leaves: + target_path = dict(tpaths.get(None, {})) + for iaxis, icpt in indexed_axes.path_with_nodes( + leaf_iaxis, leaf_icpt + ).items(): + target_path.update(tpaths.get((iaxis.id, icpt), {})) + if not axes.is_valid_path(target_path, leaf=True): + raise ValueError("incorrect/insufficient indices") + + return AxisTree( + indexed_axes.parent_to_children, + target_paths=tpaths, + index_exprs=index_expr_per_target, + layout_exprs=layout_expr_per_target, + outer_loops=outer_loops, + ) def _index_axes_rec( indices, + indices_acc, + target_path_acc, *, current_index, **kwargs, ): - index_data = collect_shape_index_callback(current_index, **kwargs) - axes_per_index, *rest = index_data + index_data = collect_shape_index_callback( + current_index, + indices_acc, + target_path_acc=target_path_acc, + **kwargs, + ) + axes_per_index, *rest, outer_loops, extra_index_exprs = index_data ( target_path_per_cpt_per_index, @@ -1087,81 +1646,88 @@ def _index_axes_rec( layout_exprs_per_cpt_per_index, ) = tuple(map(dict, rest)) + # if ("_id_Axis_132", "XXX") in index_exprs_per_cpt_per_index: + # breakpoint() + + # if extra_index_exprs: + # breakpoint() + if axes_per_index: leafkeys = axes_per_index.leaves else: leafkeys = [None] subaxes = {} - for leafkey in leafkeys: - if current_index.id in indices.parent_to_children: - for subindex in indices.parent_to_children[current_index.id]: - retval = _index_axes_rec( - indices, - current_index=subindex, - **kwargs, - ) - subaxes[leafkey] = retval[0] - - for key in retval[1].keys(): - if key in target_path_per_cpt_per_index: - target_path_per_cpt_per_index[key] = ( - target_path_per_cpt_per_index[key] | retval[1][key] - ) - index_exprs_per_cpt_per_index[key] = ( - index_exprs_per_cpt_per_index[key] | retval[2][key] - ) - layout_exprs_per_cpt_per_index[key] = ( - layout_exprs_per_cpt_per_index[key] | retval[3][key] - ) - else: - target_path_per_cpt_per_index.update({key: retval[1][key]}) - index_exprs_per_cpt_per_index.update({key: retval[2][key]}) - layout_exprs_per_cpt_per_index.update({key: retval[3][key]}) + if current_index.id in indices.parent_to_children: + for leafkey, subindex in checked_zip( + leafkeys, indices.parent_to_children[current_index.id] + ): + if subindex is None: + continue + indices_acc_ = indices_acc + (current_index,) + + target_path_acc_ = dict(target_path_acc) + target_path_acc_.update(target_path_per_cpt_per_index.get(None, {})) + if not axes_per_index.is_empty: + for _ax, _cpt in axes_per_index.path_with_nodes(*leafkey).items(): + target_path_acc_.update( + target_path_per_cpt_per_index.get((_ax.id, _cpt), {}) + ) + target_path_acc_ = freeze(target_path_acc_) + + retval = _index_axes_rec( + indices, + indices_acc_, + target_path_acc_, + current_index=subindex, + **kwargs, + ) + subaxes[leafkey] = retval[0] - target_path_per_component = pmap(target_path_per_cpt_per_index) - index_exprs_per_component = pmap(index_exprs_per_cpt_per_index) - layout_exprs_per_component = pmap(layout_exprs_per_cpt_per_index) + for key in retval[1].keys(): + if key in target_path_per_cpt_per_index: + target_path_per_cpt_per_index[key] = ( + target_path_per_cpt_per_index[key] | retval[1][key] + ) + index_exprs_per_cpt_per_index[key] = ( + index_exprs_per_cpt_per_index[key] | retval[2][key] + ) + layout_exprs_per_cpt_per_index[key] = ( + layout_exprs_per_cpt_per_index[key] | retval[3][key] + ) + else: + target_path_per_cpt_per_index.update({key: retval[1][key]}) + index_exprs_per_cpt_per_index.update({key: retval[2][key]}) + layout_exprs_per_cpt_per_index.update({key: retval[3][key]}) + + outer_loops += retval[4] + + target_path_per_component = freeze(target_path_per_cpt_per_index) + index_exprs_per_component = thaw(index_exprs_per_cpt_per_index) + for key, inner in extra_index_exprs.items(): + if key in index_exprs_per_component: + for ax, expr in inner.items(): + assert ax not in index_exprs_per_component[key] + index_exprs_per_component[key][ax] = expr + else: + index_exprs_per_component[key] = inner + index_exprs_per_component = freeze(index_exprs_per_component) + layout_exprs_per_component = freeze(layout_exprs_per_cpt_per_index) - axes = axes_per_index + axes = PartialAxisTree(axes_per_index.parent_to_children) for k, subax in subaxes.items(): if subax is not None: if axes: axes = axes.add_subtree(subax, *k) else: - axes = subax + axes = PartialAxisTree(subax.parent_to_children) return ( axes, target_path_per_component, index_exprs_per_component, layout_exprs_per_component, - ) - - -def index_axes(axes, index_tree): - ( - indexed_axes, - target_path_per_indexed_cpt, - index_exprs_per_indexed_cpt, - layout_exprs_per_indexed_cpt, - ) = _index_axes(axes, index_tree, loop_context=index_tree.loop_context) - - target_paths, index_exprs, layout_exprs = _compose_bits( - axes, - axes.target_paths, - axes.index_exprs, - axes.layout_exprs, - indexed_axes, - target_path_per_indexed_cpt, - index_exprs_per_indexed_cpt, - layout_exprs_per_indexed_cpt, - ) - return AxisTree( - indexed_axes.parent_to_children, - target_paths, - index_exprs, - layout_exprs, + outer_loops, ) @@ -1238,14 +1804,18 @@ def _compose_bits( myaxlabel ] = mycptlabel + # testing, make sure we don't miss any new index_exprs + index_exprs[iaxis.id, icpt.label] |= iindex_exprs[iaxis.id, icpt.label] + # do a replacement for index exprs # compose index expressions, this does an *inside* substitution # so the final replace map is target -> f(src) # loop over the original replace map and substitute each value # but drop some bits if indexed out... and final map is per component of the new axtree - orig_index_exprs = prev_index_exprs[target_axis.id, target_cpt.label] + orig_index_exprs = prev_index_exprs.get( + (target_axis.id, target_cpt.label), pmap() + ) for axis_label, index_expr in orig_index_exprs.items(): - # new_index_expr = IndexExpressionReplacer(new_partial_index_exprs)( new_index_expr = IndexExpressionReplacer(new_partial_index_exprs)( index_expr ) @@ -1267,7 +1837,7 @@ def _compose_bits( if prev_layout_exprs is not None: full_replace_map = merge_dicts( [ - prev_layout_exprs[tgt_ax.id, tgt_cpt.label] + prev_layout_exprs.get((tgt_ax.id, tgt_cpt.label), pmap()) for tgt_ax, tgt_cpt in detailed_path.items() ] ) @@ -1275,13 +1845,20 @@ def _compose_bits( # always 1:1 for layouts mykey, myvalue = just_one(layout_expr.items()) mytargetpath = just_one(itarget_paths[ikey].keys()) - layout_expr_replace_map = { - mytargetpath: full_replace_map[mytargetpath] - } + # layout_expr_replace_map = { + # mytargetpath: full_replace_map[mytargetpath] + # } + layout_expr_replace_map = full_replace_map new_layout_expr = IndexExpressionReplacer(layout_expr_replace_map)( myvalue ) - layout_exprs[ikey][mykey] = new_layout_expr + + # this is a trick to get things working in Firedrake, needs more + # thought to understand what is going on + if ikey in layout_exprs and mykey in layout_exprs[ikey]: + assert layout_exprs[ikey][mykey] == new_layout_expr + else: + layout_exprs[ikey][mykey] = new_layout_expr isubaxis = indexed_axes.child(iaxis, icpt) if isubaxis: @@ -1320,56 +1897,215 @@ def _compose_bits( ) +@dataclasses.dataclass(frozen=True) +class IndexIteratorEntry: + index: LoopIndex + source_path: PMap + target_path: PMap + source_exprs: PMap + target_exprs: PMap + + @property + def loop_context(self): + return freeze({self.index.id: (self.source_path, self.target_path)}) + + @property + def target_replace_map(self): + return freeze( + { + self.index.id: {ax: expr for ax, expr in self.target_exprs.items()}, + # self.index.id: ( + # # {ax: expr for ax, expr in self.source_exprs.items()}, + # {ax: expr for ax, expr in self.target_exprs.items()}, + # ) + } + ) + + @property + def source_replace_map(self): + return freeze( + { + self.index.id: {ax: expr for ax, expr in self.source_exprs.items()}, + } + ) + + +def iter_loop(loop): + if len(loop.target_paths) != 1: + raise NotImplementedError + + if loop.iterset.outer_loops: + outer_loop = just_one(loop.iterset.outer_loops) + for indices in outer_loop.iter(): + for i, index in enumerate(loop.iterset.iter(indices)): + # hack needed because we mix up our source and target exprs + axis_label = just_one( + just_one(loop.iterset.target_paths.values()).keys() + ) + + # source_path = {} + source_expr = {loop.id: {axis_label: i}} + + target_expr_sym = merge_dicts(loop.iterset.index_exprs.values())[ + axis_label + ] + replace_map = {axis_label: i} + loop_exprs = merge_dicts(idx.target_replace_map for idx in indices) + target_expr = ExpressionEvaluator(replace_map, loop_exprs)( + target_expr_sym + ) + target_expr = {axis_label: target_expr} + + # new_exprs = {} + # evaluator = ExpressionEvaluator( + # indices | {axis.label: pt}, outer_replace_map + # ) + # for axlabel, index_expr in myindex_exprs.items(): + # new_index = evaluator(index_expr) + # assert new_index != index_expr + # new_exprs[axlabel] = new_index + + index = IndexIteratorEntry( + loop, source_path, target_path, source_expr, target_expr + ) + + yield indices + (index,) + else: + for i, index in enumerate(loop.iterset.iter()): + # hack needed because we mix up our source and target exprs + axis_label = just_one(just_one(loop.iterset.target_paths.values()).keys()) + + source_path = "NA" + target_path = "NA" + + source_expr = {axis_label: i} + + target_expr_sym = merge_dicts(loop.iterset.index_exprs.values())[axis_label] + replace_map = {axis_label: i} + target_expr = ExpressionEvaluator(replace_map, {})(target_expr_sym) + target_expr = {axis_label: target_expr} + + iter_entry = IndexIteratorEntry( + loop, + source_path, + target_path, + freeze(source_expr), + freeze(target_expr), + ) + yield (iter_entry,) + + def iter_axis_tree( + loop_index: LoopIndex, axes: AxisTree, target_paths, index_exprs, - outermap, + outer_loops=(), + include_loops=False, axis=None, path=pmap(), indices=pmap(), target_path=None, index_exprs_acc=None, ): + outer_replace_map = merge_dicts( + # iter_entry.target_replace_map for iter_entry in outer_loops + # iter_entry.source_replace_map + iter_entry.target_replace_map + for iter_entry in outer_loops + ) if target_path is None: assert index_exprs_acc is None target_path = target_paths.get(None, pmap()) + # Substitute the index exprs, which map target to source, into + # indices, giving target index exprs myindex_exprs = index_exprs.get(None, pmap()) + evaluator = ExpressionEvaluator(indices, outer_replace_map) new_exprs = {} for axlabel, index_expr in myindex_exprs.items(): - new_index = ExpressionEvaluator(outermap)(index_expr) - assert new_index != index_expr + # try: + # new_index = evaluator(index_expr) + # assert new_index != index_expr + # new_exprs[axlabel] = new_index + # except UnrecognisedAxisException: + # pass + new_index = evaluator(index_expr) + # assert new_index != index_expr new_exprs[axlabel] = new_index index_exprs_acc = freeze(new_exprs) if axes.is_empty: - yield pmap(), target_path, pmap(), index_exprs_acc + if include_loops: + # source_path = + breakpoint() + else: + source_path = pmap() + source_exprs = pmap() + yield IndexIteratorEntry( + loop_index, source_path, target_path, source_exprs, index_exprs_acc + ) return axis = axis or axes.root for component in axis.components: + # for efficiency do these outside the loop path_ = path | {axis.label: component.label} target_path_ = target_path | target_paths.get((axis.id, component.label), {}) - myindex_exprs = index_exprs[axis.id, component.label] + myindex_exprs = index_exprs.get((axis.id, component.label), pmap()) subaxis = axes.child(axis, component) - for pt in range(_as_int(component.count, path, indices)): + + # bit of a hack, I reckon this can go as we can just get it from component.count + # inside as_int + if isinstance(component.count, HierarchicalArray): + mypath = component.count.target_paths.get(None, {}) + myindices = component.count.index_exprs.get(None, {}) + if not component.count.axes.is_empty: + for cax, ccpt in component.count.axes.path_with_nodes( + *component.count.axes.leaf + ).items(): + mypath.update(component.count.target_paths.get((cax.id, ccpt), {})) + myindices.update( + component.count.index_exprs.get((cax.id, ccpt), {}) + ) + + mypath = freeze(mypath) + myindices = freeze(myindices) + replace_map = indices + else: + mypath = pmap() + myindices = pmap() + replace_map = None + + for pt in range( + _as_int( + component.count, + replace_map, + # mypath, # + # myindices, + loop_indices=outer_replace_map, + ) + ): new_exprs = {} + evaluator = ExpressionEvaluator( + indices | {axis.label: pt}, outer_replace_map + ) for axlabel, index_expr in myindex_exprs.items(): - new_index = ExpressionEvaluator(outermap | indices | {axis.label: pt})( - index_expr - ) + new_index = evaluator(index_expr) assert new_index != index_expr new_exprs[axlabel] = new_index + # breakpoint() index_exprs_ = index_exprs_acc | new_exprs indices_ = indices | {axis.label: pt} if subaxis: yield from iter_axis_tree( + loop_index, axes, target_paths, index_exprs, - outermap, + outer_loops, + include_loops, subaxis, path_, indices_, @@ -1377,7 +2113,20 @@ def iter_axis_tree( index_exprs_, ) else: - yield path_, target_path_, indices_, index_exprs_ + if include_loops: + assert False, "old code" + source_path = path_ | merge_dicts( + ol.source_path for ol in outer_loops + ) + source_exprs = indices_ | merge_dicts( + ol.source_exprs for ol in outer_loops + ) + else: + source_path = path_ + source_exprs = indices_ + yield IndexIteratorEntry( + loop_index, source_path, target_path_, source_exprs, index_exprs_ + ) class ArrayPointLabel(enum.IntEnum): @@ -1417,8 +2166,8 @@ def partition_iterset(index: LoopIndex, arrays): from pyop3.array import HierarchicalArray # take first - if index.iterset.depth > 1: - raise NotImplementedError("Need a good way to sniff the parallel axis") + # if index.iterset.depth > 1: + # raise NotImplementedError("Need a good way to sniff the parallel axis") paraxis = index.iterset.root # FIXME, need indices per component @@ -1442,13 +2191,12 @@ def partition_iterset(index: LoopIndex, arrays): is_root_or_leaf_per_array[array.name] = is_root_or_leaf labels = np.full(paraxis.size, IterationPointType.CORE, dtype=np.uint8) - for path, target_path, indices, target_indices in index.iter(): - parindex = indices[paraxis.label] - assert isinstance(parindex, numbers.Integral) + for p in index.iterset.iter(): + # hack because I wrote bad code and mix up loop indices and itersets + p = dataclasses.replace(p, index=index) - replace_map = freeze( - {(index.id, axis): i for axis, i in target_indices.items()} - ) + parindex = p.source_exprs[paraxis.label] + assert isinstance(parindex, numbers.Integral) for array in arrays: # skip purely local arrays @@ -1458,15 +2206,10 @@ def partition_iterset(index: LoopIndex, arrays): continue # loop over stencil - array = array.with_context({index.id: target_path}) + array = array.with_context({index.id: (p.source_path, p.target_path)}) - for ( - array_path, - array_target_path, - array_indices, - array_target_indices, - ) in array.iter_indices(replace_map): - offset = array.simple_offset(array_target_path, array_target_indices) + for q in array.iter_indices({p}): + offset = array.offset(q.target_exprs, q.target_path) point_label = is_root_or_leaf_per_array[array.name][offset] if point_label == ArrayPointLabel.LEAF: @@ -1505,7 +2248,6 @@ def partition_iterset(index: LoopIndex, arrays): Slice( paraxis.label, [Subset(parcpt.label, subsets[0])], - label=paraxis.label, ) ] diff --git a/pyop3/lang.py b/pyop3/lang.py index 156e99ad..8183f105 100644 --- a/pyop3/lang.py +++ b/pyop3/lang.py @@ -1,3 +1,5 @@ +# TODO Rename this file insn.py - the pyop3 language is everything, not just this + from __future__ import annotations import abc @@ -6,6 +8,7 @@ import dataclasses import enum import functools +import numbers import operator from collections import defaultdict from functools import cached_property, partial @@ -14,15 +17,23 @@ import loopy as lp import numpy as np -import pymbolic as pym import pytools -from pyrsistent import freeze, pmap +from pyrsistent import freeze -from pyop3.axtree import as_axis_tree +from pyop3.axtree import Axis, as_axis_tree from pyop3.axtree.tree import ContextFree, ContextSensitive, MultiArrayCollector from pyop3.config import config from pyop3.dtypes import IntType, dtype_limits -from pyop3.utils import as_tuple, checked_zip, just_one, merge_dicts, unique +from pyop3.utils import ( + UniqueRecord, + as_tuple, + auto, + checked_zip, + just_one, + merge_dicts, + single_valued, + unique, +) # TODO I don't think that this belongs in this file, it belongs to the function? @@ -39,20 +50,22 @@ class Intent(enum.Enum): MIN_RW = "min_rw" MAX_WRITE = "max_write" MAX_RW = "max_rw" + NA = "na" # TODO prefer NONE # old alias Access = Intent -READ = Access.READ -WRITE = Access.WRITE -RW = Access.RW -INC = Access.INC -MIN_RW = Access.MIN_RW -MIN_WRITE = Access.MIN_WRITE -MAX_RW = Access.MAX_RW -MAX_WRITE = Access.MAX_WRITE +READ = Intent.READ +WRITE = Intent.WRITE +RW = Intent.RW +INC = Intent.INC +MIN_RW = Intent.MIN_RW +MIN_WRITE = Intent.MIN_WRITE +MAX_RW = Intent.MAX_RW +MAX_WRITE = Intent.MAX_WRITE +NA = Intent.NA class IntentMismatchError(Exception): @@ -62,65 +75,285 @@ class IntentMismatchError(Exception): class KernelArgument(abc.ABC): """Class representing objects that may be passed as arguments to kernels.""" + @property + @abc.abstractmethod + def kernel_dtype(self): + pass + -class LoopExpr(pytools.ImmutableRecord, abc.ABC): - fields = set() +# this is an expression, like passing an array through to a kernel +# but it is transformed first. +class Pack(KernelArgument, ContextFree): + def __init__(self, big, small): + self.big = big + self.small = small + @property + def kernel_dtype(self): + try: + return single_valued([self.big.dtype, self.small.dtype]) + except ValueError: + raise ValueError("dtypes must match") + + +class Instruction(UniqueRecord, abc.ABC): + pass + + +class ContextAwareInstruction(Instruction): @property @abc.abstractmethod def datamap(self): - """Map from names to arrays. + """Map from names to arrays.""" - weakref since we don't want to hold a reference to these things? - """ - pass + # TODO I think this can be combined with datamap + @property + @abc.abstractmethod + def kernel_arguments(self): + """Kernel arguments and their intents. - # nice for drawing diagrams - # @property - # @abc.abstractmethod - # def operands(self) -> tuple["LoopExpr"]: - # pass + The arguments are ordered according to when they first appear in + the expression. + Notes + ----- + At the moment arguments are not allowed to appear in the expression + multiple times with different intents. This would required thought into + how to resolve read-after-write and similar dependencies. + + """ -class Loop(LoopExpr): - fields = LoopExpr.fields | {"index", "statements", "id", "depends_on"} + +class Loop(Instruction): + fields = Instruction.fields | {"index", "statements"} # doubt that I need an ID here id_generator = pytools.UniqueNameGenerator() def __init__( self, - index: IndexTree, - statements: Sequence[LoopExpr], - id=None, - depends_on=frozenset(), + index: LoopIndex, + statements: Iterable[Instruction], + **kwargs, ): - # FIXME - # assert isinstance(index, pyop3.tensors.Indexed) - if not id: - id = self.id_generator("loop") - - super().__init__() - + super().__init__(**kwargs) self.index = index self.statements = as_tuple(statements) - self.id = id - # I think this can go if I generate code properly - self.depends_on = depends_on - # maybe these should not exist? backwards compat - @property - def axes(self): - return self.index.axes + def __call__(self, **kwargs): + # TODO just parse into ContextAwareLoop and call that + from pyop3.ir.lower import compile + from pyop3.itree.tree import partition_iterset - @property - def indices(self): - return self.index.indices + if self.is_parallel: + # interleave computation and communication + new_index, (icore, iroot, ileaf) = partition_iterset( + self.index, [a for a, _ in self.kernel_arguments] + ) + + assert self.index.id == new_index.id + + # substitute subsets into loopexpr, should maybe be done in partition_iterset + parallel_loop = self.copy(index=new_index) + code = compile(parallel_loop) + + # interleave communication and computation + initializers, finalizerss = self._array_updates() + + for init in initializers: + init() + + # replace the parallel axis subset with one for the specific indices here + extent = just_one(icore.axes.root.components).count + core_kwargs = merge_dicts( + [kwargs, {icore.name: icore, extent.name: extent}] + ) + code(**core_kwargs) + + # await reductions + for fin in finalizerss[0]: + fin() + + # roots + # replace the parallel axis subset with one for the specific indices here + root_extent = just_one(iroot.axes.root.components).count + root_kwargs = merge_dicts( + [kwargs, {icore.name: iroot, extent.name: root_extent}] + ) + code(**root_kwargs) + + # await broadcasts + for fin in finalizerss[1]: + fin() + + # leaves + leaf_extent = just_one(ileaf.axes.root.components).count + leaf_kwargs = merge_dicts( + [kwargs, {icore.name: ileaf, extent.name: leaf_extent}] + ) + code(**leaf_kwargs) + + # also may need to eagerly assemble Mats, or be clever and spike the accessors? + else: + compile(self)(**kwargs) + + @cached_property + def loopy_code(self): + from pyop3.ir.lower import compile + + return compile(self) + + @cached_property + def is_parallel(self): + return len(self._distarray_args) > 0 + + @cached_property + def kernel_arguments(self): + args = {} + for stmt in self.statements: + for arg, intent in stmt.kernel_arguments: + assert isinstance(arg, KernelArgument) + if arg not in args: + args[arg] = intent + else: + # FIXME, I have disabled this check because currently we + # do something special for temporaries in Firedrake and the + # clash is of those. + pass + # if args[arg] != intent: + # raise NotImplementedError( + # "Kernel argument used with differing intents" + # ) + return tuple((arg, intent) for arg, intent in args.items()) + + @cached_property + def _distarray_args(self): + from pyop3.array import ContextSensitiveMultiArray, HierarchicalArray + from pyop3.buffer import DistributedBuffer + + arrays = {} + for arg, intent in self.kernel_arguments: + if isinstance(arg, ContextSensitiveMultiArray): + # take first + arg, *_ = arg.context_map.values() + + if ( + not isinstance(arg, HierarchicalArray) + or not isinstance(arg.buffer, DistributedBuffer) + or not arg.buffer.is_distributed + ): + continue + + if arg.array not in arrays: + arrays[arg.array] = (intent, _has_nontrivial_stencil(arg)) + else: + if arrays[arg.array][0] != intent: + # I think that it does not make sense to access arrays with + # different intents in the same kernel but that it is + # always OK if the same intent is used. + raise IntentMismatchError + + # We need to know if *any* uses of a particular array touch ghost points + if not arrays[arg.array][1] and _has_nontrivial_stencil(arg): + arrays[arg.array] = (intent, True) + + # now sort + return tuple( + (arr, *arrays[arr]) for arr in sorted(arrays.keys(), key=lambda a: a.name) + ) + + def _array_updates(self): + """Collect appropriate callables for updating shared values in the right order. + + Returns + ------- + (initializers, (finalizers0, finalizers1)) + Collections of callables to be executed at the right times. + + """ + initializers = [] + finalizerss = ([], []) + for array, intent, touches_ghost_points in self._distarray_args: + if intent in {READ, RW}: + if touches_ghost_points: + if not array._roots_valid: + initializers.append(array._reduce_leaves_to_roots_begin) + finalizerss[0].extend( + [ + array._reduce_leaves_to_roots_end, + array._broadcast_roots_to_leaves_begin, + ] + ) + finalizerss[1].append(array._broadcast_roots_to_leaves_end) + else: + initializers.append(array._broadcast_roots_to_leaves_begin) + finalizerss[1].append(array._broadcast_roots_to_leaves_end) + else: + if not array._roots_valid: + initializers.append(array._reduce_leaves_to_roots_begin) + finalizerss[0].append(array._reduce_leaves_to_roots_end) + + elif intent == WRITE: + # Assumes that all points are written to (i.e. not a subset). If + # this is not the case then a manual reduction is needed. + array._leaves_valid = False + array._pending_reduction = None + + elif intent in {INC, MIN_WRITE, MIN_RW, MAX_WRITE, MAX_RW}: # reductions + # We don't need to update roots if performing the same reduction + # again. For example we can increment into an array as many times + # as we want. The reduction only needs to be done when the + # data is read. + if array._roots_valid or intent == array._pending_reduction: + pass + else: + # We assume that all points are visited, and therefore that + # WRITE accesses do not need to update roots. If only a subset + # of entities are written to then a manual reduction is required. + # This is the same assumption that we make for data_wo and is + # explained in the documentation. + if intent in {INC, MIN_RW, MAX_RW}: + assert array._pending_reduction is not None + initializers.append(array._reduce_leaves_to_roots_begin) + finalizerss[0].append(array._reduce_leaves_to_roots_end) + + # We are modifying owned values so the leaves must now be wrong + array._leaves_valid = False + + # If ghost points are not modified then no future reduction is required + if not touches_ghost_points: + array._pending_reduction = None + else: + array._pending_reduction = intent + + # set leaves to appropriate nil value + if intent == INC: + array._data[array.sf.ileaf] = 0 + elif intent in {MIN_WRITE, MIN_RW}: + array._data[array.sf.ileaf] = dtype_limits(array.dtype).max + elif intent in {MAX_WRITE, MAX_RW}: + array._data[array.sf.ileaf] = dtype_limits(array.dtype).min + else: + raise AssertionError + + else: + raise AssertionError + + return initializers, finalizerss + + +class ContextAwareLoop(ContextAwareInstruction): + fields = Instruction.fields | {"index", "statements"} + + def __init__(self, index, statements, **kwargs): + super().__init__(**kwargs) + self.index = index + self.statements = statements @cached_property def datamap(self): return self.index.datamap | merge_dicts( - stmt.datamap for stmt in self.statements + stmt.datamap for stmts in self.statements.values() for stmt in stmts ) def __call__(self, **kwargs): @@ -130,7 +363,7 @@ def __call__(self, **kwargs): if self.is_parallel: # interleave computation and communication new_index, (icore, iroot, ileaf) = partition_iterset( - self.index, [a for a, _ in self.all_function_arguments] + self.index, [a for a, _ in self.kernel_arguments] ) assert self.index.id == new_index.id @@ -175,7 +408,7 @@ def __call__(self, **kwargs): ) code(**leaf_kwargs) - # also may need to eagerly assemble Mats, or be clever? + # also may need to eagerly assemble Mats, or be clever and spike the accessors? else: compile(self)(**kwargs) @@ -190,25 +423,26 @@ def is_parallel(self): return len(self._distarray_args) > 0 @cached_property - def all_function_arguments(self): - # TODO overly verbose - func_args = {} + def kernel_arguments(self): + args = {} for stmt in self.statements: - for arg, intent in stmt.all_function_arguments: - if arg not in func_args: - func_args[arg] = intent - # now sort - return tuple( - (arg, func_args[arg]) - for arg in sorted(func_args.keys(), key=lambda a: a.name) - ) + for arg, intent in stmt.kernel_arguments: + assert isinstance(arg, KernelArgument) + if arg not in args: + args[arg] = intent + else: + if args[arg] != intent: + raise NotImplementedError( + "Kernel argument used with differing intents" + ) + return tuple((arg, intent) for arg, intent in args.items()) @cached_property def _distarray_args(self): from pyop3.buffer import DistributedBuffer arrays = {} - for arg, intent in self.all_function_arguments: + for arg, intent in self.kernel_arguments: if ( not isinstance(arg.array, DistributedBuffer) or not arg.array.is_distributed @@ -313,6 +547,7 @@ def _array_updates(self): # TODO singledispatch +# TODO perhaps this is simply "has non unit stride"? def _has_nontrivial_stencil(array): """ @@ -332,6 +567,21 @@ def _has_nontrivial_stencil(array): raise TypeError +class Terminal(Instruction, abc.ABC): + @cached_property + def datamap(self): + return merge_dicts(a.datamap for a, _ in self.kernel_arguments) + + @property + @abc.abstractmethod + def argument_shapes(self): + pass + + @abc.abstractmethod + def with_arguments(self, arguments: Iterable[KernelArgument]): + pass + + @dataclasses.dataclass(frozen=True) class ArgumentSpec: access: Intent @@ -389,35 +639,38 @@ def __call__(self, *args): f"but received {len(args)}" ) if any( - spec.dtype.numpy_dtype != arg.dtype + spec.dtype.numpy_dtype != arg.kernel_dtype for spec, arg in checked_zip(self.argspec, args) + if arg.kernel_dtype is not auto ): raise ValueError("Arguments to the kernel have the wrong dtype") return CalledFunction(self, args) @property def argspec(self): - return tuple( - ArgumentSpec(access, arg.dtype, arg.shape) - for access, arg in zip( - self._access_descrs, self.code.default_entrypoint.args - ) - ) + spec = [] + for access, arg in checked_zip( + self._access_descrs, self.code.default_entrypoint.args + ): + shape = arg.shape if not isinstance(arg, lp.ValueArg) else () + spec.append(ArgumentSpec(access, arg.dtype, shape)) + return tuple(spec) @property def name(self): return self.code.default_entrypoint.name -class CalledFunction(LoopExpr): - def __init__(self, function, arguments): +class CalledFunction(Terminal): + fields = Terminal.fields | {"function", "arguments"} + + def __init__( + self, function: Function, arguments: Iterable[KernelArgument], **kwargs + ) -> None: + super().__init__(**kwargs) self.function = function self.arguments = arguments - @functools.cached_property - def datamap(self): - return merge_dicts([arg.datamap for arg in self.arguments]) - @property def name(self): return self.function.name @@ -426,81 +679,137 @@ def name(self): def argspec(self): return self.function.argspec - # FIXME NEXT: Expand ContextSensitive things here @property - def all_function_arguments(self): + def kernel_arguments(self): return tuple( - sorted( - [ - (arg, intent) - for arg, intent in checked_zip( - self.arguments, self.function._access_descrs - ) - ], - key=lambda a: a[0].name, - ) + (arg, intent) + for arg, intent in checked_zip(self.arguments, self.function._access_descrs) + # this isn't right, loop indices do not count here + if isinstance(arg, KernelArgument) ) + @property + def argument_shapes(self): + return tuple( + arg.shape if not isinstance(arg, lp.ValueArg) else () + for arg in self.function.code.default_entrypoint.args + ) -class Instruction(pytools.ImmutableRecord): - fields = set() + def with_arguments(self, arguments): + return self.copy(arguments=arguments) -class Assignment(Instruction): - fields = Instruction.fields | {"tensor", "temporary", "shape"} +class Assignment(Terminal, abc.ABC): + fields = Terminal.fields | {"assignee", "expression"} - def __init__(self, tensor, temporary, shape, **kwargs): - self.tensor = tensor - self.temporary = temporary - self.shape = shape + def __init__(self, assignee, expression, **kwargs): super().__init__(**kwargs) + self.assignee = assignee + self.expression = expression - # better name - @property - def array(self): - return self.tensor - + def __call__(self): + do_loop(Axis(1).index(), self) -class Read(Assignment): @property - def lhs(self): - return self.temporary + def arguments(self): + # FIXME Not sure this is right for complicated expressions + return (self.assignee, self.expression) @property - def rhs(self): - return self.tensor + def arrays(self): + from pyop3.array import HierarchicalArray + arrays_ = [self.assignee] + if isinstance(self.expression, HierarchicalArray): + arrays_.append(self.expression) + else: + if not isinstance(self.expression, numbers.Number): + raise NotImplementedError + return tuple(arrays_) + # collector = MultiArrayCollector() + # return collector(self.assignee) | collector(self.expression) -class Write(Assignment): @property - def lhs(self): - return self.tensor + def argument_shapes(self): + return (None,) * len(self.kernel_arguments) - @property - def rhs(self): - return self.temporary + def with_arguments(self, arguments): + if len(arguments) != 2: + raise ValueError("Must provide 2 arguments") + assignee, expression = arguments + return self.copy(assignee=assignee, expression=expression) -class Increment(Assignment): @property - def lhs(self): - return self.tensor + def _expression_kernel_arguments(self): + from pyop3.array import HierarchicalArray - @property - def rhs(self): - return self.temporary + if isinstance(self.expression, HierarchicalArray): + return ((self.expression, READ),) + elif isinstance(self.expression, numbers.Number): + return () + else: + raise NotImplementedError("Complicated rvalues not yet supported") + + +class ReplaceAssignment(Assignment): + """Like PETSC_INSERT_VALUES.""" + + @cached_property + def kernel_arguments(self): + return ((self.assignee, WRITE),) + self._expression_kernel_arguments + + +class AddAssignment(Assignment): + """Like PETSC_ADD_VALUES.""" + + @cached_property + def kernel_arguments(self): + return ((self.assignee, INC),) + self._expression_kernel_arguments -class Zero(Assignment): +# inherit from Assignment? +class PetscMatInstruction(Instruction): + def __init__(self, mat_arg, array_arg): + self.mat_arg = mat_arg + self.array_arg = array_arg + @property - def lhs(self): - return self.temporary + def datamap(self): + return self.mat_arg.datamap | self.array_arg.datamap + + +class PetscMatLoad(PetscMatInstruction): + ... + + +class PetscMatStore(PetscMatInstruction): + ... + + +# potentially confusing name +class PetscMatAdd(PetscMatInstruction): + ... + + +class OpaqueKernelArgument(KernelArgument, ContextFree): + def __init__(self, dtype=auto): + self._dtype = dtype - # FIXME @property - def rhs(self): - # return 0 - return self.tensor + def kernel_dtype(self): + return self._dtype + + +class DummyKernelArgument(OpaqueKernelArgument): + """Placeholder kernel argument. + + This class is useful when one simply wants to generate code from a loop + expression and not execute it. + + ### dtypes not required here as sniffed from local kernel/context? + + """ def loop(*args, **kwargs): @@ -526,8 +835,8 @@ def fix_intents(tunit, accesses): kernel = tunit.default_entrypoint new_args = [] for arg, access in checked_zip(kernel.args, accesses): - assert access in {READ, WRITE, RW, INC, MIN_RW, MIN_WRITE, MAX_RW, MAX_WRITE} - is_input = access in {READ, RW, INC, MIN_RW, MAX_RW} + assert isinstance(access, Intent) + is_input = access in {READ, RW, INC, MIN_RW, MAX_RW, NA} is_output = access in {WRITE, RW, INC, MIN_RW, MIN_WRITE, MAX_WRITE, MAX_RW} new_args.append(arg.copy(is_input=is_input, is_output=is_output)) return tunit.with_kernel(kernel.copy(args=new_args)) diff --git a/pyop3/mpi.py b/pyop3/mpi.py index 8620997f..6c8905ee 100644 --- a/pyop3/mpi.py +++ b/pyop3/mpi.py @@ -39,6 +39,7 @@ import glob import os import tempfile +import weakref from itertools import count from mpi4py import MPI # noqa @@ -77,6 +78,8 @@ _DUPED_COMM_DICT = {} # Flag to indicate whether we are in cleanup (at exit) PYOP2_FINALIZED = False +# Flag for outputting information at the end of testing (do not abuse!) +_running_on_ci = bool(os.environ.get("PYOP2_CI_TESTS")) class PyOP2CommError(ValueError): @@ -180,32 +183,48 @@ def delcomm_outer(comm, keyval, icomm): :arg icomm: The inner communicator, should have a reference to ``comm``. """ - # This will raise errors at cleanup time as some objects are already - # deleted, so we just skip - if not PYOP2_FINALIZED: - if keyval not in (innercomm_keyval, compilationcomm_keyval): - raise PyOP2CommError("Unexpected keyval") - ocomm = icomm.Get_attr(outercomm_keyval) - if ocomm is None: - raise PyOP2CommError( - "Inner comm does not have expected reference to outer comm" - ) + # Use debug printer that is safe to use at exit time + debug = finalize_safe_debug() + if keyval not in (innercomm_keyval, compilationcomm_keyval): + raise PyOP2CommError("Unexpected keyval") + + if keyval == innercomm_keyval: + debug(f"Deleting innercomm keyval on {comm.name}") + if keyval == compilationcomm_keyval: + debug(f"Deleting compilationcomm keyval on {comm.name}") + + ocomm = icomm.Get_attr(outercomm_keyval) + if ocomm is None: + raise PyOP2CommError( + "Inner comm does not have expected reference to outer comm" + ) - if ocomm != comm: - raise PyOP2CommError("Inner comm has reference to non-matching outer comm") - icomm.Delete_attr(outercomm_keyval) - - # Once we have removed the ref to the inner/compilation comm we can free it - cidx = icomm.Get_attr(cidx_keyval) - cidx = cidx[0] - del _DUPED_COMM_DICT[cidx] - gc.collect() - refcount = icomm.Get_attr(refcount_keyval) - if refcount[0] > 1: - raise PyOP2CommError( - "References to comm still held, this will cause deadlock" - ) - icomm.Free() + if ocomm != comm: + raise PyOP2CommError("Inner comm has reference to non-matching outer comm") + icomm.Delete_attr(outercomm_keyval) + + # An inner comm may or may not hold a reference to a compilation comm + comp_comm = icomm.Get_attr(compilationcomm_keyval) + if comp_comm is not None: + debug("Removing compilation comm on inner comm") + decref(comp_comm) + icomm.Delete_attr(compilationcomm_keyval) + + # Once we have removed the reference to the inner/compilation comm we can free it + cidx = icomm.Get_attr(cidx_keyval) + cidx = cidx[0] + del _DUPED_COMM_DICT[cidx] + gc.collect() + refcount = icomm.Get_attr(refcount_keyval) + if refcount[0] > 1: + # In the case where `comm` is a custom user communicator there may be references + # to the inner comm still held and this is not an issue, but there is not an + # easy way to distinguish this case, so we just log the event. + debug( + f"There are still {refcount[0]} references to {comm.name}, " + "this will cause deadlock if the communicator has been incorrectly freed" + ) + icomm.Free() # Reference count, creation index, inner/outer/compilation communicator @@ -224,26 +243,23 @@ def is_pyop2_comm(comm): :arg comm: Communicator to query """ - global PYOP2_FINALIZED if isinstance(comm, PETSc.Comm): ispyop2comm = False elif comm == MPI.COMM_NULL: - if not PYOP2_FINALIZED: - raise PyOP2CommError("Communicator passed to is_pyop2_comm() is COMM_NULL") - else: - ispyop2comm = True + raise PyOP2CommError("Communicator passed to is_pyop2_comm() is COMM_NULL") elif isinstance(comm, MPI.Comm): ispyop2comm = bool(comm.Get_attr(refcount_keyval)) else: raise PyOP2CommError( - f"Argument passed to is_pyop2_comm() is a {type(comm)}, which is not a " - "recognised comm type" + f"Argument passed to is_pyop2_comm() is a {type(comm)}, which is not a recognised comm type" ) return ispyop2comm def pyop2_comm_status(): - """Prints the reference counts for all comms PyOP2 has duplicated""" + """Return string containing a table of the reference counts for all + communicators PyOP2 has duplicated. + """ status_string = "PYOP2 Communicator reference counts:\n" status_string += "| Communicator name | Count |\n" status_string += "==================================================\n" @@ -267,10 +283,7 @@ class temp_internal_comm: def __init__(self, comm): self.user_comm = comm - self.internal_comm = internal_comm(self.user_comm) - - def __del__(self): - decref(self.internal_comm) + self.internal_comm = internal_comm(self.user_comm, self) def __enter__(self): """Returns an internal comm that will be safely decref'd @@ -284,10 +297,12 @@ def __exit__(self, exc_type, exc_value, traceback): pass -def internal_comm(comm): +def internal_comm(comm, obj): """Creates an internal comm from the user comm. If comm is None, create an internal communicator from COMM_WORLD :arg comm: A communicator or None + :arg obj: The object which the comm is an attribute of + (usually `self`) :returns pyop2_comm: A PyOP2 internal communicator """ @@ -310,6 +325,7 @@ def internal_comm(comm): pyop2_comm = comm else: pyop2_comm = dup_comm(comm) + weakref.finalize(obj, decref, pyop2_comm) return pyop2_comm @@ -322,21 +338,20 @@ def incref(comm): def decref(comm): """Decrement communicator reference count""" - if not PYOP2_FINALIZED: + if comm == MPI.COMM_NULL: + # This case occurs if the the outer communicator has already been freed by + # the user + debug("Cannot decref an already freed communicator") + else: assert is_pyop2_comm(comm) refcount = comm.Get_attr(refcount_keyval) refcount[0] -= 1 - if refcount[0] == 1: - # Freeing the comm is handled by the destruction of the user comm - pass - elif refcount[0] < 1: + # Freeing the internal comm is handled by the destruction of the user comm + if refcount[0] < 1: raise PyOP2CommError( "Reference count is less than 1, decref called too many times" ) - elif comm != MPI.COMM_NULL: - comm.Free() - def dup_comm(comm_in): """Given a communicator return a communicator for internal use. @@ -388,10 +403,10 @@ def create_split_comm(comm): else: debug("Creating compilation communicator using MPI_Split + filesystem") if comm.rank == 0: - if not os.path.exists(config["cache_dir"]): - os.makedirs(config["cache_dir"], exist_ok=True) + if not os.path.exists(configuration["cache_dir"]): + os.makedirs(configuration["cache_dir"], exist_ok=True) tmpname = tempfile.mkdtemp( - prefix="rank-determination-", dir=config["cache_dir"] + prefix="rank-determination-", dir=configuration["cache_dir"] ) else: tmpname = None @@ -438,7 +453,7 @@ def set_compilation_comm(comm, comp_comm): if not is_pyop2_comm(comp_comm): raise PyOP2CommError( - "Communicator used for compilation must be a PyOP2 communicator.\n" + "Communicator used for compilation communicator must be a PyOP2 communicator.\n" "Use pyop2.mpi.dup_comm() to create a PyOP2 comm from an existing comm." ) else: @@ -446,8 +461,7 @@ def set_compilation_comm(comm, comp_comm): # Clean up old_comp_comm before setting new one if not is_pyop2_comm(old_comp_comm): raise PyOP2CommError( - "Compilation communicator is not a PyOP2 comm, something is " - "very broken!" + "Compilation communicator is not a PyOP2 comm, something is very broken!" ) gc.collect() decref(old_comp_comm) @@ -458,10 +472,13 @@ def set_compilation_comm(comm, comp_comm): @collective -def compilation_comm(comm): +def compilation_comm(comm, obj): """Get a communicator for compilation. :arg comm: The input communicator, must be a PyOP2 comm. + :arg obj: The object which the comm is an attribute of + (usually `self`) + :returns: A communicator used for compilation (may be smaller) """ if not is_pyop2_comm(comm): @@ -483,35 +500,59 @@ def compilation_comm(comm): else: comp_comm = comm incref(comp_comm) + weakref.finalize(obj, decref, comp_comm) return comp_comm +def finalize_safe_debug(): + """Return function for debug output. + + When Python is finalizing the logging module may be finalized before we have + finished writing debug information. In this case we fall back to using the + Python `print` function to output debugging information. + + Furthermore, we always want to see this finalization information when + running the CI tests. + """ + global debug + if PYOP2_FINALIZED: + if logger.level > DEBUG and not _running_on_ci: + debug = lambda string: None + else: + debug = lambda string: print(string) + return debug + + @atexit.register def _free_comms(): """Free all outstanding communicators.""" global PYOP2_FINALIZED PYOP2_FINALIZED = True - if logger.level > DEBUG: - debug = lambda string: None - else: - debug = lambda string: print(string) + debug = finalize_safe_debug() debug("PyOP2 Finalizing") # Collect garbage as it may hold on to communicator references + debug("Calling gc.collect()") gc.collect() + debug("STATE0") + debug(pyop2_comm_status()) + debug("Freeing PYOP2_COMM_WORLD") COMM_WORLD.Free() + debug("STATE1") + debug(pyop2_comm_status()) + debug("Freeing PYOP2_COMM_SELF") COMM_SELF.Free() + debug("STATE2") debug(pyop2_comm_status()) debug(f"Freeing comms in list (length {len(_DUPED_COMM_DICT)})") - for key in sorted(_DUPED_COMM_DICT.keys()): + for key in sorted(_DUPED_COMM_DICT.keys(), reverse=True): comm = _DUPED_COMM_DICT[key] if comm != MPI.COMM_NULL: refcount = comm.Get_attr(refcount_keyval) debug( - f"Freeing {comm.name}, with index {key}, which has " - f"refcount {refcount[0]}" + f"Freeing {comm.name}, with index {key}, which has refcount {refcount[0]}" ) comm.Free() del _DUPED_COMM_DICT[key] diff --git a/pyop3/sf.py b/pyop3/sf.py index 292a9bb6..e71a1459 100644 --- a/pyop3/sf.py +++ b/pyop3/sf.py @@ -5,6 +5,7 @@ from petsc4py import PETSc from pyop3.dtypes import get_mpi_dtype +from pyop3.mpi import internal_comm from pyop3.utils import just_one @@ -19,12 +20,21 @@ def __init__(self, sf, size: int): self.sf = sf self.size = size + # don't like this pattern + self._comm = internal_comm(sf.comm, self) + @classmethod - def from_graph(cls, size: int, nroots: int, ilocal, iremote, comm=None): - sf = PETSc.SF().create(comm or PETSc.Sys.getDefaultComm()) + def from_graph(cls, size: int, nroots: int, ilocal, iremote, comm): + # from pyop3.extras.debug import print_with_rank + # print_with_rank(nroots, ilocal, iremote) + sf = PETSc.SF().create(comm) sf.setGraph(nroots, ilocal, iremote) return cls(sf, size) + @property + def comm(self): + return self.sf.comm + @cached_property def iroot(self): """Return the indices of roots on the current process.""" @@ -53,6 +63,10 @@ def icore(self): def nroots(self): return self._graph[0] + @property + def nowned(self): + return self.size - self.nleaves + @property def nleaves(self): return len(self.ileaf) @@ -108,3 +122,30 @@ def _prepare_args(self, *args): # what about cdim? dtype, _ = get_mpi_dtype(from_buffer.dtype) return (dtype, from_buffer, to_buffer, op) + + +def single_star(comm, size=1, root=0): + """Construct a star forest containing a single star. + + The single star has leaves on all ranks apart from the "root" rank that + point to the same shared data. This is useful for describing globally + consistent data structures. + + """ + if comm.rank == root: + # there are no leaves on the root process + nroots = size + ilocal = [] + iremote = [] + else: + nroots = 0 + ilocal = np.arange(size, dtype=np.int32) + iremote = [(root, i) for i in ilocal] + return StarForest.from_graph(size, nroots, ilocal, iremote, comm) + + +def serial_forest(size: int) -> StarForest: + nroots = 0 + ilocal = [] + iremote = [] + return StarForest.from_graph(size, nroots, ilocal, iremote, MPI.COMM_SELF) diff --git a/pyop3/target.py b/pyop3/target.py index 67461819..b7b6e914 100644 --- a/pyop3/target.py +++ b/pyop3/target.py @@ -253,14 +253,8 @@ def __init__( self._debug = config["debug"] # Compilation communicators are reference counted on the PyOP2 comm - self.pcomm = mpi.internal_comm(comm) - self.comm = mpi.compilation_comm(self.pcomm) - - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "pcomm"): - mpi.decref(self.pcomm) + self.pcomm = mpi.internal_comm(comm, self) + self.comm = mpi.compilation_comm(self.pcomm, self) def __repr__(self): return f"<{self._name} compiler, version {self.version or 'unknown'}>" @@ -385,7 +379,7 @@ def compile_library(self, code: str, name: str, argtypes, restype): # atomically (avoiding races). tmpname = os.path.join(cachedir, "%s_p%d.so.tmp" % (basename, pid)) - if config["check_src_hashes"]: + if config["check_src_hashes"] or config["debug"]: matching = self.comm.allreduce(basename, op=_check_op) if matching != basename: # Dump all src code to disk for debugging diff --git a/pyop3/tensor.py b/pyop3/tensor.py deleted file mode 100644 index 100e9eb8..00000000 --- a/pyop3/tensor.py +++ /dev/null @@ -1,43 +0,0 @@ -import abc - -from pyop3.array import Array -from pyop3.utils import UniqueNameGenerator - - -class Tensor(abc.ABC): - """Base class for all :mod:`pyop3` parallel objects.""" - - _prefix = "tensor" - _name_generator = UniqueNameGenerator() - - def __init__(self, array: Array, name=None, *, prefix=None) -> None: - if self.rank not in array.valid_ranks: - raise TypeError("Unsuitable array provided") - if name and prefix: - raise ValueError("Can only specify one of name and prefix") - - self.array = array - self.name = name or self._name_generator(prefix or self._prefix) - - @property - @abc.abstractmethod - def rank(self) -> int: - pass - - -class Global(Tensor): - @property - def rank(self) -> int: - return 0 - - -class Dat(Tensor): - @property - def rank(self) -> int: - return 1 - - -class Mat(Tensor): - @property - def rank(self) -> int: - return 2 diff --git a/pyop3/transform.py b/pyop3/transform.py new file mode 100644 index 00000000..4b1f2da9 --- /dev/null +++ b/pyop3/transform.py @@ -0,0 +1,557 @@ +from __future__ import annotations + +import abc +import collections +import functools +import numbers + +from pyrsistent import freeze, pmap + +from pyop3.array import ContextSensitiveMultiArray, HierarchicalArray, PetscMat +from pyop3.axtree import Axis, AxisTree, ContextFree, ContextSensitive +from pyop3.buffer import DistributedBuffer, NullBuffer, PackedBuffer +from pyop3.itree import Map, TabulatedMapComponent +from pyop3.lang import ( + INC, + NA, + READ, + RW, + WRITE, + AddAssignment, + Assignment, + CalledFunction, + ContextAwareLoop, + DummyKernelArgument, + Instruction, + Loop, + Pack, + PetscMatAdd, + PetscMatLoad, + PetscMatStore, + ReplaceAssignment, + Terminal, +) +from pyop3.utils import UniqueNameGenerator, checked_zip, just_one + + +# TODO Is this generic for other parsers/transformers? Esp. lower.py +class Transformer(abc.ABC): + @abc.abstractmethod + def apply(self, expr): + pass + + +""" +TODO +We sometimes want to pass loop indices to functions even without an external loop. +This is particularly useful when we only want to generate code. We should (?) unpick +this so that there is an outer set of loop contexts that applies at the highest level. + +Alternatively, we enforce that this loop exists. But I don't think that that's feasible +right now. +""" + + +class LoopContextExpander(Transformer): + # TODO prefer __call__ instead + def apply(self, expr: Instruction): + return self._apply(expr, context=pmap()) + + @functools.singledispatchmethod + def _apply(self, expr: Instruction, **kwargs): + raise TypeError(f"No handler provided for {type(expr).__name__}") + + @_apply.register + def _(self, loop: Loop, *, context): + # this is very similar to what happens in PetscMat.__getitem__ + outer_context = collections.defaultdict(dict) # ordered set per index + if isinstance(loop.index.iterset, ContextSensitive): + for ctx in loop.index.iterset.context_map.keys(): + for index, paths in ctx.items(): + if index in context: + # assert paths == context[index] + continue + else: + outer_context[index][paths] = None + # convert ordered set to a list + outer_context = {k: tuple(v.keys()) for k, v in outer_context.items()} + + # convert to a product-like structure of [{index: paths, ...}, {index: paths}, ...] + outer_context_ = tuple(context_product(outer_context.items())) + + if not outer_context_: + outer_context_ = (pmap(),) + + loops = [] + for octx in outer_context_: + cf_iterset = loop.index.iterset.with_context(context | octx) + source_paths = cf_iterset.leaf_paths + target_paths = cf_iterset.leaf_target_paths + assert len(source_paths) == len(target_paths) + + if len(source_paths) == 1: + # single component iterset, no branching required + source_path = just_one(source_paths) + target_path = just_one(target_paths) + + context_ = context | {loop.index.id: (source_path, target_path)} + + statements = collections.defaultdict(list) + for stmt in loop.statements: + for myctx, mystmt in self._apply(stmt, context=context_ | octx): + if myctx: + raise NotImplementedError( + "need to think about how to wrap inner instructions " + "that need outer loops" + ) + statements[source_path].append(mystmt) + else: + assert len(source_paths) > 1 + statements = {} + for source_path, target_path in checked_zip(source_paths, target_paths): + context_ = context | {loop.index.id: (source_path, target_path)} + + statements[source_path] = [] + + for stmt in loop.statements: + for myctx, mystmt in self._apply(stmt, context=context_ | octx): + if myctx: + raise NotImplementedError( + "need to think about how to wrap inner instructions " + "that need outer loops" + ) + if mystmt is None: + continue + statements[source_path].append(mystmt) + + # FIXME this does not propagate inner outer contexts + loop = ContextAwareLoop( + loop.index.copy(iterset=cf_iterset), + statements, + ) + loops.append((octx, loop)) + return tuple(loops) + + @_apply.register + def _(self, terminal: CalledFunction, *, context): + # this is very similar to what happens in PetscMat.__getitem__ + outer_context = collections.defaultdict(dict) # ordered set per index + for arg in terminal.arguments: + if not isinstance(arg, ContextSensitive): + continue + + for ctx in arg.context_map.keys(): + for index, paths in ctx.items(): + if index in context: + assert paths == context[index] + else: + outer_context[index][paths] = None + # convert ordered set to a list + outer_context = {k: tuple(v.keys()) for k, v in outer_context.items()} + + # convert to a product-like structure of [{index: paths, ...}, {index: paths}, ...] + outer_context_ = tuple(context_product(outer_context.items())) + + if not outer_context_: + outer_context_ = (pmap(),) + + for arg in terminal.arguments: + if isinstance(arg, ContextSensitive): + outer_context.update( + { + index: paths + for ctx in arg.context_map.keys() + for index, paths in ctx.items() + if index not in context + } + ) + + retval = [] + for octx in outer_context_: + cf_args = [a.with_context(octx | context) for a in terminal.arguments] + retval.append((octx, terminal.with_arguments(cf_args))) + return retval + + @_apply.register + def _(self, terminal: Assignment, *, context): + # FIXME for now we assume an outer context of {}. In other words anything + # context sensitive in the assignment is completely handled by the existing + # outer loops. + + valid = True + cf_args = [] + for arg in terminal.arguments: + try: + cf_arg = ( + arg.with_context(context) + if isinstance(arg, ContextSensitive) + else arg + ) + # FIXME We will hit issues here when we are missing outer context I think + except KeyError: + # assignment is not valid in this context, do nothing + valid = False + break + cf_args.append(cf_arg) + + if valid: + return ((pmap(), terminal.with_arguments(cf_args)),) + else: + return ((pmap(), None),) + + +def expand_loop_contexts(expr: Instruction): + return LoopContextExpander().apply(expr) + + +def context_product(contexts, acc=pmap()): + contexts = tuple(contexts) + + if not contexts: + return acc + + ctx, *subctxs = contexts + index, pathss = ctx + for paths in pathss: + acc_ = acc | {index: paths} + if subctxs: + yield from context_product(subctxs, acc_) + else: + yield acc_ + + +class ImplicitPackUnpackExpander(Transformer): + def __init__(self): + self._name_generator = UniqueNameGenerator() + + def apply(self, expr): + return self._apply(expr) + + @functools.singledispatchmethod + def _apply(self, expr: Any): + raise NotImplementedError(f"No handler provided for {type(expr).__name__}") + + # TODO Can I provide a generic "operands" thing? Put in the parent class? + @_apply.register + def _(self, loop: ContextAwareLoop): + return ( + loop.copy( + statements={ + ctx: [stmt_ for stmt in stmts for stmt_ in self._apply(stmt)] + for ctx, stmts in loop.statements.items() + } + ), + ) + + @_apply.register + def _(self, assignment: Assignment): + # same as for CalledFunction + gathers = [] + # NOTE: scatters are executed in LIFO order + scatters = [] + arguments = [] + + # lazy coding, tidy up + if isinstance(assignment, ReplaceAssignment): + access = WRITE + else: + assert isinstance(assignment, AddAssignment) + access = INC + for arg, intent in [ + (assignment.assignee, access), + (assignment.expression, READ), + ]: + if isinstance(arg, numbers.Number): + arguments.append(arg) + continue + + # emit function calls for PetscMat + # this is a separate stage to the assignment operations because one + # can index a packed mat. E.g. mat[p, q][::2] would decompose into + # two calls, one to pack t0 <- mat[p, q] and another to pack t1 <- t0[::2] + if isinstance(arg.buffer, PackedBuffer): + # TODO add PackedPetscMat as a subclass of buffer? + if not isinstance(arg.buffer.array, PetscMat): + raise NotImplementedError("Only handle Mat at the moment") + + axes = AxisTree(arg.axes.parent_to_children) + new_arg = HierarchicalArray( + axes, + data=NullBuffer(arg.dtype), # does this need a size? + prefix="t", + ) + + if intent == READ: + gathers.append(PetscMatLoad(arg, new_arg)) + elif intent == WRITE: + scatters.insert(0, PetscMatStore(arg, new_arg)) + elif intent == RW: + gathers.append(PetscMatLoad(arg, new_arg)) + scatters.insert(0, PetscMatStore(arg, new_arg)) + else: + assert intent == INC + scatters.insert(0, PetscMatAdd(arg, new_arg)) + + arguments.append(new_arg) + else: + arguments.append(arg) + + return (*gathers, assignment.with_arguments(arguments), *scatters) + + @_apply.register + def _(self, terminal: CalledFunction): + gathers = [] + # NOTE: scatters are executed in LIFO order + scatters = [] + arguments = [] + for (arg, intent), shape in checked_zip( + terminal.kernel_arguments, terminal.argument_shapes + ): + assert isinstance( + arg, ContextFree + ), "Loop contexts should already be expanded" + + if isinstance(arg, DummyKernelArgument): + arguments.append(arg) + continue + + # emit function calls for PetscMat + # this is a separate stage to the assignment operations because one + # can index a packed mat. E.g. mat[p, q][::2] would decompose into + # two calls, one to pack t0 <- mat[p, q] and another to pack t1 <- t0[::2] + if ( + isinstance(arg, Pack) + and isinstance(arg.big.buffer, PackedBuffer) + or not isinstance(arg, Pack) + and isinstance(arg.buffer, PackedBuffer) + ): + if isinstance(arg, Pack): + myarg = arg.big + else: + myarg = arg + + # TODO add PackedPetscMat as a subclass of buffer? + if not isinstance(myarg.buffer.array, PetscMat): + raise NotImplementedError("Only handle Mat at the moment") + + axes = AxisTree(myarg.axes.parent_to_children) + new_arg = HierarchicalArray( + axes, + data=NullBuffer(myarg.dtype), # does this need a size? + prefix="t", + ) + + if intent == READ: + gathers.append(PetscMatLoad(myarg, new_arg)) + elif intent == WRITE: + scatters.insert(0, PetscMatStore(myarg, new_arg)) + elif intent == RW: + gathers.append(PetscMatLoad(myarg, new_arg)) + scatters.insert(0, PetscMatStore(myarg, new_arg)) + else: + assert intent == INC + gathers.append(ReplaceAssignment(new_arg, 0)) + scatters.insert(0, PetscMatAdd(myarg, new_arg)) + + # the rest of the packing code is now dealing with the result of this + # function call + arg = new_arg + + # unpick pack/unpack instructions + if intent != NA and _requires_pack_unpack(arg): + if isinstance(arg, Pack): + temporary = arg.small + arg = arg.big + else: + axes = AxisTree(arg.axes.parent_to_children) + temporary = HierarchicalArray( + axes, + data=NullBuffer(arg.dtype), # does this need a size? + prefix="t", + ) + + if intent == READ: + gathers.append(ReplaceAssignment(temporary, arg)) + elif intent == WRITE: + gathers.append(ReplaceAssignment(temporary, 0)) + scatters.insert(0, ReplaceAssignment(arg, temporary)) + elif intent == RW: + gathers.append(ReplaceAssignment(temporary, arg)) + scatters.insert(0, ReplaceAssignment(arg, temporary)) + else: + assert intent == INC + gathers.append(ReplaceAssignment(temporary, 0)) + scatters.insert(0, AddAssignment(arg, temporary)) + + arguments.append(temporary) + + else: + arguments.append(arg) + + return (*gathers, terminal.with_arguments(arguments), *scatters) + + +# TODO check this docstring renders correctly +def expand_implicit_pack_unpack(expr: Instruction): + """Expand implicit pack and unpack operations. + + An implicit pack/unpack is something of the form + + .. code:: + kernel(dat[f(p)]) + + In order for this to work the ``dat[f(p)]`` needs to be packed + into a temporary. Assuming that its intent in ``kernel`` is + `pyop3.WRITE`, we would expand this function into + + .. code:: + tmp <- [0, 0, ...] + kernel(tmp) + dat[f(p)] <- tmp + + Notes + ----- + For this routine to work, any context-sensitive loops must have + been expanded already (with `expand_loop_contexts`). This is + because context-sensitive arrays may be packed into temporaries + in some contexts but not others. + + """ + return ImplicitPackUnpackExpander().apply(expr) + + +def _requires_pack_unpack(arg): + # TODO in theory packing isn't required for arrays that are contiguous, + # but this is hard to determine + # FIXME, we inefficiently copy matrix temporaries here because this + # doesn't identify requiring pack/unpack properly. To demonstrate + # kernel(mat[p, q]) + # gets turned into + # t0 <- mat[p, q] + # kernel(t0) + # However, the array mat[p, q] is actually retrieved from MatGetValues + # so we really have something like + # MatGetValues(mat, ..., t0) + # t1 <- t0 + # kernel(t1) + # and the same for unpacking + + # if subst_layouts and layouts are the same I *think* it is safe to avoid a pack/unpack + # however, it is overly restrictive since we could pass something like dat[i0, :] directly + # to a local kernel + # return isinstance(arg, HierarchicalArray) and arg.subst_layouts != arg.layouts + return isinstance(arg, HierarchicalArray) or isinstance(arg, Pack) + + +# *below is old untested code* +# +# def compress(iterset, map_func, *, uniquify=False): +# # TODO Ultimately we should be able to generate code for this set of +# # loops. We would need to have a construct to describe "unique packing" +# # with hash sets like we do in the Python version. PETSc have PetscHSetI +# # which I think would be suitable. +# +# if not uniquify: +# raise NotImplementedError("TODO") +# +# iterset = iterset.as_tree() +# +# # prepare size arrays, we want an array per target path per iterset path +# sizess = {} +# for leaf_axis, leaf_clabel in iterset.leaves: +# iterset_path = iterset.path(leaf_axis, leaf_clabel) +# +# # bit unpleasant to have to create a loop index for this +# sizes = {} +# index = iterset.index() +# cf_map = map_func(index).with_context({index.id: iterset_path}) +# for target_path in cf_map.leaf_target_paths: +# if iterset.depth != 1: +# # TODO For now we assume iterset to have depth 1 +# raise NotImplementedError +# # The axes of the size array correspond only to the specific +# # components selected from iterset by iterset_path. +# clabels = (just_one(iterset_path.values()),) +# subiterset = iterset[clabels] +# +# # subiterset is an axis tree with depth 1, we only want the axis +# assert subiterset.depth == 1 +# subiterset = subiterset.root +# +# sizes[target_path] = HierarchicalArray( +# subiterset, dtype=IntType, prefix="nnz" +# ) +# sizess[iterset_path] = sizes +# sizess = freeze(sizess) +# +# # count sizes +# for p in iterset.iter(): +# entries = collections.defaultdict(set) +# for q in map_func(p.index).iter({p}): +# # we expect maps to only output a single target index +# q_value = just_one(q.target_exprs.values()) +# entries[q.target_path].add(q_value) +# +# for target_path, points in entries.items(): +# npoints = len(points) +# nnz = sizess[p.source_path][target_path] +# nnz.set_value(p.source_path, p.source_exprs, npoints) +# +# # prepare map arrays +# flat_mapss = {} +# for iterset_path, sizes in sizess.items(): +# flat_maps = {} +# for target_path, nnz in sizes.items(): +# subiterset = nnz.axes.root +# map_axes = AxisTree.from_nest({subiterset: Axis(nnz)}) +# flat_maps[target_path] = HierarchicalArray( +# map_axes, dtype=IntType, prefix="map" +# ) +# flat_mapss[iterset_path] = flat_maps +# flat_mapss = freeze(flat_mapss) +# +# # populate compressed maps +# for p in iterset.iter(): +# entries = collections.defaultdict(set) +# for q in map_func(p.index).iter({p}): +# # we expect maps to only output a single target index +# q_value = just_one(q.target_exprs.values()) +# entries[q.target_path].add(q_value) +# +# for target_path, points in entries.items(): +# flat_map = flat_mapss[p.source_path][target_path] +# leaf_axis, leaf_clabel = flat_map.axes.leaf +# for i, pt in enumerate(sorted(points)): +# path = p.source_path | {leaf_axis.label: leaf_clabel} +# indices = p.source_exprs | {leaf_axis.label: i} +# flat_map.set_value(path, indices, pt) +# +# # build the actual map +# connectivity = {} +# for iterset_path, flat_maps in flat_mapss.items(): +# map_components = [] +# for target_path, flat_map in flat_maps.items(): +# # since maps only target a single axis, component pair +# target_axlabel, target_clabel = just_one(target_path.items()) +# map_component = TabulatedMapComponent( +# target_axlabel, target_clabel, flat_map +# ) +# map_components.append(map_component) +# connectivity[iterset_path] = map_components +# return Map(connectivity) +# +# +# def split_loop(loop: Loop, path, tile_size: int) -> Loop: +# orig_loop_index = loop.index +# +# # I think I need to transform the index expressions of the iterset? +# # or get a new iterset? let's try that +# # It will not work because then the target path would change and the +# # data structures would not know what to do. +# +# orig_index_exprs = orig_loop_index.index_exprs +# breakpoint() +# # new_index_exprs +# +# new_loop_index = orig_loop_index.copy(index_exprs=new_index_exprs) +# return loop.copy(index=new_loop_index) diff --git a/pyop3/transforms.py b/pyop3/transforms.py deleted file mode 100644 index 9cc57365..00000000 --- a/pyop3/transforms.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - - -def split_loop(loop: Loop, path, tile_size: int) -> Loop: - orig_loop_index = loop.index - - # I think I need to transform the index expressions of the iterset? - # or get a new iterset? let's try that - # It will not work because then the target path would change and the - # data structures would not know what to do. - - orig_index_exprs = orig_loop_index.index_exprs - breakpoint() - # new_index_exprs - - new_loop_index = orig_loop_index.copy(index_exprs=new_index_exprs) - return loop.copy(index=new_loop_index) diff --git a/pyop3/tree.py b/pyop3/tree.py index 416648a9..98bb6d4f 100644 --- a/pyop3/tree.py +++ b/pyop3/tree.py @@ -3,6 +3,8 @@ import abc import collections import functools +import operator +from collections import defaultdict from collections.abc import Hashable, Sequence from functools import cached_property from itertools import chain @@ -17,6 +19,7 @@ Identified, Label, Labelled, + UniqueNameGenerator, apply_at, as_tuple, checked_zip, @@ -39,6 +42,10 @@ class EmptyTreeException(Exception): pass +class InvalidTreeException(ValueError): + pass + + class Node(pytools.ImmutableRecord, Identified): fields = {"id"} @@ -47,6 +54,7 @@ def __init__(self, id=None): Identified.__init__(self, id) +# TODO delete this class, no longer different tree types class AbstractTree(pytools.ImmutableRecord, abc.ABC): fields = {"parent_to_children"} @@ -100,15 +108,19 @@ def id_to_node(self): @cached_property def nodes(self): + # NOTE: Keep this sorted! Else strange results occur if self.is_empty: - return frozenset() - return frozenset( - { - node - for node in chain.from_iterable(self.parent_to_children.values()) - if node is not None - } - ) + return () + return self._collect_nodes(self.root) + + def _collect_nodes(self, node): + assert not self.is_empty + nodes = [node] + for subnode in self.children(node): + if subnode is None: + continue + nodes.extend(self._collect_nodes(subnode)) + return tuple(nodes) @property @abc.abstractmethod @@ -205,71 +217,6 @@ def _as_node_id(node): return node.id if isinstance(node, Node) else node -class Tree(AbstractTree): - @cached_property - def leaves(self): - return tuple( - node - for node in self.nodes - if all(c is None for c in self.parent_to_children.get(node.id, ())) - ) - - def add_node( - self, - node, - parent=None, - uniquify=False, - ): - if parent is None: - if not self.is_empty: - raise ValueError("Cannot add multiple roots") - return self.copy(parent_to_children={None: (node,)}) - else: - parent = self._as_node(parent) - if node in self: - if uniquify: - node = node.copy(id=node.unique_id()) - else: - raise ValueError("Cannot insert a node with the same ID") - - parent_to_children = { - k: list(v) for k, v in self.parent_to_children.items() - } - parent_to_children[parent.id].append(node) - # missing root, not used I think - raise NotImplementedError - return self.copy(parent_to_children=parent_to_children) - - @classmethod - def _from_nest(cls, nest): - # TODO add appropriate exception classes - if isinstance(nest, collections.abc.Mapping): - assert len(nest) == 1 - node, subnodes = just_one(nest.items()) - node = cls._parse_node(node) - - if isinstance(subnodes, collections.abc.Mapping): - if len(subnodes) == 1 and isinstance(just_one(subnodes.keys()), Node): - # just one subnode - subnodes = [subnodes] - else: - raise ValueError - elif not isinstance(subnodes, collections.abc.Sequence): - subnodes = [subnodes] - - children = [] - parent_to_children = {} - for subnode in subnodes: - subnode_, sub_p2c = cls._from_nest(subnode) - children.append(subnode_) - parent_to_children.update(sub_p2c) - parent_to_children[node.id] = children - return node, parent_to_children - else: - node = cls._parse_node(nest) - return node, {} - - class LabelledNodeComponent(pytools.ImmutableRecord, Labelled): fields = {"label"} @@ -279,24 +226,24 @@ def __init__(self, label=None): class MultiComponentLabelledNode(Node, Labelled): - fields = Node.fields | {"components", "label"} + fields = Node.fields | {"label"} - def __init__(self, components, label=None, *, id=None): + def __init__(self, label=None, *, id=None): Node.__init__(self, id) Labelled.__init__(self, label) - self.components = as_tuple(components) @property def degree(self) -> int: - return len(self.components) + return len(self.component_labels) @property + @abc.abstractmethod def component_labels(self): - return tuple(c.label for c in self.components) + pass @property - def component(self): - return just_one(self.components) + def component_label(self): + return just_one(self.component_labels) class LabelledTree(AbstractTree): @@ -305,7 +252,7 @@ def component_child(self, parent, component): return self.child(parent, component) def child(self, parent, component): - clabel = self._as_component_label(component) + clabel = as_component_label(component) cidx = parent.component_labels.index(clabel) try: return self.parent_to_children[parent.id][cidx] @@ -314,12 +261,22 @@ def child(self, parent, component): @cached_property def leaves(self): - return tuple( - (node, cpt) - for node in self.nodes - for cidx, cpt in enumerate(node.components) - if self.parent_to_children.get(node.id, [None] * node.degree)[cidx] is None - ) + # NOTE: ordered!! + if self.is_empty: + return () + else: + return self._collect_leaves(self.root) + + def _collect_leaves(self, node): + assert not self.is_empty + leaves = [] + for clabel in node.component_labels: + subnode = self.child(node, clabel) + if subnode: + leaves.extend(self._collect_leaves(subnode)) + else: + leaves.append((node, clabel)) + return tuple(leaves) def add_node( self, @@ -342,7 +299,7 @@ def add_node( "Must specify a component for parents with multiple components" ) else: - parent_cpt_label = parent_component + parent_cpt_label = as_component_label(parent_component) cpt_index = parent.component_labels.index(parent_cpt_label) @@ -378,13 +335,13 @@ def with_modified_component(self, node, component, **kwargs): node, node.with_modified_component(component, **kwargs) ) - # invalid for frozen trees def add_subtree( self, subtree, parent=None, component=None, uniquify: bool = False, + uniquify_ids=False, ): """ Parameters @@ -395,16 +352,17 @@ def add_subtree( If ``False``, duplicate ``ids`` between the tree and subtree will raise an exception. If ``True``, the ``ids`` will be changed to avoid the clash. + Also fixes node labels. - Notes - ----- - This function returns a parent-to-children mapping instead of a new tree - because it is non-trivial to unpick the impact of adding new nodes to the - tree. For example a new star forest may need to be computed. It, for now, - is preferable to make trees as "immutable as possible". """ + # FIXME bad API, uniquify implies uniquify labels only + # There are cases where the labels should be distinct but IDs may clash + # e.g. adding subaxes for a matrix + if uniquify_ids: + assert not uniquify + if uniquify: - raise NotImplementedError("TODO") + uniquify_ids = True if some_but_not_all([parent, component]): raise ValueError( @@ -414,16 +372,66 @@ def add_subtree( if not parent: raise NotImplementedError("TODO") + if subtree.is_empty: + return self + assert isinstance(parent, MultiComponentLabelledNode) - cidx = parent.component_labels.index(component.label) + clabel = as_component_label(component) + cidx = parent.component_labels.index(clabel) parent_to_children = {p: list(ch) for p, ch in self.parent_to_children.items()} - sub_p2c = dict(subtree.parent_to_children) + sub_p2c = {p: list(ch) for p, ch in subtree.parent_to_children.items()} + if uniquify_ids: + self._uniquify_node_ids(sub_p2c, set(parent_to_children.keys())) + assert ( + len(set(sub_p2c.keys()) & set(parent_to_children.keys()) - {None}) == 0 + ) + subroot = just_one(sub_p2c.pop(None)) parent_to_children[parent.id][cidx] = subroot parent_to_children.update(sub_p2c) + + if uniquify: + self._uniquify_node_labels(parent_to_children) + return self.copy(parent_to_children=parent_to_children) + def _uniquify_node_labels(self, node_map, node=None, seen_labels=None): + if not node_map: + return + + if node is None: + node = just_one(node_map[None]) + seen_labels = frozenset({node.label}) + + for i, subnode in enumerate(node_map.get(node.id, [])): + if subnode is None: + continue + if subnode.label in seen_labels: + new_label = UniqueNameGenerator(set(seen_labels))(subnode.label) + assert new_label not in seen_labels + subnode = subnode.copy(label=new_label) + node_map[node.id][i] = subnode + self._uniquify_node_labels(node_map, subnode, seen_labels | {subnode.label}) + + # do as a traversal since there is an ordering constraint in how we replace IDs + def _uniquify_node_ids(self, node_map, existing_ids, node=None): + if not node_map: + return + + node_id = node.id if node is not None else None + for i, subnode in enumerate(node_map.get(node_id, [])): + if subnode is None: + continue + if subnode.id in existing_ids: + new_id = subnode.unique_id() + assert new_id not in existing_ids + existing_ids.add(new_id) + new_subnode = subnode.copy(id=new_id) + node_map[node_id][i] = new_subnode + node_map[new_id] = node_map.pop(subnode.id) + self._uniquify_node_ids(node_map, existing_ids, new_subnode) + @cached_property def _paths(self): def paths_fn(node, component_label, current_path): @@ -463,7 +471,7 @@ def ancestors(self, node, component_label): ) def path(self, node, component, ordered=False): - clabel = self._as_component_label(component) + clabel = as_component_label(component) node_id = self._as_node_id(node) path_ = self._paths[node_id, clabel] if ordered: @@ -474,7 +482,7 @@ def path(self, node, component, ordered=False): def path_with_nodes( self, node, component_label, ordered=False, and_components=False ): - component_label = self._as_component_label(component_label) + component_label = as_component_label(component_label) node_id = self._as_node_id(node) path_ = self._paths_with_nodes[node_id, component_label] if and_components: @@ -487,6 +495,18 @@ def path_with_nodes( else: return pmap(path_) + @cached_property + def leaf_paths(self): + return tuple(self.path(*leaf) for leaf in self.leaves) + + @cached_property + def ordered_leaf_paths(self): + return tuple(self.path(*leaf, ordered=True) for leaf in self.leaves) + + @cached_property + def ordered_leaf_paths_with_nodes(self): + return tuple(self.path_with_nodes(*leaf, ordered=True) for leaf in self.leaves) + def _node_from_path(self, path): if not path: return None @@ -515,13 +535,24 @@ def detailed_path(self, path): else: return self.path_with_nodes(*node, and_components=True) - # this method is crap, if it fails I don't get any useful feedback! - def is_valid_path(self, path): - try: - self._node_from_path(path) - return True - except: - return False + def is_valid_path(self, path, complete=True, leaf=False): + if leaf: + all_paths = [set(self.path(node, cpt).items()) for node, cpt in self.leaves] + else: + all_paths = [ + set(self.path(node, cpt).items()) + for node in self.nodes + for cpt in node.components + ] + + path_set = set(path.items()) + + compare = operator.eq if complete else operator.le + + for path_ in all_paths: + if compare(path_set, path_): + return True + return False def find_component(self, node_label, cpt_label, also_node=False): """Return the first component in the tree matching the given labels. @@ -626,12 +657,12 @@ def _parse_node(node): else: raise TypeError(f"No handler defined for {type(node).__name__}") - @staticmethod - def _as_component_label(component): - if isinstance(component, LabelledNodeComponent): - return component.label - else: - return component + +def as_component_label(component): + if isinstance(component, LabelledNodeComponent): + return component.label + else: + return component def previsit( diff --git a/pyop3/utils.py b/pyop3/utils.py index 7ebcfdce..95722eb5 100644 --- a/pyop3/utils.py +++ b/pyop3/utils.py @@ -1,8 +1,6 @@ import abc import collections -import functools import itertools -import operator import warnings from typing import Any, Collection, Hashable, Optional @@ -10,6 +8,8 @@ import pytools from pyrsistent import pmap +from pyop3.config import config + class UniqueNameGenerator(pytools.UniqueNameGenerator): """Class for generating unique names.""" @@ -28,6 +28,10 @@ def unique_name(prefix: str) -> str: return _unique_name_generator(prefix) +class auto: + pass + + # type aliases Id = Hashable Label = Hashable @@ -51,6 +55,15 @@ def unique_label(cls) -> str: return unique_name(f"_label_{cls.__name__}") +# TODO is Identified really useful? +class UniqueRecord(pytools.ImmutableRecord, Identified): + fields = {"id"} + + def __init__(self, id=None): + pytools.ImmutableRecord.__init__(self) + Identified.__init__(self, id) + + def as_tuple(item): if isinstance(item, collections.abc.Sequence): return tuple(item) @@ -195,6 +208,28 @@ def popwhen(predicate, iterable): raise KeyError("Predicate does not hold for any items in iterable") +def steps(sizes, drop_last=False): + sizes = tuple(sizes) + steps_ = (0,) + tuple(np.cumsum(sizes, dtype=int)) + return steps_[:-1] if drop_last else steps_ + + +def pairwise(iterable): + return zip(iterable, iterable[1:]) + + +# stolen from stackoverflow +# https://stackoverflow.com/questions/11649577/how-to-invert-a-permutation-array-in-numpy +def invert(p): + """Return an array s with which np.array_equal(arr[p][s], arr) is True. + The array_like argument p must be some permutation of 0, 1, ..., len(p)-1. + """ + p = np.asanyarray(p) # in case p is a tuple, etc. + s = np.empty_like(p) + s[p] = np.arange(p.size) + return s + + def strict_cast(obj, cast): new_obj = cast(obj) if new_obj != obj: @@ -285,3 +320,11 @@ def frozen_record(cls): raise TypeError("frozen_record is only valid for subclasses of pytools.Record") cls.copy = _disabled_record_copy return cls + + +def debug_assert(predicate, msg=None): + if config["debug"]: + if msg: + assert predicate(), msg + else: + assert predicate() diff --git a/pyproject.toml b/pyproject.toml index b98b8dc5..14831c41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,8 @@ dependencies = [ dev = [ "black", "isort", -] -test = [ "pytest", + "pytest-timeout", "pytest-mpi @ git+https://github.com/firedrakeproject/pytest-mpi", ] @@ -32,3 +31,4 @@ profile = "black" testpaths = [ "tests", ] +timeout = "300" diff --git a/tests/conftest.py b/tests/conftest.py index 668fa5bc..da7cdda6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +import numbers + +import loopy as lp import pytest from mpi4py import MPI from petsc4py import PETSc @@ -62,3 +65,54 @@ def paxis(comm, sf): numbering = [0, 4, 1, 2, 5, 3] serial = op3.Axis(6, numbering=numbering) return op3.Axis.from_serial(serial, sf) + + +class Helper: + @classmethod + def copy_kernel(cls, shape, dtype=op3.ScalarType): + inames = cls._inames_from_shape(shape) + inames_str = ",".join(inames) + insn = f"y[{inames_str}] = x[{inames_str}]" + + lpy_kernel = cls._loopy_kernel(shape, insn, dtype) + return op3.Function(lpy_kernel, [op3.READ, op3.WRITE]) + + @classmethod + def inc_kernel(cls, shape, dtype=op3.ScalarType): + inames = cls._inames_from_shape(shape) + inames_str = ",".join(inames) + insn = f"y[{inames_str}] = y[{inames_str}] + x[{inames_str}]" + + lpy_kernel = cls._loopy_kernel(shape, insn, dtype) + return op3.Function(lpy_kernel, [op3.READ, op3.INC]) + + @classmethod + def _inames_from_shape(cls, shape): + if isinstance(shape, numbers.Number): + shape = (shape,) + return tuple(f"i_{i}" for i, _ in enumerate(shape)) + + @classmethod + def _loopy_kernel(cls, shape, insns, dtype): + if isinstance(shape, numbers.Number): + shape = (shape,) + + inames = cls._inames_from_shape(shape) + domains = tuple( + f"{{ [{iname}]: 0 <= {iname} < {s} }}" for iname, s in zip(inames, shape) + ) + return lp.make_kernel( + domains, + insns, + [ + lp.GlobalArg("x", shape=shape, dtype=dtype), + lp.GlobalArg("y", shape=shape, dtype=dtype), + ], + target=op3.ir.LOOPY_TARGET, + lang_version=op3.ir.LOOPY_LANG_VERSION, + ) + + +@pytest.fixture(scope="session") +def factory(): + return Helper() diff --git a/tests/integration/test_assign.py b/tests/integration/test_assign.py new file mode 100644 index 00000000..dd424c8c --- /dev/null +++ b/tests/integration/test_assign.py @@ -0,0 +1,19 @@ +import pytest + +import pyop3 as op3 + + +@pytest.mark.parametrize("mode", ["scalar", "vector"]) +def test_assign_number(mode): + root = op3.Axis(5) + if mode == "scalar": + axes = op3.AxisTree(root) + else: + assert mode == "vector" + axes = op3.AxisTree.from_nest({root: op3.Axis(3)}) + + dat = op3.HierarchicalArray(axes, dtype=op3.IntType) + assert (dat.data_ro == 0).all() + + op3.do_loop(p := root.index(), dat[p].assign(666)) + assert (dat.data_ro == 666).all() diff --git a/tests/integration/test_axis_ordering.py b/tests/integration/test_axis_ordering.py index df2d4867..38be5852 100644 --- a/tests/integration/test_axis_ordering.py +++ b/tests/integration/test_axis_ordering.py @@ -1,14 +1,8 @@ -import ctypes - import loopy as lp import numpy as np -import pymbolic as pym -import pytest from pyrsistent import pmap import pyop3 as op3 -from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET -from pyop3.utils import just_one def test_different_axis_orderings_do_not_change_packing_order(): @@ -23,8 +17,8 @@ def test_different_axis_orderings_do_not_change_packing_order(): lp.GlobalArg("y", op3.ScalarType, (m1, m2), is_input=False, is_output=True), ], name="copy", - target=LOOPY_TARGET, - lang_version=LOOPY_LANG_VERSION, + target=op3.ir.LOOPY_TARGET, + lang_version=op3.ir.LOOPY_LANG_VERSION, ) copy_kernel = op3.Function(lpy_kernel, [op3.READ, op3.WRITE]) @@ -51,16 +45,16 @@ def test_different_axis_orderings_do_not_change_packing_order(): p = axis0.index() path = pmap({axis0.label: axis0.component.label}) - loop_context = pmap({p.id: path}) + loop_context = pmap({p.id: (path, path)}) + cf_p = p.with_context(loop_context) slice0 = op3.Slice(axis1.label, [op3.AffineSliceComponent(axis1.component.label)]) slice1 = op3.Slice(axis2.label, [op3.AffineSliceComponent(axis2.component.label)]) q = op3.IndexTree( { - None: (p,), - p.id: (slice0,), + None: (cf_p,), + cf_p.id: (slice0,), slice0.id: (slice1,), }, - loop_context=loop_context, ) op3.do_loop(p, copy_kernel(dat0_0[q], dat1[q])) diff --git a/tests/integration/test_basics.py b/tests/integration/test_basics.py index e9c29109..29cd0973 100644 --- a/tests/integration/test_basics.py +++ b/tests/integration/test_basics.py @@ -6,22 +6,6 @@ from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET -@pytest.fixture -def scalar_copy_kernel(): - code = lp.make_kernel( - "{ [i]: 0 <= i < 1 }", - "y[i] = x[i]", - [ - lp.GlobalArg("x", op3.ScalarType, (1,), is_input=True, is_output=False), - lp.GlobalArg("y", op3.ScalarType, (1,), is_input=False, is_output=True), - ], - target=LOOPY_TARGET, - name="scalar_copy", - lang_version=(2018, 2), - ) - return op3.Function(code, [op3.READ, op3.WRITE]) - - @pytest.fixture def vector_copy_kernel(): code = lp.make_kernel( @@ -38,11 +22,11 @@ def vector_copy_kernel(): return op3.Function(code, [op3.READ, op3.WRITE]) -def test_scalar_copy(scalar_copy_kernel): +def test_scalar_copy(factory): m = 10 axis = op3.Axis(m) dat0 = op3.HierarchicalArray( - axis, name="dat0", data=np.arange(axis.size), dtype=op3.ScalarType + axis, name="dat0", data=np.arange(axis.size, dtype=op3.ScalarType) ) dat1 = op3.HierarchicalArray( axis, @@ -50,7 +34,10 @@ def test_scalar_copy(scalar_copy_kernel): dtype=dat0.dtype, ) - op3.do_loop(p := axis.index(), scalar_copy_kernel(dat0[p], dat1[p])) + kernel = factory.copy_kernel(1) + # op3.do_loop(p := axis.index(), kernel(dat0[p], dat1[p])) + loop = op3.loop(p := axis.index(), kernel(dat0[p], dat1[p])) + loop() assert np.allclose(dat1.data, dat0.data) @@ -80,7 +67,7 @@ def test_multi_component_vector_copy(vector_copy_kernel): dat0 = op3.HierarchicalArray( axes, name="dat0", - data=np.arange(m * a + n * b), + data=np.arange(axes.size), dtype=op3.ScalarType, ) dat1 = op3.HierarchicalArray( @@ -94,26 +81,37 @@ def test_multi_component_vector_copy(vector_copy_kernel): vector_copy_kernel(dat0[p, :], dat1[p, :]), ) - assert all(dat1.data[: m * a] == 0) - assert all(dat1.data[m * a :] == dat0.data[m * a :]) + assert (dat1.data[: m * a] == 0).all() + assert (dat1.data[m * a :] == dat0.data[m * a :]).all() def test_copy_multi_component_temporary(vector_copy_kernel): m = 4 n0, n1 = 2, 1 - npoints = m * n0 + m * n1 - axes = op3.AxisTree.from_nest({op3.Axis(m): op3.Axis([n0, n1])}) + axes = op3.AxisTree.from_nest( + {op3.Axis(m): op3.Axis({"pt0": n0, "pt1": n1}, "ax1")} + ) dat0 = op3.HierarchicalArray( - axes, name="dat0", data=np.arange(npoints), dtype=op3.ScalarType + axes, + name="dat0", + data=np.arange(axes.size, dtype=op3.ScalarType), ) dat1 = op3.HierarchicalArray(axes, name="dat1", dtype=dat0.dtype) - op3.do_loop(p := axes.root.index(), vector_copy_kernel(dat0[p, :], dat1[p, :])) + # An explicit slice object is required because typical slice notation ":" is + # ambiguous when there are multiple components that might be getting sliced. + slice_ = op3.Slice( + "ax1", [op3.AffineSliceComponent("pt0"), op3.AffineSliceComponent("pt1")] + ) + + op3.do_loop( + p := axes.root.index(), vector_copy_kernel(dat0[p, slice_], dat1[p, slice_]) + ) assert np.allclose(dat1.data, dat0.data) -def test_multi_component_scalar_copy_with_two_outer_loops(scalar_copy_kernel): +def test_multi_component_scalar_copy_with_two_outer_loops(factory): m, n, a, b = 8, 6, 2, 3 axes = op3.AxisTree.from_nest( @@ -129,6 +127,7 @@ def test_multi_component_scalar_copy_with_two_outer_loops(scalar_copy_kernel): ) dat1 = op3.HierarchicalArray(axes, name="dat1", dtype=dat0.dtype) - op3.do_loop(p := axes["pt1", :].index(), scalar_copy_kernel(dat0[p], dat1[p])) + kernel = factory.copy_kernel(1) + op3.do_loop(p := axes["pt1", :].index(), kernel(dat0[p], dat1[p])) assert all(dat1.data[: m * a] == 0) assert all(dat1.data[m * a :] == dat0.data[m * a :]) diff --git a/tests/integration/test_codegen.py b/tests/integration/test_codegen.py new file mode 100644 index 00000000..d1b8eb84 --- /dev/null +++ b/tests/integration/test_codegen.py @@ -0,0 +1,53 @@ +import loopy as lp + +import pyop3 as op3 +from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET + + +def test_dummy_arguments(): + kernel = op3.Function( + lp.make_kernel( + "{ [i]: 0 <= i < 1 }", + [lp.CInstruction((), "y[0] = x[0];", read_variables=frozenset({"x", "y"}))], + [ + lp.ValueArg("x", dtype=lp.types.OpaqueType("double*")), + lp.ValueArg("y", dtype=lp.types.OpaqueType("double*")), + ], + name="subkernel", + target=LOOPY_TARGET, + lang_version=LOOPY_LANG_VERSION, + ), + [op3.NA, op3.NA], + ) + # ccode = lp.generate_code_v2(kernel.code) + # breakpoint() + called_kernel = kernel(op3.DummyKernelArgument(), op3.DummyKernelArgument()) + + code = op3.ir.lower.compile(called_kernel, name="dummy_kernel") + ccode = lp.generate_code_v2(code.ir).device_code() + + # TODO validate that the generate code is correct, at the time of writing + # it merely looks right + + +def test_external_loop_index_is_passed_as_kernel_argument(): + kernel = op3.Function( + lp.make_kernel( + "{ [i]: 0 <= j < 1 }", + "x[0] = 666", + [lp.GlobalArg("x", shape=(1,), dtype=op3.IntType)], + target=LOOPY_TARGET, + lang_version=LOOPY_LANG_VERSION, + ), + [op3.WRITE], + ) + + axes = op3.AxisTree.from_iterable((5,)) + dat = op3.HierarchicalArray(axes, dtype=op3.IntType) + index = axes.index() + called_kernel = kernel(dat[index]) + + lp_code = op3.ir.lower.compile(called_kernel, name="kernel") + c_code = lp.generate_code_v2(lp_code.ir).device_code() + + # assert False, "check result" diff --git a/tests/integration/test_constants.py b/tests/integration/test_constants.py index 77c00e4e..2cfef683 100644 --- a/tests/integration/test_constants.py +++ b/tests/integration/test_constants.py @@ -5,7 +5,6 @@ from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET -# spelling? def test_loop_over_parametrised_length(scalar_copy_kernel): length = op3.HierarchicalArray(op3.AxisTree(), dtype=int) iter_axes = op3.Axis([op3.AxisComponent(length, "pt0")], "ax0") diff --git a/tests/integration/test_local_indices.py b/tests/integration/test_local_indices.py index 965e5c6f..ba70d252 100644 --- a/tests/integration/test_local_indices.py +++ b/tests/integration/test_local_indices.py @@ -1,3 +1,4 @@ +# TODO arguably a bad file name/test layout import numpy as np import pytest @@ -28,3 +29,31 @@ def test_copy_slice(scalar_copy_kernel): scalar_copy_kernel(dat0[p], dat1[p.i]), ) assert np.allclose(dat1.data_ro, dat0.data_ro[::2]) + + +@pytest.mark.xfail( + reason="Passing loop indices to the local kernel is not currently supported" +) +def test_pass_loop_index_as_argument(factory): + m = 10 + axes = op3.Axis(m) + dat = op3.HierarchicalArray(axes, dtype=op3.IntType) + + copy_kernel = factory.copy_kernel(1, dtype=dat.dtype) + op3.do_loop(p := axes.index(), copy_kernel(p, dat[p])) + assert (dat.data_ro == list(range(m))).all() + + +@pytest.mark.xfail( + reason="Passing loop indices to the local kernel is not currently supported" +) +def test_pass_multi_component_loop_index_as_argument(factory): + m, n = 10, 12 + axes = op3.Axis([m, n]) + dat = op3.HierarchicalArray(axes, dtype=op3.IntType) + + copy_kernel = factory.copy_kernel(1, dtype=dat.dtype) + op3.do_loop(p := axes.index(), copy_kernel(p, dat[p])) + + expected = list(range(m)) + list(range(n)) + assert (dat.data_ro == expected).all() diff --git a/tests/integration/test_maps.py b/tests/integration/test_maps.py index d8a4d2f7..90e5a883 100644 --- a/tests/integration/test_maps.py +++ b/tests/integration/test_maps.py @@ -1,7 +1,7 @@ import loopy as lp import numpy as np import pytest -from pyrsistent import pmap +from pyrsistent import freeze, pmap import pyop3 as op3 from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET @@ -24,6 +24,23 @@ def vector_inc_kernel(): return op3.Function(lpy_kernel, [op3.READ, op3.INC]) +# TODO make a function not a fixture +@pytest.fixture +def vector2_inc_kernel(): + lpy_kernel = lp.make_kernel( + "{ [i]: 0 <= i < 2 }", + "y[0] = y[0] + x[i]", + [ + lp.GlobalArg("x", op3.ScalarType, (2,), is_input=True, is_output=False), + lp.GlobalArg("y", op3.ScalarType, (1,), is_input=True, is_output=True), + ], + name="vector_inc", + target=LOOPY_TARGET, + lang_version=LOOPY_LANG_VERSION, + ) + return op3.Function(lpy_kernel, [op3.READ, op3.INC]) + + @pytest.fixture def vec2_inc_kernel(): lpy_kernel = lp.make_kernel( @@ -73,7 +90,10 @@ def vec12_inc_kernel(): @pytest.mark.parametrize("nested", [True, False]) -def test_inc_from_tabulated_map(scalar_inc_kernel, vector_inc_kernel, nested): +@pytest.mark.parametrize("indexed", [None, "slice", "subset"]) +def test_inc_from_tabulated_map( + scalar_inc_kernel, vector_inc_kernel, vector2_inc_kernel, nested, indexed +): m, n = 4, 3 map_data = np.asarray([[1, 2, 0], [2, 0, 1], [3, 2, 3], [2, 0, 1]]) @@ -83,13 +103,29 @@ def test_inc_from_tabulated_map(scalar_inc_kernel, vector_inc_kernel, nested): ) dat1 = op3.HierarchicalArray(axis, name="dat1", dtype=dat0.dtype) - map_axes = op3.AxisTree.from_nest({axis: op3.Axis(n)}) + map_axes = op3.AxisTree.from_nest({axis: op3.Axis({"pt0": n}, "ax1")}) map_dat = op3.HierarchicalArray( map_axes, name="map0", data=map_data.flatten(), dtype=op3.IntType, ) + + if indexed == "slice": + map_dat = map_dat[:, 1:3] + kernel = vector2_inc_kernel + elif indexed == "subset": + subset_ = op3.HierarchicalArray( + op3.Axis({"pt0": 2}, "ax1"), + name="subset", + data=np.asarray([1, 2]), + dtype=op3.IntType, + ) + map_dat = map_dat[:, subset_] + kernel = vector2_inc_kernel + else: + kernel = vector_inc_kernel + map0 = op3.Map( { pmap({"ax0": "pt0"}): [ @@ -100,17 +136,26 @@ def test_inc_from_tabulated_map(scalar_inc_kernel, vector_inc_kernel, nested): ) if nested: - op3.do_loop( + # op3.do_loop( + loop = op3.loop( p := axis.index(), op3.loop(q := map0(p).index(), scalar_inc_kernel(dat0[q], dat1[p])), ) + loop() else: - op3.do_loop(p := axis.index(), vector_inc_kernel(dat0[map0(p)], dat1[p])) + op3.do_loop(p := axis.index(), kernel(dat0[map0(p)], dat1[p])) expected = np.zeros_like(dat1.data_ro) for i in range(m): - for j in range(n): - expected[i] += dat0.data_ro[map_data[i, j]] + if indexed == "slice": + for j in range(1, 3): + expected[i] += dat0.data_ro[map_data[i, j]] + elif indexed == "subset": + for j in [1, 2]: + expected[i] += dat0.data_ro[map_data[i, j]] + else: + for j in range(n): + expected[i] += dat0.data_ro[map_data[i, j]] assert np.allclose(dat1.data_ro, expected) @@ -175,8 +220,8 @@ def test_inc_with_multiple_maps(vector_inc_kernel): ) dat1 = op3.HierarchicalArray(axis, name="dat1", dtype=dat0.dtype) - map_axes0 = op3.AxisTree.from_nest({axis: op3.Axis(arity0)}) - map_axes1 = op3.AxisTree.from_nest({axis: op3.Axis(arity1)}) + map_axes0 = op3.AxisTree.from_nest({axis: op3.Axis(arity0, "ax1")}) + map_axes1 = op3.AxisTree.from_nest({axis: op3.Axis(arity1, "ax1")}) map_dat0 = op3.HierarchicalArray( map_axes0, @@ -198,7 +243,9 @@ def test_inc_with_multiple_maps(vector_inc_kernel): op3.TabulatedMapComponent("ax0", "pt0", map_dat1), ], }, - "map0", + # FIXME + # "map0", + "ax1", ) op3.do_loop(p := axis.index(), vector_inc_kernel(dat0[map0(p)], dat1[p])) @@ -338,34 +385,238 @@ def test_vector_inc_with_map_composition(vec2_inc_kernel, vec12_inc_kernel, nest assert np.allclose(dat1.data_ro, expected) -@pytest.mark.skip( - reason="Passing ragged arguments through to the local is not yet supported" -) -def test_inc_with_variable_arity_map(ragged_inc_kernel): +def test_partial_map_connectivity(vector2_inc_kernel): + axis = op3.Axis({"pt0": 3}, "ax0") + dat0 = op3.HierarchicalArray(axis, data=np.arange(3, dtype=op3.ScalarType)) + dat1 = op3.HierarchicalArray(axis, dtype=dat0.dtype) + + map_axes = op3.AxisTree.from_nest({axis: op3.Axis(2)}) + map_data = [[0, 1], [2, 0], [2, 2]] + map_array = np.asarray(flatten(map_data), dtype=op3.IntType) + map_dat = op3.HierarchicalArray(map_axes, data=map_array) + + # Some elements of map_ are not present in axis, so should be ignored + map_ = op3.Map( + { + freeze({"ax0": "pt0"}): [ + op3.TabulatedMapComponent("ax0", "pt0", map_dat), + op3.TabulatedMapComponent("not_ax0", "not_pt0", map_dat), + ] + }, + ) + + op3.do_loop(p := axis.index(), vector2_inc_kernel(dat0[map_(p)], dat1[p])) + + expected = np.zeros_like(dat1.data_ro) + for i in range(3): + for j in range(2): + expected[i] += dat0.data_ro[map_data[i][j]] + assert np.allclose(dat1.data_ro, expected) + + +def test_inc_with_variable_arity_map(scalar_inc_kernel): m = 3 - nnzdata = np.asarray([3, 2, 1], dtype=IntType) - mapdata = [[2, 1, 0], [2, 1], [2]] + axis = op3.Axis({"pt0": m}, "ax0") + dat0 = op3.HierarchicalArray( + axis, name="dat0", data=np.arange(axis.size, dtype=op3.ScalarType) + ) + dat1 = op3.HierarchicalArray(axis, name="dat1", dtype=dat0.dtype) + + nnz_data = np.asarray([3, 2, 1], dtype=op3.IntType) + nnz = op3.HierarchicalArray(axis, name="nnz", data=nnz_data, max_value=3) + + map_axes = op3.AxisTree.from_nest({axis: op3.Axis(nnz)}) + map_data = [[2, 1, 0], [2, 1], [2]] + map_array = np.asarray(flatten(map_data), dtype=op3.IntType) + map_dat = op3.HierarchicalArray(map_axes, name="map0", data=map_array) + map0 = op3.Map( + {freeze({"ax0": "pt0"}): [op3.TabulatedMapComponent("ax0", "pt0", map_dat)]}, + name="map0", + ) + + op3.do_loop( + p := axis.index(), + op3.loop(q := map0(p).index(), scalar_inc_kernel(dat0[q], dat1[p])), + ) + + expected = np.zeros_like(dat1.data_ro) + for i in range(m): + for j in map_data[i]: + expected[i] += dat0.data_ro[j] + assert np.allclose(dat1.data_ro, expected) + + +@pytest.mark.parametrize("method", ["codegen", "python"]) +def test_loop_over_multiple_ragged_maps(factory, method): + m = 5 + axis = op3.Axis({"pt0": m}, "ax0") + dat0 = op3.HierarchicalArray( + axis, name="dat0", data=np.arange(axis.size, dtype=op3.IntType) + ) + dat1 = op3.HierarchicalArray(axis, name="dat1", dtype=dat0.dtype) - axes = AxisTree(Axis(m, "ax0")) - dat0 = MultiArray(axes, name="dat0", data=np.arange(m, dtype=ScalarType)) - dat1 = MultiArray(axes, name="dat1", data=np.zeros(m, dtype=ScalarType)) + # map0 + nnz0_data = np.asarray([3, 2, 1, 0, 3], dtype=op3.IntType) + nnz0 = op3.HierarchicalArray(axis, name="nnz0", data=nnz0_data) + + map0_axes = op3.AxisTree.from_nest({axis: op3.Axis(nnz0)}) + map0_data = [[2, 4, 0], [3, 3], [1], [], [4, 2, 1]] + map0_array = np.asarray(op3.utils.flatten(map0_data), dtype=op3.IntType) + map0_dat = op3.HierarchicalArray(map0_axes, name="map0", data=map0_array) + map0 = op3.Map( + {freeze({"ax0": "pt0"}): [op3.TabulatedMapComponent("ax0", "pt0", map0_dat)]}, + name="map0", + ) - nnz = MultiArray(axes, name="nnz", data=nnzdata, max_value=3) + # map1 + nnz1_data = np.asarray([2, 0, 3, 1, 2], dtype=op3.IntType) + nnz1 = op3.HierarchicalArray(axis, name="nnz1", data=nnz1_data) - maxes = axes.add_subaxis(Axis(nnz, "ax1"), axes.leaf) - map0 = MultiArray( - maxes, name="map0", data=np.asarray(flatten(mapdata), dtype=IntType) + map1_axes = op3.AxisTree.from_nest({axis: op3.Axis(nnz1)}) + map1_data = [[4, 0], [], [1, 0, 0], [3], [2, 3]] + map1_array = np.asarray(op3.utils.flatten(map1_data), dtype=op3.IntType) + map1_dat = op3.HierarchicalArray(map1_axes, name="map1", data=map1_array) + map1 = op3.Map( + {freeze({"ax0": "pt0"}): [op3.TabulatedMapComponent("ax0", "pt0", map1_dat)]}, + name="map1", ) - p = IndexTree(Index(Range("ax0", m))) - q = p.put_node( - Index(TabulatedMap([("ax0", 0)], [("ax0", 0)], arity=nnz[p], data=map0[p])), - p.leaf, + inc = factory.inc_kernel(1, op3.IntType) + + if method == "codegen": + op3.do_loop( + p := axis.index(), + op3.loop( + q := map1(map0(p)).index(), + inc(dat0[q], dat1[p]), + ), + ) + else: + assert method == "python" + for p in axis.iter(): + for q in map1(map0(p.index)).iter({p}): + prev_val = dat1.get_value(p.target_exprs, p.target_path) + inc = dat0.get_value(q.target_exprs, q.target_path) + dat1.set_value(p.target_exprs, prev_val + inc, p.target_path) + + expected = np.zeros_like(dat1.data_ro) + for i in range(m): + for j in map0_data[i]: + for k in map1_data[j]: + expected[i] += dat0.data_ro[k] + assert (dat1.data_ro == expected).all() + + +@pytest.mark.parametrize("method", ["codegen", "python"]) +def test_loop_over_multiple_multi_component_ragged_maps(factory, method): + m, n = 5, 6 + axis = op3.Axis({"pt0": m, "pt1": n}, "ax0") + dat0 = op3.HierarchicalArray( + axis, name="dat0", data=np.arange(axis.size, dtype=op3.IntType) ) + dat1 = op3.HierarchicalArray(axis, name="dat1", dtype=dat0.dtype) - do_loop(p, ragged_inc_kernel(dat0[q], dat1[p])) + # pt0 -> pt0 + nnz00_data = np.asarray([3, 2, 1, 0, 3], dtype=op3.IntType) + nnz00 = op3.HierarchicalArray(axis["pt0"], name="nnz00", data=nnz00_data) + map0_axes0 = op3.AxisTree.from_nest({axis["pt0"].root: op3.Axis(nnz00)}) + map0_data0 = [[2, 4, 0], [3, 3], [1], [], [4, 2, 1]] + map0_array0 = np.asarray(op3.utils.flatten(map0_data0), dtype=op3.IntType) + map0_dat0 = op3.HierarchicalArray(map0_axes0, name="map00", data=map0_array0) + + # pt0 -> pt1 + nnz01_data = np.asarray([1, 2, 1, 0, 4], dtype=op3.IntType) + nnz01 = op3.HierarchicalArray(axis["pt0"], name="nnz01", data=nnz01_data) + map0_axes1 = op3.AxisTree.from_nest({axis["pt0"].root: op3.Axis(nnz01)}) + map0_data1 = [[2], [1, 0], [2], [], [1, 4, 2, 1]] + map0_array1 = np.asarray(op3.utils.flatten(map0_data1), dtype=op3.IntType) + map0_dat1 = op3.HierarchicalArray(map0_axes1, name="map01", data=map0_array1) + + # pt1 -> pt1 (pt1 -> pt0 not implemented) + nnz1_data = np.asarray([2, 2, 1, 3, 0, 2], dtype=op3.IntType) + nnz1 = op3.HierarchicalArray(axis["pt1"], name="nnz1", data=nnz1_data) + map1_axes = op3.AxisTree.from_nest({axis["pt1"].root: op3.Axis(nnz1)}) + map1_data = [[2, 5], [0, 1], [3], [5, 5, 5], [], [2, 1]] + map1_array = np.asarray(op3.utils.flatten(map1_data), dtype=op3.IntType) + map1_dat = op3.HierarchicalArray(map1_axes, name="map1", data=map1_array) + + map_ = op3.Map( + { + freeze({"ax0": "pt0"}): [ + op3.TabulatedMapComponent("ax0", "pt0", map0_dat0), + op3.TabulatedMapComponent("ax0", "pt1", map0_dat1), + ], + freeze({"ax0": "pt1"}): [ + op3.TabulatedMapComponent("ax0", "pt1", map1_dat), + ], + }, + name="map_", + ) - assert np.allclose(dat1.data, [sum(xs) for xs in mapdata]) + inc = factory.inc_kernel(1, op3.IntType) + + if method == "codegen": + op3.do_loop( + p := axis["pt0"].index(), + op3.loop( + q := map_(map_(p)).index(), + inc(dat0[q], dat1[p]), + ), + ) + else: + assert method == "python" + for p in axis["pt0"].iter(): + for q in map_(map_(p.index)).iter({p}): + prev_val = dat1.get_value(p.target_exprs, p.target_path) + inc = dat0.get_value(q.target_exprs, q.target_path) + dat1.set_value(p.target_exprs, prev_val + inc, p.target_path) + + # To see what is going on we can determine the expected result in two + # ways: one pythonically and one equivalent to the generated code. + # We leave both here for reference as they aid in understanding what + # the code is doing. + expected_pythonic = np.zeros_like(dat1.data_ro) + for i in range(m): + # pt0 -> pt0 -> pt0 + for j in map0_data0[i]: + for k in map0_data0[j]: + expected_pythonic[i] += dat0.data_ro[k] + # pt0 -> pt0 -> pt1 + for j in map0_data0[i]: + for k in map0_data1[j]: + # add m since we are targeting pt1 + expected_pythonic[i] += dat0.data_ro[k + m] + # pt0 -> pt1 -> pt1 + for j in map0_data1[i]: + for k in map1_data[j]: + # add m since we are targeting pt1 + expected_pythonic[i] += dat0.data_ro[k + m] + + expected_codegen = np.zeros_like(dat1.data_ro) + for i in range(m): + # pt0 -> pt0 -> pt0 + for j in range(nnz00_data[i]): + map_idx = map0_data0[i][j] + for k in range(nnz00_data[map_idx]): + ptr = map0_data0[map_idx][k] + expected_codegen[i] += dat0.data_ro[ptr] + # pt0 -> pt0 -> pt1 + for j in range(nnz00_data[i]): + map_idx = map0_data0[i][j] + for k in range(nnz01_data[map_idx]): + # add m since we are targeting pt1 + ptr = map0_data1[map_idx][k] + m + expected_codegen[i] += dat0.data_ro[ptr] + # pt0 -> pt1 -> pt1 + for j in range(nnz01_data[i]): + map_idx = map0_data1[i][j] + for k in range(nnz1_data[map_idx]): + # add m since we are targeting pt1 + ptr = map1_data[map_idx][k] + m + expected_codegen[i] += dat0.data_ro[ptr] + + assert (expected_pythonic == expected_codegen).all() + assert (dat1.data_ro == expected_pythonic).all() def test_map_composition(vec2_inc_kernel): @@ -374,6 +625,10 @@ def test_map_composition(vec2_inc_kernel): iterset = op3.Axis({"pt0": 2}, "ax0") dat_axis0 = op3.Axis(10) dat_axis1 = op3.Axis(arity1) + dat0 = op3.HierarchicalArray( + dat_axis0, name="dat0", data=np.arange(dat_axis0.size, dtype=op3.ScalarType) + ) + dat1 = op3.HierarchicalArray(dat_axis1, name="dat1", dtype=dat0.dtype) map_axes0 = op3.AxisTree.from_nest({iterset: op3.Axis(arity0)}) map_data0 = np.asarray([[2, 4, 0], [6, 7, 1]]) @@ -388,9 +643,21 @@ def test_map_composition(vec2_inc_kernel): ), ], }, - "map0", ) + # The labelling for intermediate maps is quite opaque, we use the ID of the + # ContextFreeCalledMap nodes in the index tree. This is so we do not hit any + # conflicts when we compose the same map multiple times. I am unsure how to + # expose this to the user nicely, and this is a use case I do not imagine + # anyone actually wanting, so I am unpicking the right label from the + # intermediate indexed object. + p = iterset.index() + indexed_dat0 = dat0[map0(p)] + cf_indexed_dat0 = indexed_dat0.with_context( + {p.id: ({"ax0": "pt0"}, {"ax0": "pt0"})} + ) + called_map_node = op3.utils.just_one(cf_indexed_dat0.axes.nodes) + # this map targets the entries in map0 so it can only contain 0s, 1s and 2s map_axes1 = op3.AxisTree.from_nest({iterset: op3.Axis(arity1)}) map_data1 = np.asarray([[0, 2], [2, 1]]) @@ -400,18 +667,14 @@ def test_map_composition(vec2_inc_kernel): map1 = op3.Map( { pmap({"ax0": "pt0"}): [ - op3.TabulatedMapComponent("map0", "a", map_dat1), + op3.TabulatedMapComponent( + called_map_node.label, called_map_node.component.label, map_dat1 + ), ], }, - "map1", ) - dat0 = op3.HierarchicalArray( - dat_axis0, name="dat0", data=np.arange(dat_axis0.size), dtype=op3.ScalarType - ) - dat1 = op3.HierarchicalArray(dat_axis1, name="dat1", dtype=dat0.dtype) - - op3.do_loop(p := iterset.index(), vec2_inc_kernel(dat0[map0(p)][map1(p)], dat1)) + op3.do_loop(p, vec2_inc_kernel(indexed_dat0[map1(p)], dat1)) expected = np.zeros_like(dat1.data_ro) for i in range(iterset.size): @@ -423,7 +686,8 @@ def test_map_composition(vec2_inc_kernel): assert np.allclose(dat1.data_ro, expected) -def test_recursive_multi_component_maps(): +@pytest.mark.parametrize("method", ["codegen", "python"]) +def test_recursive_multi_component_maps(method): m, n = 5, 6 arity0_0, arity0_1, arity1 = 3, 2, 1 @@ -502,9 +766,17 @@ def test_recursive_multi_component_maps(): target=LOOPY_TARGET, lang_version=LOOPY_LANG_VERSION, ) - sum_kernel = op3.Function(lpy_kernel, [op3.READ, op3.WRITE]) + sum_kernel = op3.Function(lpy_kernel, [op3.READ, op3.INC]) - op3.do_loop(p := axis["pt0"].index(), sum_kernel(dat0[map1(map0(p))], dat1[p])) + if method == "codegen": + op3.do_loop(p := axis["pt0"].index(), sum_kernel(dat0[map1(map0(p))], dat1[p])) + else: + assert method == "python" + for p in axis["pt0"].iter(): + for q in map1(map0(p.index)).iter({p}): + prev_val = dat1.get_value(p.target_exprs, p.target_path) + inc = dat0.get_value(q.target_exprs, q.target_path) + dat1.set_value(p.target_exprs, prev_val + inc, p.target_path) expected = np.zeros_like(dat1.data_ro) for i in range(m): diff --git a/tests/integration/test_nested_loops.py b/tests/integration/test_nested_loops.py index 598a90ff..e5270365 100644 --- a/tests/integration/test_nested_loops.py +++ b/tests/integration/test_nested_loops.py @@ -34,12 +34,14 @@ def test_nested_multi_component_loops(scalar_copy_kernel): axes = op3.AxisTree.from_nest({axis0: [axis1, axis1_dup]}) dat0 = op3.HierarchicalArray( - axes, name="dat0", data=np.arange(axes.size), dtype=op3.ScalarType + axes, name="dat0", data=np.arange(axes.size, dtype=op3.ScalarType) ) dat1 = op3.HierarchicalArray(axes, name="dat1", dtype=dat0.dtype) - op3.do_loop( + # op3.do_loop( + loop = op3.loop( p := axis0.index(), op3.loop(q := axis1.index(), scalar_copy_kernel(dat0[p, q], dat1[p, q])), ) + loop() assert np.allclose(dat1.data_ro, dat0.data_ro) diff --git a/tests/integration/test_numbering.py b/tests/integration/test_numbering.py index c9e81594..9b75c6cb 100644 --- a/tests/integration/test_numbering.py +++ b/tests/integration/test_numbering.py @@ -1,14 +1,9 @@ -import ctypes - import loopy as lp import numpy as np -import pymbolic as pym import pytest -from pyrsistent import pmap import pyop3 as op3 from pyop3.ir.lower import LOOPY_LANG_VERSION, LOOPY_TARGET -from pyop3.utils import flatten @pytest.fixture @@ -117,10 +112,14 @@ def test_vector_copy_with_permuted_multi_component_axes(vector_copy_kernel): a, b = 2, 3 numbering = [4, 2, 0, 3, 1] - root = op3.Axis({"a": m, "b": n}) + root = op3.Axis({"a": m, "b": n}, "ax0") proot = root.copy(numbering=numbering) - axes = op3.AxisTree.from_nest({root: [op3.Axis(a), op3.Axis(b)]}) - paxes = op3.AxisTree.from_nest({proot: [op3.Axis(a), op3.Axis(b)]}) + axes = op3.AxisTree.from_nest( + {root: [op3.Axis({"pt0": a}, "ax1"), op3.Axis({"pt0": b}, "ax2")]} + ) + paxes = op3.AxisTree.from_nest( + {proot: [op3.Axis({"pt0": a}, "ax1"), op3.Axis({"pt0": b}, "ax2")]} + ) dat0 = op3.HierarchicalArray( axes, name="dat0", data=np.arange(axes.size), dtype=op3.ScalarType @@ -136,22 +135,25 @@ def test_vector_copy_with_permuted_multi_component_axes(vector_copy_kernel): assert not np.allclose(dat1.data_ro, dat0.data_ro) izero = [ - [("a", 0), 0], - [("a", 0), 1], - [("a", 1), 0], - [("a", 1), 1], - [("a", 2), 0], - [("a", 2), 1], + {"ax0": 0, "ax1": 0}, + {"ax0": 0, "ax1": 1}, + {"ax0": 1, "ax1": 0}, + {"ax0": 1, "ax1": 1}, + {"ax0": 2, "ax1": 0}, + {"ax0": 2, "ax1": 1}, ] + path = {"ax0": "a", "ax1": "pt0"} + for ix in izero: + assert np.allclose(dat1.get_value(ix, path), 0.0) + icopied = [ - [("b", 0), 0], - [("b", 0), 1], - [("b", 0), 2], - [("b", 1), 0], - [("b", 1), 1], - [("b", 1), 2], + {"ax0": 0, "ax2": 0}, + {"ax0": 0, "ax2": 1}, + {"ax0": 0, "ax2": 2}, + {"ax0": 1, "ax2": 0}, + {"ax0": 1, "ax2": 1}, + {"ax0": 1, "ax2": 2}, ] - for ix in izero: - assert np.allclose(dat1.get_value(ix), 0.0) + path = {"ax0": "b", "ax2": "pt0"} for ix in icopied: - assert np.allclose(dat1.get_value(ix), dat0.get_value(ix)) + assert np.allclose(dat1.get_value(ix, path), dat0.get_value(ix, path)) diff --git a/tests/integration/test_parallel_loops.py b/tests/integration/test_parallel_loops.py index cff1ea43..26f8fd97 100644 --- a/tests/integration/test_parallel_loops.py +++ b/tests/integration/test_parallel_loops.py @@ -98,63 +98,16 @@ def cone_map(comm, mesh_axis): assert comm.rank == 1 mdata = np.asarray([[4, 5], [5, 6], [6, 7], [7, 8]]) - # NOTES - # Question: - # How does one map from the default component-wise numbering to the - # correct component-wise numbering of the renumbered axis? - # - # Example: - # Given the renumbering [c1, v2, v0, c0, v1], generate the maps from default to - # renumbered (component-wise) points: - # - # {c0: c1, c1: c0}, {v0: v1, v1: v2, v2: v0} - # - # Solution: - # - # The specified numbering is a map from the new numbering to the old. Therefore - # the inverse of this maps from the old numbering to the new. To give an example, - # consider the interval mesh numbering [c1, v2, v0, c0, v1]. With plex numbering - # this becomes [1, 4, 2, 0, 3]. This tells us that point 0 in the new numbering - # corresponds to point 1 in the default numbering, point 1 maps to point 4 and - # so on. For this example, the inverse numbering is [3, 0, 2, 4, 1]. This tells - # us that point 0 in the default numbering maps to point 3 in the new numbering - # and so on. - # Given this map, the final thing to do is map from plex-style numbering to - # the component-wise numbering used in pyop3. We should be able to do this by - # looping over the renumbering (NOT the inverse) and have a counter for each - # component. - - # map default cell numbers to their renumbered equivalents - cell_renumbering = np.empty(ncells, dtype=int) - min_cell, max_cell = mesh_axis._component_numbering_offsets[:2] - counter = 0 - for new_pt, old_pt in enumerate(mesh_axis.numbering.data_ro): - # is it a cell? - if min_cell <= old_pt < max_cell: - old_cell = old_pt - min_cell - cell_renumbering[old_cell] = counter - counter += 1 - assert counter == ncells - - # map default vertex numbers to their renumbered equivalents - vert_renumbering = np.empty(nverts, dtype=int) - min_vert, max_vert = mesh_axis._component_numbering_offsets[1:] - counter = 0 - for new_pt, old_pt in enumerate(mesh_axis.numbering.data_ro): - # is it a vertex? - if min_vert <= old_pt < max_vert: - old_vert = old_pt - min_vert - vert_renumbering[old_vert] = counter - counter += 1 - assert counter == nverts - # renumber the map mdata_renum = np.empty_like(mdata) for old_cell in range(ncells): - new_cell = cell_renumbering[old_cell] + # new_cell = cell_renumbering[old_cell] + new_cell = mesh_axis.default_to_applied_component_number("cells", old_cell) for i, old_pt in enumerate(mdata[old_cell]): - old_vert = old_pt - min_vert - mdata_renum[new_cell, i] = vert_renumbering[old_vert] + component, old_vert = mesh_axis.axis_to_component_number(old_pt) + assert component.label == "verts" + new_vert = mesh_axis.default_to_applied_component_number("verts", old_vert) + mdata_renum[new_cell, i] = new_vert mdat = op3.HierarchicalArray(maxes, name="cone", data=mdata_renum.flatten()) return op3.Map( @@ -170,6 +123,7 @@ def cone_map(comm, mesh_axis): @pytest.mark.parallel(nprocs=2) # @pytest.mark.parametrize("intent", [op3.INC, op3.MIN, op3.MAX]) @pytest.mark.parametrize(["intent", "fill_value"], [(op3.WRITE, 0), (op3.INC, 0)]) +# @pytest.mark.timeout(5) for now def test_parallel_loop(comm, paxis, intent, fill_value): assert comm.size == 2 @@ -193,6 +147,7 @@ def test_parallel_loop(comm, paxis, intent, fill_value): # can try with P1 and P2 @pytest.mark.parallel(nprocs=2) +@pytest.mark.timeout(5) def test_parallel_loop_with_map(comm, mesh_axis, cone_map, scalar_copy_kernel): assert comm.size == 2 rank = comm.rank @@ -285,10 +240,12 @@ def test_parallel_loop_with_map(comm, mesh_axis, cone_map, scalar_copy_kernel): @pytest.mark.parallel(nprocs=2) +@pytest.mark.timeout(5) def test_same_reductions_commute(): ... @pytest.mark.parallel(nprocs=2) +@pytest.mark.timeout(5) def test_different_reductions_do_not_commute(): ... diff --git a/tests/integration/test_petscmat.py b/tests/integration/test_petscmat.py index 3a8b5d90..ae661550 100644 --- a/tests/integration/test_petscmat.py +++ b/tests/integration/test_petscmat.py @@ -68,6 +68,7 @@ def test_map_compression(scalar_copy_kernel_int): assert np.allclose(pt_to_dofs.data_ro, expected.flatten()) +@pytest.mark.skip(reason="PetscMat API has changed significantly to use adjacency maps") def test_read_matrix_values(): # Imagine a 1D mesh storing DoFs at vertices: # @@ -86,7 +87,7 @@ def test_read_matrix_values(): # FIXME we need to be able to distinguish row and col DoFs (and the IDs must differ) # this should be handled internally somehow dofs_ = op3.Axis(4, "dofs_") - mat = op3.PetscMat(dofs, dofs_, indices, name="mat") + mat = op3.PetscMatAIJ(dofs, dofs_, indices, name="mat") # put some numbers in the matrix sparsity = [ @@ -125,14 +126,14 @@ def test_read_matrix_values(): "map0", ) # so we don't have axes with the same name, needs cleanup - map1 = op3.Map( - { - pmap({"mesh": "cells"}): [ - op3.TabulatedMapComponent("dofs_", dofs_.component.label, map_dat) - ] - }, - "map1", - ) + # map1 = op3.Map( + # { + # pmap({"mesh": "cells"}): [ + # op3.TabulatedMapComponent("dofs_", dofs_.component.label, map_dat) + # ] + # }, + # "map1", + # ) # perform the computation lpy_kernel = lp.make_kernel( diff --git a/tests/integration/test_subsets.py b/tests/integration/test_subsets.py index c85e8cc6..74d1397d 100644 --- a/tests/integration/test_subsets.py +++ b/tests/integration/test_subsets.py @@ -1,30 +1,8 @@ import loopy as lp import numpy as np import pytest -from pyrsistent import pmap - -from pyop3 import ( - INC, - READ, - WRITE, - AffineSliceComponent, - Axis, - AxisComponent, - AxisTree, - Index, - IndexTree, - IntType, - Map, - MultiArray, - ScalarType, - Slice, - SliceComponent, - TabulatedMapComponent, - do_loop, - loop, -) -from pyop3.ir import LOOPY_LANG_VERSION, LOOPY_TARGET -from pyop3.utils import flatten + +import pyop3 as op3 @pytest.mark.parametrize( @@ -35,29 +13,55 @@ (slice(None, None, 2), slice(1, None, 2)), ], ) -def test_loop_over_slices(scalar_copy_kernel, touched, untouched): +def test_loop_over_slices(touched, untouched, factory): npoints = 10 - axes = AxisTree(Axis(npoints)) - dat0 = MultiArray(axes, name="dat0", data=np.arange(npoints, dtype=ScalarType)) - dat1 = MultiArray(axes, name="dat1", dtype=dat0.dtype) + axes = op3.Axis(npoints) + dat0 = op3.HierarchicalArray( + axes, name="dat0", data=np.arange(npoints), dtype=op3.ScalarType + ) + dat1 = op3.HierarchicalArray(axes, name="dat1", dtype=dat0.dtype) - do_loop(p := axes[touched].index(), scalar_copy_kernel(dat0[p], dat1[p])) - assert np.allclose(dat1.data[untouched], 0) - assert np.allclose(dat1.data[touched], dat0.data[touched]) + copy = factory.copy_kernel(1, dat0.dtype) + op3.do_loop(p := axes[touched].index(), copy(dat0[p], dat1[p])) + assert np.allclose(dat1.data_ro[untouched], 0) + assert np.allclose(dat1.data_ro[touched], dat0.data_ro[touched]) @pytest.mark.parametrize("size,touched", [(6, [2, 3, 5, 0])]) -def test_scalar_copy_of_subset(scalar_copy_kernel, size, touched): +def test_scalar_copy_of_subset(size, touched, factory): untouched = list(set(range(size)) - set(touched)) - subset_axes = Axis([AxisComponent(len(touched), "pt0")], "ax0") - subset = MultiArray( - subset_axes, name="subset0", data=np.asarray(touched, dtype=IntType) + subset_axes = op3.Axis(len(touched)) + subset = op3.HierarchicalArray( + subset_axes, name="subset0", data=np.asarray(touched), dtype=op3.IntType ) - axes = Axis([AxisComponent(size, "pt0")], "ax0") - dat0 = MultiArray(axes, name="dat0", data=np.arange(axes.size, dtype=ScalarType)) - dat1 = MultiArray(axes, name="dat1", dtype=dat0.dtype) + axes = op3.Axis(size) + dat0 = op3.HierarchicalArray( + axes, name="dat0", data=np.arange(axes.size), dtype=op3.ScalarType + ) + dat1 = op3.HierarchicalArray(axes, name="dat1", dtype=dat0.dtype) + + copy = factory.copy_kernel(1, dat0.dtype) + op3.do_loop(p := axes[subset].index(), copy(dat0[p], dat1[p])) + assert np.allclose(dat1.data_ro[touched], dat0.data_ro[touched]) + assert np.allclose(dat1.data_ro[untouched], 0) + + +@pytest.mark.parametrize("size,indices", [(6, [2, 3, 5, 0])]) +def test_write_to_subset(size, indices, factory): + n = len(indices) + + subset_axes = op3.Axis(n) + subset = op3.HierarchicalArray( + subset_axes, name="subset0", data=np.asarray(indices, dtype=op3.IntType) + ) + + axes = op3.Axis(size) + dat0 = op3.HierarchicalArray( + axes, name="dat0", data=np.arange(axes.size, dtype=op3.IntType) + ) + dat1 = op3.HierarchicalArray(subset_axes, name="dat1", dtype=dat0.dtype) - do_loop(p := axes[subset].index(), scalar_copy_kernel(dat0[p], dat1[p])) - assert np.allclose(dat1.data[touched], dat0.data[touched]) - assert np.allclose(dat1.data[untouched], 0) + copy = factory.copy_kernel(n, dat0.dtype) + op3.do_loop(op3.Axis(1).index(), copy(dat0[subset], dat1)) + assert (dat1.data_ro == indices).all() diff --git a/tests/unit/test_array.py b/tests/unit/test_array.py new file mode 100644 index 00000000..8f41b81b --- /dev/null +++ b/tests/unit/test_array.py @@ -0,0 +1,13 @@ +import pytest + +import pyop3 as op3 + + +def test_eager_zero(): + axes = op3.Axis(5) + array = op3.HierarchicalArray(axes, dtype=op3.IntType) + assert (array.buffer._data == 0).all() + + array.buffer._data[...] = 666 + array.eager_zero() + assert (array.buffer._data == 0).all() diff --git a/tests/unit/test_axis.py b/tests/unit/test_axis.py index ccc6380c..93e292ff 100644 --- a/tests/unit/test_axis.py +++ b/tests/unit/test_axis.py @@ -4,7 +4,7 @@ from pyrsistent import freeze, pmap import pyop3 as op3 -from pyop3.utils import UniqueNameGenerator, flatten, just_one, single_valued +from pyop3.utils import UniqueNameGenerator, flatten, just_one, single_valued, steps class RenameMapper(pym.mapper.IdentityMapper): @@ -59,15 +59,15 @@ def collect_multi_arrays(layout): return _ordered_collector(layout) -def check_offsets(axes, indices_and_offsets): - for indices, offset in indices_and_offsets: - assert axes.offset(indices) == offset +def check_offsets(axes, offset_args_and_offsets): + for args, offset in offset_args_and_offsets: + assert axes.offset(*args) == offset def check_invalid_indices(axes, indicess): - for indices in indicess: + for indices, path in indicess: with pytest.raises(IndexError): - axes.offset(indices) + axes.offset(indices, path) @pytest.mark.parametrize("numbering", [None, [2, 3, 0, 4, 1]]) @@ -88,7 +88,11 @@ def test_1d_affine_layout(numbering): ([4], 4), ], ) - check_invalid_indices(axes, [[5]]) + # check_invalid_indices( + # axes, + # [ + # ({"ax0": 5}, {"ax0": "pt0"}), + # ]) def test_2d_affine_layout(): @@ -102,15 +106,15 @@ def test_2d_affine_layout(): check_offsets( axes, [ - ([0, 0], 0), - ([0, 1], 1), - ([1, 0], 2), - ([1, 1], 3), - ([2, 0], 4), - ([2, 1], 5), + ([[0, 0]], 0), + ([[0, 1]], 1), + ([[1, 0]], 2), + ([[1, 1]], 3), + ([[2, 0]], 4), + ([[2, 1]], 5), ], ) - check_invalid_indices(axes, [[3, 0], [0, 2], [1, 2], [2, 2]]) + # check_invalid_indices(axes, [[3, 0], [0, 2], [1, 2], [2, 2]]) def test_1d_multi_component_layout(): @@ -124,24 +128,24 @@ def test_1d_multi_component_layout(): check_offsets( axes, [ - ([("pt0", 0)], 0), - ([("pt0", 1)], 1), - ([("pt0", 2)], 2), - ([("pt1", 0)], 3), - ([("pt1", 1)], 4), - ], - ) - check_invalid_indices( - axes, - [ - [], - [("pt0", -1)], - [("pt0", 3)], - [("pt1", -1)], - [("pt1", 2)], - [("pt0", 0), 0], + ([0, {"ax0": "pt0"}], 0), + ([1, {"ax0": "pt0"}], 1), + ([2, {"ax0": "pt0"}], 2), + ([0, {"ax0": "pt1"}], 3), + ([1, {"ax0": "pt1"}], 4), ], ) + # check_invalid_indices( + # axes, + # [ + # [], + # [("pt0", -1)], + # [("pt0", 3)], + # [("pt1", -1)], + # [("pt1", 2)], + # [("pt0", 0), 0], + # ], + # ) def test_1d_multi_component_permuted_layout(): @@ -163,22 +167,22 @@ def test_1d_multi_component_permuted_layout(): check_offsets( axes, [ - ([("pt0", 0)], 1), - ([("pt0", 1)], 3), - ([("pt0", 2)], 4), - ([("pt1", 0)], 0), - ([("pt1", 1)], 2), - ], - ) - check_invalid_indices( - axes, - [ - [("pt0", -1)], - [("pt0", 3)], - [("pt1", -1)], - [("pt1", 2)], + ([0, {"ax0": "pt0"}], 1), + ([1, {"ax0": "pt0"}], 3), + ([2, {"ax0": "pt0"}], 4), + ([0, {"ax0": "pt1"}], 0), + ([1, {"ax0": "pt1"}], 2), ], ) + # check_invalid_indices( + # axes, + # [ + # [("pt0", -1)], + # [("pt0", 3)], + # [("pt1", -1)], + # [("pt1", 2)], + # ], + # ) def test_1d_zero_sized_layout(): @@ -187,7 +191,7 @@ def test_1d_zero_sized_layout(): layout0 = axes.layouts[pmap({"ax0": "pt0"})] assert as_str(layout0) == "var_0" - check_invalid_indices(axes, [[], [0]]) + # check_invalid_indices(axes, [[], [0]]) def test_multi_component_layout_with_zero_sized_subaxis(): @@ -211,20 +215,20 @@ def test_multi_component_layout_with_zero_sized_subaxis(): check_offsets( axes, [ - ([("pt1", 0), 0], 0), - ([("pt1", 0), 1], 1), - ([("pt1", 0), 2], 2), - ], - ) - check_invalid_indices( - axes, - [ - [], - [("pt0", 0), 0], - [("pt1", 0), 3], - [("pt1", 1), 0], + ([[0, 0], {"ax0": "pt1", "ax1": "pt0"}], 0), + ([[0, 1], {"ax0": "pt1", "ax1": "pt0"}], 1), + ([[0, 2], {"ax0": "pt1", "ax1": "pt0"}], 2), ], ) + # check_invalid_indices( + # axes, + # [ + # [], + # [("pt0", 0), 0], + # [("pt1", 0), 3], + # [("pt1", 1), 0], + # ], + # ) def test_permuted_multi_component_layout_with_zero_sized_subaxis(): @@ -249,24 +253,24 @@ def test_permuted_multi_component_layout_with_zero_sized_subaxis(): check_offsets( axes, [ - ([("pt1", 0), 0], 0), - ([("pt1", 0), 1], 1), - ([("pt1", 0), 2], 2), - ([("pt1", 1), 0], 3), - ([("pt1", 1), 1], 4), - ([("pt1", 1), 2], 5), - ], - ) - check_invalid_indices( - axes, - [ - [("pt0", 0), 0], - [("pt1", 0)], - [("pt1", 2), 0], - [("pt1", 0), 3], - [("pt1", 0), 0, 0], + ([[0, 0], {"ax0": "pt1", "ax1": "pt0"}], 0), + ([[0, 1], {"ax0": "pt1", "ax1": "pt0"}], 1), + ([[0, 2], {"ax0": "pt1", "ax1": "pt0"}], 2), + ([[1, 0], {"ax0": "pt1", "ax1": "pt0"}], 3), + ([[1, 1], {"ax0": "pt1", "ax1": "pt0"}], 4), + ([[1, 2], {"ax0": "pt1", "ax1": "pt0"}], 5), ], ) + # check_invalid_indices( + # axes, + # [ + # [("pt0", 0), 0], + # [("pt1", 0)], + # [("pt1", 2), 0], + # [("pt1", 0), 3], + # [("pt1", 0), 0, 0], + # ], + # ) def test_ragged_layout(): @@ -283,26 +287,26 @@ def test_ragged_layout(): check_offsets( axes, [ - ([0, 0], 0), - ([0, 1], 1), - ([1, 0], 2), - ([2, 0], 3), - ([2, 1], 4), - ], - ) - check_invalid_indices( - axes, - [ - [-1, 0], - [0, -1], - [0, 2], - [1, -1], - [1, 1], - [2, -1], - [2, 2], - [3, 0], + ([[0, 0]], 0), + ([[0, 1]], 1), + ([[1, 0]], 2), + ([[2, 0]], 3), + ([[2, 1]], 4), ], ) + # check_invalid_indices( + # axes, + # [ + # [-1, 0], + # [0, -1], + # [0, 2], + # [1, -1], + # [1, 1], + # [2, -1], + # [2, 2], + # [3, 0], + # ], + # ) def test_ragged_layout_with_two_outer_axes(): @@ -326,25 +330,25 @@ def test_ragged_layout_with_two_outer_axes(): check_offsets( axes, [ - ([0, 0, 0], 0), - ([0, 0, 1], 1), - ([0, 1, 0], 2), - ([1, 0, 0], 3), - ([1, 1, 0], 4), - ([1, 1, 1], 5), - ], - ) - check_invalid_indices( - axes, - [ - [0, 0, 2], - [0, 1, 1], - [1, 0, 1], - [1, 1, 2], - [1, 2, 0], - [2, 0, 0], + ([[0, 0, 0]], 0), + ([[0, 0, 1]], 1), + ([[0, 1, 0]], 2), + ([[1, 0, 0]], 3), + ([[1, 1, 0]], 4), + ([[1, 1, 1]], 5), ], ) + # check_invalid_indices( + # axes, + # [ + # [0, 0, 2], + # [0, 1, 1], + # [1, 0, 1], + # [1, 1, 2], + # [1, 2, 0], + # [2, 0, 0], + # ], + # ) @pytest.mark.xfail(reason="Adjacent ragged components do not yet work") @@ -409,3 +413,23 @@ def test_independent_ragged_axes(): # [2, 0, 0], # ], # ) + + +def test_tabulate_nested_ragged_indexed_layouts(): + axis0 = op3.Axis(3) + axis1 = op3.Axis(2) + axis2 = op3.Axis(2) + nnz_data = np.asarray([[1, 0], [3, 2], [1, 1]], dtype=op3.IntType).flatten() + nnz_axes = op3.AxisTree.from_iterable([axis0, axis1]) + nnz = op3.HierarchicalArray(nnz_axes, data=nnz_data) + axes = op3.AxisTree.from_iterable([axis0, axis1, op3.Axis(nnz), axis2]) + # axes = op3.AxisTree.from_iterable([axis0, op3.Axis(nnz), op3.Axis(2)]) + # axes = op3.AxisTree.from_iterable([axis0, op3.Axis(nnz)]) + + p = axis0.index() + indexed_axes = just_one(axes[p].context_map.values()) + + layout = indexed_axes.layouts[indexed_axes.path(*indexed_axes.leaf)] + array0 = just_one(collect_multi_arrays(layout)) + expected = np.asarray(steps(nnz_data, drop_last=True), dtype=op3.IntType) * 2 + assert (array0.data_ro == expected).all() diff --git a/tests/unit/test_distarray.py b/tests/unit/test_distarray.py index 4969ec2c..08ad1434 100644 --- a/tests/unit/test_distarray.py +++ b/tests/unit/test_distarray.py @@ -52,7 +52,7 @@ def array(comm): serial = op3.Axis(npoints) axis = op3.Axis.from_serial(serial, sf) axes = op3.AxisTree.from_nest({axis: op3.Axis(3)}).freeze() - return op3.DistributedBuffer(axes.size, sf=axes.sf) + return op3.DistributedBuffer(axes.size, axes.sf) @pytest.mark.parallel(nprocs=2) diff --git a/tests/unit/test_indices.py b/tests/unit/test_indices.py index d1698423..308ac1e4 100644 --- a/tests/unit/test_indices.py +++ b/tests/unit/test_indices.py @@ -1,40 +1,109 @@ -import numpy as np import pytest -from pyrsistent import freeze +from pyrsistent import freeze, pmap import pyop3 as op3 -def test_loop_index_iter_flat(): +def test_axes_iter_flat(): iterset = op3.Axis({"pt0": 5}, "ax0") - expected = [ - (freeze({"ax0": "pt0"}),) * 2 + (freeze({"ax0": i}),) * 2 for i in range(5) - ] - assert list(iterset.index().iter()) == expected + for i, p in enumerate(iterset.iter()): + assert p.source_path == freeze({"ax0": "pt0"}) + assert p.target_path == p.source_path + assert p.source_exprs == freeze({"ax0": i}) + assert p.target_exprs == p.source_exprs -def test_loop_index_iter_nested(): +def test_axes_iter_nested(): iterset = op3.AxisTree.from_nest( { op3.Axis({"pt0": 5}, "ax0"): op3.Axis({"pt0": 3}, "ax1"), }, ) - path = freeze({"ax0": "pt0", "ax1": "pt0"}) - expected = [ - (path,) * 2 + (freeze({"ax0": i, "ax1": j}),) * 2 - for i in range(5) - for j in range(3) - ] - assert list(iterset.index().iter()) == expected + iterator = iterset.iter() + for i in range(5): + for j in range(3): + p = next(iterator) + assert p.source_path == freeze({"ax0": "pt0", "ax1": "pt0"}) + assert p.target_path == p.source_path + assert p.source_exprs == freeze({"ax0": i, "ax1": j}) + assert p.target_exprs == p.source_exprs + # make sure that the iterator is empty + try: + next(iterator) + assert False + except StopIteration: + pass -def test_loop_index_iter_multi_component(): + +def test_axes_iter_multi_component(): iterset = op3.Axis({"pt0": 3, "pt1": 3}, "ax0") - path0 = freeze({"ax0": "pt0"}) - path1 = freeze({"ax0": "pt1"}) - expected = [(path0,) * 2 + (freeze({"ax0": i}),) * 2 for i in range(3)] + [ - (path1,) * 2 + (freeze({"ax0": i}),) * 2 for i in range(3) - ] - assert list(iterset.index().iter()) == expected + iterator = iterset.iter() + for i in range(3): + p = next(iterator) + assert p.source_path == freeze({"ax0": "pt0"}) + assert p.target_path == p.source_path + assert p.source_exprs == freeze({"ax0": i}) + assert p.target_exprs == p.source_exprs + + for i in range(3): + p = next(iterator) + assert p.source_path == freeze({"ax0": "pt1"}) + assert p.target_path == p.source_path + assert p.source_exprs == freeze({"ax0": i}) + assert p.target_exprs == p.source_exprs + + # make sure that the iterator is empty + try: + next(iterator) + assert False + except StopIteration: + pass + + +def test_index_forest_inserts_extra_slices(): + axes = op3.AxisTree.from_nest( + { + op3.Axis({"pt0": 5}, "ax0"): op3.Axis({"pt0": 3}, "ax1"), + }, + ) + iforest = op3.itree.as_index_forest(slice(None), axes=axes) + + # since there are no loop indices, the index forest should contain a single entry + assert len(iforest) == 1 + assert pmap() in iforest.keys() + + itree = iforest[pmap()] + assert itree.depth == 2 + + +@pytest.mark.xfail(reason="Index tree.leaves currently broken") +def test_multi_component_index_forest_inserts_extra_slices(): + axes = op3.AxisTree.from_nest( + { + op3.Axis({"pt0": 5, "pt1": 4}, "ax0"): { + "pt0": op3.Axis({"pt0": 3}, "ax1"), + "pt1": op3.Axis({"pt0": 2}, "ax1"), + } + }, + ) + iforest = op3.itree.as_index_forest( + op3.Slice("ax1", [op3.AffineSliceComponent("pt0")]), axes=axes + ) + + # since there are no loop indices, the index forest should contain a single entry + assert len(iforest) == 1 + assert pmap() in iforest.keys() + + itree = iforest[pmap()] + assert itree.depth == 2 + assert itree.root.label == "ax1" + + # FIXME this currently fails because itree.leaves does not work. + # This is because it is difficult for loop indices to advertise component labels. + # Perhaps they should be an index component themselves? I have made some notes + # on this. + assert all(index.label == "ax0" for index, _ in itree.leaves) + assert len(itree.leaves) == 2 diff --git a/tests/unit/test_parallel.py b/tests/unit/test_parallel.py index 5fd6b163..5b355be2 100644 --- a/tests/unit/test_parallel.py +++ b/tests/unit/test_parallel.py @@ -58,6 +58,7 @@ def maxis(comm, msf): @pytest.mark.parallel(nprocs=2) +@pytest.mark.timeout(5) def test_halo_data_stored_at_end_of_array(comm, paxis): if comm.rank == 0: reordered = [3, 2, 4, 5, 0, 1] @@ -69,6 +70,7 @@ def test_halo_data_stored_at_end_of_array(comm, paxis): @pytest.mark.parallel(nprocs=2) +@pytest.mark.timeout(5) def test_multi_component_halo_data_stored_at_end(comm, maxis): if comm.rank == 0: # unchanged as halo data already at the end @@ -80,6 +82,7 @@ def test_multi_component_halo_data_stored_at_end(comm, maxis): @pytest.mark.parallel(nprocs=2) +@pytest.mark.timeout(5) def test_distributed_subaxes_partition_halo_data(paxis): # Check that # @@ -131,6 +134,7 @@ def test_distributed_subaxes_partition_halo_data(paxis): @pytest.mark.parallel(nprocs=2) +@pytest.mark.timeout(5) def test_nested_parallel_axes_produce_correct_sf(comm, paxis): # Check that # @@ -151,7 +155,7 @@ def test_nested_parallel_axes_produce_correct_sf(comm, paxis): rank = comm.rank other_rank = (rank + 1) % 2 - array = op3.DistributedBuffer(axes.size, sf=axes.sf) + array = op3.DistributedBuffer(axes.size, axes.sf) array._data[...] = rank array._leaves_valid = False @@ -166,6 +170,7 @@ def test_nested_parallel_axes_produce_correct_sf(comm, paxis): @pytest.mark.parallel(nprocs=2) @pytest.mark.parametrize("with_ghosts", [False, True]) +@pytest.mark.timeout(5) def test_partition_iterset_scalar(comm, paxis, with_ghosts): array = op3.HierarchicalArray(paxis, dtype=op3.ScalarType) @@ -193,6 +198,7 @@ def test_partition_iterset_scalar(comm, paxis, with_ghosts): @pytest.mark.parallel(nprocs=2) @pytest.mark.parametrize("with_ghosts", [False, True]) +@pytest.mark.timeout(5) def test_partition_iterset_with_map(comm, paxis, with_ghosts): axis_label = paxis.label component_label = just_one(paxis.components).label @@ -244,3 +250,78 @@ def test_partition_iterset_with_map(comm, paxis, with_ghosts): assert np.equal(icore.data_ro, expected_icore).all() assert np.equal(iroot.data_ro, expected_iroot).all() assert np.equal(ileaf.data_ro, expected_ileaf).all() + + +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize("intent", [op3.WRITE, op3.INC]) +@pytest.mark.timeout(5) +def test_shared_array(comm, intent): + sf = op3.sf.single_star(comm, 3) + axes = op3.AxisTree.from_nest({op3.Axis(3, sf=sf): op3.Axis(2)}) + shared = op3.HierarchicalArray(axes) + + assert (shared.data_ro == 0).all() + + if comm.rank == 0: + shared.buffer._data[...] = 1 + else: + assert comm.rank == 1 + shared.buffer._data[...] = 2 + shared.buffer._leaves_valid = False + shared.buffer._pending_reduction = intent + + shared.assemble() + + if intent == op3.WRITE: + # we reduce from leaves (which store a 2) to roots (which store a 1) + assert (shared.data_ro == 2).all() + else: + assert intent == op3.INC + assert (shared.data_ro == 3).all() + + +@pytest.mark.parallel(nprocs=2) +def test_lgmaps(comm): + # Create a star forest for the following distribution + # + # g g + # rank 0: [0, 1, * 2, 3, 4, 5] + # | | * | | + # rank 1: [0, 1, 2, 3, * 4, 5] + # g g + if comm.rank == 0: + size = 6 + nroots = 4 + ilocal = [0, 1] + iremote = [(1, 2), (1, 3)] + else: + assert comm.rank == 1 + size = 6 + nroots = 4 + ilocal = [4, 5] + iremote = [(0, 2), (0, 3)] + sf = op3.StarForest.from_graph(size, nroots, ilocal, iremote, comm) + + serial_axis = op3.Axis(size) + axis0 = op3.Axis.from_serial(serial_axis, sf=sf) + + lgmap = axis0.global_numbering() + print_with_rank(lgmap) + + raise NotImplementedError + axes = op3.AxisTree.from_iterable((axis0, 2)) + + # self.sf.sf.view() + sf.sf.view() + # lgmap = PETSc.LGMap().createSF(axes.sf.sf, PETSc.DECIDE) + lgmap = PETSc.LGMap().createSF(sf.sf, PETSc.DECIDE) + lgmap.setType(PETSc.LGMap.Type.BASIC) + # self._lazy_lgmap = lgmap + lgmap.view() + print_with_rank(lgmap.indices) + + raise NotImplementedError + + lgmap = axes.lgmap + print_with_rank(lgmap.indices) + assert False