Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot import timm models #26

Open
maxbartel opened this issue Jun 10, 2024 · 1 comment
Open

Cannot import timm models #26

maxbartel opened this issue Jun 10, 2024 · 1 comment

Comments

@maxbartel
Copy link
Contributor

maxbartel commented Jun 10, 2024

timm https://huggingface.co/docs/timm/index is a huggingface library for computer vision models. Some older models like resnet are available through that.

At the moment iree-turbine is not able to import those models.

Reproducer:

import timm
from shark_turbine.aot import (
    export,
)
import torch

model = timm.create_model("hf_hub:timm/resnet50.a1_in1k", pretrained=True)

exported = torch.export.export(model, (torch.randn((1, 3, 224, 224)),))

test = export(exported)

The model can be exported to a fx_graph, but calling export fails.

Error message:

Traceback (most recent call last):
  File "/Users/bartel/Documents/synaptics/phase_by_phase.py", line 12, in <module>
    test = export(exported)
           ^^^^^^^^^^^^^^^^
  File "/Users/bartel/miniforge3/envs/test_iree/lib/python3.11/site-packages/shark_turbine/aot/exporter.py", line 304, in export
    cm = TransformedModule(context=context, import_to="import")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bartel/miniforge3/envs/test_iree/lib/python3.11/site-packages/shark_turbine/aot/compiled_module.py", line 615, in __new__
    info.shadow_dict[key] = import_exported_program(
                            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bartel/miniforge3/envs/test_iree/lib/python3.11/site-packages/shark_turbine/aot/support/procedural/exported_program.py", line 183, in import_exported_program
    entry_func_op = fx_importer.import_program(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bartel/miniforge3/envs/test_iree/lib/python3.11/site-packages/iree/compiler/extras/fx_importer.py", line 687, in import_program
    node_importer.import_nodes(
  File "/Users/bartel/miniforge3/envs/test_iree/lib/python3.11/site-packages/iree/compiler/extras/fx_importer.py", line 1230, in import_nodes
    self._import_torch_op_overload(loc, node, target)
  File "/Users/bartel/miniforge3/envs/test_iree/lib/python3.11/site-packages/iree/compiler/extras/fx_importer.py", line 1476, in _import_torch_op_overload
    self.bind_node_value(node, value, i)
  File "/Users/bartel/miniforge3/envs/test_iree/lib/python3.11/site-packages/iree/compiler/extras/fx_importer.py", line 1081, in bind_node_value
    producer_callback(value)
  File "/Users/bartel/miniforge3/envs/test_iree/lib/python3.11/site-packages/iree/compiler/extras/fx_importer.py", line 1168, in on_produced
    self.fx_importer._hooks.store_produced_value(
  File "/Users/bartel/miniforge3/envs/test_iree/lib/python3.11/site-packages/shark_turbine/aot/support/procedural/exported_program.py", line 234, in store_produced_value
    raise ValueError(f"Cannot store value to unmapped global for: {info}")
ValueError: Cannot store value to unmapped global for: InputInfo(program=<torch.export.exported_program.ExportedProgram object at 0x16ff3f8d0>, input_spec=InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg163_1'), target='bn1.num_batches_tracked', persistent=True), node=arg163_1, ir_type=Type(!torch.vtensor<[],si64>), mutable_producer_node_name=None, store_producer_node='add')

@stellaraccident

@stellaraccident
Copy link
Collaborator

This uses the (new) mutable buffer support. We haven't had a lot of mileage on this outside of tests and something isn't lining up.

The hint in the stack trace is:

ValueError: Cannot store value to unmapped global for: InputInfo(program=<torch.export.exported_program.ExportedProgram object at 0x16ff3f8d0>, input_spec=InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg163_1'), target='bn1.num_batches_tracked', persistent=True), node=arg163_1, ir_type=Type(!torch.vtensor<[],si64>), mutable_producer_node_name=None, store_producer_node='add')

With that said, it looks like this model is collecting some statistics of some kind and storing them in a buffer, which is likely completely unrelated to inference of the model. It should be fine, but is just a waste and unnecessary synchronization point. It might also indicate that the model is including training time code paths in inference mode.

So this should work/can be fixed. But likely, the model is not particularly good for inference, just based on what I see.

I'm not sure when I can get to this. Working on a deadline right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants