Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better CopyToMap #1675

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
90 changes: 68 additions & 22 deletions dace/transformation/dataflow/copy_to_map.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.

from dace import dtypes, symbolic, data, subsets, Memlet
from dace import dtypes, symbolic, data, subsets, Memlet, properties
from dace.sdfg.scope import is_devicelevel_gpu
from dace.transformation import transformation as xf
from dace.sdfg import SDFGState, SDFG, nodes, utils as sdutil
from typing import Tuple
import itertools


@properties.make_properties
class CopyToMap(xf.SingleStateTransformation):
"""
Converts an access node -> access node copy into a map. Useful for generating manual code and
controlling schedules for N-dimensional strided copies.
"""
a = xf.PatternNode(nodes.AccessNode)
b = xf.PatternNode(nodes.AccessNode)
ignore_strides = properties.Property(
default=False,
desc='Ignore the stride of the data container; Defaults to `False`.',
)

@classmethod
def expressions(cls):
Expand All @@ -31,7 +36,10 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi
if isinstance(self.b.desc(sdfg), data.View):
if sdutil.get_view_node(graph, self.b) == self.a:
return False
if self.a.desc(sdfg).strides == self.b.desc(sdfg).strides:
if (not self.ignore_strides) and self.a.desc(sdfg).strides == self.b.desc(sdfg).strides:
return False
# Ensures that the edge goes from `a` -> `b`.
if not any(edge.dst is self.b for edge in graph.out_edges(self.a)):
return False

return True
Expand Down Expand Up @@ -62,31 +70,69 @@ def delinearize_linearize(self, desc: data.Array, copy_shape: Tuple[symbolic.Sym
return subsets.Range([(ind, ind, 1) for ind in cur_index])

def apply(self, state: SDFGState, sdfg: SDFG):
adesc = self.a.desc(sdfg)
bdesc = self.b.desc(sdfg)
edge = state.edges_between(self.a, self.b)[0]
avnode = self.a
av = avnode.data
adesc = avnode.desc(sdfg)
bvnode = self.b
bv = bvnode.data
bdesc = bvnode.desc(sdfg)

edge = state.edges_between(avnode, bvnode)[0]
src_subset = edge.data.get_src_subset(edge, state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we may need to call the try_initialize method before attempting to get src/dst subsets here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually not sure.
I was hit by this a few time in the past so I am in favour of keeping it.
However, the main reason why this function is used (I think I "discovered" it here) is because even in the original code it is used (see old line 70).
For that reason I decided to keep it.
But if you think we should remove that I will not object.

Copy link
Contributor

@alexnick83 alexnick83 Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, should we also add a call to the method that tries to initialize the edge/memlet to ensure that the src/dst subsets are not None when they are not supposed to.

if src_subset is None:
src_subset = subsets.Range.from_array(adesc)
src_subset_size = src_subset.size()
red_src_subset_size = tuple(s for s in src_subset_size if s != 1)

dst_subset = edge.data.get_dst_subset(edge, state)
if dst_subset is None:
dst_subset = subsets.Range.from_array(bdesc)
dst_subset_size = dst_subset.size()
red_dst_subset_size = tuple(s for s in dst_subset_size if s != 1)

if len(adesc.shape) >= len(bdesc.shape):
copy_shape = edge.data.get_src_subset(edge, state).size()
copy_shape = src_subset_size
copy_a = True
else:
copy_shape = edge.data.get_dst_subset(edge, state).size()
copy_shape = dst_subset_size
copy_a = False

maprange = {f'__i{i}': (0, s - 1, 1) for i, s in enumerate(copy_shape)}

av = self.a.data
bv = self.b.data
avnode = self.a
bvnode = self.b

# Linearize and delinearize to get index expression for other side
if copy_a:
a_index = [symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape))]
b_index = self.delinearize_linearize(bdesc, copy_shape, edge.data.get_dst_subset(edge, state))
if tuple(src_subset_size) == tuple(dst_subset_size):
# The two subsets have exactly the same shape, so we can just copying with an offset.
# We use another index variables for the tests only.
maprange = {f'__j{i}': (0, s - 1, 1) for i, s in enumerate(copy_shape)}
a_index = [symbolic.pystr_to_symbolic(f'__j{i} + ({src_subset[i][0]})') for i in range(len(copy_shape))]
b_index = [symbolic.pystr_to_symbolic(f'__j{i} + ({dst_subset[i][0]})') for i in range(len(copy_shape))]
elif red_src_subset_size == red_dst_subset_size and (len(red_dst_subset_size) > 0):
# If we remove all size 1 dimensions that the two subsets have the same size.
# This is essentially the memlet `a[0:10, 2, 0:10] -> 0:10, 10:20`
# We use another index variable only for the tests but we would have to
# recreate the index anyways.
maprange = {f'__j{i}': (0, s - 1, 1) for i, s in enumerate(red_src_subset_size)}
cnt = itertools.count(0)
a_index = [
symbolic.pystr_to_symbolic(f'{src_subset[i][0]}')
if s == 1
else symbolic.pystr_to_symbolic(f'__j{next(cnt)} + ({src_subset[i][0]})')
for i, s in enumerate(src_subset_size)
]
cnt = itertools.count(0)
b_index = [
symbolic.pystr_to_symbolic(f'{dst_subset[i][0]}')
if s == 1
else symbolic.pystr_to_symbolic(f'__j{next(cnt)} + ({dst_subset[i][0]})')
for i, s in enumerate(dst_subset_size)
]
else:
a_index = self.delinearize_linearize(adesc, copy_shape, edge.data.get_src_subset(edge, state))
b_index = [symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape))]
# We have to delinearize and linearize
# We use another index variable for the tests.
maprange = {f'__i{i}': (0, s - 1, 1) for i, s in enumerate(copy_shape)}
if copy_a:
a_index = [symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape))]
b_index = self.delinearize_linearize(bdesc, copy_shape, edge.data.get_dst_subset(edge, state))
else:
a_index = self.delinearize_linearize(adesc, copy_shape, edge.data.get_src_subset(edge, state))
b_index = [symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape))]

