From 8906b7be918be653d3c5f2ef3dbd923561603969 Mon Sep 17 00:00:00 2001 From: srcarroll <50210727+srcarroll@users.noreply.github.com> Date: Thu, 26 Dec 2024 11:32:51 -0600 Subject: [PATCH] Enable custom alloc-like ops in `promoteBufferResultsToOutParams` (#120288) In `buffer-results-to-out-params`, when `hoist-static-allocs` option is enabled the pass was looking for `memref.alloc`s in order to attempt to avoid copies when it can. Which makes it not extensible to external ops that have allocation like properties. This patch simply changes `memref::AllocOp` to `AllocationOpInterface` in the check to enable for any allocation op. Moreover, for function call updates, we enable setting an allocation function callback in `BufferResultsToOutParamsOpts` to allow users to emit their own alloc-like op. --- .../Dialect/Bufferization/Transforms/Passes.h | 25 ++++++++++++-- .../Transforms/BufferResultsToOutParams.cpp | 34 ++++++++++--------- 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index fe43a05c81fdc..c8e456a1d7e38 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -2,10 +2,12 @@ #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" namespace mlir { class FunctionOpInterface; +class MemRefType; class ModuleOp; class RewritePatternSet; class OpBuilder; @@ -38,7 +40,7 @@ std::unique_ptr createOwnershipBasedBufferDeallocationPass( DeallocationOptions options = DeallocationOptions()); /// Creates a pass that finds all temporary allocations -/// and attempts to move the deallocation after the last user/dependency +/// and attempts to move the deallocation after the last user/dependency /// of the allocation, thereby optimizing allocation liveness. std::unique_ptr createOptimizeAllocationLivenessPass(); @@ -157,6 +159,12 @@ std::unique_ptr createBufferLoopHoistingPass(); // Options struct for BufferResultsToOutParams pass. // Note: defined only here, not in tablegen. struct BufferResultsToOutParamsOpts { + /// Allocator function: Generate a memref allocation with the given type. + /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped + /// results, we don't allow passing a range of values for dynamic dims. + using AllocationFn = + std::function(OpBuilder &, Location, MemRefType)>; + /// Memcpy function: Generate a memcpy between two memrefs. using MemCpyFn = std::function; @@ -167,9 +175,20 @@ struct BufferResultsToOutParamsOpts { return true; }; + /// Allocation function; used to allocate a memref. + /// Default memref.alloc is used + AllocationFn allocationFn = [](OpBuilder &builder, Location loc, + MemRefType type) { + return builder.create(loc, type).getResult(); + }; + /// Memcpy function; used to create a copy between two memrefs. - /// If this is empty, memref.copy is used. - std::optional memCpyFn; + /// Default memref.copy is used. + MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from, + Value to) { + builder.create(loc, from, to); + return success(); + }; /// If true, the pass adds a "bufferize.result" attribute to each output /// parameter. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index b7755b2be8483..2502744cb3f58 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -21,6 +22,7 @@ namespace bufferization { } // namespace mlir using namespace mlir; +using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn; using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; /// Return `true` if the given MemRef type has a fully dynamic layout. @@ -105,10 +107,9 @@ updateFuncOp(func::FuncOp func, // Updates all ReturnOps in the scope of the given func::FuncOp by either // keeping them as return values or copying the associated buffer contents into // the given out-params. -static LogicalResult updateReturnOps(func::FuncOp func, - ArrayRef appendedEntryArgs, - MemCpyFn memCpyFn, - bool hoistStaticAllocs) { +static LogicalResult +updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, + const bufferization::BufferResultsToOutParamsOpts &options) { auto res = func.walk([&](func::ReturnOp op) { SmallVector copyIntoOutParams; SmallVector keepAsReturnOperands; @@ -120,13 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func, } OpBuilder builder(op); for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { - if (hoistStaticAllocs && - isa_and_nonnull(orig.getDefiningOp()) && + if (options.hoistStaticAllocs && + isa_and_nonnull( + orig.getDefiningOp()) && mlir::cast(orig.getType()).hasStaticShape()) { orig.replaceAllUsesWith(arg); orig.getDefiningOp()->erase(); } else { - if (failed(memCpyFn(builder, op.getLoc(), orig, arg))) + if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg))) return WalkResult::interrupt(); } } @@ -175,7 +177,14 @@ updateCalls(ModuleOp module, auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), AffineMap(), memrefType.getMemorySpace()); - Value outParam = builder.create(op.getLoc(), allocType); + auto maybeOutParam = + options.allocationFn(builder, op.getLoc(), allocType); + if (failed(maybeOutParam)) { + op.emitError() << "failed to create allocation op"; + didFail = true; + return; + } + Value outParam = maybeOutParam.value(); if (!hasStaticIdentityLayout(memrefType)) { // Layout maps are already checked in `updateFuncOp`. assert(hasFullyDynamicLayoutMap(memrefType) && @@ -213,14 +222,7 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( return failure(); if (func.isExternal()) continue; - auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from, - Value to) { - builder.create(loc, from, to); - return success(); - }; - if (failed(updateReturnOps(func, appendedEntryArgs, - options.memCpyFn.value_or(defaultMemCpyFn), - options.hoistStaticAllocs))) { + if (failed(updateReturnOps(func, appendedEntryArgs, options))) { return failure(); } }