From a3ec44979facb2a75d152419a929f37f0b760ae7 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 19 Oct 2023 19:04:57 -0500 Subject: [PATCH] Extend XeGPU with SIMT mode Co-authored with Charitha Saumya. --- include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td | 47 +- include/imex/Dialect/XeGPU/IR/XeGPUDialect.td | 6 + include/imex/Dialect/XeGPU/IR/XeGPUOps.td | 195 +++--- include/imex/Dialect/XeGPU/IR/XeGPUTypes.td | 14 +- include/imex/Utils/XeUtils.h | 2 +- lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 197 ++++++ lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 650 ++++++++++++------ .../gemm_1024x1024xf16.runnable.mlir | 16 +- test/Conversion/XeGPUToVC/gemm_basic.mlir | 18 +- test/Dialect/XeGPU/IR/XeGPUOps.mlir | 76 +- test/Dialect/XeGPU/IR/atomic_rmw.mlir | 12 +- test/Dialect/XeGPU/IR/create_nd_tdesc.mlir | 67 +- test/Dialect/XeGPU/IR/create_tdesc.mlir | 73 +- test/Dialect/XeGPU/IR/invalid.mlir | 66 +- test/Dialect/XeGPU/IR/load_gather.mlir | 64 +- test/Dialect/XeGPU/IR/load_nd.mlir | 158 ++++- test/Dialect/XeGPU/IR/prefetch_nd.mlir | 22 +- test/Dialect/XeGPU/IR/simple_gemm.mlir | 23 +- test/Dialect/XeGPU/IR/store_nd.mlir | 20 +- test/Dialect/XeGPU/IR/store_scatter.mlir | 54 +- test/Dialect/XeGPU/IR/update_nd_offset.mlir | 16 +- test/Dialect/XeGPU/IR/update_offset.mlir | 45 +- 22 files changed, 1274 insertions(+), 567 deletions(-) diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td b/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td index 2c414b7da..5f4864511 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -6,12 +6,12 @@ include "imex/Dialect/XeGPU/IR/XeGPUDialect.td" include "mlir/IR/EnumAttr.td" -class XeGPUAttrDef traits = [], string baseCppClass = "::mlir::Attribute"> +class XeGPUAttr traits = [], string baseCppClass = "::mlir::Attribute"> : AttrDef { let mnemonic = attrMnemonic; } -def XeGPU_ScatteredAttr : XeGPUAttrDef<"Scattered", "scattered", []> { +def XeGPU_ScatteredAttr : XeGPUAttr<"Scattered", "scattered"> { let summary = "Scattered attribute for scattered read and write operation."; let description = [{An attribute represent scattered read and write operation. It does not (need to) have meaningful input values. The existence of itself @@ -20,10 +20,49 @@ def XeGPU_ScatteredAttr : XeGPUAttrDef<"Scattered", "scattered", []> { let assemblyFormat = ""; } +def XeGPU_SgMapAttr: XeGPUAttr<"SgMap", "sg_map"> { + let parameters = (ins + ArrayRefParameter<"unsigned">:$mmaBlockSize, + ArrayRefParameter<"unsigned">:$wiLayout, + ArrayRefParameter<"unsigned">:$wiData); + + // In format of #xegpu.sg_map<{mma_block_size = [2, 4], wi_layout = [2, 4], wi_data = [2, 4]}> + let assemblyFormat = "`<` custom($mmaBlockSize, $wiLayout, $wiData) `>`"; +} + +def XeGPU_WgMapAttr: XeGPUAttr<"WgMap", "wg_map"> { + let parameters = (ins + ArrayRefParameter<"unsigned">:$sgLayout, + ArrayRefParameter<"unsigned">:$sgData); + + // In format of #xegpu.wg_map<{sg_layout = [2, 4], sg_data = [2, 4]}> + let assemblyFormat = "`<` custom($sgLayout, $sgData) `>`"; +} + +def XeGPU_XeMapAttr: XeGPUAttr<"XeMap", "xe_map"> { + let parameters = (ins + XeGPU_WgMapAttr: $wg, + XeGPU_SgMapAttr: $sg); + + // In format of #xegpu.xe_map + let hasCustomAssemblyFormat = 1; +} + +def XeGPU_ArgTypeAttr : I32EnumAttr< + "ArgType", "", [ I32EnumAttrCase<"Vector", 0, "vector">, + I32EnumAttrCase<"Scalar", 1, "scalar"> ]> { + let cppNamespace = "::imex::xegpu"; +} + +def XeGPU_ModeAttr : I32EnumAttr< + "Mode", "", [ I32EnumAttrCase<"SIMT", 0, "simt">, + I32EnumAttrCase<"VC", 1, "vc"> ]> { + let cppNamespace = "::imex::xegpu"; +} def XeGPU_MemoryScopeAttr : I32EnumAttr< - "MemoryScope", "", [ I32EnumAttrCase<"GLOBAL", 1, "global">, - I32EnumAttrCase<"SLM", 2, "slm"> ]> { + "MemoryScope", "", [ I32EnumAttrCase<"GLOBAL", 0, "global">, + I32EnumAttrCase<"SLM", 1, "slm"> ]> { let cppNamespace = "::imex::xegpu"; } diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUDialect.td b/include/imex/Dialect/XeGPU/IR/XeGPUDialect.td index 517e13566..b2b94111c 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUDialect.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUDialect.td @@ -38,9 +38,15 @@ def XeGPUDialect : Dialect { let dependentDialects = ["::mlir::memref::MemRefDialect"]; + // let extraClassDeclaration = [{ + // mlir::Attribute parseAttribute(mlir::DialectAsmParser &parser, mlir::Type type) const; + // void printAttribute(mlir::Attribute attr, mlir::DialectAsmPrinter &printer) const; + // }]; + // TODO: temporary disable it. let useDefaultTypePrinterParser = true; let useDefaultAttributePrinterParser = true; + // let useDefaultAttributePrinterParser = false; } #endif //XEGPU_DIALECT diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUOps.td b/include/imex/Dialect/XeGPU/IR/XeGPUOps.td index 11489bef9..fe5372471 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUOps.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUOps.td @@ -83,8 +83,8 @@ def XeGPU_CreateNdDescOp : XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSe Variadic: $strides, DenseI64ArrayAttr: $static_offsets, DefaultValuedAttr: $memory_scope, - DefaultValuedAttr: $boundary_check - ); + DefaultValuedAttr: $boundary_check, + DefaultValuedAttr: $mode); let results = (outs XeGPU_TensorDesc:$TensorDesc); @@ -96,7 +96,7 @@ def XeGPU_CreateNdDescOp : XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSe OpBuilder<(ins "::mlir::Type": $TensorDesc, "::mlir::Value": $source, "::mlir::ValueRange": $offsets, "::mlir::ValueRange": $shape, "::mlir::ValueRange": $strides, "::llvm::ArrayRef": $static_offsets, CArg<"::imex::xegpu::MemoryScope", "xegpu::MemoryScope::GLOBAL">: $memory_scope, - CArg<"bool", "true">: $boundary_check), + CArg<"bool", "true">: $boundary_check, CArg<"::imex::xegpu::Mode", "imex::xegpu::Mode::SIMT">: $mode), [{ $_state.addOperands(source); $_state.addOperands(offsets); $_state.addOperands(shape); @@ -105,11 +105,12 @@ def XeGPU_CreateNdDescOp : XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSe $_state.addAttribute(getStaticOffsetsAttrName($_state.name), $_builder.getDenseI64ArrayAttr(static_offsets)); $_state.addAttribute(getMemoryScopeAttrName($_state.name), ::imex::xegpu::MemoryScopeAttr::get($_builder.getContext(), memory_scope)); $_state.addAttribute(getBoundaryCheckAttrName($_state.name), $_builder.getBoolAttr(boundary_check)); + $_state.addAttribute(getBoundaryCheckAttrName($_state.name), ::imex::xegpu::ModeAttr::get($_builder.getContext(), mode)); $_state.addTypes(TensorDesc); }]>, OpBuilder<(ins "::mlir::Type": $tdesc, "::mlir::Value": $source, "::llvm::ArrayRef": $offsets, CArg<"::imex::xegpu::MemoryScope", "::imex::xegpu::MemoryScope::GLOBAL">:$memory_scope, - CArg<"bool", "true">:$boundary_check), + CArg<"bool", "true">:$boundary_check, CArg<"::imex::xegpu::Mode", "imex::xegpu::Mode::SIMT">: $mode), [{ assert(offsets.size() == getRankOf(source)); llvm::SmallVector staticOffsets; llvm::SmallVector dynamicOffsets; @@ -120,13 +121,14 @@ def XeGPU_CreateNdDescOp : XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSe ::mlir::ValueRange({}) /* empty dynamic strides */, staticOffsets /* static offsets */, memory_scope, - boundary_check); }]>, + boundary_check, + mode); }]>, OpBuilder<(ins "::mlir::Type": $tdesc, "::mlir::Value": $source, "::llvm::ArrayRef": $offsets, "::mlir::ValueRange": $shape, "::mlir::ValueRange": $stride, CArg<"::imex::xegpu::MemoryScope", "xegpu::MemoryScope::GLOBAL">:$memory_scope, - CArg<"bool", "true">:$boundary_check), + CArg<"bool", "true">:$boundary_check, CArg<"::imex::xegpu::Mode", "imex::xegpu::Mode::SIMT">: $mode), [{ assert((!isMemRef(source) || getRankOf(source) == offsets.size()) && shape.size() == stride.size() && offsets.size() == shape.size() && isIntegerOrDynamicShapedMemref(source)); @@ -140,7 +142,8 @@ def XeGPU_CreateNdDescOp : XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSe stride /* empty dynamic strides */, staticOffsets /* static offsets */, memory_scope, - boundary_check); }]> + boundary_check, + mode); }]> ]; let extraClassDeclaration = [{ @@ -252,7 +255,10 @@ def XeGPU_CreateDescOp (scattered) subviews. It accepts the following parameters: * source: a 1D memref or pointer (uint64_t) represents the memory object. - * offsets: a 1D vector containing offsets of each access point, the size is aligned with supportted group size, e.g., vector<16xindex>. + * offsets: In VectorCompute (VC) mode, it is a 1D vector containing offsets of each access point, the size is aligned with + supportted group size, e.g., vector<16xindex>. And each element in the vector corresponds to a + work item (SIMT lane) in the subgroup. + In SIMT mode (default), it is an index scalar representing the offset of the access point. * memory_scope: [optional attribute] indicates where the memory is located, "global" for global memory (default), and "slm" for shared memory. * chunk_size_per_lane: [optional attribute] indicates number of continious elements accessed for each offset, default is 1. @@ -261,35 +267,45 @@ def XeGPU_CreateDescOp %c0 = arith.constant dense<0, 16, 32, 64> : vector<4xindex> %1 = xegpu.create_tdesc %a, %c0: memref<1024xf32> -> TensorDesc<4xf32> - Example 2. It assumes subgroup size is 4, and each workitem access 8 elements. It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71] + Example 2. It assumes subgroup size is 4, and each workitem access 8 elements. + It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71] %0 = memref.alloc() : memref<1024xf32> %c0 = arith.constant dense<0, 16, 32, 64> : vector<4xindex> - %1 = xegpu.create_nd_tdesc %0, %c0 {{chunk_size_per_lane = 8}}: memref<1024xf32> -> TensorDesc<4x8xf32> + %1 = xegpu.create_tdesc %0, %c0 {chunk_size_per_lane = 8}: memref<1024xf32> -> TensorDesc<4x8xf32> + + Example 3. an SIMT mode example, accessing a[16]. + %a = memref.alloc() : memref<1024xf32> + %c0 = arith.constant 16 : index + %1 = xegpu.create_tdesc %a, %c0: memref<1024xf32> -> TensorDesc<1xf32> }]; let arguments = (ins XeGPU_BaseAddrType: $source, - VectorOfRankAndType<[1], [Index]>: $offsets, + XeGPU_OffsetType: $offsets, DefaultValuedAttr: $memory_scope, - DefaultValuedAttr: $chunk_size_per_lane); + DefaultValuedAttr: $chunk_size_per_lane, + DefaultValuedAttr: $mode); let results = (outs XeGPU_TensorDesc:$TensorDesc); let extraClassDeclaration = [{ static size_t getRankOf(mlir::Value value) { - if (llvm::isa<::mlir::IntegerType>(value.getType())) + if (value.getType().isIntOrIndexOrFloat()) return 0; - if (llvm::isa<::mlir::MemRefType>(value.getType())) - return llvm::cast<::mlir::MemRefType>(value.getType()).getRank(); + if (llvm::isa(value.getType())) + return llvm::cast(value.getType()).getRank(); + if (llvm::isa(value.getType())) + return llvm::cast(value.getType()).getRank(); assert(0 && "Unreachable"); } }]; + // Format: xegpu.create_tdesc %src, %offsets {mode=simt, memory_scope=slm, chunk_size_per_lane=1} + // : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; } @@ -310,20 +326,12 @@ def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> { OptionalAttr: $transpose, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint - ); - let results = (outs XeGPU_VectorType: $value); - - // let assemblyFormat = [{ - // $TensorDesc - // (`vnni_axis` $vnni_axis^)? - // (`transpose` $transpose^)? - // (`l1_hint` $l1_hint^)? - // (`l2_hint` $l2_hint^)? - // (`l3_hint` $l3_hint^)? - // attr-dict `:` qualified(type($TensorDesc)) `->` qualified(type($value)) - // }]; + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode); + let results = (outs XeGPU_ValueType: $value); + // Format: xegpu.load_nd %1 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint=streaming} + // : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> let hasCustomAssemblyFormat = 1; let hasVerifier = 1; @@ -333,20 +341,15 @@ def XeGPU_StoreNDOp : XeGPU_Op<"store_nd", []> { let summary = "stores a n-D block register region back to memory, currently only supports 2D"; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - XeGPU_Vector2DType: $value, + XeGPU_ValueType: $value, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint - ); - - // let assemblyFormat = [{ - // $TensorDesc `,`` `$value - // (`l1_hint` $l1_hint^)? - // (`l2_hint` $l2_hint^)? - // (`l3_hint` $l3_hint^)? - // attr-dict `:` `(` qualified(type($TensorDesc)) `,` qualified(type($value)) `)` - // }]; + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode + ); + // Format: xegpu.store_nd %3, %2 {l1_hint = write_back, l2_hint = uncached} + // : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -356,30 +359,26 @@ def XeGPU_PrefetchNDOp : XeGPU_Op<"prefetch_nd", []> { let arguments = (ins XeGPU_TensorDesc: $TensorDesc, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode ); - // let assemblyFormat = [{ - // $TensorDesc - // (`l1_hint` $l1_hint^)? - // (`l2_hint` $l2_hint^)? - // (`l3_hint` $l3_hint^)? - // attr-dict `:` qualified(type($TensorDesc)) - // }]; - + // In format of: xegpu.prefetch_nd %tdesc {l1_hint = cached, l2_hint = uncached}: + // !xegpu.tensor_desc<8x16xf16> let hasCustomAssemblyFormat = 1; } def XeGPU_DpasOp : XeGPU_Op<"dpas"> { let summary = "performs dpas computation"; let arguments = (ins - XeGPU_Vector3DType : $lhs, - XeGPU_Vector3DType : $rhs, - Optional: $acc + XeGPU_DpasOpType : $lhs, + XeGPU_DpasOpType : $rhs, + Optional: $acc, + DefaultValuedAttr: $mode ); let results = (outs XeGPU_Vector2DType: $result); let assemblyFormat = [{ - $lhs `,` $rhs (`,` $acc^)? attr-dict `:` + $lhs `,` $rhs (`,` $acc^)? (`{` `mode` `=` $mode^ `}`)? attr-dict `:` qualified(type($lhs)) `,` qualified(type($rhs)) (`,` qualified(type($acc))^)? `->` qualified(type($result)) }]; @@ -407,23 +406,19 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load"> { let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - VectorOfRankAndType<[1, 2], [I1]>: $mask, + XeGPU_MaskType: $mask, OptionalAttr: $vnni_axis, OptionalAttr: $transpose, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode ); - let results = (outs XeGPU_VectorType: $value); + let results = (outs XeGPU_ValueType: $value); - // let assemblyFormat = [{ - // $TensorDesc - // (`l1_hint` $l1_hint^)? - // (`l2_hint` $l2_hint^)? - // (`l3_hint` $l3_hint^)? - // attr-dict `:` qualified(type($TensorDesc)) `->` qualified(type($value)) - // }]; + // In format of: %2 = xegpu.load %1, %0 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached} + // : !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32> let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -432,21 +427,17 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", []> { let summary = "store a scalar to source[offset]."; let arguments = (ins - XeGPU_VectorType: $value, + XeGPU_ValueType: $value, XeGPU_TensorDesc: $TensorDesc, - VectorOfRankAndType<[1, 2], [I1]>: $mask, + XeGPU_MaskType: $mask, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode ); - // let assemblyFormat = [{ - // $value `,` $TensorDesc, `,` $mask - // (`l1_hint` $l1_hint^)? - // (`l2_hint` $l2_hint^)? - // (`l3_hint` $l3_hint^)? - // attr-dict `:` `(` qualified(type($value)) `,` qualified(type($TensorDesc)) `,` qualified(type($mask)) `)` - // }]; + // Format: %3 = xegpu.load %1, %0 {l1_hint = cached, l2_hint = uncached} + // : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -456,12 +447,13 @@ def XeGPU_UpdateNDOffsetOp : XeGPU_Op<"update_nd_offset", []> { let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - Variadic: $offsets); + Variadic: $offsets, + DefaultValuedAttr: $mode); let results = (outs XeGPU_TensorDesc: $result); let assemblyFormat = [{ - $TensorDesc `,` (`[` $offsets^ `]`)? + $TensorDesc `,` (`[` $offsets^ `]`)? (`{` `mode` `=` $mode^ `}`)? attr-dict `:` qualified(type($TensorDesc)) `->` qualified(type($result)) }]; @@ -474,18 +466,52 @@ def XeGPU_UpdateOffsetOp let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - VectorOfRankAndType<[1], [Index]> : $offsets + XeGPU_OffsetType: $offsets, + DefaultValuedAttr: $mode ); let results = (outs XeGPU_TensorDesc: $result); let assemblyFormat = [{ - $TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` qualified(type($offsets)) `->` qualified(type($result)) + $TensorDesc `,` $offsets (`{` `mode` `=` $mode^ `}`)? + attr-dict `:` qualified(type($TensorDesc)) `,` qualified(type($offsets)) `->` qualified(type($result)) }]; let hasVerifier = 1; } +def XeGPU_InvokeSIMDOp : XeGPU_Op<"invoke_SIMD", []> { + let summary = "Invoke_SIMD operation"; + let description = [{ + The `xegpu.invoke_SIMD` operation works similar to a direct call to a function. But it is + special to Intel GPU. + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, + Variadic:$operands, + XeGPU_ArgTypeAttr: $argType); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::TypeRange":$results, + "imex::xegpu::ArgTypeAttr":$argType, CArg<"mlir::ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("argType", argType); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "mlir::StringAttr":$callee, "mlir::TypeRange":$results, + "imex::xegpu::ArgTypeAttr":$argType, CArg<"mlir::ValueRange", "{}">:$operands), [{ + build($_builder, $_state, mlir::SymbolRefAttr::get(callee), results, argType, operands); + }]>, + OpBuilder<(ins "llvm::StringRef":$callee, "mlir::TypeRange":$results, + "imex::xegpu::ArgTypeAttr":$argType, CArg<"mlir::ValueRange", "{}">:$operands), [{ + build($_builder, $_state, mlir::StringAttr::get($_builder.getContext(), callee), + results, argType, operands); + }]>]; + +} + def XeGPU_AtomicRMWOp : XeGPU_Op<"atomic_rmw", []> { let summary = "performa ready-modify-write operation that is free from data races."; @@ -494,11 +520,13 @@ def XeGPU_AtomicRMWOp XeGPU_AtomicRMWKindAttr:$kind, XeGPU_Vector2DType:$value, XeGPU_TensorDesc:$tensorDesc, - XeGPU_Vector1DType:$mask + XeGPU_MaskType:$mask, + DefaultValuedAttr: $mode ); let assemblyFormat = [{ - $kind $value `,` $tensorDesc `,` $mask attr-dict + $kind $value `,` $tensorDesc `,` $mask (`{` `mode` `=` $mode^ `}`)? + attr-dict `:` `(` qualified(type($value)) `,` qualified(type($tensorDesc)) `,` qualified(type($mask)) `)` }]; } @@ -525,16 +553,19 @@ def XeGPU_CreateNbarrierOp I8: $nbarrier_id, I8: $nbarrier_role, I8Attr: $num_producers, - I8Attr: $num_consumers + I8Attr: $num_consumers, + DefaultValuedAttr: $mode ); let results = (outs Builtin_Vector: $result); let assemblyFormat = [{ - $nbarrier_id `,` $nbarrier_role attr-dict `:` - `(` qualified(type($nbarrier_id)) `,` qualified(type($nbarrier_role)) `)` + $nbarrier_id `,` $nbarrier_role + attr-dict `:` `(` qualified(type($nbarrier_id)) `,` qualified(type($nbarrier_role)) `)` `->` qualified(type($result)) }]; + + // let hasVerifier = 1; } def XeGPU_NbarrierArriveOp diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td b/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td index d3093aa61..cef55c67d 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUTypes.td @@ -26,10 +26,15 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, F8E4M3FN, F8E5M2, F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2FNUZ]>; def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>; def XeGPU_BaseAddrType: AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1, 2]>, UI64]>; -def XeGPU_VectorType: VectorOfRankAndType<[1,2,3], [XeGPU_ScalarType]>; -def XeGPU_Vector3DType: VectorOfRankAndType<[3], [XeGPU_ScalarType]>; +def XeGPU_DpasOpType: VectorOfRankAndType<[2, 3], [XeGPU_ScalarType]>; +def XeGPU_OffsetType: AnyTypeOf<[VectorOfRankAndType<[1], [Index]>, Index]>; +def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1,2], [I1]>, I1]>; +def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3], [XeGPU_ScalarType]>, XeGPU_ScalarType]>; + +// def XeGPU_VectorType: VectorOfRankAndType<[1,2,3], [XeGPU_ScalarType]>; +// def XeGPU_Vector3DType: VectorOfRankAndType<[3], [XeGPU_ScalarType]>; def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>; -def XeGPU_Vector1DType: VectorOfRankAndType<[1], [XeGPU_ScalarType]>; +// def XeGPU_Vector1DType: VectorOfRankAndType<[1], [XeGPU_ScalarType]>; // common base class for types in XeGPU dialect class XeGPUTypeDef static std::string makeString(llvm::ArrayRef array) { +template static std::string makeString(T array) { std::string buf; buf.clear(); llvm::raw_string_ostream os(buf); diff --git a/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 58330af68..1b245244b 100644 --- a/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -77,6 +77,202 @@ static void printShapeAndType(mlir::AsmPrinter &printer, printer << type; } +template +static mlir::LogicalResult parseArrayList(mlir::AsmParser &parser, + llvm::SmallVector &array, + bool parsePrecedenceEqual = false) { + mlir::FailureOr> result; + // Parse literal '=' + if (parsePrecedenceEqual) + if (parser.parseEqual()) return mlir::failure(); + + // Parse literal '[' + if (parser.parseLSquare()) return mlir::failure(); + + result = mlir::FieldParser<::llvm::SmallVector>::parse(parser); + + if (::mlir::failed(result)) return mlir::failure(); + + // Parse literal ']' + if (parser.parseRSquare()) return mlir::failure(); + + array = result.value(); + return mlir::success(); +} + +template +static void printArrayElement(mlir::AsmPrinter &printer, + llvm::StringRef keyword, + llvm::ArrayRef array) { + printer << keyword; + printer << ' ' << "="; + printer << ' ' << "["; + printer.printStrippedAttrOrType(array); + printer << "]"; +} + + +static mlir::LogicalResult parseSgMapAttrElements(mlir::AsmParser &parser, + llvm::SmallVector &mmaBlockSize, + llvm::SmallVector &layout, + llvm::SmallVector &data) { + auto loc = parser.getCurrentLocation(); + auto parseElt = [&]() -> mlir::LogicalResult { + return mlir::AsmParser::KeywordSwitch(parser) + .Case("mma_block_size", [&](llvm::StringRef, llvm::SMLoc) { + return parseArrayList(parser, mmaBlockSize, true); + }) + .Case("wi_layout", [&](llvm::StringRef, llvm::SMLoc) { + return parseArrayList(parser, layout, true); + }) + .Case("wi_data", [&](llvm::StringRef, llvm::SMLoc) { + return parseArrayList(parser, data, true); + }) + .Default([&](llvm::StringRef keyword, llvm::SMLoc) { + llvm::dbgs() << "\n3. Default currLoc: " << llvm::StringRef(parser.getCurrentLocation().getPointer()) << "\n"; + llvm::dbgs() << "\n3. keyword: " << keyword << "\n"; + return mlir::failure(); + }); + }; + + if (parser.parseLBrace()) return mlir::failure(); + if (parser.parseCommaSeparatedList(parseElt)) return mlir::failure(); + if (parser.parseRBrace()) return mlir::failure(); + if (mmaBlockSize.size() != 2) { + parser.emitError(loc, "failed to parse SgMapAttr: missing mma_block_size which is to be a `llvm::ArrayRef` with size 2"); + return mlir::failure(); + } + if (layout.size() != 2) { + parser.emitError(loc, "failed to parse SgMapAttr: missing wi_layout which is to be a `llvm::ArrayRef` with size 2"); + return mlir::failure(); + } + if (data.size() != 2) { + parser.emitError(loc, "failed to parse SgMapAttr: missing wi_data which is to be a `llvm::ArrayRef` with size 2"); + return mlir::failure(); + } + return mlir::success(); +} + +static void printSgMapAttrElements(mlir::AsmPrinter &printer, + llvm::ArrayRef mmaBlockSize, + llvm::ArrayRef layout, + llvm::ArrayRef data) { + printer << "{"; + printArrayElement(printer, "mma_block_size", mmaBlockSize); + printer << "," << ' '; + printArrayElement(printer, "wi_layout", layout); + printer << "," << ' '; + printArrayElement(printer, "wi_data", data); + printer << "}"; +} + +static mlir::LogicalResult parseWgMapAttrElements(mlir::AsmParser &parser, + llvm::SmallVector &layout, + llvm::SmallVector &data) { + auto loc = parser.getCurrentLocation(); + auto parseElt = [&]() -> mlir::LogicalResult { + return mlir::AsmParser::KeywordSwitch(parser) + .Case("sg_layout", [&](llvm::StringRef, llvm::SMLoc) { + return parseArrayList(parser, layout, true); + }) + .Case("sg_data", [&](llvm::StringRef, llvm::SMLoc) { + return parseArrayList(parser, data, true); + }) + .Default([&](llvm::StringRef keyword, llvm::SMLoc) { + return mlir::failure(); + }); + }; + + if (parser.parseLBrace()) return mlir::failure(); + if (parser.parseCommaSeparatedList(parseElt)) return mlir::failure(); + if (parser.parseRBrace()) return mlir::failure(); + if (layout.size() != 2) { + parser.emitError(loc, "failed to parse WgMapAttr: missing sg_layout which is to be a `llvm::ArrayRef` with size 2"); + return mlir::failure(); + } + if (data.size() != 2) { + parser.emitError(loc, "failed to parse WgMapAttr: missing sg_data which is to be a `llvm::ArrayRef` with size 2"); + return mlir::failure(); + } + return mlir::success(); +} + +static void printWgMapAttrElements(mlir::AsmPrinter &printer, + llvm::ArrayRef layout, + llvm::ArrayRef data) { + printer << "{"; + printArrayElement(printer, "sg_layout", layout); + printer << "," << ' '; + printArrayElement(printer, "sg_data", data); + printer << "}"; +} + + +mlir::Attribute XeMapAttr::parse(mlir::AsmParser &parser, mlir::Type type) { + imex::xegpu::WgMapAttr wg; + imex::xegpu::SgMapAttr sg; + // Parse literal '<' + if (parser.parseLess()) return {}; + + auto parseElt = [&]() -> mlir::ParseResult { + mlir::OptionalParseResult result = mlir::AsmParser::KeywordSwitch(parser) + .Case("sg", [&](llvm::StringRef, llvm::SMLoc) { + if (parser.parseEqual()) return mlir::failure(); + llvm::SmallVector mmaBlockSize; + llvm::SmallVector wiLayout; + llvm::SmallVector wiData; + if (mlir::failed(parseSgMapAttrElements(parser, mmaBlockSize, wiLayout, wiData))) + return mlir::failure(); + sg = imex::xegpu::SgMapAttr::get(parser.getContext(), mmaBlockSize, wiLayout, wiData); + return mlir::success(!!sg); + }) + .Case("wg", [&](llvm::StringRef, llvm::SMLoc) { + if (parser.parseEqual()) return mlir::failure(); + llvm::SmallVector sgLayout; + llvm::SmallVector sgData; + if(mlir::failed(parseWgMapAttrElements(parser, sgLayout, sgData))) + return mlir::failure(); + wg = imex::xegpu::WgMapAttr::get(parser.getContext(), sgLayout, sgData); + return mlir::success(!!wg); + }) + .Default([&](llvm::StringRef keyword, llvm::SMLoc) { + return std::nullopt; + }); + return result.value(); + }; + + // Parse wg and sg attrs + if (parser.parseCommaSeparatedList(parseElt)) return {}; + + // Parse literal '>' + if (parser.parseGreater()) return {}; + + if(!wg && !sg) { + parser.emitError(parser.getCurrentLocation(), "Expecting at least one of sg and wg attributes.\n"); + return {}; + } + + return XeMapAttr::get(parser.getContext(), wg, sg); +} + +void XeMapAttr::print(mlir::AsmPrinter &printer) const { + bool printSep = false; + printer << "<"; + if (getWg()) { + printer << "wg = "; + printWgMapAttrElements(printer, getWg().getSgLayout(), getWg().getSgData()); + printSep = true; + } + + if (getSg()) { + if (printSep) printer << ", "; + printer << "sg = "; + printSgMapAttrElements(printer, getSg().getMmaBlockSize(), getSg().getWiLayout(), getSg().getWiData()); + } + + printer << ">"; +} + } // namespace xegpu } // namespace imex @@ -85,3 +281,4 @@ static void printShapeAndType(mlir::AsmPrinter &printer, #include #define GET_TYPEDEF_CLASSES #include + diff --git a/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index cb22fede9..562964f63 100644 --- a/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -44,23 +44,44 @@ const int TN_SIZE = 16; const int TK_SIZE_FOR_D16 = 16; const int TK_SIZE_FOR_D8 = 32; -static bool vnniVerifier(size_t axis, llvm::ArrayRef tdescShape, - llvm::ArrayRef valueShape, - size_t elemTyBitWidth) { - bool isValid = valueShape.size() == tdescShape.size() + 1; - - for (size_t i = 0; i < tdescShape.size(); i++) { - if ((i == axis && valueShape[i] * valueShape.back() != tdescShape[i]) || - (i != axis && valueShape[i] != tdescShape[i])) - isValid = false; - } - - const static size_t xeSIMDLaneBitWidth = 32; - auto vnni_factor = valueShape.back(); +// static bool vnniVerifier(size_t axis, llvm::ArrayRef tdescShape, +// llvm::ArrayRef valueShape, +// size_t elemTyBitWidth) { +// bool isValid = valueShape.size() == tdescShape.size() + 1; + +// for (size_t i = 0; i < tdescShape.size(); i++) { +// if ((i == axis && valueShape[i] * valueShape.back() != tdescShape[i]) || +// (i != axis && valueShape[i] != tdescShape[i])) +// isValid = false; +// } + +// const static size_t xeSIMDLaneBitWidth = 32; +// auto vnni_factor = valueShape.back(); + +// isValid &= vnni_factor == xeSIMDLaneBitWidth / elemTyBitWidth; + +// return isValid; +// } + +static void transpose(llvm::ArrayRef trans, + std::vector &shape) { + std::vector old = shape; + for (size_t i = 0; i < trans.size(); i++) + shape[i] = old[trans[i]]; +}; - isValid &= vnni_factor == xeSIMDLaneBitWidth / elemTyBitWidth; +static void dropOnes(std::vector &array) { + std::vector old = array; + array.clear(); + for(auto v: old) { + if (v != 1) array.push_back(v); + } +}; - return isValid; +static bool isMappingAttr(mlir::Attribute attr) { + return attr && (llvm::isa(attr) + || llvm::isa(attr) + || llvm::isa(attr)); } bool dpasSupportedTypes(mlir::Type type, bool isResult) { @@ -77,29 +98,28 @@ bool dpasSupportedTypes(mlir::Type type, bool isResult) { } } -bool dpasSupportedShapes(DpasOp op) { - - mlir::Type lhsElemType = op.getLhsType().getElementType(); - // TODO: handle dynamic shape cast - auto lhsShape = op.getLhsType().cast().getShape(); - auto rhsShape = op.getRhsType().cast().getShape(); - // Retrieve 2D shapes(MxK * KxN) from 3D. Verify this - auto m = lhsShape[0]; - auto k = lhsShape[1] * lhsShape[2]; - auto n = rhsShape[1]; - - if ((lhsElemType.isF16() || lhsElemType.isBF16()) && m <= MAX_TM_SIZE && - n == TN_SIZE && k == TK_SIZE_FOR_D16) { - return true; - } - - if (lhsElemType.isInteger(8) && m <= MAX_TM_SIZE && n == TN_SIZE && - k == TK_SIZE_FOR_D8) { - return true; - } - - return false; -} +// bool dpasSupportedShapes(DpasOp op) { +// mlir::Type lhsElemType = op.getLhsType().getElementType(); +// // TODO: handle dynamic shape cast +// auto lhsShape = op.getLhsType().cast().getShape(); +// auto rhsShape = op.getRhsType().cast().getShape(); +// // Retrieve 2D shapes(MxK * KxN) from 3D. Verify this +// auto m = lhsShape[0]; +// auto k = lhsShape[1] * lhsShape[2]; +// auto n = rhsShape[1]; + +// if ((lhsElemType.isF16() || lhsElemType.isBF16()) && m <= MAX_TM_SIZE && +// n == TN_SIZE && k == TK_SIZE_FOR_D16) { +// return true; +// } + +// if (lhsElemType.isInteger(8) && m <= MAX_TM_SIZE && n == TN_SIZE && +// k == TK_SIZE_FOR_D8) { +// return true; +// } + +// return false; +// } template static mlir::ParseResult parseCustomEnumAttr(mlir::OpAsmParser &parser, @@ -179,6 +199,10 @@ parseOptionalAttrDict(mlir::OpAsmParser &parser, mlir::OperationState &result, parser, result, nameId); } + if (nameId == "mode") { + return parseCustomEnumAttr(parser, result, nameId); + } + if (nameId == "chunk_size_per_lane" || nameId == "vnni_axis") return parseBoolAndIntegerAttr(parser, result, nameId); @@ -271,7 +295,7 @@ mlir::ParseResult CreateNdDescOp::parse(mlir::OpAsmParser &parser, return ::mlir::failure(); } - if (parseOptionalAttrDict(parser, result, {"memory_scope", "boundary_check"})) + if (parseOptionalAttrDict(parser, result, {"memory_scope", "boundary_check", "mode"})) return mlir::failure(); if (parser.parseColon()) @@ -330,8 +354,9 @@ void CreateNdDescOp::print(::mlir::OpAsmPrinter &printer) { printer << "]"; } - // printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << "{"; + printer << "mode = " << getMode(); + printer << "," << ' '; printer << "memory_scope = " << getMemoryScope(); printer << "," << ' '; printer << "boundary_check = " << getBoundaryCheck(); @@ -346,9 +371,13 @@ void CreateNdDescOp::print(::mlir::OpAsmPrinter &printer) { } mlir::LogicalResult CreateNdDescOp::verify() { - LLVM_DEBUG(llvm::dbgs() << "Op: " << getValueAsString(*this) - << "\n\tstatic offsets: " - << makeString(getStaticOffsets()) << "\n\n"); + auto mode = getMode(); + auto encoding = getTensorDesc().getType().getEncoding(); + + if (mode == imex::xegpu::Mode::SIMT && !isMappingAttr(encoding)) { + return emitOpError("Expecting either SgMap, WgMap or XeMap attribute for SIMT mode operators.\n"); + } + // it is invalid to have both dynamic and static shape if (!(hasDynamicShape() ^ hasStaticShape())) return emitOpError("It is invalid to have both or none of dynamic shape " @@ -383,7 +412,7 @@ mlir::ParseResult CreateDescOp::parse(mlir::OpAsmParser &parser, return mlir::failure(); if (parseOptionalAttrDict(parser, result, - {"memory_scope", "chunk_size_per_lane"})) + {"memory_scope", "chunk_size_per_lane", "mode"})) return mlir::failure(); if (parser.parseColon()) @@ -426,6 +455,8 @@ void CreateDescOp::print(::mlir::OpAsmPrinter &printer) { printer << getOffsets(); printer << ' ' << "{"; + printer << "mode = " << getMode(); + printer << "," << ' '; printer << "memory_scope = " << getMemoryScope(); printer << "," << ' '; printer << "chunk_size_per_lane = " << getChunkSizePerLane(); @@ -443,29 +474,36 @@ void CreateDescOp::print(::mlir::OpAsmPrinter &printer) { } mlir::LogicalResult CreateDescOp::verify() { - auto offsetType = getOffsets().getType(); - auto tdescType = getTensorDesc().getType(); - auto chunkSize = getChunkSizePerLane(); - - auto offsetShape = offsetType.getShape(); - auto tdescShape = tdescType.getShape(); - if (getRankOf(getSource()) > 1) return emitOpError( "Expecting the source is a 1D memref or pointer (uint64_t)."); - if (offsetShape.size() != 1) - return emitOpError("Expecting the offset is a 1D vector."); + std::vector shape; + + auto offsetTy = getOffsets().getType(); + auto tdescTy = getTensorDesc().getType(); + auto chunkSize = getChunkSizePerLane(); + + auto tdescShape = tdescTy.getShape(); + + if (llvm::isa(offsetTy)) { + shape = llvm::dyn_cast(offsetTy).getShape().vec(); + if (shape.size() != 1) + return emitOpError("Expecting the offset is either a 1D vector (for VC) or scalar (for SIMT)."); + } + + if (offsetTy.isIndex() || chunkSize != 1) { + shape.push_back(chunkSize); + } - if (offsetShape != tdescShape && (offsetShape != tdescShape.drop_back() || - tdescShape.back() != chunkSize)) { + if (shape != tdescShape.vec()) { return emitOpError("Expecting dimensions of offsets is the same as the " "tensor descriptor, or one less than."); } - if (!tdescType.getEncoding()) - return emitOpError("Expecting the presence of scattered attribute for " - "scattered tensor descriptor."); + if (!tdescTy.getEncoding()) + return emitOpError("Expecting the presence of scattered attribute for tensor descriptor."); + return mlir::success(); } @@ -480,7 +518,7 @@ mlir::ParseResult LoadNDOp::parse(::mlir::OpAsmParser &parser, if (parseOptionalAttrDict( parser, result, - {"vnni_axis", "transpose", "l1_hint", "l2_hint", "l3_hint"})) + {"mode", "vnni_axis", "transpose", "l1_hint", "l2_hint", "l3_hint"})) return mlir::failure(); if (parser.parseColon()) @@ -512,22 +550,20 @@ void LoadNDOp::print(::mlir::OpAsmPrinter &printer) { printer << getTensorDesc(); if ((*this)->getAttrs().size()) { - bool printSep = false; printer << ' ' << "{"; + printer << "mode = " << getMode(); if (getVnniAxisAttr()) { + printer << "," << ' '; printer << "vnni_axis = " << getVnniAxis().value(); - printSep = true; } if (getTransposeAttr()) { - if (printSep) - printer << ", "; + printer << "," << ' '; printer << "transpose = "; getTransposeAttr().print(printer); - printSep = true; } - printCacheHintAttrs(printer, *this, printSep); + printCacheHintAttrs(printer, *this, true); printer << "}"; } @@ -539,76 +575,138 @@ void LoadNDOp::print(::mlir::OpAsmPrinter &printer) { printer << getValue().getType(); } +// mlir::LogicalResult CreateNbarrierOp::verify() { +// llvm::dbgs() << "\nOp: " << getValueAsString(*this) +// << "\n\tnum producers: " << getNumProducers() +// << "\n\tnum consumers: " << getNumConsumers() +// << "\n\n"; +// return mlir::success(); +// } + mlir::LogicalResult LoadNDOp::verify() { - auto input = getTensorDesc(); + auto tdescTy = getTensorDesc().getType(); + auto valueTy = llvm::dyn_cast(getValue().getType()); - auto tdescShape = getTensorDesc().getType().getShape().vec(); - auto valueShape = getValue().getType().getShape().vec(); + if (tdescTy.getRank() != 2) + return emitOpError("The TensorDesc for LoadNDOp should be a 2D TensorDesc."); - auto tdescElemTy = getTensorDesc().getType().getElementType(); - auto valueElemTy = getValue().getType().getElementType(); + if (!valueTy) + return emitOpError("Invalid result, it should be a VectorType.\n"); - if (!llvm::isa(input.getType()) || - input.getType().getRank() != 2) { - return emitOpError("The input to LoadNDOp should be a 2D TensorDesc."); - } + auto tdescElemTy = tdescTy.getElementType(); + auto valueElemTy = valueTy.getElementType(); if (tdescElemTy != valueElemTy) - return emitOpError( - "Value should have the same element type as TensorDesc."); + return emitOpError("Value should have the same element type as TensorDesc."); + + { // TODO: The following logic are architecture dependent, pending to be moved out + auto width = tdescTy.getShape()[1]; + auto height = tdescTy.getShape()[0]; + auto elemTyByteWidth = tdescElemTy.getIntOrFloatBitWidth() / 8; + + if (width < MIN_2D_BLOCK_WIDTH_IN_ELEMENTS || + width > MAX_2D_BLOCK_WIDTH_IN_ELEMENTS || + (width * elemTyByteWidth) % 4 != 0) { + return emitOpError("Invalid width size for 2D block load. \ + The specification expects the value to \ + be in range [1, 64], and The the total \ + data size (width * elemTyBytes) to be multiple of 4.\n"); + } - if (!tdescElemTy.isIntOrFloat()) { - // FIXME: currently only int and float type are supported for estimating - // size info improve it to make it more robust if neccessary - return emitOpError( - "Currently only IntType or FloatType are supported for Load2DOp."); + if (height < MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS || + height > MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS) { + return emitOpError( + "Invalid height size for 2D block load. The specification expects the " + "value to be in range [1, 32].\n"); + } } - auto width = input.getType().getShape()[1]; - auto height = input.getType().getShape()[0]; - auto elemTyByteWidth = tdescElemTy.getIntOrFloatBitWidth() / 8; - - if (width < MIN_2D_BLOCK_WIDTH_IN_ELEMENTS || - width > MAX_2D_BLOCK_WIDTH_IN_ELEMENTS || - (width * elemTyByteWidth) % 4 != 0) { - return emitOpError("Invalid width size for 2D block load. \ - The specification expects the value to \ - be in range [1, 64], and The the total \ - data size (width * elemTyBytes) to be multiple of 4.\n"); - } - if (height < MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS || - height > MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS) { - return emitOpError( - "Invalid height size for 2D block load. The specification expects the " - "value to be in range [1, 32].\n"); + auto mode = getMode(); + auto tdescShape = tdescTy.getShape().vec(); + auto valueShape = valueTy.getShape().vec(); + + if (mode == imex::xegpu::Mode::SIMT) { + imex::xegpu::WgMapAttr wgMap; + imex::xegpu::SgMapAttr sgMap; + + auto encoding = tdescTy.getEncoding(); + if (!isMappingAttr(encoding)) { + return emitOpError("Expecting either SgMap, WgMap or XeMap attribute for SIMT mode operators.\n"); + } + + if (auto xeMapAttr = llvm::dyn_cast(encoding)) { + wgMap = xeMapAttr.getWg(); + sgMap = xeMapAttr.getSg(); + } else { + wgMap = llvm::dyn_cast(encoding); + sgMap = llvm::dyn_cast(encoding); + } + + if (wgMap) { + auto sgData = wgMap.getSgData(); + auto sgLayout = wgMap.getSgLayout(); + for(size_t i = 0; i < sgData.size(); i++) { + if (tdescShape[i] % sgLayout[i] != 0 || + tdescShape[i] % sgData[i] != 0 || + tdescShape[i] % sgData[i] != 0) + return emitOpError("Invalid WgMapAttr. It should meet the following conditions: " + "tdescShape[i] % sgLayout[i] == 0 && " + "tdescShape[i] % sgData[i] == 0 && " + "tdescShape[i] % sgData[i] == 0"); + tdescShape[i] /= sgLayout[i]; + } + // dropOnes(tdescShape); + } + + if (sgMap) { + auto blockSize = sgMap.getMmaBlockSize(); + auto wiLayout = sgMap.getWiLayout(); + auto wiData = sgMap.getWiData(); + for (size_t i = 0; i < blockSize.size(); i++) { + if (tdescShape[i] % blockSize[i] != 0 || + blockSize[i] % wiLayout[i] != 0 || + blockSize[i] % wiData[i] != 0 || + blockSize[i] % (wiLayout[i] * wiData[i]) != 0) { + return emitOpError("Invalid SgMapAttr. It should meet the following conditions: " + "blockSize[i] % wiLayout[i] == 0 && " + "blockSize[i] % wiData[i] == 0 && " + "blockSize[i] % wiData[i] == 0 && " + "tdescShape[i] % blockSize[i] == 0"); + + } + auto tmp = blockSize[i]/wiLayout[i]; + tdescShape[i] /= blockSize[i]; + tdescShape[i] *= tmp; + } + } } if (getTranspose()) { - auto dim0 = getTranspose().value()[0]; - auto dim1 = getTranspose().value()[1]; - auto tmp = valueShape[dim0]; - valueShape[dim0] = valueShape[dim1]; - valueShape[dim1] = tmp; + auto trans = getTranspose().value(); + if (tdescShape.size() >= trans.size()) + transpose(trans, tdescShape); + else emitWarning("Invalid transpose attr. It is ignored."); } - if (!getVnniAxis()) { - if (valueShape != tdescShape) - return emitOpError("Value should have the same shape as TensorDesc when " - "vnni is not enabled."); - } else { + if (getVnniAxis()) { auto axis = getVnniAxis().value(); - auto bits = getTensorDesc().getType().getElementTypeBitWidth(); - if (!vnniVerifier(axis, tdescShape, valueShape, bits)) - return emitOpError("Invalid vnni transform. When vnni is enabled, value " - "should have one more" - "dimention than the TensorDesc, but having same " - "number of data elements." - "Also, vnni factor should be calculated as " - "simd_lane_width / elementTypeBitWidth." - "For element type having more than 32 bits, vnni " - "shouldn't be used.\n"); + auto vnni_factor = valueShape.back(); + tdescShape[axis] /= vnni_factor; + tdescShape.push_back(vnni_factor); + dropOnes(tdescShape); } + if (tdescShape != valueShape) + return emitOpError("Result shape doesn't match TensorDesc shape." + "The expected shape is " + makeString(tdescShape) + ", while " + "the given shape is " + makeString(valueShape) + ". " + "In VC mode, when VNNI is not enabled, the result should have the same " + "shape (or transposed shape if transpose is also enabled) as TensorDesc; " + "when VNNI is enabled, the result should have one more dimention than the " + "TensorDesc, with last dimention having vnni factor, but having same number " + "of total data elements. The vnni factor are typically calculated as simd_lane_width / elementTypeBitWidth. " + "For element type having more than 32 bits, vnni shouldn't be used. " + "In SIMT mode, the shape is derived from the mapping attributes.\n"); return mlir::success(); } @@ -631,7 +729,7 @@ ::mlir::ParseResult StoreNDOp::parse(::mlir::OpAsmParser &parser, if (parser.parseOperand(TensorDescRawOperands[0])) return mlir::failure(); - if (parseOptionalAttrDict(parser, result, {"l1_hint", "l2_hint", "l3_hint"}, + if (parseOptionalAttrDict(parser, result, {"mode", "l1_hint", "l2_hint", "l3_hint"}, true)) return mlir::failure(); @@ -669,9 +767,9 @@ void StoreNDOp::print(::mlir::OpAsmPrinter &printer) { printer << ' '; printer << getTensorDesc(); if ((*this)->getAttrs().size()) { - bool printSep = false; printer << ' ' << "{"; - printCacheHintAttrs(printer, *this, printSep); + printer << "mode = " << getMode(); + printCacheHintAttrs(printer, *this, true); printer << "}"; } printer << ' ' << ":"; @@ -683,47 +781,97 @@ void StoreNDOp::print(::mlir::OpAsmPrinter &printer) { } mlir::LogicalResult StoreNDOp::verify() { - auto dst = getTensorDesc(); // Tile - auto val = getValue(); // Vector + auto dstTy = getTensorDesc().getType(); // Tile + auto valTy = llvm::dyn_cast(getValue().getType()); // Vector - if (dst.getType().getShape() != val.getType().getShape()) { - return emitOpError( - "The value (vector) shape doesn't match the memory (dst) shape.\n"); - } - auto dstElemTy = dst.getType().getElementType(); - auto valElemTy = val.getType().getElementType(); + if (dstTy.getRank() != 2) + return emitOpError("The TensorDesc for StoreNdOp should be a 2D TensorDesc."); + + if (!valTy) + return emitOpError("Invalid value operand, it should be a VectorType.\n"); + + auto dstElemTy = dstTy.getElementType(); + auto valElemTy = valTy.getElementType(); if (dstElemTy != valElemTy) { return emitOpError("The elem type of value (vector) shape doesn't match " "the elem type of memory (dst) shape.\n"); } - if (!dstElemTy.isIntOrFloat()) { - // FIXME: currently only int and float type are supported for estimating - // size info improve it to make it more robust if neccessary - return emitOpError( - "Currently only IntType or FloatType are supported for Store2DOp."); - } - auto width = dst.getType().getShape()[1]; - auto height = dst.getType().getShape()[0]; - auto elemTyByteWidth = dstElemTy.getIntOrFloatBitWidth() / 8; - - if (width < MIN_2D_BLOCK_WIDTH_IN_ELEMENTS || - width > MAX_2D_BLOCK_WIDTH_IN_ELEMENTS || - (width * elemTyByteWidth) % 4 != 0) { - return emitOpError("Invalid width size for 2D block write. \ - The specification expects the value to \ - be in range [1, 64], and The the total \ - data size (width * elemTyBytes) to be multiple of 4.\n"); + { // TODO: The following logic are architecture dependent, pending to be moved out + auto width = dstTy.getShape()[1]; + auto height = dstTy.getShape()[0]; + auto elemTyByteWidth = dstElemTy.getIntOrFloatBitWidth() / 8; + if (width < MIN_2D_BLOCK_WIDTH_IN_ELEMENTS || + width > MAX_2D_BLOCK_WIDTH_IN_ELEMENTS || + (width * elemTyByteWidth) % 4 != 0) { + return emitOpError("Invalid width size for 2D block write. \ + The specification expects the value to \ + be in range [1, 64], and The the total \ + data size (width * elemTyBytes) to be multiple of 4.\n"); + } + + if (height < MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS || + height > MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS) { + return emitOpError("Invalid height size for 2D block write. The specification" + "expects the value to be in range [1, 32].\n"); + } } - if (height < MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS || - height > MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS) { - return emitOpError( - "Invalid height size for 2D block write. The specification expects the " - "value to be in range [1, 32].\n"); + auto mode = getMode(); + + if (mode == imex::xegpu::Mode::VC) { // for VC mode, no attr attached + if (dstTy.getShape() != valTy.getShape()) + return emitOpError("In VC mode, the value (vector) shape doesn't match the memory (dst) shape.\n"); + } else { + auto encoding = dstTy.getEncoding(); + if (!isMappingAttr(encoding)) { + return emitOpError("Expecting either SgMap, WgMap or XeMap attribute for SIMT mode operators.\n"); + } + + imex::xegpu::WgMapAttr wgMap; + imex::xegpu::SgMapAttr sgMap; + std::vector shape = dstTy.getShape().vec(); + + if (auto xeMapAttr = llvm::dyn_cast(encoding)) { + wgMap = xeMapAttr.getWg(); + sgMap = xeMapAttr.getSg(); + } else { + wgMap = llvm::dyn_cast(encoding); + sgMap = llvm::dyn_cast(encoding); + } + + if (wgMap) { + auto sgData = wgMap.getSgData(); + auto sgLayout = wgMap.getSgLayout(); + for(size_t i = 0; i < sgData.size(); i++) { + assert(shape[i] % sgLayout[i] == 0); + assert(shape[i] % sgData[i] == 0); + assert(shape[i] % (sgLayout[i] * sgData[i]) == 0); + shape[i] /= sgLayout[i]; + } + } + + if (sgMap) { + auto blockSize = sgMap.getMmaBlockSize(); + auto wiLayout = sgMap.getWiLayout(); + auto wiData = sgMap.getWiData(); + for (size_t i = 0; i < shape.size(); i++) { + assert(blockSize[i] % (wiLayout[i] * wiData[i]) == 0); + assert(blockSize[i] % wiLayout[i] == 0); + assert(blockSize[i] % wiData[i] == 0); + assert(shape[i] % blockSize[i] == 0); + auto tmp = blockSize[i]/wiLayout[i]; + shape[i] /= blockSize[i]; + shape[i] *= tmp; + } + } + + if (shape != valTy.getShape().vec()) + return emitOpError("In SIMT mode, the value (vector) shape doesn't match the memory" + "(dst) shape as derived according to the mapping rule.\n"); } return mlir::success(); @@ -742,7 +890,7 @@ ::mlir::ParseResult PrefetchNDOp::parse(::mlir::OpAsmParser &parser, if (parser.parseOperand(TensorDescRawOperands[0])) return ::mlir::failure(); - if (parseOptionalAttrDict(parser, result, {"l1_hint", "l2_hint", "l3_hint"})) + if (parseOptionalAttrDict(parser, result, {"mode", "l1_hint", "l2_hint", "l3_hint"})) return mlir::failure(); if (parser.parseColon()) @@ -761,9 +909,9 @@ void PrefetchNDOp::print(::mlir::OpAsmPrinter &printer) { printer << getTensorDesc(); // printer.printOptionalAttrDict((*this)->getAttrs()); if ((*this)->getAttrs().size()) { - bool printSep = false; printer << ' ' << "{"; - printCacheHintAttrs(printer, *this, printSep); + printer << "mode = " << getMode(); + printCacheHintAttrs(printer, *this, true); printer << "}"; } @@ -780,10 +928,12 @@ mlir::LogicalResult DpasOp::verify() { mlir::Type rhsElemType = getRhsType().getElementType(); mlir::Type resultElemType = getResultType().getElementType(); + // TODO: this is hardware specific, need to be moved out. if (!dpasSupportedTypes(lhsElemType, 0)) { return emitOpError("Unsupported src datatype for dpas op"); } + // TODO: this is hardware specific, need to be moved out. if (!dpasSupportedTypes(resultElemType, 1)) { return emitOpError("Unsupported result datatype for dpas op"); } @@ -793,25 +943,21 @@ mlir::LogicalResult DpasOp::verify() { } if (getAcc()) { - mlir::Type accElemType = getAccType().getElementType(); - if (accElemType != resultElemType) { - return emitOpError( - "Accumulator and Result element type does not match for dpas op"); - } + if (getAccType() != getResultType()) + return emitOpError("Accumulator and Result for dpas op should have the same type (both shape and element type)."); } - if (!dpasSupportedShapes(*this)) { - return emitOpError("Incorrect shapes for dpas op"); - } + // TODO: SIMT makes it harder to check semantic errors for DPAS op. + // the only thing we can check seems to be vnni factor. But it + // depends on hardware though. + // if (!dpasSupportedShapes(*this)) { + // return emitOpError("Incorrect shapes for dpas op"); + // } if (lhsRank != rhsRank) { return emitOpError("lhs and rhs rank does not match for dpas op"); } - if (lhsRank < 3) { - return emitOpError("dpas op requires 3d vector. Rank is not 3"); - } - return mlir::success(); } @@ -844,9 +990,8 @@ ::mlir::ParseResult LoadGatherOp::parse(::mlir::OpAsmParser &parser, if (parser.parseOperand(maskRawOperands[0])) return mlir::failure(); - if (parseOptionalAttrDict( - parser, result, - {"vnni_axis", "transpose", "l1_hint", "l2_hint", "l3_hint"})) + if (parseOptionalAttrDict(parser, result, + {"mode", "vnni_axis", "transpose", "l1_hint", "l2_hint", "l3_hint"})) return mlir::failure(); if (parser.parseColon()) @@ -886,22 +1031,19 @@ void LoadGatherOp::print(mlir::OpAsmPrinter &printer) { printer << ' '; printer << getMask(); if ((*this)->getAttrs().size()) { - bool printSep = false; printer << ' ' << "{"; - if (getVnniAxisAttr()) { - printer << "vnni_axis = " << getVnniAxis().value(); - printSep = true; - } + + printer << "mode = " << getMode(); + if (getVnniAxisAttr()) + printer << ", vnni_axis = " << getVnniAxis().value(); + if (getTransposeAttr()) { - if (printSep) - printer << ", "; - printer << "transpose = "; + printer << ", transpose = "; getTransposeAttr().print(printer); - printSep = true; } - printCacheHintAttrs(printer, *this, printSep); + printCacheHintAttrs(printer, *this, true); printer << "}"; } @@ -918,47 +1060,67 @@ void LoadGatherOp::print(mlir::OpAsmPrinter &printer) { } mlir::LogicalResult LoadGatherOp::verify() { - // length of the offsets vector must match the dim-0 of the tensor descriptor - auto tdescShape = getTensorDesc().getType().getShape().vec(); - auto maskShape = getMask().getType().getShape().vec(); - auto valueShape = getValue().getType().getShape().vec(); + auto tdescTy = getTensorDesc().getType(); + auto maskTy = getMask().getType(); + auto valueTy = getValue().getType(); + + auto getElementType = [&](mlir::Type type) -> mlir::Type { + if (type.isIntOrIndexOrFloat()) + return type; + else if (llvm::isa(type)) + return llvm::dyn_cast(type).getElementType(); + else if (llvm::isa(type)) + return llvm::dyn_cast(type).getElementType(); + assert(0 && "Unreachable !!!"); + return type; + }; + + auto tdescElemTy = getElementType(tdescTy); + auto valueElemTy = getElementType(valueTy); + if (tdescElemTy != valueElemTy) + return emitOpError("Value should have the same element type as TensorDesc."); + + auto getShape = [&](mlir::Type type, std::vector &shape) -> void { + if (type.isIntOrIndexOrFloat()) + shape.push_back(1); + else if (llvm::isa(type)) + shape = llvm::dyn_cast(type).getShape().vec(); + else + assert(0 && "Unreachable !!!"); + }; - auto tdescElemTy = getTensorDesc().getType().getElementType(); - auto valueElemTy = getValue().getType().getElementType(); + std::vector maskShape, valueShape; + getShape(maskTy, maskShape); + getShape(valueTy, valueShape); + auto tdescShape = tdescTy.getShape().vec(); if (tdescShape != maskShape) return emitOpError("Mask should have the same shape as TensorDesc."); - if (tdescElemTy != valueElemTy) - return emitOpError( - "Value should have the same element type as TensorDesc."); - if (getTranspose()) { - auto dim0 = getTranspose().value()[0]; - auto dim1 = getTranspose().value()[1]; - auto tmp = valueShape[dim0]; - valueShape[dim0] = valueShape[dim1]; - valueShape[dim1] = tmp; + auto trans = getTranspose().value(); + if (tdescShape.size() >= trans.size()) + transpose(trans, tdescShape); + else emitWarning("Invalid transpose attr. It is ignored."); } - if (!getVnniAxis()) { - if (valueShape != tdescShape) - return emitOpError("Value should have the same shape as TensorDesc when " - "vnni is not enabled."); - } else { + if (getVnniAxis()) { auto axis = getVnniAxis().value(); - auto bits = getTensorDesc().getType().getElementTypeBitWidth(); - if (!vnniVerifier(axis, tdescShape, valueShape, bits)) - return emitOpError("Invalid vnni transform. When vnni is enabled, value " - "should have one more" - "dimention than the TensorDesc, but having same " - "number of data elements." - "Also, vnni factor should be calculated as " - "simd_lane_width / elementTypeBitWidth." - "For element type having more than 32 bits, vnni " - "shouldn't be used.\n"); + auto vnni_factor = valueShape.back(); + tdescShape[axis] /= vnni_factor; + tdescShape.push_back(vnni_factor); + dropOnes(tdescShape); } + if (valueShape != tdescShape) + return emitOpError("Result shape doesn't match TensorDesc shape. when VNNI is not enabled," + "the result should have the same shape (or transposed shape if transpose" + "is also enabled) as TensorDesc. When VNNI is enabled, the result should" + "have one more dimention than the TensorDesc, with last dimention having" + "vnni factor, but having same number of total data elements. The vnni " + "factor are typically calculated as simd_lane_width / elementTypeBitWidth." + "For element type having more than 32 bits, vnni shouldn't be used.\n"); + return ::mlir::success(); } @@ -1006,16 +1168,13 @@ ::mlir::ParseResult StoreScatterOp::parse(::mlir::OpAsmParser &parser, if (parser.parseOperand(maskRawOperands[0])) return mlir::failure(); - if (parseOptionalAttrDict(parser, result, {"l1_hint", "l2_hint", "l3_hint"}, + if (parseOptionalAttrDict(parser, result, {"mode", "l1_hint", "l2_hint", "l3_hint"}, true)) return mlir::failure(); if (parser.parseColon()) return ::mlir::failure(); - // if (parser.parseLParen()) - // return ::mlir::failure(); - if (parser.parseType(valueRawTypes[0])) return ::mlir::failure(); @@ -1031,9 +1190,6 @@ ::mlir::ParseResult StoreScatterOp::parse(::mlir::OpAsmParser &parser, if (parser.parseType(maskRawTypes[0])) return ::mlir::failure(); - // if (parser.parseRParen()) - // return ::mlir::failure(); - if (parser.resolveOperands(valueOperands, valueTypes, valueOperandsLoc, result.operands)) return ::mlir::failure(); @@ -1058,9 +1214,9 @@ void StoreScatterOp::print(::mlir::OpAsmPrinter &printer) { printer << ' '; printer << getMask(); if ((*this)->getAttrs().size()) { - bool printSep = false; printer << ' ' << "{"; - printCacheHintAttrs(printer, *this, printSep); + printer << "mode = " << getMode(); + printCacheHintAttrs(printer, *this, true); printer << "}"; } @@ -1076,19 +1232,63 @@ void StoreScatterOp::print(::mlir::OpAsmPrinter &printer) { } ::mlir::LogicalResult StoreScatterOp::verify() { - // length of the offsets vector must match the dim-0 of the tensor descriptor - if (getTensorDesc().getType().getShape() != getMask().getType().getShape()) { - return emitOpError("Mask should have the same shape as TensorDesc."); + auto valueTy = getValue().getType(); + auto tdescTy = getTensorDesc().getType(); + auto maskTy = getMask().getType(); + + std::vector valueShape, maskShape; + auto getShape = [&](mlir::Type type, std::vector &shape) -> void { + if (type.isIntOrIndexOrFloat()) + shape.push_back(1); + else if (llvm::isa(type)) + shape = llvm::dyn_cast(type).getShape().vec(); + else + assert(0 && "Unreachable !!!"); + }; + + getShape(valueTy, valueShape); + getShape(maskTy, maskShape); + + if (tdescTy.getShape().vec() != maskShape || valueShape != maskShape ) { + return emitOpError("Mask and value should have the same shape/size as TensorDesc." + "Mask and Value can be scalar if TensorDesc is in form of TensorDesc<1xf16>."); } return ::mlir::success(); } ::mlir::LogicalResult UpdateOffsetOp::verify() { - // length of the offsets vector must match the dim-0 of the tensor descriptor - if (getTensorDesc().getType().getShape()[0] != - getOffsets().getType().getShape()[0]) { - return emitOpError("invalid number of offsets."); + auto srcTy = getTensorDesc().getType(); + auto offTy = getOffsets().getType(); + auto resTy = getResult().getType(); + + if (srcTy != resTy) + return emitOpError("The result should have the same type" + "(shape and encoding attribute) as the input TensorDesc."); + + auto shape = srcTy.getShape(); + auto encoding = srcTy.getEncoding(); + + if (!encoding || !llvm::isa(encoding)) { + return emitOpError("Invalid TensorDesc, it should have a scattered attribute."); } + + // For VC mode with chunkSize > 1. For chunkSize == 1, it is hard to distinguish + // between VC and SIMT mode by only looking at updateOffsetOp itself. So current + // verifier skipped these two cases. + if (shape.size() == 2) { + if (!llvm::isa(offTy)) + return emitOpError("Based on TensorDesc shape, it is an VC tensor descriptor, " + "in which the offset should be an 1D vector."); + + auto vecTy = llvm::dyn_cast(offTy); + if (vecTy.getRank() != 1) + return emitOpError("The index should be an 1D vector Type for VC mode tensor descriptor."); + + if (shape[0] != vecTy.getShape()[0]) + return emitOpError("For VC Mode TensorDesc. The offset should have same" + "length as the dim-0 of TensorDesc."); + } + return ::mlir::success(); } diff --git a/test/Conversion/XeGPUToVC/gemm_1024x1024xf16.runnable.mlir b/test/Conversion/XeGPUToVC/gemm_1024x1024xf16.runnable.mlir index d24f3bcd5..c5422da9b 100644 --- a/test/Conversion/XeGPUToVC/gemm_1024x1024xf16.runnable.mlir +++ b/test/Conversion/XeGPUToVC/gemm_1024x1024xf16.runnable.mlir @@ -34,18 +34,18 @@ module @gemm attributes {gpu.container_module} { %1 = gpu.block_id y %2 = arith.muli %0, %c8 : index %3 = arith.muli %1, %c16 : index - %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = xegpu.create_nd_tdesc %arg2[%2, %3] {mode = vc} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %5 = xegpu.load_nd %4 {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // each work-group has 1 subgroup. the subgroup caculates a [8x16 = 8x1024 * 1024x16] block %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { - %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> - %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> - %9 = xegpu.load_nd %7 {vnni_axis = 1}: !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> - %10 = xegpu.load_nd %8 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %11 = xegpu.dpas %9, %10, %arg4 : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] {mode = vc} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] {mode = vc} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %9 = xegpu.load_nd %7 {mode = vc, vnni_axis = 1}: !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %10 = xegpu.load_nd %8 {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %11 = xegpu.dpas %9, %10, %arg4 {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> scf.yield %11 : vector<8x16xf32> } - xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %6, %4 {mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } diff --git a/test/Conversion/XeGPUToVC/gemm_basic.mlir b/test/Conversion/XeGPUToVC/gemm_basic.mlir index a168cc791..47716f446 100644 --- a/test/Conversion/XeGPUToVC/gemm_basic.mlir +++ b/test/Conversion/XeGPUToVC/gemm_basic.mlir @@ -29,16 +29,16 @@ module @gemm attributes {gpu.container_module} { // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v128i32_i1_v8i32 // CHECK: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32 // CHECK: spirv.FunctionCall @llvm_genx_raw_sends2_noresult_i1_v8i32_v128f32 - %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> - %2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16> - xegpu.prefetch_nd %1 : !xegpu.tensor_desc<16x16xf16> + %0 = xegpu.create_nd_tdesc %arg0[0, 0] {mode = vc} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] {mode = vc} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xegpu.create_nd_tdesc %arg2[0, 0] {mode = vc} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.prefetch_nd %0 {mode = vc} : !xegpu.tensor_desc<8x16xf16> + xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<16x16xf16> - %3 = xegpu.load_nd %0 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> - %4 = xegpu.load_nd %1 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %5 = xegpu.dpas %3, %4 : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> - xegpu.store_nd %5, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %0 {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %4 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %5 = xegpu.dpas %3, %4 {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + xegpu.store_nd %5, %2 {mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } diff --git a/test/Dialect/XeGPU/IR/XeGPUOps.mlir b/test/Dialect/XeGPU/IR/XeGPUOps.mlir index e0f837e08..98bb5b15e 100644 --- a/test/Dialect/XeGPU/IR/XeGPUOps.mlir +++ b/test/Dialect/XeGPU/IR/XeGPUOps.mlir @@ -4,103 +4,103 @@ // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_create_nd_tdesc({{.*}}) { -func.func @test_create_nd_tdesc(%src: memref<24x32xf32>) { +// CHECK-LABEL: func @test_create_nd_tdesc_vc({{.*}}) { +func.func @test_create_nd_tdesc_vc(%src: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %2 = xegpu.create_nd_tdesc %src[2, 4] + %2 = xegpu.create_nd_tdesc %src[2, 4] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> return } -// CHECK-LABEL: func @test_create_tdesc({{.*}}) { -func.func @test_create_tdesc(%src: ui64, %offsets : vector<16 x index>) { +// CHECK-LABEL: func @test_create_tdesc_vc({{.*}}) { +func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) { // CHECK: xegpu.create_tdesc - // CHECK-SAME: {memory_scope = slm, chunk_size_per_lane = 2} + // CHECK-SAME: {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {memory_scope = slm, chunk_size_per_lane = 2} + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> return } -// CHECK-LABEL: func @test_load_nd({{.*}}) { -func.func @test_load_nd(%src: memref<24x32xf16>, %x : index, %y : index) { +// CHECK-LABEL: func @test_load_nd_vc({{.*}}) { +func.func @test_load_nd_vc(%src: memref<24x32xf16>, %x : index, %y : index) { // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: %arg0[%arg1, %arg2] // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %src[%x, %y] + %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: xegpu.load_nd - // CHECK-SAME: {vnni_axis = 0, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> - %2 = xegpu.load_nd %1 {vnni_axis = 0, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> + %2 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> return } -// CHECK-LABEL: func @test_store_nd({{.*}}) { -func.func @test_store_nd(%src: memref<24x32xf16>, %dst: memref<24x32xf16>) { +// CHECK-LABEL: func @test_store_nd_vc({{.*}}) { +func.func @test_store_nd_vc(%src: memref<24x32xf16>, %dst: memref<24x32xf16>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {memory_scope = global, boundary_check = true} + // CHECK-SAME: {mode = vc, memory_scope = global, boundary_check = true} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {memory_scope = global, boundary_check = true} + // CHECK-SAME: {mode = vc, memory_scope = global, boundary_check = true} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] + %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: xegpu.load_nd - // CHECK-SAME: {l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %3 = xegpu.load_nd %1 {l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 {mode=vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> // CHECK: xegpu.store_nd - // CHECK-SAME: {l1_hint = write_back, l2_hint = uncached} + // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} // CHECK-SAME: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %3, %2 {l1_hint = write_back, l2_hint = uncached}: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> return } -// CHECK-LABEL: func @test_dpas({{.*}}) { -func.func @test_dpas(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) { +// CHECK-LABEL: func @test_dpas_vc({{.*}}) { +func.func @test_dpas_vc(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) { // CHECK: xegpu.dpas // CHECK-SAME: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> - %1 = xegpu.dpas %a, %b: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + %1 = xegpu.dpas %a, %b {mode = vc}: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> return } -// CHECK-LABEL: func @test_update_nd_offset({{.*}}) { -func.func @test_update_nd_offset(%src: memref<24x32xf32>) { +// CHECK-LABEL: func @test_update_nd_offset_vc({{.*}}) { +func.func @test_update_nd_offset_vc(%src: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {memory_scope = global, boundary_check = true} + // CHECK-SAME: {mode = vc, memory_scope = global, boundary_check = true} // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.load_nd - // CHECK-SAME: {l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - %2 = xegpu.load_nd %1 {l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %2 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // CHECK: xegpu.update_nd_offset // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> @@ -109,14 +109,14 @@ func.func @test_update_nd_offset(%src: memref<24x32xf32>) { return } -// CHECK-LABEL: func @test_prefetch_nd({{.*}}) { -func.func @test_prefetch_nd(%src: memref<24x32xf16>, %x : index, %y : index) { +// CHECK-LABEL: func @test_prefetch_nd_vc({{.*}}) { +func.func @test_prefetch_nd_vc(%src: memref<24x32xf16>, %x : index, %y : index) { // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {memory_scope = global, boundary_check = true} + // CHECK-SAME: {mode = vc, memory_scope = global, boundary_check = true} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %src[%x, %y] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: xegpu.prefetch_nd - // CHECK-SAME: {l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> - xegpu.prefetch_nd %1 {l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> + xegpu.prefetch_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> return } diff --git a/test/Dialect/XeGPU/IR/atomic_rmw.mlir b/test/Dialect/XeGPU/IR/atomic_rmw.mlir index 35f4b3cba..9de90500b 100644 --- a/test/Dialect/XeGPU/IR/atomic_rmw.mlir +++ b/test/Dialect/XeGPU/IR/atomic_rmw.mlir @@ -6,33 +6,33 @@ // CHECK-LABEL: func @test_atomic_rmw({{.*}}) { func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x1xf32>, %mask : vector<16xi1>) { - %1 = xegpu.create_tdesc %src, %offsets: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> // CHECK: xegpu.atomic_rmw // CHECK-SAME: (vector<16x1xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>) - xegpu.atomic_rmw "addf" %value, %1, %mask : (vector<16x1xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>) + xegpu.atomic_rmw "addf" %value, %1, %mask {mode = vc} : (vector<16x1xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>) return } // CHECK-LABEL: func @test_atomic_rmw_0({{.*}}) { func.func @test_atomic_rmw_0(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xf32>, %mask : vector<16xi1>) { - %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> // CHECK: xegpu.atomic_rmw // CHECK-SAME: (vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>) - xegpu.atomic_rmw "mulf" %value, %1, %mask : (vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>) + xegpu.atomic_rmw "mulf" %value, %1, %mask {mode = vc} : (vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>) return } // CHECK-LABEL: func @test_atomic_rmw_1({{.*}}) { func.func @test_atomic_rmw_1(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xi32>, %mask : vector<16xi1>) { - %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scattered> // CHECK: xegpu.atomic_rmw // CHECK-SAME: (vector<16x2xi32>, !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>) - xegpu.atomic_rmw "andi" %value, %1, %mask : (vector<16x2xi32>, !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>) + xegpu.atomic_rmw "andi" %value, %1, %mask {mode = vc} : (vector<16x2xi32>, !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>) return } diff --git a/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir b/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir index 588ee76a8..20882542c 100644 --- a/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir +++ b/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir @@ -3,100 +3,111 @@ // RUN: imex-opt %s | imex-opt | FileCheck %s // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_create_nd_tdesc_0({{.*}}) { -func.func @test_create_nd_tdesc_0(%src: memref<24x32xf32>) { +// CHECK-LABEL: func @test_create_nd_tdesc_vc_0({{.*}}) { +func.func @test_create_nd_tdesc_vc_0(%src: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %2 = xegpu.create_nd_tdesc %src[2, 4] + %2 = xegpu.create_nd_tdesc %src[2, 4] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> return } -// CHECK-LABEL: func @test_create_nd_tdesc_1({{.*}}) { -func.func @test_create_nd_tdesc_1(%src: memref<24x32xf32>, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_vc_1({{.*}}) { +func.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>, %x : index, %y : index) { // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: %arg0[%arg1, %arg2] // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y] + %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> return } -// CHECK-LABEL: func @test_create_nd_tdesc_2({{.*}}) { -func.func @test_create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_vc_2({{.*}}) { +func.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] // CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : ui64 -> !xegpu.tensor_desc<8x16xf32> return } -// CHECK-LABEL: func @test_create_nd_tdesc_3({{.*}}) { -func.func @test_create_nd_tdesc_3(%src: memref, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_vc_3({{.*}}) { +func.func @test_create_nd_tdesc_vc_3(%src: memref, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : memref -> !xegpu.tensor_desc<8x16xf32> return } -// CHECK-LABEL: func @test_create_nd_tdesc_4({{.*}}) { -func.func @test_create_nd_tdesc_4(%src: memref, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_vc_4({{.*}}) { +func.func @test_create_nd_tdesc_vc_4(%src: memref, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {boundary_check = true} : memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc, boundary_check = true} : memref -> !xegpu.tensor_desc<8x16xf32> return } -// CHECK-LABEL: func @test_create_nd_tdesc_5({{.*}}) { -func.func @test_create_nd_tdesc_5(%src: memref, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_vc_5({{.*}}) { +func.func @test_create_nd_tdesc_vc_5(%src: memref, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {memory_scope = slm} : memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc, memory_scope = slm} : memref -> !xegpu.tensor_desc<8x16xf32> return } -// CHECK-LABEL: func @test_create_nd_tdesc_6({{.*}}) { -func.func @test_create_nd_tdesc_6(%src: memref, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_vc_6({{.*}}) { +func.func @test_create_nd_tdesc_vc_6(%src: memref, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {memory_scope = slm, boundary_check = true} : memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc, memory_scope = slm, boundary_check = true} : memref -> !xegpu.tensor_desc<8x16xf32> return } -// CHECK-LABEL: func @test_create_nd_tdesc_7({{.*}}) { -func.func @test_create_nd_tdesc_7(%src: memref<1024xf32>, %offset : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_vc_7({{.*}}) { +func.func @test_create_nd_tdesc_vc_7(%src: memref<1024xf32>, %offset : index) { // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32> - %1 = xegpu.create_nd_tdesc %src[%offset] : memref<1024xf32> -> !xegpu.tensor_desc<16xf32> + %1 = xegpu.create_nd_tdesc %src[%offset] {mode = vc} : memref<1024xf32> -> !xegpu.tensor_desc<16xf32> return } -// CHECK-LABEL: func @test_create_nd_tdesc_8({{.*}}) { -func.func @test_create_nd_tdesc_8(%src: memref, %w : index, %h : index, %x : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_vc_8({{.*}}) { +func.func @test_create_nd_tdesc_vc_8(%src: memref, %w : index, %h : index, %x : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {memory_scope = slm, boundary_check = true} : memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {mode = vc, memory_scope = slm, boundary_check = true} : memref -> !xegpu.tensor_desc<8x16xf32> + return +} + +// CHECK-LABEL: func @test_create_nd_tdesc_vc_9({{.*}}) { +func.func @test_create_nd_tdesc_vc_9(%src: memref, %w : index, %h : index, %x : index) { + %c1 = arith.constant 1 : index + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, memory_scope = slm, boundary_check = true} + // CHECK-SAME: !xegpu.tensor_desc<64x128xf32, #xegpu.xe_map> + %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {memory_scope = slm, boundary_check = true} : memref + -> !xegpu.tensor_desc<64x128xf32, #xegpu.xe_map> return } diff --git a/test/Dialect/XeGPU/IR/create_tdesc.mlir b/test/Dialect/XeGPU/IR/create_tdesc.mlir index 1c828e569..079159655 100644 --- a/test/Dialect/XeGPU/IR/create_tdesc.mlir +++ b/test/Dialect/XeGPU/IR/create_tdesc.mlir @@ -5,41 +5,74 @@ // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_create_tdesc({{.*}}) { -func.func @test_create_tdesc(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 {memory_scope = global, chunk_size_per_lane = 1} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> +// CHECK-LABEL: func @test_create_tdesc_vc({{.*}}) { +func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) { + // CHECK: xegpu.create_tdesc %arg0, %arg1 + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 1} + // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> return } -// CHECK-LABEL: func @test_create_tdesc_2({{.*}}) { -func.func @test_create_tdesc_2(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 {memory_scope = slm, chunk_size_per_lane = 1} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {memory_scope=slm}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> +// CHECK-LABEL: func @test_create_tdesc_vc_2({{.*}}) { +func.func @test_create_tdesc_vc_2(%src: ui64, %offsets : vector<16 x index>) { + // CHECK: xegpu.create_tdesc %arg0, %arg1 + // CHECK-SAME: {mode = vc, memory_scope = slm, chunk_size_per_lane = 1} + // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, memory_scope=slm} + : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> return } -// CHECK-LABEL: func @test_create_tdesc_3({{.*}}) { -func.func @test_create_tdesc_3(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 {memory_scope = global, chunk_size_per_lane = 8} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 8} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> +// CHECK-LABEL: func @test_create_tdesc_vc_3({{.*}}) { +func.func @test_create_tdesc_vc_3(%src: ui64, %offsets : vector<16 x index>) { + // CHECK: xegpu.create_tdesc %arg0, %arg1 + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 8} + // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 8} + : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> return } -// CHECK-LABEL: func @test_create_tdesc_4({{.*}}) { -func.func @test_create_tdesc_4(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 {memory_scope = slm, chunk_size_per_lane = 2} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {memory_scope = slm, chunk_size_per_lane = 2} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> +// CHECK-LABEL: func @test_create_tdesc_vc_4({{.*}}) { +func.func @test_create_tdesc_vc_4(%src: ui64, %offsets : vector<16 x index>) { + // CHECK: xegpu.create_tdesc %arg0, %arg1 + // CHECK-SAME: {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} + // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} + : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> return } -// CHECK-LABEL: func @test_create_tdesc_5({{.*}}) { -func.func @test_create_tdesc_5(%src: memref, %offsets : vector<16 x index>) { +// CHECK-LABEL: func @test_create_tdesc_vc_5({{.*}}) { +func.func @test_create_tdesc_vc_5(%src: memref, %offsets : vector<16 x index>) { // CHECK: xegpu.create_tdesc - // CHECK-SAME: {memory_scope = slm, chunk_size_per_lane = 2} + // CHECK-SAME: {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} // CHECK-SAME: memref, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {memory_scope = slm, chunk_size_per_lane = 2} + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} : memref, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> return } + + +// CHECK-LABEL: func @test_create_tdesc_vc_6({{.*}}) { +func.func @test_create_tdesc_vc_6(%src: memref, %offset : index) { + // CHECK: xegpu.create_tdesc + // CHECK-SAME: {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} + // CHECK-SAME: memref, index -> !xegpu.tensor_desc<2xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offset {mode = vc, memory_scope = slm, chunk_size_per_lane = 2} + : memref, index -> !xegpu.tensor_desc<2xf32, #xegpu.scattered> + return +} + +// CHECK-LABEL: func @test_create_tdesc_vc_7({{.*}}) { +func.func @test_create_tdesc_vc_7(%src: memref, %offset : index) { + // CHECK: xegpu.create_tdesc + // CHECK-SAME: {mode = vc, memory_scope = slm, chunk_size_per_lane = 1} + // CHECK-SAME: memref, index -> !xegpu.tensor_desc<1xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offset {mode = vc, memory_scope = slm, chunk_size_per_lane = 1} + : memref, index -> !xegpu.tensor_desc<1xf32, #xegpu.scattered> + return +} + diff --git a/test/Dialect/XeGPU/IR/invalid.mlir b/test/Dialect/XeGPU/IR/invalid.mlir index 92ba3169a..893741193 100644 --- a/test/Dialect/XeGPU/IR/invalid.mlir +++ b/test/Dialect/XeGPU/IR/invalid.mlir @@ -1,17 +1,17 @@ // RUN: imex-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics // ----- -func.func @test_create_nd_tdesc_1(%src: memref<24xf32>) { +func.func @test_create_nd_tdesc_vc_1(%src: memref<24xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index // expected-error@+1 {{Expecting the rank of shape, strides and offsets should match with each other}} - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32> return } // ----- -func.func @test_create_nd_tdesc_2(%input: memref<24x32xf32>) { +func.func @test_create_nd_tdesc_vc_2(%input: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index @@ -19,13 +19,13 @@ func.func @test_create_nd_tdesc_2(%input: memref<24x32xf32>) { %c16 = arith.constant 16 : index // expected-error@+1 {{It is invalid to have both or none of dynamic shape and static shape. Only one of them is needed.}} - %1 = xegpu.create_nd_tdesc %input[%c0, %c1], [%c8, %c16], [%c16, %c1] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %input[%c0, %c1], [%c8, %c16], [%c16, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> return } // ----- -func.func @test_create_nd_tdesc_3(%input: memref) { +func.func @test_create_nd_tdesc_vc_3(%input: memref) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index @@ -33,35 +33,38 @@ func.func @test_create_nd_tdesc_3(%input: memref) { %c16 = arith.constant 16 : index // expected-error@+1 {{Expecting the rank of shape, strides and offsets should match with each other}} - %1 = xegpu.create_nd_tdesc %input[%c0, %c1], [%c8, %c16], [%c16, %c1] : memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %input[%c0, %c1], [%c8, %c16], [%c16, %c1] {mode = vc} : memref -> !xegpu.tensor_desc<8x16xf32> return } // ----- -func.func @test_create_nd_tdesc_4(%input: memref) { +func.func @test_create_nd_tdesc_vc_4(%input: memref) { %c1 = arith.constant 2 : index %c8 = arith.constant 8 : index // expected-error@+1 {{Expecting the rank of shape, strides and offsets should match with each other}} - %1 = xegpu.create_nd_tdesc %input[%c1], [%c8], [%c1] : memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %input[%c1], [%c8], [%c1] {mode = vc} + : memref -> !xegpu.tensor_desc<8x16xf32> return } // ----- -func.func @test_create_nd_tdesc_5(%input: memref<24x32x64xf32>) { +func.func @test_create_nd_tdesc_vc_5(%input: memref<24x32x64xf32>) { %c1 = arith.constant 2 : index %c8 = arith.constant 8 : index // expected-error@+1 {{operand #0 must be 1D/2D memref}} - %1 = xegpu.create_nd_tdesc %input[%c1, %c1, %c8] : memref<24x32x64xf32> -> !xegpu.tensor_desc<8x16x8xf32> + %1 = xegpu.create_nd_tdesc %input[%c1, %c1, %c8] {mode = vc} + : memref<24x32x64xf32> -> !xegpu.tensor_desc<8x16x8xf32> return } // ----- func.func @test_create_tdesc(%src: ui64, %offsets : vector<16x8xindex>) { // expected-error@+1 {{operand #1 must be vector of index values of ranks 1}} - %1 = xegpu.create_tdesc %src, %offsets: ui64, vector<16x8xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc} + : ui64, vector<16x8xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> return } @@ -69,44 +72,15 @@ func.func @test_create_tdesc(%src: ui64, %offsets : vector<16x8xindex>) { func.func @test_load_gather(%src: ui64, %offsets : vector<16xindex>) { %0 = arith.constant dense<1>: vector<16x8xi1> // CHECK: xegpu.create_tdesc - // CHECK-SAME: {memory_scope = global, chunk_size_per_lane = 8} + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 8} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 8}: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf16, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 8} + : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf16, #xegpu.scattered> - // for fp16 the vnni factor should be 2 instead of 4. - // expected-error@+1 {{Invalid vnni transform.}} - %2 = xegpu.load %1, %0 {vnni_axis = 0, l1_hint = cached, l2_hint = uncached} - : !xegpu.tensor_desc<16x8xf16, #xegpu.scattered>, vector<16x8xi1> -> vector<4x8x4xf16> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + %2 = xegpu.load %1, %0 {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<16x8xf16, #xegpu.scattered>, vector<16x8xi1> -> vector<8x8x4xf16> return } -// ----- -func.func @test_load_gather_2(%src: ui64, %offsets : vector<16xindex>) { - %0 = arith.constant dense<1>: vector<16x8xi1> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {memory_scope = global, chunk_size_per_lane = 8} - // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 8}: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> - - // for fp32, no vnni available. - // expected-error@+1 {{Invalid vnni transform.}} - %2 = xegpu.load %1, %0 {transpose = [1, 0], vnni_axis = 1, l1_hint = cached, l2_hint = uncached} - : !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x8x2xf32> - return -} - -// ----- -func.func @test_load_gather_3(%src: ui64, %offsets : vector<16xindex>) { - %0 = arith.constant dense<1>: vector<16x8xi1> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {memory_scope = global, chunk_size_per_lane = 8} - // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 8}: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf16, #xegpu.scattered> - - // for fp16 the vnni factor should be 2 instead of 4. - // expected-error@+1 {{Invalid vnni transform.}} - %2 = xegpu.load %1, %0 {vnni_axis = 0, l1_hint = cached, l2_hint = uncached} - : !xegpu.tensor_desc<16x8xf16, #xegpu.scattered>, vector<16x8xi1> -> vector<4x8x4xf16> - return -} diff --git a/test/Dialect/XeGPU/IR/load_gather.mlir b/test/Dialect/XeGPU/IR/load_gather.mlir index 38e9501ce..39ce00088 100644 --- a/test/Dialect/XeGPU/IR/load_gather.mlir +++ b/test/Dialect/XeGPU/IR/load_gather.mlir @@ -5,32 +5,72 @@ // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_load_gather({{.*}}) { -func.func @test_load_gather(%src: ui64, %offsets : vector<16xindex>) { +// CHECK-LABEL: func @test_load_gather_vc({{.*}}) { +func.func @test_load_gather_vc(%src: ui64, %offsets : vector<16xindex>) { %0 = arith.constant dense<1>: vector<16xi1> // CHECK: xegpu.create_tdesc - // CHECK-SAME: {memory_scope = global, chunk_size_per_lane = 1} + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 1} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc}: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> // CHECK: xegpu.load - // CHECK-SAME: {l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> - %2 = xegpu.load %1, %0 {l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> + %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> return } -// CHECK-LABEL: func @test_load_gather_2({{.*}}) { -func.func @test_load_gather_2(%src: ui64, %offsets : vector<16xindex>) { +// CHECK-LABEL: func @test_load_gather_vc_2({{.*}}) { +func.func @test_load_gather_vc_2(%src: ui64, %offsets : vector<16xindex>) { %0 = arith.constant dense<1>: vector<16x8xi1> // CHECK: xegpu.create_tdesc - // CHECK-SAME: {memory_scope = global, chunk_size_per_lane = 8} + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 8} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 8}: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 8} + : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> // CHECK: xegpu.load - // CHECK-SAME: {transpose = [1, 0], l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: {mode = vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32> - %2 = xegpu.load %1, %0 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32> + %2 = xegpu.load %1, %0 {mode = vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32> return } + + +// CHECK-LABEL: func @test_load_gather_vc_3({{.*}}) { +func.func @test_load_gather_vc_3(%src: ui64, %offset : index) { + %0 = arith.constant dense<1>: vector<8xi1> + // CHECK: xegpu.create_tdesc + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 8} + // CHECK-SAME: ui64, index -> !xegpu.tensor_desc<8xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offset {mode = vc, chunk_size_per_lane = 8} + : ui64, index -> !xegpu.tensor_desc<8xf32, #xegpu.scattered> + + // CHECK: xegpu.load + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: !xegpu.tensor_desc<8xf32, #xegpu.scattered>, vector<8xi1> -> vector<8xf32> + %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<8xf32, #xegpu.scattered>, vector<8xi1> -> vector<8xf32> + return +} + + +// CHECK-LABEL: func @test_load_gather_vc_4({{.*}}) { +func.func @test_load_gather_vc_4(%src: ui64, %offsets : vector<16xindex>) { + %0 = arith.constant dense<1>: vector<16xi1> + // CHECK: xegpu.create_tdesc + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 1} + // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 1} + : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + + // CHECK: xegpu.load + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> + %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> + return +} + diff --git a/test/Dialect/XeGPU/IR/load_nd.mlir b/test/Dialect/XeGPU/IR/load_nd.mlir index 9d305e2bc..c67cab01b 100644 --- a/test/Dialect/XeGPU/IR/load_nd.mlir +++ b/test/Dialect/XeGPU/IR/load_nd.mlir @@ -3,52 +3,154 @@ // RUN: imex-opt %s | imex-opt | FileCheck %s // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_load_nd_0({{.*}}) { -func.func @test_load_nd_0(%src: memref<24x32xf32>) { + +#sg_map_fp16_a = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}> +#sg_map_fp16_b = #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}> +#sg_map_fp16_c = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}> +// CHECK-LABEL: func @test_load_nd_fp16({{.*}}) { +func.func @test_load_nd_fp16(%A: memref<24x32xf16>, %B : memref<24x32xf16>, %C : memref<24x32xf16>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] - : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-SAME: {mode = simt, memory_scope = global, boundary_check = true} + // CHECK-SAME: memref<24x32xf16> + // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %A[%c0, %c1] + : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a> // CHECK: xegpu.load_nd - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - %2 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + // CHECK-SAME: {mode = simt, vnni_axis = 1} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + // CHECK-SAME: -> vector<4x2xf16> + %2 = xegpu.load_nd %1 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a> -> vector<4x2xf16> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, memory_scope = global, boundary_check = true} + // CHECK-SAME: memref<24x32xf16> + // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + %3 = xegpu.create_nd_tdesc %B[%c0, %c1] + : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b> // CHECK: xegpu.load_nd - // CHECK-SAME:{transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint = streaming} - // CHECK-SAME:!xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> - %3 = xegpu.load_nd %1 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint=streaming} : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> + // CHECK-SAME: {mode = simt, vnni_axis = 0} + // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + // CHECK-SAME: -> vector<8x2xf16> + %4 = xegpu.load_nd %3 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b> -> vector<8x2xf16> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, memory_scope = global, boundary_check = true} + // CHECK-SAME: memref<24x32xf16> + // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + %5 = xegpu.create_nd_tdesc %C[%c0, %c1] + : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = simt} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + // CHECK-SAME: -> vector<8x1xf32> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> -> vector<8x1xf32> + return } -// CHECK-LABEL: func @test_load_nd_1({{.*}}) { -func.func @test_load_nd_1(%src: memref<24x32xf16>, %x : index, %y : index) { +#sg_map_i8_a = #xegpu.sg_map<{mma_block_size = [8, 32], wi_layout = [2, 8], wi_data = [1, 4]}> +#sg_map_i8_b = #xegpu.sg_map<{mma_block_size = [32, 16], wi_layout = [1, 16], wi_data = [1, 1]}> +#sg_map_i8_c = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}> +// CHECK-LABEL: func @test_load_nd_i8({{.*}}) { +func.func @test_load_nd_i8(%A: memref<64x64xi8>, %B : memref<64x64xi8>, %C : memref<64x64xi8>) { + %c0 = arith.constant 2 : index + %c1 = arith.constant 4 : index + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, memory_scope = global, boundary_check = true} + // CHECK-SAME: memref<64x64xi8> + // CHECK-SAME: -> !xegpu.tensor_desc<8x32xi8, #xegpu.sg_map<{mma_block_size = [8, 32], wi_layout = [2, 8], wi_data = [1, 4]}>> + %1 = xegpu.create_nd_tdesc %A[%c0, %c1] + : memref<64x64xi8> -> !xegpu.tensor_desc<8x32xi8, #sg_map_i8_a> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = simt, vnni_axis = 1} + // CHECK-SAME: !xegpu.tensor_desc<8x32xi8, #xegpu.sg_map<{mma_block_size = [8, 32], wi_layout = [2, 8], wi_data = [1, 4]}>> + // CHECK-SAME: -> vector<4x4xi8> + %2 = xegpu.load_nd %1 {vnni_axis = 1} : !xegpu.tensor_desc<8x32xi8, #sg_map_i8_a> -> vector<4x4xi8> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, memory_scope = global, boundary_check = true} + // CHECK-SAME: memref<64x64xi8> + // CHECK-SAME: -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<{mma_block_size = [32, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + %3 = xegpu.create_nd_tdesc %B[%c0, %c1] + : memref<64x64xi8> -> !xegpu.tensor_desc<32x16xi8, #sg_map_i8_b> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = simt, vnni_axis = 0} + // CHECK-SAME: !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<{mma_block_size = [32, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + // CHECK-SAME: -> vector<8x4xi8> + %4 = xegpu.load_nd %3 {vnni_axis = 0} : !xegpu.tensor_desc<32x16xi8, #sg_map_i8_b> -> vector<8x4xi8> + // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg1, %arg2] - // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %src[%x, %y] - : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-SAME: {mode = simt, memory_scope = global, boundary_check = true} + // CHECK-SAME: memref<64x64xi8> + // CHECK-SAME: -> !xegpu.tensor_desc<8x16xi32, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + %5 = xegpu.create_nd_tdesc %C[%c0, %c1] + : memref<64x64xi8> -> !xegpu.tensor_desc<8x16xi32, #sg_map_i8_c> // CHECK: xegpu.load_nd - // CHECK-SAME: {vnni_axis = 0, l1_hint = cached, l2_hint = uncached} - // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> - %2 = xegpu.load_nd %1 {vnni_axis = 0, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> + // CHECK-SAME: {mode = simt} + // CHECK-SAME: !xegpu.tensor_desc<8x16xi32, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + // CHECK-SAME: -> vector<8x1xi32> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xi32, #sg_map_i8_c> -> vector<8x1xi32> + return } -// CHECK-LABEL: func @test_load_nd_2({{.*}}) { -func.func @test_load_nd_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) { - %c1 = arith.constant 1 : index + +#sg_map_f64_a = #xegpu.sg_map<{mma_block_size = [4, 8], wi_layout = [2, 8], wi_data = [1, 1]}> +#sg_map_f64_b = #xegpu.sg_map<{mma_block_size = [8, 8], wi_layout = [2, 8], wi_data = [1, 1]}> +#sg_map_f64_c = #xegpu.sg_map<{mma_block_size = [4, 8], wi_layout = [2, 8], wi_data = [1, 1]}> +// CHECK-LABEL: func @test_load_nd_f64({{.*}}) { +func.func @test_load_nd_f64(%A: memref<64x64xf64>, %B : memref<64x64xf64>, %C : memref<64x64xf64>) { + %c0 = arith.constant 2 : index + %c1 = arith.constant 4 : index + // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] - // CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf16> + // CHECK-SAME: {mode = simt, memory_scope = global, boundary_check = true} + // CHECK-SAME: memref<64x64xf64> + // CHECK-SAME: -> !xegpu.tensor_desc<4x8xf64, #xegpu.sg_map<{mma_block_size = [4, 8], wi_layout = [2, 8], wi_data = [1, 1]}>> + %1 = xegpu.create_nd_tdesc %A[%c0, %c1] + : memref<64x64xf64> -> !xegpu.tensor_desc<4x8xf64, #sg_map_f64_a> + // CHECK: xegpu.load_nd - // CHECK-SAME: {vnni_axis = 1, l1_hint = cached, l2_hint = uncached} - // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> - %2 = xegpu.load_nd %1 {vnni_axis = 1, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + // CHECK-SAME: {mode = simt} + // CHECK-SAME: !xegpu.tensor_desc<4x8xf64, #xegpu.sg_map<{mma_block_size = [4, 8], wi_layout = [2, 8], wi_data = [1, 1]}>> + // CHECK-SAME: -> vector<2x1xf64> + %2 = xegpu.load_nd %1 : !xegpu.tensor_desc<4x8xf64, #sg_map_f64_a> -> vector<2x1xf64> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, memory_scope = global, boundary_check = true} + // CHECK-SAME: memref<64x64xf64> + // CHECK-SAME: -> !xegpu.tensor_desc<8x8xf64, #xegpu.sg_map<{mma_block_size = [8, 8], wi_layout = [2, 8], wi_data = [1, 1]}>> + %3 = xegpu.create_nd_tdesc %B[%c0, %c1] + : memref<64x64xf64> -> !xegpu.tensor_desc<8x8xf64, #sg_map_f64_b> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = simt} + // CHECK-SAME: !xegpu.tensor_desc<8x8xf64, #xegpu.sg_map<{mma_block_size = [8, 8], wi_layout = [2, 8], wi_data = [1, 1]}>> + // CHECK-SAME: -> vector<4x1xf64> + %4 = xegpu.load_nd %3 : !xegpu.tensor_desc<8x8xf64, #sg_map_f64_b> -> vector<4x1xf64> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, memory_scope = global, boundary_check = true} + // CHECK-SAME: memref<64x64xf64> + // CHECK-SAME: -> !xegpu.tensor_desc<4x8xf64, #xegpu.sg_map<{mma_block_size = [4, 8], wi_layout = [2, 8], wi_data = [1, 1]}>> + %5 = xegpu.create_nd_tdesc %C[%c0, %c1] + : memref<64x64xf64> -> !xegpu.tensor_desc<4x8xf64, #sg_map_f64_c> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = simt} + // CHECK-SAME: !xegpu.tensor_desc<4x8xf64, #xegpu.sg_map<{mma_block_size = [4, 8], wi_layout = [2, 8], wi_data = [1, 1]}>> + // CHECK-SAME: -> vector<2x1xf64> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<4x8xf64, #sg_map_f64_c> -> vector<2x1xf64> + return } diff --git a/test/Dialect/XeGPU/IR/prefetch_nd.mlir b/test/Dialect/XeGPU/IR/prefetch_nd.mlir index 9d9371132..3604c91a6 100644 --- a/test/Dialect/XeGPU/IR/prefetch_nd.mlir +++ b/test/Dialect/XeGPU/IR/prefetch_nd.mlir @@ -3,27 +3,27 @@ // RUN: imex-opt %s | imex-opt | FileCheck %s // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_prefetch_nd_tdesc_0({{.*}}) { -func.func @test_prefetch_nd_tdesc_0(%src: memref<24x32xf32>) { +// CHECK-LABEL: func @test_prefetch_nd_tdesc_vc_0({{.*}}) { +func.func @test_prefetch_nd_tdesc_vc_0(%src: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.prefetch_nd %1 : !xegpu.tensor_desc<8x16xf32> + // CHECK: xegpu.prefetch_nd %0 {mode = vc} : !xegpu.tensor_desc<8x16xf32> + xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<8x16xf32> return } -// CHECK-LABEL: func @test_prefetch_nd_tdesc_1({{.*}}) { -func.func @test_prefetch_nd_tdesc_1(%src: memref<24x32xf16>, %x : index, %y : index) { +// CHECK-LABEL: func @test_prefetch_nd_tdesc_vc_1({{.*}}) { +func.func @test_prefetch_nd_tdesc_vc_1(%src: memref<24x32xf16>, %x : index, %y : index) { // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {memory_scope = global, boundary_check = true} + // CHECK-SAME: {mode = vc, memory_scope = global, boundary_check = true} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %src[%x, %y] + %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - // CHECK: xegpu.prefetch_nd %0 {l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> - xegpu.prefetch_nd %1 {l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> + // CHECK: xegpu.prefetch_nd %0 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> + xegpu.prefetch_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> return } diff --git a/test/Dialect/XeGPU/IR/simple_gemm.mlir b/test/Dialect/XeGPU/IR/simple_gemm.mlir index 4a394460b..2785fae26 100644 --- a/test/Dialect/XeGPU/IR/simple_gemm.mlir +++ b/test/Dialect/XeGPU/IR/simple_gemm.mlir @@ -5,8 +5,8 @@ // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_gemm({{.*}}) { -func.func @test_gemm(%a : memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) { +// CHECK-LABEL: func @test_gemm_vc({{.*}}) { +func.func @test_gemm_vc(%a : memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index @@ -21,11 +21,11 @@ func.func @test_gemm(%a : memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: scf.for %j= %c0 to %c1024 step %c16 { // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %a[%i, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %a[%i, %c0] {mode = vc} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> - %2 = xegpu.create_nd_tdesc %b[%c0, %j] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xegpu.create_nd_tdesc %b[%c0, %j] {mode = vc} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> %3 = arith.constant dense<0.0> : vector<8x16xf32> @@ -34,29 +34,30 @@ func.func @test_gemm(%a : memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) { // CHECK: xegpu.load_nd // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> - %4 = xegpu.load_nd %subA {vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %4 = xegpu.load_nd %subA {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> // CHECK: xegpu.load_nd // CHECK-SAME: !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %5 = xegpu.load_nd %subB {vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %5 = xegpu.load_nd %subB {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> // CHECK: xegpu.dpas // CHECK-SAME: vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %6 = xegpu.dpas %4, %5, %subC: vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %6 = xegpu.dpas %4, %5, %subC {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %7 = xegpu.update_nd_offset %subA, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %7 = xegpu.update_nd_offset %subA, [%c0, %c16] {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + + %8 = xegpu.update_nd_offset %subB, [%c16, %c0] {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> - %8 = xegpu.update_nd_offset %subB, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> scf.yield %7, %8, %6: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32> } // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> - %9 = xegpu.create_nd_tdesc %c[%i, %j] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %9 = xegpu.create_nd_tdesc %c[%i, %j] {mode = vc} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.store_nd // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %result, %9: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %result, %9 {mode = vc}: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> } } return diff --git a/test/Dialect/XeGPU/IR/store_nd.mlir b/test/Dialect/XeGPU/IR/store_nd.mlir index 17ea3bc79..ceae5645c 100644 --- a/test/Dialect/XeGPU/IR/store_nd.mlir +++ b/test/Dialect/XeGPU/IR/store_nd.mlir @@ -3,31 +3,31 @@ // RUN: imex-opt %s | imex-opt | FileCheck %s // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_store_nd_0({{.*}}) { -func.func @test_store_nd_0(%src: memref<24x32xf16>, %dst: memref<24x32xf16>) { +// CHECK-LABEL: func @test_store_nd_vc_0({{.*}}) { +func.func @test_store_nd_vc_0(%src: memref<24x32xf16>, %dst: memref<24x32xf16>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {memory_scope = global, boundary_check = true} + // CHECK-SAME: {mode = vc, memory_scope = global, boundary_check = true} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {memory_scope = global, boundary_check = true} + // CHECK-SAME: {mode = vc, memory_scope = global, boundary_check = true} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] + %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: xegpu.load_nd - // CHECK-SAME: {l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %3 = xegpu.load_nd %1 {l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> // CHECK: xegpu.store_nd - // CHECK-SAME: {l1_hint = write_back, l2_hint = uncached} + // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} // CHECK-SAME: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %3, %2 {l1_hint = write_back, l2_hint = uncached}: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> return } diff --git a/test/Dialect/XeGPU/IR/store_scatter.mlir b/test/Dialect/XeGPU/IR/store_scatter.mlir index b8ae40d1a..19238afab 100644 --- a/test/Dialect/XeGPU/IR/store_scatter.mlir +++ b/test/Dialect/XeGPU/IR/store_scatter.mlir @@ -5,26 +5,60 @@ // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_create_tdesc({{.*}}) { -func.func @test_create_tdesc(%src: ui64, %offsets : vector<16 x index>, %dst: ui64) { +// CHECK-LABEL: func @test_store_scatter_vc({{.*}}) { +func.func @test_store_scatter_vc(%src: ui64, %offsets : vector<16 x index>, %dst: ui64) { %0 = arith.constant dense<1>: vector<16xi1> // CHECK: xegpu.create_tdesc - // CHECK-SAME: {memory_scope = global, chunk_size_per_lane = 1} + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 1} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc} + : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> // CHECK: xegpu.create_tdesc - // CHECK-SAME: {memory_scope = global, chunk_size_per_lane = 1} + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 1} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - %2 = xegpu.create_tdesc %dst, %offsets: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %2 = xegpu.create_tdesc %dst, %offsets {mode = vc} + : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> // CHECK: xegpu.load - // CHECK-SAME: {l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> - %3 = xegpu.load %1, %0 {l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> + %3 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> // CHECK: xegpu.store - // CHECK-SAME: {l1_hint = write_back, l2_hint = uncached} + // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} // CHECK-SAME: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> - xegpu.store %3, %2, %0 {l1_hint = write_back, l2_hint = uncached}: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> + xegpu.store %3, %2, %0 {mode = vc, l1_hint = write_back, l2_hint = uncached} + : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> return } + + +// CHECK-LABEL: func @test_store_scatter({{.*}}) { +func.func @test_store_scatter(%src: ui64, %offsets : index, %dst: ui64) { + %0 = arith.constant 1: i1 + // CHECK: xegpu.create_tdesc + // CHECK-SAME: {mode = simt, memory_scope = global, chunk_size_per_lane = 1} + // CHECK-SAME: ui64, index -> !xegpu.tensor_desc<1xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets + : ui64, index -> !xegpu.tensor_desc<1xf32, #xegpu.scattered> + + // CHECK: xegpu.create_tdesc + // CHECK-SAME: {mode = simt, memory_scope = global, chunk_size_per_lane = 1} + // CHECK-SAME: ui64, index -> !xegpu.tensor_desc<1xf32, #xegpu.scattered> + %2 = xegpu.create_tdesc %dst, %offsets + : ui64, index -> !xegpu.tensor_desc<1xf32, #xegpu.scattered> + + // CHECK: xegpu.load + // CHECK-SAME: {mode = simt, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: !xegpu.tensor_desc<1xf32, #xegpu.scattered>, i1 -> f32 + %3 = xegpu.load %1, %0 {l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<1xf32, #xegpu.scattered>, i1 -> f32 + // CHECK: xegpu.store + // CHECK-SAME: {mode = simt, l1_hint = write_back, l2_hint = uncached} + // CHECK-SAME: f32, !xegpu.tensor_desc<1xf32, #xegpu.scattered>, i1 + xegpu.store %3, %2, %0 {l1_hint = write_back, l2_hint = uncached} + : f32, !xegpu.tensor_desc<1xf32, #xegpu.scattered>, i1 + return +} + diff --git a/test/Dialect/XeGPU/IR/update_nd_offset.mlir b/test/Dialect/XeGPU/IR/update_nd_offset.mlir index 4ab22ad71..93403ea5f 100644 --- a/test/Dialect/XeGPU/IR/update_nd_offset.mlir +++ b/test/Dialect/XeGPU/IR/update_nd_offset.mlir @@ -3,25 +3,27 @@ // RUN: imex-opt %s | imex-opt | FileCheck %s // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_update_nd_offset_0({{.*}}) { -func.func @test_update_nd_offset_0(%src: memref<24x32xf32>) { +// CHECK-LABEL: func @test_update_nd_offset_vc_0({{.*}}) { +func.func @test_update_nd_offset_vc_0(%src: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {memory_scope = global, boundary_check = true} + // CHECK-SAME: {mode = vc, memory_scope = global, boundary_check = true} // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.load_nd - // CHECK-SAME: {l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - %2 = xegpu.load_nd %1 {l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %2 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // CHECK: xegpu.update_nd_offset // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %3 = xegpu.update_nd_offset %1, [%c0, %c1]: !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.update_nd_offset %1, [%c0, %c1] {mode = vc} + : !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> return } diff --git a/test/Dialect/XeGPU/IR/update_offset.mlir b/test/Dialect/XeGPU/IR/update_offset.mlir index fd92a6e6b..b1e712a3c 100644 --- a/test/Dialect/XeGPU/IR/update_offset.mlir +++ b/test/Dialect/XeGPU/IR/update_offset.mlir @@ -5,25 +5,56 @@ // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_create_tdesc({{.*}}) { -func.func @test_create_tdesc(%src: ui64, %offsets : vector<16 x index>) { +// CHECK-LABEL: func @test_update_offset_VC({{.*}}) { +func.func @test_update_offset_VC(%src: ui64, %offsets : vector<16 x index>) { %0 = arith.constant dense<1>: vector<16xi1> // CHECK: xegpu.create_tdesc - // CHECK-SAME: {memory_scope = global, chunk_size_per_lane = 1} + // CHECK-SAME: {mode = vc, memory_scope = global, chunk_size_per_lane = 1} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc} + : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> // CHECK: xegpu.load - // CHECK-SAME: {l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> - %2 = xegpu.load %1, %0 {l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> + %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> %3 = arith.constant dense<16>: vector<16 x index> %4 = arith.addi %offsets, %3: vector<16 x index> // CHECK: xegpu.update_offset // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - %5 = xegpu.update_offset %1, %4: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %5 = xegpu.update_offset %1, %4 {mode = vc} + : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> return } + +// SIMT test code +// CHECK-LABEL: func @test_update_offset({{.*}}) { +func.func @test_update_offset(%src: ui64, %offsets : index) { + %0 = arith.constant dense<1>: vector<8xi1> + // CHECK: xegpu.create_tdesc + // CHECK-SAME: {mode = simt, memory_scope = global, chunk_size_per_lane = 8} + // CHECK-SAME: ui64, index -> !xegpu.tensor_desc<8xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 8} + : ui64, index -> !xegpu.tensor_desc<8xf32, #xegpu.scattered> + + // CHECK: xegpu.load + // CHECK-SAME: {mode = simt, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: !xegpu.tensor_desc<8xf32, #xegpu.scattered>, vector<8xi1> -> vector<8xf32> + %2 = xegpu.load %1, %0 {l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<8xf32, #xegpu.scattered>, vector<8xi1> -> vector<8xf32> + + %3 = arith.constant 16: index + %4 = arith.addi %offsets, %3: index + + // CHECK: xegpu.update_offset + // CHECK-SAME: !xegpu.tensor_desc<8xf32, #xegpu.scattered>, index -> !xegpu.tensor_desc<8xf32, #xegpu.scattered> + %5 = xegpu.update_offset %1, %4 + : !xegpu.tensor_desc<8xf32, #xegpu.scattered>, index -> !xegpu.tensor_desc<8xf32, #xegpu.scattered> + + return +} +