a_subset = subsets.Range([(ind, ind, 1) for ind in a_index])
b_subset = subsets.Range([(ind, ind, 1) for ind in b_index])
Expand All @@ -101,7 +147,7 @@ def apply(self, state: SDFGState, sdfg: SDFG):
schedule = dtypes.ScheduleType.GPU_Device

# Add copy map
t, _, _ = state.add_mapped_tasklet('copy',
t, _, _ = state.add_mapped_tasklet(f'copy_{av}_{bv}',
maprange,
dict(__inp=Memlet(data=av, subset=a_subset)),
'__out = __inp',
Expand Down
164 changes: 161 additions & 3 deletions tests/transformations/copy_to_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import copy
import pytest
import numpy as np
import re
from typing import Tuple, Optional


def _copy_to_map(storage: dace.StorageType):
Expand Down Expand Up @@ -102,9 +104,165 @@ def test_preprocess():
assert np.allclose(out, inp)


def _perform_non_lin_delin_test(
sdfg: dace.SDFG,
) -> bool:
"""Performs test for the special case CopyToMap that bypasses linearizing and delinearaziong.
"""
assert sdfg.number_of_nodes() == 1
state: dace.SDFGState = sdfg.states()[0]
assert state.number_of_nodes() == 2
assert state.number_of_edges() == 1
assert all(isinstance(node, dace.nodes.AccessNode) for node in state.nodes())
sdfg.validate()

a = np.random.rand(*sdfg.arrays["a"].shape)
b_unopt = np.random.rand(*sdfg.arrays["b"].shape)
b_opt = b_unopt.copy()
sdfg(a=a, b=b_unopt)

nb_runs = sdfg.apply_transformations_repeated(CopyToMap, validate=True, options={"ignore_strides": True})
assert nb_runs == 1, f"Expected 1 application, but {nb_runs} were performed."

# Now looking for the tasklet and checking if the memlets follows the expected
# simple pattern.
tasklet: dace.nodes.Tasklet = next(iter([node for node in state.nodes() if isinstance(node, dace.nodes.Tasklet)]))
pattern: re.Pattern = re.compile(r"(__j[0-9])|(__j[0-9]+\s*\+\s*[0-9]+)|([0-9]+)")

assert state.in_degree(tasklet) == 1
assert state.out_degree(tasklet) == 1
in_edge = next(iter(state.in_edges(tasklet)))
out_edge = next(iter(state.out_edges(tasklet)))

assert all(pattern.fullmatch(str(idxs[0]).strip()) for idxs in in_edge.data.src_subset), f"IN: {in_edge.data.src_subset}"
assert all(pattern.fullmatch(str(idxs[0]).strip()) for idxs in out_edge.data.dst_subset), f"OUT: {out_edge.data.dst_subset}"

# Now call it again after the optimization.
sdfg(a=a, b=b_opt)
assert np.allclose(b_unopt, b_opt)

return True

def _make_non_lin_delin_sdfg(
shape_a: Tuple[int, ...],
shape_b: Optional[Tuple[int, ...]] = None
) -> Tuple[dace.SDFG, dace.SDFGState, dace.nodes.AccessNode, dace.nodes.AccessNode]:

if shape_b is None:
shape_b = shape_a

sdfg = dace.SDFG("bypass1")
state = sdfg.add_state(is_start_block=True)

ac = []
for name, shape in [('a', shape_a), ('b', shape_b)]:
sdfg.add_array(
name=name,
shape=shape,
dtype=dace.float64,
transient=False,
)
ac.append(state.add_access(name))

return sdfg, state, ac[0], ac[1]


def test_non_lin_delin_1():
sdfg, state, a, b = _make_non_lin_delin_sdfg((10, 10))
state.add_nedge(
a,
b,
dace.Memlet("a[0:10, 0:10] -> [0:10, 0:10]"),
)
_perform_non_lin_delin_test(sdfg)

def test_non_lin_delin_2():
sdfg, state, a, b = _make_non_lin_delin_sdfg((10, 10), (100, 100))
state.add_nedge(
a,
b,
dace.Memlet("a[0:10, 0:10] -> [50:60, 40:50]"),
)
_perform_non_lin_delin_test(sdfg)


def test_non_lin_delin_3():
sdfg, state, a, b = _make_non_lin_delin_sdfg((100, 100), (100, 100))
state.add_nedge(
a,
b,
dace.Memlet("a[1:11, 20:30] -> [50:60, 40:50]"),
)
_perform_non_lin_delin_test(sdfg)


def test_non_lin_delin_4():
sdfg, state, a, b = _make_non_lin_delin_sdfg((100, 4, 100), (100, 100))
state.add_nedge(
a,
b,
dace.Memlet("a[1:11, 2, 20:30] -> [50:60, 40:50]"),
)
_perform_non_lin_delin_test(sdfg)


def test_non_lin_delin_5():
sdfg, state, a, b = _make_non_lin_delin_sdfg((100, 4, 100), (100, 10, 100))
state.add_nedge(
a,
b,
dace.Memlet("a[1:11, 2, 20:30] -> [50:60, 4, 40:50]"),
)
_perform_non_lin_delin_test(sdfg)


def test_non_lin_delin_6():
sdfg, state, a, b = _make_non_lin_delin_sdfg((100, 100), (100, 10, 100))
state.add_nedge(
a,
b,
dace.Memlet("a[1:11, 20:30] -> [50:60, 4, 40:50]"),
)
_perform_non_lin_delin_test(sdfg)


def test_non_lin_delin_7():
sdfg, state, a, b = _make_non_lin_delin_sdfg((10, 10), (20, 20))
state.add_nedge(
a,
b,
dace.Memlet("b[5:15, 6:16]"),
)
_perform_non_lin_delin_test(sdfg)


def test_non_lin_delin_8():
sdfg, state, a, b = _make_non_lin_delin_sdfg((20, 20), (10, 10))
state.add_nedge(
a,
b,
dace.Memlet("a[5:15, 6:16]"),
)
_perform_non_lin_delin_test(sdfg)


if __name__ == '__main__':
test_non_lin_delin_1()
test_non_lin_delin_2()
test_non_lin_delin_3()
test_non_lin_delin_4()
test_non_lin_delin_5()
test_non_lin_delin_6()
test_non_lin_delin_7()
test_non_lin_delin_8()

test_copy_to_map()
test_copy_to_map_gpu()
test_flatten_to_map()
test_flatten_to_map_gpu()
test_preprocess()
try:
import cupy
test_copy_to_map_gpu()
test_flatten_to_map_gpu()
test_preprocess()
except ModuleNotFoundError as E:
if "'cupy'" not in str(E):
raise
Loading