Skip to content

Commit

Permalink
Add optimization for load+transpose pattern. (#933)
Browse files Browse the repository at this point in the history
Add optimization for load+transpose pattern. They are transformed to a sequence of
1. load with vnni
2. store scatter to slm
3. 1D block load from slm.
  • Loading branch information
chencha3 authored Oct 18, 2024
1 parent bfd47e4 commit e9c2ff5
Show file tree
Hide file tree
Showing 12 changed files with 822 additions and 40 deletions.
4 changes: 3 additions & 1 deletion include/imex/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def OptimizeTranspose : Pass<"imex-xegpu-optimize-transpose"> {
let summary = "Eliminate in-register vector transpose by fusing with load.";
let constructor = "imex::createOptimizeTransposePass()";
let dependentDialects = [
"::mlir::xegpu::XeGPUDialect"
"::mlir::xegpu::XeGPUDialect",
"::mlir::memref::MemRefDialect",
"::mlir::gpu::GPUDialect"
];
let options = [
Option<"device", "device", "std::string",
Expand Down
12 changes: 12 additions & 0 deletions include/imex/Utils/XeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@
using namespace mlir::xegpu;
namespace imex {

// Combine vectors vertically while keeping the logical data layout.
// As an example, given two vectors (2x4xf16) p and q, it will merge
// them in to a 4x4xf16 vector.
// p1, p2, p3, p4 p1, p2, p3, p4
// p5, p6, p7, p8 p5, p6, p7, p8
// ==> q1, q2, q3, q4
// q1, q2, q3, q4 q5, q6, q7, q8
// q5, q6, q7, q8
mlir::TypedValue<mlir::VectorType> stack(mlir::Value vecUp, mlir::Value vecDown,
mlir::Location loc,
mlir::PatternRewriter &rewriter);

// It checks each GPUFuncOp in the module to see
// whether they have arguments and outputs with
// xetile.TileType. They are currently not supported yet.
Expand Down
15 changes: 10 additions & 5 deletions lib/Conversion/XeGPUToVC/LSCPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ static LogicalResult isValidScatterSetup(Type elemTy, int simd_lanes,
if (!elemTy.isIntOrFloat())
return failure();

if (simd_lanes != 16 && simd_lanes != 32)
return rewriter.notifyMatchFailure(
loc, "A valid simd lane is 16 or 32 for PVC.");
// TODO: temporarily disable simd_lanes check, it is fine for SIMD pipeline
// but may be not compatible with SIMT pipeline.
// if (simd_lanes != 16 && simd_lanes != 32)
// return rewriter.notifyMatchFailure(
// loc, "A valid simd lane is 16 or 32 for PVC.");

if (!llvm::is_contained({1, 2, 3, 4, 8, 16, 32, 64}, chunk_size))
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -260,8 +262,11 @@ static func::CallOp genRawLSCIntrinsicCall(
auto elemTy = predTy.getElementType();
assert(predTy.getRank() == 1 && "predicate must be a 1D vector type.");
assert(elemTy.isInteger(1) && "predicate type must be i1.");
assert(llvm::is_contained({1, 16, 32}, predTy.getNumElements()) &&
"predicate size must be 1, 16 or 32.");
// TODO: temporarily disable predicate_size check. It is
// fine for SIMD pipeline but may not match SIMT pipeline.
//
// assert(llvm::is_contained({1, 16, 32}, predTy.getNumElements()) &&
// "predicate size must be 1, 16 or 32.");

// arg1: i8 subopcode, LSC_LOAD for load/prefetch, LSC_STORE for store
assert((opCode == LSC_LOAD || opCode == LSC_STORE) && "unsupported opcode.");
Expand Down
25 changes: 0 additions & 25 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,6 @@ using VectorTypedValue = mlir::TypedValue<mlir::VectorType>;
using funcTy = VectorTypedValue(mlir::Value, mlir::Value, mlir::Location,
mlir::PatternRewriter &);

// Combine vectors vertically while keeping the logical data layout.
// As an example, given two vectors (2x4xf16) p and q, it will merge
// them in to a 4x4xf16 vector.
// p1, p2, p3, p4 p1, p2, p3, p4
// p5, p6, p7, p8 p5, p6, p7, p8
// ==> q1, q2, q3, q4
// q1, q2, q3, q4 q5, q6, q7, q8
// q5, q6, q7, q8
static VectorTypedValue stack(mlir::Value vecUp, mlir::Value vecDown,
mlir::Location loc,
mlir::PatternRewriter &rewriter) {
auto vecUpTy = llvm::cast<mlir::VectorType>(vecUp.getType());
auto vecDownTy = llvm::cast<mlir::VectorType>(vecDown.getType());
assert(vecUpTy.getRank() == 2 && vecDownTy.getRank() == vecUpTy.getRank() &&
"only supports 2D vectors.");
assert(vecUpTy.getShape()[1] == vecDownTy.getShape()[1] &&
"Operands of stack() do not have the same number of columns.");

llvm::SmallVector<int64_t> mask(vecUpTy.getShape()[0] +
vecDownTy.getShape()[0]);
std::iota(mask.begin(), mask.end(), 0);
auto op = rewriter.create<ShuffleOp>(loc, vecUp, vecDown, mask);
return op;
}

// generate linearized shuffle mask for concat.
static llvm::SmallVector<int64_t>
getShuffleMask(llvm::ArrayRef<int64_t> shape1, llvm::ArrayRef<int64_t> shape2) {
Expand Down
224 changes: 218 additions & 6 deletions lib/Transforms/OptimizeTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include "imex/Utils/XeArch.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
Expand Down Expand Up @@ -48,6 +49,9 @@ namespace imex {
#include "imex/Transforms/Passes.h.inc"
} // namespace imex

#define index_val(value) \
rewriter.create<mlir::arith::ConstantOp>(loc, rewriter.getIndexAttr(value))

namespace optimizetranspose {

// Convenience interface for defining an op pattern.
Expand Down Expand Up @@ -210,6 +214,126 @@ bool withinRange(int val, imex::Range range) {
// NON_PACKED represents any other usage.
enum TransposeUsageType { PACKED = 1, NON_PACKED = 2 };

// Helper function to pack the given value in vnni format.
// to 32-bit representation. e.g., vector<8x8x2xf16> to vector<8x8xf32>
static mlir::Value pack(mlir::Value value, mlir::PatternRewriter &rewriter) {
auto type = mlir::dyn_cast<mlir::VectorType>(value.getType());
if (!type || type.getRank() != 3)
return value;
auto elemTy = type.getElementType();
if (!elemTy.isIntOrFloat())
return value;

auto shape = type.getShape();
auto factor = shape[2];
auto bits = elemTy.getIntOrFloatBitWidth();
if (factor * bits != 32)
return value;

auto loc = value.getLoc();
// first cast the value to 1D shape
auto vecTy = mlir::VectorType::get({type.getNumElements()}, elemTy);
value = rewriter.create<mlir::vector::ShapeCastOp>(loc, vecTy, value);

// cast to 32-bit data, use i32 for intergers and f32 for floats.
elemTy = elemTy.isInteger() ? (mlir::Type)rewriter.getIntegerType(32)
: (mlir::Type)rewriter.getF32Type();
vecTy = mlir::VectorType::get({shape[0] * shape[1]}, elemTy);
value = rewriter.create<mlir::vector::BitCastOp>(loc, vecTy, value);

// cast to 2D shape
vecTy = mlir::VectorType::get({shape[0], shape[1]}, elemTy);
return rewriter.create<mlir::vector::ShapeCastOp>(loc, vecTy, value);
}

static void createStoreScatter(mlir::Value data, mlir::Value slm,
mlir::Value base,
mlir::PatternRewriter &rewriter) {
auto type = mlir::dyn_cast<mlir::VectorType>(data.getType());
if (!type || type.getRank() > 2)
return;

auto loc = data.getLoc();
auto shape = type.getShape();
auto chunkSize = type.getRank() == 2 ? shape[0] : 1;
auto simdLanes = type.getRank() == 2 ? shape[1] : shape[0];

llvm::SmallVector<int64_t> staticOffsets;
for (auto i = 0; i < simdLanes; i++) {
staticOffsets.push_back(i * chunkSize);
}
auto addrTy = mlir::VectorType::get(simdLanes, base.getType());
auto denseOffsets = mlir::DenseIntElementsAttr::get(addrTy, staticOffsets);
mlir::Value offsets =
rewriter.create<mlir::arith::ConstantOp>(loc, denseOffsets);
base = rewriter.create<mlir::vector::BroadcastOp>(loc, addrTy, base);
offsets = rewriter.create<mlir::arith::AddIOp>(loc, base, offsets);
llvm::SmallVector<int64_t> tdescShape({simdLanes});
if (chunkSize > 1)
tdescShape.push_back(chunkSize);

auto tdescTy = mlir::xegpu::TensorDescType::get(
tdescShape, type.getElementType(), chunkSize,
mlir::xegpu::MemorySpace::SLM);
auto desc =
rewriter.create<mlir::xegpu::CreateDescOp>(loc, tdescTy, slm, offsets);

auto transposeAttr = rewriter.getUnitAttr();
auto maskTy = mlir::VectorType::get(simdLanes, rewriter.getI1Type());
auto mask = rewriter.create<mlir::arith::ConstantOp>(
loc, mlir::DenseElementsAttr::get(maskTy, rewriter.getBoolAttr(true)));
rewriter.create<mlir::xegpu::StoreScatterOp>(loc, data, desc, mask,
transposeAttr, nullptr /*L1*/,
nullptr /*L2*/, nullptr /*L3*/);
}

static mlir::Value createBlockLoad(mlir::TypedValue<mlir::MemRefType> slm,
mlir::Value base, int numElems,
mlir::Type slmElemTy, mlir::Type opElemTy,
llvm::ArrayRef<int64_t> shape,
mlir::PatternRewriter &rewriter) {
auto loc = base.getLoc();
// choose a maximum chunk size that can evenly divide numElems.
std::vector<int> chunkSizes({64, 32, 16, 8, 4, 3, 2, 1});
auto it = std::find_if(chunkSizes.begin(), chunkSizes.end(),
[&](int s) { return numElems % s == 0; });
auto vectSize = *it;
auto bitWidth = opElemTy.getIntOrFloatBitWidth();
auto factor = bitWidth >= 32 ? 1 : 32 / bitWidth;
auto numLoads = numElems / vectSize;
auto tdescTy = mlir::xegpu::TensorDescType::get(
vectSize, slmElemTy, 1, false, mlir::xegpu::MemorySpace::SLM);
auto loadTy = mlir::VectorType::get(vectSize, slmElemTy);
auto target1DTy = mlir::VectorType::get(vectSize * factor, opElemTy);
auto target2DTy =
mlir::VectorType::get({shape[0] / numLoads, shape[1]}, opElemTy);
llvm::SmallVector<mlir::Value> loads;

for (auto i = 0; i < numLoads; i++) {
mlir::Value offset = rewriter.create<mlir::arith::AddIOp>(
loc, base, index_val(i * vectSize));
auto tdesc = rewriter.create<mlir::xegpu::CreateNdDescOp>(
loc, tdescTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({offset}));
mlir::Value value = rewriter.create<mlir::xegpu::LoadNdOp>(
loc, loadTy, tdesc, nullptr /*packed*/, nullptr /*transpose*/,
nullptr /*transpose_bit_width*/, nullptr /*l1_hint*/,
nullptr /*l2_hint*/, nullptr /*l3_hint*/);
// if original data is not 32-bit, need to bitcast current 32-bit data
// back to original element type.
if (bitWidth < 32)
value = rewriter.create<mlir::vector::BitCastOp>(loc, target1DTy, value);

// shape cast the value to 2D shape.
value = rewriter.create<mlir::vector::ShapeCastOp>(loc, target2DTy, value);
loads.push_back(value);
}
auto result = loads[0];
for (size_t i = 1; i < loads.size(); i++) {
result = imex::stack(result, loads[i], loc, rewriter);
}
return result;
}

// This pattern detects a transpose op that is using the result of a load op and
// replace it with a new load op that does the load+transpose together. Pattern
// is only applied if the transpose is used in DPAS B. In addition packed layout
Expand Down Expand Up @@ -278,17 +402,105 @@ struct TransposeRewritePattern
// Check if this tranpose is using a load op.
auto loadOp = llvm::dyn_cast_if_present<mlir::xegpu::LoadNdOp>(
op.getVector().getDefiningOp());
if (!loadOp)
return mlir::failure();
// Check if this load op is part of the analysis result.
if (!analysis.contains(loadOp))
if (!loadOp || !loadOp->hasOneUse())
return mlir::failure();

auto opVectorType = op.getType();
auto opElementTy = opVectorType.getElementType();

// Only transposes that cannot be folded with the load op are considered.
bool foldable = analysis.contains(loadOp) &&
(canTranspose(loadOp, TransposeUsageType::PACKED) ||
canTranspose(loadOp, TransposeUsageType::NON_PACKED));

if (!foldable && !loadOp.getPacked() && opElementTy.isIntOrFloat()) {
// try to optimize the load+transpose sequence only using SLM.
// It covers the cases of 8-bit/16-bit data types, and hardware
// unsupported shapes of 32-bit data types, e.g., <8x32xf32>.
auto tdescTy = op.getSourceVectorType();
auto bitWidth = opElementTy.getIntOrFloatBitWidth();
auto bytes = tdescTy.getNumElements() * bitWidth / 8;

// limite the total data size <= 512 bytes, which is maximum size
// can be handled by a single load/store lsc intrinsic.
if (bytes > 512)
return rewriter.notifyMatchFailure(
op, "total data size is larger than 512 bytes.");

// Element type for SLM, all operations to slm are done in 32-bit
// or 64-bit granularity.
auto elemTy = bitWidth >= 32 ? opElementTy
: opElementTy.isInteger()
? (mlir::Type)rewriter.getIntegerType(32)
: (mlir::Type)rewriter.getF32Type();

auto shape = tdescTy.getShape();
// make sure each simd lane write a column. Ideally, it should be
// less than 32. But IGC can split it if it into multiple instructions
// if it is larger than 32.
auto simdLanes = shape[1];
// number of elements in 32-bit data.
auto numElems = bytes / 4;
// number of elements each simd lane to write
int chunkSize = numElems / simdLanes;
llvm::SmallVector<int> validChunkSizes = {64, 32, 16, 8, 4, 3, 2, 1};

// the numElems has to be evenly divided by simdLanes, and the chunkSize
// has to be in the validChunkSizes.
if (numElems % simdLanes != 0 ||
!llvm::is_contained(validChunkSizes, chunkSize))
return mlir::failure();

auto loc = loadOp.getLoc();
auto data = loadOp.getResult();
if (bitWidth < 32) {
// vnni factor, the number of elements packed into a 32-bit data
auto factor = 32 / bitWidth;
// add vnni transformation to load op, and the result (data) type
// will be updated from, e.g., vector<8x16xf16> to vector<4x16x2xf16>
int64_t vnniShape[3] = {shape[0] / factor, shape[1], factor};
auto vnniTy = mlir::VectorType::get(vnniShape, opElementTy);
loadOp.setPacked(true);
loadOp.getResult().setType(vnniTy);

// pack the result into 32-bit format, e.g., vector<4x16x2xf16> to
// vector<4x16xf32>
data = pack(loadOp.getResult(), rewriter);
}

// alloc a shared local memory for the data. Note that SLM is shared among
// subgroups in a workgroup. The total size needed is numElems *
// numSubgroups. However, currently dynamic allocation is not supported,
// so we assume maximum number of subgroups is 64, considering that the
// PVC has 8 EUs per subslice, and 8 threads per EU.
// TODO: get the number from uArch.
int64_t totSlmSize = 64 * numElems;
auto slmTy = mlir::MemRefType::get({totSlmSize}, elemTy, {}, 3);
auto slm = rewriter.create<mlir::memref::AllocOp>(loc, slmTy);

auto sgId = rewriter.create<mlir::gpu::SubgroupIdOp>(
loc, rewriter.getIndexType(), nullptr /* upper_bound*/);
auto offset = rewriter.create<mlir::arith::MulIOp>(
loc, sgId, index_val(numElems), nullptr /* overflowFlags */);

// store data using store_scatter to SLM at the given offset.
createStoreScatter(data, slm, offset, rewriter);

// load numElems elements from SLM at the given offset using 1D block
// load.
auto result = createBlockLoad(slm, offset, numElems, elemTy, opElementTy,
opVectorType.getShape(), rewriter);
rewriter.replaceOp(op, result);
return mlir::success();
}

// trying to optimize the load+transpose+dpasB sequence.

// Check if the transpose has a single user and it has desired packed
// layout conversion op sequence.
if (!op->hasOneUse())
return mlir::failure();
auto opVectorType = op.getType();
auto opElementTy = opVectorType.getElementType();

// If the element type if < 32 bits, we need to clean up the packed layout
// conversion op sequence.
if (opElementTy.getIntOrFloatBitWidth() < 32) {
Expand Down
17 changes: 17 additions & 0 deletions lib/Utils/XeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,21 @@ llvm::SmallVector<int64_t> defaultStrides(llvm::ArrayRef<int64_t> shape) {
return strides;
}

mlir::TypedValue<mlir::VectorType> stack(mlir::Value vecUp, mlir::Value vecDown,
mlir::Location loc,
mlir::PatternRewriter &rewriter) {
auto vecUpTy = llvm::cast<mlir::VectorType>(vecUp.getType());
auto vecDownTy = llvm::cast<mlir::VectorType>(vecDown.getType());
assert(vecUpTy.getRank() == 2 && vecDownTy.getRank() == vecUpTy.getRank() &&
"only supports 2D vectors.");
assert(vecUpTy.getShape()[1] == vecDownTy.getShape()[1] &&
"Operands of stack() do not have the same number of columns.");

llvm::SmallVector<int64_t> mask(vecUpTy.getShape()[0] +
vecDownTy.getShape()[0]);
std::iota(mask.begin(), mask.end(), 0);
auto op = rewriter.create<mlir::vector::ShuffleOp>(loc, vecUp, vecDown, mask);
return op;
}

} // namespace imex
Loading

0 comments on commit e9c2ff5

Please sign in to comment.