From c8df5a99e064affa3465a335a408e7b4b0f01976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 14 Aug 2024 11:12:27 +0200 Subject: [PATCH] Add `save_artifact` method to `_Step` (#871) * Add `save_artifact` method * Upload pipeline generated artifacts * Fix log file was being saved in different cache * Update `save_to_disk` to also save artifacts * Render artifacts in card * Update unit tests * Add missing unit tests * Update src/distilabel/distiset.py Co-authored-by: Agus * Add section about saving artifacts * Add correct `edit_uri` --------- Co-authored-by: Agus --- .../saving_step_generated_artifacts.md | 123 ++++++++++++++++++ mkdocs.yml | 2 + src/distilabel/constants.py | 10 ++ src/distilabel/distiset.py | 123 +++++++++++++++--- src/distilabel/pipeline/base.py | 54 +++++--- src/distilabel/pipeline/constants.py | 22 ---- src/distilabel/steps/base.py | 75 ++++++++++- src/distilabel/steps/constants.py | 17 --- .../steps/embeddings/nearest_neighbour.py | 22 +++- .../utils/card/distilabel_template.md | 15 +++ tests/unit/pipeline/test_base.py | 23 +++- tests/unit/pipeline/test_dag.py | 2 +- tests/unit/pipeline/test_write_buffer.py | 36 +++-- tests/unit/steps/test_base.py | 46 ++++++- tests/unit/test_distiset.py | 95 ++++++++++++-- 15 files changed, 557 insertions(+), 108 deletions(-) create mode 100644 docs/sections/how_to_guides/advanced/saving_step_generated_artifacts.md delete mode 100644 src/distilabel/pipeline/constants.py delete mode 100644 src/distilabel/steps/constants.py diff --git a/docs/sections/how_to_guides/advanced/saving_step_generated_artifacts.md b/docs/sections/how_to_guides/advanced/saving_step_generated_artifacts.md new file mode 100644 index 0000000000..9e89f07491 --- /dev/null +++ b/docs/sections/how_to_guides/advanced/saving_step_generated_artifacts.md @@ -0,0 +1,123 @@ +# Saving step generated artifacts + +Some `Step`s might need to produce an auxiliary artifact that is not a result of the computation, but is needed for the computation. For example, the [`FaissNearestNeighbour`](/distilabel/components-gallery/steps/faissnearestneighbour/) needs to create a Faiss index to compute the output of the step which are the top `k` nearest neighbours for each input. Generating the Faiss index takes time and it could potentially be reused outside of the `distilabel` pipeline, so it would be a shame not saving it. + +For this reason, `Step`s have a method called `save_artifact` that allows saving artifacts that will be included along the outputs of the pipeline in the generated [`Distiset`][distilabel.distiset.Distiset]. The generated artifacts will be uploaded and saved when using `Distiset.push_to_hub` or `Distiset.save_to_disk` respectively. Let's see how to use it with a simple example. + +```python +from typing import List, TYPE_CHECKING +from distilabel.steps import GlobalStep, StepInput, StepOutput +import matplotlib.pyplot as plt + +if TYPE_CHECKING: + from distilabel.steps import StepOutput + + +class CountTextCharacters(GlobalStep): + @property + def inputs(self) -> List[str]: + return ["text"] + + @property + def outputs(self) -> List[str]: + return ["text_character_count"] + + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + character_counts = [] + + for input in inputs: + text_character_count = len(input["text"]) + input["text_character_count"] = text_character_count + character_counts.append(text_character_count) + + # Generate plot with the distribution of text character counts + plt.figure(figsize=(10, 6)) + plt.hist(character_counts, bins=30, edgecolor="black") + plt.title("Distribution of Text Character Counts") + plt.xlabel("Character Count") + plt.ylabel("Frequency") + + # Save the plot as an artifact of the step + self.save_artifact( + name="text_character_count_distribution", + write_function=lambda path: plt.savefig(path / "figure.png"), + metadata={"type": "image", "library": "matplotlib"}, + ) + + plt.close() + + yield inputs +``` + +As it can be seen in the example above, we have created a simple step that counts the number of characters in each input text and generates a histogram with the distribution of the character counts. We save the histogram as an artifact of the step using the `save_artifact` method. The method takes three arguments: + +- `name`: The name we want to give to the artifact. +- `write_function`: A function that writes the artifact to the desired path. The function will receive a `path` argument which is a `pathlib.Path` object pointing to the directory where the artifact should be saved. +- `metadata`: A dictionary with metadata about the artifact. This metadata will be saved along with the artifact. + +Let's execute the step with a simple pipeline and push the resulting `Distiset` to the Hugging Face Hub: + +??? "Example full code" + + ```python + from typing import TYPE_CHECKING, List + + import matplotlib.pyplot as plt + from datasets import load_dataset + from distilabel.pipeline import Pipeline + from distilabel.steps import GlobalStep, StepInput, StepOutput + + if TYPE_CHECKING: + from distilabel.steps import StepOutput + + + class CountTextCharacters(GlobalStep): + @property + def inputs(self) -> List[str]: + return ["text"] + + @property + def outputs(self) -> List[str]: + return ["text_character_count"] + + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + character_counts = [] + + for input in inputs: + text_character_count = len(input["text"]) + input["text_character_count"] = text_character_count + character_counts.append(text_character_count) + + # Generate plot with the distribution of text character counts + plt.figure(figsize=(10, 6)) + plt.hist(character_counts, bins=30, edgecolor="black") + plt.title("Distribution of Text Character Counts") + plt.xlabel("Character Count") + plt.ylabel("Frequency") + + # Save the plot as an artifact of the step + self.save_artifact( + name="text_character_count_distribution", + write_function=lambda path: plt.savefig(path / "figure.png"), + metadata={"type": "image", "library": "matplotlib"}, + ) + + plt.close() + + yield inputs + + + with Pipeline() as pipeline: + count_text_characters = CountTextCharacters() + + if __name__ == "__main__": + distiset = pipeline.run( + dataset=load_dataset( + "HuggingFaceH4/instruction-dataset", split="test" + ).rename_column("prompt", "text"), + ) + + distiset.push_to_hub("distilabel-internal-testing/distilabel-artifacts-example") + ``` + +The generated [distilabel-internal-testing/distilabel-artifacts-example](https://huggingface.co/datasets/distilabel-internal-testing/distilabel-artifacts-example) dataset repository has a section in its card [describing the artifacts generated by the pipeline](https://huggingface.co/datasets/distilabel-internal-testing/distilabel-artifacts-example#artifacts) and the generated plot can be seen [here](https://huggingface.co/datasets/distilabel-internal-testing/distilabel-artifacts-example/blob/main/artifacts/count_text_characters_0/text_character_count_distribution/figure.png). diff --git a/mkdocs.yml b/mkdocs.yml index 8808f27a46..082c7ef27a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -7,6 +7,7 @@ site_description: Distilabel is an AI Feedback (AIF) framework for building data # Repository repo_name: argilla-io/distilabel repo_url: https://github.com/argilla-io/distilabel +edit_uri: edit/main/docs/ extra: version: @@ -179,6 +180,7 @@ nav: - Using CLI to explore and re-run existing Pipelines: "sections/how_to_guides/advanced/cli/index.md" - Using a file system to pass data of batches between steps: "sections/how_to_guides/advanced/fs_to_pass_data.md" - Assigning resources to a step: "sections/how_to_guides/advanced/assigning_resources_to_step.md" + - Saving step generated artifacts: "sections/how_to_guides/advanced/saving_step_generated_artifacts.md" - Serving an LLM for sharing it between several tasks: "sections/how_to_guides/advanced/serving_an_llm_for_reuse.md" - Scaling and distributing a pipeline with Ray: "sections/how_to_guides/advanced/scaling_with_ray.md" - Pipeline Samples: diff --git a/src/distilabel/constants.py b/src/distilabel/constants.py index a1400bcd03..bd636f8165 100644 --- a/src/distilabel/constants.py +++ b/src/distilabel/constants.py @@ -25,6 +25,16 @@ CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step" LAST_BATCH_SENT_FLAG: Final[str] = "last_batch_sent" +# Data paths constants +STEPS_OUTPUTS_PATH = "steps_outputs" +STEPS_ARTIFACTS_PATH = "steps_artifacts" + +# Distiset related constants +DISTISET_CONFIG_FOLDER: Final[str] = "distiset_configs" +DISTISET_ARTIFACTS_FOLDER: Final[str] = "artifacts" +PIPELINE_CONFIG_FILENAME: Final[str] = "pipeline.yaml" +PIPELINE_LOG_FILENAME: Final[str] = "pipeline.log" + __all__ = [ "STEP_ATTR_NAME", diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index 92eedecfca..2263bc1553 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -12,24 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os.path as posixpath import re import sys +from collections import defaultdict from os import PathLike from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union import fsspec import yaml from datasets import Dataset, load_dataset, load_from_disk from datasets.filesystems import is_remote_filesystem -from huggingface_hub import DatasetCardData, HfApi, upload_file +from huggingface_hub import DatasetCardData, HfApi, upload_file, upload_folder from huggingface_hub.file_download import hf_hub_download from pyarrow.lib import ArrowInvalid from typing_extensions import Self -from distilabel.constants import STEP_ATTR_NAME +from distilabel.constants import ( + DISTISET_ARTIFACTS_FOLDER, + DISTISET_CONFIG_FOLDER, + PIPELINE_CONFIG_FILENAME, + PIPELINE_LOG_FILENAME, + STEP_ATTR_NAME, + STEPS_ARTIFACTS_PATH, + STEPS_OUTPUTS_PATH, +) from distilabel.utils.card.dataset_card import ( DistilabelDatasetCard, size_categories_parser, @@ -42,11 +52,6 @@ from distilabel.pipeline._dag import DAG -DISTISET_CONFIG_FOLDER: Final[str] = "distiset_configs" -PIPELINE_CONFIG_FILENAME: Final[str] = "pipeline.yaml" -PIPELINE_LOG_FILENAME: Final[str] = "pipeline.log" - - class Distiset(dict): """Convenient wrapper around `datasets.Dataset` to push to the Hugging Face Hub. @@ -54,12 +59,18 @@ class Distiset(dict): `DAG` and the values are `datasets.Dataset`. Attributes: - pipeline_path: Optional path to the pipeline.yaml file that generated the dataset. - log_filename_path: Optional path to the pipeline.log file that generated was written by the - pipeline. + _pipeline_path: Optional path to the `pipeline.yaml` file that generated the dataset. + Defaults to `None`. + _artifacts_path: Optional path to the directory containing the generated artifacts + by the pipeline steps. Defaults to `None`. + _log_filename_path: Optional path to the `pipeline.log` file that generated was written + by the pipeline. Defaults to `None`. + _citations: Optional list containing citations that will be included in the dataset + card. Defaults to `None`. """ _pipeline_path: Optional[Path] = None + _artifacts_path: Optional[Path] = None _log_filename_path: Optional[Path] = None _citations: Optional[List[str]] = None @@ -121,6 +132,16 @@ def push_to_hub( **kwargs, ) + if self.artifacts_path: + upload_folder( + repo_id=repo_id, + folder_path=self.artifacts_path, + path_in_repo="artifacts", + token=token, + repo_type="dataset", + commit_message="Include pipeline artifacts", + ) + if include_script and script_path.exists(): upload_file( path_or_fileobj=script_path, @@ -128,7 +149,7 @@ def push_to_hub( repo_id=repo_id, repo_type="dataset", token=token, - commit_message="Include pipeline script.", + commit_message="Include pipeline script", ) if generate_card: @@ -185,11 +206,38 @@ def _get_card( sample_records=sample_records, include_script=include_script, filename_py=filename_py, + artifacts=self._get_artifacts_metadata(), references=self.citations, ) return card + def _get_artifacts_metadata(self) -> Dict[str, List[Dict[str, Any]]]: + """Gets a dictionary with the metadata of the artifacts generated by the pipeline steps. + + Returns: + A dictionary in which the key is the name of the step and the value is a list + of dictionaries, each of them containing the name and metadata of the step artifact. + """ + if not self.artifacts_path: + return {} + + def iterdir_ignore_hidden(path: Path) -> Generator[Path, None, None]: + return (f for f in Path(path).iterdir() if not f.name.startswith(".")) + + artifacts_metadata = defaultdict(list) + for step_artifacts_dir in iterdir_ignore_hidden(self.artifacts_path): + step_name = step_artifacts_dir.stem + for artifact_dir in iterdir_ignore_hidden(step_artifacts_dir): + artifact_name = artifact_dir.stem + metadata_path = artifact_dir / "metadata.json" + metadata = json.loads(metadata_path.read_text()) + artifacts_metadata[step_name].append( + {"name": artifact_name, "metadata": metadata} + ) + + return dict(artifacts_metadata) + def _extract_readme_metadata( self, repo_id: str, token: Optional[str] ) -> Dict[str, Any]: @@ -243,6 +291,7 @@ def _generate_card( repo_type="dataset", token=token, ) + if self.pipeline_path: # If the pipeline.yaml is available, upload it to the Hugging Face Hub as well. HfApi().upload_file( @@ -252,6 +301,7 @@ def _generate_card( repo_type="dataset", token=token, ) + if self.log_filename_path: # The same we had with "pipeline.yaml" but with the log file. HfApi().upload_file( @@ -360,6 +410,12 @@ def save_to_disk( ) fs.makedirs(distiset_config_folder, exist_ok=True) + if self.artifacts_path: + distiset_artifacts_folder = posixpath.join( + distiset_path, DISTISET_ARTIFACTS_FOLDER + ) + fs.copy(str(self.artifacts_path), distiset_artifacts_folder, recursive=True) + if save_card: # NOTE: Currently the card is not the same if we write to disk or push to the HF hub, # as we aren't generating the README copying/updating the data from the dataset repo. @@ -415,7 +471,7 @@ def load_from_disk( original_distiset_path = str(distiset_path) fs: fsspec.AbstractFileSystem - fs, _, [distiset_path] = fsspec.get_fs_token_paths( + fs, _, [distiset_path] = fsspec.get_fs_token_paths( # type: ignore original_distiset_path, storage_options=storage_options ) dest_distiset_path = distiset_path @@ -425,6 +481,7 @@ def load_from_disk( ), "`distiset_path` must be a `PathLike` object pointing to a folder or a URI of a remote filesystem." has_config = False + has_artifacts = False distiset = cls() if is_remote_filesystem(fs): @@ -432,19 +489,23 @@ def load_from_disk( if download_dir: dest_distiset_path = download_dir else: - dest_distiset_path = Dataset._build_local_temp_path(src_dataset_path) - fs.download(src_dataset_path, dest_distiset_path.as_posix(), recursive=True) + dest_distiset_path = Dataset._build_local_temp_path(src_dataset_path) # type: ignore + fs.download(src_dataset_path, dest_distiset_path.as_posix(), recursive=True) # type: ignore # Now we should have the distiset locally, so we can read those files for folder in Path(dest_distiset_path).iterdir(): if folder.stem == DISTISET_CONFIG_FOLDER: has_config = True continue + elif folder.stem == DISTISET_ARTIFACTS_FOLDER: + has_artifacts = True + continue distiset[folder.stem] = load_from_disk( str(folder), keep_in_memory=keep_in_memory, ) - # From the config folder we just need to point to the files. Once downloaded we set the path + + # From the config folder we just need to point to the files. Once downloaded we set the path to point to point to the files. Once downloaded we set the path # to wherever they are. if has_config: distiset_config_folder = posixpath.join( @@ -463,6 +524,11 @@ def load_from_disk( if Path(log_filename_path).exists(): distiset.log_filename_path = Path(log_filename_path) + if has_artifacts: + distiset.artifacts_path = Path( + posixpath.join(dest_distiset_path, DISTISET_ARTIFACTS_FOLDER) + ) + return distiset @property @@ -474,6 +540,16 @@ def pipeline_path(self) -> Union[Path, None]: def pipeline_path(self, path: PathLike) -> None: self._pipeline_path = Path(path) + @property + def artifacts_path(self) -> Union[Path, None]: + """Returns the path to the directory containing the artifacts generated by the steps + of the pipeline.""" + return self._artifacts_path + + @artifacts_path.setter + def artifacts_path(self, path: PathLike) -> None: + self._artifacts_path = Path(path) + @property def log_filename_path(self) -> Union[Path, None]: """Returns the path to the `pipeline.log` file that generated the `Pipeline`.""" @@ -540,10 +616,10 @@ def create_distiset( # noqa: C901 logger = logging.getLogger("distilabel.distiset") - data_dir = Path(data_dir) + steps_outputs_dir = data_dir / STEPS_OUTPUTS_PATH distiset = Distiset() - for file in data_dir.iterdir(): + for file in steps_outputs_dir.iterdir(): if file.is_file(): continue @@ -569,19 +645,26 @@ def create_distiset( # noqa: C901 if len(distiset.keys()) == 1: distiset["default"] = distiset.pop(list(distiset.keys())[0]) + # If there's any artifact set the `artifacts_path` so they can be uploaded + steps_artifacts_dir = data_dir / STEPS_ARTIFACTS_PATH + if any(steps_artifacts_dir.rglob("*")): + distiset.artifacts_path = steps_artifacts_dir + + # Include `pipeline.yaml` if exists if pipeline_path: distiset.pipeline_path = pipeline_path else: # If the pipeline path is not provided, try to find it in the parent directory # and assume that's the wanted file. - pipeline_path = data_dir.parent / "pipeline.yaml" + pipeline_path = steps_outputs_dir.parent / "pipeline.yaml" if pipeline_path.exists(): distiset.pipeline_path = pipeline_path + # Include `pipeline.log` if exists if log_filename_path: distiset.log_filename_path = log_filename_path else: - log_filename_path = data_dir.parent / "pipeline.log" + log_filename_path = steps_outputs_dir.parent / "pipeline.log" if log_filename_path.exists(): distiset.log_filename_path = log_filename_path diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 5f4f6afe97..911273c2c9 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -46,6 +46,8 @@ RECEIVES_ROUTED_BATCHES_ATTR_NAME, ROUTING_BATCH_FUNCTION_ATTR_NAME, STEP_ATTR_NAME, + STEPS_ARTIFACTS_PATH, + STEPS_OUTPUTS_PATH, ) from distilabel.distiset import create_distiset from distilabel.mixins.requirements import RequirementsMixin @@ -128,6 +130,8 @@ def get_pipeline(cls) -> Union["BasePipeline", None]: _STEP_LOAD_FAILED_CODE = -666 _STEP_NOT_LOADED_CODE = -999 +_ATTRIBUTES_IGNORED_CACHE = ("disable_cuda_device_placement",) + class BasePipeline(ABC, RequirementsMixin, _Serializable): """Base class for a `distilabel` pipeline. @@ -257,14 +261,14 @@ def _create_signature(self) -> str: [ f"{str(k)}={str(v)}" for k, v in value.items() - if k not in ("disable_cuda_device_placement",) + if k not in _ATTRIBUTES_IGNORED_CACHE ] ) elif isinstance(value, (list, tuple)): # runtime_parameters_info step_info += "-".join([str(v) for v in value]) elif isinstance(value, (int, str, float, bool)): - if argument != "disable_cuda_device_placement": + if argument not in _ATTRIBUTES_IGNORED_CACHE: # batch_size/name step_info += str(value) else: @@ -340,17 +344,19 @@ def run( # cache when the pipeline is run, so it's important to do it first. self._set_runtime_parameters(parameters or {}) + if dataset is not None: + self._add_dataset_generator_step(dataset) + setup_logging( log_queue=self._log_queue, filename=str(self._cache_location["log_file"]) ) - if dataset is not None: - self._add_dataset_generator_step(dataset) - # Validate the pipeline DAG to check that all the steps are chainable, there are # no missing runtime parameters, batch sizes are correct, etc. self.dag.validate() + self._set_pipeline_artifacts_path_in_steps() + # Set the initial load status for all the steps self._init_steps_load_status() @@ -360,12 +366,8 @@ def run( # Load the `_BatchManager` from cache or create one from scratch self._load_batch_manager(use_cache) - if to_install := self.requirements_to_install(): - # Print the list of requirements like they would appear in a requirements.txt - to_install_list = "\n" + "\n".join(to_install) - msg = f"Please install the following requirements to run the pipeline: {to_install_list}" - self._logger.error(msg) - raise ModuleNotFoundError(msg) + # Check pipeline requirements are installed + self._check_requirements() # Setup the filesystem that will be used to pass the data of the `_Batch`es self._setup_fsspec(storage_parameters) @@ -383,7 +385,7 @@ def run( " Returning `Distiset` from cache data..." ) distiset = create_distiset( - self._cache_location["data"], + data_dir=self._cache_location["data"], pipeline_path=self._cache_location["pipeline"], log_filename_path=self._cache_location["log_file"], enable_metadata=self._enable_metadata, @@ -450,8 +452,9 @@ def _add_dataset_generator_step(self, dataset: "InputDataset") -> None: step = self.dag.get_step(step_name)[STEP_ATTR_NAME] if isinstance(step_name, GeneratorStep): raise ValueError( - "There is already a `GeneratorStep` in the pipeline, you can either pass a `dataset` to the " - f"run method, or create a `GeneratorStep` explictly. `GeneratorStep`: {step}" + "There is already a `GeneratorStep` in the pipeline, you can either" + " pass a `dataset` to the run method, or create a `GeneratorStep` explictly." + f" `GeneratorStep`: {step}" ) loader = make_generator_step(dataset) self.dag.add_root_step(loader) @@ -475,6 +478,27 @@ def _init_steps_load_status(self) -> None: for step_name in self.dag: self._steps_load_status[step_name] = _STEP_NOT_LOADED_CODE + def _set_pipeline_artifacts_path_in_steps(self) -> None: + """Sets the attribute `_pipeline_artifacts_path` in all the `Step`s of the pipeline, + so steps can use it to get the path to save the generated artifacts.""" + artifacts_path = self._cache_location["data"] / STEPS_ARTIFACTS_PATH + for name in self.dag: + step: "_Step" = self.dag.get_step(name)[STEP_ATTR_NAME] + step.set_pipeline_artifacts_path(path=artifacts_path) + + def _check_requirements(self) -> None: + """Checks if the dependencies required to run the pipeline are installed. + + Raises: + ModuleNotFoundError: if one or more requirements are missing. + """ + if to_install := self.requirements_to_install(): + # Print the list of requirements like they would appear in a requirements.txt + to_install_list = "\n" + "\n".join(to_install) + msg = f"Please install the following requirements to run the pipeline: {to_install_list}" + self._logger.error(msg) + raise ModuleNotFoundError(msg) + def _setup_fsspec( self, storage_parameters: Optional[Dict[str, Any]] = None ) -> None: @@ -696,7 +720,7 @@ def _setup_write_buffer(self) -> None: """Setups the `_WriteBuffer` that will store the data of the leaf steps of the pipeline while running, so the `Distiset` can be created at the end. """ - buffer_data_path = self._cache_location["data"] + buffer_data_path = self._cache_location["data"] / STEPS_OUTPUTS_PATH self._logger.info(f"📝 Pipeline data will be written to '{buffer_data_path}'") self._write_buffer = _WriteBuffer(buffer_data_path, self.dag.leaf_steps) diff --git a/src/distilabel/pipeline/constants.py b/src/distilabel/pipeline/constants.py deleted file mode 100644 index 3d400e4a1b..0000000000 --- a/src/distilabel/pipeline/constants.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2023-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from typing import Final - -STEP_ATTR_NAME: Final[str] = "step" -INPUT_QUEUE_ATTR_NAME: Final[str] = "input_queue" -RECEIVES_ROUTED_BATCHES_ATTR_NAME: Final[str] = "receives_routed_batches" -ROUTING_BATCH_FUNCTION_ATTR_NAME: Final[str] = "routing_batch_function" -CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step" -LAST_BATCH_SENT_FLAG: Final[str] = "last_batch_sent" diff --git a/src/distilabel/steps/base.py b/src/distilabel/steps/base.py index 940b05d812..1aa6e8bb1d 100644 --- a/src/distilabel/steps/base.py +++ b/src/distilabel/steps/base.py @@ -17,7 +17,18 @@ import re from abc import ABC, abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, overload +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + overload, +) from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr from typing_extensions import Annotated, Self @@ -27,7 +38,7 @@ RuntimeParameter, RuntimeParametersMixin, ) -from distilabel.utils.serialization import _Serializable +from distilabel.utils.serialization import _Serializable, write_json from distilabel.utils.typing_ import is_parameter_annotated_with if TYPE_CHECKING: @@ -182,6 +193,7 @@ def process(self, inputs: *StepInput) -> StepOutput: input_mappings: Dict[str, str] = {} output_mappings: Dict[str, str] = {} + _pipeline_artifacts_path: Path = PrivateAttr(None) _built_from_decorator: bool = PrivateAttr(default=False) _logger: "Logger" = PrivateAttr(None) @@ -485,6 +497,65 @@ def get_outputs(self) -> List[str]: """ return [self.output_mappings.get(output, output) for output in self.outputs] + def set_pipeline_artifacts_path(self, path: Path) -> None: + """Sets the `_pipeline_artifacts_path` attribute. This method is meant to be used + by the `Pipeline` once the cache location is known. + + Args: + path: the path where the artifacts generated by the pipeline steps should be + saved. + """ + self._pipeline_artifacts_path = path + + @property + def artifacts_directory(self) -> Union[Path, None]: + """Gets the path of the directory where the step should save its generated artifacts. + + Returns: + The path of the directory where the step should save the generated artifacts, + or `None` if `_pipeline_artifacts_path` is not set. + """ + if self._pipeline_artifacts_path is None: + return None + return self._pipeline_artifacts_path / self.name # type: ignore + + def save_artifact( + self, + name: str, + write_function: Callable[[Path], None], + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Saves an artifact generated by the `Step`. + + Args: + name: the name of the artifact. + write_function: a function that will receive the path where the artifact should + be saved. + metadata: the artifact metadata. Defaults to `None`. + """ + if self.artifacts_directory is None: + self._logger.warning( + f"Cannot save artifact with '{name}' as `_pipeline_artifacts_path` is not" + " set. This is normal if the `Step` is being executed as a standalone component." + ) + return + + artifact_directory_path = self.artifacts_directory / name + artifact_directory_path.mkdir(parents=True, exist_ok=True) + + self._logger.info(f"🏺 Storing '{name}' generated artifact...") + + self._logger.debug( + f"Calling `write_function` to write artifact in '{artifact_directory_path}'..." + ) + write_function(artifact_directory_path) + + metadata_path = artifact_directory_path / "metadata.json" + self._logger.debug( + f"Calling `write_json` to write artifact metadata in '{metadata_path}'..." + ) + write_json(filename=metadata_path, data=metadata or {}) + def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: dump = super()._model_dump(obj, **kwargs) dump["runtime_parameters_info"] = self.get_runtime_parameters_info() diff --git a/src/distilabel/steps/constants.py b/src/distilabel/steps/constants.py deleted file mode 100644 index 259780e9fd..0000000000 --- a/src/distilabel/steps/constants.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2023-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Final - -DISTILABEL_METADATA_KEY: Final[str] = "distilabel_metadata" diff --git a/src/distilabel/steps/embeddings/nearest_neighbour.py b/src/distilabel/steps/embeddings/nearest_neighbour.py index cba8e293da..6d548bc948 100644 --- a/src/distilabel/steps/embeddings/nearest_neighbour.py +++ b/src/distilabel/steps/embeddings/nearest_neighbour.py @@ -186,6 +186,23 @@ def _build_index(self, inputs: List[Dict[str, Any]]) -> Dataset: ) return dataset + def _save_index(self, dataset: Dataset) -> None: + """Save the generated Faiss index as an artifact of the step. + + Args: + dataset: the dataset with the `faiss` index built. + """ + self.save_artifact( + name="faiss_index", + write_function=lambda path: dataset.save_faiss_index( + index_name="embedding", file=path / "index.faiss" + ), + metadata={ + "num_rows": len(dataset), + "embedding_dim": len(dataset[0]["embedding"]), + }, + ) + def _search(self, dataset: Dataset) -> Dataset: """Search the top `k` nearest neighbours for each row in the dataset. @@ -214,5 +231,6 @@ def add_search_results(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def process(self, inputs: StepInput) -> "StepOutput": # type: ignore dataset = self._build_index(inputs) - dataset = self._search(dataset) - yield dataset.to_list() + dataset_with_search_results = self._search(dataset) + self._save_index(dataset) + yield dataset_with_search_results.to_list() diff --git a/src/distilabel/utils/card/distilabel_template.md b/src/distilabel/utils/card/distilabel_template.md index c51e9a5953..38daa7f857 100644 --- a/src/distilabel/utils/card/distilabel_template.md +++ b/src/distilabel/utils/card/distilabel_template.md @@ -70,6 +70,21 @@ ds = load_dataset("{{ repo_id }}") {% endfor %} +{% if artifacts %} +## Artifacts + +{% for step_name, artifacts in artifacts.items() %} +* **Step**: `{{ step_name }}` + {% for artifact in artifacts %} + * **Artifact name**: `{{ artifact.name }}` + {% for name, value in artifact.metadata.items() %} + * `{{ name }}`: {{ value }} + {% endfor %} + {% endfor %} +{% endfor %} + +{% endif %} + {% if references %} ## References diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index dc1bd99446..a2a043f737 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -21,6 +21,11 @@ from unittest import mock import pytest +from distilabel.constants import ( + INPUT_QUEUE_ATTR_NAME, + LAST_BATCH_SENT_FLAG, + STEPS_ARTIFACTS_PATH, +) from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.pipeline.base import ( _STEP_LOAD_FAILED_CODE, @@ -30,7 +35,6 @@ ) from distilabel.pipeline.batch import _Batch from distilabel.pipeline.batch_manager import _BatchManager -from distilabel.pipeline.constants import INPUT_QUEUE_ATTR_NAME, LAST_BATCH_SENT_FLAG from distilabel.pipeline.routing_batch_function import ( routing_batch_function, sample_n_steps, @@ -154,6 +158,23 @@ def test_setup_fsspec_raises_value_error(self) -> None: with pytest.raises(ValueError, match="The 'path' key must be present"): pipeline._setup_fsspec({"key": "random"}) + def test_set_pipeline_artifacts_path_in_steps(self) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + pipeline._set_pipeline_artifacts_path_in_steps() + + artifacts_directory = pipeline._cache_location["data"] / STEPS_ARTIFACTS_PATH + assert generator.artifacts_directory == artifacts_directory / generator.name # type: ignore + assert step.artifacts_directory == artifacts_directory / step.name # type: ignore + assert step2.artifacts_directory == artifacts_directory / step2.name # type: ignore + assert step3.artifacts_directory == artifacts_directory / step3.name # type: ignore + def test_init_steps_load_status(self) -> None: with DummyPipeline(name="dummy") as pipeline: generator = DummyGeneratorStep() diff --git a/tests/unit/pipeline/test_dag.py b/tests/unit/pipeline/test_dag.py index 6abfe51fd3..b566de2f1b 100644 --- a/tests/unit/pipeline/test_dag.py +++ b/tests/unit/pipeline/test_dag.py @@ -17,9 +17,9 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List import pytest +from distilabel.constants import STEP_ATTR_NAME from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.pipeline._dag import DAG -from distilabel.pipeline.constants import STEP_ATTR_NAME from distilabel.pipeline.local import Pipeline from distilabel.pipeline.routing_batch_function import routing_batch_function from distilabel.steps.base import GeneratorStep, Step, StepInput, StepResources diff --git a/tests/unit/pipeline/test_write_buffer.py b/tests/unit/pipeline/test_write_buffer.py index a7ae64c91e..2fd552f17e 100644 --- a/tests/unit/pipeline/test_write_buffer.py +++ b/tests/unit/pipeline/test_write_buffer.py @@ -15,6 +15,7 @@ import tempfile from pathlib import Path +from distilabel.constants import STEPS_OUTPUTS_PATH from distilabel.distiset import Distiset, create_distiset from distilabel.pipeline.local import Pipeline from distilabel.pipeline.write_buffer import _WriteBuffer @@ -30,7 +31,8 @@ class TestWriteBuffer: def test_create(self) -> None: with tempfile.TemporaryDirectory() as tmpdirname: - folder = Path(tmpdirname) / "data" + folder = Path(tmpdirname) / "data" / STEPS_OUTPUTS_PATH + steps_outputs = folder / STEPS_OUTPUTS_PATH with Pipeline(name="unit-test-pipeline") as pipeline: dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1") dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2") @@ -43,7 +45,9 @@ def test_create(self) -> None: dummy_step_1.connect(dummy_step_2) dummy_step_1.connect(dummy_step_3) - write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) + write_buffer = _WriteBuffer( + path=steps_outputs, leaf_steps=pipeline.dag.leaf_steps + ) assert write_buffer._buffers == {"dummy_step_2": [], "dummy_step_3": []} assert write_buffer._buffers_dump_batch_size == { @@ -59,6 +63,7 @@ def test_create(self) -> None: def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None: with tempfile.TemporaryDirectory() as tmpdirname: folder = Path(tmpdirname) / "data" + steps_outputs = folder / STEPS_OUTPUTS_PATH with Pipeline(name="unit-test-pipeline") as pipeline: dummy_generator = DummyGeneratorStep(name="dummy_generator_step") dummy_step_1 = DummyStep1(name="dummy_step_1") @@ -67,7 +72,9 @@ def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None: dummy_generator.connect(dummy_step_1) dummy_step_1.connect(dummy_step_2) - write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) + write_buffer = _WriteBuffer( + path=steps_outputs, leaf_steps=pipeline.dag.leaf_steps + ) # Add one batch with 5 rows, shouldn't write anything 5 < 50 batch = batch_gen(dummy_step_2.name) # type: ignore @@ -78,14 +85,14 @@ def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None: batch = batch_gen(dummy_step_2.name) # type: ignore write_buffer.add_batch(batch) - assert Path(folder, "dummy_step_2", "00001.parquet").exists() + assert Path(steps_outputs, "dummy_step_2", "00001.parquet").exists() # Add 50 more rows, we should have a new file for _ in range(10): batch = batch_gen(dummy_step_2.name) # type: ignore write_buffer.add_batch(batch) - assert Path(folder, "dummy_step_2", "00002.parquet").exists() + assert Path(steps_outputs, "dummy_step_2", "00002.parquet").exists() # Add more rows and close the write buffer, we should have a new file for _ in range(5): @@ -94,9 +101,9 @@ def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None: write_buffer.close() - assert Path(folder, "dummy_step_2", "00003.parquet").exists() + assert Path(steps_outputs, "dummy_step_2", "00003.parquet").exists() - ds = create_distiset(write_buffer._path) + ds = create_distiset(folder) assert isinstance(ds, Distiset) assert len(ds.keys()) == 1 assert len(ds["default"]["train"]) == 125 @@ -104,6 +111,7 @@ def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None: def test_write_buffer_multiple_leaf_steps_and_create_dataset(self) -> None: with tempfile.TemporaryDirectory() as tmpdirname: folder = Path(tmpdirname) / "data" + steps_outputs = folder / STEPS_OUTPUTS_PATH with Pipeline(name="unit-test-pipeline") as pipeline: dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1") dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2") @@ -116,19 +124,21 @@ def test_write_buffer_multiple_leaf_steps_and_create_dataset(self) -> None: dummy_step_1.connect(dummy_step_2) dummy_step_1.connect(dummy_step_3) - write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) + write_buffer = _WriteBuffer( + path=steps_outputs, leaf_steps=pipeline.dag.leaf_steps + ) for _ in range(10): batch = batch_gen(dummy_step_2.name) # type: ignore write_buffer.add_batch(batch) - assert Path(folder, "dummy_step_2", "00001.parquet").exists() + assert Path(steps_outputs, "dummy_step_2", "00001.parquet").exists() for _ in range(10): batch = batch_gen(dummy_step_3.name) # type: ignore write_buffer.add_batch(batch) - assert Path(folder, "dummy_step_3", "00001.parquet").exists() + assert Path(steps_outputs, "dummy_step_3", "00001.parquet").exists() for _ in range(5): batch = batch_gen(dummy_step_2.name) # type: ignore @@ -140,10 +150,10 @@ def test_write_buffer_multiple_leaf_steps_and_create_dataset(self) -> None: write_buffer.close() - assert Path(folder, "dummy_step_2", "00002.parquet").exists() - assert Path(folder, "dummy_step_3", "00002.parquet").exists() + assert Path(steps_outputs, "dummy_step_2", "00002.parquet").exists() + assert Path(steps_outputs, "dummy_step_3", "00002.parquet").exists() - ds = create_distiset(write_buffer._path) + ds = create_distiset(folder) assert isinstance(ds, Distiset) assert len(ds.keys()) == 2 assert len(ds["dummy_step_2"]["train"]) == 75 diff --git a/tests/unit/steps/test_base.py b/tests/unit/steps/test_base.py index d66cb927eb..daf95bedb7 100644 --- a/tests/unit/steps/test_base.py +++ b/tests/unit/steps/test_base.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile +from pathlib import Path from typing import List, Optional import pytest +from distilabel.constants import ROUTING_BATCH_FUNCTION_ATTR_NAME from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.pipeline.constants import ROUTING_BATCH_FUNCTION_ATTR_NAME from distilabel.pipeline.local import Pipeline from distilabel.steps.base import GeneratorStep, GlobalStep, Step, StepInput from distilabel.steps.decorator import step @@ -259,6 +261,48 @@ def routing_batch_function(downstream_step_names: List[str]) -> List[str]: == routing_batch_function ) + def test_set_pipeline_artifacts_path(self) -> None: + step = DummyStep() + step.set_pipeline_artifacts_path(Path("/tmp")) + assert step.artifacts_directory == Path(f"/tmp/{step.name}") + + def test_save_artifact(self) -> None: + with tempfile.TemporaryDirectory() as tempdir: + pipeline_artifacts_path = Path(tempdir) + step = DummyStep() + step.load() + step.set_pipeline_artifacts_path(pipeline_artifacts_path) + step.save_artifact( + name="unit-test", + write_function=lambda path: Path(path / "file.txt").write_text( + "unit test" + ), + metadata={"unit-test": True}, + ) + + artifact_path = pipeline_artifacts_path / step.name / "unit-test" # type: ignore + + assert artifact_path.is_dir() + assert (artifact_path / "file.txt").read_text() == "unit test" + assert (artifact_path / "metadata.json").read_text() == '{"unit-test":true}' + + def test_save_artifact_without_setting_path(self) -> None: + with tempfile.TemporaryDirectory() as tempdir: + pipeline_artifacts_path = Path(tempdir) + step = DummyStep() + step.load() + step.save_artifact( + name="unit-test", + write_function=lambda path: Path(path / "file.txt").write_text( + "unit test" + ), + metadata={"unit-test": True}, + ) + + artifact_path = pipeline_artifacts_path / step.name / "unit-test" # type: ignore + + assert not artifact_path.exists() + class TestGeneratorStep: def test_is_generator(self) -> None: diff --git a/tests/unit/test_distiset.py b/tests/unit/test_distiset.py index 07e6549d7b..e492f218a1 100644 --- a/tests/unit/test_distiset.py +++ b/tests/unit/test_distiset.py @@ -22,11 +22,12 @@ import yaml from datasets import Dataset, DatasetDict from distilabel.distiset import Distiset +from distilabel.utils.serialization import write_json from upath import UPath @pytest.fixture(scope="function") -def distiset(): +def distiset() -> Distiset: return Distiset( { "leaf_step_1": Dataset.from_dict({"a": [1, 2, 3]}), @@ -42,14 +43,32 @@ def make_fake_file(filename: Path) -> None: def add_config_to_distiset(distiset: Distiset, folder: Path) -> Distiset: - from distilabel.distiset import DISTISET_CONFIG_FOLDER + from distilabel.constants import DISTISET_CONFIG_FOLDER pipeline_yaml = folder / DISTISET_CONFIG_FOLDER / "pipeline.yaml" pipeline_log = folder / DISTISET_CONFIG_FOLDER / "pipeline.log" make_fake_file(pipeline_yaml) make_fake_file(pipeline_log) distiset.pipeline_path = pipeline_yaml - distiset.pipeline_log_path = pipeline_log + distiset.log_filename_path = pipeline_log + return distiset + + +def add_artifacts_to_distiset(distiset: Distiset, folder: Path) -> Distiset: + from distilabel.constants import DISTISET_ARTIFACTS_FOLDER + + artifacts_folder = folder / DISTISET_ARTIFACTS_FOLDER + + for step in ("leaf_step_1", "leaf_step_2"): + step_artifacts_folder = artifacts_folder / step + step_artifacts_folder.mkdir(parents=True) + artifact_folder = step_artifacts_folder / "artifact" + artifact_folder.mkdir() + metadata_file = artifact_folder / "metadata.json" + write_json(metadata_file, {}) + + distiset.artifacts_path = artifacts_folder + return distiset @@ -63,54 +82,77 @@ def test_train_test_split(self, distiset: Distiset) -> None: @pytest.mark.parametrize("storage_options", [None, {"test": "option"}]) @pytest.mark.parametrize("with_config", [False, True]) + @pytest.mark.parametrize("with_artifacts", [False, True]) def test_save_to_disk( self, distiset: Distiset, with_config: bool, + with_artifacts: bool, storage_options: Optional[Dict[str, Any]], ) -> None: full_distiset = copy.deepcopy(distiset) # Distiset with Distiset with tempfile.TemporaryDirectory() as tmpdirname: folder = Path(tmpdirname) / "distiset_folder" + another_folder = Path(tmpdirname) / "another_distiset_folder" + if with_config: full_distiset = add_config_to_distiset(full_distiset, folder) + if with_artifacts: + full_distiset = add_artifacts_to_distiset(full_distiset, folder) + full_distiset.save_to_disk( - folder, + another_folder, save_card=with_config, save_pipeline_config=with_config, save_pipeline_log=with_config, storage_options=storage_options, ) - assert folder.is_dir() - assert len(list(folder.iterdir())) == 3 + assert another_folder.is_dir() + + if with_artifacts: + assert len(list(another_folder.iterdir())) == 4 + else: + assert len(list(another_folder.iterdir())) == 3 full_distiset = copy.deepcopy(distiset) # Distiset with DatasetDict distiset_with_dict = full_distiset.train_test_split(0.8) with tempfile.TemporaryDirectory() as tmpdirname: folder = Path(tmpdirname) / "distiset_folder" + another_folder = Path(tmpdirname) / "another_distiset_folder" + if with_config: distiset_with_dict = add_config_to_distiset(distiset_with_dict, folder) + if with_artifacts: + distiset_with_dict = add_artifacts_to_distiset( + distiset_with_dict, folder + ) + distiset_with_dict.save_to_disk( - folder, + another_folder, save_card=with_config, save_pipeline_config=with_config, save_pipeline_log=with_config, ) - assert folder.is_dir() - assert len(list(folder.iterdir())) == 3 + assert another_folder.is_dir() + if with_artifacts: + assert len(list(another_folder.iterdir())) == 4 + else: + assert len(list(another_folder.iterdir())) == 3 @pytest.mark.parametrize("pathlib_implementation", [Path, UPath]) @pytest.mark.parametrize("storage_options", [None, {"project": "experiments"}]) @pytest.mark.parametrize("with_config", [False, True]) + @pytest.mark.parametrize("with_artifacts", [False, True]) def test_load_from_disk( self, distiset: Distiset, with_config: bool, + with_artifacts: bool, storage_options: Optional[Dict[str, Any]], pathlib_implementation: type, ) -> None: @@ -120,17 +162,25 @@ def test_load_from_disk( # This way we can test also we work with UPath, using FilePath protocol, as it should # do the same as S3Path, GCSPath, etc. folder = pathlib_implementation(tmpdirname) / "distiset_folder" + another_folder = ( + pathlib_implementation(tmpdirname) / "another_distiset_folder" + ) + if with_config: full_distiset = add_config_to_distiset(full_distiset, folder) + + if with_artifacts: + full_distiset = add_artifacts_to_distiset(full_distiset, folder) + full_distiset.save_to_disk( - folder, + another_folder, save_card=with_config, save_pipeline_config=with_config, save_pipeline_log=with_config, storage_options=storage_options, ) ds = Distiset.load_from_disk( - folder, + another_folder, storage_options=storage_options, ) assert isinstance(ds, Distiset) @@ -140,24 +190,41 @@ def test_load_from_disk( assert ds.pipeline_path.exists() assert ds.log_filename_path.exists() + if with_artifacts: + assert ds.artifacts_path.exists() + full_distiset = copy.deepcopy(distiset) # Distiset with DatasetDict distiset_with_dict = full_distiset.train_test_split(0.8) with tempfile.TemporaryDirectory() as tmpdirname: folder = pathlib_implementation(tmpdirname) / "distiset_folder" + another_folder = ( + pathlib_implementation(tmpdirname) / "another_distiset_folder" + ) + if with_config: distiset_with_dict = add_config_to_distiset(distiset_with_dict, folder) - distiset_with_dict.save_to_disk(folder) - ds = Distiset.load_from_disk(folder, storage_options=storage_options) + if with_artifacts: + distiset_with_dict = add_artifacts_to_distiset( + distiset_with_dict, folder + ) - assert folder.is_dir() + distiset_with_dict.save_to_disk(another_folder) + ds = Distiset.load_from_disk( + another_folder, storage_options=storage_options + ) + + assert another_folder.is_dir() assert isinstance(ds["leaf_step_1"], DatasetDict) if with_config: assert ds.pipeline_path.exists() assert ds.log_filename_path.exists() + if with_artifacts: + assert ds.artifacts_path.exists() + def test_dataset_card(self, distiset: Distiset) -> None: # Test the the metadata we generate by default without extracting the already generated content from the HF hub. # We parse the content and check it's the same as the one we generate.