Skip to content

Commit

Permalink
Ensure that globals are always emitted in declaration order at the to…
Browse files Browse the repository at this point in the history
…p. (#5)

Prior to this change, globals were interleaved with functions as part of
the same sequence. IREE prefers that module level entities exist in
def-ref order.
  • Loading branch information
stellaraccident authored Apr 25, 2024
1 parent cafc812 commit 6d05406
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 5 deletions.
23 changes: 19 additions & 4 deletions shark_turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class ModuleBuilder:
"cache",
"context",
"fx_py_attr_tracker",
"global_ip",
"last_global_op",
"ip",
"module_op",
"symbol_table",
Expand All @@ -170,7 +170,10 @@ def __init__(self, module_op: Operation):
self.context = module_op.context
self.body = module_op.regions[0].blocks[0]
self.symbol_table = SymbolTable(module_op)
self.global_ip = InsertionPoint.at_block_begin(self.body)
# We organize globals in order of declaration at the top of the module.
# To do so, record the last one emitted so that newly created ones
# can be ordered properly.
self.last_global_op: Optional[Operation] = None
self.ip = InsertionPoint(self.body)
self.cache = ContextCache(self.context)
# Tracks global references to a MaterializedGlobal.
Expand Down Expand Up @@ -250,7 +253,10 @@ def create_tensor_global(
element_type = self.torch_dtype_to_iree_type(t.dtype)
external, external_scope, external_name = attrs.infer_external_from_tensor(t)

with self.global_ip, Location.unknown():
# Always create globals at the top. Then after created, if there was
# a prior one, move the new one to after it to maintain declaration
# order.
with InsertionPoint.at_block_begin(self.body), Location.unknown():
tensor_type = RankedTensorType.get(list(t.shape), element_type)
ir_attrs = {
"sym_name": StringAttr.get(symbol_name),
Expand Down Expand Up @@ -297,6 +303,9 @@ def create_tensor_global(

global_op = Operation.create("util.global", attributes=ir_attrs)
self.symbol_table.insert(global_op)
if self.last_global_op is not None:
global_op.move_after(self.last_global_op)
self.last_global_op = global_op
actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value
return actual_symbol_name, global_op, tensor_type

Expand All @@ -308,7 +317,10 @@ def create_typed_global(
attrs: GlobalAttributes,
logical_name: Optional[str] = None,
) -> Tuple[str, Operation]:
with self.global_ip, Location.unknown():
# Always create globals at the top. Then after created, if there was
# a prior one, move the new one to after it to maintain declaration
# order.
with InsertionPoint.at_block_begin(self.body), Location.unknown():
ir_attrs = {
"sym_name": StringAttr.get(symbol_name),
"sym_visibility": StringAttr.get("private"),
Expand All @@ -332,6 +344,9 @@ def create_typed_global(
)
global_op = Operation.create("util.global", attributes=ir_attrs)
self.symbol_table.insert(global_op)
if self.last_global_op is not None:
global_op.move_after(self.last_global_op)
self.last_global_op = global_op
actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value
return actual_symbol_name, global_op

Expand Down
43 changes: 42 additions & 1 deletion tests/aot/compiled_exported_program_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ExportedPublicModule(CompiledModule):
self.assertIn("func.func @compute1", module_str)
self.assertIn("func.func @compute2", module_str)

def testParametersAsGlobals(self):
def testParametersAsExplicitGlobals(self):
fxb = FxProgramsBuilder(SimpleParams())

@fxb.export_program(
Expand Down Expand Up @@ -119,6 +119,47 @@ class ParamsAsGlobalsModule(CompiledModule):
2, module_str.count("util.global.load @_params.classifier.bias")
)

def testParametersAsGlobalsViaExternalizeModuleParameters(self):
mdl = SimpleParams()
externalize_module_parameters(mdl)

fxb = FxProgramsBuilder(mdl)

@fxb.export_program(
args=(torch.empty([128, 20]),),
)
def _compute1(module, x):
return module.forward(x)

class ParamsAsGlobalsModule(CompiledModule):
compute1 = _compute1
compute2 = _compute1

inst = ParamsAsGlobalsModule(context=Context(), import_to="import")
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn("util.global private @__auto.classifier.weight", module_str)
self.assertIn("util.global private @__auto.classifier.bias", module_str)

# It's clunky to verify ordering, but we explicitly guarantee that
# implicitly exported globals are emitted in order of declaration,
# preceeding all functions.
g1_index = module_str.index("util.global private @__auto.classifier.weight")
g2_index = module_str.index("util.global private @__auto.classifier.bias")
f_index = module_str.index("func")
self.assertGreater(g2_index, g1_index)
self.assertGreater(f_index, g2_index)

# Should only be two.
self.assertEqual(2, module_str.count("util.global private"))
# And two loads each loads.
self.assertEqual(
2, module_str.count("util.global.load @__auto.classifier.weight")
)
self.assertEqual(
2, module_str.count("util.global.load @__auto.classifier.bias")
)

def testBuffersAsGlobals(self):
fxb = FxProgramsBuilder(SimpleBuffers())

Expand Down

0 comments on commit 6d05406

Please sign in to comment.