Skip to content

Commit

Permalink
Fix loader to read from a glob pattern (#877)
Browse files Browse the repository at this point in the history
* Fix loader to read from a glob pattern

* Fix to read from general UPath instead of Path

* Update tests to use glob patterns

* Refactor to simplify check for glob pattern
  • Loading branch information
plaguss authored Aug 14, 2024
1 parent 04d0bf0 commit f382f1c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
24 changes: 22 additions & 2 deletions src/distilabel/steps/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,23 @@ class LoadDataFromFileSystem(LoadDataFromHub):
# >>> result
# ([{'type': 'function', 'function':...', False)
```
Load data passing a glob pattern:
```python
from distilabel.steps import LoadDataFromFileSystem
loader = LoadDataFromFileSystem(
data_files="path/to/dataset/*.jsonl",
streaming=True
)
loader.load()
# Just like we saw with LoadDataFromDicts, the `process` method will yield batches.
result = next(loader.process())
# >>> result
# ([{'type': 'function', 'function':...', False)
```
"""

data_files: RuntimeParameter[Union[str, Path]] = Field(
Expand Down Expand Up @@ -376,7 +393,7 @@ def load(self) -> None:
self.num_examples = len(self._dataset)

@staticmethod
def _prepare_data_files(
def _prepare_data_files( # noqa: C901
data_path: UPath,
) -> Tuple[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]], str]:
"""Prepare the loading process by setting the `data_files` attribute.
Expand All @@ -394,9 +411,12 @@ def get_filetype(data_path: UPath) -> str:
filetype = "json"
return filetype

if data_path.is_file():
if data_path.is_file() or (
len(str(data_path.parent.glob(data_path.name))) >= 1
):
filetype = get_filetype(data_path)
data_files = str(data_path)

elif data_path.is_dir():
file_sequence = []
file_map = defaultdict(list)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/steps/generators/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_read_from_jsonl_with_folder(self, filetype: Union[str, None]) -> None:

loader = LoadDataFromFileSystem(
filetype=filetype,
data_files=tmpdir,
data_files=str(Path(tmpdir) / "*.jsonl"),
)
loader.load()
generator_step_output = next(loader.process())
Expand All @@ -127,7 +127,7 @@ def test_read_from_jsonl_with_nested_folder(

loader = LoadDataFromFileSystem(
filetype=filetype,
data_files=tmpdir,
data_files=str(Path(tmpdir) / "**/*.jsonl"),
)
loader.load()
generator_step_output = next(loader.process())
Expand Down

0 comments on commit f382f1c

Please sign in to comment.