From 7e69165b5a434a3f94125b3b7db054d453b9863f Mon Sep 17 00:00:00 2001 From: kuku929 Date: Fri, 20 Sep 2024 19:25:58 +0530 Subject: [PATCH 1/2] added fill library node with test --- AUTHORS | 1 + dace/libraries/standard/nodes/fill.py | 55 ++++++++++++++++++++++++++ tests/library/fill_test.py | 57 +++++++++++++++++++++++++++ 3 files changed, 113 insertions(+) create mode 100644 dace/libraries/standard/nodes/fill.py create mode 100644 tests/library/fill_test.py diff --git a/AUTHORS b/AUTHORS index 48cb4c05ec..558f73eadb 100644 --- a/AUTHORS +++ b/AUTHORS @@ -37,5 +37,6 @@ Yihang Luo Alexandru Calotoiu Phillip Lane Samuel Martin +Krutarth Patel and other contributors listed in https://github.com/spcl/dace/graphs/contributors diff --git a/dace/libraries/standard/nodes/fill.py b/dace/libraries/standard/nodes/fill.py new file mode 100644 index 0000000000..44a821e705 --- /dev/null +++ b/dace/libraries/standard/nodes/fill.py @@ -0,0 +1,55 @@ +from dace import library, nodes, properties +from dace.transformation.transformation import ExpandTransformation +from numbers import Number +import dace.subsets + + +@library.expansion +class ExpandPure(ExpandTransformation): + """Implements pure expansion of the Fill library node.""" + + environments = [] + + @staticmethod + def expansion(node, parent_state, parent_sdfg): + output = None + for e in parent_state.out_edges(node): + if e.src_conn == "_output": + output = parent_sdfg.arrays[e.data.data] + sdfg = dace.SDFG(f"{node.label}_sdfg") + _, out_arr = sdfg.add_array( + "_output", + output.shape, + output.dtype, + output.storage, + strides=output.strides, + ) + + state = sdfg.add_state(f"{node.label}_state") + map_params = [f"__i{i}" for i in range(len(out_arr.shape))] + map_rng = {i: f"0:{s}" for i, s in zip(map_params, out_arr.shape)} + out_mem = dace.Memlet(expr=f"_output[{','.join(map_params)}]") + inputs = {} + outputs = {"_out": out_mem} + code = f"_out = {node.value}" + state.add_mapped_tasklet( + f"{node.label}_tasklet", map_rng, inputs, code, outputs, external_edges=True + ) + + return sdfg + + +@library.node +class Fill(nodes.LibraryNode): + """Implements filling data containers with a single value""" + + implementations = {"pure": ExpandPure} + default_implementation = "pure" + value = properties.SymbolicProperty( + dtype=Number, default=0, desc="value to fill data container" + ) + + def __init__(self, name, value=0): + super().__init__(name, outputs={"_output"}) + self.value = value + self.name = name diff --git a/tests/library/fill_test.py b/tests/library/fill_test.py new file mode 100644 index 0000000000..7624a53b38 --- /dev/null +++ b/tests/library/fill_test.py @@ -0,0 +1,57 @@ +# # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +from dace.memlet import Memlet +from dace.libraries.standard.nodes import std_nodes + + +def pure_graph(implementation, dtype, size): + sdfg_name = f"fill_{implementation}_{dtype.ctype}_w{size}" + sdfg = dace.SDFG(sdfg_name) + + state = sdfg.add_state("fill") + + value = dace.symbol("value") + sdfg.add_array("r", [size], dtype) + result = state.add_write("r") + + fill_node = std_nodes.Fill("fill") + fill_node.implementation = implementation + fill_node.value = value + + # how to initialize memlet here? + state.add_memlet_path(fill_node, result, src_conn="_output", memlet=Memlet()) + + return sdfg + + +def run_test(target, size, value): + if target == "pure": + sdfg = pure_graph("pure", dace.float32, size) + # expand the nested sdfg returned by fill node + sdfg.expand_library_nodes() + else: + print(f"Unsupported target: {target}") + exit(-1) + + # we get the function we can call + fill = sdfg.compile() + + # supposed to be filled + result = np.ndarray(size, dtype=np.float32) + + # the parameters are all the symbols defined in the sdfg + fill(value=value, r=result) + for val in result: + if val != value: + raise ValueError(f"expected {value}, found {val}") + return sdfg + + +def test_fill_pure(): + # should not return a value error + assert isinstance(run_test("pure", 64, 1), dace.SDFG) + + +if __name__ == "__main__": + test_fill_pure() From 8a14c5306252df7cb65e8d43c0a21403c3b014a1 Mon Sep 17 00:00:00 2001 From: kuku929 Date: Fri, 20 Sep 2024 20:00:01 +0530 Subject: [PATCH 2/2] fixed import errors --- dace/libraries/standard/nodes/fill.py | 3 +-- dace/viewer/webclient | 2 +- tests/library/fill_test.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/dace/libraries/standard/nodes/fill.py b/dace/libraries/standard/nodes/fill.py index 44a821e705..e5dca7a9ef 100644 --- a/dace/libraries/standard/nodes/fill.py +++ b/dace/libraries/standard/nodes/fill.py @@ -1,8 +1,7 @@ +import dace from dace import library, nodes, properties from dace.transformation.transformation import ExpandTransformation from numbers import Number -import dace.subsets - @library.expansion class ExpandPure(ExpandTransformation): diff --git a/dace/viewer/webclient b/dace/viewer/webclient index c6b8fe4fd2..27174b1918 160000 --- a/dace/viewer/webclient +++ b/dace/viewer/webclient @@ -1 +1 @@ -Subproject commit c6b8fe4fd2c3616b0480ead4c24d8012b91a31fd +Subproject commit 27174b19180d6cf41e70a77a3a63bfef67ef6983 diff --git a/tests/library/fill_test.py b/tests/library/fill_test.py index 7624a53b38..42a5624a71 100644 --- a/tests/library/fill_test.py +++ b/tests/library/fill_test.py @@ -2,7 +2,7 @@ import dace import numpy as np from dace.memlet import Memlet -from dace.libraries.standard.nodes import std_nodes +from dace.libraries.standard.nodes import fill def pure_graph(implementation, dtype, size): @@ -15,7 +15,7 @@ def pure_graph(implementation, dtype, size): sdfg.add_array("r", [size], dtype) result = state.add_write("r") - fill_node = std_nodes.Fill("fill") + fill_node = fill.Fill("fill") fill_node.implementation = implementation fill_node.value = value