Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for maxpool #1664

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Fixes for maxpool #1664

wants to merge 2 commits into from

Conversation

mtopalovicTT
Copy link
Contributor

@mtopalovicTT mtopalovicTT commented Dec 25, 2024

Misc maxpool fixes

This PR fixes decomposition of max pool so it runs correctly on silicon.

State today:

  • We have TTIR_Pooling op which is used as general op for all pooling operations. StableHLO lowers its pooling operations into this op
  • This op is the decomposed into TTIR_MaxPool2d if rank(input) == 4 nad pooling_method == Max
    • During this we generate permutation which should rewrite input into NHWC layout
    • We generate inverse permutation to restore to original layout
  • Next during conversion pass of TTIR_MaxPool2d into TTNN:
    • We rewrite input of MaxPool into 1,1,(NHW),C layout using reshape. TTNN lib requires this shape for maxpool
    • We generate inverse reshape to bring output of MaxPool to original shape

Above is fine for StableHLO, but for Forge MaxPool input is usually in NCHW layout which when converted into TTNN reshapes input tensor into
1,1(NCH),W and generates bad output. So we need some other approach on how to deal with this.

Changes:

  • TTIR_MaxPool now has new attribute channel_last which indicates if input tensor is in channel last format i.e NHWC. Forge has this attribute
and passes it into MLIR.
  • Decomposition of TTIR_Pooling into TTIR_MaxPool2d will no longer rewrite input to layout NHWC. Instead it will rewrite it into NCHW and it will
set channel_last attribute to 0.
  • We are adding new pattern rewriter CanonicalizeChannelLastMaxPoolPattern which will rewrite all TTIR_MaxPool2d which have layout in NCHW (i.e channel_last is set to 0)
into NHWC
    • We generate permutation which should rewrite input into NHWC layout
    • We generate inverse permutation to restore to original layout NCHW

Other changes:

  • Move TTIR_Pooling checks from decomposition into verify method
  • Some code cleanup to make stuff more readable/understandable
  • Tosa by default has NHWC so no decompositions are needed
  • TTNN_MaxPool2d is no longer DPS op
  • TTIR_Permute now has canonicalizer (I will probably move it to `TTNN_Permute). This is needed because I'm generating two permutes in the row. Canonicalization rules are:
    • Remove identity permutations (i.e permutation 0, 1, 2, 3 is no op)
    • If input is another permute op it will try to merge them into one op if input doesn’t have more then 1 user
      • If merge permutation is identity permutation we remove it
      • Otherwise we fold two permutations into one
    • All above are tested in permute_folding.mlir

fixes #1575

mtopalovicTT and others added 2 commits December 23, 2024 12:24
Change e748e71 introduced Debug build
brake.

Debug build fails with:
```
/usr/bin/ld: /usr/bin/ld: DWARF error: invalid or unhandled FORM value: 0x23
lib/libMLIRTTIRPipelines.a(TTIRPipelines.cpp.o): in function `mlir::tt::ttir::createLinalgToLLVMPipeline(mlir::OpPassManager&, mlir::tt::ttir::LinalgToLLVMPipelineOptions const&)':
TTIRPipelines.cpp:(.text+0x50): undefined reference to `mlir::createConvertElementwiseToLinalgPass()'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x7d): undefined reference to `mlir::createConvertTensorToLinalgPass()'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x118): undefined reference to `mlir::bufferization::buildBufferDeallocationPipeline(mlir::OpPassManager&, mlir::bufferization::BufferDeallocationPipelineOptions const&)'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x12f): undefined reference to `mlir::createBufferizationToMemRefPass()'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x165): undefined reference to `mlir::createConvertLinalgToLoopsPass()'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x19b): undefined reference to `mlir::memref::createExpandStridedMetadataPass()'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x1d1): undefined reference to `mlir::createConvertSCFToCFPass()'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x207): undefined reference to `mlir::createConvertControlFlowToLLVMPass()'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x23d): undefined reference to `mlir::createArithToLLVMConversionPass()'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x273): undefined reference to `mlir::createConvertFuncToLLVMPass()'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x2a9): undefined reference to `mlir::createFinalizeMemRefToLLVMConversionPass()'
/usr/bin/ld: TTIRPipelines.cpp:(.text+0x2df): undefined reference to `mlir::createReconcileUnrealizedCastsPass()'
clang++-17: error: linker command failed with exit code 1 (use -v to see invocation)
```

