From 3539e6e6917ad772dc6b78a85f59472824c56c9a Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Tue, 17 Sep 2024 20:53:50 +0000 Subject: [PATCH] Update SampleGenerator --- torch_xla/utils/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_xla/utils/utils.py b/torch_xla/utils/utils.py index 04a49cb592f..9c5b759d062 100755 --- a/torch_xla/utils/utils.py +++ b/torch_xla/utils/utils.py @@ -8,6 +8,8 @@ import tempfile import time +from torch.utils.data.dataloader import DataLoader + class Cleaner(object): @@ -38,7 +40,7 @@ def __init__(self): self.cleaner = Cleaner(lambda: shutil.rmtree(self.name)) -class SampleGenerator(object): +class SampleGenerator(DataLoader): """Iterator which returns multiple samples of a given input data. Can be used in place of a PyTorch `DataLoader` to generate synthetic data. @@ -54,7 +56,7 @@ def __init__(self, data, sample_count): self._count = 0 def __iter__(self): - return SampleGenerator(self._data, self._sample_count) + return self def __len__(self): return self._sample_count