From 673d825a0605ad80aae5bc2bbde709fa542e634d Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 10 Sep 2024 14:49:45 +0200 Subject: [PATCH] Add pass for decomposing (log)softmax Signed-off-by: Erik Lundell Change-Id: Ifda9b5a557ee29af5abc64569c51ec6bd6421c66 --- backends/arm/_passes/arm_pass_manager.py | 5 + .../arm/_passes/decompose_softmaxes_pass.py | 74 ++++++++ backends/arm/arm_partitioner.py | 1 + backends/arm/operators/__init__.py | 1 - backends/arm/operators/op_exp.py | 1 - backends/arm/operators/op_softmax.py | 99 ----------- backends/arm/test/ops/test_logsoftmax.py | 158 ++++++++++++++++++ backends/arm/test/ops/test_softmax.py | 31 ++-- 8 files changed, 250 insertions(+), 120 deletions(-) create mode 100644 backends/arm/_passes/decompose_softmaxes_pass.py delete mode 100644 backends/arm/operators/op_softmax.py create mode 100644 backends/arm/test/ops/test_logsoftmax.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c4e806a842..62da3fc3b1 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -19,6 +19,9 @@ ConvertSplitToSlicePass, ) from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.decompose_softmaxes_pass import ( + DecomposeSoftmaxesPass, +) from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import ( InsertSqueezeAfterSumPass, ) @@ -52,6 +55,7 @@ def transform_to_backend_pipeline( self.add_pass(DecomposeDivPass()) self.add_pass(InsertSqueezeAfterSumPass()) self.add_pass(ConvertSplitToSlicePass()) + self.add_pass(DecomposeSoftmaxesPass()) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() @@ -63,4 +67,5 @@ def transform_to_backend_pipeline( def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule): self.add_pass(DecomposeDivPass()) self.add_pass(ScalarsToAttributePass()) + self.add_pass(DecomposeSoftmaxesPass()) return self._transform(graph_module) diff --git a/backends/arm/_passes/decompose_softmaxes_pass.py b/backends/arm/_passes/decompose_softmaxes_pass.py new file mode 100644 index 0000000000..a5af062e7b --- /dev/null +++ b/backends/arm/_passes/decompose_softmaxes_pass.py @@ -0,0 +1,74 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +# For BI case +torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int) + +# For MI case +edge_softmax = ( + exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten._log_softmax.default, +) + +log_softmax = (torch.ops.aten.log_softmax.int, exir_ops.edge.aten._log_softmax.default) + + +def get_logsoftmax_ops(op) -> tuple: + """ + Returns the the (log_op, expo_op, sum_op, reciprocal_op), where the ops depends on if + the logsoftmax op is in exir_ops torch.ops.aten. + """ + if op in edge_softmax: + return ( + exir_ops.edge.aten.log.default, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.mul.Tensor, + ) + if op in torch_softmax: + return ( + torch.ops.aten.log.default, + torch.ops.aten.exp.default, + torch.ops.aten.sum.dim_IntList, + torch.ops.aten.reciprocal.default, + torch.ops.aten.mul.Tensor, + ) + raise RuntimeError(f"Can't get softmax decomposition ops for op {op}") + + +class DecomposeSoftmaxesPass(ExportPass): + """ + This pass decomposes log softmax or softmax into more primitive ops. + + Example: + %op1 = exp(x) + %op2 = sum(%op1, dim) + %op3 = reciprocal(%op2) + %op4 = mul(%op1, %op3) + (in logsoftmax case: %op5 = log(%op4)) + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in torch_softmax + edge_softmax: + return super().call_operator(op, args, kwargs, meta) + + log_op, exp_op, sum_op, reciprocal_op, mul_op = get_logsoftmax_ops(op) + + _input = args[0] + dim = [args[1]] + + op1 = super().call_operator(exp_op, (_input,), {}, meta) + op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta) + op3 = super().call_operator(reciprocal_op, (op2,), {}, meta) + op4 = super().call_operator(mul_op, (op1, op3), {}, meta) + if op in log_softmax: + op4 = super().call_operator(log_op, (op4,), {}, meta) + return op4 diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 7db893694b..c3256ddc47 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -61,6 +61,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.relu.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 855487cf7f..93911e9fff 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -31,7 +31,6 @@ op_rsqrt, op_sigmoid, op_slice, - op_softmax, op_squeeze, op_sub, op_sum, diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index f98bb3f88c..0e0a75dcc4 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -42,7 +42,6 @@ def define_node( ) -> None: assert len(node.all_input_nodes) == 1 - assert len(node.users) == 1 if is_quant_node: # Assume quantized input is 8 bit. diff --git a/backends/arm/operators/op_softmax.py b/backends/arm/operators/op_softmax.py deleted file mode 100644 index 1ac4241318..0000000000 --- a/backends/arm/operators/op_softmax.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import tosa_shape -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class SoftmaxVisitor(NodeVisitor): - target = "aten._softmax.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - is_quant_node: bool, - ) -> None: - input_name = inputs[0].name - dim_order = inputs[0].dim_order - input_shape = tosa_shape(inputs[0].shape, dim_order) - dim_value = dim_order.index(inputs[1].number % len(dim_order)) - - ## softmax = exp(logits - max(logits)) / reduce_sum(exp(logits - max(logits)), -1) - # FP32 - # reduce_max_res = reducemax(logits) - # sub_res = sub(inputs, reduce_max_res) - # exp_res = exp(sub_res) - # reduce_sum_res = reduce_sum(exp_res, -1) - # inverted_reduce_sum = reciprocal(reduce_sum_res) - # output = mul(exp_res, inverted_reduce_sum) - - # Max_Reduction - attr_axis = ts.TosaSerializerAttribute() - attr_axis.AxisAttribute(axis=dim_value) - reduced_shape = list(input_shape) - reduced_shape[dim_value] = 1 - reduce_max_res = tosa_graph.addIntermediate(reduced_shape, output.dtype) - tosa_graph.addOperator( - TosaOp.Op().REDUCE_MAX, - [input_name], - [reduce_max_res.name], - attr_axis, - ) - - # Subtract max from logits - sub_res = tosa_graph.addIntermediate(input_shape, output.dtype) - tosa_graph.addOperator( - TosaOp.Op().SUB, - [input_name, reduce_max_res.name], - [sub_res.name], - ) - - # Raise the subtraction results to exponent - exp_res = tosa_graph.addIntermediate(input_shape, output.dtype) - tosa_graph.addOperator(TosaOp.Op().EXP, [sub_res.name], [exp_res.name]) - - # Reduce_sum of the calculated exponent value - reduce_sum_res = tosa_graph.addIntermediate(reduced_shape, output.dtype) - tosa_graph.addOperator( - TosaOp.Op().REDUCE_SUM, - [exp_res.name], - [reduce_sum_res.name], - attr_axis, - ) - - # Invert the reduce_sum - inverted_reduce_sum = tosa_graph.addIntermediate(reduced_shape, output.dtype) - tosa_graph.addOperator( - TosaOp.Op().RECIPROCAL, - [reduce_sum_res.name], - [inverted_reduce_sum.name], - ) - - # Multiply two parts to get the final results - attr_mul = ts.TosaSerializerAttribute() - attr_mul.MulAttribute(0) - tosa_graph.addOperator( - TosaOp.Op().MUL, - [exp_res.name, inverted_reduce_sum.name], - [output.name], - attr_mul, - ) diff --git a/backends/arm/test/ops/test_logsoftmax.py b/backends/arm/test/ops/test_logsoftmax.py new file mode 100644 index 0000000000..2d51588bb3 --- /dev/null +++ b/backends/arm/test/ops/test_logsoftmax.py @@ -0,0 +1,158 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +test_data_suite = [ + # (test_name, test_data, dim) + ("zeros", torch.zeros(10, 10, 10, 10), 0), + ("zeros_neg_dim", torch.zeros(10, 10, 10, 10), -4), + ("ones", torch.ones(10, 10), 1), + ("rand_neg_dim", torch.rand(10, 10, 10), -1), + ("rand", torch.rand(10, 10, 10, 10), 2), + ("rand_neg_dim", torch.rand(10, 10, 2, 3), -2), + ("randn", torch.randn(10, 10, 5, 10), 3), + ("randn_neg_dim", torch.randn(1, 10, 10, 10), -3), +] + + +class TestLogSoftmax(unittest.TestCase): + """Tests logsoftmax.""" + + class LogSoftmax(torch.nn.Module): + def __init__(self, dim: int = -1): + super().__init__() + self.logsoftmax = torch.nn.LogSoftmax(dim=dim) + + def forward(self, x): + return self.logsoftmax(x) + + def _test_logsoftmax_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check(["torch.ops.aten.log_softmax.int"]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_logsoftmax_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .check_not(["torch.ops.aten.log_softmax.int"]) + .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten__log_softmax_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_logsoftmax_tosa_ethos_BI_pipeline( + self, + compile_spec: list[CompileSpec], + module: torch.nn.Module, + test_data: Tuple[torch.tensor], + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .check_not(["torch.ops.aten.log_softmax.int"]) + .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + def _test_logsoftmax_tosa_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + self._test_logsoftmax_tosa_ethos_BI_pipeline( + common.get_u55_compile_spec(), module, test_data + ) + + def _test_logsoftmax_tosa_u85_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + self._test_logsoftmax_tosa_ethos_BI_pipeline( + common.get_u85_compile_spec(), module, test_data + ) + + @parameterized.expand(test_data_suite) + def test_logsoftmax_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + dim: int, + ): + self._test_logsoftmax_tosa_MI_pipeline(self.LogSoftmax(dim=dim), (test_data,)) + + @parameterized.expand(test_data_suite) + def test_logsoftmax_tosa_BI( + self, + test_name: str, + test_data: torch.Tensor, + dim: int, + ): + self._test_logsoftmax_tosa_BI_pipeline(self.LogSoftmax(dim=dim), (test_data,)) + + @parameterized.expand(test_data_suite) + def test_logsoftmax_tosa_u55_BI( + self, + test_name: str, + test_data: torch.Tensor, + dim: int, + ): + self._test_logsoftmax_tosa_u55_BI_pipeline( + self.LogSoftmax(dim=dim), (test_data,) + ) + + @parameterized.expand(test_data_suite) + def test_logsoftmax_tosa_u85_BI( + self, + test_name: str, + test_data: torch.Tensor, + dim: int, + ): + self._test_logsoftmax_tosa_u55_BI_pipeline( + self.LogSoftmax(dim=dim), (test_data,) + ) diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index a7d25d266d..954dd201a9 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -18,14 +18,14 @@ test_data_suite = [ # (test_name, test_data, dim) - ("zeros", torch.zeros(10, 10, 10, 10), 0), - ("zeros_neg_dim", torch.zeros(10, 10, 10, 10), -4), - ("ones", torch.ones(10, 10, 10, 10), 1), - ("ones_neg_dim", torch.ones(10, 10, 10, 10), -1), - ("rand", torch.rand(10, 10, 10, 10), 2), - ("rand_neg_dim", torch.rand(10, 10, 10, 10), -2), + ("zeros", torch.zeros(10, 8, 5, 2), 0), + ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4), + ("ones", torch.ones(10, 10), 1), + ("ones_neg_dim", torch.ones(10, 3, 4), -1), + ("rand", torch.rand(1, 2, 5, 8), 2), + ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2), ("randn", torch.randn(10, 10, 10, 10), 3), - ("randn_neg_dim", torch.randn(10, 10, 10, 10), -3), + ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3), ] @@ -71,14 +71,14 @@ def _test_softmax_tosa_BI_pipeline( ) .quantize() .export() - .check_count({"torch.ops.aten.softmax.int": 1}) - .check(["torch.ops.quantized_decomposed"]) + .check_not(["torch.ops.aten.softmax.int"]) + .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"]) .to_edge() .partition() .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) + .run_method_and_compare_outputs(inputs=test_data) ) def _test_softmax_tosa_ethos_BI_pipeline( @@ -95,8 +95,8 @@ def _test_softmax_tosa_ethos_BI_pipeline( ) .quantize() .export() - .check_count({"torch.ops.aten.softmax.int": 1}) - .check(["torch.ops.quantized_decomposed"]) + .check_not(["torch.ops.aten.softmax.int"]) + .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"]) .to_edge() .partition() .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) @@ -127,10 +127,7 @@ def test_softmax_tosa_MI( ): self._test_softmax_tosa_MI_pipeline(self.Softmax(dim=dim), (test_data,)) - # Expected to fail since ArmQuantizer cannot quantize a SoftMax operator - # TODO(MLETORCH-92) @parameterized.expand(test_data_suite) - @unittest.expectedFailure def test_softmax_tosa_BI( self, test_name: str, @@ -139,10 +136,7 @@ def test_softmax_tosa_BI( ): self._test_softmax_tosa_BI_pipeline(self.Softmax(dim=dim), (test_data,)) - # Expected to fail since ArmQuantizer cannot quantize a SoftMax layer - # TODO(MLETORCH-92) @parameterized.expand(test_data_suite) - @unittest.expectedFailure def test_softmax_tosa_u55_BI( self, test_name: str, @@ -152,7 +146,6 @@ def test_softmax_tosa_u55_BI( self._test_softmax_tosa_u55_BI_pipeline(self.Softmax(dim=dim), (test_data,)) @parameterized.expand(test_data_suite) - @unittest.expectedFailure def test_softmax_tosa_u85_BI( self, test_name: str,