From 7d65ddf37565adf0fc11de7dab826d4e26a44704 Mon Sep 17 00:00:00 2001 From: Dariush Wahdany <86673488+lsc64@users.noreply.github.com> Date: Sun, 19 May 2024 07:52:17 -0700 Subject: [PATCH] Fix BatchMemoryManager length (#641) Summary: ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Docs change / refactoring / dependency upgrade ## Motivation and Context / Related issue Fixes https://github.com/pytorch/opacus/issues/640 by ceiling the number of batches. ## How Has This Been Tested (if it applies) ## Checklist - [x] The documentation is up-to-date with the changes I made. - [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**). - [x] All tests passed, and additional code has been covered with new tests. Pull Request resolved: https://github.com/pytorch/opacus/pull/641 Reviewed By: HuanyuZhang Differential Revision: D55253377 fbshipit-source-id: 66c8217c016cedb871c95b79fc7ea1d506d5257e --- opacus/utils/batch_memory_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/opacus/utils/batch_memory_manager.py b/opacus/utils/batch_memory_manager.py index c5d6dcc0..feb56bb6 100644 --- a/opacus/utils/batch_memory_manager.py +++ b/opacus/utils/batch_memory_manager.py @@ -70,14 +70,16 @@ def __iter__(self): def __len__(self): if isinstance(self.sampler, BatchSampler): - return int( + return math.ceil( len(self.sampler) * (self.sampler.batch_size / self.max_batch_size) ) elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance( self.sampler, DistributedUniformWithReplacementSampler ): expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples - return int(len(self.sampler) * (expected_batch_size / self.max_batch_size)) + return math.ceil( + len(self.sampler) * (expected_batch_size / self.max_batch_size) + ) return len(self.sampler)