Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ExportedProgram] Add mutable attribute to buffer #123

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions shark_turbine/aot/support/procedural/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def store_produced_value(
raise ValueError(f"Cannot store value to unmapped global for: {info}")
logger.debug("Resolved global for store %r", mapping)
materialized_global: MaterializedGlobal = mapping.value # type: ignore
assert isinstance(materialized_global.global_op, util_d.GlobalOp)
materialized_global.global_op.is_mutable = True
converted_value = Operation.create(
"torch_c.to_builtin_tensor",
results=[materialized_global.ir_type],
Expand All @@ -251,7 +253,7 @@ def resolve_literal(
return None

# See if we know about it.
materialized_global = self._lift_tensor_to_global(literal)
materialized_global = self._lift_tensor_to_global(literal, info)
if not materialized_global:
# If it is unknown, just let the default importer take it on.
return None
Expand All @@ -269,7 +271,7 @@ def resolve_literal(
return converted_value

def _lift_tensor_to_global(
self, literal: torch.Tensor
self, literal: torch.Tensor, info: InputInfo | None
) -> Optional[MaterializedGlobal]:
module_builder = self.module_builder
mapping = module_builder.global_ref_tracker.track(literal)
Expand All @@ -282,7 +284,7 @@ def _lift_tensor_to_global(
# Policy check: Should we auto-import? Generally, we keep "small"
# tensors as inline as they can be optimized.
external_trait = ExternalTensorTrait.get(literal)
if not self._should_lift_tensor_to_global(literal, external_trait):
if not self._should_lift_tensor_to_global(literal, external_trait, info):
return None

# If it is a tensor we haven't seen yet, materialize it
Expand All @@ -304,8 +306,13 @@ def _lift_tensor_to_global(
return materialized_global

def _should_lift_tensor_to_global(
self, literal: torch.Tensor, external_trait: Optional[ExternalTensorTrait]
self,
literal: torch.Tensor,
external_trait: Optional[ExternalTensorTrait],
info: InputInfo | None,
) -> bool:
if info is not None and info.store_producer_node:
return True
if external_trait is not None:
return True
volume = math.prod(literal.shape)
Expand Down
62 changes: 62 additions & 0 deletions tests/aot/globals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,68 @@ def testUnsupportedCombinations(self):
export_global(AbstractF32, external=True, uninitialized=True)


class SimpleCache(torch.nn.Module):
def __init__(self, max_size, dtype=torch.float32):
super().__init__()
self.register_buffer("cache", torch.zeros(max_size, dtype=dtype))

def forward(self, input_pos, values):
# input_pos: [S], values: [S]
assert input_pos.shape[0] == values.shape[0]

# Writing the values to the buffer at the specified positions
cache = torch.ops.aten.index_put_(self.cache, [input_pos], values)

return cache


class ReadWriteReadCache(torch.nn.Module):
def __init__(self, max_size, dtype=torch.float32):
super().__init__()
self.register_buffer("cache", torch.zeros(max_size, dtype=dtype))

def forward(self, input_pos, values):
# input_pos: [S], values: [S]
assert input_pos.shape[0] == values.shape[0]
cache_value_0 = self.cache[2].clone()
# Writing the values to the buffer at the specified positions
cache = torch.ops.aten.index_put_(self.cache, [input_pos], values)
cache_value_1 = cache[2].clone()
return cache, cache_value_0, cache_value_1


class BufferTest(unittest.TestCase):
def testMutableBuffer(self):
max_size = 10
simple_cache = SimpleCache(max_size)

input_pos = torch.tensor([2, 5, 7])
values = torch.tensor([1.0, 2.0, 3.0])
simple_cache(input_pos, values)
exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values))
exported_programm = export(exported_fx_graph)
module_str = str(exported_programm.mlir_module)
self.assertIn(
"util.global private mutable @__auto.constant_10_torch.float32",
module_str,
)

def testReadWriteReadMutableBuffer(self):
max_size = 10
simple_cache = ReadWriteReadCache(max_size)

input_pos = torch.tensor([2, 5, 7])
values = torch.tensor([1.0, 2.0, 3.0])
simple_cache(input_pos, values)
exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values))
exported_programm = export(exported_fx_graph)
module_str = str(exported_programm.mlir_module)
self.assertIn(
"util.global private mutable @__auto.constant_10_torch.float32",
module_str,
)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
Loading