-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Check host_maps and host_data in the GPU transformations #1701
base: main
Are you sure you want to change the base?
Changes from all commits
789571d
a1906c6
223db14
d832980
dde556b
2f6c2b3
90d8ab2
b68538f
63fa672
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
from dace.sdfg import nodes, scope | ||
from dace.sdfg import utils as sdutil | ||
from dace.transformation import transformation, helpers as xfh | ||
from dace.properties import Property, make_properties | ||
from dace.properties import ListProperty, Property, make_properties | ||
from collections import defaultdict | ||
from copy import deepcopy as dc | ||
from sympy import floor | ||
|
@@ -128,6 +128,12 @@ class GPUTransformSDFG(transformation.MultiStateTransformation): | |
dtype=str, | ||
default='') | ||
|
||
host_maps = ListProperty(desc='List of map GUIDs, the passed maps are not offloaded to the GPU', | ||
element_type=str, default=None, allow_none=True) | ||
|
||
host_data = ListProperty(desc='List of data names, the passed data are not offloaded to the GPU', | ||
element_type=str, default=None, allow_none=True) | ||
|
||
@staticmethod | ||
def annotates_memlets(): | ||
# Skip memlet propagation for now | ||
|
@@ -154,19 +160,38 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): | |
return False | ||
return True | ||
|
||
def apply(self, _, sdfg: sd.SDFG): | ||
def _output_or_input_is_marked_host(self, state, entry_node): | ||
if (self.host_data is None or self.host_data == []) and (self.host_maps is None or self.host_maps == []): | ||
return False | ||
marked_accesses = [e.data.data for e in state.in_edges(entry_node) + state.out_edges(state.exit_node(entry_node)) | ||
if e.data.data is not None and e.data.data in self.host_data] | ||
return len(marked_accesses) > 0 | ||
|
||
|
||
def apply(self, _, sdfg: sd.SDFG): | ||
####################################################### | ||
# Step 0: SDFG metadata | ||
|
||
# Find all input and output data descriptors | ||
input_nodes = [] | ||
output_nodes = [] | ||
global_code_nodes: Dict[sd.SDFGState, nodes.Tasklet] = defaultdict(list) | ||
if self.host_maps is None: | ||
self.host_maps = [] | ||
if self.host_data is None: | ||
self.host_data = [] | ||
|
||
# Propagate memlets to ensure that we can find the true array subsets that are written. | ||
propagate_memlets_sdfg(sdfg) | ||
|
||
# Input and ouputs of all host_maps need to be marked as host_data | ||
for state in sdfg.nodes(): | ||
for node in state.nodes(): | ||
if isinstance(node, nodes.EntryNode) and node.guid in self.host_maps: | ||
accesses = {e.data.data for e in state.in_edges(node) + state.out_edges(state.exit_node(node)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment above. |
||
if e.data.data is not None and node.guid in self.host_maps} | ||
self.host_data.extend(accesses) | ||
|
||
for state in sdfg.nodes(): | ||
sdict = state.scope_dict() | ||
for node in state.nodes(): | ||
|
@@ -176,12 +201,13 @@ def apply(self, _, sdfg: sd.SDFG): | |
# map ranges must stay on host | ||
for e in state.out_edges(node): | ||
last_edge = state.memlet_path(e)[-1] | ||
if (isinstance(last_edge.dst, nodes.EntryNode) and last_edge.dst_conn | ||
and not last_edge.dst_conn.startswith('IN_') and sdict[last_edge.dst] is None): | ||
if (isinstance(last_edge.dst, nodes.EntryNode) and ((last_edge.dst_conn | ||
and not last_edge.dst_conn.startswith('IN_') and sdict[last_edge.dst] is None) or | ||
(last_edge.dst in self.host_maps))): | ||
break | ||
else: | ||
input_nodes.append((node.data, node.desc(sdfg))) | ||
if (state.in_degree(node) > 0 and node.data not in output_nodes): | ||
if (state.in_degree(node) > 0 and node.data not in output_nodes and node.data not in self.host_data): | ||
output_nodes.append((node.data, node.desc(sdfg))) | ||
|
||
# Input nodes may also be nodes with WCR memlets and no identity | ||
|
@@ -312,11 +338,13 @@ def apply(self, _, sdfg: sd.SDFG): | |
for node in state.nodes(): | ||
if sdict[node] is None: | ||
if isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)): | ||
node.schedule = dtypes.ScheduleType.GPU_Default | ||
gpu_nodes.add((state, node)) | ||
if node.guid not in self.host_maps and not self._output_or_input_is_marked_host(state, node): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the node is a LibraryNode or NestedSDFG, why would its ID be in |
||
node.schedule = dtypes.ScheduleType.GPU_Default | ||
gpu_nodes.add((state, node)) | ||
elif isinstance(node, nodes.EntryNode): | ||
node.schedule = dtypes.ScheduleType.GPU_Device | ||
gpu_nodes.add((state, node)) | ||
if node.guid not in self.host_maps and not self._output_or_input_is_marked_host(state, node): | ||
node.schedule = dtypes.ScheduleType.GPU_Device | ||
gpu_nodes.add((state, node)) | ||
elif self.sequential_innermaps: | ||
if isinstance(node, (nodes.EntryNode, nodes.LibraryNode)): | ||
node.schedule = dtypes.ScheduleType.Sequential | ||
|
@@ -423,7 +451,8 @@ def apply(self, _, sdfg: sd.SDFG): | |
continue | ||
|
||
# NOTE: the cloned arrays match too but it's the same storage so we don't care | ||
nodedesc.storage = dtypes.StorageType.GPU_Global | ||
if node.data not in self.host_data: | ||
nodedesc.storage = dtypes.StorageType.GPU_Global | ||
|
||
# Try to move allocation/deallocation out of loops | ||
dsyms = set(map(str, nodedesc.free_symbols)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import dace | ||
import pytest | ||
|
||
def create_assign_sdfg(): | ||
sdfg = dace.SDFG('single_iteration_map') | ||
state = sdfg.add_state() | ||
array_size = 1 | ||
A, _ = sdfg.add_array('A', [array_size], dace.float32) | ||
map_entry, map_exit = state.add_map('map_1_iter', {'i': '0:1'}) | ||
tasklet = state.add_tasklet('set_to_1', {}, {'OUT__a'}, '_a = 1') | ||
map_exit.add_in_connector('IN__a') | ||
map_exit.add_out_connector('OUT__a') | ||
tasklet.add_out_connector('OUT__a') | ||
an = state.add_write('A') | ||
state.add_edge(map_entry, None, tasklet, None, dace.Memlet()) | ||
state.add_edge(tasklet, 'OUT__a', map_exit, 'IN__a', dace.Memlet(f'A[0]')) | ||
state.add_edge(map_exit, 'OUT__a', an, None, dace.Memlet(f'A[0]')) | ||
sdfg.validate() | ||
return A, sdfg | ||
|
||
def create_increment_sdfg(): | ||
sdfg = dace.SDFG('increment_map') | ||
state = sdfg.add_state() | ||
array_size = 500 | ||
A, _ = sdfg.add_array('A', [array_size], dace.float32) | ||
map_entry, map_exit = state.add_map('map_1_iter', {'i': f'0:{array_size}'}) | ||
tasklet = state.add_tasklet('inc_by_1', {}, {'OUT__a'}, '_a = _a + 1') | ||
map_entry.add_in_connector('IN__a') | ||
map_entry.add_out_connector('OUT__a') | ||
map_exit.add_in_connector('IN__a') | ||
map_exit.add_out_connector('OUT__a') | ||
tasklet.add_in_connector('IN__a') | ||
tasklet.add_out_connector('OUT__a') | ||
an1 = state.add_read('A') | ||
an2 = state.add_write('A') | ||
state.add_edge(an1, None, map_entry, 'IN__a', dace.Memlet(f'A[i]')) | ||
state.add_edge(map_entry, 'OUT__a', tasklet, 'IN__a', dace.Memlet()) | ||
state.add_edge(tasklet, 'OUT__a', map_exit, 'IN__a', dace.Memlet(f'A[i]')) | ||
state.add_edge(map_exit, 'OUT__a', an2, None, dace.Memlet(f'A[i]')) | ||
sdfg.validate() | ||
return A, sdfg | ||
|
||
@pytest.mark.parametrize("sdfg_creator", [ | ||
create_assign_sdfg, | ||
create_increment_sdfg | ||
]) | ||
class TestHostDataHostMapParams: | ||
def test_host_data(self, sdfg_creator): | ||
"""Test that arrays marked as host_data remain on host after GPU transformation.""" | ||
A, sdfg = sdfg_creator() | ||
sdfg.apply_gpu_transformations(host_data=['A']) | ||
sdfg.validate() | ||
|
||
assert sdfg.arrays[A].storage != dace.dtypes.StorageType.GPU_Global | ||
|
||
def test_host_map(self, sdfg_creator): | ||
"""Test that maps marked as host_maps remain on host after GPU transformation.""" | ||
A, sdfg = sdfg_creator() | ||
host_maps = [ | ||
n.guid for s in sdfg.states() | ||
for n in s.nodes() | ||
if isinstance(n, dace.nodes.EntryNode) | ||
] | ||
sdfg.apply_gpu_transformations(host_maps=host_maps) | ||
sdfg.validate() | ||
assert sdfg.arrays[A].storage != dace.dtypes.StorageType.GPU_Global | ||
|
||
@pytest.mark.parametrize("pass_empty", [True, False]) | ||
def test_no_host_map_or_data(self, sdfg_creator, pass_empty): | ||
"""Test default GPU transformation behavior with no host constraints.""" | ||
A, sdfg = sdfg_creator() | ||
|
||
if pass_empty: | ||
sdfg.apply_gpu_transformations(host_maps=[], host_data=[]) | ||
else: | ||
sdfg.apply_gpu_transformations() | ||
|
||
sdfg.validate() | ||
|
||
# Verify array storage locations | ||
assert 'A' in sdfg.arrays and 'gpu_A' in sdfg.arrays | ||
assert sdfg.arrays['A'].storage != dace.dtypes.StorageType.GPU_Global | ||
assert sdfg.arrays['gpu_A'].storage == dace.dtypes.StorageType.GPU_Global | ||
|
||
# Verify map schedules | ||
for s in sdfg.states(): | ||
for n in s.nodes(): | ||
if isinstance(n, dace.nodes.MapEntry): | ||
assert n.map.schedule == dace.ScheduleType.GPU_Device | ||
|
||
if __name__ == '__main__': | ||
pytest.main([__file__]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't checking memlet paths and adding the src and destination data make more sense here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify, I am considering issues similar to those identified in #1708