Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added all possible elementwise binary ops for #316. #344

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/imex/Dialect/PTensor/IR/PTensorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ enum EWBinOpId : int {
LOGICAL_AND,
LOGICAL_OR,
LOGICAL_XOR,
LSHIFT,
// LSHIFT,
MATMUL,
MAXIMUM,
MINIMUM,
Expand All @@ -53,6 +53,7 @@ enum EWBinOpId : int {
NOT_EQUAL,
OR,
POWER,
// RSHIFT,
SUBTRACT,
TRUE_DIVIDE,
XOR,
Expand Down
91 changes: 75 additions & 16 deletions lib/Conversion/PTensorToLinalg/PTensorToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/Dialect/Math/IR/Math.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/Shape/IR/Shape.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/Dialect/Tosa/IR/TosaOps.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/Pass/Pass.h>

Expand Down Expand Up @@ -293,19 +295,63 @@ static BodyType buildTrivial(::mlir::Type typ) {
};
}

// Builder for TOSA body.
// The builder function takes an extra agrument ::mlir::Type
// Many TOSA functions take two arguments and do the
// same operation on both arguments.
Comment on lines +300 to +301
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the functionality we need? Why is TOSA any better than math and arith?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fschlimb I found those logical ops only in TOSA, is there any other alternative?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. An alternative is of course to write them manually.

// TODO:
// 1. Find a way to merge this function with trivialBuilder().
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we need buildTrivial if we use TOSA?

Copy link
Author

@chudur-budur chudur-budur Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fschlimb That's my thought too, should we just move to TOSA completely?

Copy link
Contributor

@fschlimb fschlimb Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I investigated I realized that TOSA does not have unsigned types. I could not figure out if that's a gap or a design decision. If we can assume that TOSA can support unsigned ints then the only reason for not using TOSA might be compile-time. But I tend to think that's probably acceptable at this point.

// 2. Detect if there is rhs2, if yes, build accordingly.
// 3. We need to introduce a Boolean type.
template <typename IOP, typename FOP = void>
static BodyType buildTosa(::mlir::Type typ) {
return [typ](mlir::OpBuilder &builder, ::mlir::Location loc,
::mlir::ValueRange args) -> void {
auto lhs_typ = args[0].getType();
if (lhs_typ.isIntOrIndex()) {
if constexpr (!std::is_same_v<IOP, void>) {
auto lhs = doSignCast(builder, loc, args[0]);
auto rhs1 = doSignCast(builder, loc, args[1]);
// auto rhs2 = doSignCast(builder, loc, args[2]); // sometimes there can
// be two params
yield(builder, loc, typ,
builder.create<IOP>(loc, typ, lhs, rhs1).getResult());
return;
} else
assert("Found integer type but binary op not defined for integers" ==
nullptr);
} else if (lhs_typ.isIntOrIndexOrFloat()) {
if constexpr (!std::is_same_v<FOP, void>) {
yield(builder, loc, typ,
builder.create<FOP>(loc, typ, args[0], args[1]).getResult());
return;
} else
assert("Found float type but binary op not defined for floats" ==
nullptr);
} else {
assert("Only integers and floats supported for binary ops" == nullptr);
}
};
}

// get a body builder for given binary operation and result type
// we accept a result type to insert a cast after the operation if needed
static BodyType getBodyBuilder(::imex::ptensor::EWBinOpId bop,
::mlir::Type typ) {
switch (bop) {
case ptensor::ADD:
return buildTrivial<mlir::arith::AddIOp, mlir::arith::AddFOp>(typ);
// case ptensor::ATAN2] =
case ptensor::ATAN2:
return buildTrivial<void, mlir::math::Atan2Op>(typ);
case ptensor::FLOOR_DIVIDE:
return buildTrivial<mlir::arith::FloorDivSIOp>(typ);
// case ptensor::LOGADDEXP] =
// case ptensor::LSHIFT] =
// case ptensor::MATMUL] =
// case ptensor::LSHIFT:
// return buildTosa<mlir::tosa::LogicalLeftShiftOp, void>(typ);
// case ptensor::RSHIFT:
// return buildTosa<mlir::tosa::LogicalRightShiftOp, void>(typ);
case ptensor::MATMUL:
return buildTosa<void, mlir::tosa::MatMulOp>(typ);
case ptensor::MAXIMUM:
return buildTrivial<mlir::arith::MaxSIOp, mlir::arith::MaxFOp>(typ);
case ptensor::MINIMUM:
Expand All @@ -314,24 +360,35 @@ static BodyType getBodyBuilder(::imex::ptensor::EWBinOpId bop,
return buildTrivial<mlir::arith::RemSIOp, mlir::arith::RemFOp>(typ);
case ptensor::MULTIPLY:
return buildTrivial<mlir::arith::MulIOp, mlir::arith::MulFOp>(typ);
// case ptensor::POW] =
case ptensor::POWER:
return buildTrivial<mlir::math::IPowIOp, mlir::math::PowFOp>(typ);
case ptensor::SUBTRACT:
return buildTrivial<mlir::arith::SubIOp, mlir::arith::SubFOp>(typ);
// case ptensor::TRUE_DIVIDE] =
// case ptensor::BITWISE_AND] =
// case ptensor::BITWISE_LEFT_SHIFT] =
// case ptensor::BITWISE_OR] =
// case ptensor::BITWISE_RIGHT_SHIFT] =
// case ptensor::BITWISE_XOR] =

// case ptensor::EQUAL] =
// case ptensor::GREATER] =
// case ptensor::GREATER_EQUAL] =
case ptensor::BITWISE_AND:
return buildTosa<mlir::tosa::BitwiseAndOp, void>(typ);
case ptensor::BITWISE_LEFT_SHIFT:
return buildTosa<mlir::tosa::LogicalLeftShiftOp, void>(typ);
case ptensor::BITWISE_OR:
return buildTosa<mlir::tosa::BitwiseOrOp, void>(typ);
case ptensor::BITWISE_RIGHT_SHIFT:
return buildTosa<mlir::tosa::LogicalRightShiftOp, void>(typ);
case ptensor::BITWISE_XOR:
return buildTosa<mlir::tosa::BitwiseXorOp, void>(typ);
case ptensor::EQUAL:
return buildTosa<mlir::tosa::EqualOp, void>(typ);
case ptensor::GREATER:
return buildTosa<mlir::tosa::GreaterOp, void>(typ);
case ptensor::GREATER_EQUAL:
return buildTosa<mlir::tosa::GreaterEqualOp, void>(typ);
// case ptensor::LESS] =
// case ptensor::LESS_EQUAL] =
// case ptensor::LOGICAL_AND] =
// case ptensor::LOGICAL_OR] =
// case ptensor::LOGICAL_XOR] =
case ptensor::LOGICAL_AND:
return buildTosa<mlir::tosa::LogicalAndOp, void>(typ);
case ptensor::LOGICAL_OR:
return buildTosa<mlir::tosa::LogicalOrOp, void>(typ);
case ptensor::LOGICAL_XOR:
return buildTosa<mlir::tosa::LogicalXorOp, void>(typ);
// case ptensor::NOT_EQUAL] =
default:
assert("unsupported elementwise binary operation" == nullptr);
Expand Down Expand Up @@ -592,6 +649,8 @@ struct ConvertPTensorToLinalgPass
target.addLegalDialect<::mlir::AffineDialect>();
target.addLegalDialect<::mlir::tensor::TensorDialect>();
target.addLegalDialect<::mlir::arith::ArithmeticDialect>();
target.addLegalDialect<::mlir::math::MathDialect>();
target.addLegalDialect<::mlir::tosa::TosaDialect>();
target.addLegalDialect<::mlir::shape::ShapeDialect>();
target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); // FIXME
target.addDynamicallyLegalOp<::mlir::func::FuncOp>(
Expand Down