Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Transform] Consolidate result of structured.split into one list #111171

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,10 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
{ target_size = 10, dimension = 1 }
: !transform.any_op, !transform.param<i64>,
!transform.param<i64>, !transform.param<i64>
%low, %high = structured.split %target after %split { dimension = 1 }
%handles = structured.split %target after %split { dimension = 1 }
: !transform.any_op, !transform.param<i64>
%low, %high = transform.split_handle %handles : (!transform.any_op)
-> (!transform.any_op, !transform.any_op)
%tiled_low, %loop1 = structured.tile_using_for %low [0, %sz1]
: (!transform.any_op, !transform.param<i64>)
-> (!transform.any_op, !transform.any_op)
Expand Down Expand Up @@ -1452,30 +1454,32 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
operations pointed to by the target handle.

The operation consumes the target handle, but preserves the chunk size
handle if provided. Without the `multiway` attribute, it produces two
new handles pointing to the two parts of the structured op after splitting,
in the same order as the target operand, with the first handle
corresponding to the part with lower iteration space indices.
handle if provided. Without the `multiway` attribute, it produces a
new handle that is a list of the two parts of the structured op after
splitting, whose lower index part corresponding to the part with lower
iteration space indices.

Multiway split mode is enabled by specifying the `multiway` attribute.
In this mode a single `target` op is split into multiple parts covering
the iteration space of the specified dimension. `static_chunk_sizes` and
`dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
dimension should be split into. With `multiway` it produces two handles;
the first handle is a list of the multiple parts of the structured op
dimension should be split into. With `multiway` it also produces a handle;
The result handle is a list of the multiple parts of the structured op
after splitting, where the target dimensions for each linalg op in the
list corresponds to the chunk sizes specfied in the input split list.
If the chunk sizes do not cover the entire iteration space, the leftover
chunk is the last payload in the first handle. The second handle is empty.
chunk is the last payload in the result handle.

As the result handle is most of time a list, an `transform.split_handle`
is needed to access individual handle.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
I64Attr:$dimension,
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
I64Attr:$static_chunk_sizes,
UnitAttr:$multiway);
let results = (outs TransformHandleTypeInterface:$first,
TransformHandleTypeInterface:$second);
let results = (outs TransformHandleTypeInterface:$split_list);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2363,10 +2363,10 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
};

SmallVector<Operation *> opList;
if (isMultiwaySplit) {

// Split a single target operation at multiple points.
SmallVector<Operation *> opList;
TilingInterface head, tail;
Operation *target = payload.front();

Expand Down Expand Up @@ -2406,8 +2406,6 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// Append any leftover parts to the end of the result list.
if (tail)
opList.push_back(tail.getOperation());
results.set(cast<OpResult>(getFirst()), opList);
results.set(cast<OpResult>(getSecond()), {});

} else {
// Split each target operation.
Expand Down Expand Up @@ -2453,9 +2451,11 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
return diag;
}

results.set(cast<OpResult>(getFirst()), first);
results.set(cast<OpResult>(getSecond()), second);
opList.append(first);
if (second.size())
opList.append(second);
}
results.set(cast<OpResult>(getSplitList()), opList);
return DiagnosedSilenceableFailure::success();
}

Expand Down Expand Up @@ -2507,7 +2507,7 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute(
SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
staticChunkSizes);
result.addTypes({targetType, targetType});
result.addTypes(targetType);
return success();
}

Expand Down
1 change: 0 additions & 1 deletion mlir/python/mlir/dialects/transform/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,6 @@ def __init__(
dynamic_chunk_sizes = chunk_sizes

super().__init__(
target.type,
target.type,
target,
dimension=dimension,
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
%linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.any_op {
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
Expand Down Expand Up @@ -65,7 +65,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.param<i64>
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param<i64>
%linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param<i64>
transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.param<i64> {
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.param<i64>):
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
Expand Down Expand Up @@ -126,7 +126,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
%linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
transform.foreach %linalg_splits, %tile_sizes with_zip_shortest : !transform.any_op, !transform.any_op {
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
Expand Down Expand Up @@ -177,4 +177,4 @@ func.func @continuous_tile_dynamic_linalg_matmul(
// CHECK: %[[AM16:.*]] = affine.min #[[$MAP12]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]], %[[AM12]]]
// CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM16]] step %[[C1]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
// CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<1x?xf32>) -> tensor<1x?xf32>
// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>
// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op
%low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
%splits2 = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
transform.yield
}
}
Expand Down Expand Up @@ -58,7 +58,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param<i64>
%low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
%splits2 = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
transform.yield
}
}
Expand Down
12 changes: 8 additions & 4 deletions mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
%split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
%2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
transform.foreach %5 : !transform.any_op {
^bb0(%inner_linalg: !transform.any_op):
%low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
%inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
%split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
%inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
Expand Down Expand Up @@ -111,14 +113,16 @@ module attributes {transform.with_named_sequence} {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.param<i64>
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.param<i64>
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64>
%split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64>
%2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
%tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
%inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
%split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
%inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
}
Expand Down
13 changes: 7 additions & 6 deletions mlir/test/Dialect/Linalg/transform-op-split.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
%1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
transform.yield
}
}
Expand Down Expand Up @@ -53,7 +53,7 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
%1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
transform.yield
}
}
Expand Down Expand Up @@ -138,8 +138,9 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1:2 = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op
%2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op
%t = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op
%1:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op
transform.yield
}
}
Expand Down Expand Up @@ -197,7 +198,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
// expected-error @below {{expects either a dynamic or a static split point to be provided}}
%0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%0 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
Expand Down Expand Up @@ -303,7 +304,7 @@ module attributes {transform.with_named_sequence} {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{splitting does not produce the second part for a subset of targets}}
// expected-note @below {{expected splitting to produce the second part of all or none of the targets}}
%1:2 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op
%1 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op
transform.yield
}
}
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/Dialect/Linalg/transform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ transform.sequence failures(propagate) {

transform.sequence failures(propagate) {
^bb1(%arg0: !transform.any_op):
%0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
%t = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
%0:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.split %0#0 after %0#1 { dimension = 1 } : !transform.any_op, !transform.any_op
}

Expand Down
10 changes: 7 additions & 3 deletions mlir/test/python/dialects/transform_structured_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,15 @@ def testScalarize(target):
@run
@create_sequence
def testSplit(target):
split = structured.SplitOp(target, dimension=1, chunk_sizes=42)
handle = structured.SplitOp(target, dimension=1, chunk_sizes=42)
split = transform.SplitHandleOp(
[transform.AnyOpType.get(), transform.AnyOpType.get()], handle
)
structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1])
# CHECK-LABEL: TEST: testSplit
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
# CHECK: %[[G:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
# CHECK: %[[F:.+]]:2 = split_handle %[[G]]
# CHECK: transform.structured.split %[[F]]#0 after %[[F]]#1 {dimension = 3


@run
Expand Down
Loading