Skip to content

Commit

Permalink
[XeTileToXeGPU] Add preop support for Load+Transpose optimization (#947)
Browse files Browse the repository at this point in the history
  • Loading branch information
charithaintc authored Oct 30, 2024
1 parent 07920ba commit 362d432
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 4 deletions.
3 changes: 3 additions & 0 deletions include/imex/Utils/XeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define _IMEX_XECOMMON_H_

#include "imex/Dialect/XeTile/IR/XeTileOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallVector.h"
#include <mlir/Dialect/GPU/IR/GPUDialect.h>
Expand Down Expand Up @@ -120,6 +121,8 @@ class TileUsageAnalysis {
// info is needed for downstream optimizations.
transposeBeforeDPAS = true;
q.push_back(transpose);
} else if (mlir::OpTrait::hasElementwiseMappableTraits(user)) {
q.push_back(user->getResult(0));
}
}
}
Expand Down
52 changes: 48 additions & 4 deletions test/Conversion/XeTileToXeGPU/sg_gemm_transpose_b.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ gpu.module @test_kernel {
iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value)
-> (!xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32>) {
%a_value = xetile.load_tile %a_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16>
// Check if array_length is 1 for the load + transpose + MMA B case.
//
// xegpu.load_nd %[[ARG5]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x16xf16>
// xegpu.load_nd %[[ARG6]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x16xf16>
// Check if array_length is 1 for the load + transpose + MMA B case.
//
// CHECK: xegpu.load_nd %[[ARG5]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x16xf16>
// CHECK: xegpu.load_nd %[[ARG6]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x16xf16>
%b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16>
%b_transpose = xetile.transpose %b_value, [1, 0] : vector<32x32xf16> -> vector<32x32xf16>
%c_new_value = xetile.tile_mma %a_value, %b_transpose, %c_value : vector<32x32xf16>, vector<32x32xf16>, vector<32x32xf32> -> vector<32x32xf32>
Expand All @@ -41,3 +41,47 @@ gpu.module @test_kernel {
gpu.return
}
}

// -----
gpu.module @test_kernel {
// CHECK-LABEL: gpu.func @test_gemm_preop(
// CHECK-SAME: %[[A:.*]]: memref<1024x1024xf16>, %[[arg1:.*]]: memref<1024x1024xf16>, %[[C:.*]]: memref<1024x1024xf32>)
gpu.func @test_gemm_preop(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {

%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index

%block_id_x = gpu.block_id x
%block_id_y = gpu.block_id y

%m = arith.muli %block_id_x, %c32 : index
%n = arith.muli %block_id_y, %c32 : index
%c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32>
%c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<32x32xf32> -> vector<32x32xf32>
%a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[arg1]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[arg1]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
%b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
// CHECK: scf.for %{{.*}}= %{{.*}}to %{{.*}}step %{{.*}}iter_args(%{{.*}}= %{{.*}}, %[[ARG5:.*]] = %[[T1]], %[[ARG6:.*]] = %[[T2]], %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}} = %{{.*}}) -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
%out:3 = scf.for %k = %c0 to %c1024 step %c32
iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value)
-> (!xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32>) {
%a_value = xetile.load_tile %a_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16>
// Check if array_length is 1 for the load + transpose + preop + MMA B case.
//
// CHECK: xegpu.load_nd %[[ARG5]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x16xf16>
// CHECK: xegpu.load_nd %[[ARG6]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x16xf16>
%b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16>
%b_transpose = xetile.transpose %b_value, [1, 0] : vector<32x32xf16> -> vector<32x32xf16>
%preop = math.exp %b_transpose : vector<32x32xf16>
%c_new_value = xetile.tile_mma %a_value, %preop, %c_value : vector<32x32xf16>, vector<32x32xf16>, vector<32x32xf32> -> vector<32x32xf32>
%a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] : !xetile.tile<32x32xf16>
%b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] : !xetile.tile<32x32xf16>
scf.yield %a_next_tile, %b_next_tile, %c_new_value
: !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32>
}
xetile.store_tile %out#2, %c_init_tile: vector<32x32xf32>, !xetile.tile<32x32xf32>
gpu.return
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc-optimize-transpose.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc-optimize-transpose.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck

// NOTES :
// This example assumes one subgroup per one workgroup and the kernel specifies the computation
// done by a single subgroup.

