Skip to content

Commit

Permalink
Revert "Reapply "[DispatchCreation] Extend multi-use producer fusion" (
Browse files Browse the repository at this point in the history
…#19032) (#19070)

This reverts commit 2a5d123.

Seems to cause accuracy regressions
nod-ai/shark-ai#437

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored Nov 8, 2024
1 parent cf95c94 commit be41632
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 215 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ jobs:
--goldentime-rocm-unet-ms 419.0 \
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1527 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1141 \
--goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2280000 \
Expand All @@ -241,7 +241,7 @@ jobs:
--goldentime-rocm-unet-ms 95.0 \
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1527 \
--goldendispatch-rocm-unet 1531 \
--goldendispatch-rocm-clip 1141 \
--goldendispatch-rocm-vae 246 \
--goldensize-rocm-unet-bytes 2270000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
Expand Down Expand Up @@ -108,6 +107,25 @@ static bool isEmptyFillContractionDAGRootOp(
return true;
}

/// Check that a given operation is "horizontal" to the group. The operation
/// is horizontal if the `slice` of the operation does not contain any op
/// from the group.
static bool isHorizontalToGroup(Operation *op,
const llvm::SetVector<Operation *> &currGroup,
const DominanceInfo &dominanceInfo,
Operation *seedOp) {
BackwardSliceOptions options;
// Limit the slice to the seed to make sure the slice is small.
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, seedOp);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);
return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
return slice.contains(groupedOp);
});
}

/// Get user of operation that is a truncate operation.
static std::optional<linalg::GenericOp>
getTruncateOp(Operation *op,
Expand All @@ -131,8 +149,8 @@ getTruncateOp(Operation *op,
if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) {
return std::nullopt;
}
if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(),
dominanceInfo, seedTruncateOp.value())) {
if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo,
seedTruncateOp.value())) {
return std::nullopt;
}
}
Expand Down Expand Up @@ -208,8 +226,7 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) {
return false;
}
if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo,
seedOp)) {
if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) {
return false;
}
return true;
Expand Down Expand Up @@ -329,6 +346,40 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter,
return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0);
}

/// During horizontal fusion, there might be operands of the fused operations
/// whose definitions are interspersed between the fused operations. For groups
/// chosen to fuse horizontally, such operations can be moved before the
/// seed contraction operation (where the fused operation is generated).
template <typename T>
static LogicalResult
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
Operation *insertionPoint, DominanceInfo &dominanceInfo,
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
BackwardSliceOptions options;
llvm::DenseSet<Operation *> ignoreOperationsSet;
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
!ignoreOperationsSet.contains(op);
};
// Set inclusive to true cause the slice is computed from the operand, and
// we want to include the defining op (which is the point here)
options.inclusive = true;

llvm::SetVector<Operation *> slice;
for (auto op : operations) {
for (auto operand : op->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
}

mlir::topologicalSort(slice);
for (auto op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

/// On finding this pattern
/// ```
/// %0 = linalg.matmul ins(%arg0, %arg1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -49,55 +45,25 @@ static llvm::cl::opt<int64_t> clLinalgMaxConstantFoldElements(
llvm::cl::desc("Maximum number of elements to try to constant fold."),
llvm::cl::init(0));

static Operation *getMostDominantUse(Operation *op,
const DominanceInfo &dominanceInfo) {
auto uses = op->getUses();
auto it = llvm::find_if(uses, [&](OpOperand &source) {
Operation *sourceOp = source.getOwner();

return llvm::all_of(uses, [&](OpOperand &target) {
Operation *targetOp = target.getOwner();
return dominanceInfo.dominates(sourceOp, targetOp);
});
});
if (it != uses.end()) {
return it->getOwner();
}
return nullptr;
}

/// Check if any of the use dominates all other uses of the operation.
static Operation *getFusableUse(Operation *op,
const DominanceInfo &dominanceInfo) {
static std::optional<OpOperand *> getFusableUse(Operation *op,
DominanceInfo &dominanceInfo) {
auto uses = op->getUses();
Operation *fusableUse = nullptr;
for (OpOperand &source : uses) {
Operation *sourceOp = source.getOwner();

bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) {
bool dominatesAllUsers = true;
for (OpOperand &target : uses) {
Operation *targetOp = target.getOwner();
return !isa<linalg::GenericOp>(targetOp) ||
dominanceInfo.dominates(sourceOp, targetOp);
});
if (dominatesAllFusableOps) {
fusableUse = sourceOp;
break;
if (!dominanceInfo.dominates(sourceOp, targetOp)) {
dominatesAllUsers = false;
break;
}
}
if (dominatesAllUsers) {
return &source;
}
}
Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo);
if (!fusableUse || !mostDominantOp) {
return nullptr;
}

// If `fusableUse` dominates all other users, there's nothing else to do.
if (fusableUse == mostDominantOp) {
return fusableUse;
}

SmallVector<Operation *> users(op->getUsers().begin(), op->getUsers().end());
return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp)
? fusableUse
: nullptr;
return std::nullopt;
}

