From 2b6a934ba06fd47438184c2ac9089fc882c76746 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 25 Apr 2024 11:00:10 -0700 Subject: [PATCH] Ensure that globals are always emitted in declaration order at the top. Prior to this change, globals were interleaved with functions as part of the same sequence. IREE prefers that module level entities exist in def-ref order. --- shark_turbine/aot/support/ir_utils.py | 23 +++++++++-- tests/aot/compiled_exported_program_test.py | 43 ++++++++++++++++++++- 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/shark_turbine/aot/support/ir_utils.py b/shark_turbine/aot/support/ir_utils.py index 813b7f4b..a662c15c 100644 --- a/shark_turbine/aot/support/ir_utils.py +++ b/shark_turbine/aot/support/ir_utils.py @@ -156,7 +156,7 @@ class ModuleBuilder: "cache", "context", "fx_py_attr_tracker", - "global_ip", + "last_global_op", "ip", "module_op", "symbol_table", @@ -170,7 +170,10 @@ def __init__(self, module_op: Operation): self.context = module_op.context self.body = module_op.regions[0].blocks[0] self.symbol_table = SymbolTable(module_op) - self.global_ip = InsertionPoint.at_block_begin(self.body) + # We organize globals in order of declaration at the top of the module. + # To do so, record the last one emitted so that newly created ones + # can be ordered properly. + self.last_global_op: Optional[Operation] = None self.ip = InsertionPoint(self.body) self.cache = ContextCache(self.context) # Tracks global references to a MaterializedGlobal. @@ -250,7 +253,10 @@ def create_tensor_global( element_type = self.torch_dtype_to_iree_type(t.dtype) external, external_scope, external_name = attrs.infer_external_from_tensor(t) - with self.global_ip, Location.unknown(): + # Always create globals at the top. Then after created, if there was + # a prior one, move the new one to after it to maintain declaration + # order. + with InsertionPoint.at_block_begin(self.body), Location.unknown(): tensor_type = RankedTensorType.get(list(t.shape), element_type) ir_attrs = { "sym_name": StringAttr.get(symbol_name), @@ -297,6 +303,9 @@ def create_tensor_global( global_op = Operation.create("util.global", attributes=ir_attrs) self.symbol_table.insert(global_op) + if self.last_global_op is not None: + global_op.move_after(self.last_global_op) + self.last_global_op = global_op actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value return actual_symbol_name, global_op, tensor_type @@ -308,7 +317,10 @@ def create_typed_global( attrs: GlobalAttributes, logical_name: Optional[str] = None, ) -> Tuple[str, Operation]: - with self.global_ip, Location.unknown(): + # Always create globals at the top. Then after created, if there was + # a prior one, move the new one to after it to maintain declaration + # order. + with InsertionPoint.at_block_begin(self.body), Location.unknown(): ir_attrs = { "sym_name": StringAttr.get(symbol_name), "sym_visibility": StringAttr.get("private"), @@ -332,6 +344,9 @@ def create_typed_global( ) global_op = Operation.create("util.global", attributes=ir_attrs) self.symbol_table.insert(global_op) + if self.last_global_op is not None: + global_op.move_after(self.last_global_op) + self.last_global_op = global_op actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value return actual_symbol_name, global_op diff --git a/tests/aot/compiled_exported_program_test.py b/tests/aot/compiled_exported_program_test.py index 4bf3fa15..baaeb9bb 100644 --- a/tests/aot/compiled_exported_program_test.py +++ b/tests/aot/compiled_exported_program_test.py @@ -90,7 +90,7 @@ class ExportedPublicModule(CompiledModule): self.assertIn("func.func @compute1", module_str) self.assertIn("func.func @compute2", module_str) - def testParametersAsGlobals(self): + def testParametersAsExplicitGlobals(self): fxb = FxProgramsBuilder(SimpleParams()) @fxb.export_program( @@ -119,6 +119,47 @@ class ParamsAsGlobalsModule(CompiledModule): 2, module_str.count("util.global.load @_params.classifier.bias") ) + def testParametersAsGlobalsViaExternalizeModuleParameters(self): + mdl = SimpleParams() + externalize_module_parameters(mdl) + + fxb = FxProgramsBuilder(mdl) + + @fxb.export_program( + args=(torch.empty([128, 20]),), + ) + def _compute1(module, x): + return module.forward(x) + + class ParamsAsGlobalsModule(CompiledModule): + compute1 = _compute1 + compute2 = _compute1 + + inst = ParamsAsGlobalsModule(context=Context(), import_to="import") + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn("util.global private @__auto.classifier.weight", module_str) + self.assertIn("util.global private @__auto.classifier.bias", module_str) + + # It's clunky to verify ordering, but we explicitly guarantee that + # implicitly exported globals are emitted in order of declaration, + # preceeding all functions. + g1_index = module_str.index("util.global private @__auto.classifier.weight") + g2_index = module_str.index("util.global private @__auto.classifier.bias") + f_index = module_str.index("func") + self.assertGreater(g2_index, g1_index) + self.assertGreater(f_index, g2_index) + + # Should only be two. + self.assertEqual(2, module_str.count("util.global private")) + # And two loads each loads. + self.assertEqual( + 2, module_str.count("util.global.load @__auto.classifier.weight") + ) + self.assertEqual( + 2, module_str.count("util.global.load @__auto.classifier.bias") + ) + def testBuffersAsGlobals(self): fxb = FxProgramsBuilder(SimpleBuffers())