-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
XXX: Apply WIP patch. This should fix CI.
- Loading branch information
1 parent
12f3940
commit 85cdf54
Showing
2 changed files
with
130 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
commit 3ea3d4bed57c4f6a35bed044bca8c1277fa2bb17 | ||
Author: lipracer <[email protected]> | ||
Date: Fri Mar 29 23:25:07 2024 +0800 | ||
|
||
[mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIR interface> | ||
|
||
Fixes https://github.com/llvm/llvm-project/issues/86647 | ||
|
||
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h | ||
index bd68c2744574..5ba39b80b513 100644 | ||
--- a/mlir/include/mlir/IR/OpDefinition.h | ||
+++ b/mlir/include/mlir/IR/OpDefinition.h | ||
@@ -22,6 +22,7 @@ | ||
#include "mlir/IR/Dialect.h" | ||
#include "mlir/IR/ODSSupport.h" | ||
#include "mlir/IR/Operation.h" | ||
+#include "llvm/Support/Casting.h" | ||
#include "llvm/Support/PointerLikeTypeTraits.h" | ||
|
||
#include <optional> | ||
@@ -2110,6 +2111,34 @@ struct DenseMapInfo<T, | ||
} | ||
static bool isEqual(T lhs, T rhs) { return lhs == rhs; } | ||
}; | ||
+ | ||
+template <typename To, typename From> | ||
+struct CastInfo< | ||
+ To, From, | ||
+ std::enable_if_t< | ||
+ std::is_base_of_v<mlir::OpInterface<To, typename To::InterfaceTraits>, | ||
+ To> && | ||
+ std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>, | ||
+ typename std::remove_const_t< | ||
+ From>::InterfaceTraits>, | ||
+ std::remove_const_t<From>>, | ||
+ void>> : NullableValueCastFailed<To>, | ||
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> { | ||
+ | ||
+ static bool isPossible(From &val) { | ||
+ if constexpr (std::is_same_v<To, From>) | ||
+ return true; | ||
+ else | ||
+ return mlir::OpInterface<To, typename To::InterfaceTraits>:: | ||
+ InterfaceBase::classof( | ||
+ const_cast<std::remove_const_t<From> &>(val).getOperation()); | ||
+ } | ||
+ | ||
+ static To doCast(From &val) { | ||
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation()); | ||
+ } | ||
+}; | ||
+ | ||
} // namespace llvm | ||
|
||
#endif | ||
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | ||
index 2a7406f42f34..c6409e9ec30e 100644 | ||
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | ||
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | ||
@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { | ||
// Emit the main interface class declaration. | ||
os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" | ||
"public:\n" | ||
- " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", | ||
+ " using ::mlir::{3}<{1}, detail::{2}>::{3};\n" | ||
+ " using InterfaceTraits = detail::{2};\n", | ||
interfaceName, interfaceName, interfaceTraitsName, | ||
interfaceBaseType); | ||
|
||
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp | ||
index 5ab4d9a10623..7012da669248 100644 | ||
--- a/mlir/unittests/IR/InterfaceTest.cpp | ||
+++ b/mlir/unittests/IR/InterfaceTest.cpp | ||
@@ -16,6 +16,9 @@ | ||
#include "../../test/lib/Dialect/Test/TestAttributes.h" | ||
#include "../../test/lib/Dialect/Test/TestDialect.h" | ||
#include "../../test/lib/Dialect/Test/TestTypes.h" | ||
+#include "mlir/Dialect/Arith/IR/Arith.h" | ||
+#include "mlir/Parser/Parser.h" | ||
+#include "llvm/ADT/TypeSwitch.h" | ||
|
||
using namespace mlir; | ||
using namespace test; | ||
@@ -83,3 +86,40 @@ TEST(InterfaceTest, TestImplicitConversion) { | ||
typeA = typeB; | ||
EXPECT_EQ(typeA, typeB); | ||
} | ||
+ | ||
+TEST(OperationInterfaceTest, CastOpToInterface) { | ||
+ DialectRegistry registry; | ||
+ MLIRContext ctx; | ||
+ | ||
+ const char *ir = R"MLIR( | ||
+ func.func @map(%arg : tensor<1xi64>) { | ||
+ %0 = arith.constant dense<[10]> : tensor<1xi64> | ||
+ %1 = arith.addi %arg, %0 : tensor<1xi64> | ||
+ return | ||
+ } | ||
+ )MLIR"; | ||
+ | ||
+ registry.insert<func::FuncDialect, arith::ArithDialect>(); | ||
+ ctx.appendDialectRegistry(registry); | ||
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); | ||
+ Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front(); | ||
+ | ||
+ OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op); | ||
+ | ||
+ bool constantOp = | ||
+ llvm::TypeSwitch<OpAsmOpInterface, bool>(interface) | ||
+ .Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) { | ||
+ return std::is_same_v<decltype(op), arith::ConstantOp>; | ||
+ }); | ||
+ | ||
+ EXPECT_TRUE(constantOp); | ||
+ | ||
+ EXPECT_FALSE(llvm::isa<VectorUnrollOpInterface>(interface)); | ||
+ EXPECT_FALSE(llvm::dyn_cast<VectorUnrollOpInterface>(interface)); | ||
+ | ||
+ EXPECT_TRUE(llvm::isa<InferTypeOpInterface>(interface)); | ||
+ EXPECT_TRUE(llvm::dyn_cast<InferTypeOpInterface>(interface)); | ||
+ | ||
+ EXPECT_TRUE(llvm::isa<OpAsmOpInterface>(interface)); | ||
+ EXPECT_TRUE(llvm::dyn_cast<OpAsmOpInterface>(interface)); | ||
+} |