Skip to content

Commit

Permalink
temp cache
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 committed Jun 26, 2024
1 parent 8648723 commit 738986c
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 20 deletions.
17 changes: 16 additions & 1 deletion lib/gc/Dialect/Linalgx/LinalgxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,12 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
mask = getInputs()[3];
auto dtype = cast<RankedTensorType>(query.getType()).getElementType();
auto shape = cast<RankedTensorType>(query.getType()).getShape();
float rsqrt_head = 1 / sqrt(shape[3]);

SmallVector<int64_t> permutation{0, 1, 3, 2};
SmallVector<int64_t> transposeShape{shape[0], shape[1], shape[3], shape[2]};
auto constant =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(dtype, rsqrt_head));
auto transposeOut = b.create<tensor::EmptyOp>(loc, transposeShape, dtype);
auto transpose = b.create<linalg::TransposeOp>(
/*location=*/loc,
Expand All @@ -652,16 +655,28 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
/*inputs=*/ValueRange{query, transpose->getResult(0)},
/*outputs=*/ValueRange{matmulQKOut.getResult()});

auto mulOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
auto mul = b.create<linalg::GenericOp>(
/*location=*/loc, matmulQKOut.getResult().getType(),
/*inputs=*/ValueRange{query, transpose->getResult(0)},
/*outputs=*/ValueRange{matmulQKOut.getResult()});

auto addOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
auto add = b.create<linalg::AddOp>(
/*location=*/loc, addOut.getResult().getType(),
/*inputs=*/ValueRange{matmulQK->getResult(0), mask},
/*outputs=*/ValueRange{addOut.getResult()});

auto softmaxOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
auto softmax = b.create<linalg::SoftmaxOp>(
/*location=*/loc, softmaxOut.getResult().getType(),
/*inputs=*/add->getResult(0),
/*outputs=*/softmaxOut.getResult(), 3);

auto matmulVOut = b.create<tensor::EmptyOp>(loc, shape, dtype);
auto matmulV = b.create<linalgx::MultiBatchMatmulOp>(
/*location=*/loc, matmulVOut.getResult().getType(),
/*inputs=*/ValueRange{add->getResult(0), value},
/*inputs=*/ValueRange{softmax->getResult(0), value},
/*outputs=*/ValueRange{matmulVOut.getResult()});
return SmallVector<Value>{matmulV.getResults()[0]};
}
Expand Down
172 changes: 153 additions & 19 deletions lib/gc/Transforms/FlashAttentionConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

#include <llvm/Support/Debug.h>

#include <iostream>
#include <memory>

