Skip to content

Commit

Permalink
add XeGPU createDesc/load2d/store2d/dpas to spirv genISA lowering
Browse files Browse the repository at this point in the history
add a pass option to differentiate between vc-intrinsic and genIsa intrinsic
lower load2d/store2d/dpas to corresponding genISA
  • Loading branch information
Dewei-Wang-sh authored and silee2 committed Nov 20, 2023
1 parent f58d7ce commit 05e7fbc
Show file tree
Hide file tree
Showing 11 changed files with 368 additions and 36 deletions.
4 changes: 2 additions & 2 deletions include/imex/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ memref, arith and math.
let constructor = "imex::createConvertGPUXToSPIRVPass()";
let dependentDialects = ["::mlir::spirv::SPIRVDialect"];
let options = [
Option<"enableSimtIntrinsic", "enable-simt-intrinsic","bool", "false",
"Enable XeGPU.simt Ops lowered to intel genISA simt Intrinsics">
Option<"enableVCIntrinsic", "enable-vc-intrinsic","bool", "true",
"Enable XeGPU Ops lowered to intel vc Intrinsics">
];
}

Expand Down
3 changes: 3 additions & 0 deletions include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ namespace imex {
// XeGPU to VC Intrinsics pattern
void populateXeGPUToVCIntrinsicsPatterns(
mlir::SPIRVTypeConverter &typeConverter, mlir::RewritePatternSet &patterns);
// XeGPU to genISA Intrinsics pattern
void populateXeGPUToGenISAPatterns(mlir::SPIRVTypeConverter &typeConverter,
mlir::RewritePatternSet &patterns);
} // namespace imex

