Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to Read and Write Sets #1678

Merged
merged 19 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
145c0ea
Added tests for the `_read_and_write_sets()`.
philip-paul-mueller Oct 11, 2024
38e748b
Added the fix from my MapFusion PR.
philip-paul-mueller Oct 11, 2024
3748c03
Now made `read_and_write_sets()` fully adhere to their own definition.
philip-paul-mueller Oct 11, 2024
3ab4bf3
Updated a test for the `PruneConnectors` transformation.
philip-paul-mueller Oct 11, 2024
b4feddf
Added code to `test_more_than_a_map` to ensure that the transformatio…
philip-paul-mueller Oct 11, 2024
e1c25b2
Merge remote-tracking branch 'spcl/master' into read-write-sets
philip-paul-mueller Oct 14, 2024
70fa3db
Added the new memlet creation syntax.
philip-paul-mueller Oct 14, 2024
b187a82
Modified some comments to make them clearer.
philip-paul-mueller Oct 14, 2024
9c6cb6c
Modified the `tests/transformations/move_loop_into_map_test.py::test_…
philip-paul-mueller Oct 14, 2024
b5fc16f
Merge branch 'master' into read-write-sets
philip-paul-mueller Oct 22, 2024
b7fe242
Added a test to highlights the error.
philip-paul-mueller Oct 22, 2024
b546b07
I now removed the filtering inside the read and write set.
philip-paul-mueller Oct 22, 2024
ae20590
Fixed `state_test.py::test_read_and_write_set_filter`.
philip-paul-mueller Oct 23, 2024
db211fa
Fixed the `state_test.py::test_read_write_set` test.
philip-paul-mueller Oct 23, 2024
570437b
Fixed the `state_test.py::test_read_write_set_y_formation` test.
philip-paul-mueller Oct 23, 2024
cb80f0b
Fixed `move_loop_into_map_test.py::MoveLoopIntoMapTest::test_more_tha…
philip-paul-mueller Oct 23, 2024
b704a43
Fixed `prune_connectors_test.py::test_read_write_*`.
philip-paul-mueller Oct 23, 2024
f74d6e8
General improvements to some tests.
philip-paul-mueller Oct 23, 2024
e103924
Updated `refine_nested_access_test.py::test_rna_read_and_write_sets_d…
philip-paul-mueller Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 67 additions & 36 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,51 +745,82 @@ def update_if_not_none(dic, update):

return defined_syms


def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, List[Subset]]]:
"""
Determines what data is read and written in this subgraph, returning
dictionaries from data containers to all subsets that are read/written.
"""
from dace.sdfg import utils # Avoid cyclic import

# Ensures that the `{src,dst}_subset` are properly set.
# TODO: find where the problems are
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very ugly hack, but I can replicate / encounter the issue too. I am fine with leaving this in, but please make sure there is an issue that keeps track of this TODO somewhere.

for edge in self.edges():
edge.data.try_initialize(self.sdfg, self, edge)

read_set = collections.defaultdict(list)
write_set = collections.defaultdict(list)
from dace.sdfg import utils # Avoid cyclic import
subgraphs = utils.concurrent_subgraphs(self)
for sg in subgraphs:
rs = collections.defaultdict(list)
ws = collections.defaultdict(list)
# Traverse in topological order, so data that is written before it
# is read is not counted in the read set
for n in utils.dfs_topological_sort(sg, sources=sg.source_nodes()):
if isinstance(n, nd.AccessNode):
in_edges = sg.in_edges(n)
out_edges = sg.out_edges(n)
# Filter out memlets which go out but the same data is written to the AccessNode by another memlet
for out_edge in list(out_edges):
for in_edge in list(in_edges):
if (in_edge.data.data == out_edge.data.data
and in_edge.data.dst_subset.covers(out_edge.data.src_subset)):
out_edges.remove(out_edge)
break

for e in in_edges:
# skip empty memlets
if e.data.is_empty():
continue
# Store all subsets that have been written
ws[n.data].append(e.data.subset)
for e in out_edges:
# skip empty memlets
if e.data.is_empty():
continue
rs[n.data].append(e.data.subset)
# Union all subgraphs, so an array that was excluded from the read
# set because it was written first is still included if it is read
# in another subgraph
for data, accesses in rs.items():

# NOTE: In a previous version a _single_ read (i.e. leaving Memlet) that was
# fully covered by a single write (i.e. an incoming Memlet) was removed from
# the read set and only the write survived. However, this was never fully
# implemented nor correctly implemented and caused problems.
# So this filtering was removed.

for subgraph in utils.concurrent_subgraphs(self):
subgraph_read_set = collections.defaultdict(list) # read and write set of this subgraph.
subgraph_write_set = collections.defaultdict(list)
for n in utils.dfs_topological_sort(subgraph, sources=subgraph.source_nodes()):
if not isinstance(n, nd.AccessNode):
# Read and writes can only be done through access nodes,
# so ignore every other node.
continue

# Get a list of all incoming (writes) and outgoing (reads) edges of the
# access node, ignore all empty memlets as they do not carry any data.
in_edges = [in_edge for in_edge in subgraph.in_edges(n) if not in_edge.data.is_empty()]
out_edges = [out_edge for out_edge in subgraph.out_edges(n) if not out_edge.data.is_empty()]

# Extract the subsets that describes where we read and write the data
# and store them for the later filtering.
# NOTE: In certain cases the corresponding subset might be None, in this case
# we assume that the whole array is written, which is the default behaviour.
ac_desc = n.desc(self.sdfg)
ac_size = ac_desc.total_size
in_subsets = dict()
for in_edge in in_edges:
# Ensure that if the destination subset is not given, our assumption, that the
# whole array is written to, is valid, by testing if the memlet transfers the
# whole array.
assert (in_edge.data.dst_subset is not None) or (in_edge.data.num_elements() == ac_size)
in_subsets[in_edge] = (
sbs.Range.from_array(ac_desc)
if in_edge.data.dst_subset is None
else in_edge.data.dst_subset
)
out_subsets = dict()
for out_edge in out_edges:
assert (out_edge.data.src_subset is not None) or (out_edge.data.num_elements() == ac_size)
out_subsets[out_edge] = (
sbs.Range.from_array(ac_desc)
if out_edge.data.src_subset is None
else out_edge.data.src_subset
)

# Update the read and write sets of the subgraph.
if in_edges:
subgraph_write_set[n.data].extend(in_subsets.values())
if out_edges:
subgraph_read_set[n.data].extend(out_subsets[out_edge] for out_edge in out_edges)

# Add the subgraph's read and write set to the final ones.
for data, accesses in subgraph_read_set.items():
read_set[data] += accesses
for data, accesses in ws.items():
for data, accesses in subgraph_write_set.items():
write_set[data] += accesses
return read_set, write_set

return copy.deepcopy((read_set, write_set))


def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]:
"""
Expand Down
93 changes: 91 additions & 2 deletions tests/sdfg/state_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace import subsets as sbs
from dace.transformation.helpers import find_sdfg_control_flow


Expand All @@ -19,7 +20,9 @@ def test_read_write_set():
state.add_memlet_path(rw_b, task2, dst_conn='B', memlet=dace.Memlet('B[2]'))
state.add_memlet_path(task2, write_c, src_conn='C', memlet=dace.Memlet('C[2]'))

assert 'B' not in state.read_and_write_sets()[0]
read_set, write_set = state.read_and_write_sets()
assert {'B', 'A'} == read_set
assert {'C', 'B'} == write_set


def test_read_write_set_y_formation():
Expand All @@ -41,7 +44,10 @@ def test_read_write_set_y_formation():
state.add_memlet_path(rw_b, task2, dst_conn='B', memlet=dace.Memlet(data='B', subset='0'))
state.add_memlet_path(task2, write_c, src_conn='C', memlet=dace.Memlet(data='C', subset='0'))

assert 'B' not in state.read_and_write_sets()[0]
read_set, write_set = state.read_and_write_sets()
assert {'B', 'A'} == read_set
assert {'C', 'B'} == write_set


def test_deepcopy_state():
N = dace.symbol('N')
Expand All @@ -58,6 +64,87 @@ def double_loop(arr: dace.float32[N]):
sdfg.validate()


def test_read_and_write_set_filter():
sdfg = dace.SDFG('graph')
state = sdfg.add_state('state')
sdfg.add_array('A', [2, 2], dace.float64)
sdfg.add_scalar('B', dace.float64)
sdfg.add_array('C', [2, 2], dace.float64)
A, B, C = (state.add_access(name) for name in ('A', 'B', 'C'))

state.add_nedge(
A,
B,
dace.Memlet("B[0] -> [0, 0]"),
)
state.add_nedge(
B,
C,
dace.Memlet("C[1, 1] -> [0]"),
)
state.add_nedge(
B,
C,
dace.Memlet("B[0] -> [0, 0]"),
)
sdfg.validate()

expected_reads = {
"A": [sbs.Range.from_string("0, 0")],
"B": [sbs.Range.from_string("0")],
}
expected_writes = {
"B": [sbs.Range.from_string("0")],
"C": [sbs.Range.from_string("0, 0"), sbs.Range.from_string("1, 1")],
}
read_set, write_set = state._read_and_write_sets()

for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]:
assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'."
for access_data in expected_sets.keys():
for exp in expected_sets[access_data]:
found_match = False
for res in computed_sets[access_data]:
if res == exp:
found_match = True
break
assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'"


def test_read_and_write_set_selection():
sdfg = dace.SDFG('graph')
state = sdfg.add_state('state')
sdfg.add_array('A', [2, 2], dace.float64)
sdfg.add_scalar('B', dace.float64)
A, B = (state.add_access(name) for name in ('A', 'B'))

state.add_nedge(
A,
B,
dace.Memlet("A[0, 0]"),
)
sdfg.validate()

expected_reads = {
"A": [sbs.Range.from_string("0, 0")],
}
expected_writes = {
"B": [sbs.Range.from_string("0")],
}
read_set, write_set = state._read_and_write_sets()

for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]:
assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'."
for access_data in expected_sets.keys():
for exp in expected_sets[access_data]:
found_match = False
for res in computed_sets[access_data]:
if res == exp:
found_match = True
break
assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'"


def test_add_mapped_tasklet():
sdfg = dace.SDFG("test_add_mapped_tasklet")
state = sdfg.add_state(is_start_block=True)
Expand All @@ -82,6 +169,8 @@ def test_add_mapped_tasklet():


if __name__ == '__main__':
test_read_and_write_set_selection()
test_read_and_write_set_filter()
test_read_write_set()
test_read_write_set_y_formation()
test_deepcopy_state()
Expand Down
64 changes: 60 additions & 4 deletions tests/transformations/move_loop_into_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dace
from dace.transformation.interstate import MoveLoopIntoMap
import unittest
import copy
import numpy as np

I = dace.symbol("I")
Expand Down Expand Up @@ -147,7 +148,12 @@ def test_apply_multiple_times_1(self):
self.assertTrue(np.allclose(val, ref))

def test_more_than_a_map(self):
""" `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. """
"""
`out` is read and written indirectly by the MapExit, potentially leading to a RW dependency.

