diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 135329166..a46802070 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -341,10 +341,18 @@ def dynamo_timed( remote_cache_time_saved = frame_phase_timing[ compile_id ].get("remote_cache_time_saved", None) + remote_fx_graph_cache_get_time = frame_phase_timing[ + compile_id + ].get("remote_fx_graph_cache_get", None) + remote_fx_graph_cache_put_time = frame_phase_timing[ + compile_id + ].get("remote_fx_graph_cache_put", None) else: inductor_compile_time = None code_gen_time = None remote_cache_time_saved = None + remote_fx_graph_cache_get_time = None + remote_fx_graph_cache_put_time = None structured_logging_overhead_s = ( torch._logging.get_structured_logging_overhead() ) @@ -356,6 +364,8 @@ def dynamo_timed( fail_reason, remote_cache_time_saved, structured_logging_overhead_s, + to_int_ms(remote_fx_graph_cache_get_time), + to_int_ms(remote_fx_graph_cache_put_time), ) record_compilation_metrics(metrics) @@ -762,6 +772,10 @@ def proxy_args_kwargs(args, kwargs): ) +def to_int_ms(v: Optional[float]) -> Optional[int]: + return None if v is None else int(v * 1000) + + @dataclasses.dataclass class CompilationMetrics: is_forward: bool = dataclasses.field(default=True, init=False) @@ -801,6 +815,8 @@ class CompilationMetrics: config_inline_inbuilt_nn_modules: Optional[bool] specialize_float: Optional[bool] dynamo_config: Optional[str] + remote_fx_graph_cache_get_time_ms: Optional[int] + remote_fx_graph_cache_put_time_ms: Optional[int] @dataclasses.dataclass @@ -813,6 +829,8 @@ class BwdCompilationMetrics: fail_reason: Optional[str] remote_cache_time_saved_s: Optional[float] structured_logging_overhead_s: Optional[float] + remote_fx_graph_cache_get_time_ms: Optional[int] + remote_fx_graph_cache_put_time_ms: Optional[int] DEFAULT_COMPILATION_METRICS_LIMIT = 64