Skip to content

Commit

Permalink
Add xetile-tiling and xetile-to-xegpu convertion passes.
Browse files Browse the repository at this point in the history
co-authored by:
    Chao Chen: [email protected]
    Dimpalben R Prajapati: [email protected]
  • Loading branch information
chencha3 committed Nov 8, 2023
1 parent c6f1e56 commit 1577e15
Show file tree
Hide file tree
Showing 50 changed files with 5,170 additions and 168 deletions.
1 change: 1 addition & 0 deletions include/imex/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ add_public_tablegen_target(IMEXConversionPassIncGen)

add_mlir_doc(Passes IMEXConversionPasses ./ -gen-pass-doc)
add_subdirectory(DistToStandard)
add_subdirectory(XeTileToXeGPU)
1 change: 1 addition & 0 deletions include/imex/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <imex/Conversion/GPUXToLLVM/GPUXToLLVMPass.h>
#include <imex/Conversion/PTensorToLinalg/PTensorToLinalg.h>
#include <imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h>
#include <imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h>

namespace imex {

Expand Down
44 changes: 44 additions & 0 deletions include/imex/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,48 @@ def ConvertGPUXToLLVM : Pass<"convert-gpux-to-llvm", "::mlir::ModuleOp"> {
}


//===----------------------------------------------------------------------===//
// XeTileToXeGPU
//===----------------------------------------------------------------------===//

def ConvertXeTileToXeGPU: Pass<"convert-xetile-to-xegpu", "::mlir::ModuleOp"> {
let summary = "Convert from the XeTile dialect to the XeGPU dialect.";
let description = [{
Convert XeTile dialect operations into the XeGPU dialect operations. It expects
the input code is tiled using xetile-tiling.

#### Input invariant

func.func @sglevel_tiled_load_tile(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<2x1x8x16xf16>
%2 = xetile.load_tile %1 : !xetile.tile<2x1x8x16xf16> -> vector<2x1x8x16xf16>
return
}

#### Output IR

func.func @sglevel_tiled_load_tile(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c64] {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16>
%c8 = arith.constant 8 : index
%c64_0 = arith.constant 64 : index
%1 = xegpu.create_nd_tdesc %arg0[%c8, %c64_0] {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16>
%2 = xegpu.load_nd %0 {mode = vc, l1_hint = uncached, l2_hint = uncached, l3_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%3 = xegpu.load_nd %1 {mode = vc, l1_hint = uncached, l2_hint = uncached, l3_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
return
}
}];

let constructor = "::imex::createConvertXeTileToXeGPUPass()";
let dependentDialects = ["::imex::xegpu::XeGPUDialect",
"::imex::xetile::XeTileDialect",
"::mlir::vector::VectorDialect",
"::mlir::arith::ArithDialect",
];
let options = [];
}

#endif // _IMEX_CONVERSION_PASSES_TD_INCLUDED_
Empty file.
45 changes: 45 additions & 0 deletions include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//===- XeTileToXeGPU.h - XeTileToXeGPU conversion -------*- C++ -*-===//
//
// Copyright 2022 Intel Corporation
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file defines the XeTileToXeGPU conversion, converting the XeTile
/// dialect to the XeGPU dialect.
///
//===----------------------------------------------------------------------===//

#ifndef _XeTileToXeGPU_H_INCLUDED_
#define _XeTileToXeGPU_H_INCLUDED_

#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Transforms/DialectConversion.h>

#include "XeTileToXeGPUConversion.h"

namespace mlir {
class MLIRContext;
class ModuleOp;
template <typename T> class OperationPass;
class RewritePatternSet;
} // namespace mlir

namespace imex {
class XeGPUTypeConverter;

/// Populate the given list with patterns rewrite XeTile Ops
void populateXeTileToXeGPUConversionPatterns(XeGPUTypeConverter &converter,
mlir::RewritePatternSet &patterns);

/// Create a pass to convert the XeTile dialect to the XeGPU dialect.
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createConvertXeTileToXeGPUPass();

} // namespace imex

#endif // _XeTileToXeGPU_H_INCLUDED_
190 changes: 190 additions & 0 deletions include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
//===- TypeConverter.h - XeTileToXeGPU conversion -------*- C++ -*-===//
//
// Copyright 2022 Intel Corporation
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file defines the SgXeTileToXeGPUConversion, the base class for
/// XeTileToXeGPU conversion, XeGPUTypeConverter, converting types used in
/// XeTile dialect to types used in XeGPU dialect, XeGPUOneToNPatterRewriter a
/// wrapper around ConversionPatterRewriter providng interface for supporting
/// OneToN replace.
///
//===----------------------------------------------------------------------===//

#ifndef _XeTileToXeGPUConversion_H_INCLUDED_
#define _XeTileToXeGPUConversion_H_INCLUDED_

#include <llvm/Support/Debug.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/Dialect/Vector/IR/VectorOps.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Value.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/OneToNTypeConversion.h>

#include "imex/Dialect/XeGPU/IR/XeGPUOps.h"
#include "imex/Dialect/XeTile/IR/XeTileOps.h"
#include "imex/Utils/DebugUtils.h"
#include "imex/Utils/PassWrapper.h"
#include "imex/Utils/XeCommon.h"

namespace imex {

class XeGPUTypeConverter : public imex::XeTypeConverter {
public:
XeGPUTypeConverter(mlir::MLIRContext &context, ValueAttributeMap &map);

std::optional<mlir::LogicalResult>
convertTileType(xetile::TileType tileTy,
llvm::SmallVectorImpl<mlir::Type> &resultTypes) override;

std::optional<mlir::LogicalResult>
convertVectorType(mlir::VectorType vectorTy,
llvm::SmallVectorImpl<mlir::Type> &resultTypes) override;
};

class XeGPUOneToNPatterRewriter : public mlir::PatternRewriter,
public mlir::RewriterBase::Listener {
public:
explicit XeGPUOneToNPatterRewriter(mlir::ConversionPatternRewriter &rewriter,
XeGPUTypeConverter &converter)
: mlir::PatternRewriter(rewriter.getContext()), typeConverter(converter),
rewriter(rewriter) {
setListener(this);
}

mlir::Block *
applySignatureConversion(mlir::Region *region,
mlir::TypeConverter::SignatureConversion &conversion,
const mlir::TypeConverter *converter = nullptr);

template <typename OpTy, typename... Args>
OpTy create(mlir::Location location, Args &&...args) {
return rewriter.create<OpTy>(location, std::forward<Args>(args)...);
}

mlir::FailureOr<mlir::Block *> convertRegionTypes(
mlir::Region *region, const mlir::TypeConverter &converter,
mlir::TypeConverter::SignatureConversion *entryConversion = nullptr) {
return rewriter.convertRegionTypes(region, converter, entryConversion);
}

void inlineRegionBefore(mlir::Region &region, mlir::Region &parent,
mlir::Region::iterator before) override {
rewriter.inlineRegionBefore(region, parent, before);
}

void replaceOp(mlir::Operation *op, mlir::Operation *newOp) override {
assert(op && newOp && "expected non-null op");
replaceOp(op, newOp->getResults());
}

void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) override;

void eraseOp(mlir::Operation *op) override { rewriter.eraseOp(op); }

template <typename CallableT>
void updateRootInPlace(mlir::Operation *root, CallableT &&callable) {
rewriter.updateRootInPlace(root, callable);
}

mlir::ConversionPatternRewriter &mlirConversionPatterRewriter() {
return rewriter;
};

private:
XeGPUTypeConverter &typeConverter;
mlir::ConversionPatternRewriter &rewriter;
};

template <typename SourceOp>
class SgXeTileToXeGPUConversion : public XeConversionPattern {
public:
SgXeTileToXeGPUConversion(mlir::MLIRContext *context,
XeGPUTypeConverter &typeConverter,
mlir::PatternBenefit benefit = 1)
: XeConversionPattern(typeConverter, SourceOp::getOperationName(),
benefit, context) {}

using RangeT = llvm::ArrayRef<mlir::ValueRange>;
using OpAdaptor = typename SourceOp::template GenericAdaptor<RangeT>;

/*
* This overwrites the RewritePattern::matchAndRewrite as it is the entry
* point. It will set up the OpAdaptor such that it contains the converted
* values, and wrap the ConversionPatternRewriter with
* XeGPUOneToNPatterRewriter to provide a clean interface for users.
*/
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const final {
llvm::SmallVector<mlir::ValueRange> convertedValues;

// converted into convertionPatternRewriter since applyPartialConversion
// used it
auto &convertionPatternRewriter =
static_cast<mlir::ConversionPatternRewriter &>(rewriter);

// One-To-One mapping provided by mlir::ConversionPatternRewriter.
// remappedValues contains new values for each operand of the operation. It
// is supposed to be a UnrealizedConversionCastOp (created by the replaceOp
// of XeGPUOneToNPatternRewriter in form of cast newvalues to oldType) for
// each operand that has One-to-N mapping.
llvm::SmallVector<mlir::Value> remappedValues;
if (mlir::failed(convertionPatternRewriter.getRemappedValues(
op->getOperands(), remappedValues))) {
return op->emitOpError("Failed to get remapped values.\n");
// return mlir::failure();
}

// get the One-to-N converted types.
auto operandTys = op->getOperandTypes();
mlir::OneToNTypeMapping operandMapping(operandTys);
if (mlir::failed(
typeConverter.computeTypeMapping(operandTys, operandMapping))) {
return op->emitOpError("Failed to compute Type mapping.\n");
// return mlir::failure();
}

// retrive mapped values for each operand. If its type is not convereted
// (convertedTypes.size() == 1) we will reuse the current value. Otherwise,
// it has one-to-n mapping, and the new value should be an
// UnrealizedConversionCastOp.
for (auto [idx, value] : llvm::enumerate(remappedValues)) {
mlir::TypeRange convertedTypes = operandMapping.getConvertedTypes(idx);
if (convertedTypes.size() == 1) {
convertedValues.push_back(value);
} else if (auto castOp =
llvm::dyn_cast_or_null<mlir::UnrealizedConversionCastOp>(
value.getDefiningOp())) {
convertedValues.push_back(castOp.getInputs());
} else {
return op->emitError(
"[SgXeTileToXeGPUConversion::matchAndRewrite] Unexpected that "
"cannot figure out the remapped input value.");
}
}

auto sourceOp = llvm::dyn_cast<SourceOp>(op);
OpAdaptor adaptor(convertedValues, sourceOp);
XeGPUOneToNPatterRewriter OneToNRewriter(
convertionPatternRewriter, getTypeConverter<XeGPUTypeConverter>());
return matchAndRewrite(sourceOp, adaptor, OneToNRewriter);
}

virtual mlir::LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
XeGPUOneToNPatterRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
};

} // namespace imex

#endif
12 changes: 6 additions & 6 deletions include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def XeGPU_ScatteredAttr : XeGPUAttr<"Scattered", "scattered"> {
let assemblyFormat = "";
}

def XeGPU_SgMapAttr: XeGPUAttr<"SgMap", "sg_map"> {
def XeGPU_SgMapAttr: XeGPUAttr<"SubGroupMap", "sg_map"> {
let parameters = (ins
ArrayRefParameter<"unsigned">:$wiLayout,
ArrayRefParameter<"unsigned">:$wiData,
ArrayRefParameter<"unsigned">:$mmaBlockSize);

// In format of #xegpu.sg_map<{mma_block_size = [2, 4], wi_layout = [2, 4], wi_data = [2, 4]}>
let assemblyFormat = "`<` custom<SgMapAttrElements>($wiLayout, $wiData, $mmaBlockSize) `>`";
let assemblyFormat = "`<` custom<SubGroupMapAttrElements>($wiLayout, $wiData, $mmaBlockSize) `>`";

let genVerifyDecl = true;

Expand All @@ -52,7 +52,7 @@ def XeGPU_SgMapAttr: XeGPUAttr<"SgMap", "sg_map"> {
let skipDefaultBuilders = 1;
}

def XeGPU_WgMapAttr: XeGPUAttr<"WgMap", "wg_map"> {
def XeGPU_WgMapAttr: XeGPUAttr<"WorkGroupMap", "wg_map"> {
let parameters = (ins
ArrayRefParameter<"unsigned">:$sgLayout,
ArrayRefParameter<"unsigned">:$sgData);
Expand All @@ -71,7 +71,7 @@ def XeGPU_WgMapAttr: XeGPUAttr<"WgMap", "wg_map"> {
let skipDefaultBuilders = 1;

// In format of #xegpu.wg_map<{sg_layout = [2, 4], sg_data = [2, 4]}>
let assemblyFormat = "`<` custom<WgMapAttrElements>($sgLayout, $sgData) `>`";
let assemblyFormat = "`<` custom<WorkGroupMapAttrElements>($sgLayout, $sgData) `>`";
}

def XeGPU_XeMapAttr: XeGPUAttr<"XeMap", "xe_map"> {
Expand All @@ -90,8 +90,8 @@ def XeGPU_XeMapAttr: XeGPUAttr<"XeMap", "xe_map"> {
assert(sgLayout.size() == 2 && sgData.size() == 2 && "sgLayout and sgData should be 2D arrays.\n");
assert(wiLayout.size() == 2 && wiData.size() == 2 && "wiLayout and wiData should be 2D arrays.\n");
assert((mmaBlockSize.size() == 2 || mmaBlockSize.size() == 0) && "mmaBlockSize can be either empty or a 2D array.\n");
auto wg = WgMapAttr::get($_ctxt, sgLayout, sgData);
auto sg = SgMapAttr::get($_ctxt, wiLayout, wiData, mmaBlockSize);
auto wg = WorkGroupMapAttr::get($_ctxt, sgLayout, sgData);
auto sg = SubGroupMapAttr::get($_ctxt, wiLayout, wiData, mmaBlockSize);
return $_get($_ctxt, wg, sg);
}]>
];
Expand Down
2 changes: 1 addition & 1 deletion include/imex/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas"> {
);
let results = (outs XeGPU_Vector2DType: $result);
let assemblyFormat = [{
$lhs `,` $rhs (`,` $acc^)? (`{` `mode` `=` $mode^ `}`)? attr-dict `:`
$lhs `,` $rhs (`,` $acc^)? (` ``{` `mode` `=` $mode^ `}`)? attr-dict `:`
qualified(type($lhs)) `,` qualified(type($rhs)) (`,` qualified(type($acc))^)? `->` qualified(type($result))
}];

Expand Down
2 changes: 1 addition & 1 deletion include/imex/Dialect/XeTile/IR/XeTileOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def XeTile_PrefetchTileOp : XeTile_Op<"prefetch_tile", []> {
}];
}

def XeTile_TileMMAOp : XeTile_Op<"tile_mma", [Pure]> {
def XeTile_TileMMAOp : XeTile_Op<"tile_mma", []> {
let summary = "matrix multiplication in blocked layout";
let description = [{
"tile_mma" operation represents matrix multiplication on 2D or 4D vectors. This operation
Expand Down
11 changes: 6 additions & 5 deletions include/imex/Dialect/XeTile/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ class RewritePatternSet;

namespace imex {

class XeTypeConverter;

//===----------------------------------------------------------------------===//
/// XeTile passes.
//===----------------------------------------------------------------------===//

/// Create a pass for converting XeTile Ops to XeGPU Ops
std::unique_ptr<::mlir::Pass> createXeTileToXeGPUPass();
std::unique_ptr<mlir::Pass> createXeTileTilingPass();

/// Populate the given list with patterns that eliminate XeTile ops
void populateXeTileToXeGPUPatterns(::mlir::LLVMTypeConverter &converter,
::mlir::RewritePatternSet &patterns);
///
void populateXeTileTilingPatterns(imex::XeTypeConverter &converter,
mlir::RewritePatternSet &patterns);

//===----------------------------------------------------------------------===//
// Registration
Expand Down
Loading

0 comments on commit 1577e15

Please sign in to comment.