Skip to content

Commit

Permalink
Reorder the nodes in the graph to match the dim expansion order
Browse files Browse the repository at this point in the history
This PR modifies the order in which nodes are inserted into
the graph to follow the canonical Cartesian product. This
is especially important for the iter args inside a for loop
that (without this PR) are not in the same order as the
init args and outputs. By reordering the nodes, we ensure
that the init args, iter args and outputs are all in the
same order and map 1-1.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Sep 21, 2024
1 parent 672fe45 commit 6c2a080
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 172 deletions.
62 changes: 31 additions & 31 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,39 +91,39 @@ def test_read_write_equal_sizes():
# CHECK-NEXT: %c
# CHECK-NEXT: %read_0_0
# CHECK-SAME: (%a, 4, None, None)
# CHECK-NEXT: %read_1_1
# CHECK-NEXT: %read_0_1
# CHECK-SAME: (%a, 4, None, None)
# CHECK-NEXT: %read_1_0
# CHECK-SAME: (%a, 4, None, None)
# CHECK-NEXT: %read_0_1
# CHECK-NEXT: %read_1_1
# CHECK-SAME: (%a, 4, None, None)
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %write_shared_0_0
# CHECK-SAME: (%read_0_0, %allocate, 4, None)
# CHECK-NEXT: %write_shared_1_1
# CHECK-SAME: (%read_1_1, %allocate, 4, None)
# CHECK-NEXT: %write_shared_1_0
# CHECK-SAME: (%read_1_0, %allocate, 4, None)
# CHECK-NEXT: %write_shared_0_1
# CHECK-SAME: (%read_0_1, %allocate, 4, None)
# CHECK-NEXT: %write_shared_1_0
# CHECK-SAME: (%read_1_0, %allocate, 4, None)
# CHECK-NEXT: %write_shared_1_1
# CHECK-SAME: (%read_1_1, %allocate, 4, None)
# CHECK-NEXT: %shared_memory_barrier
# CHECK-NEXT: %read_shared_0_0
# CHECK-SAME: (%allocate, 4, None, [%write_shared_0_0])
# CHECK-NEXT: %read_shared_1_1
# CHECK-SAME: (%allocate, 4, None, [%write_shared_1_1])
# CHECK-NEXT: %read_shared_1_0
# CHECK-SAME: (%allocate, 4, None, [%write_shared_1_0])
# CHECK-NEXT: %read_shared_0_1
# CHECK-SAME: (%allocate, 4, None, [%write_shared_0_1])
# CHECK-NEXT: %read_shared_1_0
# CHECK-SAME: (%allocate, 4, None, [%write_shared_1_0])
# CHECK-NEXT: %read_shared_1_1
# CHECK-SAME: (%allocate, 4, None, [%write_shared_1_1])
# CHECK-NEXT: %write_0_0
# CHECK-SAME: (%read_shared_0_0, %c, 4, None)
# CHECK-NEXT: %write_1_1
# CHECK-SAME: (%read_shared_1_1, %c, 4, None)
# CHECK-NEXT: %write_1_0
# CHECK-SAME: (%read_shared_1_0, %c, 4, None)
# CHECK-NEXT: %write_0_1
# CHECK-SAME: (%read_shared_0_1, %c, 4, None)
# CHECK-NEXT: %write_1_0
# CHECK-SAME: (%read_shared_1_0, %c, 4, None)
# CHECK-NEXT: %write_1_1
# CHECK-SAME: (%read_shared_1_1, %c, 4, None)
# CHECK-NEXT: return None

# CHECK: -----
Expand Down Expand Up @@ -178,38 +178,38 @@ def test_gemm():
# CHECK-NEXT: %b
# CHECK-NEXT: %c
# CHECK-NEXT: %register_0_0_0
# CHECK-NEXT: %register_1_1_0
# CHECK-NEXT: %register_1_0_0
# CHECK-NEXT: %register_0_1_0
# CHECK-NEXT: %register_1_0_0
# CHECK-NEXT: %register_1_1_0
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %allocate_1
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: reduction
# CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0]
# CHECK-NEXT: %getresult_1_1_0
# CHECK-SAME: (%reduction, 3)
# CHECK-NEXT: %getresult_1_0_0
# CHECK-SAME: (%reduction, 2)
# CHECK-SAME (K, [%register_0_0_0, %register_0_1_0, %register_1_0_0, %register_1_1_0]
# CHECK-NEXT: %getresult_0_1_0
# CHECK-SAME: (%reduction, 1)
# CHECK-NEXT: %getresult_1_0_0
# CHECK-SAME: (%reduction, 2)
# CHECK-NEXT: %getresult_1_1_0
# CHECK-SAME: (%reduction, 3)
# CHECK-NEXT: %getresult_0_0_0
# CHECK-SAME: (%reduction, 0)
# CHECK-NEXT: %write_0_0_0
# CHECK-SAME: (%getresult_0_0_0, %c, 4, None)
# CHECK-NEXT: %write_1_1_0
# CHECK-SAME: (%getresult_1_1_0, %c, 4, None)
# CHECK-NEXT: %write_1_0_0
# CHECK-SAME: (%getresult_1_0_0, %c, 4, None)
# CHECK-NEXT: %write_0_1_0
# CHECK-SAME: (%getresult_0_1_0, %c, 4, None)
# CHECK-NEXT: %write_1_0_0
# CHECK-SAME: (%getresult_1_0_0, %c, 4, None)
# CHECK-NEXT: %write_1_1_0
# CHECK-SAME: (%getresult_1_1_0, %c, 4, None)
# CHECK-NEXT: return None

# Reduction subgraph:
# CHECK: %acc_0_0_0
# CHECK-NEXT: %acc_1_1_0
# CHECK-NEXT: %acc_1_0_0
# CHECK-NEXT: %acc_0_1_0
# CHECK-NEXT: %acc_1_0_0
# CHECK-NEXT: %acc_1_1_0
# CHECK-NEXT: %a
# CHECK-NEXT: %read_0_0_0
# CHECK-NEXT: %read_0_0_1
Expand Down Expand Up @@ -240,12 +240,12 @@ def test_gemm():
# CHECK-NEXT: %read_shared_0_1_1
# CHECK-NEXT: %mma_0_0_0
# CHECK-NEXT: %mma_0_0_1
# CHECK-NEXT: %mma_1_1_0
# CHECK-NEXT: %mma_1_1_1
# CHECK-NEXT: %mma_1_0_0
# CHECK-NEXT: %mma_1_0_1
# CHECK-NEXT: %mma_0_1_0
# CHECK-NEXT: %mma_0_1_1
# CHECK-NEXT: %mma_1_0_0
# CHECK-NEXT: %mma_1_0_1
# CHECK-NEXT: %mma_1_1_0
# CHECK-NEXT: %mma_1_1_1


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 6c2a080

Please sign in to comment.