Skip to content

Commit

Permalink
Update SampleGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Sep 17, 2024
1 parent 87a50b1 commit 3539e6e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torch_xla/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import tempfile
import time

from torch.utils.data.dataloader import DataLoader


class Cleaner(object):

Expand Down Expand Up @@ -38,7 +40,7 @@ def __init__(self):
self.cleaner = Cleaner(lambda: shutil.rmtree(self.name))


class SampleGenerator(object):
class SampleGenerator(DataLoader):
"""Iterator which returns multiple samples of a given input data.
Can be used in place of a PyTorch `DataLoader` to generate synthetic data.
Expand All @@ -54,7 +56,7 @@ def __init__(self, data, sample_count):
self._count = 0

def __iter__(self):
return SampleGenerator(self._data, self._sample_count)
return self

def __len__(self):
return self._sample_count
Expand Down

0 comments on commit 3539e6e

Please sign in to comment.