Skip to content

Commit

Permalink
Enable custom alloc-like ops in promoteBufferResultsToOutParams (#1…
Browse files Browse the repository at this point in the history
…20288)

In `buffer-results-to-out-params`, when `hoist-static-allocs` option is
enabled the pass was looking for `memref.alloc`s in order to attempt to
avoid copies when it can. Which makes it not extensible to external ops
that have allocation like properties. This patch simply changes
`memref::AllocOp` to `AllocationOpInterface` in the check to enable for
any allocation op.
Moreover, for function call updates, we enable setting an allocation
function callback in `BufferResultsToOutParamsOpts` to allow users to
emit their own alloc-like op.
  • Loading branch information
srcarroll authored Dec 26, 2024
1 parent 831e1ac commit 8906b7b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 19 deletions.
25 changes: 22 additions & 3 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H

#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
class FunctionOpInterface;
class MemRefType;
class ModuleOp;
class RewritePatternSet;
class OpBuilder;
Expand Down Expand Up @@ -38,7 +40,7 @@ std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
DeallocationOptions options = DeallocationOptions());

/// Creates a pass that finds all temporary allocations
/// and attempts to move the deallocation after the last user/dependency
/// and attempts to move the deallocation after the last user/dependency
/// of the allocation, thereby optimizing allocation liveness.
std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();

Expand Down Expand Up @@ -157,6 +159,12 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOpts {
/// Allocator function: Generate a memref allocation with the given type.
/// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
/// results, we don't allow passing a range of values for dynamic dims.
using AllocationFn =
std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;

/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
Expand All @@ -167,9 +175,20 @@ struct BufferResultsToOutParamsOpts {
return true;
};

/// Allocation function; used to allocate a memref.
/// Default memref.alloc is used
AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
MemRefType type) {
return builder.create<memref::AllocOp>(loc, type).getResult();
};

/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;
/// Default memref.copy is used.
MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from,
Value to) {
builder.create<memref::CopyOp>(loc, from, to);
return success();
};

/// If true, the pass adds a "bufferize.result" attribute to each output
/// parameter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -21,6 +22,7 @@ namespace bufferization {
} // namespace mlir

using namespace mlir;
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;

/// Return `true` if the given MemRef type has a fully dynamic layout.
Expand Down Expand Up @@ -105,10 +107,9 @@ updateFuncOp(func::FuncOp func,
// Updates all ReturnOps in the scope of the given func::FuncOp by either
// keeping them as return values or copying the associated buffer contents into
// the given out-params.
static LogicalResult updateReturnOps(func::FuncOp func,
ArrayRef<BlockArgument> appendedEntryArgs,
MemCpyFn memCpyFn,
bool hoistStaticAllocs) {
static LogicalResult
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
const bufferization::BufferResultsToOutParamsOpts &options) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
Expand All @@ -120,13 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
}
OpBuilder builder(op);
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (hoistStaticAllocs &&
isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
if (options.hoistStaticAllocs &&
isa_and_nonnull<bufferization::AllocationOpInterface>(
orig.getDefiningOp()) &&
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
} else {
if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
return WalkResult::interrupt();
}
}
Expand Down Expand Up @@ -175,7 +177,14 @@ updateCalls(ModuleOp module,
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
auto maybeOutParam =
options.allocationFn(builder, op.getLoc(), allocType);
if (failed(maybeOutParam)) {
op.emitError() << "failed to create allocation op";
didFail = true;
return;
}
Value outParam = maybeOutParam.value();
if (!hasStaticIdentityLayout(memrefType)) {
// Layout maps are already checked in `updateFuncOp`.
assert(hasFullyDynamicLayoutMap(memrefType) &&
Expand Down Expand Up @@ -213,14 +222,7 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
Value to) {
builder.create<memref::CopyOp>(loc, from, to);
return success();
};
if (failed(updateReturnOps(func, appendedEntryArgs,
options.memCpyFn.value_or(defaultMemCpyFn),
options.hoistStaticAllocs))) {
if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
return failure();
}
}
Expand Down

0 comments on commit 8906b7b

Please sign in to comment.