Skip to content

Commit

Permalink
support flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 committed Aug 2, 2024
1 parent de0376f commit f741fbd
Show file tree
Hide file tree
Showing 7 changed files with 534 additions and 0 deletions.
25 changes: 25 additions & 0 deletions include/gc/Dialect/Linalgx/LinalgxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,33 @@

include "LinalgxDialect.td"

include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"

// Base class for Linalg dialect ops that do not correspond to library calls.
class Linalgx_Op<string mnemonic, list<Trait> traits = []> :
Op<LinalgxDialect, mnemonic, traits>;

def Linalgx_ScaledDotProductAttentionOp
: Linalgx_Op<"scaled_dot_product_attention",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>]> {
let summary = "Attention structure.";
let description = [{
Q, K, V, attention_mask.
Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V.
}];
let arguments = (ins
Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs);
let results = (outs Variadic<AnyRankedTensor>:$results);

let hasVerifier = 1;
let assemblyFormat = [{
attr-dict
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($results)^)?
}];
}
#endif // LINALGX_OPS
11 changes: 11 additions & 0 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
];
}

def FlashAttentionConversion
: Pass<"flash-attention-conversion", "func::FuncOp"> {
let summary = "Flash Attention Conversion";
let description =
[{The pass converts MHA to flash attention implementation.}];
let dependentDialects = [
"func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect",
"tensor::TensorDialect"
];
}

#ifdef GC_USE_GPU
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
let summary = "Convert linalg dialect to XeGPU dialect.";
Expand Down
75 changes: 75 additions & 0 deletions lib/gc/Dialect/Linalgx/LinalgxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "gc/Dialect/Linalgx/LinalgxOps.h"
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
#include "mlir/IR/OpImplementation.h"
#include <utility>

//===----------------------------------------------------------------------===//
// Builder helper from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Down Expand Up @@ -608,6 +609,80 @@ void MultiBatchMatmulOp::getEffects(
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}

//===----------------------------------------------------------------------===//
// ScaledDotProductAttentionOp
//===----------------------------------------------------------------------===//

LogicalResult ScaledDotProductAttentionOp::verify() { return success(); }

/// This method converts ScaledDotProductAttention into the following
/// sequence of operations:
/// output = softmax(ins[0] @ transpose(ins[1]) * scale + ins[3]) @ ins[2]
FailureOr<SmallVector<Value>>
ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(*this);
Location loc = getLoc();
Value query = getInputs()[0], key = getInputs()[1], value = getInputs()[2],
mask = getInputs()[3];
auto dtype = cast<RankedTensorType>(query.getType()).getElementType();
auto shape = cast<RankedTensorType>(query.getType()).getShape();
float rsqrt_head = 1 / sqrt(shape[3]);

SmallVector<int64_t> permutation{0, 1, 3, 2};
SmallVector<int64_t> transposeShape{shape[0], shape[1], shape[3], shape[2]};
auto transposeOut = b.create<tensor::EmptyOp>(loc, transposeShape, dtype);
auto transpose = b.create<linalg::TransposeOp>(
/*location=*/loc,
/*inputs=*/key,
/*outputs=*/transposeOut,
/*permutation=*/permutation);

SmallVector<int64_t> matmulQKShape{shape[0], shape[1], shape[2], shape[2]};
auto matmulQKOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
auto matmulQK = b.create<linalgx::MultiBatchMatmulOp>(
/*location=*/loc, matmulQKOut.getResult().getType(),
/*inputs=*/ValueRange{query, transpose->getResult(0)},
/*outputs=*/ValueRange{matmulQKOut.getResult()});

auto mulOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
// Broadcast the initial value to the output tensor before convolving.
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.push_back(b.getMultiDimIdentityMap(4));
indexingMaps.push_back(b.getMultiDimIdentityMap(4));
auto mul = b.create<linalg::GenericOp>(
/*location=*/loc, matmulQKOut.getResult().getType(),
/*inputs=*/ValueRange{matmulQK->getResult(0)},
/*outputs=*/ValueRange{mulOut.getResult()}, indexingMaps,
SmallVector<utils::IteratorType>(4, utils::IteratorType::parallel),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
Value constant = b.create<arith::ConstantOp>(
loc, nestedBuilder.getFloatAttr(dtype, rsqrt_head));
Value added =
nestedBuilder.create<arith::MulFOp>(loc, args[0], constant);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
});

auto addOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
auto add = b.create<linalg::AddOp>(
/*location=*/loc, addOut.getResult().getType(),
/*inputs=*/ValueRange{mul->getResult(0), mask},
/*outputs=*/ValueRange{addOut.getResult()});

auto softmaxOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
auto softmax = b.create<linalg::SoftmaxOp>(
/*location=*/loc, softmaxOut.getResult().getType(),
/*inputs=*/add->getResult(0),
/*outputs=*/softmaxOut.getResult(), 3);

auto matmulVOut = b.create<tensor::EmptyOp>(loc, shape, dtype);
auto matmulV = b.create<linalgx::MultiBatchMatmulOp>(
/*location=*/loc, matmulVOut.getResult().getType(),
/*inputs=*/ValueRange{softmax->getResult(0), value},
/*outputs=*/ValueRange{matmulVOut.getResult()});
return SmallVector<Value>{matmulV.getResults()[0]};
}

/////// Operations corresponding to library calls defined with Tablegen ////////

#define GET_OP_CLASSES
Expand Down
1 change: 1 addition & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_library(GCPasses
OneDNNGraphToLinalg.cpp
Pipeline.cpp
TileNamed.cpp
FlashAttentionConversion.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include
Expand Down
Loading

0 comments on commit f741fbd

Please sign in to comment.