Note that there is actually no dependency, however, the transformation, because it relies
on `SDFGState.read_and_write_sets()` it can not detect this and can thus not be applied.
"""
sdfg = dace.SDFG('more_than_a_map')
_, aarr = sdfg.add_array('A', (3, 3), dace.float64)
_, barr = sdfg.add_array('B', (3, 3), dace.float64)
Expand All @@ -167,11 +173,12 @@ def test_more_than_a_map(self):
external_edges=True,
input_nodes=dict(out=oread, B=bread),
output_nodes=dict(tmp=twrite))
body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr))
body.add_nedge(aread, oread, dace.Memlet.from_array('A', oarr))
body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr))
sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1')
count = sdfg.apply_transformations(MoveLoopIntoMap)
self.assertFalse(count > 0)

count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True)
self.assertTrue(count == 0)

def test_more_than_a_map_1(self):
"""
Expand Down Expand Up @@ -269,6 +276,55 @@ def test_more_than_a_map_3(self):
count = sdfg.apply_transformations(MoveLoopIntoMap)
self.assertFalse(count > 0)

def test_more_than_a_map_4(self):
"""
The test is very similar to `test_more_than_a_map()`. But a memlet is different
which leads to a RW dependency, which blocks the transformation.
"""
sdfg = dace.SDFG('more_than_a_map')
_, aarr = sdfg.add_array('A', (3, 3), dace.float64)
_, barr = sdfg.add_array('B', (3, 3), dace.float64)
_, oarr = sdfg.add_array('out', (3, 3), dace.float64)
_, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True)
body = sdfg.add_state('map_state')
aread = body.add_access('A')
oread = body.add_access('out')
bread = body.add_access('B')
twrite = body.add_access('tmp')
owrite = body.add_access('out')
body.add_mapped_tasklet('op',
dict(i='0:3', j='0:3'),
dict(__in1=dace.Memlet('out[i, j]'), __in2=dace.Memlet('B[i, j]')),
'__out = __in1 - __in2',
dict(__out=dace.Memlet('tmp[i, j]')),
external_edges=True,
input_nodes=dict(out=oread, B=bread),
output_nodes=dict(tmp=twrite))
body.add_nedge(aread, oread, dace.Memlet('A[Mod(_, 3), 0:3] -> [Mod(_ + 1, 3), 0:3]', aarr))
body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr))
sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1')