This PR adds missing libs. I've checked build times and executable sizes
with added libs. There is couple of seconds added to build time. Size of
binary increase is negligible. Regardless there is no other way we have
to include these libs.

Debug build time: 72 sec
Debug executable size:
* opt - 961MB
* translate - 679MB

Release build time: 81 sec
Release executable size:
* opt - 881MB
* translate - 634MB

Debug build time: 75sec
Debug executable size:
* opt - 953MB
* translate - 679MB

Release build time: 73 sec
Release executable size:
* opt 873MB
* translate 634MB
@mtopalovicTT mtopalovicTT changed the title Milant/fix max pool Fixes for maxpool Dec 25, 2024
@mtopalovicTT
Copy link
Contributor Author

@LPanosTT while we were checking spec of reduce_window we noticed that we are not using base dilation in pooling op. Is this something that should be fixed?

    baseDilations = baseDilations
                        ? baseDilations
                        : rewriter.getDenseI64ArrayAttr(
                              SmallVector<int64_t>(windowDimensions.size(), 1));

this is lowered into TTIR_PoolingOp but not used at all any further.

@sdjordjevicTT
Copy link
Contributor

Before I begin looking into the PR, I want to sincerely thank you for the detailed description of the change!

@LPanosTT
Copy link
Contributor

@LPanosTT while we were checking spec of reduce_window we noticed that we are not using base dilation in pooling op. Is this something that should be fixed?

    baseDilations = baseDilations
                        ? baseDilations
                        : rewriter.getDenseI64ArrayAttr(
                              SmallVector<int64_t>(windowDimensions.size(), 1));

this is lowered into TTIR_PoolingOp but not used at all any further.

Yes we should probably add it. I do not remember what exactly it does, but if we run into a model where it diverges from its default value and we need to handle it we will want to have the attribute in the TTIR op.

@LPanosTT
Copy link
Contributor

LPanosTT commented Dec 26, 2024

@mtopalovicTT For forge (which is NCHW) layout you should be able to lower the forge graph’s maxpool2d by either inserting permutations during lowering, or creating a TTIR_PoolingOp with the window dimensions set for a NCHW configuration.

For example if you have an NCHW input and you want to perform a pool with a 3x3 window, the window_dimensions attribute should be set to [1, 1, 3, 3] Since it is channel last. The TTIRToTTIRDecomposition pass should be able to lower that into a TTIR_MaxPool2d with the desired transposes/reshapes around to convert NCHW to NHWC, and then to (1, 1, NHW, C).

If the TTIRToTTIRDecomposition pass doesn’t already do that I can take a look on Monday

EDIT: You will need to format the window_strides similarly. So if the pool is channel-first and you want the stride to be 2x2, you should have window_strides=[1, 1, 2, 2].

For clarification, channel-last pools would have window_dimensions=[1, 3, 3, 1] and window_strides=[1, 2, 2, 1] since the spatial dimensions are in the middle in that case.

@mtopalovicTT
Copy link
Contributor Author

@mtopalovicTT For forge (which is NCHW) layout you should be able to lower the forge graph’s maxpool2d by either inserting permutations during lowering, or creating a TTIR_PoolingOp with the window dimensions set for a NCHW configuration.

For example if you have an NCHW input and you want to perform a pool with a 3x3 window, the window_dimensions attribute should be set to [1, 1, 3, 3] Since it is channel last. The TTIRToTTIRDecomposition pass should be able to lower that into a TTIR_MaxPool2d with the desired transposes/reshapes around to convert NCHW to NHWC, and then to (1, 1, NHW, C).