#endif // IMEX_CONVERSION_XEGPUTOSPIRV_H
5 changes: 4 additions & 1 deletion lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ void GPUXToSPIRVPass::runOnOperation() {
mlir::populateSCFToSPIRVPatterns(typeConverter, scfToSpirvCtx, patterns);
mlir::cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
mlir::populateMathToSPIRVPatterns(typeConverter, patterns);
imex::populateXeGPUToVCIntrinsicsPatterns(typeConverter, patterns);
if (this->enableVCIntrinsic)
imex::populateXeGPUToVCIntrinsicsPatterns(typeConverter, patterns);
else
imex::populateXeGPUToGenISAPatterns(typeConverter, patterns);

if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
return signalPassFailure();
Expand Down
278 changes: 276 additions & 2 deletions lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ unsigned encodeOpcode(xegpu::AtomicRMWKind kind) {
}

void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op,
std::string name, FunctionType funcType) {
std::string name, FunctionType funcType,
bool isVC = true) {
auto funcAttr = StringAttr::get(rewriter.getContext(), name);
Operation *found = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
if (!found) {
Expand All @@ -215,7 +216,8 @@ void lookupOrInsertIntrinsic(ConversionPatternRewriter &rewriter, Operation *op,
auto linkage = spirv::LinkageAttributesAttr::get(rewriter.getContext(),
nameAttr, linkageTypeAttr);
func.setLinkageAttributesAttr(linkage);
func->setAttr("VectorComputeFunctionINTEL", rewriter.getUnitAttr());
if (isVC)
func->setAttr("VectorComputeFunctionINTEL", rewriter.getUnitAttr());
}
}

Expand Down Expand Up @@ -1246,3 +1248,275 @@ void imex::populateXeGPUToVCIntrinsicsPatterns(
LoadStorePrefetchNdToRawSend<PrefetchNDOp>>(
typeConverter, patterns.getContext());
}

/// below is for XeGPU to SPIRV genISA Intrinsic

/// @brief encodeVectorType(xxx, 8x8x2xf16, false) returns ["v64i32", 64xi32]
std::pair<std::string, VectorType>
encodeGenISAVectorType(ConversionPatternRewriter &rewriter, VectorType type,
bool use32bitData = true) {
auto elemType = type.getElementType();
auto bitWidth = elemType.getIntOrFloatBitWidth();
int size = type.getNumElements() * bitWidth / 16;
if (use32bitData) {
size /= 2;
}
std::string str = "v";
str += std::to_string(size);
if (!use32bitData) {
str += "i16";
elemType = rewriter.getI16Type();
} else if (elemType == rewriter.getF32Type())
str += "f32";
else if (elemType == rewriter.getF16Type()) {
str += "i32";
elemType = rewriter.getI32Type();
} else
assert(0 && "add more support");
auto newType = VectorType::get(size, elemType);
return std::make_pair(str, newType);
}

class CreateNdDescToGenISA : public OpConversionPattern<CreateNdDescOp> {
public:
using OpConversionPattern<CreateNdDescOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CreateNdDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
auto v4i32 = VectorType::get(4, i32Type);
auto v2i64 = VectorType::get(2, i64Type);
Value payLoad = rewriter.create<spirv::UndefOp>(loc, v2i64);
auto createIntConstant = [&](Type type, unsigned value) {
auto attr = rewriter.getIntegerAttr(type, value);
return rewriter.create<spirv::ConstantOp>(loc, type, attr);
};
auto base = rewriter.create<spirv::ConvertPtrToUOp>(loc, i64Type,
adaptor.getSource());
auto idx0 = createIntConstant(i32Type, 0);
payLoad =
rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad, base, idx0);
auto tileType = op.getTensorDesc().getType();
auto rank = tileType.getRank();
if (rank == 2) {
payLoad = rewriter.create<spirv::BitcastOp>(loc, v4i32, payLoad);
auto createOffset = [&](unsigned idx) -> Value {
Value val;
if (ShapedType::isDynamic(op.getStaticOffsets()[idx])) {
val = op.getOffsets()[idx];
val = rewriter.create<arith::TruncIOp>(loc, i32Type, val);
} else {
val = createIntConstant(i32Type, op.getStaticOffsets()[idx]);
}
return val;
};
auto offsetX = createOffset(1);
auto offsetY = createOffset(0);
auto idx2 = createIntConstant(i32Type, 2);
auto idx3 = createIntConstant(i32Type, 3);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
offsetX, idx2);
payLoad = rewriter.create<spirv::VectorInsertDynamicOp>(loc, payLoad,
offsetY, idx3);
payLoad = rewriter.create<spirv::BitcastOp>(loc, v2i64, payLoad);
}
rewriter.replaceOp(op, payLoad);
return success();
}
};

