From a0f81e5e474b1ca032402cc8340452490d193169 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 16 Oct 2024 15:04:00 -0700 Subject: [PATCH] Add peak_bytes_used to the get_memory_info (#8265) --- test/pjrt/test_runtime_tpu.py | 1 + torch_xla/core/xla_model.py | 2 +- torch_xla/csrc/init_python_bindings.cpp | 1 + torch_xla/csrc/runtime/computation_client.h | 1 + torch_xla/csrc/runtime/pjrt_computation_client.cc | 1 + 5 files changed, 5 insertions(+), 1 deletion(-) diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index da009d49ed9..21fb55c8225 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -259,6 +259,7 @@ def test_memory_usage(self): for usage in results.values(): self.assertIn('bytes_used', usage) self.assertIn('bytes_limit', usage) + self.assertIn('peak_bytes_used', usage) if __name__ == '__main__': diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 0e403fb1982..931115db6d8 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1542,7 +1542,7 @@ def get_memory_info(device: Optional[torch.device] = None) -> MemoryInfo: Example: >>> xm.get_memory_info() - {'bytes_used': 290816, 'bytes_limit': 34088157184} + {'bytes_used': 290816, 'bytes_limit': 34088157184, 'peak_bytes_used': 500816} """ if device == None: device = xla_device() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 62c4cc9fc9e..7758cf32ddb 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -825,6 +825,7 @@ py::dict GetMemoryInfo(const std::string& device_str) { auto py_dict = py::dict(); py_dict["bytes_used"] = mem_info.bytes_used; py_dict["bytes_limit"] = mem_info.bytes_limit; + py_dict["peak_bytes_used"] = mem_info.peak_bytes_used; return py_dict; } diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 1b2819656cf..22940ee6595 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -255,6 +255,7 @@ class ComputationClient { struct MemoryInfo { int64_t bytes_used = 0; int64_t bytes_limit = 0; + int64_t peak_bytes_used = 0; }; virtual ~ComputationClient() {} diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 280a733bebe..3bc2f07fa95 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -991,6 +991,7 @@ ComputationClient::MemoryInfo PjRtComputationClient::GetMemoryInfo( return { stats.bytes_in_use, *stats.bytes_limit, + stats.peak_bytes_in_use, }; }