diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUOps.td b/include/imex/Dialect/XeGPU/IR/XeGPUOps.td index 3c8c9018f..b2baf565c 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUOps.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUOps.td @@ -610,7 +610,7 @@ def XeGPU_CreateNbarrierOp DefaultValuedAttr: $mode ); - let results = (outs Builtin_Vector: $result); + let results = (outs XeGPU_Nbarrier: $result); let assemblyFormat = [{ $nbarrier_id `,` $nbarrier_role @@ -626,7 +626,7 @@ def XeGPU_NbarrierArriveOp let summary = "arrive at a named barrier."; let arguments = (ins - Builtin_Vector: $payload + XeGPU_Nbarrier: $payload ); let assemblyFormat = [{ @@ -639,7 +639,7 @@ def XeGPU_NbarrierWaitOp let summary = "wait for a named barrier."; let arguments = (ins - Builtin_Vector: $payload + XeGPU_Nbarrier: $payload ); let assemblyFormat = [{ @@ -647,8 +647,8 @@ def XeGPU_NbarrierWaitOp }]; } -def XeGPU_CompilerHintOp - : XeGPU_Op<"compiler_hint", []> { +def XeGPU_CompileHintOp + : XeGPU_Op<"compile_hint", []> { let summary = "prevents the compiler from scheduling."; let assemblyFormat = [{ diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td b/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td index b15072834..6f6f3df59 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td @@ -40,7 +40,7 @@ class XeGPUTypeDef { let summary = "TensorDesc type describing all kinds of memory and tensors scatter tensor, 1d tensor, 2d tensor, … 5d tensor"; @@ -102,4 +102,15 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", let assemblyFormat = "`<` custom($shape, $elementType)``custom($memory_scope, $encoding)`>`"; } + +def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { + let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier."; + + let extraClassDeclaration = [{ + static NbarrierType get(mlir::MLIRContext *context) { + return Base::get(context); + }; + }]; +} + #endif // _XEGPU_TYPES_TD_INCLUDED_ diff --git a/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index cc35fe96c..c149f5aaa 100644 --- a/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -156,6 +156,10 @@ void GPUXToSPIRVPass::runOnOperation() { eraseOp->erase(); } target->addIllegalDialect(); + typeConverter.addConversion([&](xegpu::NbarrierType type) -> ::mlir::Type { + auto i32Type = ::mlir::IntegerType::get(context, 32); + return mlir::VectorType::get(8, i32Type); + }); typeConverter.addConversion( [&](xegpu::TensorDescType type) -> ::mlir::Type { auto i64Type = ::mlir::IntegerType::get(context, 64); diff --git a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp index 86025f9ae..a4cd4e437 100644 --- a/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp +++ b/lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp @@ -1066,7 +1066,7 @@ class NbarrierArriveToVCPattern : public OpConversionPattern { matchAndRewrite(NbarrierArriveOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto payload = op.getPayload(); + auto payload = adaptor.getPayload(); std::string funcName = "llvm_genx_raw_send2_noresult_i1_v8i32"; @@ -1101,7 +1101,7 @@ class NbarrierWaitToVCPattern : public OpConversionPattern { matchAndRewrite(NbarrierWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto payload = op.getPayload(); + auto payload = adaptor.getPayload(); auto i8Type = rewriter.getIntegerType(8); auto i32Type = rewriter.getIntegerType(32); @@ -1127,11 +1127,11 @@ class NbarrierWaitToVCPattern : public OpConversionPattern { } }; -class CompilerHintToVCPattern : public OpConversionPattern { +class CompilerHintToVCPattern : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CompilerHintOp op, OpAdaptor adaptor, + matchAndRewrite(CompileHintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); diff --git a/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 21d083d25..225aae701 100644 --- a/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -83,7 +83,8 @@ bool dpasSupportedTypes(mlir::Type type, bool isResult) { else return false; } else { - if (type.isF16() || type.isBF16() || type.isInteger(8)) + if (type.isF16() || type.isBF16() || type.isInteger(16) || + type.isInteger(8)) return true; else return false; diff --git a/test/Conversion/XeGPUToSPIRV/barrier_basic.mlir b/test/Conversion/XeGPUToSPIRV/barrier_basic.mlir index e95c9eb78..0d72915de 100644 --- a/test/Conversion/XeGPUToSPIRV/barrier_basic.mlir +++ b/test/Conversion/XeGPUToSPIRV/barrier_basic.mlir @@ -24,11 +24,11 @@ module @gemm attributes {gpu.container_module} { xegpu.alloc_nbarrier 16 %nbarrier_id = arith.constant 1 : i8 %nbarrier_role = arith.constant 0 : i8 - %payload = xegpu.create_nbarrier %nbarrier_id, %nbarrier_role {num_producers = 32 : i8, num_consumers = 32 : i8} : (i8, i8) -> vector<8xi32> - xegpu.nbarrier_arrive %payload : vector<8xi32> + %payload = xegpu.create_nbarrier %nbarrier_id, %nbarrier_role {num_producers = 32 : i8, num_consumers = 32 : i8} : (i8, i8) -> !xegpu.nbarrier + xegpu.nbarrier_arrive %payload : !xegpu.nbarrier xegpu.mfence {memory_kind = "ugm" , fence_op = "none", fence_scope = "local"} - xegpu.compiler_hint - xegpu.nbarrier_wait %payload : vector<8xi32> + xegpu.compile_hint + xegpu.nbarrier_wait %payload : !xegpu.nbarrier gpu.return } } diff --git a/test/Dialect/XeGPU/IR/barrier_ops.mlir b/test/Dialect/XeGPU/IR/barrier_ops.mlir index 079e916b4..af63fc5d6 100644 --- a/test/Dialect/XeGPU/IR/barrier_ops.mlir +++ b/test/Dialect/XeGPU/IR/barrier_ops.mlir @@ -17,32 +17,32 @@ func.func @create_nbarrier() { %nbarrier_role = arith.constant 0 : i8 // CHECK: xegpu.create_nbarrier // CHECK-SAME: {num_consumers = 32 : i8, num_producers = 32 : i8} - // CHECK-SAME: (i8, i8) -> vector<8xi32> + // CHECK-SAME: (i8, i8) -> !xegpu.nbarrier %nbarrier = xegpu.create_nbarrier %nbarrier_id, %nbarrier_role {num_producers = 32 :i8 , num_consumers = 32 : i8} - : (i8, i8) -> vector<8xi32> + : (i8, i8) -> !xegpu.nbarrier return } // CHECK-LABEL: func @nbarrier_arrive({{.*}}) { -func.func @nbarrier_arrive(%nbarrier : vector<8xi32>) { +func.func @nbarrier_arrive(%nbarrier : !xegpu.nbarrier) { // CHECK: xegpu.nbarrier_arrive - // CHECK-SAME: vector<8xi32> - xegpu.nbarrier_arrive %nbarrier : vector<8xi32> + // CHECK-SAME: !xegpu.nbarrier + xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier return } // CHECK-LABEL: func @nbarrier_wait({{.*}}) { -func.func @nbarrier_wait(%nbarrier : vector<8xi32>) { +func.func @nbarrier_wait(%nbarrier : !xegpu.nbarrier) { // CHECK: xegpu.nbarrier_wait - // CHECK-SAME: vector<8xi32> - xegpu.nbarrier_wait %nbarrier : vector<8xi32> + // CHECK-SAME: !xegpu.nbarrier + xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier return } -// CHECK-LABEL: func @compiler_hint({{.*}}) { -func.func @compiler_hint() { - // CHECK: xegpu.compiler_hint - xegpu.compiler_hint +// CHECK-LABEL: func @compile_hint({{.*}}) { +func.func @compile_hint() { + // CHECK: xegpu.compile_hint + xegpu.compile_hint return }