-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update llvm to ba7cb620ac002a94af0e1656ba591308f7073ab9
Fix post LLVM issue: operand_segment_sizes ->operandSegmentSizes Update SPIRV patch
- Loading branch information
Showing
7 changed files
with
81 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
ebb0a210995dcf69d9696f8e14629e1378e63a21 | ||
ba7cb620ac002a94af0e1656ba591308f7073ab9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
From 66182e3a1ce8bdc2dac7de82eabb95ed759d5eb6 Mon Sep 17 00:00:00 2001 | ||
From 174b0ae71b997aa3252382b6024aab7b6a5d110c Mon Sep 17 00:00:00 2001 | ||
From: Md Abdullah Shahneous Bari <[email protected]> | ||
Date: Mon, 5 Dec 2022 19:09:10 -0800 | ||
Date: Thu, 24 Aug 2023 09:05:47 -0700 | ||
Subject: [PATCH 1/2] Add support for VectorAnyINTEL capability | ||
|
||
Allow vector of any lengths between [2-2^63-1]. | ||
|
@@ -25,20 +25,20 @@ requirement initially, then do the check for capability inferred extension. | |
- Add support for optionally skipping capability and extension requirement | ||
--- | ||
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 11 +- | ||
mlir/include/mlir/IR/OpBase.td | 86 +++++++++++ | ||
mlir/include/mlir/IR/CommonTypeConstraints.td | 86 ++++++++++++ | ||
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 7 +- | ||
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 32 +++-- | ||
.../SPIRV/Transforms/SPIRVConversion.cpp | 134 ++++++++++++++---- | ||
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 24 +++- | ||
.../SPIRV/Transforms/SPIRVConversion.cpp | 132 +++++++++++++++--- | ||
.../arith-to-spirv-unsupported.mlir | 4 +- | ||
.../ArithToSPIRV/arith-to-spirv.mlir | 32 +++++ | ||
.../ArithToSPIRV/arith-to-spirv.mlir | 33 +++++ | ||
.../FuncToSPIRV/types-to-spirv.mlir | 17 ++- | ||
mlir/test/Dialect/SPIRV/IR/bit-ops.mlir | 6 +- | ||
mlir/test/Dialect/SPIRV/IR/gl-ops.mlir | 2 +- | ||
mlir/test/Dialect/SPIRV/IR/logical-ops.mlir | 2 +- | ||
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 42 +++--- | ||
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 36 ++--- | ||
mlir/test/Target/SPIRV/arithmetic-ops.mlir | 6 +- | ||
mlir/test/Target/SPIRV/ocl-ops.mlir | 6 + | ||
14 files changed, 314 insertions(+), 73 deletions(-) | ||
14 files changed, 311 insertions(+), 61 deletions(-) | ||
|
||
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | ||
index 1e61aa747967..6f0f728f811e 100644 | ||
|
@@ -71,11 +71,11 @@ index 1e61aa747967..6f0f728f811e 100644 | |
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>; | ||
|
||
class SPIRV_MatrixOrCoopMatrixOf<Type type> : | ||
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td | ||
index 940588b7c0f9..917a79ea65c0 100644 | ||
--- a/mlir/include/mlir/IR/OpBase.td | ||
+++ b/mlir/include/mlir/IR/OpBase.td | ||
@@ -838,6 +838,92 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths, | ||
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td | ||
index 4fc14e30b8a1..74739ecccd0d 100644 | ||
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td | ||
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td | ||
@@ -546,6 +546,92 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks, | ||
ScalableVectorOfLength<allowedLengths>.summary, | ||
"::mlir::VectorType">; | ||
|
||
|
@@ -188,24 +188,16 @@ index 124d4ed6e8e6..9188f8b699b4 100644 | |
return Type(); | ||
} | ||
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | ||
index 01c694de08a9..741d8069471d 100644 | ||
index 39d6603a46f9..741d8069471d 100644 | ||
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | ||
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | ||
@@ -101,17 +101,11 @@ bool CompositeType::classof(Type type) { | ||
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) { | ||
} | ||
|
||
bool CompositeType::isValid(VectorType type) { | ||
- switch (type.getNumElements()) { | ||
- case 2: | ||
- case 3: | ||
- case 4: | ||
- case 8: | ||
- case 16: | ||
- break; | ||
- default: | ||
- return false; | ||
- } | ||
- return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType()); | ||
- return type.getRank() == 1 && | ||
- llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && | ||
- llvm::isa<ScalarType>(type.getElementType()); | ||
+ // Number of elements should be between [2 - 2^63 -1], | ||
+ // since getNumElements() returns an unsigned, the upper limit check is | ||
+ // unnecessary | ||
|
@@ -214,7 +206,7 @@ index 01c694de08a9..741d8069471d 100644 | |
} | ||
|
||
Type CompositeType::getElementType(unsigned index) const { | ||
@@ -179,7 +173,21 @@ void CompositeType::getCapabilities( | ||
@@ -171,7 +173,21 @@ void CompositeType::getCapabilities( | ||
.Case<VectorType>([&](VectorType type) { | ||
auto vecSize = getNumElements(); | ||
if (vecSize == 8 || vecSize == 16) { | ||
|
@@ -238,10 +230,10 @@ index 01c694de08a9..741d8069471d 100644 | |
capabilities.push_back(ref); | ||
} | ||
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | ||
index c8d7aef89642..f7c5a72841af 100644 | ||
index c75d217663a9..f7a8a2a3d281 100644 | ||
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | ||
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | ||
@@ -42,9 +42,13 @@ using namespace mlir; | ||
@@ -43,9 +43,13 @@ using namespace mlir; | ||
template <typename LabelT> | ||
static LogicalResult checkExtensionRequirements( | ||
LabelT label, const spirv::TargetEnv &targetEnv, | ||
|
@@ -257,7 +249,7 @@ index c8d7aef89642..f7c5a72841af 100644 | |
continue; | ||
|
||
LLVM_DEBUG({ | ||
@@ -70,9 +74,13 @@ static LogicalResult checkExtensionRequirements( | ||
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements( | ||
template <typename LabelT> | ||
static LogicalResult checkCapabilityRequirements( | ||
LabelT label, const spirv::TargetEnv &targetEnv, | ||
|
@@ -273,12 +265,10 @@ index c8d7aef89642..f7c5a72841af 100644 | |
continue; | ||
|
||
LLVM_DEBUG({ | ||
@@ -89,8 +97,57 @@ static LogicalResult checkCapabilityRequirements( | ||
@@ -90,6 +98,55 @@ static LogicalResult checkCapabilityRequirements( | ||
return success(); | ||
} | ||
|
||
-/// Returns true if the given `storageClass` needs explicit layout when used in | ||
-/// Shader environments. | ||
+/// Check capabilities and extensions requirements, | ||
+/// this function also checks for capability infered extension requirements, | ||
+/// the check is based on capabilities that are passed to the targetEnv. | ||
|
@@ -328,16 +318,15 @@ index c8d7aef89642..f7c5a72841af 100644 | |
+ return success(); | ||
+} | ||
+ | ||
+/// Returns true if the given `storageClass` needs explicit layout when used | ||
+/// in Shader environments. | ||
/// Returns true if the given `storageClass` needs explicit layout when used in | ||
/// Shader environments. | ||
static bool needsExplicitLayout(spirv::StorageClass storageClass) { | ||
switch (storageClass) { | ||
case spirv::StorageClass::PhysicalStorageBuffer: | ||
@@ -246,12 +303,15 @@ convertScalarType(const spirv::TargetEnv &targetEnv, | ||
@@ -247,12 +304,17 @@ convertScalarType(const spirv::TargetEnv &targetEnv, | ||
return nullptr; | ||
} | ||
|
||
- if (auto floatType = dyn_cast<FloatType>(type)) { | ||
+ //if (auto floatType = dyn_cast<FloatType>(type)) { | ||
+ // Convert to 32-bit float and remove floatType related capability | ||
+ // restriction | ||
+ if (auto floatType = type.dyn_cast<FloatType>()) { | ||
|
@@ -346,12 +335,13 @@ index c8d7aef89642..f7c5a72841af 100644 | |
} | ||
|
||
- auto intType = cast<IntegerType>(type); | ||
+ //auto intType = cast<IntegerType>(type); | ||
+ // Convert to 32-bit int and remove intType related capability restriction | ||
+ auto intType = type.cast<IntegerType>(); | ||
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); | ||
return IntegerType::get(targetEnv.getContext(), /*width=*/32, | ||
intType.getSignedness()); | ||
@@ -322,16 +382,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv, | ||
@@ -342,16 +404,40 @@ convertVectorType(const spirv::TargetEnv &targetEnv, | ||
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass); | ||
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass); | ||
|
||
|
@@ -399,7 +389,7 @@ index c8d7aef89642..f7c5a72841af 100644 | |
} | ||
|
||
static Type | ||
@@ -1118,16 +1202,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { | ||
@@ -1150,16 +1236,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { | ||
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; | ||
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; | ||
for (Type valueType : valueTypes) { | ||
|
@@ -444,10 +434,18 @@ index 0d92a8e676d8..d61ace8d6876 100644 | |
} | ||
|
||
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | ||
index d70df982c366..a9a8237e53d4 100644 | ||
index aa2cd649ecd7..b951d7490d64 100644 | ||
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | ||
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | ||
@@ -1298,3 +1298,35 @@ func.func @float_scalar(%arg0: f16) { | ||
@@ -29,6 +29,7 @@ func.func @int32_scalar(%lhs: i32, %rhs: i32) { | ||
|
||
// CHECK-LABEL: @int32_scalar_srem | ||
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) | ||
+ %1 = arith.subi %arg0, %arg0: vector<5xi32> | ||
func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) { | ||
// CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32 | ||
// CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32 | ||
@@ -1362,3 +1363,35 @@ func.func @float_scalar(%arg0: f16) { | ||
} | ||
|
||
} // end module | ||
|
@@ -484,10 +482,10 @@ index d70df982c366..a9a8237e53d4 100644 | |
+ | ||
+} // end module | ||
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir | ||
index ef1ee00b709f..b2abd7504b6d 100644 | ||
index 82d750755ffe..6f364c5b0875 100644 | ||
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir | ||
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir | ||
@@ -348,8 +348,21 @@ module attributes { | ||
@@ -351,8 +351,21 @@ module attributes { | ||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>> | ||
} { | ||
|
||
|
@@ -569,33 +567,18 @@ index 7dc0bd99f54b..5dd9901828cd 100644 | |
return | ||
} | ||
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | ||
index 29a4a4613615..24fe2f945841 100644 | ||
index 29a4a4613615..869de34c83b1 100644 | ||
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | ||
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | ||
@@ -18,17 +18,17 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () { | ||
|
||
// ----- | ||
|
||
-func.func @exp(%arg0 : i32) -> () { | ||
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} | ||
- %2 = spirv.CL.exp %arg0 : i32 | ||
+func.func @exp_any_vec(%arg0 : vector<5xf32>) -> () { | ||
+ // CHECK: spirv.CL.exp {{%.*}} : vector<5xf32> | ||
+ %2 = spirv.CL.exp %arg0 : vector<5xf32> | ||
return | ||
} | ||
|
||
@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { | ||
// ----- | ||
|
||
-func.func @exp(%arg0 : vector<5xf32>) -> () { | ||
func.func @exp(%arg0 : vector<5xf32>) -> () { | ||
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} | ||
- %2 = spirv.CL.exp %arg0 : vector<5xf32> | ||
+func.func @exp(%arg0 : i32) -> () { | ||
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} | ||
+ %2 = spirv.CL.exp %arg0 : i32 | ||
+ // CHECK: spirv.CL.exp {{%.*}} : vector<5xf32> | ||
%2 = spirv.CL.exp %arg0 : vector<5xf32> | ||
return | ||
} | ||
|
||
@@ -66,6 +66,14 @@ func.func @fabsvec(%arg0 : vector<3xf16>) -> () { | ||
return | ||
} | ||
|
@@ -641,20 +624,29 @@ index 29a4a4613615..24fe2f945841 100644 | |
func.func @sabsi64(%arg0 : i64) -> () { | ||
// CHECK: spirv.CL.s_abs {{%.*}} : i64 | ||
%2 = spirv.CL.s_abs %arg0 : i64 | ||
@@ -142,13 +150,7 @@ func.func @sabs(%arg0 : f32) -> () { | ||
@@ -137,21 +145,13 @@ func.func @sabsi8(%arg0 : i8) -> () { | ||
// ----- | ||
|
||
func.func @sabs(%arg0 : f32) -> () { | ||
- // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}} | ||
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} | ||
%2 = spirv.CL.s_abs %arg0 : f32 | ||
return | ||
} | ||
|
||
-// ----- | ||
// ----- | ||
|
||
-func.func @sabs(%arg0 : vector<5xi32>) -> () { | ||
- // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} | ||
- %2 = spirv.CL.s_abs %arg0 : vector<5xi32> | ||
- return | ||
-} | ||
|
||
// ----- | ||
|
||
- | ||
-// ----- | ||
- | ||
func.func @sabs(%arg0 : i32, %arg1 : i32) -> () { | ||
// expected-error @+1 {{expected ':'}} | ||
%2 = spirv.CL.s_abs %arg0, %arg1 : i32 | ||
diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir | ||
index b1ea13c6854f..90144afc6f3a 100644 | ||
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir | ||
|
@@ -690,4 +682,4 @@ index 9a2e4cf62e37..31a7f616d648 100644 | |
// CHECK: spirv.CL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 | ||
%13 = spirv.CL.fma %arg0, %arg1, %arg2 : f32 | ||
-- | ||
2.34.1 | ||
2.42.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.