Skip to content

Commit

Permalink
Update llvm to ba7cb620ac002a94af0e1656ba591308f7073ab9
Browse files Browse the repository at this point in the history
Fix post LLVM issue:
operand_segment_sizes ->operandSegmentSizes
Update SPIRV patch
  • Loading branch information
silee2 committed Aug 29, 2023
1 parent cd757a7 commit 8536103
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 89 deletions.
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ebb0a210995dcf69d9696f8e14629e1378e63a21
ba7cb620ac002a94af0e1656ba591308f7073ab9
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
From 66182e3a1ce8bdc2dac7de82eabb95ed759d5eb6 Mon Sep 17 00:00:00 2001
From 174b0ae71b997aa3252382b6024aab7b6a5d110c Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <[email protected]>
Date: Mon, 5 Dec 2022 19:09:10 -0800
Date: Thu, 24 Aug 2023 09:05:47 -0700
Subject: [PATCH 1/2] Add support for VectorAnyINTEL capability

Allow vector of any lengths between [2-2^63-1].
Expand All @@ -25,20 +25,20 @@ requirement initially, then do the check for capability inferred extension.
- Add support for optionally skipping capability and extension requirement
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 +-
mlir/include/mlir/IR/OpBase.td | 86 +++++++++++
mlir/include/mlir/IR/CommonTypeConstraints.td | 86 ++++++++++++
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 7 +-
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 32 +++--
.../SPIRV/Transforms/SPIRVConversion.cpp | 134 ++++++++++++++----
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 24 +++-
.../SPIRV/Transforms/SPIRVConversion.cpp | 132 +++++++++++++++---
.../arith-to-spirv-unsupported.mlir | 4 +-
.../ArithToSPIRV/arith-to-spirv.mlir | 32 +++++
.../ArithToSPIRV/arith-to-spirv.mlir | 33 +++++
.../FuncToSPIRV/types-to-spirv.mlir | 17 ++-
mlir/test/Dialect/SPIRV/IR/bit-ops.mlir | 6 +-
mlir/test/Dialect/SPIRV/IR/gl-ops.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/logical-ops.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 42 +++---
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 36 ++---
mlir/test/Target/SPIRV/arithmetic-ops.mlir | 6 +-
mlir/test/Target/SPIRV/ocl-ops.mlir | 6 +
14 files changed, 314 insertions(+), 73 deletions(-)
14 files changed, 311 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1e61aa747967..6f0f728f811e 100644
Expand Down Expand Up @@ -71,11 +71,11 @@ index 1e61aa747967..6f0f728f811e 100644
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;

class SPIRV_MatrixOrCoopMatrixOf<Type type> :
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 940588b7c0f9..917a79ea65c0 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -838,6 +838,92 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a1..74739ecccd0d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -546,6 +546,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;

Expand Down Expand Up @@ -188,24 +188,16 @@ index 124d4ed6e8e6..9188f8b699b4 100644
return Type();
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 01c694de08a9..741d8069471d 100644
index 39d6603a46f9..741d8069471d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -101,17 +101,11 @@ bool CompositeType::classof(Type type) {
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) {
}

bool CompositeType::isValid(VectorType type) {
- switch (type.getNumElements()) {
- case 2:
- case 3:
- case 4:
- case 8:
- case 16:
- break;
- default:
- return false;
- }
- return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType());
- return type.getRank() == 1 &&
- llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
- llvm::isa<ScalarType>(type.getElementType());
+ // Number of elements should be between [2 - 2^63 -1],
+ // since getNumElements() returns an unsigned, the upper limit check is
+ // unnecessary
Expand All @@ -214,7 +206,7 @@ index 01c694de08a9..741d8069471d 100644
}

