From 8e0412b01ba222812d317fcb88d50513b1d1e5cd Mon Sep 17 00:00:00 2001 From: Rob Elliott Date: Thu, 17 Oct 2024 15:05:43 +0100 Subject: [PATCH] Align aot_arm_compiler to latest export flow - Update to_edge_transform_and_lower rather than export_to_edge - fix channel last on some models Signed-off-by: Rob Elliott Change-Id: I0f8e9206aa1ff3004a746955010c7bf01c896347 --- examples/arm/aot_arm_compiler.py | 70 ++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 31 deletions(-) diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 29cc0c30c7..3075d992d5 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -25,8 +25,12 @@ from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig -from executorch.extension.export_util.utils import export_to_edge, save_pte_program +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.extension.export_util.utils import save_pte_program from tabulate import tabulate # Quantize model if required using the standard export quantizaion flow. @@ -170,7 +174,9 @@ def forward(self, x): ] -def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder: +def get_compile_spec( + target: str, intermediates: Optional[str] = None +) -> ArmCompileSpecBuilder: spec_builder = None if target == "TOSA": spec_builder = ( @@ -185,7 +191,7 @@ def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder: memory_mode="Shared_Sram", extra_flags="--debug-force-regor --output-format=raw", ) - .set_permute_memory_format(args.model_name in MODEL_NAME_TO_MODEL.keys()) + .set_permute_memory_format(True) .set_quantize_io(True) ) elif "ethos-u85" in target: @@ -202,7 +208,7 @@ def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder: ) if intermediates is not None: - spec_builder.dump_intermediate_artifacts_to(args.intermediates) + spec_builder.dump_intermediate_artifacts_to(intermediates) return spec_builder.build() @@ -356,40 +362,42 @@ def get_args(): model, example_inputs = get_model_and_inputs_from_name(args.model_name) model = model.eval() + # export_for_training under the assumption we quantize, the exported form also works + # in to_edge if we don't quantize + exported_program = torch.export.export_for_training(model, example_inputs) + model = exported_program.module() model_fp32 = model - # pre-autograd export. eventually this will become torch.export - model = torch.export.export_for_training(model, example_inputs).module() - # Quantize if required model_int8 = None if args.quantize: model = quantize(model, example_inputs) model_int8 = model + # Wrap quantized model back into an exported_program + exported_program = torch.export.export_for_training(model, example_inputs) + + if args.delegate: + # As we can target multiple output encodings from ArmBackend, one must + # be specified. + compile_spec = get_compile_spec(args.target, args.intermediates) + edge = to_edge_transform_and_lower( + exported_program, + partitioner=[ArmPartitioner(compile_spec)], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + else: + edge = to_edge_transform_and_lower( + exported_program, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) - edge = export_to_edge( - model, - example_inputs, - edge_compile_config=EdgeCompileConfig( - _check_ir_validity=False, - ), - ) - - # As we can target multiple output encodings from ArmBackend, one must - # be specified. - compile_spec = ( - get_compile_spec(args.target, args.intermediates) - if args.delegate is True - else None - ) - - logging.debug(f"Exported graph:\n{edge.exported_program().graph}") - if args.delegate is True: - edge = edge.to_backend(ArmPartitioner(compile_spec)) - - dump_delegation_info(edge, args.intermediates) - - logging.debug(f"Lowered graph:\n{edge.exported_program().graph}") + dump_delegation_info(edge, args.intermediates) try: exec_prog = edge.to_executorch(