Skip to content

Commit

Permalink
Fix lack of reset when iter() is called on the DALI framework itera…
Browse files Browse the repository at this point in the history
…tor (#4048)

- the DALI framework iterator when `iter()` is called checks if any
  data has been consumed yet. If it hasn't it doesn't reset to prevent
  improper operation when DALI FW iterator prefetches the first batch after
  creation, and invocation of something like `enumerate(iterator)`,
  would reset the iterator and extend the length of the first epoch
  (prefetched data in the FW iterator would be added to the freshly
  reset iterator). The flag that is set when any data has been consumed
  was set only in the `next()` method and not `__next__()`. This PR
  fixes this problem and adjusts the test

Signed-off-by: Janusz Lisiecki <[email protected]>
  • Loading branch information
JanuszL committed Jul 12, 2022
1 parent 3e5819a commit 83da787
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions dali/python/nvidia/dali/plugin/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def _populate_descriptors(self, data_batch):
self._descriptors_populated = True

def __next__(self):
self._ever_consumed = True
if self._first_batch is not None:
batch = self._first_batch
self._first_batch = None
Expand Down Expand Up @@ -726,6 +727,7 @@ def __init__(self,
"greater than the shard size."

def __next__(self):
self._ever_consumed = True
if self._first_batch is not None:
batch = self._first_batch
self._first_batch = None
Expand Down
1 change: 1 addition & 0 deletions dali/python/nvidia/dali/plugin/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def __init__(self,
"greater than the shard size."

def __next__(self):
self._ever_consumed = True
if self._first_batch is not None:
batch = self._first_batch
self._first_batch = None
Expand Down
1 change: 1 addition & 0 deletions dali/python/nvidia/dali/plugin/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def __init__(self,
"greater than the shard size."

def __next__(self):
self._ever_consumed = True
if self._first_batch is not None:
batch = self._first_batch
self._first_batch = None
Expand Down
4 changes: 3 additions & 1 deletion dali/test/python/test_fw_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,7 +1728,9 @@ def BoringPipeline():

loader = fw_iterator(pipeline, reader_name="reader", auto_reset=auto_reset_op)
for _ in range(2):
for i, data in enumerate(loader):
loader_iter = iter(loader)
for i in range(len(loader_iter)):
data = next(loader_iter)
for j, d in enumerate(extract_data(data[0])):
assert d[0] == i * batch_size + j, f"{d[0]} { i * batch_size + j}"

Expand Down

0 comments on commit 83da787

Please sign in to comment.