Skip to content

Commit

Permalink
Axis offset node roughly works
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Aug 16, 2023
1 parent e363acf commit 67dd221
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 32 deletions.
30 changes: 30 additions & 0 deletions pyop3/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,9 @@ def __init__(
def __getitem__(self, indices):
return as_axis_tree(self)[indices]

def __call__(self, *args):
return as_axis_tree(self)(*args)

def __str__(self) -> str:
return f"{self.__class__.__name__}([{', '.join(str(cpt) for cpt in self.components)}], label={self.label})"

Expand All @@ -578,6 +581,30 @@ def count(self):
def index(self):
return as_axis_tree(self).index()

def enumerate(self):
return as_axis_tree(self).enumerate()


# this is supposed to be used in place of an array to represent the offset
# of an axis at a given index
class CalledAxisTree:
def __init__(self, axes, indices):
self.axes = axes
self.indices = indices

# FIXME
@property
def name(self):
return "my_called_axis"

@property
def dtype(self):
return IntType

@functools.cached_property
def datamap(self):
return self.axes.datamap | self.indices.datamap


class AxisTree(LabelledTree):
def __init__(
Expand All @@ -602,6 +629,9 @@ def __getitem__(self, indices):

return IndexedAxisTree(self, indices)

def __call__(self, *args):
return CalledAxisTree(self, *args)

def index(self):
# cyclic import
from pyop3.index import GlobalLoopIndex
Expand Down
66 changes: 47 additions & 19 deletions pyop3/codegen/loopexpr2loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from pyrsistent import pmap

from pyop3 import tlang, utils
from pyop3.axis import AffineLayout, Axis, AxisComponent, AxisTree, TabulatedLayout
from pyop3.axis import (
AffineLayout,
Axis,
AxisComponent,
AxisTree,
CalledAxisTree,
TabulatedLayout,
)
from pyop3.distarray import IndexedMultiArray, MultiArray
from pyop3.dtypes import IntType
from pyop3.index import (
Expand Down Expand Up @@ -544,7 +551,9 @@ def _(call: FunctionCall, loop_indices, ctx: LoopyCodegenContext) -> None:
temporaries.append((arg, indexed_temp, spec.access, shape))

# Register data
ctx.add_argument(arg.name, arg.dtype)
if not isinstance(arg, CalledAxisTree):
ctx.add_argument(arg.name, arg.dtype)

ctx.add_temporary(temporary.name, temporary.dtype, shape)

# subarrayref nonsense/magic
Expand Down Expand Up @@ -665,10 +674,13 @@ def build_assignment(

# unroll the index trees, this should be tidied up
array = assignment.array
itrees = []
while isinstance(array, Indexed):
itrees.insert(0, array.itree)
array = array.obj
if isinstance(array, CalledAxisTree):
itrees = [array.indices]
else:
itrees = []
while isinstance(array, Indexed):
itrees.insert(0, array.itree)
array = array.obj

for indices in itrees:
# get the right index tree given the loop context
Expand Down Expand Up @@ -733,6 +745,10 @@ def _prepare_assignment_rec(
jnames: pmap,
ctx: LoopyCodegenContext,
) -> tuple[pmap, pmap]:
# catch empty axis trees
# if axes.is_empty:
# return pmap(), pmap({None: 0}), pmap({None: ()})

jnames_per_axcpt = {}
insns_per_leaf = {}
array_expr_per_leaf = {}
Expand All @@ -756,9 +772,15 @@ def _prepare_assignment_rec(
insns_per_leaf |= subinsns_per_leaf
array_expr_per_leaf |= subarray_expr_per_leaf
else:
insns, array_expr = _assignment_array_insn(
assignment, axes, new_path, new_jnames, ctx
)
if isinstance(assignment.array, CalledAxisTree):
# just the offset instructions here, no subscript
insns, array_expr = emit_assignment_insn(
axes, new_path, new_jnames, ctx
)
else:
insns, array_expr = _assignment_array_insn(
assignment, axes, new_path, new_jnames, ctx
)
insns_per_leaf[axis.id, axcpt.label] = insns
array_expr_per_leaf[axis.id, axcpt.label] = array_expr

Expand Down Expand Up @@ -901,8 +923,7 @@ def _parse_assignment_final(
insns_per_leaf,
ctx: LoopyCodegenContext,
):
# catch empty axes here
if not axes.root:
if not axes.root: # catch empty axes here
for insn in insns_per_leaf[None]:
ctx.add_assignment(*insn)
array_expr = array_expr_per_leaf[None]
Expand Down Expand Up @@ -964,11 +985,18 @@ def _parse_assignment_final_rec(
for insn in insns_per_leaf[axis.id, axcpt.label]:
ctx.add_assignment(*insn)
array_expr = array_expr_per_leaf[axis.id, axcpt.label]
temp_insns, temp_expr = _assignment_temp_insn(
assignment, new_path, new_jnames, ctx
)
for insn in temp_insns:
ctx.add_assignment(*insn)
if isinstance(assignment.array, CalledAxisTree):
temp_insns, temp_expr = _assignment_temp_insn(
assignment, pmap(), pmap(), ctx
)
for insn in temp_insns:
ctx.add_assignment(*insn)
else:
temp_insns, temp_expr = _assignment_temp_insn(
assignment, new_path, new_jnames, ctx
)
for insn in temp_insns:
ctx.add_assignment(*insn)
_shared_assignment_insn(assignment, array_expr, temp_expr, ctx)


Expand Down Expand Up @@ -1240,7 +1268,6 @@ def _assignment_array_insn(assignment, axes, path, jnames, ctx):
"""
offset_insns, array_offset = emit_assignment_insn(
assignment.array.name,
axes,
path,
jnames,
Expand All @@ -1260,7 +1287,6 @@ def _assignment_temp_insn(assignment, path, jnames, ctx):
"""
offset_insns, temp_offset = emit_assignment_insn(
assignment.temporary.name,
assignment.temporary.axes,
path,
jnames,
Expand Down Expand Up @@ -1298,7 +1324,6 @@ def _shared_assignment_insn(assignment, array_expr, temp_expr, ctx):


def emit_assignment_insn(
array_name,
axes,
path,
labels_to_jnames,
Expand Down Expand Up @@ -1605,6 +1630,9 @@ def _indexed_axes(indexed, loop_indices):
axis tree. No instructions need be emitted.
"""
# offsets are always scalar
if isinstance(indexed, CalledAxisTree):
return AxisTree()

# get the right index tree given the loop context
loop_context = {}
Expand Down
4 changes: 1 addition & 3 deletions pyop3/loopexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,7 @@ class Terminal(LoopExpr):
class FunctionCall(Terminal):
def __init__(self, function, arguments):
self.function = function
self.arguments = tuple(
arg if isinstance(arg, Indexed) else arg[...] for arg in arguments
)
self.arguments = arguments

@functools.cached_property
def datamap(self) -> dict[str, DistributedArray]:
Expand Down
34 changes: 33 additions & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import loopy as lp
import pytest

from pyop3 import READ, WRITE, LoopyKernel, ScalarType
from pyop3 import INC, READ, WRITE, IntType, LoopyKernel, ScalarType
from pyop3.codegen import LOOPY_LANG_VERSION, LOOPY_TARGET


Expand All @@ -19,3 +19,35 @@ def scalar_copy_kernel():
lang_version=LOOPY_LANG_VERSION,
)
return LoopyKernel(code, [READ, WRITE])


@pytest.fixture
def scalar_copy_kernel_int():
code = lp.make_kernel(
"{ [i]: 0 <= i < 1 }",
"y[i] = x[i]",
[
lp.GlobalArg("x", IntType, (1,), is_input=True, is_output=False),
lp.GlobalArg("y", IntType, (1,), is_input=False, is_output=True),
],
name="scalar_copy_int",
target=LOOPY_TARGET,
lang_version=LOOPY_LANG_VERSION,
)
return LoopyKernel(code, [READ, WRITE])


@pytest.fixture
def scalar_inc_kernel():
lpy_kernel = lp.make_kernel(
"{ [i]: 0 <= i < 1 }",
"y[i] = y[i] + x[i]",
[
lp.GlobalArg("x", ScalarType, (1,), is_input=True, is_output=False),
lp.GlobalArg("y", ScalarType, (1,), is_input=True, is_output=True),
],
name="scalar_inc",
target=LOOPY_TARGET,
lang_version=LOOPY_LANG_VERSION,
)
return LoopyKernel(lpy_kernel, [READ, INC])
97 changes: 88 additions & 9 deletions tests/integration/test_local_indices.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import numpy as np
import pytest

from pyop3 import Axis, AxisComponent, AxisTree, MultiArray, ScalarType, do_loop
from pyop3 import (
Axis,
AxisComponent,
AxisTree,
IntType,
MultiArray,
ScalarType,
do_loop,
loop,
)


def test_copy_with_local_indices(scalar_copy_kernel):
Expand Down Expand Up @@ -33,12 +42,19 @@ def test_copy_slice(scalar_copy_kernel):
assert np.allclose(array1.data, array0.data[::2])


def test_inc_into_small_array(scalar_copy_kernel):
size = 10
dim = 3
# this isn't a very meaningful test since the local and global loop indices are identical
def test_inc_into_small_array(scalar_inc_kernel):
m, n = 10, 3

big = MultiArray(big_axes)
small = MultiArray(small_axes)
small_axes = Axis(n, "ax1")
big_axes = AxisTree(
Axis([AxisComponent(m, "pt0")], "ax0", id="root"), {"root": small_axes}
)

big = MultiArray(
big_axes, name="big", data=np.arange(big_axes.size, dtype=ScalarType)
)
small = MultiArray(small_axes, name="small", dtype=big.dtype)

# The following is equivalent to
# for p in big_axes.root:
Expand All @@ -47,9 +63,72 @@ def test_inc_into_small_array(scalar_copy_kernel):
do_loop(
p := big_axes.root.index(),
loop(
q := big_axes[p, :].enumerate(),
scalar_copy_kernel(big[p, q[1]], small[q[0]]),
q := small_axes.enumerate(),
scalar_inc_kernel(big[p, q.global_index], small[q.local_index]),
),
)

assert False, "TODO"
expected = np.zeros(n)
for i in range(m):
for j in range(n):
expected[j] += big.data.reshape((m, n))[i, j]
assert np.allclose(small.data, expected)


# TODO this does not belong in this test file
def test_copy_offset(scalar_copy_kernel_int):
m = 10
axes = Axis(m)
out = MultiArray(axes, name="out", dtype=IntType)

# do_loop(p := axes.index(), scalar_copy_kernel_int(axes(p), out[p]))
# debug
from pyrsistent import pmap

from pyop3.index import IndexTree, SplitIndexTree, SplitLoopIndex

p = axes.index()
path = pmap({axes.label: axes.components[0].label})
itree = SplitIndexTree({pmap({p: path}): IndexTree(SplitLoopIndex(p, path))})
l = loop(p, scalar_copy_kernel_int(axes(itree), out[p]))
l()
assert np.allclose(out.data, np.arange(10))


# TODO this does not belong in this test file
def test_copy_vec_offset(scalar_copy_kernel_int):
m, n = 10, 3
# axes = AxisTree(Axis(m, id="root"), {"root": Axis(n)})
axes = AxisTree(
Axis([AxisComponent(m, "pt0")], "ax0", id="root"),
{"root": Axis([AxisComponent(n, "pt0")], "ax1")},
)

out = MultiArray(axes.root, name="out", dtype=IntType)

# do_loop(p := axes.root.index(), scalar_copy_kernel(axes(p, 0), out[p]))
from pyrsistent import pmap

from pyop3.index import (
AffineSliceComponent,
IndexTree,
Slice,
SplitIndexTree,
SplitLoopIndex,
)

p = axes.root.index()
path = pmap({"ax0": "pt0"})
# i.e. [p, 0]
itree = SplitIndexTree(
{
pmap({p: path}): IndexTree(
root := SplitLoopIndex(p, path),
{root.id: Slice("ax1", [AffineSliceComponent("pt0", 0, 1)])},
)
}
)
l = loop(p, scalar_copy_kernel_int(axes(itree), out[p]))

l()
assert np.allclose(out.data, np.arange(m * n, step=n))

0 comments on commit 67dd221

Please sign in to comment.