Skip to content

Commit

Permalink
LSC pattern: Handle low precision 1D load/store/prefetch. (#938)
Browse files Browse the repository at this point in the history
  • Loading branch information
silee2 authored Oct 23, 2024
1 parent d444d47 commit e40f1ea
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 26 deletions.
114 changes: 88 additions & 26 deletions lib/Conversion/XeGPUToVC/LSCPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,63 @@ static func::CallOp gen2DStoreIntrinsicCall(
nblks, shape, payload, data);
}

auto get1DTdescNumTotalElems = [](TensorDescType tdescTy) -> int64_t {
return tdescTy.getNumElements() * tdescTy.getArrayLength();
};

auto getElemBitWidth = [](TensorDescType tdescTy) -> unsigned {
return tdescTy.getElementType().getIntOrFloatBitWidth();
};

auto isLowPrecision = [](TensorDescType tdescTy) -> bool {
// Note: Handling for sub 8bit types is unclear so report as false
auto width = getElemBitWidth(tdescTy);
return width < 32 && width >= 8;
};

auto getScaled1DTdesc =
[](TensorDescType tdescTy,
ConversionPatternRewriter &rewriter) -> TensorDescType {
// return if not 1D tensor desc
if (tdescTy.getShape().size() != 1)
return tdescTy;
// return if not low precision
if (!isLowPrecision(tdescTy))
return tdescTy;

auto scaledTy = tdescTy.getElementType();
auto totalBytes =
get1DTdescNumTotalElems(tdescTy) * getElemBitWidth(tdescTy) / 8;
switch (totalBytes) {
// i32 for 4, 8, 12, 16, 32, 64, 128, 256
// i64 for 24 and 512
case 4:
case 8:
case 12:
case 16:
case 32:
case 64:
case 128:
case 256:
scaledTy = rewriter.getI32Type();
break;
case 24:
case 512:
scaledTy = rewriter.getI64Type();
break;
default:
break;
}
return TensorDescType::get(
tdescTy.getContext(),
{totalBytes / (scaledTy.getIntOrFloatBitWidth() / 8)}, scaledTy,
tdescTy.getEncoding(), /*sg_map*/ nullptr);
};

auto isScaled = [](TensorDescType tdescTy, TensorDescType scaledTy) -> bool {
return getElemBitWidth(tdescTy) != getElemBitWidth(scaledTy);
};

