Skip to content

Commit

Permalink
[WIP] #34 Fixed support for user-defined functions. Currently only wo…
Browse files Browse the repository at this point in the history
…rks for functions without args.
  • Loading branch information
pthomadakis committed Oct 13, 2023
1 parent 16a6de6 commit bcefa44
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 18 deletions.
2 changes: 2 additions & 0 deletions frontends/comet_dsl/comet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"

#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
Expand Down Expand Up @@ -370,6 +371,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,

// Finally lowering index tree to SCF dialect
optPM.addPass(mlir::comet::createLowerIndexTreeToSCFPass());
pm.addPass(mlir::func::createFuncBufferizePass()); // Needed for func

// Dump index tree dialect.
if (emitLoops)
Expand Down
7 changes: 7 additions & 0 deletions frontends/comet_dsl/include/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,13 @@ namespace tensorAlgebra
}
// CallExprAST is generated for random()
}
else
{
if (args.size() == 0)
{
args.push_back(nullptr);
}
}
comet_debug() << "generate CallExprAST node\n ";
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args[0]));
}
Expand Down
36 changes: 34 additions & 2 deletions frontends/comet_dsl/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "comet/Dialect/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -1336,6 +1337,32 @@ namespace
sumVal = builder.create<mlir::tensorAlgebra::ReduceOp>(location, builder.getF64Type(), tensorValue);
}
}
else
{
auto *expr = call.getArgs();
if(expr)
{
assert(false && "functions with argument are currently not supported!");
}
mlir::Value tensorValue;
tensorValue = mlir::Value();
ArrayRef<mlir::Value> args{};
if(tensorValue)
args = ArrayRef<mlir::Value> (tensorValue);

auto c = functionMap.lookup(callee);
if(c.getFunctionType().getResults().size() > 0) // Function that returns a value
{
auto res = builder.create<GenericCallOp>(location, c.getFunctionType().getResults()[0], callee, args);
sumVal = res.getResults()[0];
}
else // Void function
{
builder.create<GenericCallOp>(location, callee, args);
sumVal = mlir::Value();
}
}
// comet_debug() << "Called: " << callee << "\n";

// Otherwise this is a call to a user-defined function. Calls to ser-defined
// functions are mapped to a custom call that takes the callee name as an
Expand Down Expand Up @@ -2298,8 +2325,13 @@ namespace

// Generic expression dispatch codegen.
comet_debug() << " expr->getKind(): " << expr->getKind() << "\n";
if (!mlirGen(*expr))
return mlir::failure();

// If calling a void function this will return null, thus we cannot count on this for
// error checking
mlirGen(*expr);
// return mlir::failure();
// if (!mlirGen(*expr))
// return mlir::failure();
}
return mlir::success();
}
Expand Down
2 changes: 1 addition & 1 deletion include/comet/Dialect/TensorAlgebra/IR/TAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def GenericCallOp : TA_Op<"generic_call",

// The generic call operation returns a single value of TensorType or
// StructType.
let results = (outs TA_AnyTensor);
let results = (outs Optional<TA_AnyTensor>);

// Specialize assembly printing and parsing using a declarative format.
let assemblyFormat = [{
Expand Down
34 changes: 34 additions & 0 deletions lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,40 @@ namespace
}
};

class ReturnOpLowering : public ConversionPattern
{
public:
explicit ReturnOpLowering(MLIRContext *ctx)
: ConversionPattern(tensorAlgebra::PrintElapsedTimeOp::getOperationName(), 1, ctx) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override
{
auto ctx = rewriter.getContext();
auto module = op->getParentOfType<ModuleOp>();

auto start = operands[0];
auto end = operands[1];
std::string printElapsedTimeStr = "printElapsedTime";
auto f64Type = rewriter.getF64Type();

if (!hasFuncDeclaration(module, printElapsedTimeStr))
{
auto printElapsedTimeFunc = FunctionType::get(ctx, {f64Type, f64Type}, {});
// func @printElapsedTime(f64, f64) -> ()
func::FuncOp func1 = func::FuncOp::create(op->getLoc(), printElapsedTimeStr,
printElapsedTimeFunc, ArrayRef<NamedAttribute>{});
func1.setPrivate();
module.push_back(func1);
}

rewriter.replaceOpWithNewOp<func::CallOp>(op, printElapsedTimeStr, SmallVector<Type, 2>{}, ValueRange{start, end});

return success();
}
};

} // end anonymous namespace.

