From 7ac6b72a863cb4d1ab181cec24774a76d6b33eea Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 20 Nov 2023 17:15:29 -0800 Subject: [PATCH] Add XeGPU test cases covering various combinations of Ops supported by XeGPU dialect Co-Authored-By: Chang, Liangliang --- test/Conversion/XeGPUToSPIRV/lit.local.cfg | 4 + test/Conversion/XeGPUToSPIRV/xegpu-to-vc.mlir | 70 ++++++++++ test/Dialect/XeGPU/IR/atomic_rmw.mlir | 19 +-- test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir | 38 ++++++ test/Dialect/XeGPU/IR/create_nd_tdesc.mlir | 120 +++++++++--------- test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir | 107 ++++++++++++++++ ...create_tdesc.mlir => create_tdesc_vc.mlir} | 0 .../IR/{invalid.mlir => invalid_vc.mlir} | 0 .../{load_gather.mlir => load_gather_vc.mlir} | 0 test/Dialect/XeGPU/IR/load_nd.mlir | 51 +++++++- test/Dialect/XeGPU/IR/load_nd_vc.mlir | 61 +++++++++ .../{prefetch_nd.mlir => prefetch_nd_vc.mlir} | 26 ++++ test/Dialect/XeGPU/IR/simple_gemm.mlir | 53 ++++---- test/Dialect/XeGPU/IR/simple_gemm_vc.mlir | 65 ++++++++++ test/Dialect/XeGPU/IR/store_nd.mlir | 33 ----- test/Dialect/XeGPU/IR/store_nd_vc.mlir | 92 ++++++++++++++ test/Dialect/XeGPU/IR/store_scatter.mlir | 30 ----- test/Dialect/XeGPU/IR/store_scatter_vc.mlir | 33 +++++ ...date_offset.mlir => update_offset_vc.mlir} | 28 ---- .../Dialect/XeGPU/load2d-padding.mlir | 20 --- 20 files changed, 649 insertions(+), 201 deletions(-) create mode 100644 test/Conversion/XeGPUToSPIRV/lit.local.cfg create mode 100644 test/Conversion/XeGPUToSPIRV/xegpu-to-vc.mlir create mode 100644 test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir create mode 100644 test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir rename test/Dialect/XeGPU/IR/{create_tdesc.mlir => create_tdesc_vc.mlir} (100%) rename test/Dialect/XeGPU/IR/{invalid.mlir => invalid_vc.mlir} (100%) rename test/Dialect/XeGPU/IR/{load_gather.mlir => load_gather_vc.mlir} (100%) create mode 100644 test/Dialect/XeGPU/IR/load_nd_vc.mlir rename test/Dialect/XeGPU/IR/{prefetch_nd.mlir => prefetch_nd_vc.mlir} (54%) create mode 100644 test/Dialect/XeGPU/IR/simple_gemm_vc.mlir delete mode 100644 test/Dialect/XeGPU/IR/store_nd.mlir create mode 100644 test/Dialect/XeGPU/IR/store_nd_vc.mlir create mode 100644 test/Dialect/XeGPU/IR/store_scatter_vc.mlir rename test/Dialect/XeGPU/IR/{update_offset.mlir => update_offset_vc.mlir} (56%) diff --git a/test/Conversion/XeGPUToSPIRV/lit.local.cfg b/test/Conversion/XeGPUToSPIRV/lit.local.cfg new file mode 100644 index 000000000..d23a14a3b --- /dev/null +++ b/test/Conversion/XeGPUToSPIRV/lit.local.cfg @@ -0,0 +1,4 @@ +local_excludes = [ + 'gemm_basic.mlir' + ] +config.excludes.update(local_excludes) diff --git a/test/Conversion/XeGPUToSPIRV/xegpu-to-vc.mlir b/test/Conversion/XeGPUToSPIRV/xegpu-to-vc.mlir new file mode 100644 index 000000000..36d438780 --- /dev/null +++ b/test/Conversion/XeGPUToSPIRV/xegpu-to-vc.mlir @@ -0,0 +1,70 @@ +// RUN: imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s + +gpu.module @test attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + // CHECK: spirv.ConvertPtrToU + // CHECK: spirv.VectorInsertDynamic + gpu.func @create_nd_tdesc(%src: memref<64x64xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c32 = arith.constant 16 : index + %0 = xegpu.create_nd_tdesc %src[%c32, 0] {mode = vc} : memref<64x64xf16> -> !xegpu.tensor_desc<8x16xf16> + gpu.return + } + + + // CHECK-LABEL: spirv.func @llvm_genx_raw_send2_v128i32_i1_v8i32 + // CHECK (i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<128xi32>) + // CHECK: -> vector<128xi32> "None" attributes + // CHECK: {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} + // CHECK-LABEL: spirv.func @load_nd + // CHECK: %[[ptr:.*]]: !spirv.ptr, CrossWorkgroup> + // CHECK: %[[ptr_i64:.*]] = spirv.ConvertPtrToU %[[ptr]] : !spirv.ptr, CrossWorkgroup> to i64 + // CHECK: %{{.*}} = spirv.FunctionCall @llvm_genx_raw_send2_v128i32_i1_v8i32 + + gpu.func @load_nd(%src : memref<64x64xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %1 = xegpu.create_nd_tdesc %src[0, 0] { mode = vc} : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16> + %3 = xegpu.load_nd %1 {vnni_axis = 0, mode = vc} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + gpu.return + } + + // CHECK-LABEL: spirv.func @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32(vector<128xi32>, vector<64xi32>, i32) + // CHECK: -> vector<128xf32> "None" attributes {VectorComputeFunctionINTEL, linkage_attributes = + // CHECK: #spirv.linkage_attributes>} + // CHECK-LABEL: spirv.func @dpas + // CHECK: (%[[A:.*]]: vector<64xi32>, %[[B:.*]]: vector<128xi32>) + // CHECK-NEXT: %[[cst134744586_i32:.*]] = spirv.Constant 134744586 : i32 + // CHECK-NEXT: %{{.*}} = spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32(%[[B]], %[[A]], %[[cst134744586_i32]]) + // CHECK: (vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32> + gpu.func @dpas(%A : vector<8x8x2xf16>, %B : vector<8x16x2xf16>) + kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %C = xegpu.dpas %A, %B { mode = vc }: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + gpu.return + } + + + // CHECK: (i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<128xf32>) + // CHECK: "None" attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} + // CHECK: (%[[value:.*]]: vector<128xf32>, %[[ptr:.*]]: !spirv.ptr, CrossWorkgroup>) + // CHECK: %[[ptr_i64]] = spirv.ConvertPtrToU %[[ptr]] : !spirv.ptr, CrossWorkgroup> to i64 + // CHECK: spirv.FunctionCall @llvm_genx_raw_sends2_noresult_i1_v8i32_v128f32 + gpu.func @store_nd(%value : vector<8x16xf32>, %dest : memref<64x64xf32>) + kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %1 = xegpu.create_nd_tdesc %dest[0, 0] { mode = vc } : memref<64x64xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %value, %1 { mode = vc } : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } + // CHECK: (i8, i8, i1, i8, i8, i32, i32, vector<8xi32>) + // CHECK: "None" attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} + // CHECK: (%[[ptr:.*]]: !spirv.ptr, CrossWorkgroup>) + // CHECK: spirv.ConvertPtrToU %[[ptr]] : !spirv.ptr, CrossWorkgroup> to i64 + // CHECK: spirv.VectorInsertDynamic + // CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32 + gpu.func @prefetch(%src : memref<64x64xf16>) + kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %src[0, 0] { mode = vc } : memref<64x64xf16> -> !xegpu.tensor_desc<8x16xf16> + xegpu.prefetch_nd %0 { mode = vc } : !xegpu.tensor_desc<8x16xf16> + gpu.return + } + +} diff --git a/test/Dialect/XeGPU/IR/atomic_rmw.mlir b/test/Dialect/XeGPU/IR/atomic_rmw.mlir index 5f4ea2919..dc5bdc70a 100644 --- a/test/Dialect/XeGPU/IR/atomic_rmw.mlir +++ b/test/Dialect/XeGPU/IR/atomic_rmw.mlir @@ -4,35 +4,36 @@ // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s +#sg_map_fp32 = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}> // CHECK-LABEL: func @test_atomic_rmw({{.*}}) { func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x1xf32>, %mask : vector<16xi1>) { - %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #sg_map_fp32> // CHECK: xegpu.atomic_rmw - // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32> - xegpu.atomic_rmw "addf" %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32> -> vector<16x1xf32> + // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>>, vector<16xi1>, vector<16x1xf32> + xegpu.atomic_rmw "addf" %1, %mask, %value: !xegpu.tensor_desc<16xf32, #sg_map_fp32>, vector<16xi1>, vector<16x1xf32> -> vector<16x1xf32> return } // CHECK-LABEL: func @test_atomic_rmw_0({{.*}}) { func.func @test_atomic_rmw_0(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xf32>, %mask : vector<16xi1>) { - %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #sg_map_fp32> // CHECK: xegpu.atomic_rmw - // CHECK-SAME: !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> - xegpu.atomic_rmw "mulf" %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32> + // CHECK-SAME: tensor_desc<16x2xf32, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32> + xegpu.atomic_rmw "mulf" %1, %mask, %value : !xegpu.tensor_desc<16x2xf32, #sg_map_fp32>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32> return } // CHECK-LABEL: func @test_atomic_rmw_1({{.*}}) { func.func @test_atomic_rmw_1(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xi32>, %mask : vector<16xi1>) { - %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xi32, #sg_map_fp32> // CHECK: xegpu.atomic_rmw - // CHECK-SAME: !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> - xegpu.atomic_rmw "andi" %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32> + // CHECK-SAME: !xegpu.tensor_desc<16x2xi32, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32> + xegpu.atomic_rmw "andi" %1, %mask, %value: !xegpu.tensor_desc<16x2xi32, #sg_map_fp32>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32> return } diff --git a/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir b/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir new file mode 100644 index 000000000..5f4ea2919 --- /dev/null +++ b/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir @@ -0,0 +1,38 @@ +// RUN: imex-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: imex-opt %s | imex-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s + +// CHECK-LABEL: func @test_atomic_rmw({{.*}}) { +func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x1xf32>, %mask : vector<16xi1>) { + %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + + // CHECK: xegpu.atomic_rmw + // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32> + xegpu.atomic_rmw "addf" %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32> -> vector<16x1xf32> + + return +} + +// CHECK-LABEL: func @test_atomic_rmw_0({{.*}}) { +func.func @test_atomic_rmw_0(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xf32>, %mask : vector<16xi1>) { + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered> + + // CHECK: xegpu.atomic_rmw + // CHECK-SAME: !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> + xegpu.atomic_rmw "mulf" %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32> + + return +} + +// CHECK-LABEL: func @test_atomic_rmw_1({{.*}}) { +func.func @test_atomic_rmw_1(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xi32>, %mask : vector<16xi1>) { + %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scattered> + + // CHECK: xegpu.atomic_rmw + // CHECK-SAME: !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> + xegpu.atomic_rmw "andi" %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32> + + return +} diff --git a/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir b/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir index 5ebb32297..d5aa32cb8 100644 --- a/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir +++ b/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir @@ -3,114 +3,120 @@ // RUN: imex-opt %s | imex-opt | FileCheck %s // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_create_nd_tdesc_vc_0({{.*}}) { -func.func @test_create_nd_tdesc_vc_0(%src: memref<24x32xf32>) { + +#sg_map_fp16 = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}> + +func.func @test_create_nd_tdesc_0(%src: memref<24x32xf16>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} - : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] + : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16> // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %2 = xegpu.create_nd_tdesc %src[2, 4] {mode = vc} - : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %2 = xegpu.create_nd_tdesc %src[2, 4] + : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16> return } -// CHECK-LABEL: func @test_create_nd_tdesc_vc_1({{.*}}) { -func.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_1({{.*}}) { +func.func @test_create_nd_tdesc_1(%src: memref<24x32xf16>, %x : index, %y : index) { // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg1, %arg2] - // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} - : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %src[%x, %y] + : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16> return } -// CHECK-LABEL: func @test_create_nd_tdesc_vc_2({{.*}}) { -func.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_2({{.*}}) { +func.func @test_create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] - // CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : ui64 -> !xegpu.tensor_desc<8x16xf32> + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16> return } -// CHECK-LABEL: func @test_create_nd_tdesc_vc_3({{.*}}) { -func.func @test_create_nd_tdesc_vc_3(%src: memref, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_3({{.*}}) { +func.func @test_create_nd_tdesc_3(%src: memref, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] - // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : memref -> !xegpu.tensor_desc<8x16xf32> + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : memref -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16> return } -// CHECK-LABEL: func @test_create_nd_tdesc_vc_4({{.*}}) { -func.func @test_create_nd_tdesc_vc_4(%src: memref, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_4({{.*}}) { +func.func @test_create_nd_tdesc_4(%src: memref, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] - // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc, boundary_check = true} : memref -> !xegpu.tensor_desc<8x16xf32> + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {boundary_check = true} : memref + -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16> return } -// CHECK-LABEL: func @test_create_nd_tdesc_vc_5({{.*}}) { -func.func @test_create_nd_tdesc_vc_5(%src: memref, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_5({{.*}}) { +func.func @test_create_nd_tdesc_5(%src: memref, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] - // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} - : memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf16, memory_scope = slm, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] + : memref -> !xegpu.tensor_desc<8x16xf16, memory_scope = slm, #sg_map_fp16> return } -// CHECK-LABEL: func @test_create_nd_tdesc_vc_6({{.*}}) { -func.func @test_create_nd_tdesc_vc_6(%src: memref, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_6({{.*}}) { +func.func @test_create_nd_tdesc_6(%src: memref, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] - // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> - %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc, boundary_check = true} - : memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf16, memory_scope = slm, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {boundary_check = true} + : memref -> !xegpu.tensor_desc<8x16xf16, memory_scope = slm, #sg_map_fp16> return } - -// CHECK-LABEL: func @test_create_nd_tdesc_vc_7({{.*}}) { -func.func @test_create_nd_tdesc_vc_7(%src: memref<1024xf32>, %offset : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_7({{.*}}) { +func.func @test_create_nd_tdesc_7(%src: memref<1024xf16>, %offset : index) { // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32> - %1 = xegpu.create_nd_tdesc %src[%offset] {mode = vc} : memref<1024xf32> -> !xegpu.tensor_desc<16xf32> + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref<1024xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %src[%offset] : memref<1024xf16> -> !xegpu.tensor_desc<16xf16, #sg_map_fp16> return } -// CHECK-LABEL: func @test_create_nd_tdesc_vc_8({{.*}}) { -func.func @test_create_nd_tdesc_vc_8(%src: memref, %w : index, %h : index, %x : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_8({{.*}}) { +func.func @test_create_nd_tdesc_8(%src: memref, %w : index, %h : index, %x : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> - %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {mode = vc, boundary_check = true} - : memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf16, memory_scope = slm, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {boundary_check = true} + : memref -> !xegpu.tensor_desc<8x16xf16, memory_scope = slm, #sg_map_fp16> return } -// CHECK-LABEL: func @test_create_nd_tdesc_vc_9({{.*}}) { -func.func @test_create_nd_tdesc_vc_9(%src: memref, %w : index, %h : index, %x : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}}) { +func.func @test_create_nd_tdesc_9(%src: memref, %w : index, %h : index, %x : index) { %c1 = arith.constant 1 : index // CHECK: xegpu.create_nd_tdesc // CHECK-SAME: {mode = simt, boundary_check = true} - // CHECK-SAME: !xegpu.tensor_desc<64x128xf32, memory_scope = slm, #xegpu.xe_map> - %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {boundary_check = true} : memref - -> !xegpu.tensor_desc<64x128xf32, memory_scope = slm, #xegpu.xe_map> + // CHECK-SAME: memref -> !xegpu.tensor_desc<64x128xf16, memory_scope = slm, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {boundary_check = true} : memref + -> !xegpu.tensor_desc<64x128xf16, memory_scope = slm, #sg_map_fp16> return } diff --git a/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir b/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir new file mode 100644 index 000000000..d44014687 --- /dev/null +++ b/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir @@ -0,0 +1,107 @@ +// RUN: imex-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: imex-opt %s | imex-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s + +// ----- SIMD ----- +// CHECK-LABEL: func @test_create_nd_tdesc_vc_0({{.*}}) { +func.func @test_create_nd_tdesc_vc_0(%src: memref<24x32xf32>) { + %c0 = arith.constant 2 : index + %c1 = arith.constant 4 : index + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} + : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %2 = xegpu.create_nd_tdesc %src[2, 4] {mode = vc} + : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + + return +} + +// CHECK-LABEL: func @test_create_nd_tdesc_vc_1({{.*}}) { +func.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>, %x : index, %y : index) { + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: %arg0[%arg1, %arg2] + // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} + : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + return +} + +// CHECK-LABEL: func @test_create_nd_tdesc_vc_2({{.*}}) { +func.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) { + %c1 = arith.constant 1 : index + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] + // CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : ui64 -> !xegpu.tensor_desc<8x16xf32> + return +} + +// CHECK-LABEL: func @test_create_nd_tdesc_vc_3({{.*}}) { +func.func @test_create_nd_tdesc_vc_3(%src: memref, %w : index, %h : index, %x : index, %y : index) { + %c1 = arith.constant 1 : index + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] + // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : memref -> !xegpu.tensor_desc<8x16xf32> + return +} + + +// CHECK-LABEL: func @test_create_nd_tdesc_vc_4({{.*}}) { +func.func @test_create_nd_tdesc_vc_4(%src: memref, %w : index, %h : index, %x : index, %y : index) { + %c1 = arith.constant 1 : index + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] + // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc, boundary_check = true} : memref -> !xegpu.tensor_desc<8x16xf32> + return +} + +// CHECK-LABEL: func @test_create_nd_tdesc_vc_5({{.*}}) { +func.func @test_create_nd_tdesc_vc_5(%src: memref, %w : index, %h : index, %x : index, %y : index) { + %c1 = arith.constant 1 : index + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] + // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} + : memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> + return +} + +// CHECK-LABEL: func @test_create_nd_tdesc_vc_6({{.*}}) { +func.func @test_create_nd_tdesc_vc_6(%src: memref, %w : index, %h : index, %x : index, %y : index) { + %c1 = arith.constant 1 : index + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] + // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc, boundary_check = true} + : memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> + return +} + + +// CHECK-LABEL: func @test_create_nd_tdesc_vc_7({{.*}}) { +func.func @test_create_nd_tdesc_vc_7(%src: memref<1024xf32>, %offset : index) { + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32> + %1 = xegpu.create_nd_tdesc %src[%offset] {mode = vc} : memref<1024xf32> -> !xegpu.tensor_desc<16xf32> + return +} + + +// CHECK-LABEL: func @test_create_nd_tdesc_vc_8({{.*}}) { +func.func @test_create_nd_tdesc_vc_8(%src: memref, %w : index, %h : index, %x : index) { + %c1 = arith.constant 1 : index + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> + %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {mode = vc, boundary_check = true} + : memref -> !xegpu.tensor_desc<8x16xf32, memory_scope = slm> + return +} diff --git a/test/Dialect/XeGPU/IR/create_tdesc.mlir b/test/Dialect/XeGPU/IR/create_tdesc_vc.mlir similarity index 100% rename from test/Dialect/XeGPU/IR/create_tdesc.mlir rename to test/Dialect/XeGPU/IR/create_tdesc_vc.mlir diff --git a/test/Dialect/XeGPU/IR/invalid.mlir b/test/Dialect/XeGPU/IR/invalid_vc.mlir similarity index 100% rename from test/Dialect/XeGPU/IR/invalid.mlir rename to test/Dialect/XeGPU/IR/invalid_vc.mlir diff --git a/test/Dialect/XeGPU/IR/load_gather.mlir b/test/Dialect/XeGPU/IR/load_gather_vc.mlir similarity index 100% rename from test/Dialect/XeGPU/IR/load_gather.mlir rename to test/Dialect/XeGPU/IR/load_gather_vc.mlir diff --git a/test/Dialect/XeGPU/IR/load_nd.mlir b/test/Dialect/XeGPU/IR/load_nd.mlir index 91d2b6025..3616c05bd 100644 --- a/test/Dialect/XeGPU/IR/load_nd.mlir +++ b/test/Dialect/XeGPU/IR/load_nd.mlir @@ -67,6 +67,56 @@ func.func @test_load_nd_fp16(%A: memref<24x32xf16>, %B : memref<24x32xf16>, %C : return } +#sg_map_bf16_a = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}> +#sg_map_bf16_b = #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}> +#sg_map_bf16_c = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}> +// CHECK-LABEL: func @test_load_nd_bf16({{.*}}) { +func.func @test_load_nd_bf16(%A: memref<24x32xbf16>, %B : memref<24x32xbf16>, %C : memref<24x32xbf16>) { + %c0 = arith.constant 2 : index + %c1 = arith.constant 4 : index + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref<24x32xbf16> + // CHECK-SAME: -> !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %A[%c0, %c1] + : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16, #sg_map_bf16_a> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = simt, vnni_axis = 1} + // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + // CHECK-SAME: -> vector<4x1x2xbf16> + %2 = xegpu.load_nd %1 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xbf16, #sg_map_bf16_a> -> vector<4x1x2xbf16> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref<24x32xbf16> + // CHECK-SAME: -> !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + %3 = xegpu.create_nd_tdesc %B[%c0, %c1] + : memref<24x32xbf16> -> !xegpu.tensor_desc<16x16xbf16, #sg_map_bf16_b> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = simt, vnni_axis = 0} + // CHECK-SAME: !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + // CHECK-SAME: -> vector<8x1x2xbf16> + %4 = xegpu.load_nd %3 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xbf16, #sg_map_bf16_b> -> vector<8x1x2xbf16> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = simt, boundary_check = true} + // CHECK-SAME: memref<24x32xbf16> + // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + %5 = xegpu.create_nd_tdesc %C[%c0, %c1] + : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = simt} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + // CHECK-SAME: -> vector<8x1xf32> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf32, #sg_map_bf16_c> -> vector<8x1xf32> + + return +} + #sg_map_i8_a = #xegpu.sg_map<{mma_block_size = [8, 32], wi_layout = [2, 8], wi_data = [1, 4]}> #sg_map_i8_b = #xegpu.sg_map<{mma_block_size = [32, 16], wi_layout = [1, 16], wi_data = [1, 1]}> #sg_map_i8_c = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}> @@ -117,7 +167,6 @@ func.func @test_load_nd_i8(%A: memref<64x64xi8>, %B : memref<64x64xi8>, %C : mem return } - #sg_map_f64_a = #xegpu.sg_map<{mma_block_size = [4, 8], wi_layout = [2, 8], wi_data = [1, 1]}> #sg_map_f64_b = #xegpu.sg_map<{mma_block_size = [8, 8], wi_layout = [2, 8], wi_data = [1, 1]}> #sg_map_f64_c = #xegpu.sg_map<{mma_block_size = [4, 8], wi_layout = [2, 8], wi_data = [1, 1]}> diff --git a/test/Dialect/XeGPU/IR/load_nd_vc.mlir b/test/Dialect/XeGPU/IR/load_nd_vc.mlir new file mode 100644 index 000000000..dd794285b --- /dev/null +++ b/test/Dialect/XeGPU/IR/load_nd_vc.mlir @@ -0,0 +1,61 @@ +// RUN: imex-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: imex-opt %s | imex-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s + +// -- SIMD --- +// CHECK-LABEL: func @test_load_nd_simd_f32({{.*}}) { +func.func @test_load_nd_simd_f32(%src: memref<24x32xf32>) { + %c0 = arith.constant 2 : index + %c1 = arith.constant 4 : index + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = vc, boundary_check = true} + // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc, boundary_check = true} + : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = vc} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %2 = xegpu.load_nd %1 {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + + // CHECK: xegpu.load_nd + // CHECK-SAME:{mode = vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint = streaming} + // CHECK-SAME:!xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> + %3 = xegpu.load_nd %1 {mode= vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint=streaming} : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> + return +} + +// CHECK-LABEL: func @test_load_nd_simd_f16({{.*}}) { +func.func @test_load_nd_simd_f16(%src: memref<24x32xf16>, %x : index, %y : index) { + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: %arg0[%arg1, %arg2] + // CHECK-SAME: {mode = vc, boundary_check = true} + // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc, boundary_check = true} + : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> + %2 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> + return +} + +// CHECK-LABEL: func @test_load_nd_simd_bf16({{.*}}) { +func.func @test_load_nd_simd_bf16(%src: ui64, %w : index, %h : index, %x : index, %y : index) { + %c1 = arith.constant 1 : index + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] + // CHECK-SAME: {mode = vc, boundary_check = true} + // CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xbf16> + %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc, boundary_check = true} : ui64 -> !xegpu.tensor_desc<8x16xbf16> + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = vc, vnni_axis = 1, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16> + %2 = xegpu.load_nd %1 {mode=vc, vnni_axis = 1, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16> + + return +} diff --git a/test/Dialect/XeGPU/IR/prefetch_nd.mlir b/test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir similarity index 54% rename from test/Dialect/XeGPU/IR/prefetch_nd.mlir rename to test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir index 402302610..5d8a2fd0c 100644 --- a/test/Dialect/XeGPU/IR/prefetch_nd.mlir +++ b/test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir @@ -27,3 +27,29 @@ func.func @test_prefetch_nd_tdesc_vc_1(%src: memref<24x32xf16>, %x : index, %y : xegpu.prefetch_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> return } + + +// CHECK-LABEL: func @test_prefetch_nd_tdesc_vc_i8({{.*}}) { +func.func @test_prefetch_nd_tdesc_vc_i8(%src: memref<24x32xi8>) { + %c0 = arith.constant 2 : index + %c1 = arith.constant 4 : index + + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> + + // CHECK: xegpu.prefetch_nd %0 {mode = vc} : !xegpu.tensor_desc<8x16xi8> + xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<8x16xi8> + + return +} + +// CHECK-LABEL: func @test_prefetch_nd_tdesc_vc_bf16({{.*}}) { +func.func @test_prefetch_nd_tdesc_vc_bf16(%src: memref<24x32xbf16>, %x : index, %y : index) { + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = vc, boundary_check = true} + // CHECK-SAME: memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> + %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} + : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> + // CHECK: xegpu.prefetch_nd %0 {mode = vc, l1_hint = uncached, l2_hint = cached} : !xegpu.tensor_desc<8x16xbf16> + xegpu.prefetch_nd %1 {mode = vc, l1_hint = uncached, l2_hint = cached}: !xegpu.tensor_desc<8x16xbf16> + return +} diff --git a/test/Dialect/XeGPU/IR/simple_gemm.mlir b/test/Dialect/XeGPU/IR/simple_gemm.mlir index 2785fae26..7c0d59827 100644 --- a/test/Dialect/XeGPU/IR/simple_gemm.mlir +++ b/test/Dialect/XeGPU/IR/simple_gemm.mlir @@ -4,9 +4,13 @@ // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s +// ---- BF16 ------ -// CHECK-LABEL: func @test_gemm_vc({{.*}}) { -func.func @test_gemm_vc(%a : memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) { +#sg_map_fp16_a = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}> +#sg_map_fp16_b = #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}> +#sg_map_fp16_c = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}> +// CHECK-LABEL: func @test_gemm_bf16({{.*}}) { +func.func @test_gemm_bf16(%a : memref<1024x1024xbf16>, %b: memref<1024x1024xbf16>, %c: memref<1024x1024xf32>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index @@ -20,44 +24,47 @@ func.func @test_gemm_vc(%a : memref<1024x1024xf16>, %b: memref<1024x1024xf16>, % scf.for %i= %c0 to %c1024 step %c8 { scf.for %j= %c0 to %c1024 step %c16 { // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %a[%i, %c0] {mode = vc} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-SAME: memref<1024x1024xbf16> + // CHECK-SAME: -> !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>> + %1 = xegpu.create_nd_tdesc %a[%i, %c0] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a> // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> - %2 = xegpu.create_nd_tdesc %b[%c0, %j] {mode = vc} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-SAME: memref<1024x1024xbf16> + // CHECK-SAME: -> !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}>> + %2 = xegpu.create_nd_tdesc %b[%c0, %j] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b> - %3 = arith.constant dense<0.0> : vector<8x16xf32> + %3 = arith.constant dense<0.0> : vector<8x1xf32> - %tmp0, %tmp1, %result = scf.for %k= %c0 to %c1024 step %c16 - iter_args(%subA = %1, %subB = %2, %subC = %3) - -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) { + %tmp0, %tmp1, %result = scf.for %k= %c0 to %c1024 step %c16 iter_args(%subA = %1, %subB = %2, %subC = %3) + -> (!xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>, !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>, vector<8x1xf32>) { // CHECK: xegpu.load_nd - // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> - %4 = xegpu.load_nd %subA {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + // CHECK-SAME: vector<4x1x2xbf16> + %4 = xegpu.load_nd %subA {vnni_axis = 1} : !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a> -> vector<4x1x2xbf16> // CHECK: xegpu.load_nd - // CHECK-SAME: !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %5 = xegpu.load_nd %subB {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-SAME: vector<8x1x2xbf16> + %5 = xegpu.load_nd %subB {vnni_axis = 0} : !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b> -> vector<8x1x2xbf16> // CHECK: xegpu.dpas - // CHECK-SAME: vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %6 = xegpu.dpas %4, %5, %subC {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-SAME: vector<4x1x2xbf16>, vector<8x1x2xbf16>, vector<8x1xf32> -> vector<8x1xf32> + %6 = xegpu.dpas %4, %5, %subC : vector<4x1x2xbf16>, vector<8x1x2xbf16>, vector<8x1xf32> -> vector<8x1xf32> - %7 = xegpu.update_nd_offset %subA, [%c0, %c16] {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %7 = xegpu.update_nd_offset %subA, [%c0, %c16] : !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a> + -> !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a> - %8 = xegpu.update_nd_offset %subB, [%c16, %c0] {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %8 = xegpu.update_nd_offset %subB, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b> + -> !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b> - scf.yield %7, %8, %6: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32> + scf.yield %7, %8, %6: !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>, !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>, vector<8x1xf32> } // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> - %9 = xegpu.create_nd_tdesc %c[%i, %j] {mode = vc} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-SAME: memref<1024x1024xf32> + %9 = xegpu.create_nd_tdesc %c[%i, %j] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> // CHECK: xegpu.store_nd - // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %result, %9 {mode = vc}: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + // CHECK-SAME: vector<8x1xf32> + xegpu.store_nd %result, %9 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> } } return diff --git a/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir b/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir new file mode 100644 index 000000000..108a08d06 --- /dev/null +++ b/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir @@ -0,0 +1,65 @@ +// RUN: imex-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: imex-opt %s | imex-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s + +// ---- BF16 VC ------ + +// CHECK-LABEL: func @test_gemm_vc_bf16({{.*}}) { +func.func @test_gemm_vc_bf16(%a : memref<1024x1024xbf16>, %b: memref<1024x1024xbf16>, %c: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c1024 = arith.constant 1024 : index + + %c0_1 = arith.constant 0 : i32 + %c1_1 = arith.constant 1 : i32 + + + scf.for %i= %c0 to %c1024 step %c8 { + scf.for %j= %c0 to %c1024 step %c16 { + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16> + %1 = xegpu.create_nd_tdesc %a[%i, %c0] {mode = vc} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16> + %2 = xegpu.create_nd_tdesc %b[%c0, %j] {mode = vc} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16> + + %3 = arith.constant dense<0.0> : vector<8x16xf32> + + %tmp0, %tmp1, %result = scf.for %k= %c0 to %c1024 step %c16 + iter_args(%subA = %1, %subB = %2, %subC = %3) + -> (!xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>, vector<8x16xf32>) { + // CHECK: xegpu.load_nd + // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16> + %4 = xegpu.load_nd %subA {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16> + + // CHECK: xegpu.load_nd + // CHECK-SAME: !xegpu.tensor_desc<16x16xbf16> -> vector<8x16x2xbf16> + %5 = xegpu.load_nd %subB {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xbf16> -> vector<8x16x2xbf16> + + // CHECK: xegpu.dpas + // CHECK-SAME: vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %6 = xegpu.dpas %4, %5, %subC {mode = vc} : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + + %7 = xegpu.update_nd_offset %subA, [%c0, %c16] {mode = vc} : !xegpu.tensor_desc<8x16xbf16> -> !xegpu.tensor_desc<8x16xbf16> + + %8 = xegpu.update_nd_offset %subB, [%c16, %c0] {mode = vc} : !xegpu.tensor_desc<16x16xbf16> -> !xegpu.tensor_desc<16x16xbf16> + + scf.yield %7, %8, %6: !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>, vector<8x16xf32> + } + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %9 = xegpu.create_nd_tdesc %c[%i, %j] {mode = vc} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: xegpu.store_nd + // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %result, %9 {mode = vc}: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + } + } + return +} diff --git a/test/Dialect/XeGPU/IR/store_nd.mlir b/test/Dialect/XeGPU/IR/store_nd.mlir deleted file mode 100644 index 47e714b6c..000000000 --- a/test/Dialect/XeGPU/IR/store_nd.mlir +++ /dev/null @@ -1,33 +0,0 @@ -// RUN: imex-opt %s | FileCheck %s -// Verify the printed output can be parsed. -// RUN: imex-opt %s | imex-opt | FileCheck %s -// Verify the generic form can be parsed. -// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_store_nd_vc_0({{.*}}) { -func.func @test_store_nd_vc_0(%src: memref<24x32xf16>, %dst: memref<24x32xf16>) { - %c0 = arith.constant 2 : index - %c1 = arith.constant 4 : index - - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc, boundary_check = true} - // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} - : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc, boundary_check = true} - // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} - : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} - // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - - // CHECK: xegpu.store_nd - // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} - // CHECK-SAME: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - return -} diff --git a/test/Dialect/XeGPU/IR/store_nd_vc.mlir b/test/Dialect/XeGPU/IR/store_nd_vc.mlir new file mode 100644 index 000000000..16a2824f1 --- /dev/null +++ b/test/Dialect/XeGPU/IR/store_nd_vc.mlir @@ -0,0 +1,92 @@ +// RUN: imex-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: imex-opt %s | imex-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s + +// CHECK-LABEL: func @test_store_nd_vc_bf16({{.*}}) { +func.func @test_store_nd_vc_bf16(%src: memref<24x32xbf16>, %dst: memref<24x32xbf16>) { + %c0 = arith.constant 2 : index + %c1 = arith.constant 4 : index + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = vc, boundary_check = true} + // CHECK-SAME: memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} + : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = vc, boundary_check = true} + // CHECK-SAME: memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> + %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} + : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16> + %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16> + + // CHECK: xegpu.store_nd + // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} + // CHECK-SAME: vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16> + xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16> + return +} + +// CHECK-LABEL: func @test_store_nd_vc_f64({{.*}}) { +func.func @test_store_nd_vc_f64(%src: memref<24x32xf64>, %dst: memref<24x32xf64>) { + %c0 = arith.constant 2 : index + %c1 = arith.constant 4 : index + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = vc, boundary_check = true} + // CHECK-SAME: memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64> + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} + : memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = vc, boundary_check = true} + // CHECK-SAME: memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64> + %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} + : memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf64> -> vector<8x16xf64> + %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf64> -> vector<8x16xf64> + + // CHECK: xegpu.store_nd + // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} + // CHECK-SAME: vector<8x16xf64>, !xegpu.tensor_desc<8x16xf64> + xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xf64>, !xegpu.tensor_desc<8x16xf64> + return +} + +// CHECK-LABEL: func @test_store_nd_vc_i8({{.*}}) { +func.func @test_store_nd_vc_i8(%src: memref<24x32xi8>, %dst: memref<24x32xi8>) { + %c0 = arith.constant 2 : index + %c1 = arith.constant 4 : index + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = vc, boundary_check = true} + // CHECK-SAME: memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} + : memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> + + // CHECK: xegpu.create_nd_tdesc + // CHECK-SAME: {mode = vc, boundary_check = true} + // CHECK-SAME: memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> + %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} + : memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> + + // CHECK: xegpu.load_nd + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: !xegpu.tensor_desc<8x16xi8> -> vector<8x16xi8> + %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xi8> -> vector<8x16xi8> + + // CHECK: xegpu.store_nd + // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} + // CHECK-SAME: vector<8x16xi8>, !xegpu.tensor_desc<8x16xi8> + xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xi8>, !xegpu.tensor_desc<8x16xi8> + return +} diff --git a/test/Dialect/XeGPU/IR/store_scatter.mlir b/test/Dialect/XeGPU/IR/store_scatter.mlir index 19341dc74..8924aefb8 100644 --- a/test/Dialect/XeGPU/IR/store_scatter.mlir +++ b/test/Dialect/XeGPU/IR/store_scatter.mlir @@ -4,36 +4,6 @@ // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s - -// CHECK-LABEL: func @test_store_scatter_vc({{.*}}) { -func.func @test_store_scatter_vc(%src: ui64, %offsets : vector<16 x index>, %dst: ui64) { - %0 = arith.constant dense<1>: vector<16xi1> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc, chunk_size_per_lane = 1} - // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {mode = vc} - : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc, chunk_size_per_lane = 1} - // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - %2 = xegpu.create_tdesc %dst, %offsets {mode = vc} - : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - - // CHECK: xegpu.load - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} - // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> - %3 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} - : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> - // CHECK: xegpu.store - // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} - // CHECK-SAME: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> - xegpu.store %3, %2, %0 {mode = vc, l1_hint = write_back, l2_hint = uncached} - : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> - return -} - - // CHECK-LABEL: func @test_store_scatter({{.*}}) { func.func @test_store_scatter(%src: ui64, %offsets : index, %dst: ui64) { %0 = arith.constant 1: i1 diff --git a/test/Dialect/XeGPU/IR/store_scatter_vc.mlir b/test/Dialect/XeGPU/IR/store_scatter_vc.mlir new file mode 100644 index 000000000..e8650efb0 --- /dev/null +++ b/test/Dialect/XeGPU/IR/store_scatter_vc.mlir @@ -0,0 +1,33 @@ +// RUN: imex-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: imex-opt %s | imex-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s + +// CHECK-LABEL: func @test_store_scatter_vc({{.*}}) { +func.func @test_store_scatter_vc(%src: ui64, %offsets : vector<16 x index>, %dst: ui64) { + %0 = arith.constant dense<1>: vector<16xi1> + // CHECK: xegpu.create_tdesc + // CHECK-SAME: {mode = vc, chunk_size_per_lane = 1} + // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %1 = xegpu.create_tdesc %src, %offsets {mode = vc} + : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + + // CHECK: xegpu.create_tdesc + // CHECK-SAME: {mode = vc, chunk_size_per_lane = 1} + // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + %2 = xegpu.create_tdesc %dst, %offsets {mode = vc} + : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + + // CHECK: xegpu.load + // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> + %3 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} + : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> + // CHECK: xegpu.store + // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} + // CHECK-SAME: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> + xegpu.store %3, %2, %0 {mode = vc, l1_hint = write_back, l2_hint = uncached} + : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> + return +} diff --git a/test/Dialect/XeGPU/IR/update_offset.mlir b/test/Dialect/XeGPU/IR/update_offset_vc.mlir similarity index 56% rename from test/Dialect/XeGPU/IR/update_offset.mlir rename to test/Dialect/XeGPU/IR/update_offset_vc.mlir index 416fd477a..812bbace2 100644 --- a/test/Dialect/XeGPU/IR/update_offset.mlir +++ b/test/Dialect/XeGPU/IR/update_offset_vc.mlir @@ -4,7 +4,6 @@ // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s - // CHECK-LABEL: func @test_update_offset_VC({{.*}}) { func.func @test_update_offset_VC(%src: ui64, %offsets : vector<16 x index>) { %0 = arith.constant dense<1>: vector<16xi1> @@ -30,30 +29,3 @@ func.func @test_update_offset_VC(%src: ui64, %offsets : vector<16 x index>) { return } - -// SIMT test code -// CHECK-LABEL: func @test_update_offset({{.*}}) { -func.func @test_update_offset(%src: ui64, %offsets : index) { - %0 = arith.constant dense<1>: vector<8xi1> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = simt, chunk_size_per_lane = 8} - // CHECK-SAME: ui64, index -> !xegpu.tensor_desc<8xf32, #xegpu.scattered> - %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 8} - : ui64, index -> !xegpu.tensor_desc<8xf32, #xegpu.scattered> - - // CHECK: xegpu.load - // CHECK-SAME: {mode = simt, l1_hint = cached, l2_hint = uncached} - // CHECK-SAME: !xegpu.tensor_desc<8xf32, #xegpu.scattered>, vector<8xi1> -> vector<8xf32> - %2 = xegpu.load %1, %0 {l1_hint = cached, l2_hint = uncached} - : !xegpu.tensor_desc<8xf32, #xegpu.scattered>, vector<8xi1> -> vector<8xf32> - - %3 = arith.constant 16: index - %4 = arith.addi %offsets, %3: index - - // CHECK: xegpu.update_offset - // CHECK-SAME: !xegpu.tensor_desc<8xf32, #xegpu.scattered>, index -> !xegpu.tensor_desc<8xf32, #xegpu.scattered> - %5 = xegpu.update_offset %1, %4 - : !xegpu.tensor_desc<8xf32, #xegpu.scattered>, index -> !xegpu.tensor_desc<8xf32, #xegpu.scattered> - - return -} diff --git a/test/Integration/Dialect/XeGPU/load2d-padding.mlir b/test/Integration/Dialect/XeGPU/load2d-padding.mlir index c86282770..24c3f7797 100644 --- a/test/Integration/Dialect/XeGPU/load2d-padding.mlir +++ b/test/Integration/Dialect/XeGPU/load2d-padding.mlir @@ -10,24 +10,6 @@ module @gemm attributes {gpu.container_module} { // memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<1.0> memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<1.0> - memref.global "private" constant @__constant_16x16xf16 : memref<16x16xf16> = dense<[ -[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], -[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]]> - func.func @test(%arg0: memref<8x16xf16>,%arg3:index) -> memref<8x16xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %memref = gpu.alloc host_shared () : memref<8x16xf16> @@ -52,8 +34,6 @@ module @gemm attributes {gpu.container_module} { } func.func @main() attributes {llvm.emit_c_interface} { %0 = memref.get_global @__constant_8x16xf16 : memref<8x16xf16> - %1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16> - %f32identity = memref.get_global @__constant_8x16xf16 : memref<8x16xf16> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index