namespace mlir {
Expand All @@ -45,16 +46,161 @@ namespace gc {

namespace {

struct FlashAttentionConfig {
int RowBlock, ColumnBlock;
};

static FlashAttentionConfig
getDefaultFlashAttentionConfig(linalgx::ScaledDotProductAttentionOp &sdpaOp) {
// TODO: allow tuning
auto seqLen = sdpaOp.getShape(sdpaOp.getDpsInputOperand(0))[2];
FlashAttentionConfig cfg;

// cfg.RowBlock = seqLen / 64;
// cfg.ColBlock = seqLen / 64;
return cfg;
}

struct MHAToFlashAttention
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
: public OpRewritePattern<linalgx::ScaledDotProductAttentionOp> {
using OpRewritePattern<
linalgx::ScaledDotProductAttentionOp>::OpRewritePattern;

struct OuterLoopGenerationResult {
/// Tiled operations that are generated during tiling. The order does not
/// matter except the last op. The replacements are expected to be the
/// results of the last op.
SmallVector<Operation *> tiledOps;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
SmallVector<LoopLikeOpInterface> reductionLoops;
};

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
// FailureOr<OuterLoopGenerationResult>
// outerLoopGeneration(RewriterBase &rewriter, linalg::LinalgOp linalgOp)
// const {
// SmallVector<unsigned> RowDimPos, ColDimPos;
// linalgOp.getReductionDims(KDimPos);
// getMatmulParallelDims(linalgOp, 0, MDimPos);
// getMatmulParallelDims(linalgOp, 1, NDimPos);

// OuterLoopGenerationOption option;
// auto iteratorTypes = linalgOp.getIteratorTypesArray();
// auto KFirstDim = (int)getOprandDim(linalgOp, KDimPos[0], 1);
// auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0);
// auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1);
// auto KParallelBlockSize =
// KDimPos.size() > 1
// ? divAndCeil(KFirstDim, cfg.KThreads)
// : divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) *
// cfg.KBlock;
// auto MParallelBlockSize =
// MDimPos.size() > 1
// ? divAndCeil(MFirstDim, cfg.MThreads)
// : divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) *
// cfg.MBlock;
// auto NParallelBlockSize =
// NDimPos.size() > 1
// ? divAndCeil(NFirstDim, cfg.NThreads)
// : divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) *
// cfg.NBlock;
// auto KOuterBlockSize = KDimPos.size() > 1
// ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1
// : cfg.KBlock;
// auto MOuterBlockSize = MDimPos.size() > 1
// ? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1
// : cfg.MBlock;
// auto NOuterBlockSize = NDimPos.size() > 1
// ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1
// : cfg.NBlock;
// // Outer
// option.nestedTileSizes.emplace_back(SmallVector<int>{
// MParallelBlockSize, NParallelBlockSize, KParallelBlockSize});
// option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp);
// option.loopDim.emplace_back(
// SmallVector<int>{(int)MDimPos[0], (int)NDimPos[0], (int)KDimPos[0]});
// // Middle
// for (auto [tile, dim] :
// llvm::zip(SmallVector<int>{MOuterBlockSize, NOuterBlockSize,
// KOuterBlockSize},
// SmallVector<int>{(int)MDimPos[0], (int)NDimPos[0],
// (int)KDimPos[0]})) {
// option.nestedTileSizes.emplace_back(SmallVector<int>{tile});
// option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
// option.loopDim.emplace_back(SmallVector<int>{dim});
// }
// // Inner
// if (KDimPos.size() == 1) {
// option.nestedTileSizes.emplace_back(SmallVector<int>{cfg.KBlock});
// option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
// option.loopDim.emplace_back(SmallVector<int>{(int)KDimPos.back()});
// }
// if (MDimPos.size() == 1) {
// option.nestedTileSizes.emplace_back(
// SmallVector<int>{cfg.innerMostMBlock});
// option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
// option.loopDim.emplace_back(SmallVector<int>{(int)MDimPos.back()});
// }
// if (NDimPos.size() == 1) {
// option.nestedTileSizes.emplace_back(
// SmallVector<int>{cfg.innerMostNBlock});
// option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
// option.loopDim.emplace_back(SmallVector<int>{(int)NDimPos.back()});
// }
// for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) {
// if (dim != MDimPos.back() && dim != NDimPos.back() &&
// iteratorTypes[dim] != mlir::utils::IteratorType::reduction) {
// option.nestedTileSizes.emplace_back(SmallVector<int>{1});
// option.loopType.emplace_back(
// OuterLoopGenerationOption::LoopType::ForOp);
// option.loopDim.emplace_back(SmallVector<int>{(int)dim});
// }
// }

// auto lowPrecisionCast =
// [&](RewriterBase &rewriter, Location loc,
// linalg::LinalgOp linalgop) -> FailureOr<linalg::LinalgOp> {
// auto legalizedResult = matmulDtypeLegalize(
// rewriter, linalgop.getOperation(), !hasFillOp, true);
// if (legalizedResult->castOp && legalizedResult->linalgOp) {
// auto linalgOp = legalizedResult->linalgOp;
// rewriter.replaceOp(linalgop,
// linalgOp->getResult(linalgOp->getNumResults() -
// 1));
// return dyn_cast<linalg::LinalgOp>(linalgOp);
// }
// return failure();
// };
// option.innermostFullResultCallBacks.push_back(lowPrecisionCast);

// if (hasFillOp) {
// auto removeReduncantFill =
// [&](RewriterBase &rewriter, Location loc,
// const linalg::ForallReductionTilingResult &result)
// -> FailureOr<linalg::LinalgOp> {
// auto initValue = result.initialValues;
// if (initValue.size() == 1 &&
// isa<linalg::FillOp>(initValue[0].getDefiningOp())) {
// rewriter.replaceOp(initValue[0].getDefiningOp(),
// dyn_cast<DestinationStyleOpInterface>(
// initValue[0].getDefiningOp())
// .getDpsInits()[0]);
// }
// return dyn_cast<linalg::LinalgOp>(result.parallelTiledOp);
// };
// option.finalReduceCallBacks.push_back(removeReduncantFill);
// }
// return generateOuterLoop(rewriter, linalgOp, option);
// }

LogicalResult matchAndRewrite(linalgx::ScaledDotProductAttentionOp sdpaOp,
PatternRewriter &rewriter) const override {
if (!llvm::isa<linalgx::ScaledDotProductAttentionOp>(linalgOp))
return failure();
if (linalgOp.hasPureBufferSemantics())
return failure();
auto decomposableOp =
dyn_cast<mlir::linalg::AggregatedOpInterface>(sdpaOp.getOperation());
FailureOr<SmallVector<Value>> maybeNewResults =
decomposableOp.decomposeOperation(rewriter);
rewriter.replaceOp(decomposableOp, *maybeNewResults);
return success();
}
};

Expand All @@ -65,19 +211,7 @@ struct FlashAttentionConversion
auto &ctx = getContext();
IRRewriter rewriter(&ctx);
RewritePatternSet patterns(&ctx);

patterns.add<MHAToFlashAttention>(patterns.getContext());
// linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
// linalg::ControlDropUnitDims options;
// options.rankReductionStrategy =
// linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
// linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
// tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);

// for (auto *dialect : ctx.getLoadedDialects())
// dialect->getCanonicalizationPatterns(patterns);
// for (RegisteredOperationName op : ctx.getRegisteredOperations())
// op.getCanonicalizationPatterns(patterns, &ctx);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
Expand Down

0 comments on commit 738986c

Please sign in to comment.