diff --git a/shark_turbine/aot/support/procedural/exported_program.py b/shark_turbine/aot/support/procedural/exported_program.py index 331a7345..bbc431ae 100644 --- a/shark_turbine/aot/support/procedural/exported_program.py +++ b/shark_turbine/aot/support/procedural/exported_program.py @@ -234,6 +234,8 @@ def store_produced_value( raise ValueError(f"Cannot store value to unmapped global for: {info}") logger.debug("Resolved global for store %r", mapping) materialized_global: MaterializedGlobal = mapping.value # type: ignore + assert isinstance(materialized_global.global_op, util_d.GlobalOp) + materialized_global.global_op.is_mutable = True converted_value = Operation.create( "torch_c.to_builtin_tensor", results=[materialized_global.ir_type], @@ -251,7 +253,7 @@ def resolve_literal( return None # See if we know about it. - materialized_global = self._lift_tensor_to_global(literal) + materialized_global = self._lift_tensor_to_global(literal, info) if not materialized_global: # If it is unknown, just let the default importer take it on. return None @@ -269,7 +271,7 @@ def resolve_literal( return converted_value def _lift_tensor_to_global( - self, literal: torch.Tensor + self, literal: torch.Tensor, info: InputInfo | None ) -> Optional[MaterializedGlobal]: module_builder = self.module_builder mapping = module_builder.global_ref_tracker.track(literal) @@ -282,7 +284,7 @@ def _lift_tensor_to_global( # Policy check: Should we auto-import? Generally, we keep "small" # tensors as inline as they can be optimized. external_trait = ExternalTensorTrait.get(literal) - if not self._should_lift_tensor_to_global(literal, external_trait): + if not self._should_lift_tensor_to_global(literal, external_trait, info): return None # If it is a tensor we haven't seen yet, materialize it @@ -304,8 +306,13 @@ def _lift_tensor_to_global( return materialized_global def _should_lift_tensor_to_global( - self, literal: torch.Tensor, external_trait: Optional[ExternalTensorTrait] + self, + literal: torch.Tensor, + external_trait: Optional[ExternalTensorTrait], + info: InputInfo | None, ) -> bool: + if info is not None and info.store_producer_node: + return True if external_trait is not None: return True volume = math.prod(literal.shape) diff --git a/tests/aot/globals_test.py b/tests/aot/globals_test.py index 26bab1a6..607382fd 100644 --- a/tests/aot/globals_test.py +++ b/tests/aot/globals_test.py @@ -425,6 +425,68 @@ def testUnsupportedCombinations(self): export_global(AbstractF32, external=True, uninitialized=True) +class SimpleCache(torch.nn.Module): + def __init__(self, max_size, dtype=torch.float32): + super().__init__() + self.register_buffer("cache", torch.zeros(max_size, dtype=dtype)) + + def forward(self, input_pos, values): + # input_pos: [S], values: [S] + assert input_pos.shape[0] == values.shape[0] + + # Writing the values to the buffer at the specified positions + cache = torch.ops.aten.index_put_(self.cache, [input_pos], values) + + return cache + + +class ReadWriteReadCache(torch.nn.Module): + def __init__(self, max_size, dtype=torch.float32): + super().__init__() + self.register_buffer("cache", torch.zeros(max_size, dtype=dtype)) + + def forward(self, input_pos, values): + # input_pos: [S], values: [S] + assert input_pos.shape[0] == values.shape[0] + cache_value_0 = self.cache[2].clone() + # Writing the values to the buffer at the specified positions + cache = torch.ops.aten.index_put_(self.cache, [input_pos], values) + cache_value_1 = cache[2].clone() + return cache, cache_value_0, cache_value_1 + + +class BufferTest(unittest.TestCase): + def testMutableBuffer(self): + max_size = 10 + simple_cache = SimpleCache(max_size) + + input_pos = torch.tensor([2, 5, 7]) + values = torch.tensor([1.0, 2.0, 3.0]) + simple_cache(input_pos, values) + exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values)) + exported_programm = export(exported_fx_graph) + module_str = str(exported_programm.mlir_module) + self.assertIn( + "util.global private mutable @__auto.constant_10_torch.float32", + module_str, + ) + + def testReadWriteReadMutableBuffer(self): + max_size = 10 + simple_cache = ReadWriteReadCache(max_size) + + input_pos = torch.tensor([2, 5, 7]) + values = torch.tensor([1.0, 2.0, 3.0]) + simple_cache(input_pos, values) + exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values)) + exported_programm = export(exported_fx_graph) + module_str = str(exported_programm.mlir_module) + self.assertIn( + "util.global private mutable @__auto.constant_10_torch.float32", + module_str, + ) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main()