If the TTIRToTTIRDecomposition pass doesn’t already do that I can take a look on Monday

EDIT: You will need to format the window_strides similarly. So if the pool is channel-first and you want the stride to be 2x2, you should have window_strides=[1, 1, 2, 2].

For clarification, channel-last pools would have window_dimensions=[1, 3, 3, 1] and window_strides=[1, 2, 2, 1] since the spatial dimensions are in the middle in that case.

@LPanosTT this sounds like ok proposal, but we want to avoid adding decompositions into forge. Forge has very simple lowering logic, it maps ForgeOp to TTIROp 1-1.

Regardless of that some code refactoring was needed.

@LPanosTT
Copy link
Contributor

LPanosTT commented Dec 26, 2024

@mtopalovicTT You can convert the forge maxpool op to TTIR_PoolingOp 1-1.

If your forge maxpool has attributes:

channel_first=True
window_size=(window_height, window_width)
strides=(stride_height, stride_width)
padding=(top, bottom, left, right)

The TTIR_PoolingOp would have

method: MAX // Pooling method
window_dimensions=(1, 1, window_height, window_width)
window_strides=(1, 1, stride_height, stride_width)
padding=[[0, 0], [0, 0], [top, bottom], [left, right]]

I might be missing some other attributes here but that's the idea.

@mtopalovicTT
Copy link
Contributor Author

@LPanosTT yea yea I understood it. What I was saying is that lowering from forge to mlir is straight forward in sense that there are no special cases for any op.

Copy link
Contributor

@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for great PR description. Beside comments, write a dot at the end of the comment sentences in code. Overall looks good.

if (currentLayout[i] == NON_SPATIAL) {
currentLayout[i] = nonSpatialCount;
nonSpatialCount++;
constexpr int32_t NON_SPARTIAL_PLACEHOLDER = -1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: spartial

nonSpatialCount++;
constexpr int32_t NON_SPARTIAL_PLACEHOLDER = -1;

// Get indices of spartial dimensions (height, width)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

auto paddingRightAttr =
rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1] + 1]);
// Desired layout for any pooling 2d operation is NCHW
SmallVector<int32_t, 4> desiredLayout = {0, 1, 2, 3};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to define enum with DIM_N..., then using magic numbers.


// Use intermediateResult to store the result of each transformation
// and replace the input tensor with it at the end
Value intermediateResult = adaptor.getInput();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is unnecessary. create, create,..., create, replaceOpWithNewOp chain should achieve the same effect.

llvm::ArrayRef<int64_t> inputShape = inputTy.getShape();
llvm::ArrayRef<int64_t> outputShape = outputTy.getShape();

Value intermediateInput = adaptor.getInput();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as in decomposition.

}

// Helper function to verify 2D pooling operations
bool verify2DPooling(mlir::tt::ttir::PoolingOp &op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass by value MLIR classes (ADT excluded).

}

// Check if the true window dimensions and strides indices match
if (!(trueWindowDimensionsIndices == trueWindowStrideIndices)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!= instead of !(... == ...).

}

// Helper function to verify 2D pooling operations
bool verify2DPooling(mlir::tt::ttir::PoolingOp &op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for writing this in the separate function? Helper functions should also be static in cpp files.

inputPermuteOp.getOperation()->use_end());
// If number of users of the input permutation is more than 1, then we
// cannot merge or remove the permutation.
if (numUsers > 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not replacing producer, so it shouldn't matter if it has more than one user. PermuteOp is defined as Pure, so explicit erase of producer is not necessary, if it has at least one user (other than argument op) it will stay for those users, otherwise it will be erased during dead value removal. I agree that merging in that case isn't as beneficial as in case where we can remove producer, but still I believe we shouldn't treat it as a special case.

// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't silicon test run the whole pipeline?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Resnet maxpool data mismatch
4 participants