Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataPipe] extract keys #406

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchdata.datapipes.iter import (
BucketBatcher,
Cycler,
ExtractKeys,
Header,
IndexAdder,
InMemoryCacheHolder,
Expand Down Expand Up @@ -951,6 +952,30 @@ def test_mux_longest_iterdatapipe(self):
with self.assertRaises(TypeError):
len(output_dp)

def test_extractor(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_extractor(self):
def test_key_extractor(self):

nit: We used to have a different extractor


# Functional Test: verify that extracting by patterns yields correct output
stage1 = IterableWrapper([
{"1.txt": "1", "1.bin": "1b"},
{"2.txt": "2", "2.bin": "2b"},
])
stage2 = ExtractKeys(stage1, "*.txt", "*.bin", as_tuple=True)
output = list(iter(stage2))
self.assertEqual(output, [("1", "1b"), ("2", "2b")])
stage2 = ExtractKeys(stage1, "*.txt", "*.bin")
output = list(iter(stage2))
self.assertEqual(output, [
{"1.txt": "1", "1.bin": "1b"},
{"2.txt": "2", "2.bin": "2b"},
])
with self.assertRaisesRegex(ValueError, r"(?i)multiple sample keys"):
stage2 = ExtractKeys(stage1, "*")
output = list(iter(stage2))
with self.assertRaisesRegex(ValueError, r"selected twice"):
stage2 = ExtractKeys(stage1, "*.txt", "*t")
output = list(iter(stage2))


tmbdev marked this conversation as resolved.
Show resolved Hide resolved
def test_zip_longest_iterdatapipe(self):

# Functional Test: raises TypeError when an input is not of type `IterDataPipe`
Expand Down
8 changes: 7 additions & 1 deletion torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@
TFRecordLoaderIterDataPipe as TFRecordLoader,
)
from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper
from torchdata.datapipes.iter.util.webdataset import WebDatasetIterDataPipe as WebDataset
from torchdata.datapipes.iter.util.webdataset import (
WebDatasetIterDataPipe as WebDataset,
)
from torchdata.datapipes.iter.util.extractkeys import (
ExtractKeysIterDataPipe as ExtractKeys,
)
from torchdata.datapipes.iter.util.xzfileloader import (
XzFileLoaderIterDataPipe as XzFileLoader,
XzFileReaderIterDataPipe as XzFileReader,
Expand Down Expand Up @@ -151,6 +156,7 @@
"Dropper",
"EndOnDiskCacheHolder",
"Enumerator",
"ExtractKeys",
"Extractor",
"FSSpecFileLister",
"FSSpecFileOpener",
Expand Down
73 changes: 73 additions & 0 deletions torchdata/datapipes/iter/util/extractkeys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from fnmatch import fnmatch
from typing import Dict, Iterator, Tuple, Union

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe


@functional_datapipe("extract_keys")
class ExtractKeysIterDataPipe(IterDataPipe[Dict]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this to KeyExtractor to follow our naming convention? Thanks.

We can still keep "extract_keys" as the functional name.

r"""
Given a stream of dictionaries, return a stream of tuples by selecting keys using glob patterns.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Given a stream of dictionaries, return a stream of tuples by selecting keys using glob patterns.
Given a stream of dictionaries, return a stream of dicts (or tuples) by selecting keys using glob patterns.


Args:
source_datapipe: a DataPipe yielding a stream of dictionaries.
duplicate_is_error: it is an error if the same key is selected twice (True)
ignore_missing: skip any dictionaries where one or more patterns don't match (False)
Comment on lines +21 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
duplicate_is_error: it is an error if the same key is selected twice (True)
ignore_missing: skip any dictionaries where one or more patterns don't match (False)

Duplicate lines of descriptions

*args: list of glob patterns or list of glob patterns
duplicate_is_error: it is an error if the same key is selected twice (True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
duplicate_is_error: it is an error if the same key is selected twice (True)
duplicate_is_error: it is an error if the same key is selected twice (True), otherwise returns the first matched value

ignore_missing: allow patterns not to match (i.e., incomplete outputs)
as_tuple: return a tuple instead of a dictionary
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
as_tuple: return a tuple instead of a dictionary
as_tuple: return a tuple instead of a dictionary (True or False here)


Returns:
a DataPipe yielding a stream of tuples

Examples:
>>> dp = FileLister(...).load_from_tar().webdataset().decode(...).extract_keys(["*.jpg", "*.png"], "*.gt.txt")
"""
Comment on lines +32 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to the one example with webdataset, please add an example with sample outputs here. Copying from the test cases is totally fine to me.


def __init__(
self, source_datapipe: IterDataPipe[Dict], *args, duplicate_is_error=True, ignore_missing=False, as_tuple=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to default as_tuple=False? Based on the docstring I would've guessed you wanted True instead.

Suggested change
self, source_datapipe: IterDataPipe[Dict], *args, duplicate_is_error=True, ignore_missing=False, as_tuple=False
self, source_datapipe: IterDataPipe[Dict], *args, duplicate_is_error: bool = True, ignore_missing: bool = False, as_tuple: bool = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: allow_duplicate might be a better name than duplicate_is_error

) -> None:
super().__init__()
self.source_datapipe: IterDataPipe[Dict] = source_datapipe
self.duplicate_is_error = duplicate_is_error
self.patterns = args
self.ignore_missing = ignore_missing
self.as_tuple = as_tuple

def __iter__(self) -> Union[Iterator[Tuple], Iterator[Dict]]: # type: ignore
for sample in self.source_datapipe:
result = []
used = set()
for pattern in self.patterns:
pattern = [pattern] if not isinstance(pattern, (list, tuple)) else pattern
matches = [x for x in sample.keys() if any(fnmatch(x, p) for p in pattern)]
if len(matches) == 0:
if self.ignore_missing:
continue
else:
raise ValueError(f"extract_keys: cannot find {pattern} in sample keys {sample.keys()}.")
if len(matches) > 1 and self.duplicate_is_error:
tmbdev marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"extract_keys: multiple sample keys {sample.keys()} match {pattern}.")
if matches[0] in used and self.duplicate_is_error:
raise ValueError(f"extract_keys: key {matches[0]} is selected twice.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError(f"extract_keys: key {matches[0]} is selected twice.")
raise ValueError(f"extract_keys: key {matches[0]} is selected twice by multiple patterns.")

nit

used.add(matches[0])
value = sample[matches[0]]
if self.as_tuple:
result.append(value)
else:
result.append((matches[0], value))
if self.as_tuple:
yield tuple(result)
else:
yield {k: v for k, v in result}

def __len__(self) -> int:
return len(self.source_datapipe)
Comment on lines +72 to +73
Copy link
Contributor

@NivekT NivekT Sep 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: A sample will always be yielded even if nothing matches right?