module @gemm attributes {gpu.container_module} {
func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%c512 = arith.constant 512 : index
%A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16>
memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16>
%B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16>
memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16>
%C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32>
memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32>
gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>)
gpu.dealloc %A_gpu : memref<1024x1024xf16>
gpu.dealloc %B_gpu : memref<1024x1024xf16>
return %C_gpu : memref<1024x1024xf32>
}
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index
%block_id_x = gpu.block_id x
%block_id_y = gpu.block_id y
%m = arith.muli %block_id_x, %c16 : index
%n = arith.muli %block_id_y, %c32 : index
// intialize C tile and load it
%c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32>
%c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32>
// initalize A and B tiles
%a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16>
%b_init_tile = xetile.init_tile %B[%n, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
// compute the value of C tile by iterating over tiles in k-dimension and doing dpas
%out:3 = scf.for %k = %c0 to %c1024 step %c32
iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value)
-> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) {

// load A and B tiles
%a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16>
%b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16>
%b_trans = vector.transpose %b_value, [1, 0] : vector<32x32xf16> to vector<32x32xf16>
%preop = arith.addf %b_trans, %b_trans : vector<32x32xf16>
// perform dpas and accumulate
%c_new_value = xetile.tile_mma %a_value, %preop, %c_value
: vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32>
// update the offsets for A and B tiles
%a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32]
: !xetile.tile<16x32xf16>
%b_next_tile = xetile.update_tile_offset %b_tile, [%c0, %c32]
: !xetile.tile<32x32xf16>
// partial C tile result
scf.yield %a_next_tile, %b_next_tile, %c_new_value
: !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>
}
// store the final accumulated C tile result back to memory
xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32>
gpu.return
}
}
func.func @main() attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1024 = arith.constant 1024 : index
%cf_0 = arith.constant 0.0 : f16
%cf_1 = arith.constant 1.0 : f16
%A = memref.alloc() : memref<1024x1024xf16>
%B = memref.alloc() : memref<1024x1024xf16>
%C = memref.alloc() : memref<1024x1024xf32>
%C_ref = memref.alloc() : memref<1024x1024xf32>
// intialize matrix B ; B[i, j] = j
scf.for %i = %c0 to %c1024 step %c1 {
scf.for %j = %c0 to %c1024 step %c1 {
%t = index.castu %j : index to i16
%val = arith.uitofp %t : i16 to f16
memref.store %val, %B[%i, %j] : memref<1024x1024xf16>
}
}
// make matrix A an identity matrix
scf.for %i = %c0 to %c1024 step %c1 {
scf.for %j = %c0 to %c1024 step %c1 {
%i_i32 = index.castu %i : index to i32
%j_i32 = index.castu %j : index to i32
%i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32

scf.if %i_j_same {
memref.store %cf_1, %A[%i, %j] : memref<1024x1024xf16>
} else {
memref.store %cf_0, %A[%i, %j] : memref<1024x1024xf16>
}
}
}
// intialize matrix C and C_ref ; C[i, j] = 0
%c0_f32 = arith.constant 0.0 : f32
scf.for %i = %c0 to %c1024 step %c1 {
scf.for %j = %c0 to %c1024 step %c1 {
memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32>
memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32>
}
}
// compute C for reference
scf.for %i = %c0 to %c1024 step %c1 {
scf.for %j = %c0 to %c1024 step %c1 {
%c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32>
%c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 {
%a_val = memref.load %A[%i, %k] : memref<1024x1024xf16>
%b_val = memref.load %B[%j, %k] : memref<1024x1024xf16>
%preop = arith.addf %b_val, %b_val : f16
%t = arith.mulf %a_val, %preop : f16
%t_cast = arith.extf %t : f16 to f32
%c_sum = arith.addf %t_cast, %c_partial : f32
scf.yield %c_sum : f32
}
memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32>
}
}
%2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32>
// %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16>
// call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
%cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32>
%cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32>
// call @printMemrefF32(%cast_C) : (memref<*xf32>) -> ()
// call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> ()
// %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32>
// %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32>
// call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> ()
// CHECK: [ALLCLOSE: TRUE]
call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> ()
memref.dealloc %A : memref<1024x1024xf16>
memref.dealloc %B : memref<1024x1024xf16>
memref.dealloc %C : memref<1024x1024xf32>
memref.dealloc %C_ref : memref<1024x1024xf32>
return
}
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface}
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
}

0 comments on commit 362d432

Please sign in to comment.