Skip to content

Commit

Permalink
XXX: Apply WIP patch. This should fix CI.
Browse files Browse the repository at this point in the history
  • Loading branch information
ingomueller-net committed Jul 9, 2024
1 parent 12f3940 commit 85cdf54
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/buildAndTestStructured.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ jobs:
path: sandbox
submodules: recursive

- name: Patch LLVM with WIP patch for testing
run: |
cd ${STRUCTURED_MAIN_SRC_DIR}/third_party/llvm-project
patch -p1 < ../../cast.patch
git diff
- name: Install Ninja
uses: llvm/actions/install-ninja@6a57890d0e3f9f35dfc72e7e48bc5e1e527cdd6c # Jan 17

Expand Down
124 changes: 124 additions & 0 deletions cast.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
commit 3ea3d4bed57c4f6a35bed044bca8c1277fa2bb17
Author: lipracer <[email protected]>
Date: Fri Mar 29 23:25:07 2024 +0800

[mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIR interface>

Fixes https://github.com/llvm/llvm-project/issues/86647

diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index bd68c2744574..5ba39b80b513 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -22,6 +22,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/Operation.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/PointerLikeTypeTraits.h"

#include <optional>
@@ -2110,6 +2111,34 @@ struct DenseMapInfo<T,
}
static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
};
+
+template <typename To, typename From>
+struct CastInfo<
+ To, From,
+ std::enable_if_t<
+ std::is_base_of_v<mlir::OpInterface<To, typename To::InterfaceTraits>,
+ To> &&
+ std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
+ typename std::remove_const_t<
+ From>::InterfaceTraits>,
+ std::remove_const_t<From>>,
+ void>> : NullableValueCastFailed<To>,
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+ static bool isPossible(From &val) {
+ if constexpr (std::is_same_v<To, From>)
+ return true;
+ else
+ return mlir::OpInterface<To, typename To::InterfaceTraits>::
+ InterfaceBase::classof(
+ const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+
+ static To doCast(From &val) {
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+};
+
} // namespace llvm

#endif
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 2a7406f42f34..c6409e9ec30e 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
// Emit the main interface class declaration.
os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
"public:\n"
- " using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
+ " using ::mlir::{3}<{1}, detail::{2}>::{3};\n"
+ " using InterfaceTraits = detail::{2};\n",
interfaceName, interfaceName, interfaceTraitsName,
interfaceBaseType);

diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 5ab4d9a10623..7012da669248 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -16,6 +16,9 @@
#include "../../test/lib/Dialect/Test/TestAttributes.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Parser/Parser.h"
+#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace test;
@@ -83,3 +86,40 @@ TEST(InterfaceTest, TestImplicitConversion) {
typeA = typeB;
EXPECT_EQ(typeA, typeB);
}
+
+TEST(OperationInterfaceTest, CastOpToInterface) {
+ DialectRegistry registry;
+ MLIRContext ctx;
+
+ const char *ir = R"MLIR(
+ func.func @map(%arg : tensor<1xi64>) {
+ %0 = arith.constant dense<[10]> : tensor<1xi64>
+ %1 = arith.addi %arg, %0 : tensor<1xi64>
+ return
+ }
+ )MLIR";
+
+ registry.insert<func::FuncDialect, arith::ArithDialect>();
+ ctx.appendDialectRegistry(registry);
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
+
+ OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);
+
+ bool constantOp =
+ llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
+ .Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) {
+ return std::is_same_v<decltype(op), arith::ConstantOp>;
+ });
+
+ EXPECT_TRUE(constantOp);
+
+ EXPECT_FALSE(llvm::isa<VectorUnrollOpInterface>(interface));
+ EXPECT_FALSE(llvm::dyn_cast<VectorUnrollOpInterface>(interface));
+
+ EXPECT_TRUE(llvm::isa<InferTypeOpInterface>(interface));
+ EXPECT_TRUE(llvm::dyn_cast<InferTypeOpInterface>(interface));
+
+ EXPECT_TRUE(llvm::isa<OpAsmOpInterface>(interface));
+ EXPECT_TRUE(llvm::dyn_cast<OpAsmOpInterface>(interface));
+}

0 comments on commit 85cdf54

Please sign in to comment.