diff --git a/AUTHORS b/AUTHORS index 48cb4c05ec..558f73eadb 100644 --- a/AUTHORS +++ b/AUTHORS @@ -37,5 +37,6 @@ Yihang Luo Alexandru Calotoiu Phillip Lane Samuel Martin +Krutarth Patel and other contributors listed in https://github.com/spcl/dace/graphs/contributors diff --git a/dace/libraries/standard/nodes/fill.py b/dace/libraries/standard/nodes/fill.py new file mode 100644 index 0000000000..e5dca7a9ef --- /dev/null +++ b/dace/libraries/standard/nodes/fill.py @@ -0,0 +1,54 @@ +import dace +from dace import library, nodes, properties +from dace.transformation.transformation import ExpandTransformation +from numbers import Number + +@library.expansion +class ExpandPure(ExpandTransformation): + """Implements pure expansion of the Fill library node.""" + + environments = [] + + @staticmethod + def expansion(node, parent_state, parent_sdfg): + output = None + for e in parent_state.out_edges(node): + if e.src_conn == "_output": + output = parent_sdfg.arrays[e.data.data] + sdfg = dace.SDFG(f"{node.label}_sdfg") + _, out_arr = sdfg.add_array( + "_output", + output.shape, + output.dtype, + output.storage, + strides=output.strides, + ) + + state = sdfg.add_state(f"{node.label}_state") + map_params = [f"__i{i}" for i in range(len(out_arr.shape))] + map_rng = {i: f"0:{s}" for i, s in zip(map_params, out_arr.shape)} + out_mem = dace.Memlet(expr=f"_output[{','.join(map_params)}]") + inputs = {} + outputs = {"_out": out_mem} + code = f"_out = {node.value}" + state.add_mapped_tasklet( + f"{node.label}_tasklet", map_rng, inputs, code, outputs, external_edges=True + ) + + return sdfg + + +@library.node +class Fill(nodes.LibraryNode): + """Implements filling data containers with a single value""" + + implementations = {"pure": ExpandPure} + default_implementation = "pure" + value = properties.SymbolicProperty( + dtype=Number, default=0, desc="value to fill data container" + ) + + def __init__(self, name, value=0): + super().__init__(name, outputs={"_output"}) + self.value = value + self.name = name diff --git a/dace/viewer/webclient b/dace/viewer/webclient index c6b8fe4fd2..27174b1918 160000 --- a/dace/viewer/webclient +++ b/dace/viewer/webclient @@ -1 +1 @@ -Subproject commit c6b8fe4fd2c3616b0480ead4c24d8012b91a31fd +Subproject commit 27174b19180d6cf41e70a77a3a63bfef67ef6983 diff --git a/tests/library/fill_test.py b/tests/library/fill_test.py new file mode 100644 index 0000000000..42a5624a71 --- /dev/null +++ b/tests/library/fill_test.py @@ -0,0 +1,57 @@ +# # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +from dace.memlet import Memlet +from dace.libraries.standard.nodes import fill + + +def pure_graph(implementation, dtype, size): + sdfg_name = f"fill_{implementation}_{dtype.ctype}_w{size}" + sdfg = dace.SDFG(sdfg_name) + + state = sdfg.add_state("fill") + + value = dace.symbol("value") + sdfg.add_array("r", [size], dtype) + result = state.add_write("r") + + fill_node = fill.Fill("fill") + fill_node.implementation = implementation + fill_node.value = value + + # how to initialize memlet here? + state.add_memlet_path(fill_node, result, src_conn="_output", memlet=Memlet()) + + return sdfg + + +def run_test(target, size, value): + if target == "pure": + sdfg = pure_graph("pure", dace.float32, size) + # expand the nested sdfg returned by fill node + sdfg.expand_library_nodes() + else: + print(f"Unsupported target: {target}") + exit(-1) + + # we get the function we can call + fill = sdfg.compile() + + # supposed to be filled + result = np.ndarray(size, dtype=np.float32) + + # the parameters are all the symbols defined in the sdfg + fill(value=value, r=result) + for val in result: + if val != value: + raise ValueError(f"expected {value}, found {val}") + return sdfg + + +def test_fill_pure(): + # should not return a value error + assert isinstance(run_test("pure", 64, 1), dace.SDFG) + + +if __name__ == "__main__": + test_fill_pure()