/// This is a partial lowering to linear algebra of the tensor algebra operations that are
Expand Down
80 changes: 66 additions & 14 deletions lib/Conversion/TensorAlgebraToSCF/LowerFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,30 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"

using namespace mlir;
// *********** For debug purpose *********//
// #ifndef DEBUG_MODE_LOWER_FUNC
// #define DEBUG_MODE_LOWER_FUNC
// #endif

#ifdef DEBUG_MODE_LOWER_FUNC
#define comet_debug() llvm::errs() << __FILE__ << " " << __LINE__ << " "
#define comet_pdump(n) \
llvm::errs() << __FILE__ << " " << __LINE__ << " "; \
n->dump()
#define comet_vdump(n) \
llvm::errs() << __FILE__ << " " << __LINE__ << " "; \
n.dump()
#else
#define comet_debug() llvm::nulls()
#define comet_pdump(n)
#define comet_vdump(n)
#endif
// *********** For debug purpose *********//

//===----------------------------------------------------------------------===//
// tensorAlgebra::FuncOp to func::FuncOp RewritePatterns
Expand All @@ -25,20 +44,20 @@ namespace {

struct FuncOpLowering : public OpConversionPattern<tensorAlgebra::FuncOp> {
using OpConversionPattern<tensorAlgebra::FuncOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(tensorAlgebra::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// We only lower the main function as we expect that all other functions
// have been inlined.
if (op.getName() != "main")
return failure();

// Verify that the given main has no inputs and results.
if (op.getNumArguments() || op.getFunctionType().getNumResults()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "expected 'main' to have 0 inputs and 0 results";
});
if (op.getName() == "main")
{
// return failure();
// Verify that the given main has no inputs and results.
if (op.getNumArguments() || op.getFunctionType().getNumResults()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "expected 'main' to have 0 inputs and 0 results";
});
}
}

// Create a new non-tensorAlgebra function, with the same region.
Expand Down Expand Up @@ -78,15 +97,48 @@ struct ReturnOpLowering : public OpRewritePattern<tensorAlgebra::TAReturnOp> {
PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
return failure();
// if (op.hasOperand())
// return failure();

if(op.hasOperand())
{
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, op.getOperands());
}
else
{
rewriter.replaceOpWithNewOp<func::ReturnOp>(op);
}

return success();
}
};

struct GenericCallOpLowering : public OpRewritePattern<mlir::tensorAlgebra::GenericCallOp> {
using OpRewritePattern<mlir::tensorAlgebra::GenericCallOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::tensorAlgebra::GenericCallOp op,
PatternRewriter &rewriter) const final {

// During this lowering, we expect that all function calls have been
// inlined.
// if (op.hasOperand())
// return failure();

// We lower "toy.return" directly to "func.return".
rewriter.replaceOpWithNewOp<func::ReturnOp>(op);
if(op.getResults().size() > 0)
{
auto res = rewriter.replaceOpWithNewOp<func::CallOp>(op, op->getAttrOfType<SymbolRefAttr>("callee"), op.getType(0), op.getOperands());
}
else
{
auto res = rewriter.replaceOpWithNewOp<func::CallOp>(op, op->getAttrOfType<SymbolRefAttr>("callee"), mlir::TypeRange(), op.getOperands());
}

return success();
}
};


void FuncOpLoweringPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
Expand All @@ -110,7 +162,7 @@ void FuncOpLoweringPass::runOnOperation() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<FuncOpLowering, ReturnOpLowering>(
patterns.add<FuncOpLowering, ReturnOpLowering, GenericCallOpLowering>(
&getContext());

// With the target and rewrite patterns defined, we can now attempt the
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TensorAlgebra/IR/TADialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value> arguments)
{
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
// state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands(arguments);
state.addAttribute("callee",
mlir::SymbolRefAttr::get(builder.getContext(), callee));
Expand Down

0 comments on commit bcefa44

Please sign in to comment.