Skip to content

Commit

Permalink
fix flash attention decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeizh2 committed Jun 28, 2024
1 parent 9b02c96 commit 97f85b4
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions lib/gc/Dialect/Linalgx/LinalgxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,9 @@ void MultiBatchMatmulOp::getEffects(

LogicalResult ScaledDotProductAttentionOp::verify() { return success(); }

/// Given an N-dimensional tensor x, this method converts
/// softmax(x) to the following sequence of operations:
///
/// 1. transpose ins[1]
/// 2. matmul ins[0] @ 1
///
/// This method converts ScaledDotProductAttention into the following
/// sequence of operations:
/// output = softmax(ins[0] @ transpose(ins[1]) * scale + ins[3]) @ ins[2]
FailureOr<SmallVector<Value>>
ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
OpBuilder::InsertionGuard guard(b);
Expand All @@ -635,6 +632,7 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
mask = getInputs()[3];
auto dtype = cast<RankedTensorType>(query.getType()).getElementType();
auto shape = cast<RankedTensorType>(query.getType()).getShape();
float rsqrt_head = 1 / sqrt(shape[3]);

SmallVector<int64_t> permutation{0, 1, 3, 2};
SmallVector<int64_t> transposeShape{shape[0], shape[1], shape[3], shape[2]};
Expand All @@ -652,16 +650,40 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
/*inputs=*/ValueRange{query, transpose->getResult(0)},
/*outputs=*/ValueRange{matmulQKOut.getResult()});

auto mulOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
// Broadcast the initial value to the output tensor before convolving.
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.push_back(b.getMultiDimIdentityMap(4));
indexingMaps.push_back(b.getMultiDimIdentityMap(4));
auto mul = b.create<linalg::GenericOp>(
/*location=*/loc, matmulQKOut.getResult().getType(),
/*inputs=*/ValueRange{matmulQK->getResult(0)},
/*outputs=*/ValueRange{mulOut.getResult()}, indexingMaps,
SmallVector<utils::IteratorType>(4, utils::IteratorType::parallel),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
Value constant = b.create<arith::ConstantOp>(
loc, nestedBuilder.getFloatAttr(dtype, rsqrt_head));
Value added =
nestedBuilder.create<arith::MulFOp>(loc, args[0], constant);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
});

auto addOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
auto add = b.create<linalg::AddOp>(
/*location=*/loc, addOut.getResult().getType(),
/*inputs=*/ValueRange{matmulQK->getResult(0), mask},
/*inputs=*/ValueRange{mul->getResult(0), mask},
/*outputs=*/ValueRange{addOut.getResult()});

auto softmaxOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
auto softmax = b.create<linalg::SoftmaxOp>(
/*location=*/loc, softmaxOut.getResult().getType(),
/*inputs=*/add->getResult(0),
/*outputs=*/softmaxOut.getResult(), 3);

auto matmulVOut = b.create<tensor::EmptyOp>(loc, shape, dtype);
auto matmulV = b.create<linalgx::MultiBatchMatmulOp>(
/*location=*/loc, matmulVOut.getResult().getType(),
/*inputs=*/ValueRange{add->getResult(0), value},
/*inputs=*/ValueRange{softmax->getResult(0), value},
/*outputs=*/ValueRange{matmulVOut.getResult()});
return SmallVector<Value>{matmulV.getResults()[0]};
}
Expand Down

0 comments on commit 97f85b4

Please sign in to comment.