Skip to content

Commit

Permalink
fix: buffers need the mutable attribute
Browse files Browse the repository at this point in the history
Signed-off-by: Christopher McGirr <[email protected]>
  • Loading branch information
maxbartel authored and chrsmcgrr committed Sep 19, 2024
1 parent 512b366 commit 2a7058f
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 4 deletions.
15 changes: 11 additions & 4 deletions shark_turbine/aot/support/procedural/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def store_produced_value(
raise ValueError(f"Cannot store value to unmapped global for: {info}")
logger.debug("Resolved global for store %r", mapping)
materialized_global: MaterializedGlobal = mapping.value # type: ignore
assert isinstance(materialized_global.global_op, util_d.GlobalOp)
materialized_global.global_op.is_mutable = True
converted_value = Operation.create(
"torch_c.to_builtin_tensor",
results=[materialized_global.ir_type],
Expand All @@ -251,7 +253,7 @@ def resolve_literal(
return None

# See if we know about it.
materialized_global = self._lift_tensor_to_global(literal)
materialized_global = self._lift_tensor_to_global(literal, info)
if not materialized_global:
# If it is unknown, just let the default importer take it on.
return None
Expand All @@ -269,7 +271,7 @@ def resolve_literal(
return converted_value

def _lift_tensor_to_global(
self, literal: torch.Tensor
self, literal: torch.Tensor, info: InputInfo | None
) -> Optional[MaterializedGlobal]:
module_builder = self.module_builder
mapping = module_builder.global_ref_tracker.track(literal)
Expand All @@ -282,7 +284,7 @@ def _lift_tensor_to_global(
# Policy check: Should we auto-import? Generally, we keep "small"
# tensors as inline as they can be optimized.
external_trait = ExternalTensorTrait.get(literal)
if not self._should_lift_tensor_to_global(literal, external_trait):
if not self._should_lift_tensor_to_global(literal, external_trait, info):
return None

# If it is a tensor we haven't seen yet, materialize it
Expand All @@ -304,8 +306,13 @@ def _lift_tensor_to_global(
return materialized_global

def _should_lift_tensor_to_global(
self, literal: torch.Tensor, external_trait: Optional[ExternalTensorTrait]
self,
literal: torch.Tensor,
external_trait: Optional[ExternalTensorTrait],
info: InputInfo | None,
) -> bool:
if info is not None and info.store_producer_node:
return True
if external_trait is not None:
return True
volume = math.prod(literal.shape)
Expand Down
62 changes: 62 additions & 0 deletions tests/aot/globals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,68 @@ def testUnsupportedCombinations(self):
export_global(AbstractF32, external=True, uninitialized=True)


class SimpleCache(torch.nn.Module):
def __init__(self, max_size, dtype=torch.float32):
super().__init__()
self.register_buffer("cache", torch.zeros(max_size, dtype=dtype))

def forward(self, input_pos, values):
# input_pos: [S], values: [S]
assert input_pos.shape[0] == values.shape[0]

# Writing the values to the buffer at the specified positions
cache = torch.ops.aten.index_put_(self.cache, [input_pos], values)

return cache


class ReadWriteReadCache(torch.nn.Module):
def __init__(self, max_size, dtype=torch.float32):
super().__init__()
self.register_buffer("cache", torch.zeros(max_size, dtype=dtype))

def forward(self, input_pos, values):
# input_pos: [S], values: [S]
assert input_pos.shape[0] == values.shape[0]
cache_value_0 = self.cache[2].clone()
# Writing the values to the buffer at the specified positions
cache = torch.ops.aten.index_put_(self.cache, [input_pos], values)
cache_value_1 = cache[2].clone()
return cache, cache_value_0, cache_value_1


class BufferTest(unittest.TestCase):
def testMutableBuffer(self):
max_size = 10
simple_cache = SimpleCache(max_size)

input_pos = torch.tensor([2, 5, 7])
values = torch.tensor([1.0, 2.0, 3.0])
simple_cache(input_pos, values)
exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values))
exported_programm = export(exported_fx_graph)
module_str = str(exported_programm.mlir_module)
self.assertIn(
"util.global private mutable @__auto.constant_10_torch.float32",
module_str,
)

def testReadWriteReadMutableBuffer(self):
max_size = 10
simple_cache = ReadWriteReadCache(max_size)

input_pos = torch.tensor([2, 5, 7])
values = torch.tensor([1.0, 2.0, 3.0])
simple_cache(input_pos, values)
exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values))
exported_programm = export(exported_fx_graph)
module_str = str(exported_programm.mlir_module)
self.assertIn(
"util.global private mutable @__auto.constant_10_torch.float32",
module_str,
)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

0 comments on commit 2a7058f

Please sign in to comment.