Skip to content

Commit

Permalink
Merge pull request #3 from connorjward/better-slice-composition
Browse files Browse the repository at this point in the history
Better slice composition
  • Loading branch information
connorjward authored Aug 15, 2023
2 parents 61a4bdc + 3a5906c commit b625ee3
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 43 deletions.
21 changes: 13 additions & 8 deletions pyop3/codegen/loopexpr2loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,19 +1595,24 @@ def _indexed_axes(indexed, loop_indices):
"""

# handle 'indexed of indexed' things
if isinstance(indexed.obj, Indexed):
orig_axes = _indexed_axes(indexed.obj, loop_indices)
else:
assert isinstance(indexed.obj, MultiArray)
orig_axes = indexed.obj.axes

# get the right index tree given the loop context
loop_context = {}
for loop_index, (path, _) in loop_indices.items():
loop_context[loop_index] = pmap(path)
loop_context = pmap(loop_context)
index_tree = indexed.indices[loop_context]

# nasty hack, pass a tuple of axis tree and index tree sometimes
if isinstance(indexed, tuple):
orig_axes, split_index_tree = indexed
index_tree = split_index_tree[loop_context]
else:
# handle 'indexed of indexed' things
if isinstance(indexed.obj, Indexed):
orig_axes = _indexed_axes(indexed.obj, loop_indices)
else:
assert isinstance(indexed.obj, MultiArray)
orig_axes = indexed.obj.axes
index_tree = indexed.indices[loop_context]

axes = visit_indices(
index_tree,
Expand Down
13 changes: 7 additions & 6 deletions pyop3/distarray/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pytools
from mpi4py import MPI
from petsc4py import PETSc
from pyrsistent import pmap

from pyop3 import utils
from pyop3.axis import Axis, AxisComponent, AxisTree, as_axis_tree, get_bottom_part
Expand Down Expand Up @@ -125,6 +126,12 @@ def __init__(

self._sync_thread = None

# don't like this, instead use something singledispatch in the right place
# split_axes is only used for a very specific use case
@property
def split_axes(self):
return pmap({pmap(): self.axes})

@property
def data(self):
import warnings
Expand Down Expand Up @@ -313,12 +320,6 @@ def root(self):

# maybe I could check types here and use instead of get_value?
def __getitem__(self, indices: IndexTree | Index):
indices = as_split_index_tree(indices, axes=self.axes)

# TODO recover this
# if not is_fully_indexed(self.axes, indices):
# raise ValueError("Provided indices are insufficient to address the array")

return Indexed(self, indices)

def select_axes(self, indices):
Expand Down
49 changes: 34 additions & 15 deletions pyop3/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import collections
import dataclasses
import functools
import numbers
from typing import Any, Collection, Hashable, Mapping, Sequence

import pytools
Expand Down Expand Up @@ -84,25 +85,35 @@ class Indexed:
"""

def __init__(self, obj, indices):
from pyop3.codegen.loopexpr2loopy import _indexed_axes

# The following tricksy bit of code builds a pretend AxisTree for the
# indexed object. It is complicated because the resultant AxisTree will
# have a different shape depending on the loop context (which is why we have
# SplitIndexTrees). We therefore store axes here split by loop context.
split_indices = {}
split_axes = {}
for loop_ctx, axes in obj.split_axes.items():
indices = as_split_index_tree(indices, axes=axes, loop_context=loop_ctx)
split_indices |= indices.index_trees
for loop_ctx_, itree in indices.index_trees.items():
# nasty hack because _indexed_axes currently expects a 2-tuple per loop index
assert set(loop_ctx.keys()) <= set(loop_ctx_.keys())
my_loop_context = {
idx: (path, "not used") for idx, path in loop_ctx_.items()
}
split_axes[loop_ctx_] = _indexed_axes((axes, indices), my_loop_context)

self.obj = obj
self.indices = as_split_index_tree(indices)
self.split_axes = pmap(split_axes)
self.indices = SplitIndexTree(split_indices)

# old alias, not right now we have a pmap of index trees rather than just a single one
@property
def itree(self):
return self.indices

def __getitem__(self, indices):
from pyop3.distarray import MultiArray

if not isinstance(self.obj, MultiArray) and not isinstance(
indices, (IndexTree, Index)
):
raise NotImplementedError(
"Need to compute the temporary/intermediate axes for this to be allowed"
)

indices = as_split_index_tree(indices)
return Indexed(self, indices)

