Skip to content

Commit

Permalink
Support reconstructing CUDA event object within Dynamo graph (#133635)
Browse files Browse the repository at this point in the history
Summary:
`torch.cuda.Event` objects are different from `torch.cuda.Stream` in that events are not pooled, meaning we can't look up a previously created CUDA event object by ID. This prevents CUDA event object created outside of the Dynamo graph from being used within the graph (since Dynamo needs a way to emit a `call_function` line in the graph that does the retrieval of the event object for downstream op use). This PR adds a simple object pool within Dynamo utility, to support looking up CUDA event object by ID from within the Dynamo graph.

After this PR, if a user creates a CUDA event object outside of the graph and use that event within the graph, the behavior will exactly match eager.

Test commands:
- `pytest -rA test/dynamo/test_ctx_manager.py::CtxManagerTests::test_cuda_event_created_outside_of_graph`
- `pytest -rA test/dynamo/test_ctx_manager.py::CtxManagerTests::test_cuda_event_across_graph_break`

X-link: pytorch/pytorch#133635
Approved by: https://github.com/yifuwang
ghstack dependencies: #133532, #133531, #133636

Reviewed By: jeanschmidt

Differential Revision: D61432589

Pulled By: yf225

fbshipit-source-id: 6d168c84d80b8f086b90b6a506e7778e66b0d80e
  • Loading branch information
yf225 authored and facebook-github-bot committed Aug 18, 2024
1 parent 22a574f commit 64130f1
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from __future__ import annotations

import atexit
import collections
import contextlib
Expand Down Expand Up @@ -927,7 +929,7 @@ def create(scope, name, val):

class CleanupManager(ExactWeakKeyDictionary):
count = 0
instance: ClassVar["CleanupManager"]
instance: ClassVar[CleanupManager]

def _remove_id(self, idx):
for hook in self.values[idx]:
Expand Down Expand Up @@ -1775,7 +1777,7 @@ def disable_cache_limit():
guard_failures: DefaultDict[Any, List[Any]] = collections.defaultdict(list)

# Keep a record of graph break reasons for logging
graph_break_reasons: List["torch._dynamo.output_graph.GraphCompileReason"] = []
graph_break_reasons: List[torch._dynamo.output_graph.GraphCompileReason] = []

# keep record of compiled code, if we are in "error if recompile"
# to track code that dynamo has compiled previously
Expand Down Expand Up @@ -2246,7 +2248,7 @@ def tensor_static_reason_to_message(reason: TensorStaticReason):
def tensor_always_has_static_shape(
tensor: Union[torch.Tensor, Any],
is_tensor: bool,
guard_source: "torch._guards.GuardSource",
guard_source: torch._guards.GuardSource,
) -> Tuple[bool, Optional[TensorStaticReason]]:
"""
Given a tensor, source, and is_tensor flag, determine if a shape should be static.
Expand Down Expand Up @@ -2816,7 +2818,7 @@ def is_torch_function_object(value):
return hasattr(value, "__torch_function__")


def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bool:
def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool:
from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable
from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable

Expand Down Expand Up @@ -3048,3 +3050,20 @@ def _extract_tensor_dict(t):
}

return tensor_dict


# This is useful for reconstructing within the Dynamo graph the non-graph-input objects
# whose lifetime is governed by the user.
# e.g. torch.cuda.Event is a prime example.
user_obj_id_to_weakref: Dict[int, weakref.ReferenceType[object]] = {}


def get_user_object_from_id(obj_id):
obj = user_obj_id_to_weakref[obj_id]()
assert obj is not None, "User object is no longer alive"
return obj


def store_user_object_weakref(obj):
obj_id = id(obj)
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)

0 comments on commit 64130f1

Please sign in to comment.