diff --git a/include/imex/Dialect/PTensor/IR/PTensorOps.h b/include/imex/Dialect/PTensor/IR/PTensorOps.h index a70b3c47f..0d3c61d57 100644 --- a/include/imex/Dialect/PTensor/IR/PTensorOps.h +++ b/include/imex/Dialect/PTensor/IR/PTensorOps.h @@ -44,7 +44,7 @@ enum EWBinOpId : int { LOGICAL_AND, LOGICAL_OR, LOGICAL_XOR, - LSHIFT, + // LSHIFT, MATMUL, MAXIMUM, MINIMUM, @@ -53,6 +53,7 @@ enum EWBinOpId : int { NOT_EQUAL, OR, POWER, + // RSHIFT, SUBTRACT, TRUE_DIVIDE, XOR, diff --git a/lib/Conversion/PTensorToLinalg/PTensorToLinalg.cpp b/lib/Conversion/PTensorToLinalg/PTensorToLinalg.cpp index 7e3b98d3a..cab17ebcc 100644 --- a/lib/Conversion/PTensorToLinalg/PTensorToLinalg.cpp +++ b/lib/Conversion/PTensorToLinalg/PTensorToLinalg.cpp @@ -24,9 +24,11 @@ #include #include #include +#include #include #include #include +#include #include #include @@ -293,6 +295,45 @@ static BodyType buildTrivial(::mlir::Type typ) { }; } +// Builder for TOSA body. +// The builder function takes an extra agrument ::mlir::Type +// Many TOSA functions take two arguments and do the +// same operation on both arguments. +// TODO: +// 1. Find a way to merge this function with trivialBuilder(). +// 2. Detect if there is rhs2, if yes, build accordingly. +// 3. We need to introduce a Boolean type. +template +static BodyType buildTosa(::mlir::Type typ) { + return [typ](mlir::OpBuilder &builder, ::mlir::Location loc, + ::mlir::ValueRange args) -> void { + auto lhs_typ = args[0].getType(); + if (lhs_typ.isIntOrIndex()) { + if constexpr (!std::is_same_v) { + auto lhs = doSignCast(builder, loc, args[0]); + auto rhs1 = doSignCast(builder, loc, args[1]); + // auto rhs2 = doSignCast(builder, loc, args[2]); // sometimes there can + // be two params + yield(builder, loc, typ, + builder.create(loc, typ, lhs, rhs1).getResult()); + return; + } else + assert("Found integer type but binary op not defined for integers" == + nullptr); + } else if (lhs_typ.isIntOrIndexOrFloat()) { + if constexpr (!std::is_same_v) { + yield(builder, loc, typ, + builder.create(loc, typ, args[0], args[1]).getResult()); + return; + } else + assert("Found float type but binary op not defined for floats" == + nullptr); + } else { + assert("Only integers and floats supported for binary ops" == nullptr); + } + }; +} + // get a body builder for given binary operation and result type // we accept a result type to insert a cast after the operation if needed static BodyType getBodyBuilder(::imex::ptensor::EWBinOpId bop, @@ -300,12 +341,17 @@ static BodyType getBodyBuilder(::imex::ptensor::EWBinOpId bop, switch (bop) { case ptensor::ADD: return buildTrivial(typ); - // case ptensor::ATAN2] = + case ptensor::ATAN2: + return buildTrivial(typ); case ptensor::FLOOR_DIVIDE: return buildTrivial(typ); // case ptensor::LOGADDEXP] = - // case ptensor::LSHIFT] = - // case ptensor::MATMUL] = + // case ptensor::LSHIFT: + // return buildTosa(typ); + // case ptensor::RSHIFT: + // return buildTosa(typ); + case ptensor::MATMUL: + return buildTosa(typ); case ptensor::MAXIMUM: return buildTrivial(typ); case ptensor::MINIMUM: @@ -314,24 +360,35 @@ static BodyType getBodyBuilder(::imex::ptensor::EWBinOpId bop, return buildTrivial(typ); case ptensor::MULTIPLY: return buildTrivial(typ); - // case ptensor::POW] = + case ptensor::POWER: + return buildTrivial(typ); case ptensor::SUBTRACT: return buildTrivial(typ); // case ptensor::TRUE_DIVIDE] = - // case ptensor::BITWISE_AND] = - // case ptensor::BITWISE_LEFT_SHIFT] = - // case ptensor::BITWISE_OR] = - // case ptensor::BITWISE_RIGHT_SHIFT] = - // case ptensor::BITWISE_XOR] = - - // case ptensor::EQUAL] = - // case ptensor::GREATER] = - // case ptensor::GREATER_EQUAL] = + case ptensor::BITWISE_AND: + return buildTosa(typ); + case ptensor::BITWISE_LEFT_SHIFT: + return buildTosa(typ); + case ptensor::BITWISE_OR: + return buildTosa(typ); + case ptensor::BITWISE_RIGHT_SHIFT: + return buildTosa(typ); + case ptensor::BITWISE_XOR: + return buildTosa(typ); + case ptensor::EQUAL: + return buildTosa(typ); + case ptensor::GREATER: + return buildTosa(typ); + case ptensor::GREATER_EQUAL: + return buildTosa(typ); // case ptensor::LESS] = // case ptensor::LESS_EQUAL] = - // case ptensor::LOGICAL_AND] = - // case ptensor::LOGICAL_OR] = - // case ptensor::LOGICAL_XOR] = + case ptensor::LOGICAL_AND: + return buildTosa(typ); + case ptensor::LOGICAL_OR: + return buildTosa(typ); + case ptensor::LOGICAL_XOR: + return buildTosa(typ); // case ptensor::NOT_EQUAL] = default: assert("unsupported elementwise binary operation" == nullptr); @@ -592,6 +649,8 @@ struct ConvertPTensorToLinalgPass target.addLegalDialect<::mlir::AffineDialect>(); target.addLegalDialect<::mlir::tensor::TensorDialect>(); target.addLegalDialect<::mlir::arith::ArithmeticDialect>(); + target.addLegalDialect<::mlir::math::MathDialect>(); + target.addLegalDialect<::mlir::tosa::TosaDialect>(); target.addLegalDialect<::mlir::shape::ShapeDialect>(); target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); // FIXME target.addDynamicallyLegalOp<::mlir::func::FuncOp>(