diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml new file mode 100644 index 0000000000..6da0a611be --- /dev/null +++ b/.github/workflows/codspeed.yml @@ -0,0 +1,42 @@ +name: Benchmarks + +on: + push: + branches: + - "main" + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + benchmarks: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.12" + # Looks like it's not working very well for other people: + # https://github.com/actions/setup-python/issues/436 + # cache: "pip" + # cache-dependency-path: pyproject.toml + + - uses: actions/cache@v3 + id: cache + with: + path: ${{ env.pythonLocation }} + key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-benchmarks-v00 + + - name: Install dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: ./scripts/install_dependencies.sh + + - name: Run benchmarks + uses: CodSpeedHQ/action@v2 + with: + token: ${{ secrets.CODSPEED_TOKEN }} + run: pytest tests/ --codspeed diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 01f1ebcb9a..b80e88d2e8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,6 +9,12 @@ on: types: - opened - synchronize + workflow_dispatch: + inputs: + tmate_session: + description: Starts the workflow with tmate enabled. + required: false + default: "false" concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -19,7 +25,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] fail-fast: false steps: @@ -42,14 +48,7 @@ jobs: - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' - run: | - python_version=$(python -c "import sys; print(sys.version_info[:2])") - - pip install -e .[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,vllm] - if [ "${python_version}" != "(3, 8)" ]; then - pip install -e .[mistralai] - fi; - pip install git+https://github.com/argilla-io/LLM-Blender.git + run: ./scripts/install_dependencies.sh - name: Lint run: make lint @@ -59,4 +58,3 @@ jobs: - name: Integration Tests run: make integration-tests - timeout-minutes: 5 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 59ac8caa8c..24ab3d19a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,11 +11,10 @@ repos: - --fuzzy-match-generates-todo - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.1.4 + rev: v0.4.5 hooks: - id: ruff - args: - - --fix + args: [--fix] - id: ruff-format ci: diff --git a/Makefile b/Makefile index 9634d016f7..16cc0c92c7 100644 --- a/Makefile +++ b/Makefile @@ -2,12 +2,12 @@ sources = src/distilabel tests .PHONY: format format: - ruff --fix $(sources) + ruff check --fix $(sources) ruff format $(sources) .PHONY: lint lint: - ruff $(sources) + ruff check $(sources) ruff format --check $(sources) .PHONY: unit-tests diff --git a/docs/assets/images/sections/examples/knowledge-graph-example.png b/docs/assets/images/sections/examples/knowledge-graph-example.png new file mode 100644 index 0000000000..05c80c6caf Binary files /dev/null and b/docs/assets/images/sections/examples/knowledge-graph-example.png differ diff --git a/docs/sections/learn/advanced/distiset.md b/docs/sections/learn/advanced/distiset.md index 17be198ae3..199b607fa5 100644 --- a/docs/sections/learn/advanced/distiset.md +++ b/docs/sections/learn/advanced/distiset.md @@ -70,6 +70,44 @@ distiset.push_to_hub( ) ``` +### Save and load from disk + +Saves the [`Distiset`][distilabel.distiset.Distiset] to disk, and optionally (will be done by default) saves the dataset card, the pipeline config file and logs: + +```python +distiset.save_to_disk( + "my-dataset", + save_card=True, + save_pipeline_config=True, + save_pipeline_log=True +) +``` + +And load a [`Distiset`][distilabel.distiset.Distiset] that was saved using [`Distiset.save_to_disk`][distilabel.distiset.Distiset.save_to_disk] just the same way: + +```python +from distilabel.distiset import Distiset + +distiset = Distiset.load_from_disk("my-dataset") +``` + +or from your cloud provider if that's where it was stored: + +```python +distiset = Distiset.load_from_disk( + "s3://path/to/my_dataset", # gcs:// or any filesystem tolerated by fsspec + storage_options={ + "key": os.environ["S3_ACCESS_KEY"], + "secret": os.environ["S3_SECRET_KEY"], + ... + } +) +``` + +Take into account that these methods work as `datasets.load_from_disk` and `datasets.Dataset.save_to_disk` so the arguments are directly passed to those methods. This means you can also make use of `storage_options` argument to save your [`Distiset`][distilabel.distiset.Distiset] in your cloud provider, including the distilabel artifacts (`pipeline.yaml`, `pipeline.log` and the `README.md` with the dataset card). You can read more in `datasets` documentation [here](https://huggingface.co/docs/datasets/filesystems#saving-serialized-datasets). + +Take a look at the remaining arguments at [`Distiset.save_to_disk`][distilabel.distiset.Distiset.save_to_disk] and [`Distiset.load_from_disk`][distilabel.distiset.Distiset.load_from_disk]. + ## Dataset card Having this special type of dataset comes with an added advantage when calling [`Distiset.push_to_hub`][distilabel.distiset.Distiset], which is the automatically generated dataset card in the Hugging Face Hub. Note that it is enabled by default, but can be disabled by setting `generate_card=False`: diff --git a/docs/sections/learn/advanced/fs_to_pass_data.md b/docs/sections/learn/advanced/fs_to_pass_data.md new file mode 100644 index 0000000000..2851c3bc3c --- /dev/null +++ b/docs/sections/learn/advanced/fs_to_pass_data.md @@ -0,0 +1,24 @@ +# Using a file system to pass data of batches between steps + +In some situations, it can happen that the batches contains so much data that is faster to write it to disk and read it back in the next step, instead of passing it using the queue. To solve this issue, `distilabel` uses [`fsspec`](https://filesystem-spec.readthedocs.io/en/latest/) to allow providing a file system configuration and whether if this file system should be used to pass data between steps in the `run` method of the `distilabel` pipelines: + +```python +from distilabel.pipeline import Pipeline + +with Pipeline(name="my-pipeline") as pipeline: + ... + +if __name__ == "__main__": + distiset = pipeline.run( + ..., + storage_parameters={"protocol": "gcs", "path": "gcs://my-bucket"}, + use_fs_to_pass_data=True + ) +``` + +The code above setups a file system (in this case Google Cloud Storage) and sets the flag `use_fs_to_pass_data` to specify that the data of the batches should be passed to the steps using the file system.The `storage_parameters` argument is optional, and in the case it's not provided but `use_fs_to_pass_data==True`, `distilabel` will use the local file system. + +!!! NOTE + + As `GlobalStep`s receives all the data from the previous steps in one single batch accumulating all the data, it's very likely that the data of the batch will be too big to be passed using the queue. In this case and even if `use_fs_to_pass_data==False`, `distilabel` will use the file system to pass the data to the `GlobalStep`. + diff --git a/docs/sections/learn/advanced/structured_generation.md b/docs/sections/learn/advanced/structured_generation.md index c0ba743ad4..579427434d 100644 --- a/docs/sections/learn/advanced/structured_generation.md +++ b/docs/sections/learn/advanced/structured_generation.md @@ -8,6 +8,14 @@ The [`LLM`][distilabel.llms.LLM] has an argument named `structured_output`[^1] that determines how we can generate structured outputs with it, let's see an example using [`LlamaCppLLM`][distilabel.llms.LlamaCppLLM]. +!!! Note + + For `outlines` integration to work you may need to install the corresponding dependencies: + + ```bash + pip install distilabel[outlines] + ``` + ### JSON We will start with a JSON example, where we initially define a `pydantic.BaseModel` schema to guide the generation of the structured output. @@ -101,7 +109,7 @@ if match: These were some simple examples, but one can see the options this opens. -!!! NOTE +!!! Tip A full pipeline example can be seen in the following script: [`examples/structured_generation_with_outlines.py`](../../pipeline_samples/examples/index.md#llama-cpp-with-outlines) @@ -119,6 +127,72 @@ These were some simple examples, but one can see the options this opens. curl -L -o ~/Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf ``` +## Instructor + +When working with model providers behind an API, there's no direct way of accesing the internal logit processor as `outlines` does, but thanks to [`instructor`](https://python.useinstructor.com/) we can generate structured output from LLM providers. We have integrated `instructor` to deal with the [`AsyncLLM`][distilabel.llms.AsyncLLM], so you can work with the following LLMs: [`OpenAILLM`][distilabel.llms.OpenAILLM], [`AzureOpenAILLM`][distilabel.llms.AzureOpenAILLM], [`CohereLLM`][distilabel.llms.CohereLLM], [`GroqLLM`][distilabel.llms.GroqLLM], [`LiteLLM`][distilabel.llms.LiteLLM] and [`MistralLLM`][distilabel.llms.MistralLLM]. + +`instructor` works with `pydantic.BaseModel` objects internally but in `distilabel` the examples generated would result in the string representation of them, from which the `BaseModel` object can be regenerated. + +!!! Note + For `instructor` integration to work you may need to install the corresponding dependencies: + + ```bash + pip install distilabel[instructor] + ``` + +!!! Note + Take a look at [`InstructorStructuredOutputType`][distilabel.steps.tasks.structured_outputs.instructor.InstructorStructuredOutputType] to see the expected format + of the `structured_output` dict variable. + +The following is the same example you can see with `outlines`'s `JSON` section for comparison purposes. + +```python +from pydantic import BaseModel + +class User(BaseModel): + name: str + last_name: str + id: int +``` + +And then we provide that schema to the `structured_output` argument of the LLM: + +!!! Note + In this example we are using *open-mixtral-8x22b*, keep in mind not all the models work with the function calling functionality required for this example to work. + +```python +from distilabel.llms import MistralLLM + +llm = MistralLLM( + model="open-mixtral-8x22b", + structured_output={"schema": User} +) +llm.load() +``` + +And we are ready to pass our instruction as usual: + +```python +import json + +result = llm.generate( + [[{"role": "user", "content": "Create a user profile for the following marathon"}]], + max_new_tokens=256 +) + +data = json.loads(result[0][0]) +data +# {'name': 'John', 'last_name': 'Doe', 'id': 12345} +User(**data) +# User(name='John', last_name='Doe', id=12345) +``` + +We get back a Python dictionary (formatted as a string) that we can parse using `json.loads`, or validate it directly using the `User`, which is a `pydantic.BaseModel` instance. + +!!! Tip + A full pipeline example can be seen in the following script: + [`examples/structured_generation_with_instructor.py`](../../pipeline_samples/examples/index.md#mistralai-with-instructor) + ## OpenAI JSON OpenAI offers a [JSON Mode](https://platform.openai.com/docs/guides/text-generation/json-mode) to deal with structured output via their API, let's see how to make use of them. The JSON mode instructs the model to always return a JSON object following the instruction required. diff --git a/docs/sections/pipeline_samples/examples/index.md b/docs/sections/pipeline_samples/examples/index.md index c07a7f1737..aa74004357 100644 --- a/docs/sections/pipeline_samples/examples/index.md +++ b/docs/sections/pipeline_samples/examples/index.md @@ -2,7 +2,7 @@ This section contains different example pipelines that showcase different tasks, maybe you can take inspiration from them. -### [llama.cpp with outlines](#llama-cpp-with-outlines) +### [llama.cpp with `outlines`](#llama-cpp-with-outlines) Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `distilabel`. @@ -21,3 +21,42 @@ Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `dis ```python title="structured_generation_with_outlines.py" --8<-- "examples/structured_generation_with_outlines.py" ``` + + +### [MistralAI with `instructor`](#mistralai-with-instructor) + +Answer instructions with knowledge graphs defined as `pydantic.BaseModel` objects using `instructor` in `distilabel`. + +??? Example "See example" + + This script makes use of [`MistralLLM`][distilabel.llms.mistral.MistralLLM] and the structured output capabilities thanks to [`instructor`](https://python.useinstructor.com/) to generate knowledge graphs from complex topics. + + This example is translated from this [awesome example](https://python.useinstructor.com/examples/knowledge_graph/) from `instructor` cookbook. + + ??? Run + + ```python + python examples/structured_generation_with_instructor.py + ``` + + ```python title="structured_generation_with_instructor.py" + --8<-- "examples/structured_generation_with_instructor.py" + ``` + + ??? "Visualizing the graphs" + + Want to see how to visualize the graphs? You can test it using the following script. Generate some samples on your own and take a look: + + !!! NOTE + + This example uses graphviz to render the graph, you can install with `pip` in the following way: + + ```console + pip install graphviz + ``` + + ```python + python examples/draw_kg.py 2 # You can pass 0,1,2 to visualize each of the samples. + ``` + + ![Knowledge graph figure](../../../assets/images/sections/examples/knowledge-graph-example.png) diff --git a/examples/draw_kg.py b/examples/draw_kg.py new file mode 100644 index 0000000000..8d45e40b85 --- /dev/null +++ b/examples/draw_kg.py @@ -0,0 +1,82 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, Dict, List, Union + +from graphviz import Digraph +from pydantic import BaseModel, Field + + +class Node(BaseModel): + id: int + label: str + color: str + + +class Edge(BaseModel): + source: int + target: int + label: str + color: str = "black" + + +class KnowledgeGraph(BaseModel): + nodes: List[Node] = Field(..., default_factory=list) + edges: List[Edge] = Field(..., default_factory=list) + + +def visualize_knowledge_graph(kg: KnowledgeGraph): + dot = Digraph(comment="Knowledge Graph") + + # Add nodes + for node in kg.nodes: + dot.node(str(node.id), node.label, color=node.color) + + # Add edges + for edge in kg.edges: + dot.edge( + str(edge.source), + str(edge.target), + label=edge.label, + color=edge.color or "black", + ) + + # Render the graph + dot.render("knowledge_graph.gv", view=True) + + +def create_knowledge_graph(data: str) -> Union[KnowledgeGraph, None]: + data: Dict[str, Any] = json.loads(data) + + nodes = [Node(**node) for node in data["nodes"]] + edges = [] + for edge in data["edges"]: + if edge.get("color") is None: + edge["color"] = "black" + edges.append(Edge(**edge)) + + return KnowledgeGraph(nodes=nodes, edges=edges) + + +if __name__ == "__main__": + import sys + + args = sys.argv[1:] + + from datasets import load_dataset + + ds = load_dataset("distilabel-internal-testing/knowledge_graphs", split="train") + graphs = [create_knowledge_graph(g) for g in ds["generation"]] + visualize_knowledge_graph(graphs[int(args[0])]) diff --git a/examples/structured_generation_with_instructor.py b/examples/structured_generation_with_instructor.py new file mode 100644 index 0000000000..48082886f4 --- /dev/null +++ b/examples/structured_generation_with_instructor.py @@ -0,0 +1,87 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from distilabel.llms import MistralLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import LoadDataFromDicts +from distilabel.steps.tasks import TextGeneration +from pydantic import BaseModel, Field + + +class Node(BaseModel): + id: int + label: str + color: str + + +class Edge(BaseModel): + source: int + target: int + label: str + color: str = "black" + + +class KnowledgeGraph(BaseModel): + nodes: List[Node] = Field(..., default_factory=list) + edges: List[Edge] = Field(..., default_factory=list) + + +with Pipeline( + name="Knowledge-Graphs", + description=( + "Generate knowledge graphs to answer questions, this type of dataset can be used to " + "steer a model to answer questions with a knowledge graph." + ), +) as pipeline: + sample_questions = [ + "Teach me about quantum mechanics", + "Who is who in The Simpsons family?", + "Tell me about the evolution of programming languages", + ] + + load_dataset = LoadDataFromDicts( + name="load_instructions", + data=[ + { + "system_prompt": "You are a knowledge graph expert generator. Help me understand by describing everything as a detailed knowledge graph.", + "instruction": f"{question}", + } + for question in sample_questions + ], + ) + + text_generation = TextGeneration( + name="knowledge_graph_generation", + llm=MistralLLM( + model="open-mixtral-8x22b", structured_output={"schema": KnowledgeGraph} + ), + input_batch_size=8, + output_mappings={"model_name": "generation_model"}, + ) + load_dataset >> text_generation + + +if __name__ == "__main__": + distiset = pipeline.run( + parameters={ + text_generation.name: { + "llm": {"generation_kwargs": {"max_new_tokens": 2048}} + } + }, + use_cache=False, + ) + + distiset.push_to_hub("distilabel-internal-testing/knowledge_graphs") diff --git a/mkdocs.yml b/mkdocs.yml index d525d1828e..6bf897561b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -142,6 +142,7 @@ nav: - Caching: "sections/learn/advanced/caching.md" - Distiset: "sections/learn/advanced/distiset.md" - Structured Generation: "sections/learn/advanced/structured_generation.md" + - Using the file system to pass batch data: "sections/learn/advanced/fs_to_pass_data.md" - Pipeline Samples: - "sections/pipeline_samples/index.md" - Examples: "sections/pipeline_samples/examples/index.md" @@ -164,9 +165,9 @@ nav: - GlobalStep: "api/step/global_step.md" - "@step": "api/step/decorator.md" - Step Gallery: - - Argilla: "api/step_gallery/argilla.md" - - Columns: "api/step_gallery/columns.md" - - Extra: "api/step_gallery/extra.md" + - Argilla: "api/step_gallery/argilla.md" + - Columns: "api/step_gallery/columns.md" + - Extra: "api/step_gallery/extra.md" - Task: - "api/task/index.md" - GeneratorTask: "api/task/generator_task.md" diff --git a/pyproject.toml b/pyproject.toml index e22c80a16e..80fe4714ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] @@ -34,6 +35,7 @@ dependencies = [ "typer >= 0.9.0", "tblib >= 3.0.0", "orjson >= 3.10.0", + "universal_pathlib >= 0.2.2", ] dynamic = ["version"] @@ -44,7 +46,7 @@ distilabel = "distilabel.cli.app:app" "distilabel/components-gallery" = "distilabel.utils.mkdocs.components_gallery:ComponentsGalleryPlugin" [project.optional-dependencies] -dev = ["ruff == 0.2.2", "pre-commit >= 3.5.0"] +dev = ["ruff == 0.4.5", "pre-commit >= 3.5.0"] docs = [ "mkdocs-material >= 9.5.0", "mkdocstrings[python] >= 0.24.0", @@ -56,15 +58,22 @@ docs = [ "CairoSVG >= 2.7.1", "mknotebooks >= 0.8.0", ] -tests = ["pytest >= 7.4.0", "pytest-asyncio", "nest-asyncio", "pytest-timeout"] +tests = [ + "pytest >= 7.4.0", + "pytest-asyncio", + "nest-asyncio", + "pytest-timeout", + "pytest-codspeed", +] # Optional LLMs, integrations, etc anthropic = ["anthropic >= 0.20.0"] -argilla = ["argilla >= 1.23.0"] +argilla = ["argilla >= 1.29.0"] cohere = ["cohere >= 5.2.0"] groq = ["groq >= 0.4.1"] hf-inference-endpoints = ["huggingface_hub >= 0.19.0"] hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"] +instructor = ["instructor >= 1.2.3"] litellm = ["litellm >= 1.30.0"] llama-cpp = ["llama-cpp-python >= 0.2.0"] mistralai = ["mistralai >= 0.1.0"] @@ -72,7 +81,7 @@ ollama = ["ollama >= 0.1.7"] openai = ["openai >= 1.0.0"] outlines = ["outlines >= 0.0.40"] vertexai = ["google-cloud-aiplatform >= 1.38.0"] -vllm = ["vllm >= 0.2.1", "filelock >= 3.13.4"] +vllm = ["vllm >= 0.4.0", "outlines == 0.0.34", "filelock >= 3.13.4"] [project.urls] Documentation = "https://distilabel.argilla.io/" diff --git a/scripts/install_dependencies.sh b/scripts/install_dependencies.sh new file mode 100755 index 0000000000..9344ac472c --- /dev/null +++ b/scripts/install_dependencies.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +python_version=$(python -c "import sys; print(sys.version_info[:2])") + +python -m pip install uv + +uv pip install --system -e ".[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai]" +if [ "${python_version}" != "(3, 8)" ]; then + uv pip install --system -e .[mistralai,instructor] +fi + +uv pip install --system git+https://github.com/argilla-io/LLM-Blender.git diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index 98c11009f2..3538bccd37 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -13,13 +13,20 @@ # limitations under the License. import logging +import os.path as posixpath import re +from os import PathLike from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Final, Optional, Union -from datasets import load_dataset +import fsspec +import yaml +from datasets import Dataset, load_dataset, load_from_disk +from datasets.filesystems import is_remote_filesystem from huggingface_hub import DatasetCardData, HfApi +from huggingface_hub.file_download import hf_hub_download from pyarrow.lib import ArrowInvalid +from typing_extensions import Self from distilabel.utils.card.dataset_card import ( DistilabelDatasetCard, @@ -27,6 +34,10 @@ ) from distilabel.utils.files import list_files_in_dir +DISTISET_CONFIG_FOLDER: Final[str] = "distiset_configs" +PIPELINE_CONFIG_FILENAME: Final[str] = "pipeline.yaml" +PIPELINE_LOG_FILENAME: Final[str] = "pipeline.log" + class Distiset(dict): """Convenient wrapper around `datasets.Dataset` to push to the Hugging Face Hub. @@ -40,8 +51,8 @@ class Distiset(dict): pipeline. """ - pipeline_path: Optional[Path] = None - log_filename_path: Optional[Path] = None + _pipeline_path: Optional[Path] = None + _log_filename_path: Optional[Path] = None def push_to_hub( self, @@ -83,14 +94,23 @@ def push_to_hub( if generate_card: self._generate_card(repo_id, token) - def _generate_card(self, repo_id: str, token: Optional[str]) -> None: - """Generates a dataset card and pushes it to the Hugging Face Hub, and - if the `pipeline.yaml` path is available in the `Distiset`, uploads that - to the same repository. + def _get_card( + self, repo_id: str, token: Optional[str] = None + ) -> DistilabelDatasetCard: + """Generates the dataset card for the `Distiset`. + + Note: + If `repo_id` and `token` are provided, it will extract the metadata from the README.md file + on the hub. Args: - repo_id: The ID of the repository to push to, from the `push_to_hub` method. - token: The token to authenticate with the Hugging Face Hub, from the `push_to_hub` method. + repo_id: Name of the repository to push to, or the path for the distiset if saved to disk. + token: The token to authenticate with the Hugging Face Hub. + We assume that if it's provided, the dataset will be in the Hugging Face Hub, + so the README metadata will be extracted from there. + + Returns: + The dataset card for the `Distiset`. """ sample_records = {} for name, dataset in self.items(): @@ -98,8 +118,12 @@ def _generate_card(self, repo_id: str, token: Optional[str]) -> None: dataset[0] if not isinstance(dataset, dict) else dataset["train"][0] ) + readme_metadata = {} + if repo_id and token: + readme_metadata = self._extract_readme_metadata(repo_id, token) + metadata = { - **self._extract_readme_metadata(repo_id, token), + **readme_metadata, "size_categories": size_categories_parser( max(len(dataset) for dataset in self.values()) ), @@ -111,29 +135,8 @@ def _generate_card(self, repo_id: str, token: Optional[str]) -> None: repo_id=repo_id, sample_records=sample_records, ) - card.push_to_hub( - repo_id, - repo_type="dataset", - token=token, - ) - if self.pipeline_path: - # If the pipeline.yaml is available, upload it to the Hugging Face Hub as well. - HfApi().upload_file( - path_or_fileobj=self.pipeline_path, - path_in_repo="pipeline.yaml", - repo_id=repo_id, - repo_type="dataset", - token=token, - ) - if self.log_filename_path: - # The same we had with "pipeline.yaml" but with the log file. - HfApi().upload_file( - path_or_fileobj=self.log_filename_path, - path_in_repo="pipeline.log", - repo_id=repo_id, - repo_type="dataset", - token=token, - ) + + return card def _extract_readme_metadata( self, repo_id: str, token: Optional[str] @@ -150,11 +153,6 @@ def _extract_readme_metadata( Returns: The metadata extracted from the README.md file of the dataset repository as a dict. """ - import re - - import yaml - from huggingface_hub.file_download import hf_hub_download - readme_path = Path( hf_hub_download(repo_id, "README.md", repo_type="dataset", token=token) ) @@ -163,12 +161,47 @@ def _extract_readme_metadata( metadata = yaml.safe_load(metadata) return metadata + def _generate_card(self, repo_id: str, token: str) -> None: + """Generates a dataset card and pushes it to the Hugging Face Hub, and + if the `pipeline.yaml` path is available in the `Distiset`, uploads that + to the same repository. + + Args: + repo_id: The ID of the repository to push to, from the `push_to_hub` method. + token: The token to authenticate with the Hugging Face Hub, from the `push_to_hub` method. + """ + card = self._get_card(repo_id=repo_id, token=token) + + card.push_to_hub( + repo_id, + repo_type="dataset", + token=token, + ) + if self.pipeline_path: + # If the pipeline.yaml is available, upload it to the Hugging Face Hub as well. + HfApi().upload_file( + path_or_fileobj=self.pipeline_path, + path_in_repo=PIPELINE_CONFIG_FILENAME, + repo_id=repo_id, + repo_type="dataset", + token=token, + ) + if self.log_filename_path: + # The same we had with "pipeline.yaml" but with the log file. + HfApi().upload_file( + path_or_fileobj=self.log_filename_path, + path_in_repo=PIPELINE_LOG_FILENAME, + repo_id=repo_id, + repo_type="dataset", + token=token, + ) + def train_test_split( self, train_size: float, shuffle: bool = True, seed: Optional[int] = None, - ) -> "Distiset": + ) -> Self: """Return a `Distiset` whose values will be a `datasets.DatasetDict` with two random train and test subsets. Splits are created from the dataset according to `train_size` and `shuffle`. @@ -192,6 +225,198 @@ def train_test_split( ) return self + def save_to_disk( + self, + distiset_path: PathLike, + max_shard_size: Optional[Union[str, int]] = None, + num_shards: Optional[int] = None, + num_proc: Optional[int] = None, + storage_options: Optional[dict] = None, + save_card: bool = True, + save_pipeline_config: bool = True, + save_pipeline_log: bool = True, + ) -> None: + r""" + Saves a `Distiset` to a dataset directory, or in a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. + + In case you want to save the `Distiset` in a remote filesystem, you can pass the `storage_options` parameter + as you would do with `datasets`'s `Dataset.save_to_disk` method: [see example](https://huggingface.co/docs/datasets/filesystems#saving-serialized-datasets) + + Args: + distiset_path: Path where you want to save the `Distiset`. It can be a local path + (e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`) + max_shard_size: The maximum size of the dataset shards to be uploaded to the hub. + If expressed as a string, needs to be digits followed by a unit (like `"50MB"`). + Defaults to `None`. + num_shards: Number of shards to write. By default the number of shards depends on + `max_shard_size` and `num_proc`. Defaults to `None`. + num_proc: Number of processes when downloading and generating the dataset locally. + Multiprocessing is disabled by default. Defaults to `None`. + storage_options: Key/value pairs to be passed on to the file-system backend, if any. + Defaults to `None`. + save_card: Whether to save the dataset card. Defaults to `True`. + save_pipeline_config: Whether to save the pipeline configuration file (aka the `pipeline.yaml` file). + Defaults to `True`. + save_pipeline_log: Whether to save the pipeline log file (aka the `pipeline.log` file). + Defaults to `True`. + + Examples: + ```python + # Save your distiset in a local folder: + >>> distiset.save_to_disk(dataset_path="my-distiset") + # Save your distiset in a remote storage: + >>> storage_options = { + ... "key": os.environ["S3_ACCESS_KEY"], + ... "secret": os.environ["S3_SECRET_KEY"], + ... "client_kwargs": { + ... "endpoint_url": os.environ["S3_ENDPOINT_URL"], + ... "region_name": os.environ["S3_REGION"], + ... }, + ... } + >>> distiset.save_to_disk(dataset_path="my-distiset", storage_options=storage_options) + ``` + """ + distiset_path = str(distiset_path) + for name, dataset in self.items(): + dataset.save_to_disk( + f"{distiset_path}/{name}", + max_shard_size=max_shard_size, + num_shards=num_shards, + num_proc=num_proc, + storage_options=storage_options, + ) + + distiset_config_folder = posixpath.join(distiset_path, DISTISET_CONFIG_FOLDER) + + fs: fsspec.AbstractFileSystem + fs, _, _ = fsspec.get_fs_token_paths( + distiset_config_folder, storage_options=storage_options + ) + fs.makedirs(distiset_config_folder, exist_ok=True) + + if save_card: + # NOTE: Currently the card is not the same if we write to disk or push to the HF hub, + # as we aren't generating the README copying/updating the data from the dataset repo. + card = self._get_card(repo_id=Path(distiset_path).stem, token=None) + new_filename = posixpath.join(distiset_config_folder, "README.md") + if storage_options: + # Write the card the same way as DatasetCard.save does: + with fs.open(new_filename, "w", newline="", encoding="utf-8") as f: + f.write(str(card)) + else: + card.save(new_filename) + + # Write our internal files to the distiset folder by copying them to the distiset folder. + if save_pipeline_config and self.pipeline_path: + new_filename = posixpath.join( + distiset_config_folder, PIPELINE_CONFIG_FILENAME + ) + if self.pipeline_path.exists() and (not fs.isfile(new_filename)): + data = yaml.safe_load(self.pipeline_path.read_text()) + with fs.open(new_filename, "w", encoding="utf-8") as f: + yaml.dump(data, f, default_flow_style=False) + + if save_pipeline_log and self.log_filename_path: + new_filename = posixpath.join(distiset_config_folder, PIPELINE_LOG_FILENAME) + if self.log_filename_path.exists() and (not fs.isfile(new_filename)): + data = self.log_filename_path.read_text() + with fs.open(new_filename, "w", encoding="utf-8") as f: + f.write(data) + + @classmethod + def load_from_disk( + cls, + distiset_path: PathLike, + keep_in_memory: Optional[bool] = None, + storage_options: Optional[Dict[str, Any]] = None, + download_dir: Optional[PathLike] = None, + ) -> Self: + """Loads a dataset that was previously saved using `Distiset.save_to_disk` from a dataset + directory, or from a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. + + Args: + distiset_path: Path ("dataset/train") or remote URI ("s3://bucket/dataset/train"). + keep_in_memory: Whether to copy the dataset in-memory, see `datasets.Dataset.load_from_disk`` + for more information. Defaults to `None`. + storage_options: Key/value pairs to be passed on to the file-system backend, if any. + Defaults to `None`. + download_dir: Optional directory to download the dataset to. Defaults to None, + in which case it will create a temporary directory. + + Returns: + A `Distiset` loaded from disk, it should be a `Distiset` object created using `Distiset.save_to_disk`. + """ + original_distiset_path = str(distiset_path) + + fs: fsspec.AbstractFileSystem + fs, _, [distiset_path] = fsspec.get_fs_token_paths( + original_distiset_path, storage_options=storage_options + ) + dest_distiset_path = distiset_path + + assert fs.isdir( + original_distiset_path + ), "`distiset_path` must be a `PathLike` object pointing to a folder or a URI of a remote filesystem." + + has_config = False + distiset = cls() + + if is_remote_filesystem(fs): + src_dataset_path = distiset_path + if download_dir: + dest_distiset_path = download_dir + else: + dest_distiset_path = Dataset._build_local_temp_path(src_dataset_path) + fs.download(src_dataset_path, dest_distiset_path.as_posix(), recursive=True) + + # Now we should have the distiset locally, so we can read those files + for folder in Path(dest_distiset_path).iterdir(): + if folder.stem == DISTISET_CONFIG_FOLDER: + has_config = True + continue + distiset[folder.stem] = load_from_disk( + str(folder), + keep_in_memory=keep_in_memory, + ) + # From the config folder we just need to point to the files. Once downloaded we set the path + # to wherever they are. + if has_config: + distiset_config_folder = posixpath.join( + dest_distiset_path, DISTISET_CONFIG_FOLDER + ) + + pipeline_path = posixpath.join( + distiset_config_folder, PIPELINE_CONFIG_FILENAME + ) + if Path(pipeline_path).exists(): + distiset.pipeline_path = Path(pipeline_path) + + log_filename_path = posixpath.join( + distiset_config_folder, PIPELINE_LOG_FILENAME + ) + if Path(log_filename_path).exists(): + distiset.log_filename_path = Path(log_filename_path) + + return distiset + + @property + def pipeline_path(self) -> Union[Path, None]: + """Returns the path to the `pipeline.yaml` file that generated the `Pipeline`.""" + return self._pipeline_path + + @pipeline_path.setter + def pipeline_path(self, path: PathLike) -> None: + self._pipeline_path = Path(path) + + @property + def log_filename_path(self) -> Union[Path, None]: + """Returns the path to the `pipeline.log` file that generated the `Pipeline`.""" + return self._log_filename_path + + @log_filename_path.setter + def log_filename_path(self, path: PathLike) -> None: + self._log_filename_path = Path(path) + def __repr__(self): # Copy from `datasets.DatasetDict.__repr__`. repr = "\n".join([f"{k}: {v}" for k, v in self.items()]) @@ -207,6 +432,9 @@ def create_distiset( # noqa: C901 ) -> Distiset: """Creates a `Distiset` from the buffer folder. + This function is intended to be used as a helper to create a `Distiset` from from the folder + where the cached data was written by the `_WriteBuffer`. + Args: data_dir: Folder where the data buffers were written by the `_WriteBuffer`. It should correspond to `CacheLocation.data`. @@ -222,6 +450,13 @@ def create_distiset( # noqa: C901 Returns: The dataset created from the buffer folder, where the different leaf steps will correspond to different configurations of the dataset. + + Examples: + + ```python + >>> from pathlib import Path + >>> distiset = create_distiset(Path.home() / ".cache/distilabel/pipelines/path-to-pipe-hashname") + ``` """ from distilabel.steps.constants import DISTILABEL_METADATA_KEY diff --git a/src/distilabel/llms/anthropic.py b/src/distilabel/llms/anthropic.py index f472aca664..fb4f6dc03c 100644 --- a/src/distilabel/llms/anthropic.py +++ b/src/distilabel/llms/anthropic.py @@ -59,6 +59,8 @@ class AnthropicLLM(AsyncLLM): to `6`. http_client: if provided, an alternative HTTP client to use for calling Anthropic API. Defaults to `None`. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. Defaults to None. _api_key_env_var: the name of the environment variable to use for the API key. It is meant to be used internally. _aclient: the `AsyncAnthropic` client to use for the Anthropic API. It is meant @@ -143,6 +145,15 @@ def load(self) -> None: http_client=self.http_client, max_retries=self.max_retries, ) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="anthropic", + ) + self._aclient = result.get("client") + if structured_output := result.get("structured_output"): + self.structured_output = structured_output @property def model_name(self) -> str: @@ -174,22 +185,32 @@ async def agenerate( # type: ignore """ from anthropic._types import NOT_GIVEN - completion = await self._aclient.messages.create( # type: ignore - model=self.model, - system=( + kwargs = { + "messages": input, # type: ignore + "model": self.model, + "system": ( input.pop(0)["content"] if input and input[0]["role"] == "system" else NOT_GIVEN ), - messages=input, # type: ignore - max_tokens=max_tokens, - stream=False, - stop_sequences=NOT_GIVEN if stop_sequences is None else stop_sequences, - temperature=temperature, - top_p=NOT_GIVEN if top_p is None else top_p, - top_k=NOT_GIVEN if top_k is None else top_k, - ) + "max_tokens": max_tokens, + "stream": False, + "stop_sequences": NOT_GIVEN if stop_sequences is None else stop_sequences, + "temperature": temperature, + "top_p": NOT_GIVEN if top_p is None else top_p, + "top_k": NOT_GIVEN if top_k is None else top_k, + } + + if self.structured_output: + kwargs = self._prepare_kwargs(kwargs, self.structured_output) + generations = [] + + completion = await self._aclient.messages.create(**kwargs) # type: ignore + if self.structured_output: + generations.append(completion.model_dump_json()) + return generations + if (content := completion.content[0].text) is None: self._logger.warning( f"Received no response using Anthropic client (model: '{self.model}')." diff --git a/src/distilabel/llms/azure.py b/src/distilabel/llms/azure.py index 58d455d65e..3fa1f7cde4 100644 --- a/src/distilabel/llms/azure.py +++ b/src/distilabel/llms/azure.py @@ -14,6 +14,7 @@ import os from typing import TYPE_CHECKING, Optional +from unittest.mock import patch from pydantic import Field, PrivateAttr, SecretStr from typing_extensions import override @@ -68,7 +69,10 @@ class AzureOpenAILLM(OpenAILLM): @override def load(self) -> None: """Loads the `AsyncAzureOpenAI` client to benefit from async requests.""" - super().load() + # This is a workaround to avoid the `OpenAILLM` calling the _prepare_structured_output + # in the load method before we have the proper client. + with patch("OpenAILLM._prepare_structured_output", lambda x: x): + super().load() try: from openai import AsyncAzureOpenAI @@ -93,3 +97,6 @@ def load(self) -> None: max_retries=self.max_retries, # type: ignore timeout=self.timeout, ) + + if self.structured_output: + self._prepare_structured_output(self.structured_output) diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index 8c9cc76a3a..1e38771bc0 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -33,6 +33,9 @@ if TYPE_CHECKING: from distilabel.llms.typing import GenerateOutput, HiddenState from distilabel.mixins.runtime_parameters import RuntimeParametersNames + from distilabel.steps.tasks.structured_outputs.instructor import ( + InstructorStructuredOutputType, + ) from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType from distilabel.steps.tasks.typing import DefaultInput, FormattedInput from distilabel.utils.docstring import Docstring @@ -302,3 +305,83 @@ def __del__(self) -> None: return if self.event_loop is not None: self.event_loop.close() + + @staticmethod + def _prepare_structured_output( + structured_output: "InstructorStructuredOutputType", + client: Any = None, + framework: Optional[str] = None, + ) -> Dict[str, Union[str, Any]]: + """Wraps the client and updates the schema to work store it internally as a json schema. + + Args: + structured_output: The configuration dict to prepare the structured output. + client: The client to wrap to generate structured output. Implemented to work + with `instructor`. + framework: The name of the framework. + + Returns: + A dictionary containing the wrapped client and the schema to update the structured_output + variable in case it is a pydantic model. + """ + from distilabel.steps.tasks.structured_outputs.instructor import ( + prepare_instructor, + ) + + result = {} + client = prepare_instructor( + client, + mode=structured_output.get("mode"), + framework=framework, + ) + result["client"] = client + + schema = structured_output.get("schema") + if not schema: + raise ValueError( + f"The `structured_output` argument must contain a schema: {structured_output}" + ) + if issubclass(schema, BaseModel): + # We want a json schema for the serialization, but instructor wants a pydantic BaseModel. + structured_output["schema"] = schema.model_json_schema() + result["structured_output"] = structured_output + + return result + + @staticmethod + def _prepare_kwargs( + arguments: Dict[str, Any], structured_output: Dict[str, Any] + ) -> Dict[str, Any]: + """Helper method to update the kwargs with the structured output configuration, + used in case they are defined. + + Args: + arguments: The arguments that would be passed to the LLM as **kwargs. + to update with the structured output configuration. + structured_outputs: The structured output configuration to update the arguments. + + Returns: + kwargs updated with the special arguments used by `instructor`. + """ + # We can deal with json schema or BaseModel, but we need to convert it to a BaseModel + # for the Instructor client. + schema = structured_output.get("schema") + if not issubclass(schema, BaseModel): + from distilabel.steps.tasks.structured_outputs.utils import ( + json_schema_to_model, + ) + + try: + schema = json_schema_to_model(schema) + except Exception as e: + raise ValueError( + f"Failed to convert the schema to a pydantic model, the model is too complex currently: {e}" + ) from e + + arguments.update( + **{ + "response_model": schema, + "max_retries": structured_output.get("max_retries", 1), + }, + ) + return arguments diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/llms/cohere.py index b7ecafcace..a49b203f3d 100644 --- a/src/distilabel/llms/cohere.py +++ b/src/distilabel/llms/cohere.py @@ -38,6 +38,7 @@ from distilabel.llms.typing import GenerateOutput + _COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY" @@ -54,6 +55,9 @@ class CohereLLM(AsyncLLM): to `120`. client_name: the name of the client to use for the API requests. Defaults to `"distilabel"`. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. _ChatMessage: the `ChatMessage` class from the `cohere` package. _aclient: the `AsyncClient` client from the `cohere` package. @@ -117,6 +121,16 @@ def load(self) -> None: timeout=self.timeout, ) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="cohere", + ) + self._aclient = result.get("client") + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + def _format_chat_to_cohere( self, input: "ChatType" ) -> Tuple[Union[str, None], List["ChatMessage"], str]: @@ -192,21 +206,28 @@ async def agenerate( # type: ignore """ system, chat_history, message = self._format_chat_to_cohere(input) - response = await self._aclient.chat( # type: ignore - message=message, - model=self.model, - preamble=system, - chat_history=chat_history, - temperature=temperature, - max_tokens=max_tokens, - k=k, - p=p, - seed=seed, - stop_sequences=stop_sequences, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - raw_prompting=raw_prompting, - ) + kwargs = { + "message": message, + "model": self.model, + "preamble": system, + "chat_history": chat_history, + "temperature": temperature, + "max_tokens": max_tokens, + "k": k, + "p": p, + "seed": seed, + "stop_sequences": stop_sequences, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "raw_prompting": raw_prompting, + } + if self.structured_output: + kwargs = self._prepare_kwargs(kwargs, self.structured_output) + + response = await self._aclient.chat(**kwargs) # type: ignore + + if self.structured_output: + return response.model_dump_json() if (text := response.text) == "": self._logger.warning( diff --git a/src/distilabel/llms/groq.py b/src/distilabel/llms/groq.py index f7fbda1dc8..4905f82839 100644 --- a/src/distilabel/llms/groq.py +++ b/src/distilabel/llms/groq.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from groq import AsyncGroq + _GROQ_API_BASE_URL_ENV_VAR_NAME = "GROQ_BASE_URL" _GROQ_API_KEY_ENV_VAR_NAME = "GROQ_API_KEY" @@ -45,6 +46,9 @@ class GroqLLM(AsyncLLM): failing. Defaults to `2`. timeout: the maximum time in seconds to wait for a response from the API. Defaults to `120`. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. _api_key_env_var: the name of the environment variable to use for the API key. _aclient: the `AsyncGroq` client from the `groq` package. @@ -109,6 +113,16 @@ def load(self) -> None: timeout=self.timeout, ) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="groq", + ) + self._aclient = result.get("client") + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" @@ -142,17 +156,25 @@ async def agenerate( # type: ignore References: - https://console.groq.com/docs/text-chat """ - completion = await self._aclient.chat.completions.create( # type: ignore - messages=input, # type: ignore - model=self.model, - seed=seed, # type: ignore - temperature=temperature, - max_tokens=max_new_tokens, - top_p=top_p, - stream=False, - stop=stop, - ) + kwargs = { + "messages": input, # type: ignore + "model": self.model, + "seed": seed, + "temperature": temperature, + "max_tokens": max_new_tokens, + "top_p": top_p, + "stream": False, + "stop": stop, + } + if self.structured_output: + kwargs = self._prepare_kwargs(kwargs, self.structured_output) + generations = [] + completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore + if self.structured_output: + generations.append(completion.model_dump_json()) + return generations + for choice in completion.choices: if (content := choice.message.content) is None: self._logger.warning( # type: ignore diff --git a/src/distilabel/llms/litellm.py b/src/distilabel/llms/litellm.py index 1b0add14ac..d3660c5ea0 100644 --- a/src/distilabel/llms/litellm.py +++ b/src/distilabel/llms/litellm.py @@ -33,6 +33,9 @@ class LiteLLM(AsyncLLM): model: the model name to use for the LLM e.g. "gpt-3.5-turbo" or "mistral/mistral-large", etc. verbose: whether to log the LiteLLM client's logs. Defaults to `False`. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. Runtime parameters: - `verbose`: whether to log the LiteLLM client's logs. Defaults to `False`. @@ -69,6 +72,16 @@ def load(self) -> None: continue logging.getLogger(key).setLevel(logging.CRITICAL) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="litellm", + ) + self._aclient = result.get("client") + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" @@ -141,34 +154,40 @@ async def agenerate( # type: ignore """ import litellm + kwargs = { + "model": self.model, + "messages": input, + "n": num_generations, + "functions": functions, + "function_call": function_call, + "temperature": temperature, + "top_p": top_p, + "stream": False, + "stop": stop, + "max_tokens": max_tokens, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "user": user, + "metadata": metadata, + "api_base": api_base, + "api_version": api_version, + "api_key": api_key, + "model_list": model_list, + "mock_response": mock_response, + "force_timeout": force_timeout, + "custom_llm_provider": custom_llm_provider, + } + if self.structured_output: + kwargs = self._prepare_kwargs(kwargs, self.structured_output) + async def _call_aclient_until_n_choices() -> List["Choices"]: choices = [] while len(choices) < num_generations: - completion = await self._aclient( # type: ignore - model=self.model, - messages=input, - n=num_generations, - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - stream=False, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - metadata=metadata, - api_base=api_base, - api_version=api_version, - api_key=api_key, - model_list=model_list, - mock_response=mock_response, - force_timeout=force_timeout, - custom_llm_provider=custom_llm_provider, - ) - choices.extend(completion.choices) + completion = await self._aclient(**kwargs) # type: ignore + if not self.structured_output: + completion = completion.choices + choices.extend(completion) return choices # litellm.drop_params is used to en/disable sending **kwargs parameters to the API if they cannot be used @@ -183,6 +202,11 @@ async def _call_aclient_until_n_choices() -> List["Choices"]: raise e generations = [] + + if self.structured_output: + generations.append([choice.model_dump_json() for choice in choices]) + return generations + for choice in choices: if (content := choice.message.content) is None: self._logger.warning( diff --git a/src/distilabel/llms/mistral.py b/src/distilabel/llms/mistral.py index d05d9d3f65..dd96cae91f 100644 --- a/src/distilabel/llms/mistral.py +++ b/src/distilabel/llms/mistral.py @@ -45,6 +45,9 @@ class MistralLLM(AsyncLLM): timeout: the maximum time in seconds to wait for a response. Defaults to `120`. max_concurrent_requests: the maximum number of concurrent requests to send. Defaults to `64`. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. _api_key_env_var: the name of the environment variable to use for the API key. It is meant to be used internally. _aclient: the `MistralAsyncClient` to use for the Mistral API. It is meant to be used internally. @@ -107,6 +110,16 @@ def load(self) -> None: max_concurrent_requests=self.max_concurrent_requests, ) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="mistral", + ) + self._aclient = result.get("client") + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" @@ -134,14 +147,26 @@ async def agenerate( # type: ignore Returns: A list of lists of strings containing the generated responses for each input. """ - completion = await self._aclient.chat( # type: ignore - messages=input, - model=self.model, - temperature=temperature, - max_tokens=max_new_tokens, - top_p=top_p, - ) + kwargs = { + "messages": input, # type: ignore + "model": self.model, + "max_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + } generations = [] + if self.structured_output: + kwargs = self._prepare_kwargs(kwargs, self.structured_output) + # TODO: This should work just with the _aclient.chat method, but it's not working. + # We need to check instructor and see if we can create a PR. + completion = await self._aclient.chat.completions.create(**kwargs) + else: + completion = await self._aclient.chat(**kwargs) + + if self.structured_output: + generations.append(completion.model_dump_json()) + return generations + for choice in completion.choices: if (content := choice.message.content) is None: self._logger.warning( diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py index 7314f7c74d..6dedc2387c 100644 --- a/src/distilabel/llms/openai.py +++ b/src/distilabel/llms/openai.py @@ -45,8 +45,9 @@ class OpenAILLM(AsyncLLM): failing. Defaults to `6`. timeout: the maximum time in seconds to wait for a response from the API. Defaults to `120`. - structured_output: a dictionary containing the structured output configuration or if more - fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. Runtime parameters: - `base_url`: the base URL to use for the OpenAI API requests. Defaults to `None`. @@ -110,6 +111,16 @@ def load(self) -> None: timeout=self.timeout, ) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="openai", + ) + self._aclient = result.get("client") + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" @@ -162,20 +173,29 @@ async def agenerate( # type: ignore f"Invalid response format '{response_format}'. Must be either 'text' or 'json'." ) - completion = await self._aclient.chat.completions.create( # type: ignore - messages=input, # type: ignore - model=self.model, - max_tokens=max_new_tokens, - n=num_generations, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - temperature=temperature, - top_p=top_p, - stop=stop, - timeout=50, - response_format={"type": response_format}, - ) + kwargs = { + "messages": input, # type: ignore + "model": self.model, + "max_tokens": max_new_tokens, + "n": num_generations, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "temperature": temperature, + "top_p": top_p, + "stop": stop, + "timeout": 50, + "response_format": {"type": response_format}, + } + if self.structured_output: + kwargs = self._prepare_kwargs(kwargs, self.structured_output) + generations = [] + completion = await self._aclient.chat.completions.create(**kwargs) + + if self.structured_output: + generations.append(completion.model_dump_json()) + return generations + for choice in completion.choices: if (content := choice.message.content) is None: self._logger.warning( # type: ignore diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 83c10a6067..bc77a53536 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -32,11 +32,14 @@ Union, ) +import fsspec import pyarrow as pa import pyarrow.parquet as pq from typing_extensions import Self +from upath import UPath from distilabel import __version__ +from distilabel.distiset import create_distiset from distilabel.pipeline._dag import DAG from distilabel.pipeline.constants import ( RECEIVES_ROUTED_BATCHES_ATTR_NAME, @@ -45,6 +48,7 @@ ) from distilabel.utils.dicts import flatten_dict from distilabel.utils.files import list_files_in_dir +from distilabel.utils.logging import setup_logging, stop_logging from distilabel.utils.serialization import ( TYPE_INFO_KEY, _check_is_dir, @@ -77,6 +81,7 @@ class _CacheLocation(TypedDict): pipeline: Path batch_manager: Path data: Path + batch_input_data: Path log_file: Path @@ -119,14 +124,31 @@ class BasePipeline(_Serializable): _cache_dir: The directory where the pipeline will be cached. _logger: The logger instance that will be used by the pipeline. _batch_manager: The batch manager that will manage the batches received from the - steps while running the pipeline. + steps while running the pipeline. It will be created when the pipeline is run, + from scratch or from cache. Defaults to `None`. + _write_buffer: The buffer that will store the data of the leaf steps of the pipeline + while running, so the `Distiset` can be created at the end. It will be created + when the pipeline is run. Defaults to `None`. + _logging_parameters: A dictionary containing the parameters that will passed to + `setup_logging` function to initialize the logging. Defaults to `{}`. + _fs: The `fsspec` filesystem to be used to store the data of the `_Batch`es passed + between the steps. It will be set when the pipeline is run. Defaults to `None`. + _storage_base_path: The base path where the data of the `_Batch`es passed between + the steps will be stored. It will be set then the pipeline is run. Defaults + to `None`. + _use_fs_to_pass_data: Whether to use the file system to pass the data of the + `_Batch`es between the steps. Even if this parameter is `False`, the `Batch`es + received by `GlobalStep`s will always use the file system to pass the data. + Defaults to `False`. + _dry_run: A flag to indicate if the pipeline is running in dry run mode. Defaults + to `False`. """ def __init__( self, name: str, description: Optional[str] = None, - cache_dir: Optional["PathLike"] = None, + cache_dir: Optional[Union[str, "PathLike"]] = None, enable_metadata: bool = False, ) -> None: """Initialize the `BasePipeline` instance. @@ -154,8 +176,15 @@ def __init__( self._logger = logging.getLogger("distilabel.pipeline") - # It's set to None here, will be created in the call to run self._batch_manager: Optional["_BatchManager"] = None + self._write_buffer: Optional["_WriteBuffer"] = None + self._logging_parameters: Dict[str, Any] = { + "filename": self._cache_location["log_file"] + } + + self._fs: Optional[fsspec.AbstractFileSystem] = None + self._storage_base_path: Optional[str] = None + self._use_fs_to_pass_data: bool = False self._dry_run: bool = False def __enter__(self) -> Self: @@ -225,10 +254,22 @@ def _create_signature(self) -> str: return hasher.hexdigest() + def _set_logging_parameters(self, parameters: Dict[str, Any]) -> None: + """Set the parameters that will be passed to the `setup_logging` function to + initialize the logging. + + Args: + parameters: A dictionary with the parameters that will be passed to the + `setup_logging` function. + """ + self._logging_parameters = parameters + def run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, use_cache: bool = True, + storage_parameters: Optional[Dict[str, Any]] = None, + use_fs_to_pass_data: bool = False, ) -> "Distiset": # type: ignore """Run the pipeline. It will set the runtime parameters for the steps and validate the pipeline. @@ -241,15 +282,59 @@ def run( the runtime parameters for the step as the value. Defaults to `None`. use_cache: Whether to use the cache from previous pipeline runs. Defaults to `True`. + storage_parameters: A dictionary with the storage parameters (`fsspec` and path) + that will be used to store the data of the `_Batch`es passed between the + steps if `use_fs_to_pass_data` is `True` (for the batches received by a + `GlobalStep` it will be always used). It must have at least the "path" key, + and it can contain additional keys depending on the protocol. By default, + it will use the local file system and a directory in the cache directory. + Defaults to `None`. + use_fs_to_pass_data: Whether to use the file system to pass the data of + the `_Batch`es between the steps. Even if this parameter is `False`, the + `Batch`es received by `GlobalStep`s will always use the file system to + pass the data. Defaults to `False`. Returns: The `Distiset` created by the pipeline. """ - if use_cache: - self._load_from_cache() + + setup_logging(**self._logging_parameters) + + # Set the runtime parameters that will be used during the pipeline execution self._set_runtime_parameters(parameters or {}) + + # Validate the pipeline DAG to check that all the steps are chainable, there are + # no missing runtime parameters, batch sizes are correct, etc. self.dag.validate() + # Load the `_BatchManager` from cache or create one from scratch + self._load_batch_manager(use_cache) + + # Setup the filesystem that will be used to pass the data of the `_Batch`es + self._setup_fsspec(storage_parameters) + self._use_fs_to_pass_data = use_fs_to_pass_data + + if self._dry_run: + self._logger.info("🌵 Dry run mode") + + # If the batch manager is not able to generate batches, that means that the loaded + # `_BatchManager` from cache didn't have any remaining batches to process i.e. + # the previous pipeline execution was completed successfully. + if not self._batch_manager.can_generate(): # type: ignore + self._logger.info( + "💾 Loaded batch manager from cache doesn't contain any remaining data." + " Returning `Distiset` from cache data..." + ) + stop_logging() + return create_distiset( + self._cache_location["data"], + pipeline_path=self._cache_location["pipeline"], + log_filename_path=self._cache_location["log_file"], + enable_metadata=self._enable_metadata, + ) + + self._setup_write_buffer() + def dry_run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, @@ -298,6 +383,40 @@ def get_runtime_parameters_info(self) -> Dict[str, List[Dict[str, Any]]]: runtime_parameters[step_name] = step.get_runtime_parameters_info() return runtime_parameters + def _setup_fsspec( + self, storage_parameters: Optional[Dict[str, Any]] = None + ) -> None: + """Setups the `fsspec` filesystem to be used to store the data of the `_Batch`es + passed between the steps. + + Args: + storage_parameters: A dictionary with the storage parameters (`fsspec` and path) + that will be used to store the data of the `_Batch`es passed between the + steps if `use_fs_to_pass_data` is `True` (for the batches received by a + `GlobalStep` it will be always used). It must have at least the "path" key, + and it can contain additional keys depending on the protocol. By default, + it will use the local file system and a directory in the cache directory. + Defaults to `None`. + """ + if not storage_parameters: + self._fs = fsspec.filesystem("file") + self._storage_base_path = ( + f"file://{self._cache_location['batch_input_data']}" + ) + return + + if "path" not in storage_parameters: + raise ValueError( + "The 'path' key must be present in the `storage_parameters` dictionary" + " if it's not `None`." + ) + + path = storage_parameters.pop("path") + protocol = UPath(path).protocol + + self._fs = fsspec.filesystem(protocol, **storage_parameters) + self._storage_base_path = path + def _add_step(self, step: "_Step") -> None: """Add a step to the pipeline. @@ -412,11 +531,15 @@ def _cache_location(self) -> _CacheLocation: "pipeline": folder / "pipeline.yaml", "batch_manager": folder / "batch_manager.json", "data": folder / "data", + "batch_input_data": folder / "batch_input_data", "log_file": folder / "pipeline.log", } def _cache(self) -> None: """Saves the `BasePipeline` using the `_cache_filename`.""" + if self._dry_run: + return + self.save( path=self._cache_location["pipeline"], format=self._cache_location["pipeline"].suffix.replace(".", ""), # type: ignore @@ -425,17 +548,54 @@ def _cache(self) -> None: self._batch_manager.cache(self._cache_location["batch_manager"]) self._logger.debug("Pipeline and batch manager saved to cache.") - def _load_from_cache(self) -> None: - """Will try to load the `BasePipeline` from the cache dir if found, updating - the internal `DAG` and `_BatchManager`. + def _load_batch_manager(self, use_cache: bool = True) -> None: + """Will try to load the `_BatchManager` from the cache dir if found. Otherwise, + it will create one from scratch. """ - cache_loc = self._cache_location - if cache_loc["pipeline"].exists(): - if cache_loc["batch_manager"].exists(): - self._batch_manager = _BatchManager.load_from_cache( - cache_loc["batch_manager"] - ) - self._logger.info("💾 Load pipeline from cache") + batch_manager_cache_loc = self._cache_location["batch_manager"] + if use_cache and batch_manager_cache_loc.exists(): + self._logger.info( + f"💾 Loading `_BatchManager` from cache: '{batch_manager_cache_loc}'" + ) + self._batch_manager = _BatchManager.load_from_cache(batch_manager_cache_loc) + else: + self._batch_manager = _BatchManager.from_dag(self.dag) + + def _setup_write_buffer(self) -> None: + """Setups the `_WriteBuffer` that will store the data of the leaf steps of the + pipeline while running, so the `Distiset` can be created at the end. + """ + buffer_data_path = self._cache_location["data"] + self._logger.info(f"📝 Pipeline data will be written to '{buffer_data_path}'") + self._write_buffer = _WriteBuffer(buffer_data_path, self.dag.leaf_steps) + + def _send_batch_to_step(self, batch: "_Batch") -> None: + """Sends a batch to the input queue of a step, writing the data of the batch + to the filesystem and setting `batch.data_path` with the path where the data + was written (if requiered i.e. the step is a global step or `use_fs_to_pass_data`) + + This method should be extended by the specific pipeline implementation, adding + the logic to send the batch to the step. + + Args: + batch: The batch to send. + """ + self._logger.debug( + f"Setting batch {batch.seq_no} as last batch sent to '{batch.step_name}': {batch}" + ) + self._batch_manager.set_last_batch_sent(batch) # type: ignore + + step: "_Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] + if not step.is_generator and (step.is_global or self._use_fs_to_pass_data): + base_path = UPath(self._storage_base_path) / step.name # type: ignore + self._logger.debug( + f"Writing {batch.seq_no} batch for '{batch.step_name}' step to filesystem: {base_path}" + ) + batch.write_batch_data_to_fs(self._fs, base_path) # type: ignore + + self._logger.debug( + f"Sending batch {batch.seq_no} to step '{batch.step_name}': {batch}" + ) @dataclass @@ -447,6 +607,8 @@ class _Batch(_Serializable): step_name: The name of the step that will process the batch. last_batch: A flag to indicate if the batch is the last one. data: The data to be processed. + data_hash: The hash of the data. Defaults to `None`. + data_path: The path where the data of the batch is stored. Defaults to `None`. accumulated: A flag to indicate if the batch is accumulated. created_from: A dictionary containing the `seq_no` of the batches of the steps that were used to create this batch. @@ -458,10 +620,12 @@ class _Batch(_Serializable): last_batch: bool data: List[List[Dict[str, Any]]] = field(default_factory=list, repr=False) data_hash: Optional[str] = None + data_path: Optional[str] = None accumulated: bool = False created_from: Dict[str, List[Tuple[int, int]]] = field(default_factory=dict) batch_routed_to: List[str] = field(default_factory=list) size: int = 0 + _fs: Optional[fsspec.AbstractFileSystem] = None def next_batch(self) -> "_Batch": """Create a new `_Batch` instance with the next batch of data. @@ -498,6 +662,9 @@ def get_data(self, num_rows: Union[int, None] = None) -> List[Dict[str, Any]]: A list with the data taken from the batch. """ + if self.data == [] and self.data_path is not None: + pass + if num_rows is None: data = self.data[0] self.data = [] @@ -571,6 +738,73 @@ def copy(self) -> "_Batch": """ return copy.deepcopy(self) + def write_batch_data_to_fs( + self, + fs: Optional[fsspec.AbstractFileSystem] = None, + base_path: Optional[UPath] = None, + ) -> None: + """Writes the content of the batch to the filesystem. + + Args + fs: The `fsspec` filesystem to be used to write the data. If not provided, the + one set in the `_fs` attribute will be used. Defaults to `None`. + base_path: The base path where the data of the batch will be stored. If not + provided, the one set in the `data_path` attribute will be used. Defaults + to `None`. + + Raises: + ValueError: If `fs` is not provided and the `_fs` attribute is not set. + """ + + if not fs and not self._fs: + raise ValueError( + "The `fs` parameter must be provided if the `_fs` attribute is not set." + ) + + if fs: + self._fs = fs + + if not base_path and not self.data_path: + raise ValueError( + "The `base_path` parameter must be provided if the `data_path` attribute" + " is not set." + ) + + seq_no_dir = ( + base_path / f"seq_no_{self.seq_no}" if base_path else UPath(self.data_path) + ) + seq_no_dir._fs_cached = self._fs # type: ignore + seq_no_dir.mkdir(parents=True, exist_ok=True) + + for i, data in enumerate(self.data): + table = pa.Table.from_pylist(data) + with self._fs.open(seq_no_dir / f"data_index_{i}.parquet", "wb") as f: # type: ignore + pq.write_table(table, f) + + self.data = [] + self.data_path = str(seq_no_dir) + + def read_batch_data_from_fs(self) -> None: + """Reads the content of the batch from the filesystem.""" + if not self.data_path: + raise ValueError( + "`data_path` attribute must be set to read the data from the filesystem." + " Use `write_batch_data_to_fs` method to set the `data_path` attribute." + ) + + if not self._fs: + raise ValueError( + "`_fs` attribute must be set to read the data from the filesystem." + " Use `write_batch_data_to_fs` method to set the `_fs` attribute." + ) + + for file in self._fs.ls(self.data_path): + with self._fs.open(file, "rb") as f: + table = pq.read_table(f) + self.data.append(table.to_pylist()) + + self._fs.rm(self.data_path, recursive=True) + @dataclass class _BatchManagerStep(_Serializable): diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index 98c5f00a37..36f41e16b9 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import multiprocessing as mp import signal +import sys import threading import time import traceback @@ -28,8 +28,6 @@ LAST_BATCH_SENT_FLAG, BasePipeline, _Batch, - _BatchManager, - _WriteBuffer, ) from distilabel.pipeline.constants import ( CONVERGENCE_STEP_ATTR_NAME, @@ -68,9 +66,14 @@ _SUBPROCESS_EXCEPTION: Union[Exception, None] = None -def _init_worker(queue: "Queue[Any]") -> None: +def _init_worker(log_queue: "Queue[Any]") -> None: + """Init function for the child processes that will execute the `Step`s of the `Pipeline`. + + Args: + log_queue: The queue to send the logs to the main process. + """ signal.signal(signal.SIGINT, signal.SIG_IGN) - setup_logging(queue) + setup_logging(log_queue) class Pipeline(BasePipeline): @@ -80,6 +83,8 @@ def run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, use_cache: bool = True, + storage_parameters: Optional[Dict[str, Any]] = None, + use_fs_to_pass_data: bool = False, ) -> "Distiset": """Runs the pipeline. @@ -88,6 +93,17 @@ def run( the runtime parameters for the step as the value. Defaults to `None`. use_cache: Whether to use the cache from previous pipeline runs. Defaults to `True`. + storage_parameters: A dictionary with the storage parameters (`fsspec` and path) + that will be used to store the data of the `_Batch`es passed between the + steps if `use_fs_to_pass_data` is `True` (for the batches received by a + `GlobalStep` it will be always used). It must have at least the "path" key, + and it can contain additional keys depending on the protocol. By default, + it will use the local file system and a directory in the cache directory. + Defaults to `None`. + use_fs_to_pass_data: Whether to use the file system to pass the data of + the `_Batch`es between the steps. Even if this parameter is `False`, the + `Batch`es received by `GlobalStep`s will always use the file system to + pass the data. Defaults to `False`. Returns: The `Distiset` created by the pipeline. @@ -96,37 +112,15 @@ def run( RuntimeError: If the pipeline fails to load all the steps. """ log_queue = mp.Queue() - # We must place the runtime parameters before calling setup_logging to ensure consistency - super().run(parameters, use_cache) - setup_logging(log_queue, filename=str(self._cache_location["log_file"])) # type: ignore - self._logger = logging.getLogger("distilabel.pipeline.local") - - if self._dry_run: - # This message is placed here to ensure we are using the already setup logger. - self._logger.info("🌵 Dry run mode") - - if self._batch_manager is None: - self._batch_manager = _BatchManager.from_dag(self.dag) - - # If the batch manager is not able to generate batches, that means that the loaded - # `_BatchManager` from cache didn't have any remaining batches to process i.e. - # the previous pipeline execution was completed successfully. - if not self._batch_manager.can_generate(): - self._logger.info( - "💾 Loaded batch manager from cache doesn't have any remaining data. Returning" - " `Distiset` from cache data..." - ) - stop_logging() - return create_distiset( - self._cache_location["data"], - pipeline_path=self._cache_location["pipeline"], - log_filename_path=self._cache_location["log_file"], - enable_metadata=self._enable_metadata, - ) - buffer_data_path = self._cache_location["data"] - self._logger.info(f"📝 Pipeline data will be written to '{buffer_data_path}'") - write_buffer = _WriteBuffer(buffer_data_path, self.dag.leaf_steps) + self._set_logging_parameters( + {"log_queue": log_queue, "filename": self._cache_location["log_file"]} + ) + + if distiset := super().run( + parameters, use_cache, storage_parameters, use_fs_to_pass_data + ): + return distiset num_processes = len(self.dag) ctx = mp.get_context() # type: ignore @@ -144,7 +138,7 @@ def run( # Wait for all the steps to be loaded correctly if not self._all_steps_loaded(): - write_buffer.close() + self._write_buffer.close() # type: ignore self._batch_manager = None stop_logging() raise RuntimeError( @@ -156,15 +150,17 @@ def run( self._request_initial_batches() # Start a loop to receive the output batches from the steps - self._run_output_queue_loop_in_thread(write_buffer) + self._run_output_queue_loop_in_thread() # Send `None` to steps `input_queue`s just in case some step is still waiting self._notify_steps_to_stop() - pool.close() - pool.join() + # `Pool.__exit__` has already called `terminate`, `join` the pool to make sure + # all the processes have finished + pool.join() + manager.join() - write_buffer.close() + self._write_buffer.close() # type: ignore distiset = create_distiset( self._cache_location["data"], pipeline_path=self._cache_location["pipeline"], @@ -174,15 +170,11 @@ def run( stop_logging() return distiset - def _run_output_queue_loop_in_thread(self, write_buffer: "_WriteBuffer") -> None: + def _run_output_queue_loop_in_thread(self) -> None: """Runs the output queue loop in a separate thread to receive the output batches from the steps. This is done to avoid the signal handler to block the loop, which - would prevent the pipeline from stopping correctly. - - Args: - write_buffer: The write buffer to write the data from the leaf steps to disk. - """ - thread = threading.Thread(target=self._output_queue_loop, args=(write_buffer,)) + would prevent the pipeline from stopping correctly.""" + thread = threading.Thread(target=self._output_queue_loop) thread.start() thread.join() @@ -193,21 +185,28 @@ def _notify_steps_to_stop(self) -> None: if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME): input_queue.put(None) - def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None: + def _output_queue_loop(self) -> None: """Loop to receive the output batches from the steps and manage the flow of the - batches through the pipeline. - - Args: - write_buffer: The write buffer to write the data from the leaf steps to disk. - """ + batches through the pipeline.""" while self._batch_manager.can_generate() and not _STOP_CALLED: # type: ignore self._logger.debug("Waiting for output batch from step...") if (batch := self.output_queue.get()) is None: self._logger.debug("Received `None` from output queue. Breaking loop.") break + self._logger.debug( + f"Received batch with seq_no {batch.seq_no} from step '{batch.step_name}'" + f" from output queue: {batch}" + ) + + if batch.data_path: + self._logger.debug( + f"Reading {batch.seq_no} batch data from '{batch.step_name}': '{batch.data_path}'" + ) + batch.read_batch_data_from_fs() + if batch.step_name in self.dag.leaf_steps: - write_buffer.add_batch(batch) + self._write_buffer.add_batch(batch) # type: ignore # If `_STOP_CALLED` was set to `True` while waiting for the output queue, then # we need to handle the stop of the pipeline and break the loop to avoid @@ -217,15 +216,10 @@ def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None: self._handle_batch_on_stop(batch) break - self._logger.debug( - f"Received batch with seq_no {batch.seq_no} from step '{batch.step_name}'" - f" from output queue: {batch}" - ) - self._manage_batch_flow(batch) if _STOP_CALLED: - self._handle_stop(write_buffer) + self._handle_stop() def _manage_batch_flow(self, batch: "_Batch") -> None: """Checks if the step that generated the batch has more data in its buffer to @@ -357,14 +351,10 @@ def _request_more_batches_if_needed(self, step: "Step") -> None: ) self._send_batch_to_step(last_batch.next_batch()) - def _handle_stop(self, write_buffer: "_WriteBuffer") -> None: + def _handle_stop(self) -> None: """Handles the stop of the pipeline execution, which will stop the steps from processing more batches and wait for the output queue to be empty, to not lose - any data that was already processed by the steps before the stop was called. - - Args: - write_buffer: The write buffer to write the data from the leaf steps to disk. - """ + any data that was already processed by the steps before the stop was called.""" self._logger.debug("Handling stop of the pipeline execution...") # Add the remaining batches in the input queues back to the batch manager @@ -399,7 +389,7 @@ def _handle_stop(self, write_buffer: "_WriteBuffer") -> None: continue if batch.step_name in self.dag.leaf_steps: - write_buffer.add_batch(batch) + self._write_buffer.add_batch(batch) # type: ignore self._handle_batch_on_stop(batch) @@ -522,14 +512,7 @@ def _send_batch_to_step(self, batch: "_Batch") -> None: Args: batch: The batch to send. """ - self._logger.debug( - f"Setting batch {batch.seq_no} as last batch sent to '{batch.step_name}': {batch}" - ) - self._batch_manager.set_last_batch_sent(batch) # type: ignore - - self._logger.debug( - f"Sending batch {batch.seq_no} to step '{batch.step_name}': {batch}" - ) + super()._send_batch_to_step(batch) input_queue = self.dag.get_step(batch.step_name)[INPUT_QUEUE_ATTR_NAME] input_queue.put(batch) @@ -582,7 +565,7 @@ def _run_steps_in_loop( for step_name in self.dag: step: "Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME] input_queue = manager.Queue() - self.dag.set_step_attr(step.name, INPUT_QUEUE_ATTR_NAME, input_queue) + self.dag.set_step_attr(step.name, INPUT_QUEUE_ATTR_NAME, input_queue) # type: ignore # Set `pipeline` to `None` as in some Python environments the pipeline is not # picklable and it will raise an error when trying to send the step to the process. @@ -623,20 +606,21 @@ def _error_callback(self, e: BaseException) -> None: with self.shared_info[_STEPS_LOADED_LOCK_KEY]: self.shared_info[_STEPS_LOADED_KEY] = [_STEPS_LOADED_ERROR_CODE] _SUBPROCESS_EXCEPTION = e.subprocess_exception - _SUBPROCESS_EXCEPTION.__traceback__ = tblib.Traceback.from_string( + _SUBPROCESS_EXCEPTION.__traceback__ = tblib.Traceback.from_string( # type: ignore e.formatted_traceback ).as_traceback() return # If the step is global, is not in the last trophic level and has no successors, # then we can ignore the error and continue executing the pipeline + step_name: str = e.step.name # type: ignore if ( e.step.is_global - and not self.dag.step_in_last_trophic_level(e.step.name) - and list(self.dag.get_step_successors(e.step.name)) == [] + and not self.dag.step_in_last_trophic_level(step_name) + and list(self.dag.get_step_successors(step_name)) == [] ): self._logger.error( - f"✋ An error occurred when running global step '{e.step.name}' with no" + f"✋ An error occurred when running global step '{step_name}' with no" " successors and not in the last trophic level. Pipeline execution can" f" continue. Error will be ignored." ) @@ -644,7 +628,7 @@ def _error_callback(self, e: BaseException) -> None: return # Global step with successors failed - self._logger.error(f"An error occurred in global step '{e.step.name}'") + self._logger.error(f"An error occurred in global step '{step_name}'") self._logger.error(f"Subprocess traceback:\n\n{e.formatted_traceback}") self._cache() self._stop() @@ -691,28 +675,22 @@ def _stop( if _STOP_CALLED: global _STOP_CALLS _STOP_CALLS += 1 - # if _STOP_CALLS == 1: - # self._logger.warning( - # "🛑 Stop has already been called. Ignoring subsequent calls and waiting" - # " for the pipeline to finish..." - # ) if _STOP_CALLS == 1: self._logger.warning( "🛑 Press again to force the pipeline to stop." ) elif _STOP_CALLS > 1: self._logger.warning("🛑 Forcing pipeline interruption.") - import gc - import sys - - if manager: - manager.shutdown() if pool: - pool.close() pool.terminate() + pool.join() + + if manager: + manager.shutdown() + manager.join() - gc.collect() + stop_logging() sys.exit(1) @@ -958,6 +936,11 @@ def _non_generator_process_loop(self) -> None: self.step._logger.info( f"📦 Processing batch {batch.seq_no} in '{batch.step_name}'" ) + + if batch.data_path is not None: + self.step._logger.debug(f"Reading batch data from '{batch.data_path}'") + batch.read_batch_data_from_fs() + result = [] try: if self.step.has_multiple_inputs: @@ -969,11 +952,7 @@ def _non_generator_process_loop(self) -> None: raise _ProcessWrapperException(str(e), self.step, 2, e) from e # Impute step outputs columns with `None` - for row in batch.data[0]: - data = row.copy() - for output in self.step.outputs: - data[output] = None - result.append(data) + result = self._impute_step_outputs(batch) # if the step is not global then we can skip the batch which means sending # an empty batch to the output queue @@ -991,8 +970,26 @@ def _non_generator_process_loop(self) -> None: if batch.last_batch: break + def _impute_step_outputs(self, batch: "_Batch") -> List[Dict[str, Any]]: + """Imputes the step outputs columns with `None` in the batch data. + + Args: + batch: The batch to impute. + """ + result = [] + for row in batch.data[0]: + data = row.copy() + for output in self.step.outputs: + data[output] = None + result.append(data) + return result + def _send_batch(self, batch: _Batch) -> None: """Sends a batch to the `output_queue`.""" + if batch.data_path is not None: + self.step._logger.debug(f"Writing batch data to '{batch.data_path}'") + batch.write_batch_data_to_fs() + self.step._logger.info( f"📨 Step '{batch.step_name}' sending batch {batch.seq_no} to output queue" ) diff --git a/src/distilabel/steps/argilla/base.py b/src/distilabel/steps/argilla/base.py index 1460e55f99..57cd1863bf 100644 --- a/src/distilabel/steps/argilla/base.py +++ b/src/distilabel/steps/argilla/base.py @@ -137,9 +137,7 @@ def load(self) -> None: @property @abstractmethod - def inputs(self) -> List[str]: - ... + def inputs(self) -> List[str]: ... @abstractmethod - def process(self, *inputs: StepInput) -> "StepOutput": - ... + def process(self, *inputs: StepInput) -> "StepOutput": ... diff --git a/src/distilabel/steps/base.py b/src/distilabel/steps/base.py index fcac454447..9aab121815 100644 --- a/src/distilabel/steps/base.py +++ b/src/distilabel/steps/base.py @@ -220,18 +220,15 @@ def _set_routing_batch_function( routing_batch_function._step = self @overload - def __rshift__(self, other: "RoutingBatchFunction") -> "RoutingBatchFunction": - ... + def __rshift__(self, other: "RoutingBatchFunction") -> "RoutingBatchFunction": ... @overload def __rshift__( self, other: List["DownstreamConnectableSteps"] - ) -> List["DownstreamConnectableSteps"]: - ... + ) -> List["DownstreamConnectableSteps"]: ... @overload - def __rshift__(self, other: "DownstreamConnectable") -> "DownstreamConnectable": - ... + def __rshift__(self, other: "DownstreamConnectable") -> "DownstreamConnectable": ... def __rshift__( self, diff --git a/src/distilabel/steps/decorator.py b/src/distilabel/steps/decorator.py index 1d7c1853cb..da2cbb8dcc 100644 --- a/src/distilabel/steps/decorator.py +++ b/src/distilabel/steps/decorator.py @@ -53,8 +53,7 @@ def step( inputs: Union[List[str], None] = None, outputs: Union[List[str], None] = None, step_type: Literal["normal"] = "normal", -) -> Callable[..., Type["Step"]]: - ... +) -> Callable[..., Type["Step"]]: ... @overload @@ -62,8 +61,7 @@ def step( inputs: Union[List[str], None] = None, outputs: Union[List[str], None] = None, step_type: Literal["global"] = "global", -) -> Callable[..., Type["GlobalStep"]]: - ... +) -> Callable[..., Type["GlobalStep"]]: ... @overload @@ -71,8 +69,7 @@ def step( inputs: None = None, outputs: Union[List[str], None] = None, step_type: Literal["generator"] = "generator", -) -> Callable[..., Type["GeneratorStep"]]: - ... +) -> Callable[..., Type["GeneratorStep"]]: ... def step( diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index c69e51bc2c..7e2cfc2520 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -25,6 +25,7 @@ StepInput, _Step, ) +from distilabel.steps.constants import DISTILABEL_METADATA_KEY from distilabel.utils.dicts import combine_dicts if TYPE_CHECKING: @@ -33,9 +34,6 @@ from distilabel.steps.typing import StepOutput -DISTILABEL_METADATA_KEY = "distilabel_metadata" - - class _Task(_Step, ABC): """_Task is an abstract class that implements the `_Step` interface and adds the `format_input` and `format_output` methods to format the inputs and outputs of the diff --git a/src/distilabel/steps/tasks/structured_outputs/instructor.py b/src/distilabel/steps/tasks/structured_outputs/instructor.py new file mode 100644 index 0000000000..e9ec1ea431 --- /dev/null +++ b/src/distilabel/steps/tasks/structured_outputs/instructor.py @@ -0,0 +1,140 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +from typing import ( + TYPE_CHECKING, + Callable, + Literal, + Optional, + Tuple, + Type, + TypeAlias, + TypedDict, + Union, + get_args, +) + +from pydantic import BaseModel + +if TYPE_CHECKING: + import instructor + from anthropic import AsyncAnthropic + from cohere import AsyncClient as AsyncCohere + from groq import AsyncGroq + from mistralai.async_client import MistralAsyncClient + from openai import AsyncAzureOpenAI, AsyncOpenAI + + +InstructorFrameworks = Literal[ + "openai", "azure_openai", "anthropic", "cohere", "groq", "litellm", "mistral" +] +"""Available frameworks for the structured output configuration with `instructor`. """ + +InstructorAvailableClients: TypeAlias = Union[ + "AsyncAnthropic", + "AsyncAzureOpenAI", + "AsyncCohere", + "AsyncGroq", + "AsyncOpenAI", + "MistralAsyncClient", +] +"""Available clients that can be wrapped with `instructor`. """ + + +class InstructorStructuredOutputType(TypedDict): + """TypedDict to represent the structured output configuration from `instructor`.""" + + schema: Type[BaseModel] + """The schema to use for the structured output, a `pydantic.BaseModel` class. """ + mode: Optional["instructor.Mode"] + """Generation mode. Take a look at `instructor.Mode` for more information, if not informed it will + be determined automatically. """ + max_retries: int + """Number of times to reask the model in case of error, if not set will default to the model's default. """ + + +def _client_patcher(framework: InstructorFrameworks) -> Tuple[Callable, str]: + """Helper function to return the appropriate instructor client for the given framework. + + Args: + framework: The framework to use for the instructor client. + + Raises: + ValueError: If the framework is not one of the available frameworks. + + Returns: + Tuple of Callable and string, with the builder of the client patch and the + default mode to use. + """ + import instructor + + if framework in {"openai", "azure_openai"}: + patch = instructor.from_openai, instructor.Mode.TOOLS + elif framework == "anthropic": + patch = instructor.from_anthropic, instructor.Mode.ANTHROPIC_JSON + elif framework == "litellm": + patch = instructor.from_litellm, instructor.Mode.TOOLS + elif framework == "mistral": + patch = instructor.from_mistral, instructor.Mode.MISTRAL_TOOLS + elif framework == "cohere": + patch = instructor.from_cohere, instructor.Mode.COHERE_TOOLS + elif framework == "groq": + patch = instructor.from_groq, instructor.Mode.TOOLS + else: + raise ValueError( + f"Invalid framework '{framework}'. Must be one of {get_args(InstructorFrameworks)}" + ) + + return patch + + +def prepare_instructor( + client: InstructorAvailableClients, + mode: Optional["instructor.Mode"] = None, + framework: Optional[InstructorFrameworks] = None, +) -> "instructor.AsyncInstructor": + """Wraps the given client with the instructor client for the given framework. + + Args: + client: The client to wrap with the instructor client, corresponds to the internal + client we wrap on `LLM`, and one of the implemented in `instructor`. + mode: One of the `instructor.Mode` values. Defaults to None. + framework: The framework corresponding to the client. Defaults to None. + + Raises: + ImportError: If `instructor` is not installed. + ValueError: If the mode is not one of the available modes. + + Returns: + patched_client: The instructor wrapping the original client to be used for + structured generation. + """ + if not importlib.util.find_spec("instructor"): + raise ImportError( + "`instructor` is not installed. Please install it using `pip install instructor`." + ) + import instructor + + builder, default_mode = _client_patcher(framework) + + mode = mode or default_mode + if mode.value not in [m.value for m in instructor.mode.Mode]: + raise ValueError( + f"Invalid mode '{mode}'. Must be one of {[m.value for m in instructor.mode.Mode]}" + ) + + patched_client: instructor.AsyncInstructor = builder(client, mode=mode) + + return patched_client diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py index 087f4913bc..7bc53623de 100644 --- a/src/distilabel/steps/tasks/structured_outputs/outlines.py +++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py @@ -31,6 +31,8 @@ from pydantic import BaseModel +from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict + Frameworks = Literal["transformers", "llamacpp", "vllm"] """Available frameworks for the structured output configuration. """ @@ -59,15 +61,6 @@ def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]: return json.dumps(schema.model_json_schema()) -def _schema_as_dict(schema: Union[str, Type[BaseModel]]) -> Dict[str, Any]: - """Helper function to obtain the schema and simplify serialization.""" - if type(schema) == type(BaseModel): - return schema.model_json_schema() - elif isinstance(schema, str): - return json.loads(schema) - return schema - - def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]: """Helper function to return the appropriate logits processor for the given framework.""" if framework == "transformers": @@ -137,7 +130,7 @@ def prepare_guided_output( llm, whitespace_pattern=structured_output.get("whitespace_pattern"), ), - "schema": _schema_as_dict(schema), + "schema": schema_as_dict(schema), } if format == "regex": diff --git a/src/distilabel/steps/tasks/structured_outputs/utils.py b/src/distilabel/steps/tasks/structured_outputs/utils.py new file mode 100644 index 0000000000..8bcebcb819 --- /dev/null +++ b/src/distilabel/steps/tasks/structured_outputs/utils.py @@ -0,0 +1,157 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, Dict, List, Optional, Type, Union + +from pydantic import BaseModel, Field, create_model + + +def schema_as_dict(schema: Union[str, Type[BaseModel]]) -> Dict[str, Any]: + """Helper function to obtain the schema and simplify serialization.""" + if type(schema) == type(BaseModel): + return schema.model_json_schema() + elif isinstance(schema, str): + return json.loads(schema) + return schema + + +# NOTE: The following functions were copied from: +# https://github.com/pydantic/pydantic/issues/643#issuecomment-1999755873 +# and slightly modified to work with nested models. +# It would be nice to find the original source of this code to give credit. +# Other option would be working with this library: https://github.com/c32168/dyntamic + + +def json_schema_to_model(json_schema: Dict[str, Any]) -> Type[BaseModel]: + """Converts a JSON schema to a `pydantic.BaseModel` class. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A `pydantic.BaseModel` class. + """ + + # Extract the model name from the schema title. + model_name = json_schema.get("title") + if defs := json_schema.get("$defs", None): + # This is done to grab the content of nested classes that need to dereference + # the objects (those should be in a higher level). + pass + + # Extract the field definitions from the schema properties. + field_definitions = { + name: json_schema_to_pydantic_field( + name, prop, json_schema.get("required", []), defs=defs + ) + for name, prop in json_schema.get("properties", {}).items() + } + + # Create the BaseModel class using create_model(). + return create_model(model_name, **field_definitions) + + +def json_schema_to_pydantic_field( + name: str, + json_schema: Dict[str, Any], + required: List[str], + defs: Optional[Dict[str, Any]] = None, +) -> Any: + """Converts a JSON schema property to a `pydantic.Field`. + + Args: + name: The field name. + json_schema: The JSON schema property. + required: The list of required fields. + defs: The definitions of the JSON schema. It's used to dereference nested classes, + so we can grab the original definition from the json schema (it won't + work out of the box with just the reference). + + Returns: + A `pydantic.Field`. + """ + + # NOTE(plaguss): This needs more testing, nested classes need extra work to be converted + # here if we pass a reference to another class it will crash, we have to find the original + # definition and insert it here + # This takes into account single items referred to other classes + if ref := json_schema.get("$ref"): + json_schema = defs.get(ref.split("/")[-1]) + + # This takes into account lists of items referred to other classes + if "items" in json_schema and (ref := json_schema["items"].get("$ref")): + json_schema["items"] = defs.get(ref.split("/")[-1]) + + # Get the field type. + type_ = json_schema_to_pydantic_type(json_schema) + + # Get the field description. + description = json_schema.get("description") + + # Get the field examples. + examples = json_schema.get("examples") + + # Create a Field object with the type, description, and examples. + # The "required" flag will be set later when creating the model. + return ( + type_, + Field( + description=description, + examples=examples, + default=... if name in required else None, + ), + ) + + +def json_schema_to_pydantic_type(json_schema: Dict[str, Any]) -> Any: + """Converts a JSON schema type to a Pydantic type. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic type. + """ + type_ = json_schema.get("type") + + if type_ == "string": + type_val = str + elif type_ == "integer": + type_val = int + elif type_ == "number": + type_val = float + elif type_ == "boolean": + type_val = bool + elif type_ == "array": + items_schema = json_schema.get("items") + if items_schema: + item_type = json_schema_to_pydantic_type(items_schema) + type_val = List[item_type] + else: + type_val = List + elif type_ == "object": + # Handle nested models. + properties = json_schema.get("properties") + if properties: + nested_model = json_schema_to_model(json_schema) + type_val = nested_model + else: + type_val = Dict + elif type_ == "null": + type_val = Optional[Any] # Use Optional[Any] for nullable fields + else: + raise ValueError(f"Unsupported JSON schema type: {type_}") + + return type_val diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index 41eb4444dc..ece5344caf 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -62,14 +62,10 @@ def format_input(self, input: Dict[str, Any]) -> ChatType: is the first interaction from the user within a conversation.""" if is_openai_format(input["instruction"]): - warnings.warn( + raise ValueError( "Providing `instruction` formatted as an OpenAI chat / conversation is" - " about to be deprecated in `distilabel v1.2.0`, please make sure to use" - " `ChatTextGeneration` with `messages` as input instead.", - DeprecationWarning, - stacklevel=2, + " deprecated, you should use `ChatGeneration` with `messages` as input instead.", ) - return input["instruction"] if not isinstance(input["instruction"], str): raise ValueError( diff --git a/src/distilabel/utils/logging.py b/src/distilabel/utils/logging.py index 15a737f448..af1b26a18b 100644 --- a/src/distilabel/utils/logging.py +++ b/src/distilabel/utils/logging.py @@ -42,7 +42,9 @@ queue_listener: Union[QueueListener, None] = None -def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> None: +def setup_logging( + log_queue: Optional["Queue[Any]"] = None, filename: Optional[str] = None +) -> None: """Sets up logging to use a queue across all processes.""" global queue_listener @@ -53,7 +55,7 @@ def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> No # If the current process is the main process, set up a `QueueListener` # to handle logs from all subprocesses - if mp.current_process().name == "MainProcess": + if mp.current_process().name == "MainProcess" and filename: formatter = logging.Formatter("['%(name)s'] %(message)s") handler = RichHandler(rich_tracebacks=True) handler.setFormatter(formatter) @@ -66,10 +68,11 @@ def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> No ) file_handler.setFormatter(file_formatter) - queue_listener = QueueListener( - log_queue, handler, file_handler, respect_handler_level=True - ) - queue_listener.start() + if log_queue is not None: + queue_listener = QueueListener( + log_queue, handler, file_handler, respect_handler_level=True + ) + queue_listener.start() log_level = os.environ.get("DISTILABEL_LOG_LEVEL", "INFO").upper() if log_level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: @@ -80,9 +83,15 @@ def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> No log_level = "INFO" root_logger = logging.getLogger() - root_logger.handlers.clear() + + running_test = "PYTEST_CURRENT_TEST" in os.environ + if not running_test: + root_logger.handlers.clear() + + if log_queue is not None: + root_logger.addHandler(QueueHandler(log_queue)) + root_logger.setLevel(log_level) - root_logger.addHandler(QueueHandler(log_queue)) def stop_logging() -> None: @@ -90,4 +99,5 @@ def stop_logging() -> None: global queue_listener if queue_listener is not None: queue_listener.stop() + queue_listener.queue.close() queue_listener = None diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000000..8337c9aaa9 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,27 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from typing import Generator + +import pytest + + +@pytest.fixture(autouse=True) +def temp_cache_dir() -> Generator[None, None, None]: + """Set the cache directory to a temporary directory for all tests.""" + with tempfile.TemporaryDirectory() as tmpdirname: + os.environ["DISTILABEL_CACHE_DIR"] = tmpdirname + yield diff --git a/tests/integration/test_cache.py b/tests/integration/test_cache.py new file mode 100644 index 0000000000..6eddd6f7ca --- /dev/null +++ b/tests/integration/test_cache.py @@ -0,0 +1,55 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List + +import numpy as np +import pytest +from distilabel.pipeline import Pipeline +from distilabel.steps import GeneratorStep, StepInput, step + +if TYPE_CHECKING: + from distilabel.steps import GeneratorStepOutput, StepOutput + + +class NumpyBigArrayGenerator(GeneratorStep): + num_batches: int + + @property + def outputs(self) -> List[str]: + return ["array"] + + def process(self, offset: int = 0) -> "GeneratorStepOutput": + for i in range(self.num_batches): + yield ( + [{"array": np.random.randn(256)} for _ in range(self.batch_size)], # type: ignore + i == self.num_batches - 1, + ) # type: ignore + + +@step(step_type="global") +def ReceiveArrays(inputs: StepInput) -> "StepOutput": + yield inputs + + +@pytest.mark.benchmark +def test_cache_time() -> None: + with Pipeline(name="dummy") as pipeline: + numpy_generator = NumpyBigArrayGenerator(num_batches=2, batch_size=100) + + receive_arrays = ReceiveArrays() + + numpy_generator >> receive_arrays + + pipeline.run(use_cache=False) diff --git a/tests/integration/test_pipe_simple.py b/tests/integration/test_pipe_simple.py index 8ab1fff29a..31a624b15f 100644 --- a/tests/integration/test_pipe_simple.py +++ b/tests/integration/test_pipe_simple.py @@ -166,15 +166,8 @@ def run_pipeline(): ) -def test_pipeline_cached(): - ds = run_pipeline() - print() - print("----- RUNNING PIPELINE AGAIN -----") - print() +def test_pipeline_cached() -> None: + run_pipeline() ds = run_pipeline() assert isinstance(ds, Distiset) assert len(ds["default"]["train"]) == 80 - - -if __name__ == "__main__": - test_pipeline_cached() diff --git a/tests/integration/test_routing_batch_function.py b/tests/integration/test_routing_batch_function.py index 0ea2ee3cdc..228fb1c43e 100644 --- a/tests/integration/test_routing_batch_function.py +++ b/tests/integration/test_routing_batch_function.py @@ -74,7 +74,7 @@ def CombineGenerations(*inputs: StepInput) -> "StepOutput": yield combined_list -@pytest.mark.timeout(120) +@pytest.mark.timeout(240) def test_routing_batch_function() -> None: with Pipeline(name="test") as pipeline: load_dataset = LoadDataFromDicts( @@ -95,7 +95,7 @@ def test_routing_batch_function() -> None: assert len(row["generations"]) == 2 -@pytest.mark.timeout(120) +@pytest.mark.timeout(240) def test_routing_batch_function_irregular_batch_sizes() -> None: with Pipeline(name="test") as pipeline: load_dataset = LoadDataFromDicts( @@ -120,7 +120,7 @@ def test_routing_batch_function_irregular_batch_sizes() -> None: assert len(row["generations"]) == 2 -@pytest.mark.timeout(120) +@pytest.mark.timeout(240) def test_multiple_routing_batch_function() -> None: batch_size = 200 diff --git a/tests/integration/test_using_fs_to_pass_data.py b/tests/integration/test_using_fs_to_pass_data.py new file mode 100644 index 0000000000..811885e356 --- /dev/null +++ b/tests/integration/test_using_fs_to_pass_data.py @@ -0,0 +1,68 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List + +import numpy as np +from distilabel.pipeline import Pipeline +from distilabel.steps import GeneratorStep, StepInput, step + +if TYPE_CHECKING: + from distilabel.steps import GeneratorStepOutput, StepOutput + + +class NumpyBigArrayGenerator(GeneratorStep): + num_batches: int + + @property + def outputs(self) -> List[str]: + return ["array"] + + def process(self, offset: int = 0) -> "GeneratorStepOutput": + for i in range(self.num_batches): + yield ( + [{"array": np.random.randn(128)} for _ in range(self.batch_size)], # type: ignore + i == self.num_batches - 1, + ) # type: ignore + + +@step(step_type="global") +def ReceiveArrays(inputs: StepInput) -> "StepOutput": + yield inputs + + +def test_passing_data_through_fs_only_global_steps() -> None: + with Pipeline(name="dummy") as pipeline: + numpy_generator = NumpyBigArrayGenerator(num_batches=5, batch_size=100) + + receive_arrays = ReceiveArrays() + + numpy_generator >> receive_arrays + + distiset = pipeline.run(use_fs_to_pass_data=False, use_cache=False) + + assert len(distiset["default"]["train"]) == 500 + + +def test_passing_data_through_fs() -> None: + with Pipeline(name="dummy") as pipeline: + numpy_generator = NumpyBigArrayGenerator(num_batches=2, batch_size=200) + + receive_arrays = ReceiveArrays() + + numpy_generator >> receive_arrays + + distiset = pipeline.run(use_fs_to_pass_data=True, use_cache=False) + + assert len(distiset["default"]["train"]) == 400 diff --git a/tests/unit/llms/test_anthropic.py b/tests/unit/llms/test_anthropic.py index 28e486756b..75e7bcbf62 100644 --- a/tests/unit/llms/test_anthropic.py +++ b/tests/unit/llms/test_anthropic.py @@ -13,12 +13,16 @@ # limitations under the License. import os +import sys +from typing import Any, Dict from unittest.mock import AsyncMock, MagicMock, Mock, patch import nest_asyncio import pytest from distilabel.llms.anthropic import AnthropicLLM +from .utils import DummyUserDetail + @patch("anthropic.AsyncAnthropic") class TestAnthropicLLM: @@ -47,6 +51,37 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None: ] ) + @pytest.mark.asyncio + async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: + llm = AnthropicLLM( + model="claude-3-opus-20240229", + api_key="api.key", + structured_output={ + "schema": DummyUserDetail, + "mode": "tool_call", + "max_retries": 1, + }, + ) # type: ignore + llm._aclient = mock_openai + + sample_user = DummyUserDetail(name="John Doe", age=30) + + llm._aclient.messages.create = AsyncMock(return_value=sample_user) + + generation = await llm.agenerate( + input=[ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) + assert generation[0] == sample_user.model_dump_json() + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + ) @pytest.mark.asyncio async def test_generate(self, mock_anthropic: MagicMock) -> None: llm = AnthropicLLM(model="claude-3-opus-20240229") # type: ignore @@ -71,7 +106,52 @@ async def test_generate(self, mock_anthropic: MagicMock) -> None: ] ) - def test_serialization(self, _: MagicMock) -> None: + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "base_url": "https://api.anthropic.com", + "generation_kwargs": {}, + "max_retries": 6, + "model": "claude-3-opus-20240229", + "timeout": 600.0, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.anthropic", + "name": "AnthropicLLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "base_url": "https://api.anthropic.com", + "generation_kwargs": {}, + "max_retries": 6, + "model": "claude-3-opus-20240229", + "timeout": 600.0, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.anthropic", + "name": "AnthropicLLM", + }, + }, + ), + ], + ) + def test_serialization( + self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: os.environ["ANTHROPIC_API_KEY"] = "api.key" llm = AnthropicLLM(model="claude-3-opus-20240229") # type: ignore diff --git a/tests/unit/llms/test_azure.py b/tests/unit/llms/test_azure.py index a5208da95f..e8af5d7b8f 100644 --- a/tests/unit/llms/test_azure.py +++ b/tests/unit/llms/test_azure.py @@ -13,10 +13,14 @@ # limitations under the License. import os +from typing import Any, Dict from unittest import mock +import pytest from distilabel.llms.azure import AzureOpenAILLM +from .utils import DummyUserDetail + class TestAzureOpenAILLM: model_id: str = "gpt-4" @@ -56,20 +60,70 @@ def test_azure_openai_llm_env_vars(self) -> None: assert llm.api_key.get_secret_value() == "another.api.key" # type: ignore assert llm.api_version == self.api_version - def test_serialization(self) -> None: + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "model": "gpt-4", + "api_version": "preview", + "generation_kwargs": {}, + "max_retries": 6, + "base_url": "https://example-resource.azure.openai.com/", + "timeout": 120, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.azure", + "name": "AzureOpenAILLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "model": "gpt-4", + "api_version": "preview", + "generation_kwargs": {}, + "max_retries": 6, + "base_url": "https://example-resource.azure.openai.com/", + "timeout": 120, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.azure", + "name": "AzureOpenAILLM", + }, + }, + ), + ], + ) + def test_serialization( + self, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: llm = AzureOpenAILLM( - model=self.model_id, base_url=self.base_url, api_version=self.api_version + model=self.model_id, + base_url=self.base_url, + api_version=self.api_version, + structured_output=structured_output, ) - _dump = { - "generation_kwargs": {}, - "model": "gpt-4", - "base_url": "https://example-resource.azure.openai.com/", - "max_retries": 6, - "timeout": 120, - "api_version": "preview", - "structured_output": None, - "type_info": {"module": "distilabel.llms.azure", "name": "AzureOpenAILLM"}, - } - assert llm.dump() == _dump - assert isinstance(AzureOpenAILLM.from_dict(_dump), AzureOpenAILLM) + # _dump = { + # "generation_kwargs": {}, + # "model": "gpt-4", + # "base_url": "https://example-resource.azure.openai.com/", + # "max_retries": 6, + # "timeout": 120, + # "api_version": "preview", + # "structured_output": None, + # "type_info": {"module": "distilabel.llms.azure", "name": "AzureOpenAILLM"}, + # } + assert llm.dump() == dump + assert isinstance(AzureOpenAILLM.from_dict(dump), AzureOpenAILLM) diff --git a/tests/unit/llms/test_cohere.py b/tests/unit/llms/test_cohere.py index 0c2e2e213c..3cba9611d8 100644 --- a/tests/unit/llms/test_cohere.py +++ b/tests/unit/llms/test_cohere.py @@ -13,12 +13,16 @@ # limitations under the License. import os +import sys +from typing import Any, Dict from unittest import mock import nest_asyncio import pytest from distilabel.llms.cohere import CohereLLM +from .utils import DummyUserDetail + @mock.patch("cohere.AsyncClient") class TestCohereLLM: @@ -64,6 +68,38 @@ async def test_agenerate(self, mock_async_client: mock.MagicMock) -> None: ] ) + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + ) + @pytest.mark.asyncio + async def test_agenerate_structured( + self, mock_async_client: mock.MagicMock + ) -> None: + llm = CohereLLM( + model="command-r", + structured_output={ + "schema": DummyUserDetail, + "mode": "tool_call", + "max_retries": 1, + }, + ) + llm._aclient = mock_async_client # type: ignore + + sample_user = DummyUserDetail(name="John Doe", age=30) + + llm._aclient.chat = mock.AsyncMock(return_value=sample_user) + + generation = await llm.agenerate( + input=[ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) + assert generation == sample_user.model_dump_json() + @pytest.mark.asyncio async def test_generate(self, mock_async_client: mock.MagicMock) -> None: llm = CohereLLM(model="command-r") @@ -92,21 +128,53 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None: ] ) - def test_serialization(self, _: mock.MagicMock) -> None: - llm = CohereLLM(model="command-r") - - dump = { - "model": "command-r", - "generation_kwargs": {}, - "base_url": "https://api.cohere.ai/v1", - "timeout": 120, - "client_name": "distilabel", - "structured_output": None, - "type_info": { - "module": "distilabel.llms.cohere", - "name": "CohereLLM", - }, - } + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "model": "command-r", + "generation_kwargs": {}, + "base_url": "https://api.cohere.ai/v1", + "timeout": 120, + "client_name": "distilabel", + "structured_output": None, + "type_info": { + "module": "distilabel.llms.cohere", + "name": "CohereLLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "model": "command-r", + "generation_kwargs": {}, + "base_url": "https://api.cohere.ai/v1", + "timeout": 120, + "client_name": "distilabel", + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.cohere", + "name": "CohereLLM", + }, + }, + ), + ], + ) + def test_serialization( + self, _: mock.MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: + llm = CohereLLM(model="command-r", structured_output=structured_output) assert llm.dump() == dump assert isinstance(CohereLLM.from_dict(dump), CohereLLM) diff --git a/tests/unit/llms/test_groq.py b/tests/unit/llms/test_groq.py index e75166ce97..7607ab2cb2 100644 --- a/tests/unit/llms/test_groq.py +++ b/tests/unit/llms/test_groq.py @@ -13,12 +13,16 @@ # limitations under the License. import os +import sys +from typing import Any, Dict from unittest.mock import AsyncMock, MagicMock, Mock, patch import nest_asyncio import pytest from distilabel.llms.groq import GroqLLM +from .utils import DummyUserDetail + @patch("groq._client.AsyncGroq") class TestGroqLLM: @@ -47,6 +51,37 @@ async def test_agenerate(self, mock_groq: MagicMock) -> None: ] ) == [" Aenean hendrerit aliquam velit. ..."] + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + ) + @pytest.mark.asyncio + async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: + llm = GroqLLM( + model="llama3-70b-8192", + api_key="api.key", + structured_output={ + "schema": DummyUserDetail, + "mode": "tool_call", + "max_retries": 1, + }, + ) # type: ignore + llm._aclient = mock_openai + + sample_user = DummyUserDetail(name="John Doe", age=30) + + llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) + + generation = await llm.agenerate( + input=[ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) + assert generation[0] == sample_user.model_dump_json() + @pytest.mark.asyncio async def test_generate(self, mock_groq: MagicMock) -> None: llm = GroqLLM(model="llama3-70b-8192", api_key="api.key") # type: ignore @@ -71,22 +106,54 @@ async def test_generate(self, mock_groq: MagicMock) -> None: ] ) == [(" Aenean hendrerit aliquam velit. ...",)] - def test_serialization(self, mock_groq: MagicMock) -> None: + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "model": "llama3-70b-8192", + "base_url": "https://api.groq.com", + "generation_kwargs": {}, + "max_retries": 2, + "timeout": 120, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.groq", + "name": "GroqLLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "model": "llama3-70b-8192", + "base_url": "https://api.groq.com", + "generation_kwargs": {}, + "max_retries": 2, + "timeout": 120, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.groq", + "name": "GroqLLM", + }, + }, + ), + ], + ) + def test_serialization( + self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: os.environ["GROQ_API_KEY"] = "api.key" - llm = GroqLLM(model="llama3-70b-8192") - - _dump = { - "model": "llama3-70b-8192", - "base_url": "https://api.groq.com", - "generation_kwargs": {}, - "max_retries": 2, - "timeout": 120, - "structured_output": None, - "type_info": { - "module": "distilabel.llms.groq", - "name": "GroqLLM", - }, - } + llm = GroqLLM(model="llama3-70b-8192", structured_output=structured_output) - assert llm.dump() == _dump - assert isinstance(GroqLLM.from_dict(_dump), GroqLLM) # type: ignore + assert llm.dump() == dump + assert isinstance(GroqLLM.from_dict(dump), GroqLLM) # type: ignore diff --git a/tests/unit/llms/test_mistral.py b/tests/unit/llms/test_mistral.py index f31e903d3e..5bb2337481 100644 --- a/tests/unit/llms/test_mistral.py +++ b/tests/unit/llms/test_mistral.py @@ -14,11 +14,14 @@ import os import sys +from typing import Any, Dict from unittest.mock import AsyncMock, MagicMock, Mock, patch import nest_asyncio import pytest +from .utils import DummyUserDetail + try: from distilabel.llms.mistral import MistralLLM except ImportError: @@ -55,6 +58,37 @@ async def test_agenerate(self, mock_mistral: MagicMock) -> None: ] ) + @pytest.mark.asyncio + async def test_agenerate_structured(self, mock_mistral: MagicMock) -> None: + llm = MistralLLM( + model="mistral-tiny", + api_key="api.key", + structured_output={ + "schema": DummyUserDetail, + "mode": "tool_call", + "max_retries": 1, + }, + ) # type: ignore + llm._aclient = mock_mistral + + sample_user = DummyUserDetail(name="John Doe", age=30) + + llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) + # This should work just with the _aclient.chat method once it's fixed in instructor, and + # then in our code. + # llm._aclient.chat = AsyncMock(return_value=sample_user) + + generation = await llm.agenerate( + input=[ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) + assert generation[0] == sample_user.model_dump_json() + @pytest.mark.asyncio async def test_generate(self, mock_mistral: MagicMock) -> None: llm = MistralLLM(model="mistral-tiny", api_key="api.key") # type: ignore @@ -79,7 +113,54 @@ async def test_generate(self, mock_mistral: MagicMock) -> None: ] ) - def test_serialization(self, mock_mistral: MagicMock) -> None: + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "model": "mistral-tiny", + "endpoint": "https://api.mistral.ai", + "generation_kwargs": {}, + "max_retries": 6, + "timeout": 120, + "max_concurrent_requests": 64, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.mistral", + "name": "MistralLLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "model": "mistral-tiny", + "endpoint": "https://api.mistral.ai", + "generation_kwargs": {}, + "max_retries": 6, + "timeout": 120, + "max_concurrent_requests": 64, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.mistral", + "name": "MistralLLM", + }, + }, + ), + ], + ) + def test_serialization( + self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: os.environ["MISTRAL_API_KEY"] = "api.key" llm = MistralLLM(model="mistral-tiny") # type: ignore diff --git a/tests/unit/llms/test_openai.py b/tests/unit/llms/test_openai.py index 3562b6588b..7f90f513a2 100644 --- a/tests/unit/llms/test_openai.py +++ b/tests/unit/llms/test_openai.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import sys +from typing import Any, Dict from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -20,6 +22,8 @@ import pytest from distilabel.llms.openai import OpenAILLM +from .utils import DummyUserDetail + @patch("openai.AsyncOpenAI") class TestOpenAILLM: @@ -63,6 +67,37 @@ async def test_agenerate(self, mock_openai: MagicMock) -> None: ] ) + @pytest.mark.asyncio + async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: + llm = OpenAILLM( + model=self.model_id, + api_key="api.key", + structured_output={ + "schema": DummyUserDetail, + "mode": "tool_call", + "max_retries": 1, + }, + ) # type: ignore + llm._aclient = mock_openai + + sample_user = DummyUserDetail(name="John Doe", age=30) + + llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) + + generation = await llm.agenerate( + input=[ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) + assert generation[0] == sample_user.model_dump_json() + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + ) @pytest.mark.asyncio async def test_generate(self, mock_openai: MagicMock) -> None: llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore @@ -101,21 +136,53 @@ async def test_generate(self, mock_openai: MagicMock) -> None: response_format="unkown_format", ) - def test_serialization(self, _: MagicMock) -> None: - llm = OpenAILLM(model=self.model_id) - - _dump = { - "model": self.model_id, - "generation_kwargs": {}, - "max_retries": 6, - "base_url": "https://api.openai.com/v1", - "timeout": 120, - "structured_output": None, - "type_info": { - "module": "distilabel.llms.openai", - "name": "OpenAILLM", - }, - } - - assert llm.dump() == _dump - assert isinstance(OpenAILLM.from_dict(_dump), OpenAILLM) + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "model": "gpt-4", + "generation_kwargs": {}, + "max_retries": 6, + "base_url": "https://api.openai.com/v1", + "timeout": 120, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.openai", + "name": "OpenAILLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "model": "gpt-4", + "generation_kwargs": {}, + "max_retries": 6, + "base_url": "https://api.openai.com/v1", + "timeout": 120, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.openai", + "name": "OpenAILLM", + }, + }, + ), + ], + ) + def test_serialization( + self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: + llm = OpenAILLM(model=self.model_id, structured_output=structured_output) + + assert llm.dump() == dump + assert isinstance(OpenAILLM.from_dict(dump), OpenAILLM) diff --git a/tests/unit/llms/utils.py b/tests/unit/llms/utils.py new file mode 100644 index 0000000000..7b899253bb --- /dev/null +++ b/tests/unit/llms/utils.py @@ -0,0 +1,20 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import BaseModel + + +class DummyUserDetail(BaseModel): + name: str + age: int diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 60a624c9a0..8ae456a319 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -32,10 +32,19 @@ ) from distilabel.pipeline.local import Pipeline from distilabel.steps.base import GlobalStep, Step, StepInput +from distilabel.steps.typing import StepOutput from distilabel.utils.serialization import TYPE_INFO_KEY +from fsspec.implementations.local import LocalFileSystem from pydantic import Field - -from .utils import DummyGeneratorStep, DummyStep1, DummyStep2, batch_gen +from upath import UPath + +from .utils import ( + DummyGeneratorStep, + DummyGlobalStep, + DummyStep1, + DummyStep2, + batch_gen, +) if TYPE_CHECKING: from distilabel.steps.base import GeneratorStep @@ -70,6 +79,134 @@ def test_context_manager(self) -> None: assert _GlobalPipelineManager.get_pipeline() is None + @pytest.mark.parametrize("use_cache", [False, True]) + def test_load_batch_manager(self, use_cache: bool) -> None: + pipeline = BasePipeline(name="unit-test-pipeline") + pipeline._load_batch_manager(use_cache=True) + pipeline._cache() + + with mock.patch( + "distilabel.pipeline.base._BatchManager.load_from_cache" + ) as mock_load_from_cache, mock.patch( + "distilabel.pipeline.base._BatchManager.from_dag" + ) as mock_from_dag: + pipeline._load_batch_manager(use_cache=use_cache) + + if use_cache: + mock_load_from_cache.assert_called_once_with( + pipeline._cache_location["batch_manager"] + ) + mock_from_dag.assert_not_called() + else: + mock_load_from_cache.assert_not_called() + mock_from_dag.assert_called_once_with(pipeline.dag) + + def test_setup_write_buffer(self) -> None: + pipeline = BasePipeline(name="unit-test-pipeline") + + pipeline._setup_write_buffer() + assert isinstance(pipeline._write_buffer, _WriteBuffer) + + def test_set_logging_parameters(self) -> None: + pipeline = BasePipeline(name="unit-test-pipeline") + pipeline._set_logging_parameters({"unit-test": "yes"}) + + assert pipeline._logging_parameters == {"unit-test": "yes"} + + def test_setup_fsspec(self) -> None: + pipeline = BasePipeline(name="unit-test-pipeline") + + with mock.patch("fsspec.filesystem") as mock_filesystem: + pipeline._setup_fsspec({"path": "gcs://my-bucket", "extra": "stuff"}) + + mock_filesystem.assert_called_once_with("gcs", **{"extra": "stuff"}) + + def test_setup_fsspec_default(self) -> None: + pipeline = BasePipeline(name="unit-test-pipeline") + pipeline._setup_fsspec() + + assert isinstance(pipeline._fs, LocalFileSystem) + assert ( + pipeline._storage_base_path + == f"file://{pipeline._cache_location['batch_input_data']}" + ) + + def test_setup_fsspec_raises_value_error(self) -> None: + pipeline = BasePipeline(name="unit-test-pipeline") + + with pytest.raises(ValueError, match="The 'path' key must be present"): + pipeline._setup_fsspec({"key": "random"}) + + def test_send_batch_to_step(self) -> None: + with BasePipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + global_step = DummyGlobalStep() + + generator >> [step, global_step] + + pipeline._batch_manager = mock.MagicMock() + pipeline._setup_fsspec() + + with mock.patch( + "distilabel.pipeline.base._Batch.write_batch_data_to_fs" + ) as mock_write: + batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + pipeline._send_batch_to_step(batch) + pipeline._batch_manager.set_last_batch_sent.assert_called_once_with(batch) + + pipeline._send_batch_to_step( + _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore + ) + + mock_write.assert_not_called() + + with mock.patch( + "distilabel.pipeline.base._Batch.write_batch_data_to_fs" + ) as mock_write: + pipeline._send_batch_to_step( + _Batch(seq_no=0, step_name=global_step.name, last_batch=False) # type: ignore + ) + + mock_write.assert_called_once_with( + pipeline._fs, + UPath(pipeline._storage_base_path) / global_step.name, + ) + + pipeline._use_fs_to_pass_data = True + + with mock.patch( + "distilabel.pipeline.base._Batch.write_batch_data_to_fs" + ) as mock_write: + pipeline._send_batch_to_step( + _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + ) + + mock_write.assert_not_called() + + with mock.patch( + "distilabel.pipeline.base._Batch.write_batch_data_to_fs" + ) as mock_write: + pipeline._send_batch_to_step( + _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore + ) + pipeline._send_batch_to_step( + _Batch(seq_no=0, step_name=global_step.name, last_batch=False) # type: ignore + ) + + mock_write.assert_has_calls( + [ + mock.call( + pipeline._fs, + UPath(pipeline._storage_base_path) / step.name, + ), + mock.call( + pipeline._fs, + UPath(pipeline._storage_base_path) / global_step.name, + ), + ] + ) + def test_get_runtime_parameters_info(self) -> None: class DummyStep1(Step): runtime_param1: RuntimeParameter[str] = Field( @@ -180,8 +317,8 @@ class DummyStep1(Step): default=None, description="runtime_param2 description" ) - def process(self, inputs: StepInput) -> None: - pass + def process(self, inputs: StepInput) -> StepOutput: # type: ignore + yield [{}] class DummyStep2(Step): runtime_param3: RuntimeParameter[str] = Field( @@ -191,8 +328,8 @@ class DummyStep2(Step): default=None, description="runtime_param4 description" ) - def process(self, inputs: StepInput) -> None: - pass + def process(self, inputs: StepInput) -> StepOutput: # type: ignore + yield [{}] with BasePipeline(name="unit-test-pipeline") as pipeline: gen_step = DummyGeneratorStep(name="dummy_generator_step") @@ -205,7 +342,8 @@ def process(self, inputs: StepInput) -> None: if expected: assert expected in caplog.text else: - assert caplog.text == expected + assert "Did you mean any of:" not in expected + assert "Available runtime parameters for the step" not in expected def test_cache_dir_env_variable(self) -> None: with mock.patch.dict(os.environ, clear=True): @@ -2509,62 +2647,6 @@ def test_base_pipeline_signature(self): signature = pipeline._create_signature() assert signature == "a11ac46253598e6fe126420b23b9ad31c6422c92" - @pytest.mark.parametrize("use_cache", [True, False]) - def test_run_pipe_and_load_from_cache(self, use_cache: bool): - # Maybe not the best place for this test, but does the work for now - from distilabel.pipeline.base import BasePipeline - from distilabel.pipeline.routing_batch_function import sample_n_steps - - from tests.unit.pipeline.utils import DummyGeneratorStep, DummyStep1, DummyStep2 - - sample_two_steps = sample_n_steps(2) - - with tempfile.TemporaryDirectory() as tmpdirname: - with BasePipeline( - name="unit-test-pipeline", cache_dir=tmpdirname - ) as pipeline: - dummy_generator = DummyGeneratorStep() - dummy_step_1_0 = DummyStep1() - dummy_step_1_1 = DummyStep1() - dummy_step_1_2 = DummyStep1() - dummy_step_2 = DummyStep2() - - ( - dummy_generator - >> sample_two_steps - >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2] - >> dummy_step_2 - ) - - pipeline.run({}, use_cache=use_cache) - - assert not pipeline._cache_location["pipeline"].exists() - # Set the _BatchManager to the pipeline to check it exists afterwards - pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag) - pipeline._cache() - - assert pipeline._cache_location["pipeline"].exists() - - with BasePipeline(name="unit-test-pipeline", cache_dir=tmpdirname) as pipe: - dummy_generator = DummyGeneratorStep() - dummy_step_1_0 = DummyStep1() - dummy_step_1_1 = DummyStep1() - dummy_step_1_2 = DummyStep1() - dummy_step_2 = DummyStep2() - - ( - dummy_generator - >> sample_two_steps - >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2] - >> dummy_step_2 - ) - - pipe.run({}, use_cache=use_cache) - if use_cache: - assert pipe._batch_manager - else: - assert not pipe._batch_manager - def test_binary_rshift_operator(self) -> None: # Tests the steps can be connected using the >> operator. from distilabel.pipeline.local import Pipeline diff --git a/tests/unit/pipeline/test_local.py b/tests/unit/pipeline/test_local.py index 3c4a15b534..511f8f5040 100644 --- a/tests/unit/pipeline/test_local.py +++ b/tests/unit/pipeline/test_local.py @@ -58,8 +58,7 @@ def test_send_batch_to_step(self, dummy_generator_step: "GeneratorStep") -> None ) pipeline._send_batch_to_step(batch=batch) # type: ignore - batch_manager_mock.set_last_batch_sent.assert_called_once_with(batch) - get_step_mock.assert_called_once_with(dummy_generator_step.name) + get_step_mock.assert_has_calls([mock.call(dummy_generator_step.name)]) input_queue.put.assert_called_once_with(batch) @mock.patch("distilabel.pipeline.local._ProcessWrapper") diff --git a/tests/unit/steps/argilla/test_base.py b/tests/unit/steps/argilla/test_base.py index c816c8bcac..dbb8773923 100644 --- a/tests/unit/steps/argilla/test_base.py +++ b/tests/unit/steps/argilla/test_base.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import sys from typing import TYPE_CHECKING, List import pytest @@ -83,7 +84,9 @@ def test_with_errors(self, caplog) -> None: with pytest.raises( TypeError, - match="Can't instantiate abstract class Argilla with abstract methods inputs, process", + match="Can't instantiate abstract class Argilla with abstract methods inputs, process" + if sys.version_info < (3, 12) + else "Can't instantiate abstract class Argilla without an implementation for abstract methods 'inputs', 'process'", ): Argilla(name="step", pipeline=Pipeline(name="unit-test-pipeline")) # type: ignore diff --git a/tests/unit/steps/tasks/structured_outputs/test_utils.py b/tests/unit/steps/tasks/structured_outputs/test_utils.py new file mode 100644 index 0000000000..6238c8567f --- /dev/null +++ b/tests/unit/steps/tasks/structured_outputs/test_utils.py @@ -0,0 +1,75 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import List + +from distilabel.steps.tasks.structured_outputs.utils import json_schema_to_model +from pydantic import BaseModel, Field, StringConstraints, conint +from typing_extensions import Annotated + + +class Node(BaseModel): + id: int + label: str + color: str + + +class Edge(BaseModel): + source: int + target: int + label: str + color: str = "black" + + +class KnowledgeGraph(BaseModel): + nodes: List[Node] = Field(..., default_factory=list) + edges: List[Edge] = Field(..., default_factory=list) + + +class Weapon(str, Enum): + sword = "sword" + axe = "axe" + mace = "mace" + spear = "spear" + bow = "bow" + crossbow = "crossbow" + + +class Armor(str, Enum): + leather = "leather" + chainmail = "chainmail" + plate = "plate" + mithril = "mithril" + + +class Character(BaseModel): + name: Annotated[str, StringConstraints(max_length=30)] + age: conint(gt=1, lt=3000) + armor: Armor + weapon: Weapon + + +def test_json_schema_to_model(): + assert type(json_schema_to_model(Node.model_json_schema())) == type(Node) + + +def test_json_schema_to_model_with_enum(): + assert type(json_schema_to_model(Character.model_json_schema())) == type(Character) + + +def test_json_schema_to_model_nested(): + assert type(json_schema_to_model(KnowledgeGraph.model_json_schema())) == type( + KnowledgeGraph + ) diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index ed1fd956cf..0cccbd5c9d 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys from dataclasses import field from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -77,7 +78,9 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: with pytest.raises( TypeError, - match="Can't instantiate abstract class Task with abstract methods format_input, format_output", + match="Can't instantiate abstract class Task with abstract methods format_input, format_output" + if sys.version_info < (3, 12) + else "Can't instantiate abstract class Task without an implementation for abstract methods 'format_input', 'format_output'", ): Task(name="task", llm=DummyLLM()) # type: ignore diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index ecff0e1d90..d07ba464a3 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -53,6 +53,12 @@ def test_format_input_errors(self) -> None: name="task", llm=llm, pipeline=pipeline, use_system_prompt=True ) + with pytest.raises( + ValueError, + match=r"Providing \`instruction\` formatted as an OpenAI chat / conversation is deprecated", + ): + task.format_input({"instruction": [{"role": "user", "content": "test"}]}) + with pytest.raises( ValueError, match=r"Input \`instruction\` must be a string. Got: 1." ): @@ -79,23 +85,6 @@ def test_process(self) -> None: } ] - def test_deprecation_warning(self) -> None: - pipeline = Pipeline(name="unit-test-pipeline") - llm = DummyLLM() - task = TextGeneration(name="task", llm=llm, pipeline=pipeline) - - with pytest.warns( - DeprecationWarning, - match=r"Providing \`instruction\` formatted as an OpenAI chat \/ conversation is about to be deprecated in \`distilabel v1.2.0\`", - ): - task.format_input( - { - "instruction": [ - {"role": "user", "content": "Tell me a joke."}, - ] - } - ) - class TestChatGeneration: def test_format_input(self) -> None: diff --git a/tests/unit/test_distiset.py b/tests/unit/test_distiset.py index 18de3f8769..07e6549d7b 100644 --- a/tests/unit/test_distiset.py +++ b/tests/unit/test_distiset.py @@ -12,9 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import re +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional + import pytest +import yaml from datasets import Dataset, DatasetDict from distilabel.distiset import Distiset +from upath import UPath @pytest.fixture(scope="function") @@ -27,6 +35,24 @@ def distiset(): ) +def make_fake_file(filename: Path) -> None: + if not filename.parent.exists(): + filename.parent.mkdir(parents=True) + filename.touch() + + +def add_config_to_distiset(distiset: Distiset, folder: Path) -> Distiset: + from distilabel.distiset import DISTISET_CONFIG_FOLDER + + pipeline_yaml = folder / DISTISET_CONFIG_FOLDER / "pipeline.yaml" + pipeline_log = folder / DISTISET_CONFIG_FOLDER / "pipeline.log" + make_fake_file(pipeline_yaml) + make_fake_file(pipeline_log) + distiset.pipeline_path = pipeline_yaml + distiset.pipeline_log_path = pipeline_log + return distiset + + class TestDistiset: def test_train_test_split(self, distiset: Distiset) -> None: assert isinstance(distiset["leaf_step_1"], Dataset) @@ -34,3 +60,111 @@ def test_train_test_split(self, distiset: Distiset) -> None: assert isinstance(ds, Distiset) assert len(ds) == 2 assert isinstance(ds["leaf_step_1"], DatasetDict) + + @pytest.mark.parametrize("storage_options", [None, {"test": "option"}]) + @pytest.mark.parametrize("with_config", [False, True]) + def test_save_to_disk( + self, + distiset: Distiset, + with_config: bool, + storage_options: Optional[Dict[str, Any]], + ) -> None: + full_distiset = copy.deepcopy(distiset) + # Distiset with Distiset + with tempfile.TemporaryDirectory() as tmpdirname: + folder = Path(tmpdirname) / "distiset_folder" + if with_config: + full_distiset = add_config_to_distiset(full_distiset, folder) + + full_distiset.save_to_disk( + folder, + save_card=with_config, + save_pipeline_config=with_config, + save_pipeline_log=with_config, + storage_options=storage_options, + ) + assert folder.is_dir() + assert len(list(folder.iterdir())) == 3 + + full_distiset = copy.deepcopy(distiset) + # Distiset with DatasetDict + distiset_with_dict = full_distiset.train_test_split(0.8) + with tempfile.TemporaryDirectory() as tmpdirname: + folder = Path(tmpdirname) / "distiset_folder" + if with_config: + distiset_with_dict = add_config_to_distiset(distiset_with_dict, folder) + + distiset_with_dict.save_to_disk( + folder, + save_card=with_config, + save_pipeline_config=with_config, + save_pipeline_log=with_config, + ) + + assert folder.is_dir() + assert len(list(folder.iterdir())) == 3 + + @pytest.mark.parametrize("pathlib_implementation", [Path, UPath]) + @pytest.mark.parametrize("storage_options", [None, {"project": "experiments"}]) + @pytest.mark.parametrize("with_config", [False, True]) + def test_load_from_disk( + self, + distiset: Distiset, + with_config: bool, + storage_options: Optional[Dict[str, Any]], + pathlib_implementation: type, + ) -> None: + full_distiset = copy.deepcopy(distiset) + # Distiset with Distiset + with tempfile.TemporaryDirectory() as tmpdirname: + # This way we can test also we work with UPath, using FilePath protocol, as it should + # do the same as S3Path, GCSPath, etc. + folder = pathlib_implementation(tmpdirname) / "distiset_folder" + if with_config: + full_distiset = add_config_to_distiset(full_distiset, folder) + full_distiset.save_to_disk( + folder, + save_card=with_config, + save_pipeline_config=with_config, + save_pipeline_log=with_config, + storage_options=storage_options, + ) + ds = Distiset.load_from_disk( + folder, + storage_options=storage_options, + ) + assert isinstance(ds, Distiset) + assert isinstance(ds["leaf_step_1"], Dataset) + + if with_config: + assert ds.pipeline_path.exists() + assert ds.log_filename_path.exists() + + full_distiset = copy.deepcopy(distiset) + # Distiset with DatasetDict + distiset_with_dict = full_distiset.train_test_split(0.8) + with tempfile.TemporaryDirectory() as tmpdirname: + folder = pathlib_implementation(tmpdirname) / "distiset_folder" + if with_config: + distiset_with_dict = add_config_to_distiset(distiset_with_dict, folder) + + distiset_with_dict.save_to_disk(folder) + ds = Distiset.load_from_disk(folder, storage_options=storage_options) + + assert folder.is_dir() + assert isinstance(ds["leaf_step_1"], DatasetDict) + + if with_config: + assert ds.pipeline_path.exists() + assert ds.log_filename_path.exists() + + def test_dataset_card(self, distiset: Distiset) -> None: + # Test the the metadata we generate by default without extracting the already generated content from the HF hub. + # We parse the content and check it's the same as the one we generate. + distiset_card = distiset._get_card("repo_name_or_path") + metadata = re.findall(r"---\n(.*?)\n---", str(distiset_card), re.DOTALL)[0] + metadata = yaml.safe_load(metadata) + assert metadata == { + "size_categories": "n<1K", + "tags": ["synthetic", "distilabel", "rlaif"], + }