Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing reduction ops #1673

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Fixing reduction ops #1673

wants to merge 3 commits into from

Conversation

mtopalovicTT
Copy link
Contributor

Currently all reduction ops when lowered from forge are failing. Reduction ops in TTIR/TTNN have optional dim_arg attribute which can be used to specify dimension along which reduce should be applied.
Forge uses dim attribute to specify reduction dims, so when lowered to TTIR is completly ignored by our compiler.
This PR fixes the naming of this attribute and also aligns possible values for this attribute from:
OptionalAttr<I32ArrayAttr> -> OptionalAttr<AnyAttrOf<[SI32Attr, I32ArrayAttr]>>

IR before change:

%2 = "ttir.mean"(%arg0, %1) <{dim_arg = [-2 : i32], keep_dim = true}> : (tensor<1x1x49x2048xf32>, tensor<1x1x1x2048xf32>) -> tensor<1x1x1x2048xf32>

IR after change:

%2 = "ttir.mean"(%arg0, %1) <{dim = [-2 : i32], keep_dim = true}> : (tensor<1x1x49x2048xf32>, tensor<1x1x1x2048xf32>) -> tensor<1x1x1x2048xf32>
%2 = "ttir.mean"(%arg0, %1) <{dim = -2 : si32, keep_dim = true}> : (tensor<1x1x49x2048xf32>, tensor<1x1x1x2048xf32>) -> tensor<1x1x1x2048xf32>

fixes #1574

@mrakitaTT
Copy link
Contributor

Forge uses dim attribute to specify reduction dims, so when lowered to TTIR is completly ignored by our compiler.

Just curious of the reasoning, why did you decide to make this change in TTIR instead of Forge? You had to change bunch of code and bunch of tests because of this attribute renaming (and missed some).

If we are already renaming, I would choose dimensions instead of dim, because we allow reduce over multiple dimensions. But it seems more natural to just fix Forge in the first place to serialize this attribute name to dim_arg instead of dim.

@mtopalovicTT
Copy link
Contributor Author

mtopalovicTT commented Dec 27, 2024

@mrakitaTT dim is used in both forge and torch. dim_arg is used in metal.

When I was talking with someone from forge couple days ago they told me they are trying to mirror torch as close as possible with ops so I decided not to do it in forge.

If we were to change it in forge then we would have to either:

  • Fix it on op level in forge - this would also require bunch of test changes in forge because of renaming
  • Fix it during lowering to mlir - I would personally like to avoid this since it would complicate lowering logic which is today very very simple.

But regardless of the naming we still have a issue with attribute type which is not 1-1 what forge can lower.

You had to change bunch of code and bunch of tests because of this attribute renaming (and missed some).

I thought personally change was small and simplified stuff, like getting reduce dims. Unfortunately with these kind of changes it's inevitable to test changes . Regarding tests I fixed that, I didn't have stable hlo flag on...

If we are already renaming, I would choose dimensions

Yea I thought this also, not sure why torch developers decided on that name.

@mrakitaTT
Copy link
Contributor

mrakitaTT commented Dec 27, 2024

When I was talking with someone from forge couple days ago they told me they are trying to mirror torch as close as possible with ops so I decided not to do it in forge.

Yeah but this is not part of the Forge external interface (i.e. changing the Forge ops), this is part of the serialization from Forge graph to TTIR graph.

Fix it during lowering to mlir - I would personally like to avoid this since it would complicate lowering logic which is today very very simple.

I am not too familiar with Forge codebase so please pardon my ignorance, but shouldn't this just be a simple serialization string change?

Edit: I've just checked the Forge repo and I see what you mean, Forge serialization is just matching attribute names from Forge OpNode, but I don't think this is sustainable long term anyways. We can't just adapt TTIR to be 1:1 to Forge, TTIR needs to be general enough WRT many external dialects. We need to pick something that both makes sense to us (for example op or attribute name that makes most sense and is most commonly used in third party dialects) and is also general enough so that every other dialect can easily convert to it (for example we try to have 1:1 conversion for most of the ops in third party dialects).

I see this pattern often where we are focusing too much on Forge, but I would implore everyone to always keep in mind that there are also other frontends and dialects, and to check their definitions too when deciding on something like this. For example we've named some op Select in TTIR just because Forge uses that name, even though all other dialects use that name in context of some other op (Select=Where) (#1675)

But regardless of the naming we still have a issue with attribute type which is not 1-1 what forge can lower.

I am not sure about this change too. Imagine if some other dialect used SI16Attr, some other used SI8Attr, some other used I16ArrayAttr, etc... Should we define the op in TTIR with OptionalAttr<AnyAttrOf<[SI32Attr, I32ArrayAttr, SI16Attr, SI8Attr, I16ArrayAttr, ...]>>? Or should we agree on one definition which is general enough like OptionalAttr<I32ArrayAttr> and require third party dialects to convert to it during ThirdPartyDialect->TTIR conversion? I would argue for the latter.

For example StableHLO uses i64 type for dimensions, so we do this during StableHLO->TTIR conversion:

mlir::ArrayAttr dimArg = rewriter.getI32ArrayAttr(
        llvm::SmallVector<int32_t>(srcOp.getDimensions()));

Arguably we could've also changed TTIR to use i64 instead of i32, though I can't imagine tensor with such large rank ever existing :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Resnet Avg pool in fails due to reshape getting unexpected input shape
2 participants