diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 2ae6109b31..09e7607d65 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -745,51 +745,82 @@ def update_if_not_none(dic, update): return defined_syms + def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, List[Subset]]]: """ Determines what data is read and written in this subgraph, returning dictionaries from data containers to all subsets that are read/written. """ + from dace.sdfg import utils # Avoid cyclic import + + # Ensures that the `{src,dst}_subset` are properly set. + # TODO: find where the problems are + for edge in self.edges(): + edge.data.try_initialize(self.sdfg, self, edge) + read_set = collections.defaultdict(list) write_set = collections.defaultdict(list) - from dace.sdfg import utils # Avoid cyclic import - subgraphs = utils.concurrent_subgraphs(self) - for sg in subgraphs: - rs = collections.defaultdict(list) - ws = collections.defaultdict(list) - # Traverse in topological order, so data that is written before it - # is read is not counted in the read set - for n in utils.dfs_topological_sort(sg, sources=sg.source_nodes()): - if isinstance(n, nd.AccessNode): - in_edges = sg.in_edges(n) - out_edges = sg.out_edges(n) - # Filter out memlets which go out but the same data is written to the AccessNode by another memlet - for out_edge in list(out_edges): - for in_edge in list(in_edges): - if (in_edge.data.data == out_edge.data.data - and in_edge.data.dst_subset.covers(out_edge.data.src_subset)): - out_edges.remove(out_edge) - break - - for e in in_edges: - # skip empty memlets - if e.data.is_empty(): - continue - # Store all subsets that have been written - ws[n.data].append(e.data.subset) - for e in out_edges: - # skip empty memlets - if e.data.is_empty(): - continue - rs[n.data].append(e.data.subset) - # Union all subgraphs, so an array that was excluded from the read - # set because it was written first is still included if it is read - # in another subgraph - for data, accesses in rs.items(): + + # NOTE: In a previous version a _single_ read (i.e. leaving Memlet) that was + # fully covered by a single write (i.e. an incoming Memlet) was removed from + # the read set and only the write survived. However, this was never fully + # implemented nor correctly implemented and caused problems. + # So this filtering was removed. + + for subgraph in utils.concurrent_subgraphs(self): + subgraph_read_set = collections.defaultdict(list) # read and write set of this subgraph. + subgraph_write_set = collections.defaultdict(list) + for n in utils.dfs_topological_sort(subgraph, sources=subgraph.source_nodes()): + if not isinstance(n, nd.AccessNode): + # Read and writes can only be done through access nodes, + # so ignore every other node. + continue + + # Get a list of all incoming (writes) and outgoing (reads) edges of the + # access node, ignore all empty memlets as they do not carry any data. + in_edges = [in_edge for in_edge in subgraph.in_edges(n) if not in_edge.data.is_empty()] + out_edges = [out_edge for out_edge in subgraph.out_edges(n) if not out_edge.data.is_empty()] + + # Extract the subsets that describes where we read and write the data + # and store them for the later filtering. + # NOTE: In certain cases the corresponding subset might be None, in this case + # we assume that the whole array is written, which is the default behaviour. + ac_desc = n.desc(self.sdfg) + ac_size = ac_desc.total_size + in_subsets = dict() + for in_edge in in_edges: + # Ensure that if the destination subset is not given, our assumption, that the + # whole array is written to, is valid, by testing if the memlet transfers the + # whole array. + assert (in_edge.data.dst_subset is not None) or (in_edge.data.num_elements() == ac_size) + in_subsets[in_edge] = ( + sbs.Range.from_array(ac_desc) + if in_edge.data.dst_subset is None + else in_edge.data.dst_subset + ) + out_subsets = dict() + for out_edge in out_edges: + assert (out_edge.data.src_subset is not None) or (out_edge.data.num_elements() == ac_size) + out_subsets[out_edge] = ( + sbs.Range.from_array(ac_desc) + if out_edge.data.src_subset is None + else out_edge.data.src_subset + ) + + # Update the read and write sets of the subgraph. + if in_edges: + subgraph_write_set[n.data].extend(in_subsets.values()) + if out_edges: + subgraph_read_set[n.data].extend(out_subsets[out_edge] for out_edge in out_edges) + + # Add the subgraph's read and write set to the final ones. + for data, accesses in subgraph_read_set.items(): read_set[data] += accesses - for data, accesses in ws.items(): + for data, accesses in subgraph_write_set.items(): write_set[data] += accesses - return read_set, write_set + + return copy.deepcopy((read_set, write_set)) + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: """ diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index 7ba43ac4c0..4bde3788e0 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import dace +from dace import subsets as sbs from dace.transformation.helpers import find_sdfg_control_flow @@ -19,7 +20,9 @@ def test_read_write_set(): state.add_memlet_path(rw_b, task2, dst_conn='B', memlet=dace.Memlet('B[2]')) state.add_memlet_path(task2, write_c, src_conn='C', memlet=dace.Memlet('C[2]')) - assert 'B' not in state.read_and_write_sets()[0] + read_set, write_set = state.read_and_write_sets() + assert {'B', 'A'} == read_set + assert {'C', 'B'} == write_set def test_read_write_set_y_formation(): @@ -41,7 +44,10 @@ def test_read_write_set_y_formation(): state.add_memlet_path(rw_b, task2, dst_conn='B', memlet=dace.Memlet(data='B', subset='0')) state.add_memlet_path(task2, write_c, src_conn='C', memlet=dace.Memlet(data='C', subset='0')) - assert 'B' not in state.read_and_write_sets()[0] + read_set, write_set = state.read_and_write_sets() + assert {'B', 'A'} == read_set + assert {'C', 'B'} == write_set + def test_deepcopy_state(): N = dace.symbol('N') @@ -58,6 +64,87 @@ def double_loop(arr: dace.float32[N]): sdfg.validate() +def test_read_and_write_set_filter(): + sdfg = dace.SDFG('graph') + state = sdfg.add_state('state') + sdfg.add_array('A', [2, 2], dace.float64) + sdfg.add_scalar('B', dace.float64) + sdfg.add_array('C', [2, 2], dace.float64) + A, B, C = (state.add_access(name) for name in ('A', 'B', 'C')) + + state.add_nedge( + A, + B, + dace.Memlet("B[0] -> [0, 0]"), + ) + state.add_nedge( + B, + C, + dace.Memlet("C[1, 1] -> [0]"), + ) + state.add_nedge( + B, + C, + dace.Memlet("B[0] -> [0, 0]"), + ) + sdfg.validate() + + expected_reads = { + "A": [sbs.Range.from_string("0, 0")], + "B": [sbs.Range.from_string("0")], + } + expected_writes = { + "B": [sbs.Range.from_string("0")], + "C": [sbs.Range.from_string("0, 0"), sbs.Range.from_string("1, 1")], + } + read_set, write_set = state._read_and_write_sets() + + for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]: + assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'." + for access_data in expected_sets.keys(): + for exp in expected_sets[access_data]: + found_match = False + for res in computed_sets[access_data]: + if res == exp: + found_match = True + break + assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'" + + +def test_read_and_write_set_selection(): + sdfg = dace.SDFG('graph') + state = sdfg.add_state('state') + sdfg.add_array('A', [2, 2], dace.float64) + sdfg.add_scalar('B', dace.float64) + A, B = (state.add_access(name) for name in ('A', 'B')) + + state.add_nedge( + A, + B, + dace.Memlet("A[0, 0]"), + ) + sdfg.validate() + + expected_reads = { + "A": [sbs.Range.from_string("0, 0")], + } + expected_writes = { + "B": [sbs.Range.from_string("0")], + } + read_set, write_set = state._read_and_write_sets() + + for expected_sets, computed_sets in [(expected_reads, read_set), (expected_writes, write_set)]: + assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'." + for access_data in expected_sets.keys(): + for exp in expected_sets[access_data]: + found_match = False + for res in computed_sets[access_data]: + if res == exp: + found_match = True + break + assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'" + + def test_add_mapped_tasklet(): sdfg = dace.SDFG("test_add_mapped_tasklet") state = sdfg.add_state(is_start_block=True) @@ -82,6 +169,8 @@ def test_add_mapped_tasklet(): if __name__ == '__main__': + test_read_and_write_set_selection() + test_read_and_write_set_filter() test_read_write_set() test_read_write_set_y_formation() test_deepcopy_state() diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index dca775bb7a..ad51941cb0 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -2,6 +2,7 @@ import dace from dace.transformation.interstate import MoveLoopIntoMap import unittest +import copy import numpy as np I = dace.symbol("I") @@ -147,7 +148,12 @@ def test_apply_multiple_times_1(self): self.assertTrue(np.allclose(val, ref)) def test_more_than_a_map(self): - """ `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. """ + """ + `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. + + Note that there is actually no dependency, however, the transformation, because it relies + on `SDFGState.read_and_write_sets()` it can not detect this and can thus not be applied. + """ sdfg = dace.SDFG('more_than_a_map') _, aarr = sdfg.add_array('A', (3, 3), dace.float64) _, barr = sdfg.add_array('B', (3, 3), dace.float64) @@ -167,11 +173,12 @@ def test_more_than_a_map(self): external_edges=True, input_nodes=dict(out=oread, B=bread), output_nodes=dict(tmp=twrite)) - body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr)) + body.add_nedge(aread, oread, dace.Memlet.from_array('A', oarr)) body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertFalse(count > 0) + + count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True) + self.assertTrue(count == 0) def test_more_than_a_map_1(self): """ @@ -269,6 +276,55 @@ def test_more_than_a_map_3(self): count = sdfg.apply_transformations(MoveLoopIntoMap) self.assertFalse(count > 0) + def test_more_than_a_map_4(self): + """ + The test is very similar to `test_more_than_a_map()`. But a memlet is different + which leads to a RW dependency, which blocks the transformation. + """ + sdfg = dace.SDFG('more_than_a_map') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + body = sdfg.add_state('map_state') + aread = body.add_access('A') + oread = body.add_access('out') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('out[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(out=oread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(aread, oread, dace.Memlet('A[Mod(_, 3), 0:3] -> [Mod(_ + 1, 3), 0:3]', aarr)) + body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) + sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') + + sdfg_args_ref = { + "A": np.array(np.random.rand(3, 3), dtype=np.float64), + "B": np.array(np.random.rand(3, 3), dtype=np.float64), + "out": np.array(np.random.rand(3, 3), dtype=np.float64), + } + sdfg_args_res = copy.deepcopy(sdfg_args_ref) + + # Perform the reference execution + sdfg(**sdfg_args_ref) + + # Apply the transformation and execute the SDFG again. + count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True) + sdfg(**sdfg_args_res) + + for name in sdfg_args_ref.keys(): + self.assertTrue( + np.allclose(sdfg_args_ref[name], sdfg_args_res[name]), + f"Miss match for {name}", + ) + self.assertFalse(count > 0) + if __name__ == '__main__': unittest.main() diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index 63bbe5843f..b7b287d77e 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -153,7 +153,6 @@ def _make_read_write_sdfg( Depending on `conforming_memlet` the memlet that copies `inner_A` into `inner_B` will either be associated to `inner_A` (`True`) or `inner_B` (`False`). - This choice has consequences on if the transformation can apply or not. Notes: This is most likely a bug, see [issue#1643](https://github.com/spcl/dace/issues/1643), @@ -332,16 +331,6 @@ def test_unused_retval_2(): assert np.allclose(a, 1) -def test_read_write_1(): - # Because the memlet is conforming, we can apply the transformation. - sdfg = _make_read_write_sdfg(True) - - assert first_mode == PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=osdfg, expr_index=0, permissive=False) - - - - - def test_prune_connectors_with_dependencies(): sdfg = dace.SDFG('tester') A, A_desc = sdfg.add_array('A', [4], dace.float64) @@ -420,18 +409,11 @@ def test_prune_connectors_with_dependencies(): assert np.allclose(np_d, np_d_) -def test_read_write_1(): - # Because the memlet is conforming, we can apply the transformation. +def test_read_write(): sdfg, nsdfg = _make_read_write_sdfg(True) + assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) - assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) - sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True) - - -def test_read_write_2(): - # Because the memlet is not conforming, we can not apply the transformation. sdfg, nsdfg = _make_read_write_sdfg(False) - assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) diff --git a/tests/transformations/refine_nested_access_test.py b/tests/transformations/refine_nested_access_test.py index d9fb9a7392..81640665ed 100644 --- a/tests/transformations/refine_nested_access_test.py +++ b/tests/transformations/refine_nested_access_test.py @@ -156,7 +156,115 @@ def inner_sdfg(A: dace.int32[5], B: dace.int32[5, 5], idx_a: int, idx_b: int): assert np.allclose(ref, val) +def _make_rna_read_and_write_set_sdfg(diff_in_out: bool) -> dace.SDFG: + """Generates the SDFG for the `test_rna_read_and_write_sets_*()` tests. + + If `diff_in_out` is `False` then the output is also used as temporary storage + within the nested SDFG. Because of the definition of the read/write sets, + this usage of the temporary storage is not picked up and it is only considered + as write set. + + If `diff_in_out` is true, then a different storage container, which is classified + as output, is used as temporary storage. + + This test was added during [PR#1678](https://github.com/spcl/dace/pull/1678). + """ + + def _make_nested_sdfg(diff_in_out: bool) -> dace.SDFG: + sdfg = dace.SDFG("inner_sdfg") + state = sdfg.add_state(is_start_block=True) + sdfg.add_array("A", dtype=dace.float64, shape=(2,), transient=False) + sdfg.add_array("T1", dtype=dace.float64, shape=(2,), transient=False) + + A = state.add_access("A") + T1_output = state.add_access("T1") + if diff_in_out: + sdfg.add_array("T2", dtype=dace.float64, shape=(2,), transient=False) + T1_input = state.add_access("T2") + else: + T1_input = state.add_access("T1") + + tsklt = state.add_tasklet( + "comp", + inputs={"__in1": None, "__in2": None}, + outputs={"__out": None}, + code="__out = __in1 + __in2", + ) + + state.add_edge(A, None, tsklt, "__in1", dace.Memlet("A[1]")) + # An alternative would be to write to a different location here. + # Then, the data would be added to the access node. + state.add_edge(A, None, T1_input, None, dace.Memlet("A[0] -> [0]")) + state.add_edge(T1_input, None, tsklt, "__in2", dace.Memlet(T1_input.data + "[0]")) + state.add_edge(tsklt, "__out", T1_output, None, dace.Memlet(T1_output.data + "[1]")) + return sdfg + + sdfg = dace.SDFG("Parent_SDFG") + state = sdfg.add_state(is_start_block=True) + + sdfg.add_array("A", dtype=dace.float64, shape=(2,), transient=False) + sdfg.add_array("T1", dtype=dace.float64, shape=(2,), transient=False) + sdfg.add_array("T2", dtype=dace.float64, shape=(2,), transient=False) + A = state.add_access("A") + T1 = state.add_access("T1") + + nested_sdfg = _make_nested_sdfg(diff_in_out) + + nsdfg = state.add_nested_sdfg( + nested_sdfg, + parent=sdfg, + inputs={"A"}, + outputs={"T2", "T1"} if diff_in_out else {"T1"}, + symbol_mapping={}, + ) + + state.add_edge(A, None, nsdfg, "A", dace.Memlet("A[0:2]")) + state.add_edge(nsdfg, "T1", T1, None, dace.Memlet("T1[0:2]")) + + if diff_in_out: + state.add_edge(nsdfg, "T2", state.add_access("T2"), None, dace.Memlet("T2[0:2]")) + sdfg.validate() + return sdfg + + +def test_rna_read_and_write_sets_doule_use(): + # The transformation does not apply because we access element `0` of both arrays that we + # pass inside the nested SDFG. + sdfg = _make_rna_read_and_write_set_sdfg(False) + nb_applied = sdfg.apply_transformations_repeated( + [RefineNestedAccess], + validate=True, + validate_all=True, + ) + assert nb_applied == 0 + + +def test_rna_read_and_write_sets_different_storage(): + + # There is a dedicated temporary storage used. + sdfg = _make_rna_read_and_write_set_sdfg(True) + + nb_applied = sdfg.apply_transformations_repeated( + [RefineNestedAccess], + validate=True, + validate_all=True, + ) + assert nb_applied > 0 + + args = { + "A": np.array(np.random.rand(2), dtype=np.float64, copy=True), + "T2": np.array(np.random.rand(2), dtype=np.float64, copy=True), + "T1": np.zeros(2, dtype=np.float64), + } + ref = args["A"][0] + args["A"][1] + sdfg(**args) + res = args["T1"][1] + assert np.allclose(res, ref), f"Expected '{ref}' but got '{res}'." + + if __name__ == '__main__': test_refine_dataflow() test_refine_interstate() test_free_symbols_only_by_indices() + test_rna_read_and_write_sets_different_storage() + test_rna_read_and_write_sets_doule_use()