Skip to content

Commit

Permalink
Add broken test, need to handle indexed maps
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Dec 12, 2023
1 parent 1834ee1 commit 6387448
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
3 changes: 3 additions & 0 deletions pyop3/ir/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,9 @@ def _scalar_assignment(
# Register data
ctx.add_argument(array)

if array.index_exprs != array.axes._default_index_exprs():
raise NotImplementedError

offset_expr = make_offset_expr(
array.layouts[path],
array_labels_to_jnames,
Expand Down
53 changes: 48 additions & 5 deletions tests/integration/test_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@ def vector_inc_kernel():
return op3.Function(lpy_kernel, [op3.READ, op3.INC])


# TODO make a function not a fixture
@pytest.fixture
def vector2_inc_kernel():
lpy_kernel = lp.make_kernel(
"{ [i]: 0 <= i < 2 }",
"y[0] = y[0] + x[i]",
[
lp.GlobalArg("x", op3.ScalarType, (2,), is_input=True, is_output=False),
lp.GlobalArg("y", op3.ScalarType, (1,), is_input=True, is_output=True),
],
name="vector_inc",
target=LOOPY_TARGET,
lang_version=LOOPY_LANG_VERSION,
)
return op3.Function(lpy_kernel, [op3.READ, op3.INC])


@pytest.fixture
def vec2_inc_kernel():
lpy_kernel = lp.make_kernel(
Expand Down Expand Up @@ -73,7 +90,10 @@ def vec12_inc_kernel():


@pytest.mark.parametrize("nested", [True, False])
def test_inc_from_tabulated_map(scalar_inc_kernel, vector_inc_kernel, nested):
@pytest.mark.parametrize("indexed", [None, "slice", "subset"])
def test_inc_from_tabulated_map(
scalar_inc_kernel, vector_inc_kernel, vector2_inc_kernel, nested, indexed
):
m, n = 4, 3
map_data = np.asarray([[1, 2, 0], [2, 0, 1], [3, 2, 3], [2, 0, 1]])

Expand All @@ -83,13 +103,29 @@ def test_inc_from_tabulated_map(scalar_inc_kernel, vector_inc_kernel, nested):
)
dat1 = op3.HierarchicalArray(axis, name="dat1", dtype=dat0.dtype)

map_axes = op3.AxisTree.from_nest({axis: op3.Axis(n)})
map_axes = op3.AxisTree.from_nest({axis: op3.Axis({"pt0": n}, "ax1")})
map_dat = op3.HierarchicalArray(
map_axes,
name="map0",
data=map_data.flatten(),
dtype=op3.IntType,
)

if indexed == "slice":
map_dat = map_dat[:, :2]
kernel = vector2_inc_kernel
elif indexed == "subset":
subset_ = op3.HierarchicalArray(
op3.Axis({"pt0": 2}, "ax1"),
name="subset",
data=np.asarray([1, 2]),
dtype=op3.IntType,
)
map_dat = map_dat[:, subset_]
kernel = vector2_inc_kernel
else:
kernel = vector_inc_kernel

map0 = op3.Map(
{
pmap({"ax0": "pt0"}): [
Expand All @@ -105,12 +141,19 @@ def test_inc_from_tabulated_map(scalar_inc_kernel, vector_inc_kernel, nested):
op3.loop(q := map0(p).index(), scalar_inc_kernel(dat0[q], dat1[p])),
)
else:
op3.do_loop(p := axis.index(), vector_inc_kernel(dat0[map0(p)], dat1[p]))
op3.do_loop(p := axis.index(), kernel(dat0[map0(p)], dat1[p]))

expected = np.zeros_like(dat1.data_ro)
for i in range(m):
for j in range(n):
expected[i] += dat0.data_ro[map_data[i, j]]
if indexed == "slice":
for j in range(2):
expected[i] += dat0.data_ro[map_data[i, j]]
elif indexed == "subset":
for j in [1, 2]:
expected[i] += dat0.data_ro[map_data[i, j]]
else:
for j in range(n):
expected[i] += dat0.data_ro[map_data[i, j]]
assert np.allclose(dat1.data_ro, expected)


Expand Down

0 comments on commit 6387448

Please sign in to comment.