@functools.cached_property
Expand Down Expand Up @@ -570,6 +581,12 @@ def _split_index_tree_from_iterable(
]
elif isinstance(index, MultiArray):
slice_cpts = [Subset(cpt.label, index) for cpt in current_axis.components]
elif isinstance(index, numbers.Integral):
# an integer is just a one-sized slice (assumed over all components)
slice_cpts = [
AffineSliceComponent(cpt.label, index, index + 1)
for cpt in current_axis.components
]
else:
raise TypeError
index = Slice(current_axis.label, slice_cpts)
Expand All @@ -583,7 +600,7 @@ def _split_index_tree_from_iterable(
index.component_labels, index.target_paths
):
split_subtree = _split_index_tree_from_iterable(
subindices, axes, path | target_path
subindices, axes, path | target_path, loop_context
)
for loopctx, subtree in split_subtree.index_trees.items():
if loopctx not in index_trees:
Expand All @@ -600,7 +617,9 @@ def _split_index_tree_from_iterable(


def _split_index_tree_from_ellipsis(
axes: AxisTree, current_axis: Axis | None = None
axes: AxisTree,
current_axis: Axis | None = None,
loop_context=pmap(),
) -> IndexTree:
current_axis = current_axis or axes.root

Expand All @@ -611,7 +630,7 @@ def _split_index_tree_from_ellipsis(

subaxis = axes.child(current_axis, cpt)
if subaxis:
subtrees.append(_index_tree_from_ellipsis(axes, subaxis))
subtrees.append(_index_tree_from_ellipsis(axes, subaxis, loop_context))
else:
subtrees.append(None)

Expand All @@ -620,7 +639,7 @@ def _split_index_tree_from_ellipsis(
for subslice, subtree in checked_zip(subslices, subtrees):
if subtree is not None:
tree = tree.add_subtree(subtree, slice_, subslice.component)
return tree
return SplitIndexTree({pmap(): tree})


def is_fully_indexed(axes: AxisTree, indices: IndexTree) -> bool:
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/test_axis_ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyop3.codegen import LOOPY_LANG_VERSION, LOOPY_TARGET
from pyop3.distarray import MultiArray
from pyop3.dtypes import IntType, ScalarType
from pyop3.index import SplitIndexTree
from pyop3.loopexpr import INC, READ, WRITE, LoopyKernel, do_loop, loop
from pyop3.utils import flatten, just_one

Expand Down Expand Up @@ -51,13 +52,15 @@ def test_different_axis_orderings_do_not_change_packing_order():
dat1 = MultiArray(axes0, name="dat1", data=np.zeros(npoints, dtype=ScalarType))

p = axis0.index()
q = IndexTree(SplitLoopIndex(p, just_one(axis0.target_paths)))
path = just_one(axis0.target_paths)
q = IndexTree(SplitLoopIndex(p, path))
q = q.put_node(
Slice("ax1", [AffineSliceComponent(axis1.component_labels[0])]), *q.leaf
)
q = q.put_node(
Slice("ax2", [AffineSliceComponent(axis2.component_labels[0])]), *q.leaf
)
q = SplitIndexTree({pmap({p: path}): q})

do_loop(p, copy_kernel(dat0_0[q], dat1[q]))
assert np.allclose(dat1.data, dat0_0.data)
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/test_nested_loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np

from pyop3 import Axis, AxisTree, MultiArray, ScalarType, do_loop, loop


def test_transpose(scalar_copy_kernel):
npoints = 5
axis0 = Axis(npoints)
axis1 = Axis(npoints)

axes0 = AxisTree(axis0, {axis0.id: axis1})
axes1 = AxisTree(axis1, {axis1.id: axis0})

array0 = MultiArray(
axes0, name="array0", data=np.arange(axes0.size, dtype=ScalarType)
)
array1 = MultiArray(axes1, name="array1", dtype=array0.dtype)

do_loop(
p := axis0.index(),
loop(q := axis1.index(), scalar_copy_kernel(array0[p, q], array1[q, p])),
)
assert np.allclose(
array1.data.reshape((npoints, npoints)),
array0.data.reshape((npoints, npoints)).T,
)
15 changes: 2 additions & 13 deletions tests/integration/test_slice_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@ def test_1d_slice_composition(vec2_copy_kernel):
)
dat1 = MultiArray(Axis([(n, "cpt0")], "ax0"), name="dat1", dtype=dat0.dtype)

# this is needed because we currently do not compute the axis tree for the
# intermediate indexed object, so it cannot be indexed with shorthand
itree = IndexTree(Slice("ax0", [AffineSliceComponent("cpt0", 1, 3)]))

# do_loop(Axis(1).index(), vec2_copy_kernel(dat0[::2][1:3], dat1[...]))
do_loop(Axis(1).index(), vec2_copy_kernel(dat0[::2][itree], dat1[...]))
do_loop(Axis(1).index(), vec2_copy_kernel(dat0[::2][1:3], dat1[...]))
assert np.allclose(dat1.data, dat0.data[::2][1:3])


Expand All @@ -62,16 +57,10 @@ def test_2d_slice_composition(vec2_copy_kernel):
dat0 = MultiArray(axes0, name="dat0", data=np.arange(axes0.size, dtype=ScalarType))
dat1 = MultiArray(axes1, name="dat1", dtype=dat0.dtype)

itree = IndexTree(
Slice("ax0", [AffineSliceComponent("cpt0", 2, 4)], id="slice1"),
{"slice1": Slice("ax1", [AffineSliceComponent("cpt0", 1, 2)])},
)

do_loop(
Axis(1).index(),
vec2_copy_kernel(
# dat0[::2, 1:][2:4, 1],
dat0[::2, 1:][itree],
dat0[::2, 1:][2:4, 1],
dat1[...],
),
)
Expand Down

0 comments on commit b625ee3

Please sign in to comment.