From 4fa5cba389ae35ae14fbd203bf9e776d73324d79 Mon Sep 17 00:00:00 2001 From: Jinyun Joey Ye Date: Fri, 4 Oct 2024 22:40:37 +0800 Subject: [PATCH] [MLIR][Transform] Consolidate result of structured.split into one list E.g.: ``` %0:2 = transform.structured.split ``` is changed to ``` %t = transform.structured.split %0:2 = transform.split_handle %t ``` --- .../Linalg/TransformOps/LinalgTransformOps.td | 24 +++++++++++-------- .../TransformOps/LinalgTransformOps.cpp | 12 +++++----- .../mlir/dialects/transform/structured.py | 1 - .../Linalg/continuous-tiling-full.mlir | 8 +++---- .../continuous-tiling-multiway-split.mlir | 4 ++-- .../Dialect/Linalg/multisize-tiling-full.mlir | 12 ++++++---- .../Dialect/Linalg/transform-op-split.mlir | 13 +++++----- mlir/test/Dialect/Linalg/transform-ops.mlir | 3 ++- .../dialects/transform_structured_ext.py | 10 +++++--- 9 files changed, 50 insertions(+), 37 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index a997502c34299c..c01248f70195e3 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -676,8 +676,10 @@ def MultiTileSizesOp : Op, !transform.param, !transform.param - %low, %high = structured.split %target after %split { dimension = 1 } + %handles = structured.split %target after %split { dimension = 1 } : !transform.any_op, !transform.param + %low, %high = transform.split_handle %handles : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) %tiled_low, %loop1 = structured.tile_using_for %low [0, %sz1] : (!transform.any_op, !transform.param) -> (!transform.any_op, !transform.any_op) @@ -1422,21 +1424,24 @@ def SplitOp : Op:$dynamic_chunk_sizes, I64Attr:$static_chunk_sizes, UnitAttr:$multiway); - let results = (outs TransformHandleTypeInterface:$first, - TransformHandleTypeInterface:$second); + let results = (outs TransformHandleTypeInterface:$split_list); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 0b9223013a0f1b..15af963f1ff3df 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2348,10 +2348,10 @@ SplitOp::apply(transform::TransformRewriter &rewriter, return DiagnosedSilenceableFailure::success(); }; + SmallVector opList; if (isMultiwaySplit) { // Split a single target operation at multiple points. - SmallVector opList; TilingInterface head, tail; Operation *target = payload.front(); @@ -2391,8 +2391,6 @@ SplitOp::apply(transform::TransformRewriter &rewriter, // Append any leftover parts to the end of the result list. if (tail) opList.push_back(tail.getOperation()); - results.set(cast(getFirst()), opList); - results.set(cast(getSecond()), {}); } else { // Split each target operation. @@ -2438,9 +2436,11 @@ SplitOp::apply(transform::TransformRewriter &rewriter, return diag; } - results.set(cast(getFirst()), first); - results.set(cast(getSecond()), second); + opList.append(first); + if (second.size()) + opList.append(second); } + results.set(cast(getSplitList()), opList); return DiagnosedSilenceableFailure::success(); } @@ -2492,7 +2492,7 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { result.addAttribute( SplitOp::getStaticChunkSizesAttrName(result.name).getValue(), staticChunkSizes); - result.addTypes({targetType, targetType}); + result.addTypes(targetType); return success(); } diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index 41051c0d5b2ffb..1aae669c3437cf 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -445,7 +445,6 @@ def __init__( dynamic_chunk_sizes = chunk_sizes super().__init__( - target.type, target.type, target, dimension=dimension, diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir index 7410ff593d01a2..e02aa0c4db44af 100644 --- a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir +++ b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir @@ -4,7 +4,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op - %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op + %linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.any_op { ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op): %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) @@ -65,7 +65,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.param - %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param + %linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.param { ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.param): %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.param) -> (!transform.any_op, !transform.any_op) @@ -126,7 +126,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op - %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op + %linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op transform.foreach %linalg_splits, %tile_sizes with_zip_shortest : !transform.any_op, !transform.any_op { ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op): %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) @@ -177,4 +177,4 @@ func.func @continuous_tile_dynamic_linalg_matmul( // CHECK: %[[AM16:.*]] = affine.min #[[$MAP12]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]], %[[AM12]]] // CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM16]] step %[[C1]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor) { // CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x?xf32>, tensor) outs(%{{.*}} : tensor<1x?xf32>) -> tensor<1x?xf32> -// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor \ No newline at end of file +// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir index 609766fbdc91f2..12fe8a2a2b6b5c 100644 --- a/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir +++ b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir @@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op - %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op + %splits2 = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op transform.yield } } @@ -58,7 +58,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param - %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param + %splits2 = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param transform.yield } } diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir index 51332ffce03d1d..af041db9eeffbf 100644 --- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir +++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir @@ -6,14 +6,16 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op - %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op + %split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op + %2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op transform.foreach %5 : !transform.any_op { ^bb0(%inner_linalg: !transform.any_op): %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op - %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op + %split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op + %inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) } @@ -111,14 +113,16 @@ module attributes {transform.with_named_sequence} { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.param %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.param - %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param + %split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param + %2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.param) -> (!transform.any_op, !transform.any_op) %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param) -> (!transform.any_op, !transform.any_op) %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param, !transform.param, !transform.param transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param, !transform.param, !transform.param { ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param, %high: !transform.param, %split_point: !transform.param): - %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param + %split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param + %inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param) -> (!transform.any_op, !transform.any_op) transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param) -> (!transform.any_op, !transform.any_op) } diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir index e072fff4c5d771..68c849385ba6b5 100644 --- a/mlir/test/Dialect/Linalg/transform-op-split.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir @@ -3,7 +3,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op + %1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op transform.yield } } @@ -53,7 +53,7 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op + %1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op transform.yield } } @@ -138,8 +138,9 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1:2 = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op - %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op + %t = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op + %1:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op transform.yield } } @@ -197,7 +198,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>, module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) { // expected-error @below {{expects either a dynamic or a static split point to be provided}} - %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %0 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op) transform.yield } } @@ -303,7 +304,7 @@ module attributes {transform.with_named_sequence} { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{splitting does not produce the second part for a subset of targets}} // expected-note @below {{expected splitting to produce the second part of all or none of the targets}} - %1:2 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op + %1 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op transform.yield } } diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir index c152fc887a3a39..06a89fccd5c383 100644 --- a/mlir/test/Dialect/Linalg/transform-ops.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -18,7 +18,8 @@ transform.sequence failures(propagate) { transform.sequence failures(propagate) { ^bb1(%arg0: !transform.any_op): - %0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op + %t = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op + %0:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.structured.split %0#0 after %0#1 { dimension = 1 } : !transform.any_op, !transform.any_op } diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 3ea73e8beea368..22a09867231d1f 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -361,11 +361,15 @@ def testScalarize(target): @run @create_sequence def testSplit(target): - split = structured.SplitOp(target, dimension=1, chunk_sizes=42) + handle = structured.SplitOp(target, dimension=1, chunk_sizes=42) + split = transform.SplitHandleOp( + [transform.AnyOpType.get(), transform.AnyOpType.get()], handle + ) structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1]) # CHECK-LABEL: TEST: testSplit - # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 - # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3 + # CHECK: %[[G:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 + # CHECK: %[[F:.+]]:2 = split_handle %[[G]] + # CHECK: transform.structured.split %[[F]]#0 after %[[F]]#1 {dimension = 3 @run