Skip to content

Commit

Permalink
update XeGPUToSPIRV pass acommondating xegpu.NbarrierType
Browse files Browse the repository at this point in the history
add i16 support for dpas
add xegpu.nbarrier type
  • Loading branch information
chencha3 authored and silee2 committed Nov 16, 2023
1 parent 58c5124 commit 73249a3
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 28 deletions.
10 changes: 5 additions & 5 deletions include/imex/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def XeGPU_CreateNbarrierOp
DefaultValuedAttr<XeGPU_ModeAttr, "imex::xegpu::Mode::SIMT">: $mode
);

let results = (outs Builtin_Vector: $result);
let results = (outs XeGPU_Nbarrier: $result);

let assemblyFormat = [{
$nbarrier_id `,` $nbarrier_role
Expand All @@ -626,7 +626,7 @@ def XeGPU_NbarrierArriveOp
let summary = "arrive at a named barrier.";

let arguments = (ins
Builtin_Vector: $payload
XeGPU_Nbarrier: $payload
);

let assemblyFormat = [{
Expand All @@ -639,16 +639,16 @@ def XeGPU_NbarrierWaitOp
let summary = "wait for a named barrier.";

let arguments = (ins
Builtin_Vector: $payload
XeGPU_Nbarrier: $payload
);

let assemblyFormat = [{
$payload attr-dict `:` qualified(type($payload))
}];
}

def XeGPU_CompilerHintOp
: XeGPU_Op<"compiler_hint", []> {
def XeGPU_CompileHintOp
: XeGPU_Op<"compile_hint", []> {
let summary = "prevents the compiler from scheduling.";

let assemblyFormat = [{
Expand Down
13 changes: 12 additions & 1 deletion include/imex/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class XeGPUTypeDef<string name, string typeMnemonic,
let mnemonic = typeMnemonic;
}

// TODO:
// TensorDesc contains dim and element type info
def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
[ShapedTypeInterface], "::mlir::TensorType"> {
let summary = "TensorDesc type describing all kinds of memory and tensors scatter tensor, 1d tensor, 2d tensor, … 5d tensor";
Expand Down Expand Up @@ -102,4 +102,15 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let assemblyFormat = "`<` custom<ShapeAndType>($shape, $elementType)``custom<TensorDescAttr>($memory_scope, $encoding)`>`";
}


def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";

let extraClassDeclaration = [{
static NbarrierType get(mlir::MLIRContext *context) {
return Base::get(context);
};
}];
}

#endif // _XEGPU_TYPES_TD_INCLUDED_
4 changes: 4 additions & 0 deletions lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ void GPUXToSPIRVPass::runOnOperation() {
eraseOp->erase();
}
target->addIllegalDialect<imex::xegpu::XeGPUDialect>();
typeConverter.addConversion([&](xegpu::NbarrierType type) -> ::mlir::Type {
auto i32Type = ::mlir::IntegerType::get(context, 32);
return mlir::VectorType::get(8, i32Type);
});
typeConverter.addConversion(
[&](xegpu::TensorDescType type) -> ::mlir::Type {
auto i64Type = ::mlir::IntegerType::get(context, 64);
Expand Down
10 changes: 5 additions & 5 deletions lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ class NbarrierArriveToVCPattern : public OpConversionPattern<NbarrierArriveOp> {
matchAndRewrite(NbarrierArriveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto payload = op.getPayload();
auto payload = adaptor.getPayload();

std::string funcName = "llvm_genx_raw_send2_noresult_i1_v8i32";

Expand Down Expand Up @@ -1101,7 +1101,7 @@ class NbarrierWaitToVCPattern : public OpConversionPattern<NbarrierWaitOp> {
matchAndRewrite(NbarrierWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto payload = op.getPayload();
auto payload = adaptor.getPayload();

auto i8Type = rewriter.getIntegerType(8);
auto i32Type = rewriter.getIntegerType(32);
Expand All @@ -1127,11 +1127,11 @@ class NbarrierWaitToVCPattern : public OpConversionPattern<NbarrierWaitOp> {
}
};

class CompilerHintToVCPattern : public OpConversionPattern<CompilerHintOp> {
class CompilerHintToVCPattern : public OpConversionPattern<CompileHintOp> {
public:
using OpConversionPattern<CompilerHintOp>::OpConversionPattern;
using OpConversionPattern<CompileHintOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CompilerHintOp op, OpAdaptor adaptor,
matchAndRewrite(CompileHintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();

Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ bool dpasSupportedTypes(mlir::Type type, bool isResult) {
else
return false;
} else {
if (type.isF16() || type.isBF16() || type.isInteger(8))
if (type.isF16() || type.isBF16() || type.isInteger(16) ||
type.isInteger(8))
return true;
else
return false;
Expand Down
8 changes: 4 additions & 4 deletions test/Conversion/XeGPUToSPIRV/barrier_basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ module @gemm attributes {gpu.container_module} {
xegpu.alloc_nbarrier 16
%nbarrier_id = arith.constant 1 : i8
%nbarrier_role = arith.constant 0 : i8
%payload = xegpu.create_nbarrier %nbarrier_id, %nbarrier_role {num_producers = 32 : i8, num_consumers = 32 : i8} : (i8, i8) -> vector<8xi32>
xegpu.nbarrier_arrive %payload : vector<8xi32>
%payload = xegpu.create_nbarrier %nbarrier_id, %nbarrier_role {num_producers = 32 : i8, num_consumers = 32 : i8} : (i8, i8) -> !xegpu.nbarrier
xegpu.nbarrier_arrive %payload : !xegpu.nbarrier
xegpu.mfence {memory_kind = "ugm" , fence_op = "none", fence_scope = "local"}
xegpu.compiler_hint
xegpu.nbarrier_wait %payload : vector<8xi32>
xegpu.compile_hint
xegpu.nbarrier_wait %payload : !xegpu.nbarrier
gpu.return
}
}
Expand Down
24 changes: 12 additions & 12 deletions test/Dialect/XeGPU/IR/barrier_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,32 @@ func.func @create_nbarrier() {
%nbarrier_role = arith.constant 0 : i8
// CHECK: xegpu.create_nbarrier
// CHECK-SAME: {num_consumers = 32 : i8, num_producers = 32 : i8}
// CHECK-SAME: (i8, i8) -> vector<8xi32>
// CHECK-SAME: (i8, i8) -> !xegpu.nbarrier
%nbarrier = xegpu.create_nbarrier %nbarrier_id, %nbarrier_role {num_producers = 32 :i8 , num_consumers = 32 : i8}
: (i8, i8) -> vector<8xi32>
: (i8, i8) -> !xegpu.nbarrier
return
}

// CHECK-LABEL: func @nbarrier_arrive({{.*}}) {
func.func @nbarrier_arrive(%nbarrier : vector<8xi32>) {
func.func @nbarrier_arrive(%nbarrier : !xegpu.nbarrier) {
// CHECK: xegpu.nbarrier_arrive
// CHECK-SAME: vector<8xi32>
xegpu.nbarrier_arrive %nbarrier : vector<8xi32>
// CHECK-SAME: !xegpu.nbarrier
xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier
return
}

// CHECK-LABEL: func @nbarrier_wait({{.*}}) {
func.func @nbarrier_wait(%nbarrier : vector<8xi32>) {
func.func @nbarrier_wait(%nbarrier : !xegpu.nbarrier) {
// CHECK: xegpu.nbarrier_wait
// CHECK-SAME: vector<8xi32>
xegpu.nbarrier_wait %nbarrier : vector<8xi32>
// CHECK-SAME: !xegpu.nbarrier
xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier
return
}

// CHECK-LABEL: func @compiler_hint({{.*}}) {
func.func @compiler_hint() {
// CHECK: xegpu.compiler_hint
xegpu.compiler_hint
// CHECK-LABEL: func @compile_hint({{.*}}) {
func.func @compile_hint() {
// CHECK: xegpu.compile_hint
xegpu.compile_hint
return
}

Expand Down

0 comments on commit 73249a3

Please sign in to comment.