-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes for maxpool #1664
base: main
Are you sure you want to change the base?
Fixes for maxpool #1664
Conversation
Change e748e71 introduced Debug build brake. Debug build fails with: ``` /usr/bin/ld: /usr/bin/ld: DWARF error: invalid or unhandled FORM value: 0x23 lib/libMLIRTTIRPipelines.a(TTIRPipelines.cpp.o): in function `mlir::tt::ttir::createLinalgToLLVMPipeline(mlir::OpPassManager&, mlir::tt::ttir::LinalgToLLVMPipelineOptions const&)': TTIRPipelines.cpp:(.text+0x50): undefined reference to `mlir::createConvertElementwiseToLinalgPass()' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x7d): undefined reference to `mlir::createConvertTensorToLinalgPass()' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x118): undefined reference to `mlir::bufferization::buildBufferDeallocationPipeline(mlir::OpPassManager&, mlir::bufferization::BufferDeallocationPipelineOptions const&)' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x12f): undefined reference to `mlir::createBufferizationToMemRefPass()' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x165): undefined reference to `mlir::createConvertLinalgToLoopsPass()' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x19b): undefined reference to `mlir::memref::createExpandStridedMetadataPass()' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x1d1): undefined reference to `mlir::createConvertSCFToCFPass()' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x207): undefined reference to `mlir::createConvertControlFlowToLLVMPass()' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x23d): undefined reference to `mlir::createArithToLLVMConversionPass()' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x273): undefined reference to `mlir::createConvertFuncToLLVMPass()' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x2a9): undefined reference to `mlir::createFinalizeMemRefToLLVMConversionPass()' /usr/bin/ld: TTIRPipelines.cpp:(.text+0x2df): undefined reference to `mlir::createReconcileUnrealizedCastsPass()' clang++-17: error: linker command failed with exit code 1 (use -v to see invocation) ``` This PR adds missing libs. I've checked build times and executable sizes with added libs. There is couple of seconds added to build time. Size of binary increase is negligible. Regardless there is no other way we have to include these libs. Debug build time: 72 sec Debug executable size: * opt - 961MB * translate - 679MB Release build time: 81 sec Release executable size: * opt - 881MB * translate - 634MB Debug build time: 75sec Debug executable size: * opt - 953MB * translate - 679MB Release build time: 73 sec Release executable size: * opt 873MB * translate 634MB
@LPanosTT while we were checking spec of reduce_window we noticed that we are not using base dilation in pooling op. Is this something that should be fixed?
this is lowered into |
Before I begin looking into the PR, I want to sincerely thank you for the detailed description of the change! |
Yes we should probably add it. I do not remember what exactly it does, but if we run into a model where it diverges from its default value and we need to handle it we will want to have the attribute in the TTIR op. |
@mtopalovicTT For forge (which is NCHW) layout you should be able to lower the forge graph’s maxpool2d by either inserting permutations during lowering, or creating a For example if you have an NCHW input and you want to perform a pool with a 3x3 window, the If the TTIRToTTIRDecomposition pass doesn’t already do that I can take a look on Monday EDIT: You will need to format the For clarification, channel-last pools would have |
@LPanosTT this sounds like ok proposal, but we want to avoid adding decompositions into forge. Forge has very simple lowering logic, it maps ForgeOp to TTIROp 1-1. Regardless of that some code refactoring was needed. |
@mtopalovicTT You can convert the forge maxpool op to If your forge maxpool has attributes:
The
I might be missing some other attributes here but that's the idea. |
@LPanosTT yea yea I understood it. What I was saying is that lowering from forge to mlir is straight forward in sense that there are no special cases for any op. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for great PR description. Beside comments, write a dot at the end of the comment sentences in code. Overall looks good.
if (currentLayout[i] == NON_SPATIAL) { | ||
currentLayout[i] = nonSpatialCount; | ||
nonSpatialCount++; | ||
constexpr int32_t NON_SPARTIAL_PLACEHOLDER = -1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: spartial
nonSpatialCount++; | ||
constexpr int32_t NON_SPARTIAL_PLACEHOLDER = -1; | ||
|
||
// Get indices of spartial dimensions (height, width) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
auto paddingRightAttr = | ||
rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1] + 1]); | ||
// Desired layout for any pooling 2d operation is NCHW | ||
SmallVector<int32_t, 4> desiredLayout = {0, 1, 2, 3}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to define enum with DIM_N
..., then using magic numbers.
|
||
// Use intermediateResult to store the result of each transformation | ||
// and replace the input tensor with it at the end | ||
Value intermediateResult = adaptor.getInput(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is unnecessary. create
, create
,..., create
, replaceOpWithNewOp
chain should achieve the same effect.
llvm::ArrayRef<int64_t> inputShape = inputTy.getShape(); | ||
llvm::ArrayRef<int64_t> outputShape = outputTy.getShape(); | ||
|
||
Value intermediateInput = adaptor.getInput(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as in decomposition.
} | ||
|
||
// Helper function to verify 2D pooling operations | ||
bool verify2DPooling(mlir::tt::ttir::PoolingOp &op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass by value MLIR classes (ADT excluded).
} | ||
|
||
// Check if the true window dimensions and strides indices match | ||
if (!(trueWindowDimensionsIndices == trueWindowStrideIndices)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!=
instead of !(... == ...)
.
} | ||
|
||
// Helper function to verify 2D pooling operations | ||
bool verify2DPooling(mlir::tt::ttir::PoolingOp &op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for writing this in the separate function? Helper functions should also be static
in cpp files.
inputPermuteOp.getOperation()->use_end()); | ||
// If number of users of the input permutation is more than 1, then we | ||
// cannot merge or remove the permutation. | ||
if (numUsers > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not replacing producer, so it shouldn't matter if it has more than one user. PermuteOp
is defined as Pure
, so explicit erase
of producer is not necessary, if it has at least one user (other than argument op
) it will stay for those users, otherwise it will be erased during dead value removal. I agree that merging in that case isn't as beneficial as in case where we can remove producer, but still I believe we shouldn't treat it as a special case.
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir | ||
// RUN: FileCheck %s --input-file=%t.mlir | ||
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn | ||
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't silicon test run the whole pipeline?
Misc maxpool fixes
This PR fixes decomposition of max pool so it runs correctly on silicon.
State today:
TTIR_Pooling
op which is used as general op for all pooling operations. StableHLO lowers its pooling operations into this opTTIR_MaxPool2d
ifrank(input) == 4
nadpooling_method == Max
NHWC
layoutTTIR_MaxPool2d
intoTTNN
:MaxPool
into1,1,(NHW),C
layout using reshape. TTNN lib requires this shape for maxpoolMaxPool
to original shapeAbove is fine for StableHLO, but for Forge MaxPool input is usually in
NCHW
layout which when converted into TTNN reshapes input tensor into1,1(NCH),W
and generates bad output. So we need some other approach on how to deal with this.Changes:
TTIR_MaxPool
now has new attributechannel_last
which indicates if input tensor is in channel last format i.eNHWC
. Forge has this attribute and passes it into MLIR.TTIR_Pooling
intoTTIR_MaxPool2d
will no longer rewrite input to layoutNHWC
. Instead it will rewrite it intoNCHW
and it will setchannel_last
attribute to 0.CanonicalizeChannelLastMaxPoolPattern
which will rewrite allTTIR_MaxPool2d
which have layout inNCHW
(i.echannel_last
is set to 0) intoNHWC
NHWC
layoutNCHW
Other changes:
TTIR_Pooling
checks from decomposition into verify methodNHWC
so no decompositions are neededTTNN_MaxPool2d
is no longer DPS opTTIR_Permute
now has canonicalizer (I will probably move it to `TTNN_Permute). This is needed because I'm generating two permutes in the row. Canonicalization rules are:0, 1, 2, 3
is no op)permute_folding.mlir
fixes #1575