#define shrui(...) rewriter.createOrFold<arith::ShRUIOp>(loc, __VA_ARGS__)
class LoadNdPattern : public OpConversionPattern<LoadNdOp> {
using OpConversionPattern<LoadNdOp>::OpConversionPattern;
Expand Down Expand Up @@ -759,16 +816,25 @@ class LoadNdPattern : public OpConversionPattern<LoadNdOp> {
return rewriter.notifyMatchFailure(
op, "transpose is not supported for slm and 1D tensor desc");

auto elems = tdescTy.getNumElements() * tdescTy.getArrayLength();
auto scaledTdescTy = getScaled1DTdesc(tdescTy, rewriter);
auto scaledElems = get1DTdescNumTotalElems(scaledTdescTy);
auto scaledElemTy = scaledTdescTy.getElementType();

if (failed(isValid1DBlockSetup(elemTy, elems, loc, rewriter))) {
if (failed(
isValid1DBlockSetup(scaledElemTy, scaledElems, loc, rewriter))) {
return rewriter.notifyMatchFailure(
loc, "unsupported 1D/SLM TensorDescType.");
}

bool scaled = isScaled(tdescTy, scaledTdescTy);
auto resTy =
scaled ? VectorType::get({scaledElems}, scaledElemTy) : op.getType();
auto newValue = gen1DLoadInstrinsicCall(
rewriter, loc, op.getType(), l1hint, l3hint, elemTy, elems,
rewriter, loc, resTy, l1hint, l3hint, scaledElemTy, scaledElems,
tdescTy.getMemorySpace(), adaptor.getTensorDesc());
if (scaled) {
newValue =
rewriter.create<vector::BitCastOp>(loc, op.getType(), newValue);
}
rewriter.replaceOp(op, newValue);
return success();
} else if (rank == 2) { // 2d.ugm.desc
Expand Down Expand Up @@ -866,7 +932,6 @@ class PrefetchNdPattern : public OpConversionPattern<PrefetchNdOp> {

auto loc = op.getLoc();
auto tdescTy = op.getTensorDescType();
auto elemTy = tdescTy.getElementType();
auto rank = tdescTy.getRank();
auto scope = tdescTy.getMemorySpace();

Expand All @@ -881,15 +946,17 @@ class PrefetchNdPattern : public OpConversionPattern<PrefetchNdOp> {
return success();
}

auto elems = tdescTy.getNumElements() * tdescTy.getArrayLength();
auto scaledTdescTy = getScaled1DTdesc(tdescTy, rewriter);
auto scaledElems = get1DTdescNumTotalElems(scaledTdescTy);
auto scaledElemTy = scaledTdescTy.getElementType();

if (failed(isValid1DBlockSetup(elemTy, elems, loc, rewriter)))
if (failed(isValid1DBlockSetup(scaledElemTy, scaledElems, loc, rewriter)))
return rewriter.notifyMatchFailure(
loc, "unsupported 1D/SLM TensorDescType.");

auto callOp =
gen1DPrefetchIntrinsicCall(rewriter, loc, l1hint, l3hint, elemTy,
elems, scope, adaptor.getTensorDesc());
auto callOp = gen1DPrefetchIntrinsicCall(rewriter, loc, l1hint, l3hint,
scaledElemTy, scaledElems, scope,
adaptor.getTensorDesc());
rewriter.replaceOp(op, callOp);
return success();
} else if (rank == 2) { // 2d.ugm.desc
Expand All @@ -914,7 +981,6 @@ class StoreNdPattern : public OpConversionPattern<StoreNdOp> {

auto loc = op.getLoc();
auto tdescTy = op.getTensorDescType();
auto elemTy = tdescTy.getElementType();
auto rank = tdescTy.getRank();
auto scope = tdescTy.getMemorySpace();

Expand All @@ -925,25 +991,21 @@ class StoreNdPattern : public OpConversionPattern<StoreNdOp> {
auto data = adaptor.getValue();

if (rank == 1) {
// for slm and 1D tensor desc, use lsc.store,
// all non 32-bit data has to be encoded as i32.

// get instrinsic name, the data type has to be encoded
// as vNi32 for 8-bit/16-bit data in regular store.
// for example, Vector<8x16xf16> should be encoded as V128I32.
auto lscTy = getOrigOrI32VectorType(op.getValueType());
auto typeStr = convertVectorType(lscTy).first;
auto intrinsicStr = getLSCIntrinsicStr("store", 1, scope, typeStr);

auto elems = tdescTy.getNumElements();
auto scaledTdescTy = getScaled1DTdesc(tdescTy, rewriter);
auto scaledElems = get1DTdescNumTotalElems(scaledTdescTy);
auto scaledElemTy = scaledTdescTy.getElementType();

if (failed(isValid1DBlockSetup(elemTy, elems, loc, rewriter)))
if (failed(isValid1DBlockSetup(scaledElemTy, scaledElems, loc, rewriter)))
return rewriter.notifyMatchFailure(
loc, "unsupported 1D/SLM TensorDescType.");

auto callOp =
gen1DStoreInstrinsicCall(rewriter, loc, l1hint, l3hint, elemTy, elems,
scope, adaptor.getTensorDesc(), data);
if (isScaled(tdescTy, scaledTdescTy)) {
auto scaledVecTy = VectorType::get({scaledElems}, scaledElemTy);
data = rewriter.create<vector::BitCastOp>(loc, scaledVecTy, data);
}
auto callOp = gen1DStoreInstrinsicCall(rewriter, loc, l1hint, l3hint,
scaledElemTy, scaledElems, scope,
adaptor.getTensorDesc(), data);

rewriter.replaceOp(op, callOp);
return success();
Expand Down
27 changes: 27 additions & 0 deletions test/Conversion/XeGPUToVC/load_store_prefetch_1D_bf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: imex-opt -convert-xegpu-to-vc -cse %s | FileCheck %s

gpu.module @load_store_bf16 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Bfloat16ConversionINTEL, BFloat16TypeKHR, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, VectorAnyINTEL, VectorComputeINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_bfloat16, SPV_KHR_expect_assume, SPV_INTEL_bfloat16_conversion, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @load_store_bf16(%arg0: memref<4x2x128xbf16>, %arg1: memref<4x2x128xbf16>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array<i32: 4, 2, 4>, known_grid_size = array<i32: 1, 1, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c32 = arith.constant 32 : index
%thread_id_x = gpu.thread_id x
%thread_id_y = gpu.thread_id y
%thread_id_z = gpu.thread_id z
%0 = arith.muli %thread_id_z, %c32 : index
%1 = xegpu.create_nd_tdesc %arg0[%thread_id_x, %thread_id_y, %0], [4, 2, 128], [256, 128, 1] : memref<4x2x128xbf16> -> !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>

// CHECK: func.call @llvm.genx.lsc.prefetch.stateless.v1i1.v1i64
xegpu.prefetch_nd %1 : !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
%2 = xegpu.create_nd_tdesc %arg0[%thread_id_x, %thread_id_y, %0], [4, 2, 128], [256, 128, 1] : memref<4x2x128xbf16> -> !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>

// CHECK: %[[LOAD_VAL:.*]] = func.call @llvm.genx.lsc.load.stateless.v16i32.v1i1.v1i64
// CHECK: %[[REAL_VAL:.*]] = vector.bitcast %[[LOAD_VAL]] : vector<16xi32> to vector<32xbf16>
%3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32xbf16>
%4 = xegpu.create_nd_tdesc %arg1[%thread_id_x, %thread_id_y, %0], [4, 2, 128], [256, 128, 1] : memref<4x2x128xbf16> -> !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>

// CHECK: %[[STORE_VAL:.*]] = vector.bitcast %[[REAL_VAL]] : vector<32xbf16> to vector<16xi32>
// CHECK: func.call @llvm.genx.lsc.store.stateless.v1i1.v1i64.v16i32
// CHECK: %[[STORE_VAL]], %[[LAST_ARG:.*]]) :
xegpu.store_nd %3, %4 : vector<32xbf16>, !xegpu.tensor_desc<32xbf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
gpu.return
}
}
86 changes: 86 additions & 0 deletions test/Integration/Dialect/XeGPU/load_store_with_1d_bf16_tile.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
module @gemm attributes {gpu.container_module} {
memref.global "private" constant @__constant_8x32xbf16 : memref<8x32xbf16> = dense<0.0>
func.func @test(%arg0: memref<8x32xbf16>, %arg1: memref<8x32xbf16>) -> memref<8x32xbf16> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index

%memref = gpu.alloc host_shared () : memref<8x32xbf16>
memref.copy %arg0, %memref : memref<8x32xbf16> to memref<8x32xbf16>
%memref_1 = gpu.alloc host_shared () : memref<8x32xbf16>
memref.copy %arg1, %memref_1 : memref<8x32xbf16> to memref<8x32xbf16>
%memref_2 = gpu.alloc host_shared () : memref<8x32xbf16>
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%memref : memref<8x32xbf16>, %memref_1 : memref<8x32xbf16>, %memref_2 : memref<8x32xbf16>)
gpu.dealloc %memref : memref<8x32xbf16>
gpu.dealloc %memref_1 : memref<8x32xbf16>
return %memref_2 : memref<8x32xbf16>
}
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL, Bfloat16ConversionINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_INTEL_bfloat16_conversion]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%arg0: memref<8x32xbf16>, %arg1: memref<8x32xbf16>, %arg2: memref<8x32xbf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%thread_id_x = gpu.thread_id x
cf.br ^bb1
^bb1:
%0 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0] : memref<8x32xbf16> -> !xegpu.tensor_desc<32xbf16>
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32xbf16> -> vector<32xbf16>
%2 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0] : memref<8x32xbf16> -> !xegpu.tensor_desc<32xbf16>
%3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32xbf16> -> vector<32xbf16>
%4 = arith.addf %3, %1 : vector<32xbf16>
%5 = xegpu.create_nd_tdesc %arg2[%thread_id_x, 0] : memref<8x32xbf16> -> !xegpu.tensor_desc<32xbf16>
xegpu.store_nd %4, %5 : vector<32xbf16>, !xegpu.tensor_desc<32xbf16>
gpu.return
}
}
func.func @main() attributes {llvm.emit_c_interface} {
%c_gen_int = arith.constant 0 : i1
%cf_lower = arith.constant -0.5 : f32
%cf_upper = arith.constant 0.5 : f32

%A = memref.alloc() : memref<8x32xbf16>
%A_random = memref.cast %A : memref<8x32xbf16> to memref<*xbf16>
call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> ()

%B = memref.alloc() : memref<8x32xbf16>
%B_random = memref.cast %B : memref<8x32xbf16> to memref<*xbf16>
call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> ()

// calculate the result C matrix
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%ref = memref.alloc() : memref<8x32xf32>
scf.for %i = %c0 to %c8 step %c1 {
scf.for %j = %c0 to %c32 step %c1 {
%a = memref.load %A[%i, %j] : memref<8x32xbf16>
%b = memref.load %B[%i, %j] : memref<8x32xbf16>
%a_ext = arith.extf %a : bf16 to f32
%b_ext = arith.extf %b : bf16 to f32
%c = arith.addf %a_ext, %b_ext : f32
%c_trunc = arith.truncf %c : f32 to bf16
%c_ext = arith.extf %c_trunc : bf16 to f32
memref.store %c_ext, %ref[%i, %j] : memref<8x32xf32>
}
}

%C = call @test(%A, %B) : (memref<8x32xbf16>, memref<8x32xbf16>) -> memref<8x32xbf16>

%C_cast = memref.cast %C : memref<8x32xbf16> to memref<*xbf16>
%ref_cast = memref.cast %ref : memref<8x32xf32> to memref<*xf32>
//call @printMemrefBF16(%C_cast) : (memref<*xbf16>) -> ()
//call @printMemrefF32(%ref_cast) : (memref<*xf32>) -> ()
// CHECK: [ALLCLOSE: TRUE]
call @printAllcloseBF16(%C_cast, %ref_cast) : (memref<*xbf16>, memref<*xf32>) -> ()
return
}
func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface}
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
}

0 comments on commit e40f1ea

Please sign in to comment.