static OpOperand *getFirstUseInConsumer(Operation *producer,
Expand Down Expand Up @@ -125,7 +91,6 @@ static SmallVector<OpOperand *> getAllUsesInConsumer(Operation *producer,
/// using elementwise fusion.
static LogicalResult doMultiUseFusion(Operation *rootOp,
llvm::SetVector<Operation *> &fusableOps,
const DominanceInfo &dominanceInfo,
RewriterBase &rewriter) {
assert(rootOp && "root op cant be null");

Expand All @@ -147,20 +112,11 @@ static LogicalResult doMultiUseFusion(Operation *rootOp,
Operation *consumerOp = rootOp;
OpBuilder::InsertionGuard g(rewriter);
for (Operation *producerOp : llvm::reverse(fusedOpsVec)) {
Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo);
// Fuse all uses from producer -> consumer. It has been checked
// before that all uses are fusable.
while (OpOperand *fusedOperand =
getFirstUseInConsumer(producerOp, consumerOp)) {
rewriter.setInsertionPoint(consumerOp);

if (consumerOp != mostDominantUser &&
failed(moveOperandDefs(rewriter, ArrayRef<Operation *>{consumerOp},
mostDominantUser, dominanceInfo))) {
return rewriter.notifyMatchFailure(consumerOp,
"failed to move operand defs");
}
rewriter.moveOpBefore(consumerOp, mostDominantUser);
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
linalg::fuseElementwiseOps(rewriter, fusedOperand);
if (failed(fusionResult)) {
Expand Down Expand Up @@ -234,8 +190,9 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
}

// 6. Check that the `genericOp` dominates all uses of `producer`.
Operation *fusableUse = getFusableUse(producer, dominanceInfo);
if (!fusableUse || fusableUse != genericOp) {
std::optional<OpOperand *> fusableUse =
getFusableUse(producer, dominanceInfo);
if (!fusableUse || fusableUse.value()->getOwner() != genericOp) {
continue;
}

Expand Down Expand Up @@ -275,8 +232,7 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,

IRRewriter rewriter(context);
for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) {
if (failed(
doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) {
if (failed(doMultiUseFusion(it->first, it->second, rewriter))) {
return funcOp->emitOpError("failed multi use fusion");
}
}
Expand Down
22 changes: 0 additions & 22 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
#include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h"
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/RegionUtils.h"

namespace mlir::iree_compiler::DispatchCreation {

Expand Down Expand Up @@ -101,22 +97,4 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
return true;
}

bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
const DominanceInfo &dominanceInfo,
Operation *seedOp) {
assert(dominanceInfo.properlyDominates(seedOp, op) &&
op->getParentRegion() == seedOp->getParentRegion());
BackwardSliceOptions options;
options.omitUsesFromAbove = false;
// Limit the slice to the seed to make sure the slice is small.
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, seedOp);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);
return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
return slice.contains(groupedOp);
});
}

} // namespace mlir::iree_compiler::DispatchCreation
45 changes: 0 additions & 45 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"

namespace mlir::iree_compiler::DispatchCreation {
Expand All @@ -23,45 +19,4 @@ namespace mlir::iree_compiler::DispatchCreation {
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
bool fuseMultiReduction);

/// Check that a given operation is "horizontal" to the group. The operation
/// is horizontal if the program slice of the operation (from op back to seedOp)
/// does not contain any op from the group.
bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
const DominanceInfo &dominanceInfo, Operation *seedOp);

/// Moves the operands and transitive defs for each op in `operations` directly
/// after `insertionPoint`. Note: this does not check if it is legal to move the
/// operands.
template <typename T>
static LogicalResult
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
Operation *insertionPoint, const DominanceInfo &dominanceInfo,
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
BackwardSliceOptions options;
options.omitUsesFromAbove = false;
llvm::DenseSet<Operation *> ignoreOperationsSet;
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
!ignoreOperationsSet.contains(op);
};
// Set inclusive to true cause the slice is computed from the operand, and
// we want to include the defining op (which is the point here)
options.inclusive = true;

llvm::SetVector<Operation *> slice;
for (auto op : operations) {
assert(insertionPoint->getBlock() == op->getBlock());
for (auto operand : op->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
}

mlir::topologicalSort(slice);
for (auto op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

} // namespace mlir::iree_compiler::DispatchCreation
Loading

0 comments on commit be41632

Please sign in to comment.