Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pass for decomposing (log)softmax #6287

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
ConvertSplitToSlicePass,
)
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
DecomposeSoftmaxesPass,
)
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
InsertSqueezeAfterSumPass,
)
Expand Down Expand Up @@ -52,6 +55,7 @@ def transform_to_backend_pipeline(
self.add_pass(DecomposeDivPass())
self.add_pass(InsertSqueezeAfterSumPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(DecomposeSoftmaxesPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand All @@ -63,4 +67,5 @@ def transform_to_backend_pipeline(
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
self.add_pass(DecomposeDivPass())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeSoftmaxesPass())
return self._transform(graph_module)
74 changes: 74 additions & 0 deletions backends/arm/_passes/decompose_softmaxes_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

# For BI case
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)

# For MI case
edge_softmax = (
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten._log_softmax.default,
)

log_softmax = (torch.ops.aten.log_softmax.int, exir_ops.edge.aten._log_softmax.default)


def get_logsoftmax_ops(op) -> tuple:
"""
Returns the the (log_op, expo_op, sum_op, reciprocal_op), where the ops depends on if
the logsoftmax op is in exir_ops torch.ops.aten.
"""
if op in edge_softmax:
return (
exir_ops.edge.aten.log.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit can we use getattr and string concat?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a neat suggestion, I'll keep it in mind when implementing a similar decomposition in the future

exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.mul.Tensor,
)
if op in torch_softmax:
return (
torch.ops.aten.log.default,
torch.ops.aten.exp.default,
torch.ops.aten.sum.dim_IntList,
torch.ops.aten.reciprocal.default,
torch.ops.aten.mul.Tensor,
)
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")


class DecomposeSoftmaxesPass(ExportPass):
"""
This pass decomposes log softmax or softmax into more primitive ops.

Example:
%op1 = exp(x)
%op2 = sum(%op1, dim)
%op3 = reciprocal(%op2)
%op4 = mul(%op1, %op3)
(in logsoftmax case: %op5 = log(%op4))
"""

def call_operator(self, op, args, kwargs, meta):
if op not in torch_softmax + edge_softmax:
return super().call_operator(op, args, kwargs, meta)

log_op, exp_op, sum_op, reciprocal_op, mul_op = get_logsoftmax_ops(op)

_input = args[0]
dim = [args[1]]

op1 = super().call_operator(exp_op, (_input,), {}, meta)
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta)
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta)
op4 = super().call_operator(mul_op, (op1, op3), {}, meta)
if op in log_softmax:
op4 = super().call_operator(log_op, (op4,), {}, meta)
return op4
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
Expand Down
1 change: 0 additions & 1 deletion backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
op_rsqrt,
op_sigmoid,
op_slice,
op_softmax,
op_squeeze,
op_sub,
op_sum,
Expand Down
1 change: 0 additions & 1 deletion backends/arm/operators/op_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def define_node(
) -> None:

assert len(node.all_input_nodes) == 1
assert len(node.users) == 1

if is_quant_node:
# Assume quantized input is 8 bit.
Expand Down
99 changes: 0 additions & 99 deletions backends/arm/operators/op_softmax.py

This file was deleted.

158 changes: 158 additions & 0 deletions backends/arm/test/ops/test_logsoftmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized


test_data_suite = [
# (test_name, test_data, dim)
("zeros", torch.zeros(10, 10, 10, 10), 0),
("zeros_neg_dim", torch.zeros(10, 10, 10, 10), -4),
("ones", torch.ones(10, 10), 1),
("rand_neg_dim", torch.rand(10, 10, 10), -1),
("rand", torch.rand(10, 10, 10, 10), 2),
("rand_neg_dim", torch.rand(10, 10, 2, 3), -2),
("randn", torch.randn(10, 10, 5, 10), 3),
("randn_neg_dim", torch.randn(1, 10, 10, 10), -3),
]


class TestLogSoftmax(unittest.TestCase):
"""Tests logsoftmax."""

class LogSoftmax(torch.nn.Module):
def __init__(self, dim: int = -1):
super().__init__()
self.logsoftmax = torch.nn.LogSoftmax(dim=dim)

def forward(self, x):
return self.logsoftmax(x)

def _test_logsoftmax_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.check(["torch.ops.aten.log_softmax.int"])
.check_not(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_logsoftmax_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.quantize()
.export()
.check_not(["torch.ops.aten.log_softmax.int"])
.check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten__log_softmax_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
)

def _test_logsoftmax_tosa_ethos_BI_pipeline(
self,
compile_spec: list[CompileSpec],
module: torch.nn.Module,
test_data: Tuple[torch.tensor],
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.export()
.check_not(["torch.ops.aten.log_softmax.int"])
.check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

def _test_logsoftmax_tosa_u55_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
self._test_logsoftmax_tosa_ethos_BI_pipeline(
common.get_u55_compile_spec(), module, test_data
)

def _test_logsoftmax_tosa_u85_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
self._test_logsoftmax_tosa_ethos_BI_pipeline(
common.get_u85_compile_spec(), module, test_data
)

@parameterized.expand(test_data_suite)
def test_logsoftmax_tosa_MI(
self,
test_name: str,
test_data: torch.Tensor,
dim: int,
):
self._test_logsoftmax_tosa_MI_pipeline(self.LogSoftmax(dim=dim), (test_data,))

@parameterized.expand(test_data_suite)
def test_logsoftmax_tosa_BI(
self,
test_name: str,
test_data: torch.Tensor,
dim: int,
):
self._test_logsoftmax_tosa_BI_pipeline(self.LogSoftmax(dim=dim), (test_data,))

@parameterized.expand(test_data_suite)
def test_logsoftmax_tosa_u55_BI(
self,
test_name: str,
test_data: torch.Tensor,
dim: int,
):
self._test_logsoftmax_tosa_u55_BI_pipeline(
self.LogSoftmax(dim=dim), (test_data,)
)

@parameterized.expand(test_data_suite)
def test_logsoftmax_tosa_u85_BI(
self,
test_name: str,
test_data: torch.Tensor,
dim: int,
):
self._test_logsoftmax_tosa_u55_BI_pipeline(
self.LogSoftmax(dim=dim), (test_data,)
)
Loading
Loading