Skip to content

Commit

Permalink
Prefetch 2D: encode dummy data type as part of intrinsic name as the … (
Browse files Browse the repository at this point in the history
#949)

Prefetch 2D: encode dummy data type as part of intrinsic name as the last suffix.
Revert "LSC 2D prefetch: Unify type of dummy data arg to i32 (#940)"
  • Loading branch information
silee2 authored Oct 31, 2024
1 parent d6876bd commit 3b8c2af
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 32 deletions.
40 changes: 14 additions & 26 deletions lib/Conversion/XeGPUToVC/LSCPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ static std::string getLSCIntrinsicStr(llvm::StringRef opName, int simd_lanes,
// lsc.load/store/prefetch.2d.ugm. The fullname is in format of
// 1. lsc.load.2d.ugm.desc.<transform>.<retType>.<cache_controls>
// 2. lsc.store.2d.ugm.desc.<cacheCtrType>.<dataType>
// 3. lsc.prefetch.2d.ugm.desc.<predType>
// 3. lsc.prefetch.2d.ugm.desc.<predType>.<dataType>
// All the types are encoded as vN[i/f]M, where N is the number of elements,
// and M is the bit width. So for vector<16xf32>, it will be v16f32, and for
// vector<16xi1>, it will be v16i1. cacheCtrType is fixed to vNi8, where N is
Expand Down Expand Up @@ -228,8 +228,8 @@ static std::string getBlockIntrinsicStr(llvm::StringRef opName,
cache_levels, dataTyStr)
.str();
} else if (opName == "prefetch") {
return llvm::formatv("llvm.genx.lsc.prefetch.2d.ugm.desc.v{0}i8",
cache_levels)
return llvm::formatv("llvm.genx.lsc.prefetch.2d.ugm.desc.v{0}i8.{1}",
cache_levels, dataTyStr)
.str();
}
llvm_unreachable("unsupported opName");
Expand Down Expand Up @@ -675,34 +675,22 @@ gen2DPrefetchIntrinsicCall(ConversionPatternRewriter &rewriter, Location &loc,
assert(tdescTy.getRank() == 2 && !tdescTy.isScattered() &&
"Only works on 2D block TensorDesc.");

auto intrinsicStr = getBlockIntrinsicStr("prefetch");
auto nblks = tdescTy.getArrayLength();
auto shape = tdescTy.getShape();
auto elemTy = tdescTy.getElementType();
auto noRetTy = TypeRange({});
auto bitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();

// Sub 32bit data types are packed into 32bit data types (i32).
auto packFactor = 32 / bitWidth;

// If packing is needed, the innermost dimensions gets scaled by the packing
// factor. In such case, the shape[1] must be a multiple of the pack factor.
// Otherwise, packing cannot be done correctly
if (packFactor > 1) {
assert(
shape[1] % packFactor == 0 &&
"shape[1] must be a multiple of pack factor (32 / element bitwidth)");
}
auto bitWidth = elemTy.getIntOrFloatBitWidth();
auto prefix = elemTy.isInteger() ? "i" : elemTy.isBF16() ? "bf" : "f";
auto typeStr = llvm::formatv("{0}{1}", prefix, bitWidth).str();
auto intrinsicStr = getBlockIntrinsicStr("prefetch", typeStr);

// for arg8: dummy value, type has to be always the same since intrinsic
// func name for prefetch is the same regardless of the element type.
// Different type used for dummy causes type conflict in case of multiple
// calls with different dummy arg type.
auto attr = (TypedAttr)rewriter.getIntegerAttr(rewriter.getI32Type(), 0);
// for arg8: dummy value
auto attr = elemTy.isInteger()
? (TypedAttr)rewriter.getIntegerAttr(elemTy, 0)
: (TypedAttr)rewriter.getFloatAttr(elemTy, 0.0);
auto dummy = constant_val(attr);
return gen2DBlockIntrinsicCall(
rewriter, loc, intrinsicStr, noRetTy, l1, l3, nblks,
{shape[0], bitWidth == 64 ? shape[1] * 2 : shape[1] / packFactor},
payload, dummy);
return gen2DBlockIntrinsicCall(rewriter, loc, intrinsicStr, noRetTy, l1, l3,
nblks, shape, payload, dummy);
}

// generate a call to lsc.store.2d.ugm.* intrinsic for 2D block store, which is
Expand Down
16 changes: 10 additions & 6 deletions test/Conversion/XeGPUToVC/prefetchnd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
module @gemm attributes {gpu.container_module} {

gpu.module @test_kernel {

gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %{{.*}} : memref<8x16xf16> -> index
Expand Down Expand Up @@ -50,14 +51,15 @@ module @gemm attributes {gpu.container_module} {
//CHECK: %[[r26:.*]] = vector.insert %[[c1807_i32]], %[[r25]] [7] : i32 into vector<16xi32>
%2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>

//LSC: %[[cst_2:.*]] = arith.constant 0.000000e+00 : f16
//LSC: %[[true:.*]] = arith.constant true
//LSC: %[[c0_i8:.*]] = arith.constant 0 : i8
//LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8>
//LSC: %[[c1_i8:.*]] = arith.constant 1 : i8
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c16_i16:.*]] = arith.constant 16 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f16(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> ()
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f16(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> ()
xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1 : !xegpu.tensor_desc<16x16xf16>

Expand Down Expand Up @@ -94,14 +96,16 @@ module @two_type attributes {gpu.container_module} {
%2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>

//LSC: %[[c0_i32:.*]] = arith.constant 0 : i32
//LSC: %[[c0_f16:.*]] = arith.constant 0.000000e+00 : f16
//LSC: %[[true:.*]] = arith.constant true
//LSC: %[[c0_i8:.*]] = arith.constant 0 : i8
//LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8>
//LSC: %[[c1_i8:.*]] = arith.constant 1 : i8
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c16_i16:.*]] = arith.constant 16 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f16(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_f16]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> ()
//LSC: %[[c0_f32:.*]] = arith.constant 0.000000e+00 : f32
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f32(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_f32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f32) -> ()
xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1 : !xegpu.tensor_desc<8x16xf32>

Expand Down

0 comments on commit 3b8c2af

Please sign in to comment.