Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower shape.cstr_broadcastable op in ShapeLegalizeToHLO #2384

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ An experimental pass that legalizes shape-related ops to StableHLO ops.
Bringing shape and data computations together via an optional pass will
make it possible for the StableHLO ecosystem to potentially leverage the
compilation pipelines that use StableHLO operations to model dynamism.

#### Options
```
-legalize-constraints : Whether to legalize Cstr Ops to shape_assertion custom_call
```
### `-stablehlo-aggressive-folder`

_Folds StableHLO operations_
Expand Down
104 changes: 104 additions & 0 deletions stablehlo/tests/shape_cstr_legalize_to_stablehlo.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// RUN: stablehlo-opt --shape-legalize-to-stablehlo=legalize-constraints=true --split-input-file --verify-diagnostics %s | FileCheck %s

// -----

// CHECK-LABEL: func.func @shape_cstr_broadcastable
func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) {
%0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex>
shape.assuming %0 {
}
func.return
// CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32>
// CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32>
// CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32>
// CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1>
// CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1>
// CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense<true> : tensor<1xi1>
// CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1>
// CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1>
// CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1>
// CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1>
// CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor<i1>
// CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<i1>) -> ()
// CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true
// CHECK-NEXT: shape.assuming %[[WITNESS]] {
// CHECK-NEXT: }
// CHECK-NEXT: return
}

// -----

func.func @shape_cstr_broadcastable_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) {
// expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}}
%0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
func.return
sdasgup3 marked this conversation as resolved.
Show resolved Hide resolved
}

// -----

func.func @shape_cstr_broadcastable_different_dims_1(%arg0: tensor<2xindex>, %arg1: tensor<1xindex>) {
%0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<1xindex>
shape.assuming %0 {
}
func.return
// CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32>
// CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<1xindex> to tensor<1xi32>
// CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32>
// CHECK-NEXT: %[[DIMS2_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
// CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32>
// CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2_PAD]], %[[ONES:.*]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1>
// CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2_PAD]], NOTYPE : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1>
// CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense<true> : tensor<1xi1>
// CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1>
// CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1>
// CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1>
// CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1>
// CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor<i1>
// CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<i1>) -> ()
// CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true
// CHECK-NEXT: shape.assuming %[[WITNESS]] {
// CHECK-NEXT: }
// CHECK-NEXT: return
}

// -----

func.func @shape_cstr_broadcastable_different_dims_2(%arg0: tensor<1xindex>, %arg1: tensor<0xindex>) {
%0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<1xindex>, tensor<0xindex>
shape.assuming %0 {
}
func.return
// CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<1xindex> to tensor<1xi32>
// CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<0xindex> to tensor<0xi32>
// CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32>
// CHECK-NEXT: %[[DIMS2_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS2]], dim = 0 : (tensor<1xi32>, tensor<0xi32>) -> tensor<1xi32>
// CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<1xi32>
// CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]], NOTYPE : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2_PAD]], %[[ONES:.*]], NOTYPE : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<1xi1>
// CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2_PAD]], NOTYPE : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<1xi1>
// CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense<true> : tensor<1xi1>
// CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<1xi1>) -> tensor<1xi1>
// CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1>
// CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor<i1>
// CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<i1>) -> ()
// CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true
// CHECK-NEXT: shape.assuming %[[WITNESS]] {
// CHECK-NEXT: }
// CHECK-NEXT: return
mlevesquedion marked this conversation as resolved.
Show resolved Hide resolved
}

// -----

func.func @shape_cstr_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) {
// expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}}
%0 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex>
func.return
}
7 changes: 6 additions & 1 deletion stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,15 @@ void populateStablehloLegalizeDeprecatedOpsPatterns(

/// Collection of shape dialect to StableHLO patterns.
void populateShapeToStablehloPatterns(MLIRContext *context,
RewritePatternSet *patterns);
RewritePatternSet *patterns,
bool legalizeConstraints);

//// Additional pass constructors ////

// Legalizes from the Shape dialect to the StableHLO dialect.
std::unique_ptr<mlir::OperationPass<func::FuncOp>>
createShapeLegalizeToStablehloPass(bool legalizeConstraints);

std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
TypeRange refinedTypes);