template <typename OpType>
class LoadStorePrefetchNdToGenISA : public OpConversionPattern<OpType> {
public:
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tileType = op.getTensorDesc().getType();
int rank = tileType.getRank();
assert(rank <= 2 && "only support 1d/2d load/store/prefetch for now");
auto loc = op.getLoc();
::mlir::VectorType vecType;
std::string funcName;
constexpr bool isLoad = std::is_same_v<OpType, LoadNDOp>;
constexpr bool isPrefetch = std::is_same_v<OpType, PrefetchNDOp>;
if constexpr (isLoad) {
vecType = cast<VectorType>(op.getResult().getType());
funcName = rank == 2 ? "llvm.genx.GenISA.LSC2DBlockRead."
: "llvm.genx.GenISA.LSCLoadBlock.";
} else if constexpr (isPrefetch) {
vecType = VectorType::get({8, 16}, rewriter.getF32Type());
funcName = rank == 2 ? "llvm.genx.GenISA.LSC2DPrefetch."
: "llvm.genx.GenISA.LSCPrefetch";
} else {
vecType = cast<VectorType>(op.getValue().getType());
funcName = rank == 2 ? "llvm.genx.GenISA.LSC2DBlockWrite."
: "llvm.genx.GenISA.LSCStoreBlock";
}
auto createIntConstant = [&](Type type, unsigned value) {
auto attr = rewriter.getIntegerAttr(type, value);
return rewriter.create<spirv::ConstantOp>(loc, type, attr);
};
auto i1Type = rewriter.getI1Type();
auto i8Type = rewriter.getI8Type();
auto i32Type = rewriter.getI32Type();
auto vnni = false;
auto transpose = false;
if constexpr (isLoad) {
auto vnniValue = op.getVnniAxis();
vnni = vnniValue.has_value() && vnniValue.value() == 0 ? true : false;
auto transposeValue = op.getTranspose();
transpose = transposeValue.has_value() && transposeValue.value()[0] == 1
? true
: false;
}
unsigned dataSize = vecType.getElementType().getIntOrFloatBitWidth();
auto elemSize = createIntConstant(i8Type, dataSize);
auto trans = createIntConstant(i1Type, transpose ? 1 : 0);
// number of blocks(1 for now)
auto nBlks = createIntConstant(i8Type, 1);
auto tensorDesc = adaptor.getTensorDesc();
auto idx0 = createIntConstant(i32Type, 0);
auto base =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx0);
auto [typeStr, newType] = encodeGenISAVectorType(rewriter, vecType, false);
SmallVector<Value> args;
if (rank == 2) {
auto blockWidth = tileType.getShape()[1];
auto blockHeight = tileType.getShape()[0];
auto blockW = createIntConstant(i32Type, blockWidth);
auto blockH = createIntConstant(i32Type, blockHeight);
auto transform = createIntConstant(i1Type, vnni ? 1 : 0);
// static memref for now
auto createDescOp =
op.getTensorDesc().template getDefiningOp<CreateNdDescOp>();
auto memType = cast<MemRefType>(createDescOp.getSource().getType());
unsigned bitWidth = memType.getElementType().getIntOrFloatBitWidth();
auto surfaceWidth = memType.getShape()[1] * (bitWidth / 8) - 1;
auto surfaceHeight = memType.getShape()[0] - 1;
// pitch = width for now
auto surfacePitch = surfaceWidth;
auto surfaceW = createIntConstant(i32Type, surfaceWidth);
auto surfaceH = createIntConstant(i32Type, surfaceHeight);
auto surfaceP = createIntConstant(i32Type, surfacePitch);
auto v4i32 = VectorType::get(4, i32Type);
tensorDesc = rewriter.create<spirv::BitcastOp>(loc, v4i32, tensorDesc);
auto idx2 = createIntConstant(i32Type, 2);
auto idx3 = createIntConstant(i32Type, 3);
auto offsetX =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx2);
auto offsetY =
rewriter.create<spirv::VectorExtractDynamicOp>(loc, tensorDesc, idx3);
args.assign({base, surfaceW, surfaceH, surfaceP, offsetX, offsetY,
elemSize, blockW, blockH, nBlks, trans, transform});
if constexpr (!isLoad && !isPrefetch) {
args.push_back(adaptor.getValue());
}
}
if constexpr (isLoad)
funcName += typeStr;
else if constexpr (!isPrefetch)
funcName += "isVoid";
if constexpr (isLoad) {
auto funcType =
rewriter.getFunctionType(ValueRange(args).getTypes(), newType);
Operation *opPtr = op;
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true);
auto funcOp =
rewriter.create<spirv::FunctionCallOp>(loc, newType, funcName, args);
auto castTy = this->getTypeConverter()->convertType(op.getType());
auto cast =
rewriter.create<spirv::BitcastOp>(loc, castTy, funcOp->getResult(0));
rewriter.replaceOp(op, cast);
} else {
auto funcType = rewriter.getFunctionType(ValueRange(args).getTypes(), {});
Operation *opPtr = op;
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true);
rewriter.create<spirv::FunctionCallOp>(loc, TypeRange(), funcName, args);
rewriter.eraseOp(op);
}
return success();
}
};

