diff --git a/pyop3/codegen/loopexpr2loopy.py b/pyop3/codegen/loopexpr2loopy.py index 30d2a277..eb0f452d 100644 --- a/pyop3/codegen/loopexpr2loopy.py +++ b/pyop3/codegen/loopexpr2loopy.py @@ -1595,19 +1595,24 @@ def _indexed_axes(indexed, loop_indices): """ - # handle 'indexed of indexed' things - if isinstance(indexed.obj, Indexed): - orig_axes = _indexed_axes(indexed.obj, loop_indices) - else: - assert isinstance(indexed.obj, MultiArray) - orig_axes = indexed.obj.axes - # get the right index tree given the loop context loop_context = {} for loop_index, (path, _) in loop_indices.items(): loop_context[loop_index] = pmap(path) loop_context = pmap(loop_context) - index_tree = indexed.indices[loop_context] + + # nasty hack, pass a tuple of axis tree and index tree sometimes + if isinstance(indexed, tuple): + orig_axes, split_index_tree = indexed + index_tree = split_index_tree[loop_context] + else: + # handle 'indexed of indexed' things + if isinstance(indexed.obj, Indexed): + orig_axes = _indexed_axes(indexed.obj, loop_indices) + else: + assert isinstance(indexed.obj, MultiArray) + orig_axes = indexed.obj.axes + index_tree = indexed.indices[loop_context] axes = visit_indices( index_tree, diff --git a/pyop3/distarray/multiarray.py b/pyop3/distarray/multiarray.py index b2ba5c71..7bb26800 100644 --- a/pyop3/distarray/multiarray.py +++ b/pyop3/distarray/multiarray.py @@ -14,6 +14,7 @@ import pytools from mpi4py import MPI from petsc4py import PETSc +from pyrsistent import pmap from pyop3 import utils from pyop3.axis import Axis, AxisComponent, AxisTree, as_axis_tree, get_bottom_part @@ -125,6 +126,12 @@ def __init__( self._sync_thread = None + # don't like this, instead use something singledispatch in the right place + # split_axes is only used for a very specific use case + @property + def split_axes(self): + return pmap({pmap(): self.axes}) + @property def data(self): import warnings @@ -313,12 +320,6 @@ def root(self): # maybe I could check types here and use instead of get_value? def __getitem__(self, indices: IndexTree | Index): - indices = as_split_index_tree(indices, axes=self.axes) - - # TODO recover this - # if not is_fully_indexed(self.axes, indices): - # raise ValueError("Provided indices are insufficient to address the array") - return Indexed(self, indices) def select_axes(self, indices): diff --git a/pyop3/index.py b/pyop3/index.py index 90da3720..d268c3a5 100644 --- a/pyop3/index.py +++ b/pyop3/index.py @@ -4,6 +4,7 @@ import collections import dataclasses import functools +import numbers from typing import Any, Collection, Hashable, Mapping, Sequence import pytools @@ -84,8 +85,28 @@ class Indexed: """ def __init__(self, obj, indices): + from pyop3.codegen.loopexpr2loopy import _indexed_axes + + # The following tricksy bit of code builds a pretend AxisTree for the + # indexed object. It is complicated because the resultant AxisTree will + # have a different shape depending on the loop context (which is why we have + # SplitIndexTrees). We therefore store axes here split by loop context. + split_indices = {} + split_axes = {} + for loop_ctx, axes in obj.split_axes.items(): + indices = as_split_index_tree(indices, axes=axes, loop_context=loop_ctx) + split_indices |= indices.index_trees + for loop_ctx_, itree in indices.index_trees.items(): + # nasty hack because _indexed_axes currently expects a 2-tuple per loop index + assert set(loop_ctx.keys()) <= set(loop_ctx_.keys()) + my_loop_context = { + idx: (path, "not used") for idx, path in loop_ctx_.items() + } + split_axes[loop_ctx_] = _indexed_axes((axes, indices), my_loop_context) + self.obj = obj - self.indices = as_split_index_tree(indices) + self.split_axes = pmap(split_axes) + self.indices = SplitIndexTree(split_indices) # old alias, not right now we have a pmap of index trees rather than just a single one @property @@ -93,16 +114,6 @@ def itree(self): return self.indices def __getitem__(self, indices): - from pyop3.distarray import MultiArray - - if not isinstance(self.obj, MultiArray) and not isinstance( - indices, (IndexTree, Index) - ): - raise NotImplementedError( - "Need to compute the temporary/intermediate axes for this to be allowed" - ) - - indices = as_split_index_tree(indices) return Indexed(self, indices) @functools.cached_property @@ -570,6 +581,12 @@ def _split_index_tree_from_iterable( ] elif isinstance(index, MultiArray): slice_cpts = [Subset(cpt.label, index) for cpt in current_axis.components] + elif isinstance(index, numbers.Integral): + # an integer is just a one-sized slice (assumed over all components) + slice_cpts = [ + AffineSliceComponent(cpt.label, index, index + 1) + for cpt in current_axis.components + ] else: raise TypeError index = Slice(current_axis.label, slice_cpts) @@ -583,7 +600,7 @@ def _split_index_tree_from_iterable( index.component_labels, index.target_paths ): split_subtree = _split_index_tree_from_iterable( - subindices, axes, path | target_path + subindices, axes, path | target_path, loop_context ) for loopctx, subtree in split_subtree.index_trees.items(): if loopctx not in index_trees: @@ -600,7 +617,9 @@ def _split_index_tree_from_iterable( def _split_index_tree_from_ellipsis( - axes: AxisTree, current_axis: Axis | None = None + axes: AxisTree, + current_axis: Axis | None = None, + loop_context=pmap(), ) -> IndexTree: current_axis = current_axis or axes.root @@ -611,7 +630,7 @@ def _split_index_tree_from_ellipsis( subaxis = axes.child(current_axis, cpt) if subaxis: - subtrees.append(_index_tree_from_ellipsis(axes, subaxis)) + subtrees.append(_index_tree_from_ellipsis(axes, subaxis, loop_context)) else: subtrees.append(None) @@ -620,7 +639,7 @@ def _split_index_tree_from_ellipsis( for subslice, subtree in checked_zip(subslices, subtrees): if subtree is not None: tree = tree.add_subtree(subtree, slice_, subslice.component) - return tree + return SplitIndexTree({pmap(): tree}) def is_fully_indexed(axes: AxisTree, indices: IndexTree) -> bool: diff --git a/tests/integration/test_axis_ordering.py b/tests/integration/test_axis_ordering.py index 1234b396..d3b16112 100644 --- a/tests/integration/test_axis_ordering.py +++ b/tests/integration/test_axis_ordering.py @@ -11,6 +11,7 @@ from pyop3.codegen import LOOPY_LANG_VERSION, LOOPY_TARGET from pyop3.distarray import MultiArray from pyop3.dtypes import IntType, ScalarType +from pyop3.index import SplitIndexTree from pyop3.loopexpr import INC, READ, WRITE, LoopyKernel, do_loop, loop from pyop3.utils import flatten, just_one @@ -51,13 +52,15 @@ def test_different_axis_orderings_do_not_change_packing_order(): dat1 = MultiArray(axes0, name="dat1", data=np.zeros(npoints, dtype=ScalarType)) p = axis0.index() - q = IndexTree(SplitLoopIndex(p, just_one(axis0.target_paths))) + path = just_one(axis0.target_paths) + q = IndexTree(SplitLoopIndex(p, path)) q = q.put_node( Slice("ax1", [AffineSliceComponent(axis1.component_labels[0])]), *q.leaf ) q = q.put_node( Slice("ax2", [AffineSliceComponent(axis2.component_labels[0])]), *q.leaf ) + q = SplitIndexTree({pmap({p: path}): q}) do_loop(p, copy_kernel(dat0_0[q], dat1[q])) assert np.allclose(dat1.data, dat0_0.data) diff --git a/tests/integration/test_nested_loops.py b/tests/integration/test_nested_loops.py new file mode 100644 index 00000000..46f67bcc --- /dev/null +++ b/tests/integration/test_nested_loops.py @@ -0,0 +1,26 @@ +import numpy as np + +from pyop3 import Axis, AxisTree, MultiArray, ScalarType, do_loop, loop + + +def test_transpose(scalar_copy_kernel): + npoints = 5 + axis0 = Axis(npoints) + axis1 = Axis(npoints) + + axes0 = AxisTree(axis0, {axis0.id: axis1}) + axes1 = AxisTree(axis1, {axis1.id: axis0}) + + array0 = MultiArray( + axes0, name="array0", data=np.arange(axes0.size, dtype=ScalarType) + ) + array1 = MultiArray(axes1, name="array1", dtype=array0.dtype) + + do_loop( + p := axis0.index(), + loop(q := axis1.index(), scalar_copy_kernel(array0[p, q], array1[q, p])), + ) + assert np.allclose( + array1.data.reshape((npoints, npoints)), + array0.data.reshape((npoints, npoints)).T, + ) diff --git a/tests/integration/test_slice_composition.py b/tests/integration/test_slice_composition.py index edfb64c9..a344aca8 100644 --- a/tests/integration/test_slice_composition.py +++ b/tests/integration/test_slice_composition.py @@ -41,12 +41,7 @@ def test_1d_slice_composition(vec2_copy_kernel): ) dat1 = MultiArray(Axis([(n, "cpt0")], "ax0"), name="dat1", dtype=dat0.dtype) - # this is needed because we currently do not compute the axis tree for the - # intermediate indexed object, so it cannot be indexed with shorthand - itree = IndexTree(Slice("ax0", [AffineSliceComponent("cpt0", 1, 3)])) - - # do_loop(Axis(1).index(), vec2_copy_kernel(dat0[::2][1:3], dat1[...])) - do_loop(Axis(1).index(), vec2_copy_kernel(dat0[::2][itree], dat1[...])) + do_loop(Axis(1).index(), vec2_copy_kernel(dat0[::2][1:3], dat1[...])) assert np.allclose(dat1.data, dat0.data[::2][1:3]) @@ -62,16 +57,10 @@ def test_2d_slice_composition(vec2_copy_kernel): dat0 = MultiArray(axes0, name="dat0", data=np.arange(axes0.size, dtype=ScalarType)) dat1 = MultiArray(axes1, name="dat1", dtype=dat0.dtype) - itree = IndexTree( - Slice("ax0", [AffineSliceComponent("cpt0", 2, 4)], id="slice1"), - {"slice1": Slice("ax1", [AffineSliceComponent("cpt0", 1, 2)])}, - ) - do_loop( Axis(1).index(), vec2_copy_kernel( - # dat0[::2, 1:][2:4, 1], - dat0[::2, 1:][itree], + dat0[::2, 1:][2:4, 1], dat1[...], ), )