Skip to content

Commit

Permalink
init pass
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 committed Jun 19, 2024
1 parent 1eaf75d commit 0049c9c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 8 deletions.
2 changes: 1 addition & 1 deletion include/gc/Dialect/Arith/Utils/EasyBuild.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ inline EBFloatPoint operator-(const EBFloatPoint &a) {
}

#define DEF_EASYBUILD_CMP_OPERATOR(OP, OPCLASS, TYPE, PRED) \
EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \
inline EBUnsigned operator OP(const TYPE &a, const TYPE &b) { \
return OperatorHandlers::handleCmp<OPCLASS>(a, b, PRED); \
} \
template <typename T> EBUnsigned operator OP(const TYPE &a, T b) { \
Expand Down
2 changes: 1 addition & 1 deletion include/gc/IR/EasyBuildSCF.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ inline int IfIterator::operator*() const {

} // namespace impl

impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) {
inline impl::IfSimulator makeIfRange(const EasyBuilder &s, Operation *op) {
return impl::IfSimulator{s.builder, op};
}

Expand Down
36 changes: 34 additions & 2 deletions lib/gc/Transforms/FlashAttentionConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//


#include "./Tiling.hpp"
#include "gc/Dialect/Arith/Utils/EasyBuild.h"
#include "gc/Dialect/Linalgx/LinalgxOps.h"
Expand Down Expand Up @@ -45,11 +44,44 @@ namespace gc {
#include "gc/Transforms/Passes.h.inc"

namespace {

struct MHAToFlashAttention
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
if (!llvm::isa<linalgx::ScaledDotProductAttentionOp>(linalgOp))
return failure();
if (linalgOp.hasPureBufferSemantics())
return failure();
}
};

struct FlashAttentionConversion
: public impl::FlashAttentionConversionBase<FlashAttentionConversion> {
public:
void runOnOperation() final {
return;
auto &ctx = getContext();
IRRewriter rewriter(&ctx);
RewritePatternSet patterns(&ctx);

patterns.add<MHAToFlashAttention>(patterns.getContext());
// linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
// linalg::ControlDropUnitDims options;
// options.rankReductionStrategy =
// linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
// linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
// tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);

// for (auto *dialect : ctx.getLoadedDialects())
// dialect->getCanonicalizationPatterns(patterns);
// for (RegisteredOperationName op : ctx.getRegisteredOperations())
// op.getCanonicalizationPatterns(patterns, &ctx);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};

Expand Down
8 changes: 4 additions & 4 deletions unittests/Example/Example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//

#include "gc/Dialect/Linalgx/LinalgxDialect.h"
#include "gtest/gtest.h"
// #include "gtest/gtest.h"

TEST(example, HelloWorld) {
ASSERT_EQ(mlir::linalgx::LinalgxDialect::getDialectNamespace(), "linalgx");
}
// TEST(example, HelloWorld) {
// ASSERT_EQ(mlir::linalgx::LinalgxDialect::getDialectNamespace(), "linalgx");
// }

0 comments on commit 0049c9c

Please sign in to comment.