Type CompositeType::getElementType(unsigned index) const {
@@ -179,7 +173,21 @@ void CompositeType::getCapabilities(
@@ -171,7 +173,21 @@ void CompositeType::getCapabilities(
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
if (vecSize == 8 || vecSize == 16) {
Expand All @@ -238,10 +230,10 @@ index 01c694de08a9..741d8069471d 100644
capabilities.push_back(ref);
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c8d7aef89642..f7c5a72841af 100644
index c75d217663a9..f7a8a2a3d281 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -42,9 +42,13 @@ using namespace mlir;
@@ -43,9 +43,13 @@ using namespace mlir;
template <typename LabelT>
static LogicalResult checkExtensionRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
Expand All @@ -257,7 +249,7 @@ index c8d7aef89642..f7c5a72841af 100644
continue;

LLVM_DEBUG({
@@ -70,9 +74,13 @@ static LogicalResult checkExtensionRequirements(
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
template <typename LabelT>
static LogicalResult checkCapabilityRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
Expand All @@ -273,12 +265,10 @@ index c8d7aef89642..f7c5a72841af 100644
continue;

LLVM_DEBUG({
@@ -89,8 +97,57 @@ static LogicalResult checkCapabilityRequirements(
@@ -90,6 +98,55 @@ static LogicalResult checkCapabilityRequirements(
return success();
}

-/// Returns true if the given `storageClass` needs explicit layout when used in
-/// Shader environments.
+/// Check capabilities and extensions requirements,
+/// this function also checks for capability infered extension requirements,
+/// the check is based on capabilities that are passed to the targetEnv.
Expand Down Expand Up @@ -328,16 +318,15 @@ index c8d7aef89642..f7c5a72841af 100644
+ return success();
+}
+
+/// Returns true if the given `storageClass` needs explicit layout when used
+/// in Shader environments.
/// Returns true if the given `storageClass` needs explicit layout when used in
/// Shader environments.
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
switch (storageClass) {
case spirv::StorageClass::PhysicalStorageBuffer:
@@ -246,12 +303,15 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
@@ -247,12 +304,17 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
return nullptr;
}

- if (auto floatType = dyn_cast<FloatType>(type)) {
+ //if (auto floatType = dyn_cast<FloatType>(type)) {
+ // Convert to 32-bit float and remove floatType related capability
+ // restriction
+ if (auto floatType = type.dyn_cast<FloatType>()) {
Expand All @@ -346,12 +335,13 @@ index c8d7aef89642..f7c5a72841af 100644
}

- auto intType = cast<IntegerType>(type);
+ //auto intType = cast<IntegerType>(type);
+ // Convert to 32-bit int and remove intType related capability restriction
+ auto intType = type.cast<IntegerType>();
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
return IntegerType::get(targetEnv.getContext(), /*width=*/32,
intType.getSignedness());
@@ -322,16 +382,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
@@ -342,16 +404,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);

Expand Down Expand Up @@ -399,7 +389,7 @@ index c8d7aef89642..f7c5a72841af 100644
}

static Type
@@ -1118,16 +1202,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
@@ -1150,16 +1236,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
Expand Down Expand Up @@ -444,10 +434,18 @@ index 0d92a8e676d8..d61ace8d6876 100644
}

diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index d70df982c366..a9a8237e53d4 100644
index aa2cd649ecd7..b951d7490d64 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1298,3 +1298,35 @@ func.func @float_scalar(%arg0: f16) {
@@ -29,6 +29,7 @@ func.func @int32_scalar(%lhs: i32, %rhs: i32) {

// CHECK-LABEL: @int32_scalar_srem
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+ %1 = arith.subi %arg0, %arg0: vector<5xi32>
func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
// CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
// CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
@@ -1362,3 +1363,35 @@ func.func @float_scalar(%arg0: f16) {
}

} // end module
Expand Down Expand Up @@ -484,10 +482,10 @@ index d70df982c366..a9a8237e53d4 100644
+
+} // end module
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index ef1ee00b709f..b2abd7504b6d 100644
index 82d750755ffe..6f364c5b0875 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -348,8 +348,21 @@ module attributes {
@@ -351,8 +351,21 @@ module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
} {

Expand Down Expand Up @@ -569,33 +567,18 @@ index 7dc0bd99f54b..5dd9901828cd 100644
return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 29a4a4613615..24fe2f945841 100644
index 29a4a4613615..869de34c83b1 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -18,17 +18,17 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () {

// -----

-func.func @exp(%arg0 : i32) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
- %2 = spirv.CL.exp %arg0 : i32
+func.func @exp_any_vec(%arg0 : vector<5xf32>) -> () {
+ // CHECK: spirv.CL.exp {{%.*}} : vector<5xf32>
+ %2 = spirv.CL.exp %arg0 : vector<5xf32>
return
}

@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () {
// -----

-func.func @exp(%arg0 : vector<5xf32>) -> () {
func.func @exp(%arg0 : vector<5xf32>) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}}
- %2 = spirv.CL.exp %arg0 : vector<5xf32>
+func.func @exp(%arg0 : i32) -> () {
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %2 = spirv.CL.exp %arg0 : i32
+ // CHECK: spirv.CL.exp {{%.*}} : vector<5xf32>
%2 = spirv.CL.exp %arg0 : vector<5xf32>
return
}

@@ -66,6 +66,14 @@ func.func @fabsvec(%arg0 : vector<3xf16>) -> () {
return
}
Expand Down Expand Up @@ -641,20 +624,29 @@ index 29a4a4613615..24fe2f945841 100644
func.func @sabsi64(%arg0 : i64) -> () {
// CHECK: spirv.CL.s_abs {{%.*}} : i64
%2 = spirv.CL.s_abs %arg0 : i64
@@ -142,13 +150,7 @@ func.func @sabs(%arg0 : f32) -> () {
@@ -137,21 +145,13 @@ func.func @sabsi8(%arg0 : i8) -> () {
// -----

func.func @sabs(%arg0 : f32) -> () {
- // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
%2 = spirv.CL.s_abs %arg0 : f32
return
}

-// -----
// -----

-func.func @sabs(%arg0 : vector<5xi32>) -> () {
- // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
- %2 = spirv.CL.s_abs %arg0 : vector<5xi32>
- return
-}

// -----

-
-// -----
-
func.func @sabs(%arg0 : i32, %arg1 : i32) -> () {
// expected-error @+1 {{expected ':'}}
%2 = spirv.CL.s_abs %arg0, %arg1 : i32
diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index b1ea13c6854f..90144afc6f3a 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
Expand Down Expand Up @@ -690,4 +682,4 @@ index 9a2e4cf62e37..31a7f616d648 100644
// CHECK: spirv.CL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
%13 = spirv.CL.fma %arg0, %arg1, %arg2 : f32
--
2.34.1
2.42.0
6 changes: 3 additions & 3 deletions test/Conversion/GPUToGPUX/gpux-alloc-dealloc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ func.func @main() attributes {llvm.emit_c_interface} {
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
// CHECK: %[[STREAM:.*]] = "gpux.create_stream"() : () -> !gpux.StreamType
// CHECK: %[[ALLOC_0:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operand_segment_sizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
// CHECK: %[[ALLOC_0:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operandSegmentSizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref = gpu.alloc () : memref<8xf32>
// CHECK: %[[ALLOC_1:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operand_segment_sizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
// CHECK: %[[ALLOC_1:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operandSegmentSizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref_1 = gpu.alloc () : memref<8xf32>
// CHECK: %[[ALLOC_2:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operand_segment_sizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
// CHECK: %[[ALLOC_2:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operandSegmentSizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref_2 = gpu.alloc () : memref<8xf32>
// CHECK: "gpux.dealloc"(%[[STREAM:.*]], %[[ALLOC_0:.*]]) : (!gpux.StreamType, memref<8xf32>) -> ()
gpu.dealloc %memref : memref<8xf32>
Expand Down
8 changes: 4 additions & 4 deletions test/Conversion/GPUToGPUX/gpux-launch-func.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ func.func @main() attributes {llvm.emit_c_interface} {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[STREAM:.*]] = "gpux.create_stream"() : () -> !gpux.StreamType
// CHECK: %[[ALLOC_0:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operand_segment_sizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
// CHECK: %[[ALLOC_0:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operandSegmentSizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref = gpu.alloc () : memref<8xf32>
// CHECK: %[[ALLOC_1:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operand_segment_sizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
// CHECK: %[[ALLOC_1:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operandSegmentSizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref_1 = gpu.alloc () : memref<8xf32>
// CHECK: %[[ALLOC_2:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operand_segment_sizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
// CHECK: %[[ALLOC_2:.*]] = "gpux.alloc"(%[[STREAM:.*]]) {operandSegmentSizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref_2 = gpu.alloc () : memref<8xf32>
// CHECK: "gpux.launch_func"(%[[STREAM:.*]], %[[C8:.*]], %[[C1:.*]], %[[C1:.*]], %[[C1:.*]], %[[C1:.*]], %[[C1:.*]], %[[ALLOC_0:.*]], %[[ALLOC_1:.*]], %[[ALLOC_2:.*]]) {kernel = @Kernels::@kernel_1, operand_segment_sizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 1, 0, 3>} : (!gpux.StreamType, index, index, index, index, index, index, memref<8xf32>, memref<8xf32>, memref<8xf32>) -> ()
// CHECK: "gpux.launch_func"(%[[STREAM:.*]], %[[C8:.*]], %[[C1:.*]], %[[C1:.*]], %[[C1:.*]], %[[C1:.*]], %[[C1:.*]], %[[ALLOC_0:.*]], %[[ALLOC_1:.*]], %[[ALLOC_2:.*]]) {kernel = @Kernels::@kernel_1, operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 1, 0, 3>} : (!gpux.StreamType, index, index, index, index, index, index, memref<8xf32>, memref<8xf32>, memref<8xf32>) -> ()
gpu.launch_func @Kernels::@kernel_1 blocks in (%c8, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8xf32>, %memref_1 : memref<8xf32>, %memref_2 : memref<8xf32>)
gpu.dealloc %memref : memref<8xf32>
gpu.dealloc %memref_1 : memref<8xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module attributes {gpu.container_module}{
// CHECK: %[[STREAM:.*]] = llvm.call @gpuCreateStream(%[[DEVICE:.*]], %[[CONTEXT:.*]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> !llvm.ptr<i8>
%0 = "gpux.create_stream"() : () -> !gpux.StreamType
// CHECK: llvm.call @gpuMemAlloc(%[[stream:.*]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i32) -> !llvm.ptr<i8>
%memref = "gpux.alloc"(%0) {operand_segment_sizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref = "gpux.alloc"(%0) {operandSegmentSizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
// CHECK: llvm.call @gpuMemFree(%[[stream:.*]], %{{.*}}) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
"gpux.dealloc"(%0, %memref) : (!gpux.StreamType, memref<8xf32>) -> ()
"gpux.destroy_stream"(%0) : (!gpux.StreamType) -> ()
Expand Down
8 changes: 4 additions & 4 deletions test/Conversion/GPUXToLLVM/launch-func-to-gpu-runtime.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ module attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#s
// CHECK: %[[CONTEXT:.*]] = llvm.mlir.null : !llvm.ptr<i8>
// CHECK: %[[STREAM:.*]] = llvm.call @gpuCreateStream(%[[DEVICE:.*]], %[[CONTEXT:.*]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> !llvm.ptr<i8>
%0 = "gpux.create_stream"() : () -> !gpux.StreamType
%memref = "gpux.alloc"(%0) {operand_segment_sizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref_0 = "gpux.alloc"(%0) {operand_segment_sizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref_1 = "gpux.alloc"(%0) {operand_segment_sizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref = "gpux.alloc"(%0) {operandSegmentSizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref_0 = "gpux.alloc"(%0) {operandSegmentSizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>
%memref_1 = "gpux.alloc"(%0) {operandSegmentSizes = array<i32: 0, 1, 0, 0>} : (!gpux.StreamType) -> memref<8xf32>

// CHECK: llvm.mlir.addressof @Kernels_spirv_binary : !llvm.ptr<array<552 x i8>>
// CHECK: %[[MODULE:.*]] = llvm.call @gpuModuleLoad(%[[STREAM:.*]], %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
// CHECK: llvm.mlir.addressof @Kernels_kernel_1_kernel_name : !llvm.ptr<array<9 x i8>>
// CHECK: %[[KERNEL:.*]] = llvm.call @gpuKernelGet(%[[STREAM:.*]], %[[MODULE:.*]], %{{.*}}) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: llvm.call @gpuLaunchKernel(%[[STREAM:.*]], %[[KERNEL:.*]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i64, i64, i64, i64, i64, i64, i32, !llvm.ptr<struct<(ptr<i8>, i64)>>) -> ()
"gpux.launch_func"(%0, %c8, %c1, %c1, %c1, %c1, %c1, %memref, %memref_0, %memref_1) {kernel = @Kernels::@kernel_1, operand_segment_sizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 1, 0, 3>} : (!gpux.StreamType, index, index, index, index, index, index, memref<8xf32>, memref<8xf32>, memref<8xf32>) -> ()
"gpux.launch_func"(%0, %c8, %c1, %c1, %c1, %c1, %c1, %memref, %memref_0, %memref_1) {kernel = @Kernels::@kernel_1, operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 1, 0, 3>} : (!gpux.StreamType, index, index, index, index, index, index, memref<8xf32>, memref<8xf32>, memref<8xf32>) -> ()
"gpux.dealloc"(%0, %memref) : (!gpux.StreamType, memref<8xf32>) -> ()
"gpux.dealloc"(%0, %memref_0) : (!gpux.StreamType, memref<8xf32>) -> ()
"gpux.dealloc"(%0, %memref_1) : (!gpux.StreamType, memref<8xf32>) -> ()
Expand Down
Loading

0 comments on commit 8536103

Please sign in to comment.