Skip to content

Commit

Permalink
Fix to Read and Write Sets (#1678)
Browse files Browse the repository at this point in the history
During my work on the [new map
fusion](#1643) I discovered a bug in
`SDFGState._read_and_write_set()`.
Originally I solved it there, but it was decided to move it into its own
PR.


Lets look at the first, super silly example, that is not useful on its
own.
The main point here, is that the `data` attribute of the Memlet does not
refer to the source of the connection but of the destination.


![test_1](https://github.com/user-attachments/assets/740ee4fc-cfe5-4844-a999-e316cb8f9c16)


BTW: The Webviewer outputs something like `B[0] -> [0, 0]` however, the
parser of the Memlet constructor does not understand this, it must be
written as `B[0] -> 0, 0`, i.e. the second set of brackets must be
omitted, this should be changed!

From the above we would expect the following sets:
- Reads:
	- `A`: `[Range (0, 0)]`
- `B`: Should not be listed in this set, because it is fully read and
written, thus it is excluded.
- Writes
	- `B`: `[Range (0)]`
	- `C`: `[Range (0, 0), Range (1, 1)]`

However, the current implementation gives us:
- Reads: `{'A': [Range (0)], 'B': [Range (1, 1)]}`
- Write: `{'B': [Range (0)], 'C': [Range (1, 1), Range (0)]}`

The current behaviour is wrong because:
- `A` is a `2x2` array, thus the read set should also have two
dimensions.
- `B` inside the read set, it is a scalar, but the range has two
dimensions, furthermore, it is present at all.
- `C` the first member of the write set (`Range(1, 1)`) is correct,
while the second (`Range(0)`) is horrible wrong.


The second example is even more simple.


![test_2](https://github.com/user-attachments/assets/da3d03af-6f10-411f-952e-ab057ed057c6)


From the SDFG we expect the following sets:
- Reads:
	- `A`: `[Range(0, 0)]`
- Writes:
	- `B`: `[Range(0)]`

It is important that in the above example `other_subset` is `None` and
`data` is set to `A`, so it is not one of these "crazy" non standard
Memlets we have seen in the first test.
However, the current implementation gives us:
- Reads: `{'A': [Range (0, 0)]}`
- Writes: `{'B': [Range (0, 0)]}`

This clearly shows, that whatever the implementation does is not
correct.
  • Loading branch information
philip-paul-mueller authored Oct 23, 2024
1 parent 975a065 commit 380554f
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 62 deletions.
103 changes: 67 additions & 36 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,51 +745,82 @@ def update_if_not_none(dic, update):

return defined_syms


def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, List[Subset]]]:
"""
Determines what data is read and written in this subgraph, returning
dictionaries from data containers to all subsets that are read/written.
"""
from dace.sdfg import utils # Avoid cyclic import

# Ensures that the `{src,dst}_subset` are properly set.
# TODO: find where the problems are
for edge in self.edges():
edge.data.try_initialize(self.sdfg, self, edge)

read_set = collections.defaultdict(list)
write_set = collections.defaultdict(list)
from dace.sdfg import utils # Avoid cyclic import
subgraphs = utils.concurrent_subgraphs(self)
for sg in subgraphs:
rs = collections.defaultdict(list)
ws = collections.defaultdict(list)
# Traverse in topological order, so data that is written before it
# is read is not counted in the read set
for n in utils.dfs_topological_sort(sg, sources=sg.source_nodes()):
if isinstance(n, nd.AccessNode):
in_edges = sg.in_edges(n)
out_edges = sg.out_edges(n)
# Filter out memlets which go out but the same data is written to the AccessNode by another memlet
for out_edge in list(out_edges):
for in_edge in list(in_edges):
if (in_edge.data.data == out_edge.data.data
and in_edge.data.dst_subset.covers(out_edge.data.src_subset)):
out_edges.remove(out_edge)
break

for e in in_edges:
# skip empty memlets
if e.data.is_empty():
continue
# Store all subsets that have been written
ws[n.data].append(e.data.subset)
for e in out_edges:
# skip empty memlets
if e.data.is_empty():
continue
rs[n.data].append(e.data.subset)
# Union all subgraphs, so an array that was excluded from the read
# set because it was written first is still included if it is read
# in another subgraph
for data, accesses in rs.items():

# NOTE: In a previous version a _single_ read (i.e. leaving Memlet) that was
# fully covered by a single write (i.e. an incoming Memlet) was removed from
# the read set and only the write survived. However, this was never fully
# implemented nor correctly implemented and caused problems.
# So this filtering was removed.

for subgraph in utils.concurrent_subgraphs(self):
subgraph_read_set = collections.defaultdict(list) # read and write set of this subgraph.
subgraph_write_set = collections.defaultdict(list)
for n in utils.dfs_topological_sort(subgraph, sources=subgraph.source_nodes()):
if not isinstance(n, nd.AccessNode):
# Read and writes can only be done through access nodes,
# so ignore every other node.
continue

# Get a list of all incoming (writes) and outgoing (reads) edges of the
# access node, ignore all empty memlets as they do not carry any data.
in_edges = [in_edge for in_edge in subgraph.in_edges(n) if not in_edge.data.is_empty()]
out_edges = [out_edge for out_edge in subgraph.out_edges(n) if not out_edge.data.is_empty()]

# Extract the subsets that describes where we read and write the data
# and store them for the later filtering.
# NOTE: In certain cases the corresponding subset might be None, in this case
# we assume that the whole array is written, which is the default behaviour.
ac_desc = n.desc(self.sdfg)
ac_size = ac_desc.total_size
in_subsets = dict()
for in_edge in in_edges:
# Ensure that if the destination subset is not given, our assumption, that the
# whole array is written to, is valid, by testing if the memlet transfers the
# whole array.
assert (in_edge.data.dst_subset is not None) or (in_edge.data.num_elements() == ac_size)
in_subsets[in_edge] = (
sbs.Range.from_array(ac_desc)
if in_edge.data.dst_subset is None
else in_edge.data.dst_subset
)
out_subsets = dict()
for out_edge in out_edges:
assert (out_edge.data.src_subset is not None) or (out_edge.data.num_elements() == ac_size)
out_subsets[out_edge] = (
sbs.Range.from_array(ac_desc)
if out_edge.data.src_subset is None
else out_edge.data.src_subset
)

# Update the read and write sets of the subgraph.
if in_edges:
subgraph_write_set[n.data].extend(in_subsets.values())
if out_edges:
subgraph_read_set[n.data].extend(out_subsets[out_edge] for out_edge in out_edges)

# Add the subgraph's read and write set to the final ones.
for data, accesses in subgraph_read_set.items():
read_set[data] += accesses
for data, accesses in ws.items():
for data, accesses in subgraph_write_set.items():
write_set[data] += accesses
return read_set, write_set

return copy.deepcopy((read_set, write_set))


def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]:
"""
Expand Down
93 changes: 91 additions & 2 deletions tests/sdfg/state_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace import subsets as sbs
from dace.transformation.helpers import find_sdfg_control_flow


Expand All @@ -19,7 +20,9 @@ def test_read_write_set():
state.add_memlet_path(rw_b, task2, dst_conn='B', memlet=dace.Memlet('B[2]'))
state.add_memlet_path(task2, write_c, src_conn='C', memlet=dace.Memlet('C[2]'))

assert 'B' not in state.read_and_write_sets()[0]
read_set, write_set = state.read_and_write_sets()
assert {'B', 'A'} == read_set
assert {'C', 'B'} == write_set


def test_read_write_set_y_formation():
Expand All @@ -41,7 +44,10 @@ def test_read_write_set_y_formation():
state.add_memlet_path(rw_b, task2, dst_conn='B', memlet=dace.Memlet(data='B', subset='0'))
state.add_memlet_path(task2, write_c, src_conn='C', memlet=dace.Memlet(data='C', subset='0'))

assert 'B' not in state.read_and_write_sets()[0]
read_set, write_set = state.read_and_write_sets()
assert {'B', 'A'} == read_set
assert {'C', 'B'} == write_set


def test_deepcopy_state():
N = dace.symbol('N')
Expand All @@ -58,6 +64,87 @@ def double_loop(arr: dace.float32[N]):
sdfg.validate()


def test_read_and_write_set_filter():
sdfg = dace.SDFG('graph')
state = sdfg.add_state('state')
sdfg.add_array('A', [2, 2], dace.float64)
sdfg.add_scalar('B', dace.float64)
sdfg.add_array('C', [2, 2], dace.float64)
A, B, C = (state.add_access(name) for name in ('A', 'B', 'C'))

state.add_nedge(
A,
B,
dace.Memlet("B[0] -> [0, 0]"),
)
state.add_nedge(
B,
C,
dace.Memlet("C[1, 1] -> [0]"),
)
state.add_nedge(
B,
C,
dace.Memlet("B[0] -> [0, 0]"),
)
sdfg.validate()

expected_reads = {
"A": [sbs.Range.from_string("0, 0")],
"B": [sbs.Range.from_string("0")],
}
expected_writes = {
"B": [sbs.Range.from_string("0")],
"C": [sbs.Range.from_string("0, 0"), sbs.Range.from_string("1, 1")],
}
read_set, write_set = state._read_and_write_sets()

for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]:
assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'."
for access_data in expected_sets.keys():
for exp in expected_sets[access_data]:
found_match = False
for res in computed_sets[access_data]:
if res == exp:
found_match = True
break
assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'"


def test_read_and_write_set_selection():
sdfg = dace.SDFG('graph')
state = sdfg.add_state('state')
sdfg.add_array('A', [2, 2], dace.float64)
sdfg.add_scalar('B', dace.float64)
A, B = (state.add_access(name) for name in ('A', 'B'))

state.add_nedge(
A,
B,
dace.Memlet("A[0, 0]"),
)
sdfg.validate()

expected_reads = {
"A": [sbs.Range.from_string("0, 0")],
}
expected_writes = {
"B": [sbs.Range.from_string("0")],
}
read_set, write_set = state._read_and_write_sets()

for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]:
assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'."
for access_data in expected_sets.keys():
for exp in expected_sets[access_data]:
found_match = False
for res in computed_sets[access_data]:
if res == exp:
found_match = True
break
assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'"


def test_add_mapped_tasklet():
sdfg = dace.SDFG("test_add_mapped_tasklet")
state = sdfg.add_state(is_start_block=True)
Expand All @@ -82,6 +169,8 @@ def test_add_mapped_tasklet():


if __name__ == '__main__':
test_read_and_write_set_selection()
test_read_and_write_set_filter()
test_read_write_set()
test_read_write_set_y_formation()
test_deepcopy_state()
Expand Down
64 changes: 60 additions & 4 deletions tests/transformations/move_loop_into_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dace
from dace.transformation.interstate import MoveLoopIntoMap
import unittest
import copy
import numpy as np

I = dace.symbol("I")
Expand Down Expand Up @@ -147,7 +148,12 @@ def test_apply_multiple_times_1(self):
self.assertTrue(np.allclose(val, ref))

def test_more_than_a_map(self):
""" `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. """
"""
`out` is read and written indirectly by the MapExit, potentially leading to a RW dependency.
Note that there is actually no dependency, however, the transformation, because it relies
on `SDFGState.read_and_write_sets()` it can not detect this and can thus not be applied.
"""
sdfg = dace.SDFG('more_than_a_map')
_, aarr = sdfg.add_array('A', (3, 3), dace.float64)
_, barr = sdfg.add_array('B', (3, 3), dace.float64)
Expand All @@ -167,11 +173,12 @@ def test_more_than_a_map(self):
external_edges=True,
input_nodes=dict(out=oread, B=bread),
output_nodes=dict(tmp=twrite))
body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr))
body.add_nedge(aread, oread, dace.Memlet.from_array('A', oarr))
body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr))
sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1')
count = sdfg.apply_transformations(MoveLoopIntoMap)
self.assertFalse(count > 0)

count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True)
self.assertTrue(count == 0)

def test_more_than_a_map_1(self):
"""
Expand Down Expand Up @@ -269,6 +276,55 @@ def test_more_than_a_map_3(self):
count = sdfg.apply_transformations(MoveLoopIntoMap)
self.assertFalse(count > 0)

def test_more_than_a_map_4(self):
"""
The test is very similar to `test_more_than_a_map()`. But a memlet is different
which leads to a RW dependency, which blocks the transformation.
"""
sdfg = dace.SDFG('more_than_a_map')
_, aarr = sdfg.add_array('A', (3, 3), dace.float64)
_, barr = sdfg.add_array('B', (3, 3), dace.float64)
_, oarr = sdfg.add_array('out', (3, 3), dace.float64)
_, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True)
body = sdfg.add_state('map_state')
aread = body.add_access('A')
oread = body.add_access('out')
bread = body.add_access('B')
twrite = body.add_access('tmp')
owrite = body.add_access('out')
body.add_mapped_tasklet('op',
dict(i='0:3', j='0:3'),
dict(__in1=dace.Memlet('out[i, j]'), __in2=dace.Memlet('B[i, j]')),
'__out = __in1 - __in2',
dict(__out=dace.Memlet('tmp[i, j]')),
external_edges=True,
input_nodes=dict(out=oread, B=bread),
output_nodes=dict(tmp=twrite))
body.add_nedge(aread, oread, dace.Memlet('A[Mod(_, 3), 0:3] -> [Mod(_ + 1, 3), 0:3]', aarr))
body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr))
sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1')

sdfg_args_ref = {
"A": np.array(np.random.rand(3, 3), dtype=np.float64),
"B": np.array(np.random.rand(3, 3), dtype=np.float64),
"out": np.array(np.random.rand(3, 3), dtype=np.float64),
}
sdfg_args_res = copy.deepcopy(sdfg_args_ref)

# Perform the reference execution
sdfg(**sdfg_args_ref)

# Apply the transformation and execute the SDFG again.
count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True)
sdfg(**sdfg_args_res)

for name in sdfg_args_ref.keys():
self.assertTrue(
np.allclose(sdfg_args_ref[name], sdfg_args_res[name]),
f"Miss match for {name}",
)
self.assertFalse(count > 0)


if __name__ == '__main__':
unittest.main()
22 changes: 2 additions & 20 deletions tests/transformations/prune_connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def _make_read_write_sdfg(
Depending on `conforming_memlet` the memlet that copies `inner_A` into `inner_B`
will either be associated to `inner_A` (`True`) or `inner_B` (`False`).
This choice has consequences on if the transformation can apply or not.
Notes:
This is most likely a bug, see [issue#1643](https://github.com/spcl/dace/issues/1643),
Expand Down Expand Up @@ -332,16 +331,6 @@ def test_unused_retval_2():
assert np.allclose(a, 1)


def test_read_write_1():
# Because the memlet is conforming, we can apply the transformation.
sdfg = _make_read_write_sdfg(True)

assert first_mode == PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=osdfg, expr_index=0, permissive=False)





def test_prune_connectors_with_dependencies():
sdfg = dace.SDFG('tester')
A, A_desc = sdfg.add_array('A', [4], dace.float64)
Expand Down Expand Up @@ -420,18 +409,11 @@ def test_prune_connectors_with_dependencies():
assert np.allclose(np_d, np_d_)


def test_read_write_1():
# Because the memlet is conforming, we can apply the transformation.
def test_read_write():
sdfg, nsdfg = _make_read_write_sdfg(True)
assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)

assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)
sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True)


def test_read_write_2():
# Because the memlet is not conforming, we can not apply the transformation.
sdfg, nsdfg = _make_read_write_sdfg(False)

assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)


Expand Down
Loading

0 comments on commit 380554f

Please sign in to comment.