Expand Down
4 changes: 4 additions & 0 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def ShapeLegalizeToStablehloPass : Pass<"shape-legalize-to-stablehlo", "func::Fu
compilation pipelines that use StableHLO operations to model dynamism.
}];
let dependentDialects = ["mlir::stablehlo::StablehloDialect"];
let options = [
Option<"legalize_constraints_", "legalize-constraints", "bool",
/*default=*/"false", "Whether to legalize Cstr Ops to shape_assertion custom_call">
];
}

def StablehloLegalizeDeprecatedOpsPass : Pass<"stablehlo-legalize-deprecated-ops", "func::FuncOp"> {
Expand Down
107 changes: 105 additions & 2 deletions stablehlo/transforms/ShapeLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ Value castToIndex(PatternRewriter& rewriter, Location loc, Value value) {
return cast.getResult(0);
}

void insertShapeAssertionCustomCall(OpBuilder builder, Location loc,
Value assert) {
auto customCall = builder.create<stablehlo::CustomCallOp>(loc, TypeRange{},
ValueRange{assert});
customCall.setCallTargetName("shape_assertion");
customCall.setHasSideEffect(true);
customCall->setAttr("error_message",
builder.getStringAttr("Shape assertion failed"));
mlevesquedion marked this conversation as resolved.
Show resolved Hide resolved
}

Value maybeCastToIndex(Value result, Value value, PatternRewriter& rewriter) {
if (isShapedOfI32(result)) return value;
return castToIndex(rewriter, value.getLoc(), value);
Expand Down Expand Up @@ -491,6 +501,75 @@ struct ConvertTensorFromElementsPattern
}
};

struct ConvertCstrBroadcastableOp
: public OpRewritePattern<shape::CstrBroadcastableOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter& rewriter) const override {
// The way CstrBroadcastableOp is defined, its inputs inputs must be 1D
// tensor or !shape.shape. We only support inputs of two 1D tensors.
if (op.getShapes().size() != 2) return failure();
auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front());
auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back());
if (!shape1 || !shape2) return failure();
auto tensorType1 = dyn_cast<RankedTensorType>(shape1.getType());
auto tensorType2 = dyn_cast<RankedTensorType>(shape2.getType());
if (!tensorType1 || !tensorType2) return failure();

// If the two operand shapes are of different sizes, the smaller one is
// padded with 1's from the left.
if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) {
shape1 =
padFromLeft(rewriter, op.getLoc(), shape1,
tensorType2.getDimSize(0) - tensorType1.getDimSize(0));
} else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) {
shape2 =
padFromLeft(rewriter, op.getLoc(), shape2,
tensorType1.getDimSize(0) - tensorType2.getDimSize(0));
}

// Compute if each dim is broadcastable. A dim is broadcastable iff
// dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1
int32_t rank =
std::max(tensorType1.getDimSize(0), tensorType2.getDimSize(0));
auto allOne = rewriter.create<stablehlo::ConstantOp>(
op.getLoc(), DenseIntElementsAttr::get<int32_t>(
RankedTensorType::get({rank}, rewriter.getI32Type()),
static_cast<int32_t>(1)));
Value dimSize1Is1 = rewriter.create<stablehlo::CompareOp>(
op.getLoc(), shape1, allOne, ComparisonDirection::EQ);
Value dimSize2Is1 = rewriter.create<stablehlo::CompareOp>(
op.getLoc(), shape2, allOne, ComparisonDirection::EQ);
Value eitherDimSizeIs1 =
rewriter.create<stablehlo::OrOp>(op.getLoc(), dimSize1Is1, dimSize2Is1);
Value dimSizeEq = rewriter.create<stablehlo::CompareOp>(
op.getLoc(), shape1, shape2, ComparisonDirection::EQ);
Value dimBroadcastable = rewriter.create<stablehlo::OrOp>(
op.getLoc(), eitherDimSizeIs1, dimSizeEq);

// Iterate over each dim to check that all dims are broadcastable.
auto boolType = RankedTensorType::get({1}, rewriter.getI1Type());
Value allBroadcastable = rewriter.create<stablehlo::ConstantOp>(
op.getLoc(), DenseIntElementsAttr::get<bool>(boolType, true));
for (auto i = 0; i < rank; ++i) {
Value broadcastable =
rewriter.create<SliceOp>(op.getLoc(), dimBroadcastable, i, i + 1, 1);
allBroadcastable =
rewriter.create<AndOp>(op.getLoc(), allBroadcastable, broadcastable);
}
Value allBroadcastableScalar = rewriter.create<ReshapeOp>(
op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()),
allBroadcastable);
mlevesquedion marked this conversation as resolved.
Show resolved Hide resolved