sdfg_args_ref = {
"A": np.array(np.random.rand(3, 3), dtype=np.float64),
"B": np.array(np.random.rand(3, 3), dtype=np.float64),
"out": np.array(np.random.rand(3, 3), dtype=np.float64),
}
sdfg_args_res = copy.deepcopy(sdfg_args_ref)

# Perform the reference execution
sdfg(**sdfg_args_ref)

# Apply the transformation and execute the SDFG again.
count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True)
sdfg(**sdfg_args_res)

for name in sdfg_args_ref.keys():
self.assertTrue(
np.allclose(sdfg_args_ref[name], sdfg_args_res[name]),
f"Miss match for {name}",
)
self.assertFalse(count > 0)


if __name__ == '__main__':
unittest.main()
22 changes: 2 additions & 20 deletions tests/transformations/prune_connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def _make_read_write_sdfg(

Depending on `conforming_memlet` the memlet that copies `inner_A` into `inner_B`
will either be associated to `inner_A` (`True`) or `inner_B` (`False`).
This choice has consequences on if the transformation can apply or not.

Notes:
This is most likely a bug, see [issue#1643](https://github.com/spcl/dace/issues/1643),
Expand Down Expand Up @@ -332,16 +331,6 @@ def test_unused_retval_2():
assert np.allclose(a, 1)


def test_read_write_1():
# Because the memlet is conforming, we can apply the transformation.
sdfg = _make_read_write_sdfg(True)

assert first_mode == PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=osdfg, expr_index=0, permissive=False)





def test_prune_connectors_with_dependencies():
sdfg = dace.SDFG('tester')
A, A_desc = sdfg.add_array('A', [4], dace.float64)
Expand Down Expand Up @@ -420,18 +409,11 @@ def test_prune_connectors_with_dependencies():
assert np.allclose(np_d, np_d_)


def test_read_write_1():
# Because the memlet is conforming, we can apply the transformation.
def test_read_write():
sdfg, nsdfg = _make_read_write_sdfg(True)
assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)

assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)
sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True)


def test_read_write_2():
# Because the memlet is not conforming, we can not apply the transformation.
sdfg, nsdfg = _make_read_write_sdfg(False)

assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)


Expand Down
Loading
Loading