Skip to content

Commit

Permalink
Add save_artifact method to _Step (#871)
Browse files Browse the repository at this point in the history
* Add `save_artifact` method

* Upload pipeline generated artifacts

* Fix log file was being saved in different cache

* Update `save_to_disk` to also save artifacts

* Render artifacts in card

* Update unit tests

* Add missing unit tests

* Update src/distilabel/distiset.py

Co-authored-by: Agus <[email protected]>

* Add section about saving artifacts

* Add correct `edit_uri`

---------

Co-authored-by: Agus <[email protected]>
  • Loading branch information
gabrielmbmb and plaguss authored Aug 14, 2024
1 parent f382f1c commit c8df5a9
Show file tree
Hide file tree
Showing 15 changed files with 557 additions and 108 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Saving step generated artifacts

Some `Step`s might need to produce an auxiliary artifact that is not a result of the computation, but is needed for the computation. For example, the [`FaissNearestNeighbour`](/distilabel/components-gallery/steps/faissnearestneighbour/) needs to create a Faiss index to compute the output of the step which are the top `k` nearest neighbours for each input. Generating the Faiss index takes time and it could potentially be reused outside of the `distilabel` pipeline, so it would be a shame not saving it.

For this reason, `Step`s have a method called `save_artifact` that allows saving artifacts that will be included along the outputs of the pipeline in the generated [`Distiset`][distilabel.distiset.Distiset]. The generated artifacts will be uploaded and saved when using `Distiset.push_to_hub` or `Distiset.save_to_disk` respectively. Let's see how to use it with a simple example.

```python
from typing import List, TYPE_CHECKING
from distilabel.steps import GlobalStep, StepInput, StepOutput
import matplotlib.pyplot as plt

if TYPE_CHECKING:
from distilabel.steps import StepOutput


class CountTextCharacters(GlobalStep):
@property
def inputs(self) -> List[str]:
return ["text"]

@property
def outputs(self) -> List[str]:
return ["text_character_count"]

def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
character_counts = []

for input in inputs:
text_character_count = len(input["text"])
input["text_character_count"] = text_character_count
character_counts.append(text_character_count)

# Generate plot with the distribution of text character counts
plt.figure(figsize=(10, 6))
plt.hist(character_counts, bins=30, edgecolor="black")
plt.title("Distribution of Text Character Counts")
plt.xlabel("Character Count")
plt.ylabel("Frequency")

# Save the plot as an artifact of the step
self.save_artifact(
name="text_character_count_distribution",
write_function=lambda path: plt.savefig(path / "figure.png"),
metadata={"type": "image", "library": "matplotlib"},
)

plt.close()

yield inputs
```

As it can be seen in the example above, we have created a simple step that counts the number of characters in each input text and generates a histogram with the distribution of the character counts. We save the histogram as an artifact of the step using the `save_artifact` method. The method takes three arguments:

- `name`: The name we want to give to the artifact.
- `write_function`: A function that writes the artifact to the desired path. The function will receive a `path` argument which is a `pathlib.Path` object pointing to the directory where the artifact should be saved.
- `metadata`: A dictionary with metadata about the artifact. This metadata will be saved along with the artifact.

Let's execute the step with a simple pipeline and push the resulting `Distiset` to the Hugging Face Hub:

??? "Example full code"

```python
from typing import TYPE_CHECKING, List

import matplotlib.pyplot as plt
from datasets import load_dataset
from distilabel.pipeline import Pipeline
from distilabel.steps import GlobalStep, StepInput, StepOutput

if TYPE_CHECKING:
from distilabel.steps import StepOutput


class CountTextCharacters(GlobalStep):
@property
def inputs(self) -> List[str]:
return ["text"]

@property
def outputs(self) -> List[str]:
return ["text_character_count"]

def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
character_counts = []

for input in inputs:
text_character_count = len(input["text"])
input["text_character_count"] = text_character_count
character_counts.append(text_character_count)

# Generate plot with the distribution of text character counts
plt.figure(figsize=(10, 6))
plt.hist(character_counts, bins=30, edgecolor="black")
plt.title("Distribution of Text Character Counts")
plt.xlabel("Character Count")
plt.ylabel("Frequency")

# Save the plot as an artifact of the step
self.save_artifact(
name="text_character_count_distribution",
write_function=lambda path: plt.savefig(path / "figure.png"),
metadata={"type": "image", "library": "matplotlib"},
)

plt.close()

yield inputs


with Pipeline() as pipeline:
count_text_characters = CountTextCharacters()

if __name__ == "__main__":
distiset = pipeline.run(
dataset=load_dataset(
"HuggingFaceH4/instruction-dataset", split="test"
).rename_column("prompt", "text"),
)

distiset.push_to_hub("distilabel-internal-testing/distilabel-artifacts-example")
```

The generated [distilabel-internal-testing/distilabel-artifacts-example](https://huggingface.co/datasets/distilabel-internal-testing/distilabel-artifacts-example) dataset repository has a section in its card [describing the artifacts generated by the pipeline](https://huggingface.co/datasets/distilabel-internal-testing/distilabel-artifacts-example#artifacts) and the generated plot can be seen [here](https://huggingface.co/datasets/distilabel-internal-testing/distilabel-artifacts-example/blob/main/artifacts/count_text_characters_0/text_character_count_distribution/figure.png).
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ site_description: Distilabel is an AI Feedback (AIF) framework for building data
# Repository
repo_name: argilla-io/distilabel
repo_url: https://github.com/argilla-io/distilabel
edit_uri: edit/main/docs/

extra:
version:
Expand Down Expand Up @@ -179,6 +180,7 @@ nav:
- Using CLI to explore and re-run existing Pipelines: "sections/how_to_guides/advanced/cli/index.md"
- Using a file system to pass data of batches between steps: "sections/how_to_guides/advanced/fs_to_pass_data.md"
- Assigning resources to a step: "sections/how_to_guides/advanced/assigning_resources_to_step.md"
- Saving step generated artifacts: "sections/how_to_guides/advanced/saving_step_generated_artifacts.md"
- Serving an LLM for sharing it between several tasks: "sections/how_to_guides/advanced/serving_an_llm_for_reuse.md"
- Scaling and distributing a pipeline with Ray: "sections/how_to_guides/advanced/scaling_with_ray.md"
- Pipeline Samples:
Expand Down
10 changes: 10 additions & 0 deletions src/distilabel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@
CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step"
LAST_BATCH_SENT_FLAG: Final[str] = "last_batch_sent"

# Data paths constants
STEPS_OUTPUTS_PATH = "steps_outputs"
STEPS_ARTIFACTS_PATH = "steps_artifacts"

# Distiset related constants
DISTISET_CONFIG_FOLDER: Final[str] = "distiset_configs"
DISTISET_ARTIFACTS_FOLDER: Final[str] = "artifacts"
PIPELINE_CONFIG_FILENAME: Final[str] = "pipeline.yaml"
PIPELINE_LOG_FILENAME: Final[str] = "pipeline.log"


__all__ = [
"STEP_ATTR_NAME",
Expand Down
Loading

0 comments on commit c8df5a9

Please sign in to comment.