class DpasToGenISA : public OpConversionPattern<DpasOp> {
public:
using OpConversionPattern<DpasOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(DpasOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
{
OpBuilder::InsertionGuard guard(rewriter);
auto func = op->getParentOfType<spirv::FuncOp>();
rewriter.setInsertionPointAfter(func);
rewriter.create<spirv::ExecutionModeOp>(
op.getLoc(), func, spirv::ExecutionMode::SubgroupSize, 16);
}
auto i32Type = rewriter.getI32Type();
auto createIntConstant = [&](Type type, unsigned value) {
auto attr = rewriter.getIntegerAttr(type, value);
return rewriter.create<spirv::ConstantOp>(loc, type, attr);
};
auto encodePrecision = [&](Type type) -> uint8_t {
if (type == rewriter.getBF16Type())
return 9;
else if (type == rewriter.getF16Type())
return 10;
else if (type == rewriter.getTF32Type())
return 12;
else {
assert(0 && "add more support");
return 0;
}
};
auto lType = op.getLhs().getType().cast<VectorType>();
auto rType = op.getRhs().getType().cast<VectorType>();
auto resultType = op.getResultType().cast<VectorType>();
auto [lhsStr, lhsType] = encodeGenISAVectorType(rewriter, lType, false);
auto [rhsStr, rhsType] = encodeGenISAVectorType(rewriter, rType, false);
auto [newStr, newType] = encodeGenISAVectorType(rewriter, resultType);
auto lhs =
rewriter.create<spirv::BitcastOp>(loc, lhsType, adaptor.getLhs());
auto rhs =
rewriter.create<spirv::BitcastOp>(loc, rhsType, adaptor.getRhs());
uint8_t preca = encodePrecision(lType.getElementType());
uint8_t precb = encodePrecision(rType.getElementType());
auto precA = createIntConstant(i32Type, preca);
auto precB = createIntConstant(i32Type, precb);
// fixed for now
auto rc = createIntConstant(i32Type, 8);
auto sd = createIntConstant(i32Type, 8);
auto dpasW = createIntConstant(rewriter.getI1Type(), 0);
Value acc = op.getAcc() ? adaptor.getAcc()
: rewriter.create<spirv::UndefOp>(loc, newType);
SmallVector<Value, 8> args{acc, lhs, rhs, precA, precB, sd, rc, dpasW};
std::string funcName = "llvm.genx.GenISA.sub.group.dpas.";
funcName += newStr;
funcName += ".";
funcName += newStr;
funcName += ".";
funcName += lhsStr;
funcName += ".";
funcName += rhsStr;
auto funcType =
rewriter.getFunctionType(ValueRange(args).getTypes(), newType);
Operation *opPtr = op;
lookupOrInsertIntrinsic(rewriter, opPtr, funcName, funcType, true);
auto funcOp =
rewriter.create<spirv::FunctionCallOp>(loc, newType, funcName, args);
rewriter.replaceOp(op, funcOp);
return success();
}
};

void imex::populateXeGPUToGenISAPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<CreateNdDescToGenISA, DpasToGenISA,
LoadStorePrefetchNdToGenISA<LoadNDOp>,
LoadStorePrefetchNdToGenISA<StoreNDOp>,
LoadStorePrefetchNdToGenISA<PrefetchNDOp>>(
typeConverter, patterns.getContext());
}
55 changes: 24 additions & 31 deletions test/Conversion/XeGPUToSPIRV/gemm_basic.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// RUN: imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s
// RUN: IMEX_NOT_PREFER_RAWSEND=1 imex-opt -imex-convert-gpu-to-spirv %s | FileCheck %s --check-prefix=LSC
// RUN: imex-opt -imex-convert-gpu-to-spirv='enable-vc-intrinsic=false' %s | FileCheck %s

