Skip to content

Commit

Permalink
Add peak_bytes_used to the get_memory_info (#8265)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Oct 16, 2024
1 parent 32afdbb commit a0f81e5
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/pjrt/test_runtime_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def test_memory_usage(self):
for usage in results.values():
self.assertIn('bytes_used', usage)
self.assertIn('bytes_limit', usage)
self.assertIn('peak_bytes_used', usage)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,7 +1542,7 @@ def get_memory_info(device: Optional[torch.device] = None) -> MemoryInfo:
Example:
>>> xm.get_memory_info()
{'bytes_used': 290816, 'bytes_limit': 34088157184}
{'bytes_used': 290816, 'bytes_limit': 34088157184, 'peak_bytes_used': 500816}
"""
if device == None:
device = xla_device()
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ py::dict GetMemoryInfo(const std::string& device_str) {
auto py_dict = py::dict();
py_dict["bytes_used"] = mem_info.bytes_used;
py_dict["bytes_limit"] = mem_info.bytes_limit;
py_dict["peak_bytes_used"] = mem_info.peak_bytes_used;
return py_dict;
}

Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class ComputationClient {
struct MemoryInfo {
int64_t bytes_used = 0;
int64_t bytes_limit = 0;
int64_t peak_bytes_used = 0;
};

virtual ~ComputationClient() {}
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,7 @@ ComputationClient::MemoryInfo PjRtComputationClient::GetMemoryInfo(
return {
stats.bytes_in_use,
*stats.bytes_limit,
stats.peak_bytes_in_use,
};
}

Expand Down

0 comments on commit a0f81e5

Please sign in to comment.