// Add CustomCallOp and replace Cstr op with const witness, which is useful
// for canonicalizer to remove the shape.assuming region.
insertShapeAssertionCustomCall(rewriter, op->getLoc(),
allBroadcastableScalar);
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
return success();
}
};

template <typename OpType>
struct CastOperandsPattern : public OpRewritePattern<OpType> {
using OpRewritePattern<OpType>::OpRewritePattern;
Expand Down Expand Up @@ -523,6 +602,12 @@ struct ShapeLegalizeToStablehloPass
ShapeLegalizeToStablehloPass> {
using ShapeLegalizeToStablehloPassBase::ShapeLegalizeToStablehloPassBase;

explicit ShapeLegalizeToStablehloPass(bool legalizeConstraints)
: impl::ShapeLegalizeToStablehloPassBase<
ShapeLegalizeToStablehloPass>::ShapeLegalizeToStablehloPassBase() {
this->legalize_constraints_ = legalizeConstraints;
}

LogicalResult initialize(MLIRContext* context) override {
// In order to make dynamic StableHLO programs compatible with HLO, we need
// to get rid of all non-StableHLO ops.
Expand All @@ -548,6 +633,10 @@ struct ShapeLegalizeToStablehloPass
// is able to remove unnecessary cruft. At the moment, this pass is a
// work in progress, so not all of these ops are supported.
//
// When legalize_constraints_ is set true, cstr* ops are also legalized.
// A shape_assertion custom_call is used to check the constraint. And the
// shape.assuming region will consume a shape.const_witness that evaluate to
// true, so that it can be removed later in a canonicalizer pass.
target = std::make_shared<ConversionTarget>(*context);
target->addIllegalDialect<shape::ShapeDialect>();
target->addIllegalDialect<tensor::TensorDialect>();
Expand All @@ -559,6 +648,10 @@ struct ShapeLegalizeToStablehloPass
});
target->addLegalOp<tensor::CastOp>();
target->addLegalOp<UnrealizedConversionCastOp>();
if (this->legalize_constraints_) {
target->addLegalOp<shape::ConstWitnessOp, shape::AssumingOp,
shape::AssumingYieldOp>();
}
Comment on lines +651 to +654
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it would be cleaner to consolidate the handling of this->legalize_constraints_ into a single if. If it's true, add the pattern to the set of patterns and add the legal ops.

Copy link
Member Author

@sdasgup3 sdasgup3 Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. The current choice is driven by how the code is currenty structured: addition of legal ops handled separately from adding the pattern. Using a single if would create an outlier for legalize_contriants_ case. wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think both takes are valid, but IMO in this case it would be more readable to have a single if since that would make it clearer what happens when legalize_constraints_ is true. Up to you though.


// The patterns do what one might expect, converting between MLIR-style
// and HLO-style shape computations.
Expand All @@ -569,7 +662,8 @@ struct ShapeLegalizeToStablehloPass
// to ultimately annihilate with each other upon canonicalization if
// everything went right.
RewritePatternSet patterns_(context);
populateShapeToStablehloPatterns(context, &patterns_);
populateShapeToStablehloPatterns(context, &patterns_,
this->legalize_constraints_);
patterns = std::move(patterns_);

return success();
Expand All @@ -588,7 +682,8 @@ struct ShapeLegalizeToStablehloPass
} // namespace

void populateShapeToStablehloPatterns(MLIRContext* context,
RewritePatternSet* patterns) {
RewritePatternSet* patterns,
bool legalizeConstraints) {
patterns->add<ConvertConstShapeOpPattern>(context);
patterns->add<ConvertMulIOpPattern>(context);
patterns->add<ConvertIndexCastOpPattern>(context);
Expand All @@ -600,6 +695,14 @@ void populateShapeToStablehloPatterns(MLIRContext* context,
patterns->add<ConvertTensorDimPattern>(context);
patterns->add<ConvertTensorExtractPattern>(context);
patterns->add<ConvertTensorFromElementsPattern>(context);
if (legalizeConstraints) {
patterns->add<ConvertCstrBroadcastableOp>(context);
}
}

std::unique_ptr<mlir::OperationPass<func::FuncOp>>
createShapeLegalizeToStablehloPass(bool legalizeConstraints) {
return std::make_unique<ShapeLegalizeToStablehloPass>(legalizeConstraints);
}

} // namespace stablehlo
Expand Down
Loading