From 67dd221a3f5a10933456fcc6571212414e7021bb Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Wed, 16 Aug 2023 15:18:33 +0100 Subject: [PATCH] Axis offset node roughly works --- pyop3/axis.py | 30 ++++++++ pyop3/codegen/loopexpr2loopy.py | 66 ++++++++++++----- pyop3/loopexpr.py | 4 +- tests/integration/conftest.py | 34 ++++++++- tests/integration/test_local_indices.py | 97 ++++++++++++++++++++++--- 5 files changed, 199 insertions(+), 32 deletions(-) diff --git a/pyop3/axis.py b/pyop3/axis.py index 05875d46..5b90a200 100644 --- a/pyop3/axis.py +++ b/pyop3/axis.py @@ -555,6 +555,9 @@ def __init__( def __getitem__(self, indices): return as_axis_tree(self)[indices] + def __call__(self, *args): + return as_axis_tree(self)(*args) + def __str__(self) -> str: return f"{self.__class__.__name__}([{', '.join(str(cpt) for cpt in self.components)}], label={self.label})" @@ -578,6 +581,30 @@ def count(self): def index(self): return as_axis_tree(self).index() + def enumerate(self): + return as_axis_tree(self).enumerate() + + +# this is supposed to be used in place of an array to represent the offset +# of an axis at a given index +class CalledAxisTree: + def __init__(self, axes, indices): + self.axes = axes + self.indices = indices + + # FIXME + @property + def name(self): + return "my_called_axis" + + @property + def dtype(self): + return IntType + + @functools.cached_property + def datamap(self): + return self.axes.datamap | self.indices.datamap + class AxisTree(LabelledTree): def __init__( @@ -602,6 +629,9 @@ def __getitem__(self, indices): return IndexedAxisTree(self, indices) + def __call__(self, *args): + return CalledAxisTree(self, *args) + def index(self): # cyclic import from pyop3.index import GlobalLoopIndex diff --git a/pyop3/codegen/loopexpr2loopy.py b/pyop3/codegen/loopexpr2loopy.py index 0cb08a39..a25e317a 100644 --- a/pyop3/codegen/loopexpr2loopy.py +++ b/pyop3/codegen/loopexpr2loopy.py @@ -20,7 +20,14 @@ from pyrsistent import pmap from pyop3 import tlang, utils -from pyop3.axis import AffineLayout, Axis, AxisComponent, AxisTree, TabulatedLayout +from pyop3.axis import ( + AffineLayout, + Axis, + AxisComponent, + AxisTree, + CalledAxisTree, + TabulatedLayout, +) from pyop3.distarray import IndexedMultiArray, MultiArray from pyop3.dtypes import IntType from pyop3.index import ( @@ -544,7 +551,9 @@ def _(call: FunctionCall, loop_indices, ctx: LoopyCodegenContext) -> None: temporaries.append((arg, indexed_temp, spec.access, shape)) # Register data - ctx.add_argument(arg.name, arg.dtype) + if not isinstance(arg, CalledAxisTree): + ctx.add_argument(arg.name, arg.dtype) + ctx.add_temporary(temporary.name, temporary.dtype, shape) # subarrayref nonsense/magic @@ -665,10 +674,13 @@ def build_assignment( # unroll the index trees, this should be tidied up array = assignment.array - itrees = [] - while isinstance(array, Indexed): - itrees.insert(0, array.itree) - array = array.obj + if isinstance(array, CalledAxisTree): + itrees = [array.indices] + else: + itrees = [] + while isinstance(array, Indexed): + itrees.insert(0, array.itree) + array = array.obj for indices in itrees: # get the right index tree given the loop context @@ -733,6 +745,10 @@ def _prepare_assignment_rec( jnames: pmap, ctx: LoopyCodegenContext, ) -> tuple[pmap, pmap]: + # catch empty axis trees + # if axes.is_empty: + # return pmap(), pmap({None: 0}), pmap({None: ()}) + jnames_per_axcpt = {} insns_per_leaf = {} array_expr_per_leaf = {} @@ -756,9 +772,15 @@ def _prepare_assignment_rec( insns_per_leaf |= subinsns_per_leaf array_expr_per_leaf |= subarray_expr_per_leaf else: - insns, array_expr = _assignment_array_insn( - assignment, axes, new_path, new_jnames, ctx - ) + if isinstance(assignment.array, CalledAxisTree): + # just the offset instructions here, no subscript + insns, array_expr = emit_assignment_insn( + axes, new_path, new_jnames, ctx + ) + else: + insns, array_expr = _assignment_array_insn( + assignment, axes, new_path, new_jnames, ctx + ) insns_per_leaf[axis.id, axcpt.label] = insns array_expr_per_leaf[axis.id, axcpt.label] = array_expr @@ -901,8 +923,7 @@ def _parse_assignment_final( insns_per_leaf, ctx: LoopyCodegenContext, ): - # catch empty axes here - if not axes.root: + if not axes.root: # catch empty axes here for insn in insns_per_leaf[None]: ctx.add_assignment(*insn) array_expr = array_expr_per_leaf[None] @@ -964,11 +985,18 @@ def _parse_assignment_final_rec( for insn in insns_per_leaf[axis.id, axcpt.label]: ctx.add_assignment(*insn) array_expr = array_expr_per_leaf[axis.id, axcpt.label] - temp_insns, temp_expr = _assignment_temp_insn( - assignment, new_path, new_jnames, ctx - ) - for insn in temp_insns: - ctx.add_assignment(*insn) + if isinstance(assignment.array, CalledAxisTree): + temp_insns, temp_expr = _assignment_temp_insn( + assignment, pmap(), pmap(), ctx + ) + for insn in temp_insns: + ctx.add_assignment(*insn) + else: + temp_insns, temp_expr = _assignment_temp_insn( + assignment, new_path, new_jnames, ctx + ) + for insn in temp_insns: + ctx.add_assignment(*insn) _shared_assignment_insn(assignment, array_expr, temp_expr, ctx) @@ -1240,7 +1268,6 @@ def _assignment_array_insn(assignment, axes, path, jnames, ctx): """ offset_insns, array_offset = emit_assignment_insn( - assignment.array.name, axes, path, jnames, @@ -1260,7 +1287,6 @@ def _assignment_temp_insn(assignment, path, jnames, ctx): """ offset_insns, temp_offset = emit_assignment_insn( - assignment.temporary.name, assignment.temporary.axes, path, jnames, @@ -1298,7 +1324,6 @@ def _shared_assignment_insn(assignment, array_expr, temp_expr, ctx): def emit_assignment_insn( - array_name, axes, path, labels_to_jnames, @@ -1605,6 +1630,9 @@ def _indexed_axes(indexed, loop_indices): axis tree. No instructions need be emitted. """ + # offsets are always scalar + if isinstance(indexed, CalledAxisTree): + return AxisTree() # get the right index tree given the loop context loop_context = {} diff --git a/pyop3/loopexpr.py b/pyop3/loopexpr.py index 4351a09c..bcc2cb40 100644 --- a/pyop3/loopexpr.py +++ b/pyop3/loopexpr.py @@ -217,9 +217,7 @@ class Terminal(LoopExpr): class FunctionCall(Terminal): def __init__(self, function, arguments): self.function = function - self.arguments = tuple( - arg if isinstance(arg, Indexed) else arg[...] for arg in arguments - ) + self.arguments = arguments @functools.cached_property def datamap(self) -> dict[str, DistributedArray]: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a493e547..1970d75a 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,7 +1,7 @@ import loopy as lp import pytest -from pyop3 import READ, WRITE, LoopyKernel, ScalarType +from pyop3 import INC, READ, WRITE, IntType, LoopyKernel, ScalarType from pyop3.codegen import LOOPY_LANG_VERSION, LOOPY_TARGET @@ -19,3 +19,35 @@ def scalar_copy_kernel(): lang_version=LOOPY_LANG_VERSION, ) return LoopyKernel(code, [READ, WRITE]) + + +@pytest.fixture +def scalar_copy_kernel_int(): + code = lp.make_kernel( + "{ [i]: 0 <= i < 1 }", + "y[i] = x[i]", + [ + lp.GlobalArg("x", IntType, (1,), is_input=True, is_output=False), + lp.GlobalArg("y", IntType, (1,), is_input=False, is_output=True), + ], + name="scalar_copy_int", + target=LOOPY_TARGET, + lang_version=LOOPY_LANG_VERSION, + ) + return LoopyKernel(code, [READ, WRITE]) + + +@pytest.fixture +def scalar_inc_kernel(): + lpy_kernel = lp.make_kernel( + "{ [i]: 0 <= i < 1 }", + "y[i] = y[i] + x[i]", + [ + lp.GlobalArg("x", ScalarType, (1,), is_input=True, is_output=False), + lp.GlobalArg("y", ScalarType, (1,), is_input=True, is_output=True), + ], + name="scalar_inc", + target=LOOPY_TARGET, + lang_version=LOOPY_LANG_VERSION, + ) + return LoopyKernel(lpy_kernel, [READ, INC]) diff --git a/tests/integration/test_local_indices.py b/tests/integration/test_local_indices.py index 32f22c2d..e63941e2 100644 --- a/tests/integration/test_local_indices.py +++ b/tests/integration/test_local_indices.py @@ -1,7 +1,16 @@ import numpy as np import pytest -from pyop3 import Axis, AxisComponent, AxisTree, MultiArray, ScalarType, do_loop +from pyop3 import ( + Axis, + AxisComponent, + AxisTree, + IntType, + MultiArray, + ScalarType, + do_loop, + loop, +) def test_copy_with_local_indices(scalar_copy_kernel): @@ -33,12 +42,19 @@ def test_copy_slice(scalar_copy_kernel): assert np.allclose(array1.data, array0.data[::2]) -def test_inc_into_small_array(scalar_copy_kernel): - size = 10 - dim = 3 +# this isn't a very meaningful test since the local and global loop indices are identical +def test_inc_into_small_array(scalar_inc_kernel): + m, n = 10, 3 - big = MultiArray(big_axes) - small = MultiArray(small_axes) + small_axes = Axis(n, "ax1") + big_axes = AxisTree( + Axis([AxisComponent(m, "pt0")], "ax0", id="root"), {"root": small_axes} + ) + + big = MultiArray( + big_axes, name="big", data=np.arange(big_axes.size, dtype=ScalarType) + ) + small = MultiArray(small_axes, name="small", dtype=big.dtype) # The following is equivalent to # for p in big_axes.root: @@ -47,9 +63,72 @@ def test_inc_into_small_array(scalar_copy_kernel): do_loop( p := big_axes.root.index(), loop( - q := big_axes[p, :].enumerate(), - scalar_copy_kernel(big[p, q[1]], small[q[0]]), + q := small_axes.enumerate(), + scalar_inc_kernel(big[p, q.global_index], small[q.local_index]), ), ) - assert False, "TODO" + expected = np.zeros(n) + for i in range(m): + for j in range(n): + expected[j] += big.data.reshape((m, n))[i, j] + assert np.allclose(small.data, expected) + + +# TODO this does not belong in this test file +def test_copy_offset(scalar_copy_kernel_int): + m = 10 + axes = Axis(m) + out = MultiArray(axes, name="out", dtype=IntType) + + # do_loop(p := axes.index(), scalar_copy_kernel_int(axes(p), out[p])) + # debug + from pyrsistent import pmap + + from pyop3.index import IndexTree, SplitIndexTree, SplitLoopIndex + + p = axes.index() + path = pmap({axes.label: axes.components[0].label}) + itree = SplitIndexTree({pmap({p: path}): IndexTree(SplitLoopIndex(p, path))}) + l = loop(p, scalar_copy_kernel_int(axes(itree), out[p])) + l() + assert np.allclose(out.data, np.arange(10)) + + +# TODO this does not belong in this test file +def test_copy_vec_offset(scalar_copy_kernel_int): + m, n = 10, 3 + # axes = AxisTree(Axis(m, id="root"), {"root": Axis(n)}) + axes = AxisTree( + Axis([AxisComponent(m, "pt0")], "ax0", id="root"), + {"root": Axis([AxisComponent(n, "pt0")], "ax1")}, + ) + + out = MultiArray(axes.root, name="out", dtype=IntType) + + # do_loop(p := axes.root.index(), scalar_copy_kernel(axes(p, 0), out[p])) + from pyrsistent import pmap + + from pyop3.index import ( + AffineSliceComponent, + IndexTree, + Slice, + SplitIndexTree, + SplitLoopIndex, + ) + + p = axes.root.index() + path = pmap({"ax0": "pt0"}) + # i.e. [p, 0] + itree = SplitIndexTree( + { + pmap({p: path}): IndexTree( + root := SplitLoopIndex(p, path), + {root.id: Slice("ax1", [AffineSliceComponent("pt0", 0, 1)])}, + ) + } + ) + l = loop(p, scalar_copy_kernel_int(axes(itree), out[p])) + + l() + assert np.allclose(out.data, np.arange(m * n, step=n))