#sg_map_fp16_a = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [2, 8], wi_data = [1, 2]}>
#sg_map_fp16_b = #xegpu.sg_map<{mma_block_size = [16, 16], wi_layout = [1, 16], wi_data = [1, 1]}>
#sg_map_fp16_c = #xegpu.sg_map<{mma_block_size = [8, 16], wi_layout = [1, 16], wi_data = [1, 1]}>
module @gemm attributes {gpu.container_module} {
memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<5.000000e-01>
memref.global "private" constant @__constant_16x16xf16 : memref<16x16xf16> = dense<1.099610e+00>
Expand All @@ -15,35 +18,25 @@ module @gemm attributes {gpu.container_module} {
gpu.dealloc %memref_0 : memref<16x16xf16>
return %memref_1 : memref<8x16xf32>
}
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
// LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64
// LSC: spirv.FunctionCall @llvm_genx_lsc_prefetch2d_stateless_i1_i64
// LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v64i32_i1_i64
// LSC: spirv.FunctionCall @llvm_genx_lsc_load2d_stateless_v128i32_i1_i64
// LSC: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32
// LSC: spirv.FunctionCall @llvm_genx_lsc_store2d_stateless_i1_i64_v128f32
// CHECK: %[[BASE:.*]] = spirv.ConvertPtrToU %arg0 : !spirv.ptr<!spirv.array<128 x f16>, CrossWorkgroup> to i64
// CHECK: %[[BASE1:.*]] = spirv.VectorInsertDynamic %[[BASE]]
// CHECK: %[[BASE2:.*]] = spirv.Bitcast %[[BASE1]]
// CHECK: spirv.VectorInsertDynamic
// CHECK: spirv.VectorInsertDynamic
// CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32
// CHECK: spirv.FunctionCall @llvm_genx_raw_send2_noresult_i1_v8i32
// CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v64i32_i1_v8i32
// CHECK: spirv.FunctionCall @llvm_genx_raw_send2_v128i32_i1_v8i32
// CHECK: spirv.FunctionCall @llvm_genx_dpas_nosrc0_v128f32_v128i32_v64i32
// CHECK: spirv.FunctionCall @llvm_genx_raw_sends2_noresult_i1_v8i32_v128f32
%0 = xegpu.create_nd_tdesc %arg0[0, 0] {mode = vc} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %arg1[0, 0] {mode = vc} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
%2 = xegpu.create_nd_tdesc %arg2[0, 0] {mode = vc} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
xegpu.prefetch_nd %0 {mode = vc} : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<16x16xf16>

%3 = xegpu.load_nd %0 {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16>
%4 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
%5 = xegpu.dpas %3, %4 {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
xegpu.store_nd %5, %2 {mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
// CHECK: %[[a:.*]] = spirv.FunctionCall @llvm.genx.GenISA.LSC2DBlockRead.v8i16
// CHECK: %[[a0:.*]] = spirv.Bitcast %[[a]]
// CHECK: %[[b:.*]] = spirv.FunctionCall @llvm.genx.GenISA.LSC2DBlockRead.v16i16
// CHECK: %[[b0:.*]] = spirv.Bitcast %[[b]]
// CHECK: %[[A:.*]] = spirv.Bitcast %[[a0]]
// CHECK: %[[B:.*]] = spirv.Bitcast %[[b0]]
// CHECK: %[[C:.*]] = spirv.FunctionCall @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v16i16
// CHECK-SAME: %[[A]], %[[B]]
// CHECK: spirv.FunctionCall @llvm.genx.GenISA.LSC2DBlockWrite.isVoid
// CHECK-SAME: %[[C]]
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a>
%1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b>
%2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c>
%3 = xegpu.load_nd %0 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a> -> vector<4x1x2xf16>
%4 = xegpu.load_nd %1 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b> -> vector<8x1x2xf16>
%5 = xegpu.dpas %3, %4 : vector<4x1x2xf16>, vector<8x1x2xf16> -> vector<8x1xf32>
xegpu.store_nd %5, %2 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c>
gpu.return
}
}
Expand Down
Loading

0 comments on commit 05e7fbc

Please sign in to comment.