Skip to content

Commit

Permalink
Fix BatchMemoryManager length (#641)
Browse files Browse the repository at this point in the history
Summary:
## Types of changes

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Docs change / refactoring / dependency upgrade

## Motivation and Context / Related issue

Fixes #640 by ceiling the number of batches.

## How Has This Been Tested (if it applies)

## Checklist

- [x] The documentation is up-to-date with the changes I made.
- [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**).
- [x] All tests passed, and additional code has been covered with new tests.

Pull Request resolved: #641

Reviewed By: HuanyuZhang

Differential Revision: D55253377

fbshipit-source-id: 66c8217c016cedb871c95b79fc7ea1d506d5257e
  • Loading branch information
dwahdany authored and facebook-github-bot committed May 19, 2024
1 parent 32a465b commit 7d65ddf
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions opacus/utils/batch_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,16 @@ def __iter__(self):

def __len__(self):
if isinstance(self.sampler, BatchSampler):
return int(
return math.ceil(
len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
)
elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance(
self.sampler, DistributedUniformWithReplacementSampler
):
expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
return int(len(self.sampler) * (expected_batch_size / self.max_batch_size))
return math.ceil(
len(self.sampler) * (expected_batch_size / self.max_batch_size)
)

return len(self.sampler)

Expand Down

0 comments on commit 7d65ddf

Please sign in to comment.