-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pass for decomposing (log)softmax #6287
Open
Erik-Lundell
wants to merge
3
commits into
pytorch:main
Choose a base branch
from
Erik-Lundell:op_softmax
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+250
−120
Open
Changes from 2 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.pass_base import ExportPass | ||
|
||
# For BI case | ||
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int) | ||
|
||
# For MI case | ||
edge_softmax = ( | ||
exir_ops.edge.aten._softmax.default, | ||
exir_ops.edge.aten._log_softmax.default, | ||
) | ||
|
||
log_softmax = (torch.ops.aten.log_softmax.int, exir_ops.edge.aten._log_softmax.default) | ||
|
||
|
||
def get_logsoftmax_ops(op) -> tuple: | ||
""" | ||
Returns the the (log_op, expo_op, sum_op, reciprocal_op), where the ops depends on if | ||
the logsoftmax op is in exir_ops torch.ops.aten. | ||
""" | ||
if op in edge_softmax: | ||
return ( | ||
exir_ops.edge.aten.log.default, | ||
exir_ops.edge.aten.exp.default, | ||
exir_ops.edge.aten.sum.dim_IntList, | ||
exir_ops.edge.aten.reciprocal.default, | ||
exir_ops.edge.aten.mul.Tensor, | ||
) | ||
if op in torch_softmax: | ||
return ( | ||
torch.ops.aten.log.default, | ||
torch.ops.aten.exp.default, | ||
torch.ops.aten.sum.dim_IntList, | ||
torch.ops.aten.reciprocal.default, | ||
torch.ops.aten.mul.Tensor, | ||
) | ||
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}") | ||
|
||
|
||
class DecomposeSoftmaxesPass(ExportPass): | ||
""" | ||
This pass decomposes log softmax or softmax into more primitive ops. | ||
|
||
Example: | ||
%op1 = exp(x) | ||
%op2 = sum(%op1, dim) | ||
%op3 = reciprocal(%op2) | ||
%op4 = mul(%op1, %op3) | ||
(in logsoftmax case: %op5 = log(%op4)) | ||
""" | ||
|
||
def call_operator(self, op, args, kwargs, meta): | ||
if op not in torch_softmax + edge_softmax: | ||
return super().call_operator(op, args, kwargs, meta) | ||
|
||
log_op, exp_op, sum_op, reciprocal_op, mul_op = get_logsoftmax_ops(op) | ||
|
||
_input = args[0] | ||
dim = [args[1]] | ||
|
||
op1 = super().call_operator(exp_op, (_input,), {}, meta) | ||
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta) | ||
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta) | ||
op4 = super().call_operator(mul_op, (op1, op3), {}, meta) | ||
if op in log_softmax: | ||
op4 = super().call_operator(log_op, (op4,), {}, meta) | ||
return op4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,6 @@ | |
op_rsqrt, | ||
op_sigmoid, | ||
op_slice, | ||
op_softmax, | ||
op_squeeze, | ||
op_sub, | ||
op_sum, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from executorch.exir.backend.compile_spec_schema import CompileSpec | ||
from parameterized import parameterized | ||
|
||
|
||
test_data_suite = [ | ||
# (test_name, test_data, dim) | ||
("zeros", torch.zeros(10, 10, 10, 10), 0), | ||
("zeros_neg_dim", torch.zeros(10, 10, 10, 10), -4), | ||
("ones", torch.ones(10, 10), 1), | ||
("rand_neg_dim", torch.rand(10, 10, 10), -1), | ||
("rand", torch.rand(10, 10, 10, 10), 2), | ||
("rand_neg_dim", torch.rand(10, 10, 2, 3), -2), | ||
("randn", torch.randn(10, 10, 5, 10), 3), | ||
("randn_neg_dim", torch.randn(1, 10, 10, 10), -3), | ||
] | ||
|
||
|
||
class TestLogSoftmax(unittest.TestCase): | ||
"""Tests logsoftmax.""" | ||
|
||
class LogSoftmax(torch.nn.Module): | ||
def __init__(self, dim: int = -1): | ||
super().__init__() | ||
self.logsoftmax = torch.nn.LogSoftmax(dim=dim) | ||
|
||
def forward(self, x): | ||
return self.logsoftmax(x) | ||
|
||
def _test_logsoftmax_tosa_MI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.export() | ||
.check(["torch.ops.aten.log_softmax.int"]) | ||
.check_not(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_logsoftmax_tosa_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.quantize() | ||
.export() | ||
.check_not(["torch.ops.aten.log_softmax.int"]) | ||
.check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten__log_softmax_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data, qtol=1) | ||
) | ||
|
||
def _test_logsoftmax_tosa_ethos_BI_pipeline( | ||
self, | ||
compile_spec: list[CompileSpec], | ||
module: torch.nn.Module, | ||
test_data: Tuple[torch.tensor], | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=compile_spec, | ||
) | ||
.quantize() | ||
.export() | ||
.check_not(["torch.ops.aten.log_softmax.int"]) | ||
.check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
) | ||
|
||
def _test_logsoftmax_tosa_u55_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
self._test_logsoftmax_tosa_ethos_BI_pipeline( | ||
common.get_u55_compile_spec(), module, test_data | ||
) | ||
|
||
def _test_logsoftmax_tosa_u85_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
self._test_logsoftmax_tosa_ethos_BI_pipeline( | ||
common.get_u85_compile_spec(), module, test_data | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_logsoftmax_tosa_MI( | ||
self, | ||
test_name: str, | ||
test_data: torch.Tensor, | ||
dim: int, | ||
): | ||
self._test_logsoftmax_tosa_MI_pipeline(self.LogSoftmax(dim=dim), (test_data,)) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_logsoftmax_tosa_BI( | ||
self, | ||
test_name: str, | ||
test_data: torch.Tensor, | ||
dim: int, | ||
): | ||
self._test_logsoftmax_tosa_BI_pipeline(self.LogSoftmax(dim=dim), (test_data,)) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_logsoftmax_tosa_u55_BI( | ||
self, | ||
test_name: str, | ||
test_data: torch.Tensor, | ||
dim: int, | ||
): | ||
self._test_logsoftmax_tosa_u55_BI_pipeline( | ||
self.LogSoftmax(dim=dim), (test_data,) | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_logsoftmax_tosa_u85_BI( | ||
self, | ||
test_name: str, | ||
test_data: torch.Tensor, | ||
dim: int, | ||
): | ||
self._test_logsoftmax_tosa_u55_BI_pipeline( | ||
self.LogSoftmax(dim=dim), (test_data,) | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit can we use
getattr
and string concat?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a neat suggestion, I'll keep it in mind when implementing a similar decomposition in the future