Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

[MLIR][OpenMP] Prevent CSE and constant materialization from crossing some OpenMP region boundaries #227

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions mlir/include/mlir/Interfaces/CSEInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- CSEInterfaces.h ------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_CSEINTERFACES_H_
#define MLIR_INTERFACES_CSEINTERFACES_H_

#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir {

/// Define an interface to allow for dialects to control specific aspects of
/// common subexpression elimination behavior for operations they define.
class DialectCSEInterface : public DialectInterface::Base<DialectCSEInterface> {
public:
DialectCSEInterface(Dialect *dialect) : Base(dialect) {}

/// Registered hook to check if an operation that is *not* isolated from
/// above, should allow common subexpressions to be extracted out of its
/// regions.
virtual bool subexpressionExtractionAllowed(Operation *op) const {
return true;
}
};

} // namespace mlir

#endif // MLIR_INTERFACES_CSEINTERFACES_H_
13 changes: 12 additions & 1 deletion mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/CSEInterfaces.h"
#include "mlir/Interfaces/FoldInterfaces.h"

#include "llvm/ADT/BitVector.h"
Expand Down Expand Up @@ -46,12 +47,21 @@ struct PointerLikeModel
}
};

struct OpenMPDialectCSEInterface : public DialectCSEInterface {
using DialectCSEInterface::DialectCSEInterface;

bool subexpressionExtractionAllowed(Operation *op) const final {
// Avoid extracting common subexpressions across op boundaries
return !isa<TargetOp, TeamsOp, ParallelOp>(op);
}
};

struct OpenMPDialectFoldInterface : public DialectFoldInterface {
using DialectFoldInterface::DialectFoldInterface;

bool shouldMaterializeInto(Region *region) const final {
// Avoid folding constants across target regions
return isa<TargetOp>(region->getParentOp());
return isa<TargetOp, TeamsOp, ParallelOp>(region->getParentOp());
}
};
} // namespace
Expand All @@ -66,6 +76,7 @@ void OpenMPDialect::initialize() {
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
>();

addInterface<OpenMPDialectCSEInterface>();
addInterface<OpenMPDialectFoldInterface>();
LLVM::LLVMPointerType::attachInterface<
PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
Expand Down
11 changes: 9 additions & 2 deletions mlir/lib/Transforms/CSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/CSEInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
Expand Down Expand Up @@ -61,7 +62,8 @@ namespace {
class CSEDriver {
public:
CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
: rewriter(rewriter), domInfo(domInfo) {}
: rewriter(rewriter), domInfo(domInfo),
interfaces(rewriter.getContext()) {}

/// Simplify all operations within the given op.
void simplify(Operation *op, bool *changed = nullptr);
Expand Down Expand Up @@ -122,6 +124,8 @@ class CSEDriver {
DominanceInfo *domInfo = nullptr;
MemEffectsCache memEffectsCache;

DialectInterfaceCollection<DialectCSEInterface> interfaces;

// Various statistics.
int64_t numCSE = 0;
int64_t numDCE = 0;
Expand Down Expand Up @@ -289,7 +293,10 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
// If this operation is isolated above, we can't process nested regions
// with the given 'knownValues' map. This would cause the insertion of
// implicit captures in explicit capture only regions.
if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
const DialectCSEInterface *cseInterface = interfaces.getInterfaceFor(&op);
if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
LLVM_UNLIKELY(cseInterface &&
!cseInterface->subexpressionExtractionAllowed(&op))) {
ScopedMapTy nestedKnownValues;
for (auto &region : op.getRegions())
simplifyRegion(nestedKnownValues, region);
Expand Down