diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml new file mode 100644 index 0000000000..6da0a611be --- /dev/null +++ b/.github/workflows/codspeed.yml @@ -0,0 +1,42 @@ +name: Benchmarks + +on: + push: + branches: + - "main" + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + benchmarks: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.12" + # Looks like it's not working very well for other people: + # https://github.com/actions/setup-python/issues/436 + # cache: "pip" + # cache-dependency-path: pyproject.toml + + - uses: actions/cache@v3 + id: cache + with: + path: ${{ env.pythonLocation }} + key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-benchmarks-v00 + + - name: Install dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: ./scripts/install_dependencies.sh + + - name: Run benchmarks + uses: CodSpeedHQ/action@v2 + with: + token: ${{ secrets.CODSPEED_TOKEN }} + run: pytest tests/ --codspeed diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 46e55a7e0c..c5abc04ca1 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -45,6 +45,10 @@ jobs: - run: mike deploy dev --push if: github.ref == 'refs/heads/develop' + env: + GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }} - run: mike deploy ${{ github.ref_name }} latest --update-aliases --push if: startsWith(github.ref, 'refs/tags/') + env: + GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 01f1ebcb9a..b80e88d2e8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,6 +9,12 @@ on: types: - opened - synchronize + workflow_dispatch: + inputs: + tmate_session: + description: Starts the workflow with tmate enabled. + required: false + default: "false" concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -19,7 +25,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] fail-fast: false steps: @@ -42,14 +48,7 @@ jobs: - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' - run: | - python_version=$(python -c "import sys; print(sys.version_info[:2])") - - pip install -e .[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,vllm] - if [ "${python_version}" != "(3, 8)" ]; then - pip install -e .[mistralai] - fi; - pip install git+https://github.com/argilla-io/LLM-Blender.git + run: ./scripts/install_dependencies.sh - name: Lint run: make lint @@ -59,4 +58,3 @@ jobs: - name: Integration Tests run: make integration-tests - timeout-minutes: 5 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 59ac8caa8c..24ab3d19a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,11 +11,10 @@ repos: - --fuzzy-match-generates-todo - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.1.4 + rev: v0.4.5 hooks: - id: ruff - args: - - --fix + args: [--fix] - id: ruff-format ci: diff --git a/Makefile b/Makefile index 9634d016f7..16cc0c92c7 100644 --- a/Makefile +++ b/Makefile @@ -2,12 +2,12 @@ sources = src/distilabel tests .PHONY: format format: - ruff --fix $(sources) + ruff check --fix $(sources) ruff format $(sources) .PHONY: lint lint: - ruff $(sources) + ruff check $(sources) ruff format --check $(sources) .PHONY: unit-tests diff --git a/README.md b/README.md index 4e071df69d..786b6ad523 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ Compute is expensive and output quality is important. We help you **focus on dat Synthesize and judge data with **latest research papers** while ensuring **flexibility, scalability and fault tolerance**. So you can focus on improving your data and training your models. -## 🏘️ Community +## Community We are an open-source community-driven project and we love to hear from you. Here are some ways to get involved: @@ -68,7 +68,7 @@ Distilabel is a tool that can be used to **synthesize data and provide AI feedba - Our [distilabeled Intel Orca DPO dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs) and the [improved OpenHermes model](https://huggingface.co/argilla/distilabeled-OpenHermes-2.5-Mistral-7B),, show how we **improve model performance by filtering out 50%** of the original dataset through **AI feedback**. - The [haiku DPO data](https://github.com/davanstrien/haiku-dpo) outlines how anyone can create a **dataset for a specific task** and **the latest research papers** to improve the quality of the dataset. -## 👨🏽‍💻 Installation +## Installation ```sh pip install distilabel --upgrade @@ -116,7 +116,7 @@ with Pipeline( generate_with_openai = TextGeneration(llm=OpenAILLM(model="gpt-3.5-turbo")) - load_dataset.connect(generate_with_openai) + load_dataset >> generate_with_openai if __name__ == "__main__": distiset = pipeline.run( @@ -153,3 +153,15 @@ If you build something cool with `distilabel` consider adding one of these badge To directly contribute with `distilabel`, check our [good first issues](https://github.com/argilla-io/distilabel/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) or [open a new one](https://github.com/argilla-io/distilabel/issues/new/choose). +## Citation + +```bibtex +@misc{distilabel-argilla-2024, + author = {Álvaro Bartolomé Del Canto and Gabriel Martín Blázquez and Agustín Piqueres Lajarín and Daniel Vila Suero}, + title = {Distilabel: An AI Feedback (AIF) framework for building datasets with and for LLMs}, + year = {2024}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/argilla-io/distilabel}} +} +``` diff --git a/docs/api/cli.md b/docs/api/cli.md index acdc2bef06..4757cb1470 100644 --- a/docs/api/cli.md +++ b/docs/api/cli.md @@ -1,6 +1,6 @@ # Command Line Interface (CLI) -This section contains the API reference for the CLI. For more information on how to use the CLI, see [Tutorial - CLI](../sections/learn/tutorial/cli/index.md). +This section contains the API reference for the CLI. For more information on how to use the CLI, see [Tutorial - CLI](../sections/how_to_guides/advanced/cli/index.md). ## Utility functions for the `distilabel pipeline` sub-commands diff --git a/docs/api/distiset.md b/docs/api/distiset.md new file mode 100644 index 0000000000..71b57c43ca --- /dev/null +++ b/docs/api/distiset.md @@ -0,0 +1,6 @@ +# Distiset + +This section contains the API reference for the Distiset. For more information on how to use the CLI, see [Tutorial - CLI](../sections/how_to_guides/advanced/distiset.md). + +:::distilabel.distiset.Distiset +:::distilabel.distiset.create_distiset diff --git a/docs/api/llm/cohere.md b/docs/api/llm/cohere.md new file mode 100644 index 0000000000..c7064b7a75 --- /dev/null +++ b/docs/api/llm/cohere.md @@ -0,0 +1,3 @@ +# CohereLLM + +::: distilabel.llms.cohere diff --git a/docs/api/llm/index.md b/docs/api/llm/index.md index 27628e034a..fe58a65384 100644 --- a/docs/api/llm/index.md +++ b/docs/api/llm/index.md @@ -2,6 +2,6 @@ This section contains the API reference for the `distilabel` LLMs, both for the [`LLM`][distilabel.llms.LLM] synchronous implementation, and for the [`AsyncLLM`][distilabel.llms.AsyncLLM] asynchronous one. -For more information and examples on how to use existing LLMs or create custom ones, please refer to [Tutorial - LLM](../../sections/learn/tutorial/llm/index.md). +For more information and examples on how to use existing LLMs or create custom ones, please refer to [Tutorial - LLM](../../sections/how_to_guides/basic/llm/index.md). ::: distilabel.llms.base diff --git a/docs/api/pipeline/index.md b/docs/api/pipeline/index.md index f727ecea81..40c40426a0 100644 --- a/docs/api/pipeline/index.md +++ b/docs/api/pipeline/index.md @@ -1,6 +1,6 @@ # Pipeline -This section contains the API reference for the `distilabel` pipelines. For an example on how to use the pipelines, see the [Tutorial - Pipeline](../../sections/learn/tutorial/pipeline/index.md). +This section contains the API reference for the `distilabel` pipelines. For an example on how to use the pipelines, see the [Tutorial - Pipeline](../../sections/how_to_guides/basic/pipeline/index.md). ::: distilabel.pipeline.base ::: distilabel.pipeline.local diff --git a/docs/api/step/decorator.md b/docs/api/step/decorator.md index 98e6c6eefe..73844910e5 100644 --- a/docs/api/step/decorator.md +++ b/docs/api/step/decorator.md @@ -2,6 +2,6 @@ This section contains the reference for the `@step` decorator, used to create new [`Step`][distilabel.steps.Step] subclasses without having to manually define the class. -For more information check the [Tutorial - Step](../../sections/learn/tutorial/step/index.md) page. +For more information check the [Tutorial - Step](../../sections/how_to_guides/basic/step/index.md) page. ::: distilabel.steps.decorator diff --git a/docs/api/step/generator_step.md b/docs/api/step/generator_step.md index 5055b51545..949202eefd 100644 --- a/docs/api/step/generator_step.md +++ b/docs/api/step/generator_step.md @@ -2,6 +2,6 @@ This section contains the API reference for the [`GeneratorStep`][distilabel.steps.base.GeneratorStep] class. -For more information and examples on how to use existing generator steps or create custom ones, please refer to [Tutorial - Step - GeneratorStep](../../sections/learn/tutorial/step/generator_step.md). +For more information and examples on how to use existing generator steps or create custom ones, please refer to [Tutorial - Step - GeneratorStep](../../sections/how_to_guides/basic/step/generator_step.md). ::: distilabel.steps.base.GeneratorStep diff --git a/docs/api/step/global_step.md b/docs/api/step/global_step.md index e2a4d9f8e1..c2fba5ac38 100644 --- a/docs/api/step/global_step.md +++ b/docs/api/step/global_step.md @@ -2,6 +2,6 @@ This section contains the API reference for the [`GlobalStep`][distilabel.steps.base.GlobalStep] class. -For more information and examples on how to use existing global steps or create custom ones, please refer to [Tutorial - Step - GlobalStep](../../sections/learn/tutorial/step/global_step.md). +For more information and examples on how to use existing global steps or create custom ones, please refer to [Tutorial - Step - GlobalStep](../../sections/how_to_guides/basic/step/global_step.md). ::: distilabel.steps.base.GlobalStep diff --git a/docs/api/step/index.md b/docs/api/step/index.md index ac5de3bdab..cc49224bb3 100644 --- a/docs/api/step/index.md +++ b/docs/api/step/index.md @@ -2,7 +2,7 @@ This section contains the API reference for the `distilabel` step, both for the [`_Step`][distilabel.steps.base._Step] base class and the [`Step`][distilabel.steps.Step] class. -For more information and examples on how to use existing steps or create custom ones, please refer to [Tutorial - Step](../../sections/learn/tutorial/step/index.md). +For more information and examples on how to use existing steps or create custom ones, please refer to [Tutorial - Step](../../sections/how_to_guides/basic/step/index.md). ::: distilabel.steps.base options: diff --git a/docs/api/step_gallery/columns.md b/docs/api/step_gallery/columns.md index 2d30080f09..80fa7adfee 100644 --- a/docs/api/step_gallery/columns.md +++ b/docs/api/step_gallery/columns.md @@ -1,6 +1,6 @@ # Columns -This section contains the existing steps intended to be used for commong column operations to apply to the batches. +This section contains the existing steps intended to be used for common column operations to apply to the batches. ::: distilabel.steps.combine ::: distilabel.steps.expand diff --git a/docs/api/step_gallery/extra.md b/docs/api/step_gallery/extra.md index 4eecb44e5d..e310e45d4b 100644 --- a/docs/api/step_gallery/extra.md +++ b/docs/api/step_gallery/extra.md @@ -1,5 +1,6 @@ # Extra +::: distilabel.steps.generators.data ::: distilabel.steps.deita ::: distilabel.steps.formatting ::: distilabel.steps.typing diff --git a/docs/api/step_gallery/hugging_face.md b/docs/api/step_gallery/hugging_face.md new file mode 100644 index 0000000000..42fb85e795 --- /dev/null +++ b/docs/api/step_gallery/hugging_face.md @@ -0,0 +1,7 @@ +# Hugging Face + +This section contains the existing steps integrated with `Hugging Face` so as to easily push the generated datasets to Hugging Face. + +::: distilabel.steps.LoadDataFromDisk +::: distilabel.steps.LoadDataFromFileSystem +::: distilabel.steps.LoadDataFromHub diff --git a/docs/api/task/generator_task.md b/docs/api/task/generator_task.md index 309aba2a16..31748034df 100644 --- a/docs/api/task/generator_task.md +++ b/docs/api/task/generator_task.md @@ -2,6 +2,6 @@ This section contains the API reference for the `distilabel` generator tasks. -For more information on how the [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask] works and see some examples, check the [Tutorial - Task - GeneratorTask](../../sections/learn/tutorial/task/generator_task.md) page. +For more information on how the [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask] works and see some examples, check the [Tutorial - Task - GeneratorTask](../../sections/how_to_guides/basic/task/generator_task.md) page. ::: distilabel.steps.tasks.base.GeneratorTask diff --git a/docs/api/task/index.md b/docs/api/task/index.md index f280ab02bd..ee32b63aae 100644 --- a/docs/api/task/index.md +++ b/docs/api/task/index.md @@ -2,7 +2,7 @@ This section contains the API reference for the `distilabel` tasks. -For more information on how the [`Task`][distilabel.steps.tasks.Task] works and see some examples, check the [Tutorial - Task](../../sections/learn/tutorial/task/index.md) page. +For more information on how the [`Task`][distilabel.steps.tasks.Task] works and see some examples, check the [Tutorial - Task](../../sections/how_to_guides/basic/task/index.md) page. ::: distilabel.steps.tasks.base options: diff --git a/docs/api/task/typing.md b/docs/api/task/typing.md new file mode 100644 index 0000000000..818ad070b6 --- /dev/null +++ b/docs/api/task/typing.md @@ -0,0 +1,3 @@ +# Task Typing + +::: distilabel.steps.tasks.typing \ No newline at end of file diff --git a/docs/api/task_gallery/index.md b/docs/api/task_gallery/index.md index 96d38ac958..4cf90c479d 100644 --- a/docs/api/task_gallery/index.md +++ b/docs/api/task_gallery/index.md @@ -9,3 +9,4 @@ This section contains the existing [`Task`][distilabel.steps.tasks.Task] subclas - "!_Task" - "!GeneratorTask" - "!ChatType" + - "!typing" \ No newline at end of file diff --git a/docs/assets/distilabel-badge-light.png b/docs/assets/distilabel-badge-light.png index dafd37d76f..3d9c7f7ee0 100644 Binary files a/docs/assets/distilabel-badge-light.png and b/docs/assets/distilabel-badge-light.png differ diff --git a/docs/assets/distilabel-black.svg b/docs/assets/distilabel-black.svg new file mode 100644 index 0000000000..02992ff20b --- /dev/null +++ b/docs/assets/distilabel-black.svg @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/assets/distilabel-white.svg b/docs/assets/distilabel-white.svg new file mode 100644 index 0000000000..f4bd5f10a8 --- /dev/null +++ b/docs/assets/distilabel-white.svg @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/assets/images/sections/examples/knowledge-graph-example.png b/docs/assets/images/sections/examples/knowledge-graph-example.png new file mode 100644 index 0000000000..05c80c6caf Binary files /dev/null and b/docs/assets/images/sections/examples/knowledge-graph-example.png differ diff --git a/docs/assets/images/sections/learn/steps/argilla/preference.png b/docs/assets/images/sections/how_to_guides/steps/argilla/preference.png similarity index 100% rename from docs/assets/images/sections/learn/steps/argilla/preference.png rename to docs/assets/images/sections/how_to_guides/steps/argilla/preference.png diff --git a/docs/assets/images/sections/learn/steps/argilla/text_generation.png b/docs/assets/images/sections/how_to_guides/steps/argilla/text_generation.png similarity index 100% rename from docs/assets/images/sections/learn/steps/argilla/text_generation.png rename to docs/assets/images/sections/how_to_guides/steps/argilla/text_generation.png diff --git a/docs/assets/images/sections/pipeline/pipeline-ctrlc.png b/docs/assets/images/sections/pipeline/pipeline-ctrlc.png deleted file mode 100644 index 33b5b171ae..0000000000 Binary files a/docs/assets/images/sections/pipeline/pipeline-ctrlc.png and /dev/null differ diff --git a/docs/index.md b/docs/index.md index 097d3b5c80..de4bc67be8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,15 +1,15 @@ --- description: Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs. -hide: - - toc +hide: + - navigation ---
- - Distilabel Logo + Distilabel Logo + Distilabel Logo
@@ -36,11 +36,9 @@ hide:

- Distilabel is the **framework for synthetic data and AI feedback for AI engineers** that require **high-quality outputs, full data ownership, and overall efficiency**. If you just want to get started, we recommend you check the [documentation](http://distilabel.argilla.io/). Curious, and want to know more? Keep reading! - ## Why use Distilabel? @@ -58,106 +56,10 @@ Compute is expensive and output quality is important. We help you **focus on dat Synthesize and judge data with **latest research papers** while ensuring **flexibility, scalability and fault tolerance**. So you can focus on improving your data and training your models. -## 🏘️ Community - -We are an open-source community-driven project and we love to hear from you. Here are some ways to get involved: - -- [Community Meetup](https://lu.ma/embed-checkout/evt-IQtRiSuXZCIW6FB): listen in or present during one of our bi-weekly events. - -- [Slack](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g): get direct support from the community. - -- [Roadmap](https://github.com/orgs/argilla-io/projects/10/views/1): plans change but we love to discuss those with our community so feel encouraged to participate. - ## What do people build with Distilabel? Distilabel is a tool that can be used to **synthesize data and provide AI feedback**. Our community uses Distilabel to create amazing [datasets](https://huggingface.co/datasets?other=distilabel) and [models](https://huggingface.co/models?other=distilabel), and **we love contributions to open-source** ourselves too. - The [1M OpenHermesPreference](https://huggingface.co/datasets/argilla/OpenHermesPreferences) is a dataset of ~1 million AI preferences derived from teknium/OpenHermes-2.5. It shows how we can use Distilabel to **synthesize data on an immense scale**. - Our [distilabeled Intel Orca DPO dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs) and the [improved OpenHermes model](https://huggingface.co/argilla/distilabeled-OpenHermes-2.5-Mistral-7B),, show how we **improve model performance by filtering out 50%** of the original dataset through **AI feedback**. -- The [haiku DPO data](https://github.com/davanstrien/haiku-dpo) outlines how anyone can create a **dataset for a specific task** and **the latest research papers** to improve the quality of the dataset. - -## 👨🏽‍💻 Installation - -```sh -pip install distilabel --upgrade -``` - -Requires Python 3.8+ - -In addition, the following extras are available: - -- `anthropic`: for using models available in [Anthropic API](https://www.anthropic.com/api) via the `AnthropicLLM` integration. -- `cohere`: for using models available in [Cohere](https://cohere.ai/) via the `CohereLLM` integration. -- `argilla`: for exporting the generated datasets to [Argilla](https://argilla.io/). -- `groq`: for using models available in [Groq](https://groq.com/) using [`groq`](https://github.com/groq/groq-python) Python client via the `GroqLLM` integration. -- `hf-inference-endpoints`: for using the [Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) via the `InferenceEndpointsLLM` integration. -- `hf-transformers`: for using models available in [transformers](https://github.com/huggingface/transformers) package via the `TransformersLLM` integration. -- `litellm`: for using [`LiteLLM`](https://github.com/BerriAI/litellm) to call any LLM using OpenAI format via the `LiteLLM` integration. -- `llama-cpp`: for using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) Python bindings for `llama.cpp` via the `LlamaCppLLM` integration. -- `mistralai`: for using models available in [Mistral AI API](https://mistral.ai/news/la-plateforme/) via the `MistralAILLM` integration. -- `ollama`: for using [Ollama](https://ollama.com/) and their available models via `OllamaLLM` integration. -- `openai`: for using [OpenAI API](https://openai.com/blog/openai-api) models via the `OpenAILLM` integration, or the rest of the integrations based on OpenAI and relying on its client as `AnyscaleLLM`, `AzureOpenAILLM`, and `TogetherLLM`. -- `vertexai`: for using [Google Vertex AI](https://cloud.google.com/vertex-ai) proprietary models via the `VertexAILLM` integration. -- `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration. - -### Example - -To run the following example you must install `distilabel` with both `openai` extra: - -```sh -pip install "distilabel[openai]" --upgrade -``` - -Then run: - -```python -from distilabel.llms import OpenAILLM -from distilabel.pipeline import Pipeline -from distilabel.steps import LoadHubDataset -from distilabel.steps.tasks import TextGeneration - -with Pipeline( - name="simple-text-generation-pipeline", - description="A simple text generation pipeline", -) as pipeline: - load_dataset = LoadHubDataset(output_mappings={"prompt": "instruction"}) - - generate_with_openai = TextGeneration(llm=OpenAILLM(model="gpt-3.5-turbo")) - - load_dataset.connect(generate_with_openai) - -if __name__ == "__main__": - distiset = pipeline.run( - parameters={ - load_dataset.name: { - "repo_id": "distilabel-internal-testing/instruction-dataset-mini", - "split": "test", - }, - generate_with_openai.name: { - "llm": { - "generation_kwargs": { - "temperature": 0.7, - "max_new_tokens": 512, - } - } - }, - }, - ) -``` - -## Badges - -If you build something cool with `distilabel` consider adding one of these badges to your dataset or model card. - - [Built with Distilabel](https://github.com/argilla-io/distilabel) - -[Built with Distilabel](https://github.com/argilla-io/distilabel) - - [Built with Distilabel](https://github.com/argilla-io/distilabel) - -[Built with Distilabel](https://github.com/argilla-io/distilabel) - -## Contribute - -To directly contribute with `distilabel`, check our [good first issues](https://github.com/argilla-io/distilabel/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) or [open a new one](https://github.com/argilla-io/distilabel/issues/new/choose). - +- The [haiku DPO data](https://github.com/davanstrien/haiku-dpo) outlines how anyone can create a **dataset for a specific task** and **the latest research papers** to improve the quality of the dataset. \ No newline at end of file diff --git a/docs/scripts/gen_popular_issues.py b/docs/scripts/gen_popular_issues.py new file mode 100644 index 0000000000..aff095214e --- /dev/null +++ b/docs/scripts/gen_popular_issues.py @@ -0,0 +1,180 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import datetime +from typing import List, Union + +import pandas as pd +import requests +import mkdocs_gen_files + + +REPOSITORY = "argilla-io/distilabel" +DATA_PATH = "sections/community/popular_issues.md" + +GITHUB_ACCESS_TOKEN = os.getenv( + "GH_ACCESS_TOKEN" +) # public_repo and read:org scopes are required + + +def fetch_issues_from_github_repository( + repository: str, auth_token: Union[str, None] = None +) -> pd.DataFrame: + if auth_token is None: + return pd.DataFrame( + { + "Issue": [], + "State": [], + "Created at": [], + "Closed at": [], + "Last update": [], + "Labels": [], + "Milestone": [], + "Reactions": [], + "Comments": [], + "URL": [], + "Repository": [], + "Author": [], + } + ) + + headers = { + "Authorization": f"token {auth_token}", + "Accept": "application/vnd.github.v3+json", + } + issues_data = [] + + print(f"Fetching issues from '{repository}'...") + with requests.Session() as session: + session.headers.update(headers) + + owner, repo_name = repository.split("/") + issues_url = ( + f"https://api.github.com/repos/{owner}/{repo_name}/issues?state=all" + ) + + while issues_url: + response = session.get(issues_url) + issues = response.json() + + for issue in issues: + issues_data.append( + { + "Issue": f"{issue['number']} - {issue['title']}", + "State": issue["state"], + "Created at": issue["created_at"], + "Closed at": issue.get("closed_at", None), + "Last update": issue["updated_at"], + "Labels": [label["name"] for label in issue["labels"]], + "Milestone": (issue.get("milestone") or {}).get("title"), + "Reactions": issue["reactions"]["total_count"], + "Comments": issue["comments"], + "URL": issue["html_url"], + "Repository": repo_name, + "Author": issue["user"]["login"], + } + ) + + issues_url = response.links.get("next", {}).get("url", None) + + return pd.DataFrame(issues_data) + + +def get_org_members(auth_token: Union[str, None] = None) -> List[str]: + if auth_token is None: + return [] + + headers = { + "Authorization": f"token {auth_token}", + "Accept": "application/vnd.github.v3+json", + } + members_list = [] + + members_url = "https://api.github.com/orgs/argilla-io/members" + + while members_url: + response = requests.get(members_url, headers=headers) + members = response.json() + + for member in members: + members_list.append(member["login"]) + + members_list.extend(["pre-commit-ci[bot]"]) + + members_url = response.links.get("next", {}).get("url", None) + + return members_list + + +with mkdocs_gen_files.open(DATA_PATH, "w") as f: + df = fetch_issues_from_github_repository(REPOSITORY, GITHUB_ACCESS_TOKEN) + + open_issues = df.loc[df["State"] == "open"] + engagement_df = ( + open_issues[["URL", "Issue", "Repository", "Reactions", "Comments"]] + .sort_values(by=["Reactions", "Comments"], ascending=False) + .head(10) + .reset_index() + ) + + members = get_org_members(GITHUB_ACCESS_TOKEN) + community_issues = df.loc[~df["Author"].isin(members)] + community_issues_df = ( + community_issues[ + ["URL", "Issue", "Repository", "Created at", "Author", "State"] + ] + .sort_values(by=["Created at"], ascending=False) + .head(10) + .reset_index() + ) + + planned_issues = df.loc[df["Milestone"].notna()] + planned_issues_df = ( + planned_issues[ + ["URL", "Issue", "Repository", "Created at", "Milestone", "State"] + ] + .sort_values(by=["Milestone"], ascending=False) + .head(10) + .reset_index() + ) + + f.write('=== "Most engaging open issues"\n\n') + f.write(" | Rank | Issue | Reactions | Comments |\n") + f.write(" |------|-------|:---------:|:--------:|\n") + for ix, row in engagement_df.iterrows(): + f.write( + f" | {ix+1} | [{row['Issue']}]({row['URL']}) | 👍 {row['Reactions']} | 💬 {row['Comments']} |\n" + ) + + f.write('\n=== "Latest issues open by the community"\n\n') + f.write(" | Rank | Issue | Author |\n") + f.write(" |------|-------|:------:|\n") + for ix, row in community_issues_df.iterrows(): + state = "🟢" if row["State"] == "open" else "🟣" + f.write( + f" | {ix+1} | {state} [{row['Issue']}]({row['URL']}) | by **{row['Author']}** |\n" + ) + + f.write('\n=== "Planned issues for upcoming releases"\n\n') + f.write(" | Rank | Issue | Milestone |\n") + f.write(" |------|-------|:------:|\n") + for ix, row in planned_issues_df.iterrows(): + state = "🟢" if row["State"] == "open" else "🟣" + f.write( + f" | {ix+1} | {state} [{row['Issue']}]({row['URL']}) | **{row['Milestone']}** |\n" + ) + + today = datetime.today().date() + f.write(f"\nLast update: {today}\n") diff --git a/docs/sections/community/index.md b/docs/sections/community/index.md new file mode 100644 index 0000000000..ed7f6cdd42 --- /dev/null +++ b/docs/sections/community/index.md @@ -0,0 +1,60 @@ +--- +hide: + - toc + - footer +--- + +We are an open-source community-driven project not only focused on building a great product but also on building a great community, where you can get support, share your experiences, and contribute to the project! We would love to hear from you and help you get started with distilabel. + +
+ +- __Slack__ + + --- + + In our Slack you can get direct support from the community. + + + [:octicons-arrow-right-24: Slack ↗](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g) + +- __Community Meetup__ + + --- + + We host bi-weekly community meetups where you can listen in or present your work. + + [:octicons-arrow-right-24: Community Meetup ↗](https://lu.ma/argilla-event-calendar) + +- __Changelog__ + + --- + + The changelog is where you can find the latest updates and changes to the distilabel project. + + [:octicons-arrow-right-24: Changelog ↗](https://github.com/argilla-io/distilabel/releases) + +- __Roadmap__ + + --- + + We love to discuss our plans with the community. Feel encouraged to participate in our roadmap discussions. + + [:octicons-arrow-right-24: Roadmap ↗](https://github.com/orgs/argilla-io/projects/15) + +
+ +## Badges + +If you build something cool with `distilabel` consider adding one of these badges to your dataset or model card. + + [Built with Distilabel](https://github.com/argilla-io/distilabel) + +[Built with Distilabel](https://github.com/argilla-io/distilabel) + + [Built with Distilabel](https://github.com/argilla-io/distilabel) + +[Built with Distilabel](https://github.com/argilla-io/distilabel) + +## Contribute + +To directly contribute with `distilabel`, check our [good first issues](https://github.com/argilla-io/distilabel/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) or [open a new one](https://github.com/argilla-io/distilabel/issues/new/choose). \ No newline at end of file diff --git a/docs/sections/faq.md b/docs/sections/getting_started/faq.md similarity index 82% rename from docs/sections/faq.md rename to docs/sections/getting_started/faq.md index 58e7c13a07..27768a3c6f 100644 --- a/docs/sections/faq.md +++ b/docs/sections/getting_started/faq.md @@ -1,3 +1,9 @@ +--- +description: Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs. +hide: + - toc +--- + # Frequent Asked Questions (FAQ) ??? faq "How can I rename the columns in a batch?" @@ -30,6 +36,9 @@ All the data will be stored in `.cache/distilabel`, but the only data that will persist at the end of the `Pipeline.run` execution is the one from the leaf step/s, so bear that in mind. - For more information on the caching mechanism in `distilabel`, you can check the [Learn - Advanced - Caching](./learn/advanced/caching.md) section. + For more information on the caching mechanism in `distilabel`, you can check the [Learn - Advanced - Caching](../how_to_guides/advanced/caching.md) section. Also note that when running a [`Step`][distilabel.steps.base.Step] or a [`Task`][distilabel.steps.tasks.Task] standalone, the cache mechanism won't be used, so if you want to use that, you should use the `Pipeline` context manager. + +??? faq "How can I use the same `LLM` across several tasks without having to load it several times?" + You can serve the LLM using a solution like TGI or vLLM, and then connect to it using an `AsyncLLM` client like `InferenceEndpointsLLM` or `OpenAILLM`. Please refer to [Serving LLMs guide](../how_to_guides/advanced/serving_an_llm_for_reuse.md) for more information. diff --git a/docs/sections/installation.md b/docs/sections/getting_started/installation.md similarity index 97% rename from docs/sections/installation.md rename to docs/sections/getting_started/installation.md index 3244622762..07e473795a 100644 --- a/docs/sections/installation.md +++ b/docs/sections/getting_started/installation.md @@ -1,3 +1,9 @@ +--- +description: Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs. +hide: + - toc +--- + # Installation !!! NOTE diff --git a/docs/sections/how_to_guide.md b/docs/sections/getting_started/quickstart.md similarity index 79% rename from docs/sections/how_to_guide.md rename to docs/sections/getting_started/quickstart.md index f12b18a63a..f982ae319f 100644 --- a/docs/sections/how_to_guide.md +++ b/docs/sections/getting_started/quickstart.md @@ -1,20 +1,26 @@ -# How to Guide +--- +description: Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs. +hide: + - toc +--- + +# Quickstart To start off, `distilabel` is a framework for building pipelines for generating synthetic data using LLMs, that defines a [`Pipeline`][distilabel.pipeline.Pipeline] which orchestrates the execution of the [`Step`][distilabel.steps.base.Step] subclasses, and those will be connected as nodes in a Direct Acyclic Graph (DAG). -This being said, in this guide we will walk you through the process of creating a simple pipeline that uses the [`OpenAILLM`][distilabel.llms.OpenAILLM] class to generate text.å The [`Pipeline`][distilabel.pipeline.Pipeline] will load a dataset that contains a column named `prompt` from the Hugging Face Hub via the step [`LoadHubDataset`][distilabel.steps.LoadHubDataset] and then use the [`OpenAILLM`][distilabel.llms.OpenAILLM] class to generate text based on the dataset using the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task. +That being said, in this guide we will walk you through the process of creating a simple pipeline that uses the [`OpenAILLM`][distilabel.llms.OpenAILLM] class to generate text. The [`Pipeline`][distilabel.pipeline.Pipeline] will load a dataset that contains a column named `prompt` from the Hugging Face Hub via the step [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] and then use the [`OpenAILLM`][distilabel.llms.OpenAILLM] class to generate text based on the dataset using the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task. ```python from distilabel.llms import OpenAILLM from distilabel.pipeline import Pipeline -from distilabel.steps import LoadHubDataset +from distilabel.steps import LoadDataFromHub from distilabel.steps.tasks import TextGeneration with Pipeline( # (1) name="simple-text-generation-pipeline", description="A simple text generation pipeline", ) as pipeline: # (2) - load_dataset = LoadHubDataset( # (3) + load_dataset = LoadDataFromHub( # (3) name="load_dataset", output_mappings={"prompt": "instruction"}, ) @@ -50,7 +56,7 @@ if __name__ == "__main__": 2. We are using the [`Pipeline`][distilabel.pipeline.Pipeline] context manager, meaning that every [`Step`][distilabel.steps.base.Step] subclass that is defined within the context manager will be added to the pipeline automatically. -3. We define a [`LoadHubDataset`][distilabel.steps.LoadHubDataset] step named `load_dataset` that will load a dataset from the Hugging Face Hub, as provided via runtime parameters in the `pipeline.run` method below, but it can also be defined within the class instance via the arg `repo_id=...`. This step will basically produce output batches with the rows from the dataset, and the column `prompt` will be mapped to the `instruction` field. +3. We define a [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] step named `load_dataset` that will load a dataset from the Hugging Face Hub, as provided via runtime parameters in the `pipeline.run` method below, but it can also be defined within the class instance via the arg `repo_id=...`. This step will basically produce output batches with the rows from the dataset, and the column `prompt` will be mapped to the `instruction` field. 4. We define a [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task named `text_generation` that will generate text based on the `instruction` field from the dataset. This task will use the [`OpenAILLM`][distilabel.llms.OpenAILLM] class with the model `gpt-3.5-turbo`. diff --git a/docs/sections/learn/advanced/argilla.md b/docs/sections/how_to_guides/advanced/argilla.md similarity index 70% rename from docs/sections/learn/advanced/argilla.md rename to docs/sections/how_to_guides/advanced/argilla.md index 23f9649ded..2d8c047960 100644 --- a/docs/sections/learn/advanced/argilla.md +++ b/docs/sections/how_to_guides/advanced/argilla.md @@ -1,20 +1,17 @@ -# Argilla +# Export data to Argilla -As an additional step, besides being able to restore the dataset from the [`Pipeline`][distilabel.pipeline.Pipeline] output as a [`Distiset`][distilabel.distiset.Distiset] (which is a `datasets.DatasetDict` with multiple configurations depending on the leaf nodes of the [`Pipeline`][distilabel.pipeline.Pipeline]), one can also include a [`Step`][distilabel.steps.Step] within the [`Pipeline`][distilabel.pipeline.Pipeline] to easily export the datasets to Argilla with a pre-defined configuration, suiting the annotation purposes. +Being able to export the generated synthetic datasets to Argilla, is a core feature within `distilabel`. We believe in the potential of synthetic data, but without removing the impact a human annotator or group of annotators can bring. So on, the Argilla integration makes it straightforward to push a dataset to Argilla while the [`Pipeline`][distilabel.pipeline.Pipeline] is running, to be able to follow along the generation process in Argilla's UI, as well as annotating the records on the fly. One can include a [`Step`][distilabel.steps.Step] within the [`Pipeline`][distilabel.pipeline.Pipeline] to easily export the datasets to Argilla with a pre-defined configuration, suiting the annotation purposes. -Being able to export the generated synthetic datasets to Argilla, was one of the core features we wanted to have integrated within `distilabel` because we believe in the potential of synthetic data, but without removing the impact a human annotator or group of annotators can bring. So on, the Argilla integration makes it straightforward to push a dataset to Argilla while the [`Pipeline`][distilabel.pipeline.Pipeline] is running, to be able to follow along the generation process in Argilla's UI, as well as annotating the records on the fly. - -Before using any of the steps about to be described below, you should first have an Argilla instance up and running, so that you can successfully upload the data to Argilla. In order to deploy Argilla, the easiest and most straight forward way is to deploy it via the [Argilla Template in Hugging Face Spaces](https://docs.argilla.io/en/latest/getting_started/installation/deployments/huggingface-spaces.html) as simply as following the steps there, or just via the following button: +Before using any of the steps about to be described below, you should first have an Argilla instance up and running, so that you can successfully upload the data to Argilla. In order to deploy Argilla, the easiest and most straightforward way is to deploy it via the [Argilla Template in Hugging Face Spaces](https://huggingface.co/docs/hub/en/spaces-sdks-docker-argilla) as simply as following the steps there, or just via the following button: -Additionally, Argilla offer multiple deployment options listed in the [Argilla Documentation - Installation](https://docs.argilla.io/en/latest/getting_started/installation/deployments/deployments.html) page. ### Text Generation -For text generation scenarios, i.e. when the [`Pipeline`][distilabel.pipeline.Pipeline] contains a [`TextGeneration`][distilabel.steps.tasks.TextGeneration] step, we have designed the task [`TextGenerationToArgilla`][distilabel.steps.TextGenerationToArgilla], which will seamlessly push the generated data to Argilla, and allow the annotator to review the records. +For text generation scenarios, i.e. when the [`Pipeline`][distilabel.pipeline.Pipeline] contains a single [`TextGeneration`][distilabel.steps.tasks.TextGeneration] step, we have designed the task [`TextGenerationToArgilla`][distilabel.steps.TextGenerationToArgilla], which will seamlessly push the generated data to Argilla, and allow the annotator to review the records. The dataset will be pushed with the following configuration: @@ -58,7 +55,7 @@ with Pipeline(name="my-pipeline") as pipeline: pipeline.run() ``` -![Text Generation to Argilla](../../../assets/images/sections/learn/steps/argilla/text_generation.png) +![Text Generation to Argilla](../../../assets/images/sections/how_to_guides/steps/argilla/text_generation.png) ### Preference @@ -74,7 +71,7 @@ The dataset will be pushed with the following configuration: The [`PreferenceToArgilla`][distilabel.steps.PreferenceToArgilla] step will only work if the [`Pipeline`][distilabel.pipeline.Pipeline] contains multiple [`TextGeneration`][distilabel.steps.tasks.TextGeneration] steps, or if the columns `instruction` and `generations` are available within the batch data. Otherwise, the variable `input_mappings` will need to be set so that either both or one of `instruction` and `generations` are mapped to one of the existing columns in the batch data. !!! NOTE - Additionally, if the [`Pipeline`][distilabel.pipeline.Pipeline] contains an [`UltraFeedback`][distilabel.steps.tasks.UltraFeedback] step, the `ratings` and `rationales` will also be available, so if that's the case, those will be automatically injected as suggestions to the existing dataset so that the annotator only needs to review those, instead of fulfilling those by themselves. + Additionally, if the [`Pipeline`][distilabel.pipeline.Pipeline] contains an [`UltraFeedback`][distilabel.steps.tasks.UltraFeedback] step, the `ratings` and `rationales` will also be available and be automatically injected as suggestions to the existing dataset. ```python from distilabel.llms import OpenAILLM @@ -109,10 +106,8 @@ with Pipeline(name="my-pipeline") as pipeline: load_dataset >> text_generation >> to_argilla -pipeline.run() +if __name__ == "__main__": + pipeline.run() ``` -![Preference to Argilla](../../../assets/images/sections/learn/steps/argilla/preference.png) - -!!! NOTE - If you are willing to also add the suggestions, feel free to check ["UltraFeedback: Boosting Language Models with High-quality Feedback"](../../examples/papers/ultrafeedback.md) where the [`UltraFeedback`][distilabel.steps.tasks.UltraFeedback] task is used to generate both ratings and rationales for each of the generations of a given instruction. +![Preference to Argilla](../../../assets/images/sections/how_to_guides/steps/argilla/preference.png) diff --git a/docs/sections/learn/advanced/caching.md b/docs/sections/how_to_guides/advanced/caching.md similarity index 98% rename from docs/sections/learn/advanced/caching.md rename to docs/sections/how_to_guides/advanced/caching.md index 16137f9a72..1fc9414940 100644 --- a/docs/sections/learn/advanced/caching.md +++ b/docs/sections/how_to_guides/advanced/caching.md @@ -1,6 +1,6 @@ -# Caching +# Cache and recover pipeline executions -Distilabel `Pipelines` automatically save all the intermediate steps to to avoid losing any data in case of error. +Distilabel `Pipelines` automatically save all the intermediate steps to avoid losing any data in case of error. ## Cache directory @@ -131,5 +131,5 @@ ds !!! Note Internally, the function will try to inject the `pipeline_path` variable if it's not passed via argument, assuming it's in the parent directory of the current one, called `pipeline.yaml`. If the file doesn't exist, it won't raise any error, but take into account that if the `Distiset` is pushed to the Hugging Face Hub, the `pipeline.yaml` won't be generated. The same happens with the `pipeline.log` file, it can be passed via `log_filename_path`, but it will try to locate it automatically. - + Lastly, there is the option of including the `distilabel_metadata` column in the final dataset. This column can contain custom metadata generated automatically by the pipeline, like the raw output from an `LLM` without formatting in case of failure, and we can decide whether to include it using the `enable_metadata` argument. diff --git a/docs/sections/learn/tutorial/cli/index.md b/docs/sections/how_to_guides/advanced/cli/index.md similarity index 100% rename from docs/sections/learn/tutorial/cli/index.md rename to docs/sections/how_to_guides/advanced/cli/index.md diff --git a/docs/sections/learn/advanced/distiset.md b/docs/sections/how_to_guides/advanced/distiset.md similarity index 54% rename from docs/sections/learn/advanced/distiset.md rename to docs/sections/how_to_guides/advanced/distiset.md index 17be198ae3..0893599c43 100644 --- a/docs/sections/learn/advanced/distiset.md +++ b/docs/sections/how_to_guides/advanced/distiset.md @@ -1,10 +1,10 @@ -# Distiset +# Using the Distiset dataset object -A [`Pipeline`][distilabel.pipeline.Pipeline] in `distilabel` returns a special type of Hugging Face [`datasets.DatasetDict`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict) which is called [`Distiset`][distilabel.distiset.Distiset], as a combination of `distilabel` and dataset. This object is a wrapper around [`datasets.Dataset`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset) which comes with some extra functionality to easily deal with the dataset pieces that a [`Pipeline`][distilabel.pipeline.Pipeline] can generate. +A [`Pipeline`][distilabel.pipeline.Pipeline] in `distilabel` returns a special type of Hugging Face [`datasets.DatasetDict`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict) which is called [`Distiset`][distilabel.distiset.Distiset]. -The [`Distiset`][distilabel.distiset.Distiset] is a dictionary-like object that contains the different configurations generated by the [`Pipeline`][distilabel.pipeline.Pipeline], where each configuration corresponds to each leaf step in the DAG built by the [`Pipeline`][distilabel.pipeline.Pipeline]. Each configuration corresponds to a different subset of the dataset, which is a concept taken from 🤗 `datasets` that lets you upload different configurations of the same dataset within the same repository and can contain different columns i.e. different configurations, which can be seamlessly pushed to the Hugging Face Hub straight away. +The [`Distiset`][distilabel.distiset.Distiset] is a dictionary-like object that contains the different configurations generated by the [`Pipeline`][distilabel.pipeline.Pipeline], where each configuration corresponds to each leaf step in the DAG built by the [`Pipeline`][distilabel.pipeline.Pipeline]. Each configuration corresponds to a different subset of the dataset. This is a concept taken from 🤗 `datasets` that lets you upload different configurations of the same dataset within the same repository and can contain different columns i.e. different configurations, which can be seamlessly pushed to the Hugging Face Hub. -Below you can find an example on how to create a [`Distiset`][distilabel.distiset.Distiset] object, similarly as a [`datasets.DatasetDict`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict), which is not required in `distilabel` since that's internally handled by the [`Pipeline`][distilabel.pipeline.Pipeline] as part of the output of the `run` method: +Below you can find an example of how to create a [`Distiset`][distilabel.distiset.Distiset] object that resembles a [`datasets.DatasetDict`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.DatasetDict): ```python from datasets import Dataset @@ -29,7 +29,7 @@ We can interact with the different pieces generated by the [`Pipeline`][distilab ### Train/Test split -Which easily does the train/test split partition of the dataset for the different configurations or subsets. +Create a train/test split partition of the dataset for the different configurations or subsets. ```python >>> distiset.train_test_split(train_size=0.9) @@ -59,7 +59,7 @@ Distiset({ ### Push to Hugging Face Hub -Pushes the [`Distiset`][distilabel.distiset.Distiset] to a Hugging Face repository, where each one of the subsets will correspond to a different configuration: +Push the [`Distiset`][distilabel.distiset.Distiset] to a Hugging Face repository, where each one of the subsets will correspond to a different configuration: ```python distiset.push_to_hub( @@ -70,6 +70,48 @@ distiset.push_to_hub( ) ``` +### Save and load from disk + +Take into account that these methods work as `datasets.load_from_disk` and `datasets.Dataset.save_to_disk` so the arguments are directly passed to those methods. This means you can also make use of `storage_options` argument to save your [`Distiset`][distilabel.distiset.Distiset] in your cloud provider, including the distilabel artifacts (`pipeline.yaml`, `pipeline.log` and the `README.md` with the dataset card). You can read more in `datasets` documentation [here](https://huggingface.co/docs/datasets/filesystems#saving-serialized-datasets). + +=== "Save to disk" + + Save the [`Distiset`][distilabel.distiset.Distiset] to disk, and optionally (will be done by default) saves the dataset card, the pipeline config file and logs: + + ```python + distiset.save_to_disk( + "my-dataset", + save_card=True, + save_pipeline_config=True, + save_pipeline_log=True + ) + ``` + +=== "Load from disk (local)" + + Load a [`Distiset`][distilabel.distiset.Distiset] that was saved using [`Distiset.save_to_disk`][distilabel.distiset.Distiset.save_to_disk] just the same way: + + ```python + distiset = Distiset.load_from_disk("my-dataset") + ``` + +=== "Load from disk (cloud)" + + Load a [`Distiset`][distilabel.distiset.Distiset] from a remote location, like S3, GCS. You can pass the `storage_options` argument to authenticate with the cloud provider: + + ```python + distiset = Distiset.load_from_disk( + "s3://path/to/my_dataset", # gcs:// or any filesystem tolerated by fsspec + storage_options={ + "key": os.environ["S3_ACCESS_KEY"], + "secret": os.environ["S3_SECRET_KEY"], + ... + } + ) + ``` + +Take a look at the remaining arguments at [`Distiset.save_to_disk`][distilabel.distiset.Distiset.save_to_disk] and [`Distiset.load_from_disk`][distilabel.distiset.Distiset.load_from_disk]. + ## Dataset card Having this special type of dataset comes with an added advantage when calling [`Distiset.push_to_hub`][distilabel.distiset.Distiset], which is the automatically generated dataset card in the Hugging Face Hub. Note that it is enabled by default, but can be disabled by setting `generate_card=False`: diff --git a/docs/sections/how_to_guides/advanced/fs_to_pass_data.md b/docs/sections/how_to_guides/advanced/fs_to_pass_data.md new file mode 100644 index 0000000000..178b3e5eac --- /dev/null +++ b/docs/sections/how_to_guides/advanced/fs_to_pass_data.md @@ -0,0 +1,34 @@ +# Using a file system to pass data of batches between steps + +In some situations, it can happen that the batches contains so much data that is faster to write it to disk and read it back in the next step, instead of passing it using the queue. To solve this issue, `distilabel` uses [`fsspec`](https://filesystem-spec.readthedocs.io/en/latest/) to allow providing a file system configuration and whether if this file system should be used to pass data between steps in the `run` method of the `distilabel` pipelines: + +!!! WARNING + + In order to use a specific file system/cloud storage, you will need to install the specific package providing the `fsspec` implementation for that file system. For instance, to use Google Cloud Storage you will need to install `gcsfs`: + + ```bash + pip install gcsfs + ``` + + Check the available implementations: [fsspec - Other known implementations](https://filesystem-spec.readthedocs.io/en/latest/api.html#other-known-implementations) + +```python +from distilabel.pipeline import Pipeline + +with Pipeline(name="my-pipeline") as pipeline: + ... + +if __name__ == "__main__": + distiset = pipeline.run( + ..., + storage_parameters={"path": "gcs://my-bucket"}, + use_fs_to_pass_data=True + ) +``` + +The code above setups a file system (in this case Google Cloud Storage) and sets the flag `use_fs_to_pass_data` to specify that the data of the batches should be passed to the steps using the file system. The `storage_parameters` argument is optional, and in the case it's not provided but `use_fs_to_pass_data==True`, `distilabel` will use the local file system. + +!!! NOTE + + As `GlobalStep`s receives all the data from the previous steps in one single batch accumulating all the data, it's very likely that the data of the batch will be too big to be passed using the queue. In this case and even if `use_fs_to_pass_data==False`, `distilabel` will use the file system to pass the data to the `GlobalStep`. + diff --git a/docs/sections/how_to_guides/advanced/serving_an_llm_for_reuse.md b/docs/sections/how_to_guides/advanced/serving_an_llm_for_reuse.md new file mode 100644 index 0000000000..eedba82e3d --- /dev/null +++ b/docs/sections/how_to_guides/advanced/serving_an_llm_for_reuse.md @@ -0,0 +1,92 @@ +# Serving an `LLM` for sharing it between several `Task`s + +It's very common to want to use the same `LLM` for several `Task`s in a pipeline. To avoid loading the `LLM` as many times as the number of `Task`s and avoid wasting resources, it's recommended to serve the model using solutions like [`text-generation-inference`](https://huggingface.co/docs/text-generation-inference/quicktour#launching-tgi) or [`vLLM`](https://docs.vllm.ai/en/stable/serving/deploying_with_docker.html), and then use an `AsyncLLM` compatible client like `InferenceEndpointsLLM` or `OpenAILLM` to communicate with the server respectively. + +## Serving `meta-llama/Meta-Llama-3-8B-Instruct` using `text-generation-inference` + +```bash +model=meta-llama/Meta-Llama-3-8B-Instruct +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ + -e HUGGING_FACE_HUB_TOKEN= \ + ghcr.io/huggingface/text-generation-inference:2.0.4 \ + --model-id $model +``` + +!!! NOTE + + The bash command above has been copy-pasted from the official docs [text-generation-inference](https://huggingface.co/docs/text-generation-inference/quicktour#launching-tgi). Please refer to the official docs for more information. + +And then we can use `InferenceEndpointsLLM` with `base_url=http://localhost:8080` (pointing to our `TGI` local deployment): + +```python +from distilabel.llms import InferenceEndpointsLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import LoadDataFromDicts +from distilabel.steps.tasks import TextGeneration, UltraFeedback + +with Pipeline(name="serving-llm") as pipeline: + load_data = LoadDataFromDicts( + data=[{"instruction": "Write a poem about the sun and moon."}] + ) + + # `base_url` points to the address of the `TGI` serving the LLM + llm = InferenceEndpointsLLM(base_url="http://192.168.1.138:8080") + + text_generation = TextGeneration( + llm=llm, + num_generations=3, + group_generations=True, + output_mappings={"generation": "generations"}, + ) + + ultrafeedback = UltraFeedback(aspect="overall-rating", llm=llm) + + load_data >> text_generation >> ultrafeedback +``` + + +## Serving `meta-llama/Meta-Llama-3-8B-Instruct` using `vLLM` + +```bash +docker run --gpus all \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HUGGING_FACE_HUB_TOKEN=" \ + -p 8000:8000 \ + --ipc=host \ + vllm/vllm-openai:latest \ + --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +!!! NOTE + + The bash command above has been copy-pasted from the official docs [vLLM](https://docs.vllm.ai/en/stable/serving/deploying_with_docker.html). Please refer to the official docs for more information. + +And then we can use `OpenAILLM` with `base_url=http://localhost:8000` (pointing to our `vLLM` local deployment): + +```python +from distilabel.llms import OpenAILLM +from distilabel.pipeline import Pipeline +from distilabel.steps import LoadDataFromDicts +from distilabel.steps.tasks import TextGeneration, UltraFeedback + +with Pipeline(name="serving-llm") as pipeline: + load_data = LoadDataFromDicts( + data=[{"instruction": "Write a poem about the sun and moon."}] + ) + + # `base_url` points to the address of the `vLLM` serving the LLM + llm = OpenAILLM(base_url="http://192.168.1.138:8000", model="") + + text_generation = TextGeneration( + llm=llm, + num_generations=3, + group_generations=True, + output_mappings={"generation": "generations"}, + ) + + ultrafeedback = UltraFeedback(aspect="overall-rating", llm=llm) + + load_data >> text_generation >> ultrafeedback +``` diff --git a/docs/sections/learn/advanced/structured_generation.md b/docs/sections/how_to_guides/advanced/structured_generation.md similarity index 62% rename from docs/sections/learn/advanced/structured_generation.md rename to docs/sections/how_to_guides/advanced/structured_generation.md index c9d86c2816..ee7565272c 100644 --- a/docs/sections/learn/advanced/structured_generation.md +++ b/docs/sections/how_to_guides/advanced/structured_generation.md @@ -1,4 +1,4 @@ -# Structured Generation +# Structured data generation `Distilabel` has integrations with relevant libraries to generate structured text i.e. to guide the [`LLM`][distilabel.llms.LLM] towards the generation of structured outputs following a JSON schema, a regex, etc. @@ -8,6 +8,14 @@ The [`LLM`][distilabel.llms.LLM] has an argument named `structured_output`[^1] that determines how we can generate structured outputs with it, let's see an example using [`LlamaCppLLM`][distilabel.llms.LlamaCppLLM]. +!!! Note + + For `outlines` integration to work you may need to install the corresponding dependencies: + + ```bash + pip install distilabel[outlines] + ``` + ### JSON We will start with a JSON example, where we initially define a `pydantic.BaseModel` schema to guide the generation of the structured output. @@ -101,9 +109,9 @@ if match: These were some simple examples, but one can see the options this opens. -!!! NOTE +!!! Tip A full pipeline example can be seen in the following script: - [`examples/structured_generation_with_outlines.py`](../../examples/index.md/#structured_generation_with_outlines) + [`examples/structured_generation_with_outlines.py`](../../pipeline_samples/examples/#llama-cpp-with-outlines) [^1]: You can check the variable type by importing it from: @@ -119,14 +127,78 @@ These were some simple examples, but one can see the options this opens. curl -L -o ~/Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf ``` +## Instructor + +When working with model providers behind an API, there's no direct way of accessing the internal logit processor like `outlines` does, but thanks to [`instructor`](https://python.useinstructor.com/) we can generate structured output from LLM providers based on `pydantic.BaseModel` objects. We have integrated `instructor` to deal with the [`AsyncLLM`][distilabel.llms.AsyncLLM], so you can work with the following LLMs: [`OpenAILLM`][distilabel.llms.OpenAILLM], [`AzureOpenAILLM`][distilabel.llms.AzureOpenAILLM], [`CohereLLM`][distilabel.llms.CohereLLM], [`GroqLLM`][distilabel.llms.GroqLLM], [`LiteLLM`][distilabel.llms.LiteLLM] and [`MistralLLM`][distilabel.llms.MistralLLM]. + +!!! Note + For `instructor` integration to work you may need to install the corresponding dependencies: + + ```bash + pip install distilabel[instructor] + ``` + +!!! Note + Take a look at [`InstructorStructuredOutputType`][distilabel.steps.tasks.typing.InstructorStructuredOutputType] to see the expected format + of the `structured_output` dict variable. + +The following is the same example you can see with `outlines`'s `JSON` section for comparison purposes. + +```python +from pydantic import BaseModel + +class User(BaseModel): + name: str + last_name: str + id: int +``` + +And then we provide that schema to the `structured_output` argument of the LLM: + +!!! Note + In this example we are using *open-mixtral-8x22b*, keep in mind not all the models work with the function calling functionality required for this example to work. + +```python +from distilabel.llms import MistralLLM + +llm = MistralLLM( + model="open-mixtral-8x22b", + structured_output={"schema": User} +) +llm.load() +``` + +And we are ready to pass our instructions as usual: + +```python +import json + +result = llm.generate( + [[{"role": "user", "content": "Create a user profile for the following marathon"}]], + max_new_tokens=256 +) + +data = json.loads(result[0][0]) +data +# {'name': 'John', 'last_name': 'Doe', 'id': 12345} +User(**data) +# User(name='John', last_name='Doe', id=12345) +``` + +We get back a Python dictionary (formatted as a string) that we can parse using `json.loads`, or validate it directly using the `User`, which is a `pydantic.BaseModel` instance. + +!!! Tip + A full pipeline example can be seen in the following script: + [`examples/structured_generation_with_instructor.py`](../../pipeline_samples/examples/index.md#mistralai-with-instructor) + ## OpenAI JSON OpenAI offers a [JSON Mode](https://platform.openai.com/docs/guides/text-generation/json-mode) to deal with structured output via their API, let's see how to make use of them. The JSON mode instructs the model to always return a JSON object following the instruction required. !!! WARNING - Bear in mind, that in order for this to work, you must instruct the model in some way to generate JSON, either in the `system message` or in the instruction, as can be seen in the [API reference](https://platform.openai.com/docs/guides/text-generation/json-mode). + Bear in mind, for this to work, you must instruct the model in some way to generate JSON, either in the `system message` or in the instruction, as can be seen in the [API reference](https://platform.openai.com/docs/guides/text-generation/json-mode). -Contrary to what we have via `outlines`, JSON mode will not guarantee the output matches any specific schema, only that it is valid and parses without errors. More information can be found the OpenAI documentation. +Contrary to what we have via `outlines`, JSON mode will not guarantee the output matches any specific schema, only that it is valid and parses without errors. More information can be found in the OpenAI documentation. Other than the reference to generating JSON, to ensure the model generates parseable JSON we can pass the argument `response_format="json"`[^3]: diff --git a/docs/sections/how_to_guides/basic/llm/index.md b/docs/sections/how_to_guides/basic/llm/index.md new file mode 100644 index 0000000000..4bd5f9de2b --- /dev/null +++ b/docs/sections/how_to_guides/basic/llm/index.md @@ -0,0 +1,145 @@ +# Define LLMs as local or remote models + +## Working with LLMs + +LLM subclasses are designed to be used within a [Task][distilabel.steps.tasks.Task], but they can also be used standalone. + +```python +from distilabel.llms import OpenAILLM + +llm = OpenAILLM(model="gpt-4") +llm.load() + +llm.generate( + inputs=[ + [{"role": "user", "content": "What's the capital of Spain?"}], + ], +) +# "The capital of Spain is Madrid." +``` + +!!! NOTE + Always call the `LLM.load` or `Task.load` method when using LLMs standalone or as part of a `Task`. If using a `Pipeline`, this is done automatically in `Pipeline.run()`. + +### Within a Task + +Pass the LLM as an argument to the [`Task`][distilabel.steps.tasks.Task], and the task will handle the rest. + +```python +from distilabel.llms import OpenAILLM +from distilabel.steps.tasks import TextGeneration + +llm = OpenAILLM(model="gpt-4") +task = TextGeneration(name="text_generation", llm=llm) + +task.load() + +next(task.process(inputs=[{"instruction": "What's the capital of Spain?"}])) +# [{'instruction': "What's the capital of Spain?", "generation": "The capital of Spain is Madrid."}] +``` + +### Runtime Parameters + +LLMs can have runtime parameters, such as `generation_kwargs`, provided via the `Pipeline.run()` method using the `params` argument. + +!!! NOTE + Runtime parameters can differ between LLM subclasses, caused by the different functionalities offered by the LLM providers. + +```python +from distilabel.pipeline import Pipeline +from distilabel.llms import OpenAILLM +from distilabel.steps import LoadDataFromDicts +from distilabel.steps.tasks import TextGeneration + +with Pipeline(name="text-generation-pipeline") as pipeline: + load_dataset = LoadDataFromDicts( + name="load_dataset", + data=[{"instruction": "Write a short story about a dragon that saves a princess from a tower."}], + ) + + text_generation = TextGeneration( + name="text_generation", + llm=OpenAILLM(model="gpt-4"), + ) + + load_dataset >> text_generation + +if __name__ == "__main__": + pipeline.run( + parameters={ + text_generation.name: {"llm": {"generation_kwargs": {"temperature": 0.3}}}, + }, + ) +``` + +## Creating custom LLMs + +To create custom LLMs, subclass either [`LLM`][distilabel.llms.LLM] for synchronous or [`AsyncLLM`][distilabel.llms.AsyncLLM] for asynchronous LLMs. Implement the following methods: + +* `model_name`: A property containing the model's name. + +* `generate`: A method that takes a list of prompts and returns generated texts. + +* `agenerate`: A method that takes a single prompt and returns generated texts. This method is used within the `generate` method of the `AsyncLLM` class. +* +* (optional) `get_last_hidden_state`: is a method that will take a list of prompts and return a list of hidden states. This method is optional and will be used by some tasks such as the [`GenerateEmbeddings`][distilabel.steps.tasks.GenerateEmbeddings] task. + + +=== "Custom LLM" + + ```python + from typing import Any + + from pydantic import validate_call + + from distilabel.llms import LLM + from distilabel.llms.typing import GenerateOutput, HiddenState + from distilabel.steps.tasks.typing import ChatType + + class CustomLLM(LLM): + @property + def model_name(self) -> str: + return "my-model" + + @validate_call + def generate(self, inputs: List[ChatType], num_generations: int = 1, **kwargs: Any) -> List[GenerateOutput]: + for _ in range(num_generations): + ... + + def get_last_hidden_state(self, inputs: List[ChatType]) -> List[HiddenState]: + ... + ``` + +=== "Custom AsyncLLM" + + ```python + from typing import Any + + from pydantic import validate_call + + from distilabel.llms import AsyncLLM + from distilabel.llms.typing import GenerateOutput, HiddenState + from distilabel.steps.tasks.typing import ChatType + + class CustomAsyncLLM(AsyncLLM): + @property + def model_name(self) -> str: + return "my-model" + + @validate_call + async def agenerate(self, input: ChatType, num_generations: int = 1, **kwargs: Any) -> GenerateOutput: + for _ in range(num_generations): + ... + + def get_last_hidden_state(self, inputs: List[ChatType]) -> List[HiddenState]: + ... + ``` + +`generate` and `agenerate` keyword arguments (but `input` and `num_generations`) are considered as `RuntimeParameter`s, so a value can be passed to them via the `parameters` argument of the `Pipeline.run` method. + +!!! NOTE + To have the arguments of the `generate` and `agenerate` coerced to the expected types, the `validate_call` decorator is used, which will automatically coerce the arguments to the expected types, and raise an error if the types are not correct. This is specially useful when providing a value for an argument of `generate` or `agenerate` from the CLI, since the CLI will always provide the arguments as strings. + +## Available LLMs + +[Our LLM gallery](/distilabel/components-gallery/llms/) shows a list of the available LLMs that can be used within the `distilabel` library. \ No newline at end of file diff --git a/docs/sections/how_to_guides/basic/pipeline/index.md b/docs/sections/how_to_guides/basic/pipeline/index.md new file mode 100644 index 0000000000..f4abcfee63 --- /dev/null +++ b/docs/sections/how_to_guides/basic/pipeline/index.md @@ -0,0 +1,433 @@ +# Execute Steps and Tasks in a Pipeline + +## How to create a pipeline + +[`Pipeline`][distilabel.pipeline.Pipeline] organise the Steps and Tasks in a sequence, where the output of one step is the input of the next one. +A [`Pipeline`][distilabel.pipeline.Pipeline] should be created by making use of the context manager along with passing a **name**, and optionally a **description**. + +```python +from distilabel.pipeline import Pipeline + +with Pipeline("pipe-name", description="My first pipe") as pipeline: + ... +``` + +### Connecting steps with the `Step.connect` method + +Now, we can define the steps of our [`Pipeline`][distilabel.pipeline.Pipeline]. + +!!! NOTE + Steps without predecessors (i.e. root steps), need to be [`GeneratorStep`][distilabel.steps.GeneratorStep]s such as [`LoadDataFromDicts`][distilabel.steps.LoadDataFromDicts] or [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub]. After this, other steps can be defined. + + +```python +from distilabel.pipeline import Pipeline +from distilabel.steps import LoadDataFromHub + +with Pipeline("pipe-name", description="My first pipe") as pipeline: + load_dataset = LoadDataFromHub(name="load_dataset") + ... +``` + +Next, we will use `prompt` column from the dataset obtained through `LoadDataFromHub` and use several `LLM`s to execute a `TextGeneration` task. We will also use the `Task.connect()` method to connect the steps, so the output of one step is the input of the next one. + +!!! NOTE + The order of the execution of the steps will be determined by the connections of the steps. In this case, the `TextGeneration` tasks will be executed after the `LoadDataFromHub` step. + +```python +from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM +from distilabel.pipeline import Pipeline +from distilabel.steps import LoadDataFromHub +from distilabel.steps.tasks import TextGeneration + +with Pipeline("pipe-name", description="My first pipe") as pipeline: + load_dataset = LoadDataFromHub(name="load_dataset") + + for llm in ( + OpenAILLM(model="gpt-4-0125-preview"), + MistralLLM(model="mistral-large-2402"), + VertexAILLM(model="gemini-1.5-pro"), + ): + task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) + task.connect(load_dataset) + + ... +``` + +For each row of the dataset, the `TextGeneration` task will generate a text based on the `instruction` column and the `LLM` model, and store the result (a single string) in a new column called `generation`. Because we need to have the `response`s in the same column, we will add `CombineColumns` to combine them all in the same column as a list of strings. + +!!! NOTE + In this case, the `CombineColumns` tasks will be executed after all `TextGeneration` steps. + +```python +from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM +from distilabel.pipeline import Pipeline +from distilabel.steps import CombineColumns, LoadDataFromHub +from distilabel.steps.tasks import TextGeneration + +with Pipeline("pipe-name", description="My first pipe") as pipeline: + load_dataset = LoadDataFromHub(name="load_dataset") + + combine_generations = CombineColumns( + name="combine_generations", + columns=["generation", "model_name"], + output_columns=["generations", "model_names"], + ) + + for llm in ( + OpenAILLM(model="gpt-4-0125-preview"), + MistralLLM(model="mistral-large-2402"), + VertexAILLM(model="gemini-1.5-pro"), + ): + task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) + load_dataset.connect(task) + task.connect(combine_generations) +``` + +### Connecting steps with the `>>` operator + +Besides the `Step.connect` method: `step1.connect(step2)`, there's an alternative way by making use of the `>>` operator. We can connect steps in a more readable way, and it's also possible to connect multiple steps at once. + +=== "Step per step" + + Each call to `step1.connect(step2)` has been exchanged by `step1 >> step2` within the loop. + + ```python + from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM + from distilabel.pipeline import Pipeline + from distilabel.steps import CombineColumns, LoadDataFromHub + from distilabel.steps.tasks import TextGeneration + + with Pipeline("pipe-name", description="My first pipe") as pipeline: + load_dataset = LoadDataFromHub(name="load_dataset") + + combine_generations = CombineColumns( + name="combine_generations", + columns=["generation", "model_name"], + output_columns=["generations", "model_names"], + ) + + for llm in ( + OpenAILLM(model="gpt-4-0125-preview"), + MistralLLM(model="mistral-large-2402"), + VertexAILLM(model="gemini-1.5-pro"), + ): + task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) + load_dataset >> task >> combine_generations + ``` + +=== "Multiple steps at once" + + Each task is first appended to a list, and then all the calls to connections are done in a single call. + + ```python + from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM + from distilabel.pipeline import Pipeline + from distilabel.steps import CombineColumns, LoadDataFromHub + from distilabel.steps.tasks import TextGeneration + + with Pipeline("pipe-name", description="My first pipe") as pipeline: + load_dataset = LoadDataFromHub(name="load_dataset") + + combine_generations = CombineColumns( + name="combine_generations", + columns=["generation", "model_name"], + output_columns=["generations", "model_names"], + ) + + tasks = [] + for llm in ( + OpenAILLM(model="gpt-4-0125-preview"), + MistralLLM(model="mistral-large-2402"), + VertexAILLM(model="gemini-1.5-pro"), + ): + tasks.append( + TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) + ) + + load_dataset >> tasks >> combine_generations + ``` + +### Routing batches to specific downstream steps + +In some pipelines, you may want to send batches from a single upstream step to specific downstream steps based on certain conditions. To achieve this, you can use a `routing_batch_function`. This function takes a list of downstream steps and returns a list of step names to which each batch should be routed. + +Let's update the example above to route the batches loaded by the `LoadDataFromHub` step to just 2 of the `TextGeneration` tasks. First, we will create our custom [`routing_batch_function`][distilabel.pipeline.routing_batch_function.routing_batch_function], and then we will update the pipeline to use it: + +```python +import random +from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM +from distilabel.pipeline import Pipeline, routing_batch_function +from distilabel.steps import CombineColumns, LoadDataFromHub +from distilabel.steps.tasks import TextGeneration + +@routing_batch_function +def sample_two_steps(steps: list[str]) -> list[str]: + return random.sample(steps, 2) + +with Pipeline("pipe-name", description="My first pipe") as pipeline: + load_dataset = LoadDataFromHub( + name="load_dataset", + output_mappings={"prompt": "instruction"}, + ) + + tasks = [] + for llm in ( + OpenAILLM(model="gpt-4-0125-preview"), + MistralLLM(model="mistral-large-2402"), + VertexAILLM(model="gemini-1.0-pro"), + ): + tasks.append( + TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) + ) + + combine_generations = CombineColumns( + name="combine_generations", + columns=["generation", "model_name"], + output_columns=["generations", "model_names"], + ) + + load_dataset >> sample_two_steps >> tasks >> combine_generations +``` + + The `routing_batch_function` that we just built is a common one, so `distilabel` comes with a builtin function that can be used to achieve the same behavior: + +```python +from distilable.pipeline import sample_n_steps + +sample_two_steps = sample_n_steps(2) +``` + +## Running the pipeline + +### Pipeline.dry_run + +Before running the `Pipeline` we can check if the pipeline is valid using the `Pipeline.dry_run()` method. It takes the same parameters as the `run` method which we will discuss in the following section, plus the `batch_size` we want the dry run to use (by default set to 1). + +```python +with Pipeline("pipe-name", description="My first pipe") as pipeline: + ... + +if __name__ == "__main__": + distiset = pipeline.dry_run(parameters=..., batch_size=1) +``` + +### Pipeline.run + +After testing, we can now execute the full `Pipeline` using the `Pipeline.run()` method. + +```python +with Pipeline("pipe-name", description="My first pipe") as pipeline: + ... + +if __name__ == "__main__": + distiset = pipeline.run( + parameters={ + "load_dataset": { + "repo_id": "distilabel-internal-testing/instruction-dataset-mini", + "split": "test", + }, + "text_generation_with_gpt-4-0125-preview": { + "llm": { + "generation_kwargs": { + "temperature": 0.7, + "max_new_tokens": 512, + } + } + }, + "text_generation_with_mistral-large-2402": { + "llm": { + "generation_kwargs": { + "temperature": 0.7, + "max_new_tokens": 512, + } + } + }, + "text_generation_with_gemini-1.0-pro": { + "llm": { + "generation_kwargs": { + "temperature": 0.7, + "max_new_tokens": 512, + } + } + }, + }, + ) +``` + +But if we run the pipeline above, we will see that the `run` method will fail: + +``` +ValueError: Step 'text_generation_with_gpt-4-0125-preview' requires inputs ['instruction'], but only the inputs=['prompt', 'completion', 'meta'] are available, which means that the inputs=['instruction'] are missing or not available +when the step gets to be executed in the pipeline. Please make sure previous steps to 'text_generation_with_gpt-4-0125-preview' are generating the required inputs. +``` + +This is because, before actually running the pipeline, we must ensure each step has the necessary input columns to be executed. In this case, the `TextGeneration` task requires the `instruction` column, but the `LoadDataFromHub` step generates the `prompt` column. To solve this, we can use the `output_mappings` or `input_mapping` arguments of individual `Step`s, to map columns from one step to another. + +```python +with Pipeline("pipe-name", description="My first pipe") as pipeline: + load_dataset = LoadDataFromHub( + name="load_dataset", + output_mappings={"prompt": "instruction"} + ) + + ... +``` + +If we execute the pipeline again, it will run successfully and we will have a `Distiset` with the outputs of all the leaf steps of the pipeline which we can push to the Hugging Face Hub. + +```python +if __name__ == "__main__": + distiset = pipeline.run(...) + distiset.push_to_hub("distilabel-internal-testing/instruction-dataset-mini-with-generations") +``` + +### Stopping the pipeline + +In case you want to stop the pipeline while it's running, you can press ++ctrl+c++ or ++cmd+c++ depending on your OS (or send a `SIGINT` to the main process), and the outputs will be stored in the cache. Pressing an additional time will force the pipeline to stop its execution, but this can lead to losing the generated outputs for certain batches. + +## Cache + +If for some reason, the pipeline execution stops (for example by pressing `Ctrl+C`), the state of the pipeline and the outputs will be stored in the cache, so we can resume the pipeline execution from the point where it was stopped. + +If we want to force the pipeline to run again without can, then we can use the `use_cache` argument of the `Pipeline.run()` method: + +```python +if __name__ == "__main__": + distiset = pipeline.run(parameters={...}, use_cache=False) +``` + +!!! NOTE + For more information on caching, we refer the reader to the [caching](../../advanced/caching.md) section. + +## Adjusting the batch size for each step + +Memory issues can arise when processing large datasets or when using large models. To avoid this, we can use the `input_batch_size` argument of individual tasks. `TextGeneration` task will receive 5 dictionaries, while the `LoadDataFromHub` step will send 10 dictionaries per batch: + +```python +from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM +from distilabel.pipeline import Pipeline +from distilabel.steps import CombineColumns, LoadDataFromHub +from distilabel.steps.tasks import TextGeneration + +with Pipeline("pipe-name", description="My first pipe") as pipeline: + load_dataset = LoadDataFromHub( + name="load_dataset", + output_mappings={"prompt": "instruction"}, + batch_size=10 + ) + + for llm in ( + OpenAILLM(model="gpt-4-0125-preview"), + MistralLLM(model="mistral-large-2402"), + VertexAILLM(model="gemini-1.5-pro"), + ): + task = TextGeneration( + name=f"text_generation_with_{llm.model_name}", + llm=llm, + input_batch_size=5, + ) + + ... +``` + +## Serializing the pipeline + +Sharing a pipeline with others is very easy, as we can serialize the pipeline object using the `save` method. We can save the pipeline in different formats, such as `yaml` or `json`: + +=== "yaml" + ```python + if __name__ == "__main__": + pipeline.save("pipeline.yaml", format="yaml") + ``` + +=== "json" + ```python + if __name__ == "__main__": + pipeline.save("pipeline.json", format="json") + ``` + +To load the pipeline, we can use the `from_yaml` or `from_json` methods: + +=== "yaml" + ```python + pipeline = Pipeline.from_yaml("pipeline.yaml") + ``` + +=== "json" + ```python + pipeline = Pipeline.from_json("pipeline.json") + ``` + +Serializing the pipeline is very useful when we want to share the pipeline with others, or when we want to store the pipeline for future use. It can even be hosted online, so the pipeline can be executed directly using the [CLI](../../advanced/cli/index.md). + +## Fully working example + +To sum up, here is the full code of the pipeline we have created in this section. Note that you will need to change the name of the Hugging Face repository where the resulting will be pushed, set `OPENAI_API_KEY` environment variable, set `MISTRAL_API_KEY` and have `gcloud` installed and configured: + +??? Code + ```python + from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM + from distilabel.pipeline import Pipeline + from distilabel.steps import CombineColumns, LoadDataFromHub + from distilabel.steps.tasks import TextGeneration + + with Pipeline("pipe-name", description="My first pipe") as pipeline: + load_dataset = LoadDataFromHub( + name="load_dataset", + output_mappings={"prompt": "instruction"}, + ) + + combine_generations = CombineColumns( + name="combine_generations", + columns=["generation", "model_name"], + output_columns=["generations", "model_names"], + ) + + for llm in ( + OpenAILLM(model="gpt-4-0125-preview"), + MistralLLM(model="mistral-large-2402"), + VertexAILLM(model="gemini-1.0-pro"), + ): + task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) + load_dataset.connect(task) + task.connect(combine_generations) + + if __name__ == "__main__": + distiset = pipeline.run( + parameters={ + "load_dataset": { + "repo_id": "distilabel-internal-testing/instruction-dataset-mini", + "split": "test", + }, + "text_generation_with_gpt-4-0125-preview": { + "llm": { + "generation_kwargs": { + "temperature": 0.7, + "max_new_tokens": 512, + } + } + }, + "text_generation_with_mistral-large-2402": { + "llm": { + "generation_kwargs": { + "temperature": 0.7, + "max_new_tokens": 512, + } + } + }, + "text_generation_with_gemini-1.0-pro": { + "llm": { + "generation_kwargs": { + "temperature": 0.7, + "max_new_tokens": 512, + } + } + }, + }, + ) + distiset.push_to_hub( + "distilabel-internal-testing/instruction-dataset-mini-with-generations" + ) + ``` + diff --git a/docs/sections/how_to_guides/basic/step/generator_step.md b/docs/sections/how_to_guides/basic/step/generator_step.md new file mode 100644 index 0000000000..c5b665a82e --- /dev/null +++ b/docs/sections/how_to_guides/basic/step/generator_step.md @@ -0,0 +1,110 @@ +# GeneratorStep + +The [`GeneratorStep`][distilabel.steps.GeneratorStep] is a subclass of [`Step`][distilabel.steps.Step] that is intended to be used as the first step within a [`Pipeline`][distilabel.pipeline.Pipeline], because it doesn't require input and generates data that can be used by other steps. Alternatively, it can also be used as a standalone. + +```python +from typing import List +from typing_extensions import override + +from distilabel.steps import GeneratorStep +from distilabel.steps.typing import GeneratorStepOutput + +class MyGeneratorStep(GeneratorStep): + instructions: List[str] + + @override + def process(self, offset: int = 0) -> GeneratorStepOutput: + if offset: + self.instructions = self.instructions[offset:] + + while self.instructions: + batch = [ + { + "instruction": instruction + } for instruction in self.instructions[: self.batch_size] + ] + self.instructions = self.instructions[self.batch_size :] + yield ( + batch, + True if len(self.instructions) == 0 else False, + ) + + @property + def outputs(self) -> List[str]: + return ["instruction"] +``` + +Then we can use it as follows: + +```python +step = MyGeneratorStep( + name="my-generator-step", + instructions=["Tell me a joke.", "Tell me a story."], + batch_size=1, +) +step.load() + +next(step.process(offset=0)) +# ([{'instruction': 'Tell me a joke.'}], False) +next(step.process(offset=1)) +# ([{'instruction': 'Tell me a story.'}], True) +``` + +!!! NOTE + The `Step.load()` always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution. + +## Defining custom GeneratorSteps + +We can define a custom generator step by creating a new subclass of the [`GeneratorStep`][distilabel.steps.GeneratorStep] and defining the following: + +- `outputs`: is a property that returns a list of strings with the names of the output fields. + +- `process`: is a method that yields output data and a boolean flag indicating whether that's the last batch to be generated. + +!!! NOTE + The default signature for the `process` method is `process(self, offset: int = 0) -> GeneratorStepOutput`. The argument `offset` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too because it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. + +!!! WARNING + For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. + +=== "Inherit from `GeneratorStep`" + + We can inherit from the `GeneratorStep` class and define the `outputs`, and `process` methods as follows: + + + ```python + from typing import List + from typing_extensions import override + + from distilabel.steps import GeneratorStep + from distilabel.steps.typing import GeneratorStepOutput + + class MyGeneratorStep(GeneratorStep): + instructions: List[str] + + @override + def process(self, offset: int = 0) -> GeneratorStepOutput: + ... + + @property + def outputs(self) -> List[str]: + ... + ``` + +=== "Using the `@step` decorator" + + The `@step` decorator will take care of the boilerplate code, and will allow to define the `outputs`, and `process` methods in a more straightforward way. One downside is that it won't let you access the `self` attributes if any, neither set those, so if you need to access or set any attribute, you should go with the first approach of defining the custom [`GeneratorStep`][distilabel.steps.GeneratorStep] subclass. + + ```python + from distilabel.steps import step + from distilabel.steps.typing import GeneratorStepOutput + + @step(outputs=[...], step_type="generator") + def CustomGeneratorStep(offset: int = 0) -> GeneratorStepOutput: + yield ( + ..., + True if offset == 10 else False, + ) + + step = CustomGeneratorStep(name="my-step") + ``` \ No newline at end of file diff --git a/docs/sections/how_to_guides/basic/step/global_step.md b/docs/sections/how_to_guides/basic/step/global_step.md new file mode 100644 index 0000000000..c9044a87d2 --- /dev/null +++ b/docs/sections/how_to_guides/basic/step/global_step.md @@ -0,0 +1,67 @@ +# GlobalStep + +The [`GlobalStep`][distilabel.steps.GlobalStep] is a subclass of [`Step`][distilabel.steps.Step] that is used to define a step that requires the previous steps to be completed to run, since it will wait until all the input batches are received before running. This step is useful when you need to run a step that requires all the input data to be processed before running. Alternatively, it can also be used as a standalone. + +## Defining custom GlobalSteps + +We can define a custom step by creating a new subclass of the [`GlobalStep`][distilabel.steps.GlobalStep] and defining the following: + +- `inputs`: is a property that returns a list of strings with the names of the required input fields. + +- `outputs`: is a property that returns a list of strings with the names of the output fields. + +- `process`: is a method that receives the input data and returns the output data, and it should be a generator, meaning that it should `yield` the output data. + +!!! NOTE + The default signature for the `process` method is `process(self, *inputs: StepInput) -> StepOutput`. The argument `inputs` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too because it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. + +!!! WARNING + For the custom [`GlobalStep`][distilabel.steps.GlobalStep] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. + +=== "Inherit from `GlobalStep`" + + We can inherit from the `GlobalStep` class and define the `inputs`, `outputs`, and `process` methods as follows: + + ```python + from distilabel.steps import GlobalStep, StepInput + from distilabel.steps.typing import StepOutput + + class CustomStep(Step): + @property + def inputs(self) -> List[str]: + ... + + @property + def outputs(self) -> List[str]: + ... + + def process(self, *inputs: StepInput) -> StepOutput: + for input in inputs: + for item in input: + ... + yield item + + # When overridden (ideally under the `typing_extensions.override` decorator) + # @typing_extensions.override + # def process(self, inputs: StepInput) -> StepOutput: + # for input in inputs: + # ... + # yield inputs + ``` + +=== "Using the `@step` decorator" + + The `@step` decorator will take care of the boilerplate code, and will allow to define the `inputs`, `outputs`, and `process` methods in a more straightforward way. One downside is that it won't let you access the `self` attributes if any, neither set those, so if you need to access or set any attribute, you should go with the first approach of defining the custom [`GlobalStep`][distilabel.steps.GlobalStep] subclass. + + ```python + from distilabel.steps import StepInput, step + from distilabel.steps.typing import StepOutput + + @step(inputs=[...], outputs=[...], step_type="global") + def CustomStep(inputs: StepInput) -> StepOutput: + for input in inputs: + ... + yield inputs + + step = CustomStep(name="my-step") + ``` \ No newline at end of file diff --git a/docs/sections/how_to_guides/basic/step/index.md b/docs/sections/how_to_guides/basic/step/index.md new file mode 100644 index 0000000000..e3b19e5334 --- /dev/null +++ b/docs/sections/how_to_guides/basic/step/index.md @@ -0,0 +1,132 @@ +# Define Steps for your Pipeline + +## Working with Steps + +The [`Step`][distilabel.steps.Step] is intended to be used within the scope of a [`Pipeline`][distilabel.pipeline.Pipeline], which will orchestrate the different steps defined but can also be used standalone. + +Assuming that we have a [`Step`][distilabel.steps.Step] already defined as it follows: + +```python +class MyStep(Step): + @property + def inputs(self) -> List[str]: + return ["input_field"] + + @property + def outputs(self) -> List[str]: + return ["output_field"] + + def process(self, inputs: StepInput) -> "StepOutput": + for input in inputs: + input["output_field"] = input["input_field"] + yield inputs +``` + +Then we can use it as follows: + +```python +step = MyStep(name="my-step") +step.load() + +next(step.process([{"input_field": "value"}])) +# [{'input_field': 'value', 'output_field': 'value'}] +``` +!!! NOTE + The `Step.load()` always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution. + +### Arguments + +- `input_mappings`, is a dictionary that maps keys from the input dictionaries to the keys expected by the step. For example, if `input_mappings={"instruction": "prompt"}`, means that the input key `prompt` will be used as the key `instruction` for current step. + +- `output_mappings`, is a dictionary that can be used to map the outputs of the step to other names. For example, if `output_mappings={"conversation": "prompt"}`, means that output key `conversation` will be renamed to `prompt` for the next step. + +- `input_batch_size` (by default set to 50), is independent for every step and will determine how many input dictionaries will process at once. + +### Runtime parameters + +`Step`s can also have `RuntimeParameter`, which are parameters that can only used after the pipeline initialisation when calling the `Pipeline.run`. + +```python +from distilabel.mixins.runtime_parameters import RuntimeParameter + +class Step(...): + input_batch_size: RuntimeParameter[PositiveInt] = Field( + default=DEFAULT_INPUT_BATCH_SIZE, + description="The number of rows that will contain the batches processed by the" + " step.", + ) +``` + +## Types of Steps + +There are two special types of [`Step`][distilabel.steps.Step] in `distilabel`: + +* [`GeneratorStep`][distilabel.steps.GeneratorStep]: is a step that only generates data, and it doesn't need any input data from previous steps and normally is the first node in a [`Pipeline`][distilabel.pipeline.Pipeline]. More information: [Components -> Step - GeneratorStep](./generator_step.md). + +* [`GlobalStep`][distilabel.steps.GlobalStep]: is a step with the standard interface i.e. receives inputs and generates outputs, but it processes all the data at once, and often is the final step in the [`Pipeline`][distilabel.pipeline.Pipeline]. The fact that a [`GlobalStep`][distilabel.steps.GlobalStep] requires the previous steps to finish before being able to start. More information: [Components - Step - GlobalStep](global_step.md). + +* [`Task`][distilabel.steps.tasks.Task], is essentially the same as a default [`Step`][distilabel.steps.Step], but it relies on an [`LLM`][distilabel.llms.LLM] as an attribute, and the `process` method will be in charge of calling that LLM. More information: [Components - Task](../task/index.md). + +## Defining custom Steps + +We can define a custom step by creating a new subclass of the [`Step`][distilabel.steps.Step] and defining the following: + +- `inputs`: is a property that returns a list of strings with the names of the required input fields. + +- `outputs`: is a property that returns a list of strings with the names of the output fields. + +- `process`: is a method that receives the input data and returns the output data, and it should be a generator, meaning that it should `yield` the output data. + +!!! NOTE + The default signature for the `process` method is `process(self, *inputs: StepInput) -> StepOutput`. The argument `inputs` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too because it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. + +!!! WARNING + For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. + +=== "Inherit from `Step`" + + We can inherit from the `Step` class and define the `inputs`, `outputs`, and `process` methods as follows: + + ```python + from distilabel.steps import Step, StepInput + from distilabel.steps.typing import StepOutput + + class CustomStep(Step): + @property + def inputs(self) -> List[str]: + ... + + @property + def outputs(self) -> List[str]: + ... + + def process(self, *inputs: StepInput) -> StepOutput: + for input in inputs: + ... + yield item + + # When overridden (ideally under the `typing_extensions.override` decorator) + # @typing_extensions.override + # def process(self, inputs: StepInput) -> StepOutput: + # for input in inputs: + # ... + # yield inputs + ``` + +=== "Using the `@step` decorator" + + The `@step` decorator will take care of the boilerplate code, and will allow to define the `inputs`, `outputs`, and `process` methods in a more straightforward way. One downside is that it won't let you access the `self` attributes if any, neither set those, so if you need to access or set any attribute, you should go with the first approach of defining the custom [`Step`][distilabel.steps.Step] subclass. + + + ```python + from distilabel.steps import StepInput, step + from distilabel.steps.typing import StepOutput + + @step(inputs=[...], outputs=[...]) + def CustomStep(inputs: StepInput) -> StepOutput: + for input in inputs: + ... + yield inputs + + step = CustomStep(name="my-step") + ``` \ No newline at end of file diff --git a/docs/sections/learn/tutorial/task/generator_task.md b/docs/sections/how_to_guides/basic/task/generator_task.md similarity index 55% rename from docs/sections/learn/tutorial/task/generator_task.md rename to docs/sections/how_to_guides/basic/task/generator_task.md index c9943154af..040af877d9 100644 --- a/docs/sections/learn/tutorial/task/generator_task.md +++ b/docs/sections/how_to_guides/basic/task/generator_task.md @@ -1,15 +1,11 @@ # GeneratorTask -The [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask] is a custom implementation of a [`Task`][distilabel.steps.tasks.Task], but based on [`GeneratorStep`][distilabel.steps.GeneratorStep]; which means that will essentially be similar to the standard [`Task`][distilabel.steps.tasks.Task], but without the need of providing any input data, as the data will be generated as part of the [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask] execution. - -!!! WARNING - This task is still experimental and may be subject to changes in the future, since apparently it's not the most efficient way to generate data, but it's a good way to generate data on the fly without the need of providing any input data. - ## Working with GeneratorTasks -The subclasses of [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask] are intended to be used within the scope of a [`Pipeline`][distilabel.pipeline.Pipeline], which will orchestrate the different tasks defined; but nonetheless, they can be used standalone if needed too. +The [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask] is a custom implementation of a [`Task`][distilabel.steps.tasks.Task] based on the [`GeneratorStep`][distilabel.steps.GeneratorStep]. As with a [`Task`][distilabel.steps.tasks.Task], it is normally used within a [`Pipeline`][distilabel.pipeline.Pipeline] but can also be used standalone. -These tasks will basically expect no input data, but generate data as part of the `process` method of the parent class. Say you have a [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask] that generates text from a pre-defined instruction: +!!! WARNING + This task is still experimental and may be subject to changes in the future. ```python from typing import Any, Dict, List, Union @@ -48,7 +44,7 @@ class MyCustomTask(GeneratorTask): return {"output_field": output} ``` -To then use it as: +We can then use it as follows: ```python task = MyCustomTask( @@ -66,21 +62,17 @@ next(task.process()) Most of the times you would need to override the default `process` method, as it's suited for the standard [`Task`][distilabel.steps.tasks.Task] and not for the [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask]. But within the context of the `process` function you can freely use the `llm` to generate data in any way. !!! NOTE - The `load` method needs to be called ALWAYS if using the tasks as standalone, otherwise, if the [`Pipeline`][distilabel.pipeline.Pipeline] context manager is used, there's no need to call that method, since it will be automatically called on `Pipeline.run`; but in any other case the method `load` needs to be called from the parent class e.g. a [`GeneratorTask`][distilabel.steps.tasks.Task] with an [`LLM`][distilabel.llms.LLM] will need to call `GeneratorTask.load` to load both the task and the LLM. + The `Step.load()` always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution. ## Defining custom GeneratorTasks -In order to define custom tasks, we need to inherit from the [`Task`][distilabel.steps.tasks.Task] class and implement the `format_input` and `format_output` methods, as well as setting the properties `inputs` and `outputs`, as for [`Step`][distilabel.steps.Step] subclasses. - -So on, the following will need to be defined: - -- `process`: is a method that generates the data based on the [`LLM`][distilabel.llms.LLM] and the `instruction` provided within the class instance, and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that the `inputs` argument is not allowed in this function since this is not a [`Task`][distilabel.steps.tasks.Task] but a [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask], so no input data is expected; so the signature only expects the `offset` argument, which is used to keep track of the current iteration in the generator. +We can define a custom generator task by creating a new subclass of the [`GeneratorTask`][distilabel.steps.tasks.Task] and defining the following: -- `outputs`: is a property that returns a list of strings with the names of the output fields. Note that since all the [`Task`][distilabel.steps.tasks.Task] subclasses are designed to work with a single [`LLM`][distilabel.llms.LLM], this property should always include `model_name` as one of the outputs, since that's automatically injected from the LLM. +- `process`: is a method that generates the data based on the [`LLM`][distilabel.llms.LLM] and the `instruction` provided within the class instance, and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that the `inputs` argument is not allowed in this function since this is a [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask]. The signature only expects the `offset` argument, which is used to keep track of the current iteration in the generator. -- `format_output`: is a method that receives the output from the [`LLM`][distilabel.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output, since that's automatically injected from the LLM in the `process` method of the [`Task`][distilabel.steps.tasks.Task]. +- `outputs`: is a property that returns a list of strings with the names of the output fields, this property should always include `model_name` as one of the outputs since that's automatically injected from the LLM. -Once those methods have been implemented, the task can be used as any other task, and it will be able to generate text based on the input data. +- `format_output`: is a method that receives the output from the [`LLM`][distilabel.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output. ```python from typing import Any, Dict, List, Union diff --git a/docs/sections/how_to_guides/basic/task/index.md b/docs/sections/how_to_guides/basic/task/index.md new file mode 100644 index 0000000000..70b118c3ea --- /dev/null +++ b/docs/sections/how_to_guides/basic/task/index.md @@ -0,0 +1,151 @@ +# Define Tasks that rely on LLMs + +## Working with Tasks + +The [`Task`][distilabel.steps.tasks.Task] is a special kind of [`Step`][distilabel.steps.Step] that includes the [`LLM`][distilabel.llms.LLM] as a mandatory argument. As with a [`Step`][distilabel.steps.Step], it is normally used within a [`Pipeline`][distilabel.pipeline.Pipeline] but can also be used standalone. + +For example, the most basic task is the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task, which generates text based on a given instruction. + +```python +from distilabel.llms import InferenceEndpointsLLM +from distilabel.steps.tasks import TextGeneration + +task = TextGeneration( + name="text-generation", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), +) +task.load() + +next(task.process([{"instruction": "What's the capital of Spain?"}])) +# [ +# { +# 'instruction': "What's the capital of Spain?", +# 'generation': 'The capital of Spain is Madrid.', +# 'distilabel_metadata': {'raw_output_text-generation': 'The capital of Spain is Madrid.'}, +# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct' +# } +# ] +``` + +!!! NOTE + The `Step.load()` always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution. + +As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task adds a `generation` based on the `instruction`. Additionally, it provides some metadata about the LLM call through `distilabel_metadata`. This can be disabled by setting the `add_raw_output` attribute to `False` when creating the task. + +## Specifying the number of generations and grouping generations + +All the `Task`s have a `num_generations` attribute that allows defining the number of generations that we want to have per input. We can update the example above to generate 3 completions per input: + +```python +from distilabel.llms import InferenceEndpointsLLM +from distilabel.steps.tasks import TextGeneration + +task = TextGeneration( + name="text-generation", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + num_generations=3, +) +task.load() + +next(task.process([{"instruction": "What's the capital of Spain?"}])) +# [ +# { +# 'instruction': "What's the capital of Spain?", +# 'generation': 'The capital of Spain is Madrid.', +# 'distilabel_metadata': {'raw_output_text-generation': 'The capital of Spain is Madrid.'}, +# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct' +# }, +# { +# 'instruction': "What's the capital of Spain?", +# 'generation': 'The capital of Spain is Madrid.', +# 'distilabel_metadata': {'raw_output_text-generation': 'The capital of Spain is Madrid.'}, +# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct' +# }, +# { +# 'instruction': "What's the capital of Spain?", +# 'generation': 'The capital of Spain is Madrid.', +# 'distilabel_metadata': {'raw_output_text-generation': 'The capital of Spain is Madrid.'}, +# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct' +# } +# ] +``` + +In addition, we might want to group the generations in a single output row as maybe one downstream step expects a single row with multiple generations. We can achieve this by setting the `group_generations` attribute to `True`: + +```python +from distilabel.llms import InferenceEndpointsLLM +from distilabel.steps.tasks import TextGeneration + +task = TextGeneration( + name="text-generation", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + num_generations=3, + group_generations=True +) +task.load() + +next(task.process([{"instruction": "What's the capital of Spain?"}])) +# [ +# { +# 'instruction': "What's the capital of Spain?", +# 'generation': ['The capital of Spain is Madrid.', 'The capital of Spain is Madrid.', 'The capital of Spain is Madrid.'], +# 'distilabel_metadata': [ +# {'raw_output_text-generation': 'The capital of Spain is Madrid.'}, +# {'raw_output_text-generation': 'The capital of Spain is Madrid.'}, +# {'raw_output_text-generation': 'The capital of Spain is Madrid.'} +# ], +# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct' +# } +# ] +``` + +## Defining custom Tasks + +We can define a custom step by creating a new subclass of the [`Task`][distilabel.steps.tasks.Task] and defining the following: + +- `inputs`: is a property that returns a list of strings with the names of the required input fields. + +- `format_input`: is a method that receives a dictionary with the input data and returns a [`ChatType`][distilabel.steps.tasks.ChatType] following [the chat-completion OpenAI message formatting](https://platform.openai.com/docs/guides/text-generation). + +- `outputs`: is a property that returns a list of strings with the names of the output fields, this property should always include `model_name` as one of the outputs since that's automatically injected from the LLM. + +- `format_output`: is a method that receives the output from the [`LLM`][distilabel.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output. + +```python +from typing import Any, Dict, List, Union + +from distilabel.steps.tasks.base import Task +from distilabel.steps.tasks.typing import ChatType + + +class MyCustomTask(Task): + @property + def inputs(self) -> List[str]: + return ["input_field"] + + def format_input(self, input: Dict[str, Any]) -> ChatType: + return [ + { + "role": "user", + "content": input["input_field"], + }, + ] + + @property + def outputs(self) -> List[str]: + return ["output_field", "model_name"] + + def format_output( + self, output: Union[str, None], input: Dict[str, Any] + ) -> Dict[str, Any]: + return {"output_field": output} +``` diff --git a/docs/sections/how_to_guides/index.md b/docs/sections/how_to_guides/index.md new file mode 100644 index 0000000000..3d6cb3e82d --- /dev/null +++ b/docs/sections/how_to_guides/index.md @@ -0,0 +1,93 @@ +# How-to guides + +Welcome to the how-to guides section! Here you will find a collection of guides that will help you get started with Distilabel. We have divided the guides into two categories: basic and advanced. The basic guides will help you get started with the core concepts of Distilabel, while the advanced guides will help you explore more advanced features. + +## Basic + +
+ +- __Define Steps for your Pipeline__ + + --- + + Steps are the building blocks of your pipeline. They can be used to generate data, evaluate models, manipulate data, or any other general task. + + [:octicons-arrow-right-24: Define Steps](basic/step/index.md) + +- __Define Tasks that rely on LLMs__ + + --- + + Tasks are a specific type of step that rely on Language Models (LLMs) to generate data. + + [:octicons-arrow-right-24: Define Tasks](basic/task/index.md) + +- __Define LLMs as local or remote models__ + + --- + + LLMs are the core of your tasks. They are used to integrate with local models or remote APIs. + + [:octicons-arrow-right-24: Define LLMs](basic/llm/index.md) + +- __Execute Steps and Tasks in a Pipeline__ + + --- + + Pipeline is where you put all your steps and tasks together to create a workflow. + + [:octicons-arrow-right-24: Execute Pipeline](basic/pipeline/index.md) + +
+ +## Advanced + +
+- __Using the Distiset dataset object__ + + --- + + Distiset is a dataset object based on the datasets library that can be used to store and manipulate data. + + [:octicons-arrow-right-24: Distiset](advanced/distiset.md) + +- __Export data to Argilla__ + + --- + + Argilla is a platform that can be used to store, search, and apply feedback to datasets. + [:octicons-arrow-right-24: Argilla](advanced/argilla.md) + +- __Using a file system to pass data of batches between steps__ + + --- + + File system can be used to pass data between steps in a pipeline. + + [:octicons-arrow-right-24: File System](advanced/fs_to_pass_data.md) + +- __Using CLI to explore and re-run existing Pipelines__ + + --- + + CLI can be used to explore and re-run existing pipelines through the command line. + + [:octicons-arrow-right-24: CLI](advanced/cli/index.md) + +- __Cache and recover pipeline executions__ + + --- + + Caching can be used to recover pipeline executions to avoid loosing data and precious LLM calls. + + [:octicons-arrow-right-24: Caching](advanced/caching.md) + +- __Structured data generation__ + + --- + + Structured data generation can be used to generate data with a specific structure like JSON, function calls, etc. + + [:octicons-arrow-right-24: Structured Generation](advanced/structured_generation.md) + +
\ No newline at end of file diff --git a/docs/sections/learn/advanced/index.md b/docs/sections/learn/advanced/index.md deleted file mode 100644 index eff42bbe24..0000000000 --- a/docs/sections/learn/advanced/index.md +++ /dev/null @@ -1,3 +0,0 @@ -# Advanced - -This subsection will cover the advanced components of `distilabel` which are either internal specifications on how `distilabel` works or components used to create more complex and robust pipelines. diff --git a/docs/sections/learn/index.md b/docs/sections/learn/index.md deleted file mode 100644 index 92e11b95fd..0000000000 --- a/docs/sections/learn/index.md +++ /dev/null @@ -1,3 +0,0 @@ -# Learn - -Here is a step by step guide to all the components of `distilabel` in a Tutorial form and a special section for more advanced topics. diff --git a/docs/sections/learn/tutorial/index.md b/docs/sections/learn/tutorial/index.md deleted file mode 100644 index 78bfd2629e..0000000000 --- a/docs/sections/learn/tutorial/index.md +++ /dev/null @@ -1,5 +0,0 @@ -# Tutorial - -`Distilabel` builds a [`Pipeline`][distilabel.pipeline.Pipeline] with steps that can be thought of as nodes in a graph, as the [`Pipeline`][distilabel.pipeline.Pipeline] will orchestrate the execution of the [`Step`][distilabel.steps.base.Step] subclasses, and those will be connected as nodes in a Direct Acyclic Graph (DAG). - -This guide can be considered a tutorial, which will guide you through the different components of `distilabel`. diff --git a/docs/sections/learn/tutorial/llm/index.md b/docs/sections/learn/tutorial/llm/index.md deleted file mode 100644 index 68a1a9b0d9..0000000000 --- a/docs/sections/learn/tutorial/llm/index.md +++ /dev/null @@ -1,164 +0,0 @@ -# LLM - -The LLMs are implemented as subclasses of either [`LLM`][distilabel.llms.LLM] or [`AsyncLLM`][distilabel.llms.AsyncLLM], and are only in charge of running the text generation for a given prompt or conversation. The LLMs are intended to be used together with the [`Task`][distilabel.steps.tasks.Task] and any of its subclasses, via the `llm` argument, this means that any of the implemented LLMs can be easily plugged seamlessly into any task. - -## Working with LLMs - -The subclasses of both [`LLM`][distilabel.llms.LLM] or [`AsyncLLM`][distilabel.llms.AsyncLLM] are intended to be used within the scope of a [`Task`][distilabel.steps.tasks.Task], since those are seamlessly integrated within the different tasks; but nonetheless, they can be used standalone if needed. - -```python -from distilabel.llms import OpenAILLM - -llm = OpenAILLM(model="gpt-4") -llm.load() - -llm.generate( - inputs=[ - [ - {"role": "user", "content": "What's the capital of Spain?"}, - ], - ], -) -# "The capital of Spain is Madrid." -``` - -!!! NOTE - The `load` method needs to be called ALWAYS if using the LLMs as standalone or as part of a task, otherwise, if the `Pipeline` context manager is used, there's no need to call that method, since it will be automatically called on `Pipeline.run`; but in any other case the method `load` needs to be called from the parent class e.g. a `Task` with an `LLM` will need to call `Task.load` to load both the task and the LLM. - -### Within a Task - -Now, in order to use the LLM within a [`Task`][distilabel.steps.tasks.Task], we need to pass it as an argument to the task, and the task will take care of the rest. - -```python -from distilabel.llms import OpenAILLM -from distilabel.steps.tasks import TextGeneration - - -llm = OpenAILLM(model="gpt-4") -task = TextGeneration(name="text_generation", llm=llm) - -task.load() - -next(task.process(inputs=[{"instruction": "What's the capital of Spain?"}])) -# [{'instruction': "What's the capital of Spain?", "generation": "The capital of Spain is Madrid."}] -``` - -### Runtime Parameters - -Additionally, besides the runtime parameters that can / need to be provided to the [`Task`][distilabel.steps.tasks], the LLMs can also define their own runtime parameters such as the `generation_kwargs`, and those need to be provided within the `Pipeline.run` method via the argument `params`. - -!!! NOTE - Each LLM subclass may have its own runtime parameters and those can differ between the different implementations, as those are not aligned, since the LLM engines offer different functionalities. - -```python -from distilabel.pipeline import Pipeline -from distilabel.llms import OpenAILLM -from distilabel.steps import LoadDataFromDicts -from distilabel.steps.tasks import TextGeneration - - -with Pipeline(name="text-generation-pipeline") as pipeline: - load_dataset = LoadDataFromDicts( - name="load_dataset", - data=[ - { - "instruction": "Write a short story about a dragon that saves a princess from a tower.", - }, - ], - ) - - text_generation = TextGeneration( - name="text_generation", - llm=OpenAILLM(model="gpt-4"), - ) - - load_dataset >> text_generation - -if __name__ == "__main__": - pipeline.run( - parameters={ - text_generation.name: {"llm": {"generation_kwargs": {"temperature": 0.3}}}, - }, - ) -``` - -## Defining custom LLMs - -In order to define custom LLMs, one must subclass either [`LLM`][distilabel.llms.LLM] or [`AsyncLLM`][distilabel.llms.AsyncLLM], to define a synchronous or asynchronous LLM, respectively. - -One can either extend any of the existing LLMs to override the default behaviour if needed, but also to define a new one from scratch, that could be potentially contributed to the `distilabel` codebase. - -In order to define a new LLM, one must define the following methods: - -* `model_name`: is a property that contains the name of the model to be used, which means that it needs to be retrieved from the LLM using the LLM-specific approach i.e. for [`TransformersLLM`][distilabel.llms.TransformersLLM] the `model_name` will be the `model_name_or_path` provided as an argument, or in [`OpenAILLM`][distilabel.llms.OpenAILLM] the `model_name` will be the `model` provided as an argument. - -* `generate`: is a method that will take a list of prompts and return a list of generated texts. This method will be called by the [`Task`][distilabel.steps.tasks.Task] to generate the texts, so it's the most important method to define. This method will be implemented in the subclass of the [`LLM`][distilabel.llms.LLM] i.e. the synchronous LLM. - -* `agenerate`: is a method that will take a single prompt and return a list of generated texts, since the rest of the behaviour will be controlled by the `generate` method that cannot be overwritten when subclassing [`AsyncLLM`][distilabel.llms.AsyncLLM]. This method will be called by the [`Task`][distilabel.steps.tasks.Task] to generate the texts, so it's the most important method to define. This method will be implemented in the subclass of the [`AsyncLLM`][distilabel.llms.AsyncLLM] i.e. the asynchronous LLM. - -* (optional) `get_last_hidden_state`: is a method that will take a list of prompts and return a list of hidden states. This method is optional and will be used by some tasks such as the [`GenerateEmbeddings`][distilabel.steps.tasks.GenerateEmbeddings] task. - -Once those methods have been implemented, then the custom LLM will be ready to be integrated within either any of the existing or a new task. - -```python -from typing import Any - -from pydantic import validate_call - -from distilabel.llms import AsyncLLM, LLM -from distilabel.llms.typing import GenerateOutput, HiddenState -from distilabel.steps.tasks.typing import ChatType - - -class CustomLLM(LLM): - @property - def model_name(self) -> str: - return "my-model" - - @validate_call - def generate(self, inputs: List[ChatType], num_generations: int = 1, **kwargs: Any) -> List[GenerateOutput]: - for _ in range(num_generations): - ... - - def get_last_hidden_state(self, inputs: List[ChatType]) -> List[HiddenState]: - ... - - -class CustomAsyncLLM(AsyncLLM): - @property - def model_name(self) -> str: - return "my-model" - - @validate_call - async def agenerate(self, input: ChatType, num_generations: int = 1, **kwargs: Any) -> GenerateOutput: - for _ in range(num_generations): - ... - - def get_last_hidden_state(self, inputs: List[ChatType]) -> List[HiddenState]: - ... -``` - -`generate` and `agenerate` keyword arguments (but `input` and `num_generations`) are considered as `RuntimeParameter`s, so a value can be passed to them via the `parameters` argument of the `Pipeline.run` method. - -!!! NOTE - To have the arguments of the `generate` and `agenerate` coerced to the expected types, the `validate_call` decorator is used, which will automatically coerce the arguments to the expected types, and raise an error if the types are not correct. This is specially useful when providing a value for an argument of `generate` or `agenerate` from the CLI, since the CLI will always provide the arguments as strings. - -## Available LLMs - -Here's a list with the available LLMs that can be used within the `distilabel` library: - -* [AnthropicLLM][distilabel.llms.AnthropicLLM] -* [AnyscaleLLM][distilabel.llms.AnyscaleLLM] -* [AzureOpenAILLM][distilabel.llms.AzureOpenAILLM] -* [CohereLLM][distilabel.llms.CohereLLM] -* [GroqLLM][distilabel.llms.GroqLLM] -* [InferenceEndpointsLLM][distilabel.llms.huggingface.InferenceEndpointsLLM] -* [LiteLLM][distilabel.llms.LiteLLM] -* [LlamaCppLLM][distilabel.llms.LlamaCppLLM] -* [MistralLLM][distilabel.llms.MistralLLM] -* [OllamaLLM][distilabel.llms.OllamaLLM] -* [OpenAILLM][distilabel.llms.OpenAILLM] -* [TogetherLLM][distilabel.llms.TogetherLLM] -* [TransformersLLM][distilabel.llms.huggingface.TransformersLLM] -* [VertexAILLM][distilabel.llms.VertexAILLM] -* [vLLM][distilabel.llms.vLLM] diff --git a/docs/sections/learn/tutorial/pipeline/index.md b/docs/sections/learn/tutorial/pipeline/index.md deleted file mode 100644 index cd4720d5cb..0000000000 --- a/docs/sections/learn/tutorial/pipeline/index.md +++ /dev/null @@ -1,459 +0,0 @@ -# Pipeline - -The [`Pipeline`][distilabel.pipeline.Pipeline] is the central point in `distilabel`, the way to organize the steps to create your datasets. Up to this point we've seen how we can define different [`Step`][distilabel.steps.Step] and [`Task`][distilabel.steps.tasks.Task] subclasses in [Tutorial - Step](../step/index.md) and [Tutorial - Task](../task/index.md), respectively; which together with an [`LLM`][distilabel.llms.LLM] are the building blocks of our datasets, in this section we will take a look at how all these blocks are organized inside a [`Pipeline`][distilabel.pipeline.Pipeline]. - -!!! Note - Currently `distilabel` implements a *local* version of a [`Pipeline`][distilabel.pipeline.Pipeline], and will assume that's the only definition, but this can be extended in the future to include remote execution of the [`Pipeline`][distilabel.pipeline.Pipeline]. - -## How to create a pipeline - -The most common way a [`Pipeline`][distilabel.pipeline.Pipeline] should be created is by making use of the context manager, we just need to give our [`Pipeline`][distilabel.pipeline.Pipeline] a **name**, and optionally a **description**, and that's it[^1]: - -```python -from distilabel.pipeline import Pipeline - -with Pipeline("pipe-name", description="My first pipe") as pipeline: # (1) - ... - -``` - -1. Use the context manager to create a [`Pipeline`][distilabel.pipeline.Pipeline] with a name and an optional description. - -This way, we ensure all the steps we define there are connected with each other under the same [`Pipeline`][distilabel.pipeline.Pipeline]. The next step is to define the steps of our [`Pipeline`][distilabel.pipeline.Pipeline]. It's mandatory that the root steps of the pipeline i.e. the ones that doesn't have any predecessors, are [`GeneratorStep`][distilabel.steps.GeneratorStep]s such as [`LoadDataFromDicts`][distilabel.steps.LoadDataFromDicts] or [`LoadHubDataset`][distilabel.steps.LoadHubDataset]. - -```python -from distilabel.pipeline import Pipeline -from distilabel.steps import LoadHubDataset - -with Pipeline("pipe-name", description="My first pipe") as pipeline: - load_dataset = LoadHubDataset(name="load_dataset") # (1) - ... - -``` - -1. Define the first step of the pipeline, in this case `LoadHubDataset`, a `GeneratorStep` used to load a dataset from the Hugging Face Hub. - -Once we have a source of data, we can create another [`Step`][distilabel.steps.Step]s that will consume and process the data generated by the `GeneratorStep`s. Let's assume that the dataset we're going to load from the Hub contains a `prompt` column and that we want to generate texts based on this prompt. We also want to use several `LLM`s for this task. To do so, we will create several `TextGeneration` tasks, each with a different `LLM`. - -```python -from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM -from distilabel.pipeline import Pipeline -from distilabel.steps import LoadHubDataset -from distilabel.steps.tasks import TextGeneration - -with Pipeline("pipe-name", description="My first pipe") as pipeline: - load_dataset = LoadHubDataset(name="load_dataset") - - for llm in ( - OpenAILLM(model="gpt-4-0125-preview"), - MistralLLM(model="mistral-large-2402"), - VertexAILLM(model="gemini-1.5-pro"), - ): - task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) # (1) - task.connect(load_dataset) # (2) - - ... -``` - -1. Create a `TextGeneration` task for each `LLM` we want to use. -2. Connect the `TextGeneration` task with the `LoadHubDataset` step, so the output data from the dataset is passed to the task. - -!!! NOTE - The order of the execution of the steps will be determined by the connections of the steps. In this case, the `TextGeneration` tasks will be executed after the `LoadHubDataset` step. - -For each row of the dataset, the `TextGeneration` task will generate a text based on the `instruction` column and the `LLM` model, and store the result (a single string) in a new column called `generation`. As we would like to have all the `response`s in the same column, we will add an extra step to combine them all in the same column, so the value of this column is a list of strings or responses. - -```python -from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM -from distilabel.pipeline import Pipeline -from distilabel.steps import CombineColumns, LoadHubDataset -from distilabel.steps.tasks import TextGeneration - -with Pipeline("pipe-name", description="My first pipe") as pipeline: - load_dataset = LoadHubDataset(name="load_dataset") - - combine_generations = CombineColumns( # (1) - name="combine_generations", - columns=["generation", "model_name"], - output_columns=["generations", "model_names"], - ) - - for llm in ( - OpenAILLM(model="gpt-4-0125-preview"), - MistralLLM(model="mistral-large-2402"), - VertexAILLM(model="gemini-1.5-pro"), - ): - task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) - load_dataset.connect(task) - task.connect(combine_generations) # (2) -``` - -1. Create a `CombineColumns` step to combine all the `generation` columns into a single column called `generations` and the `model_name` columns into a single column called `model_names`. -2. Connect the `TextGeneration` task with the `CombineColumns` step, so the output data from the task is passed to the step that will combine all the `generation` columns. - -As the [`CombineColumns`][distilabel.steps.CombineColumns] is the last step or it's a leaf step of the pipeline because it doesn't have any successors, that means that the outputs of this step will be included in the returned [`Distiset`][distilabel.distiset.Distiset] (more information about it in [Advanced - Distiset](../../advanced/distiset.md)). - -!!! NOTE - One pipeline can have several leaf steps, which means that the outputs of all the leaf steps will be included in the returned `Distiset`, which will contain several subsets, one for each leaf step. - -### Connecting steps - -In the previous example we saw how to create a `Pipeline` and connect different steps using the `Step.connect` method: `step1.connect(step2)`, but there's an alternative way by making use of the `>>` operator, let's see how using the previous `Pipeline` as an example: - -=== "Step per step" - - Each call to `step1.connect(step2)` has been exchanged by `step1 >> step2`: - - ```python - from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM - from distilabel.pipeline import Pipeline - from distilabel.steps import CombineColumns, LoadHubDataset - from distilabel.steps.tasks import TextGeneration - - with Pipeline("pipe-name", description="My first pipe") as pipeline: - load_dataset = LoadHubDataset(name="load_dataset") - - combine_generations = CombineColumns( - name="combine_generations", - columns=["generation", "model_name"], - output_columns=["generations", "model_names"], - ) - - for llm in ( - OpenAILLM(model="gpt-4-0125-preview"), - MistralLLM(model="mistral-large-2402"), - VertexAILLM(model="gemini-1.5-pro"), - ): - task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) - load_dataset >> task >> combine_generations # (1) - ``` - - 1. Here `load_dataset >> task >> combine_generations` was exchanged with `load_dataset.connect(task).connect(combine_generations)`. - - -=== "Multiple steps at once" - - All the calls to connections from the `load_dataset` step to the different `task` objects are done in a single call: - - ```python - from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM - from distilabel.pipeline import Pipeline - from distilabel.steps import CombineColumns, LoadHubDataset - from distilabel.steps.tasks import TextGeneration - - with Pipeline("pipe-name", description="My first pipe") as pipeline: - load_dataset = LoadHubDataset(name="load_dataset") - - combine_generations = CombineColumns( - name="combine_generations", - columns=["generation", "model_name"], - output_columns=["generations", "model_names"], - ) - - tasks = [] - for llm in ( - OpenAILLM(model="gpt-4-0125-preview"), - MistralLLM(model="mistral-large-2402"), - VertexAILLM(model="gemini-1.5-pro"), - ): - tasks.append( - TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) - ) - - load_dataset >> tasks >> combine_generations # (1) - ``` - - 1. Notice how `tasks` is a list of different `Tasks`. In a single call to the operator we are connecting `load_dataset` with all the `tasks`, and all of those again to `combine_generations`. - - -### Routing batches to specific downstream steps - -In some pipelines, it's likely that you will need to have a list of downstream steps receiving batches from the same upstream step, but you would like to route the batches to specific downstream steps based on some condition. To do so, you can use a `routing_batch_function`, which is a simple function that receives a list of the downstream steps to which a batch can be routed, and returns a list containing the names of steps to which the batch should be routed. Let's update the example above to route the batches loaded by the `LoadHubDataset` step to just 2 of the `TextGeneration` tasks. First, we will create our custom [`routing_batch_function`][distilabel.pipeline.routing_batch_function.routing_batch_function], and then we will update the pipeline to use it: - -```python -import random -from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM -from distilabel.pipeline import Pipeline, routing_batch_function -from distilabel.steps import CombineColumns, LoadHubDataset -from distilabel.steps.tasks import TextGeneration - -@routing_batch_function -def sample_two_steps(steps: list[str]) -> list[str]: - return random.sample(steps, 2) - -with Pipeline("pipe-name", description="My first pipe") as pipeline: - load_dataset = LoadHubDataset( - name="load_dataset", - output_mappings={"prompt": "instruction"}, - ) - - tasks = [] - for llm in ( - OpenAILLM(model="gpt-4-0125-preview"), - MistralLLM(model="mistral-large-2402"), - VertexAILLM(model="gemini-1.0-pro"), - ): - tasks.append( - TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) - ) - - combine_generations = CombineColumns( - name="combine_generations", - columns=["generation", "model_name"], - output_columns=["generations", "model_names"], - ) - - load_dataset >> sample_two_steps >> tasks >> combine_generations -``` - -As it can be seen, the `routing_batch_function` can be used with the `>>` operator to route the batches to specific downstream steps. In this case, each batch yielded by the `load_dataset` step will be routed to just 2 of the `TextGeneration` tasks, and then all the outputs of the tasks will be combined in the `CombineColumns` step so each row of the final dataset will contain generations of 2 `LLM`s at most. The `routing_batch_function` that we just built is a common one, so `distilabel` comes with an auxiliary function that can be used to achieve the same behavior: - -```python -from distilable.pipeline import sample_n_steps - -sample_two_steps = sample_n_steps(2) -``` - -## Running the pipeline - -### Pipeline.dry_run - -Before running the `Pipeline` we may want to check all the components behave as expected. We can do a `dry_run` for this case: - -```python -with Pipeline("pipe-name", description="My first pipe") as pipeline: - ... - -if __name__ == "__main__": - distiset = pipeline.dry_run(parameters=..., batch_size=1) -``` - -It takes the same parameters as the `run` method we will see in the following section, plus the `batch_size` we want the dry run to use (by default set to 1). In this case, the `Pipeline` would select a single example from our generator steps and pass through all the steps. Assuming the `dry_run` runs successfully, we are ready to run our pipeline. - -### Pipeline.run - -Once we have created the pipeline, we can run it. To do so, we need to call the `run` method of the `Pipeline`, and specify the runtime parameters for each step: - -```python -with Pipeline("pipe-name", description="My first pipe") as pipeline: - ... - -if __name__ == "__main__": - distiset = pipeline.run( - parameters={ - "load_dataset": { - "repo_id": "distilabel-internal-testing/instruction-dataset-mini", - "split": "test", - }, - "text_generation_with_gpt-4-0125-preview": { - "llm": { - "generation_kwargs": { - "temperature": 0.7, - "max_new_tokens": 512, - } - } - }, - "text_generation_with_mistral-large-2402": { - "llm": { - "generation_kwargs": { - "temperature": 0.7, - "max_new_tokens": 512, - } - } - }, - "text_generation_with_gemini-1.0-pro": { - "llm": { - "generation_kwargs": { - "temperature": 0.7, - "max_new_tokens": 512, - } - } - }, - }, - ) -``` - -But if we run it, we will see that the `run` method will fail: - -``` -ValueError: Step 'text_generation_with_gpt-4-0125-preview' requires inputs ['instruction'], but only the inputs=['prompt', 'completion', 'meta'] are available, which means that the inputs=['instruction'] are missing or not available -when the step gets to be executed in the pipeline. Please make sure previous steps to 'text_generation_with_gpt-4-0125-preview' are generating the required inputs. -``` - -This is because, before actually running the pipeline, the pipeline is validated to verify that everything is correct and all the steps in the pipeline are chainable i.e. each step has the necessary inputs to be executed. In this case, the `TextGeneration` task requires the `instruction` column, but the `LoadHubDataset` step generates the `prompt` column. To solve this, we can use the `output_mappings` argument that every `Step` has, to map or rename the output columns of a step to the required input columns of another step: - -```python -with Pipeline("pipe-name", description="My first pipe") as pipeline: - load_dataset = LoadHubDataset( - name="load_dataset", - output_mappings={"prompt": "instruction"}, # (1) - ) - - ... -``` - -1. Use the `output_mappings` argument to map the `prompt` column generated by the `LoadHubDataset` step to the `instruction` column required by the `TextGeneration` task. - -If we execute the pipeline again, it will run successfully and we will have a `Distiset` with the outputs of all the leaf steps of the pipeline which we can push to the Hugging Face Hub. - -```python -if __name__ == "__main__": - distiset = pipeline.run(...) - distiset.push_to_hub("distilabel-internal-testing/instruction-dataset-mini-with-generations") -``` - -### Stopping the pipeline - -In case you want to stop the pipeline while it's running using the `Ctrl+c` (`Cmd+c` in macos), we automatically catch the signal and try to finish whatever steps are currently running. If it got hang by some reason, repeating the command 2 times it will force the pipeline close. - -!!! Note - - When pushing sending the signal to kill the process, you can expect to see the following log messages: - - ![Pipeline ctrl+c](../../../../assets/images/sections/pipeline/pipeline-ctrlc.png) - -## Cache - -If we try to execute the pipeline again, the pipeline won't execute as it will load the dataset from the cache, and the outputs of the pipeline will be the same as the previous run. If for some reason, we decide to stop the pipeline execution in the middle of the process pressing CTRL + C, the pipeline will stop and the state of the pipeline and the outputs will be stored in the cache, so we can resume the pipeline execution from the point where it was stopped. - -If we want to force the pipeline to run again, then we can use the `use_cache` argument of the `run` method and set it to `False`: - -```python -if __name__ == "__main__": - distiset = pipeline.run(parameters={...}, use_cache=False) -``` - -## Adjusting the batch size for each step - -It's very likely that in some pipelines the batch size of the steps (the number of dictionaries that will receive every `Step.process` method when called) will need to be adjusted in order to avoid memory issues or a more efficient processing. To do so, we can use the `input_batch_size` argument of the `run` method: - -```python -from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM -from distilabel.pipeline import Pipeline -from distilabel.steps import CombineColumns, LoadHubDataset -from distilabel.steps.tasks import TextGeneration - -with Pipeline("pipe-name", description="My first pipe") as pipeline: - ... - - for llm in ( - OpenAILLM(model="gpt-4-0125-preview"), - MistralLLM(model="mistral-large-2402"), - VertexAILLM(model="gemini-1.5-pro"), - ): - task = TextGeneration( - name=f"text_generation_with_{llm.model_name}", - llm=llm, - input_batch_size=5, # (1) - ) - - ... -``` - -1. Use the `input_batch_size` argument to set the batch size of the `TextGeneration` task to 5. - -When we run the pipeline, the `TextGeneration` task will receive 5 dictionaries in every call to the `process` method. In addition, we can also adjust the batch size of the generated batches by the `GeneratorStep`s using the `batch_size` argument: - -```python -with Pipeline("pipe-name", description="My first pipe") as pipeline: - load_dataset = LoadHubDataset( - name="load_dataset", - output_mappings={"prompt": "instruction"}, - batch_size=10 # (1) - ) - - ... -``` - -1. Use the `batch_size` argument to set the batch size of the `LoadHubDataset` step to 10. - -By default, both arguments have a value of `50`. - -## Serializing the pipeline - -Sharing a pipeline with others is very easy, as we can serialize the pipeline object using the `save` method. We can save the pipeline in different formats, such as `yaml` or `json`: - -```python -if __name__ == "__main__": - pipeline.save("pipeline.yaml", format="yaml") -``` - -To load the pipeline, we can use the `from_yaml` or `from_json` methods: - -```python -pipeline = Pipeline.from_yaml("pipeline.yaml") -``` - -Serializing the pipeline is very useful when we want to share the pipeline with others, or when we want to store the pipeline for future use. It can even be hosted online, so the pipeline can be executed directly using the [CLI](../cli/index.md) knowing the URL of the pipeline. - -## Fully working example - -To sump up, here is the full code of the pipeline we have created in this section. Note that you will need to change the name of the Hugging Face repository where the resulting will be pushed, set `OPENAI_API_KEY` environment variable, set `MISTRAL_API_KEY` and have `gcloud` installed and configured: - -??? Code - ```python - from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM - from distilabel.pipeline import Pipeline - from distilabel.steps import CombineColumns, LoadHubDataset - from distilabel.steps.tasks import TextGeneration - - with Pipeline("pipe-name", description="My first pipe") as pipeline: - load_dataset = LoadHubDataset( - name="load_dataset", - output_mappings={"prompt": "instruction"}, - ) - - combine_generations = CombineColumns( - name="combine_generations", - columns=["generation", "model_name"], - output_columns=["generations", "model_names"], - ) - - for llm in ( - OpenAILLM(model="gpt-4-0125-preview"), - MistralLLM(model="mistral-large-2402"), - VertexAILLM(model="gemini-1.0-pro"), - ): - task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm) - load_dataset.connect(task) - task.connect(combine_generations) - - if __name__ == "__main__": - distiset = pipeline.run( - parameters={ - "load_dataset": { - "repo_id": "distilabel-internal-testing/instruction-dataset-mini", - "split": "test", - }, - "text_generation_with_gpt-4-0125-preview": { - "llm": { - "generation_kwargs": { - "temperature": 0.7, - "max_new_tokens": 512, - } - } - }, - "text_generation_with_mistral-large-2402": { - "llm": { - "generation_kwargs": { - "temperature": 0.7, - "max_new_tokens": 512, - } - } - }, - "text_generation_with_gemini-1.0-pro": { - "llm": { - "generation_kwargs": { - "temperature": 0.7, - "max_new_tokens": 512, - } - } - }, - }, - ) - distiset.push_to_hub( - "distilabel-internal-testing/instruction-dataset-mini-with-generations" - ) - ``` - -[^1]: We also have the *cache_dir* argument to pass, for more information on this parameter, we refer the reader to the [caching](../../advanced/caching.md) section. diff --git a/docs/sections/learn/tutorial/step/generator_step.md b/docs/sections/learn/tutorial/step/generator_step.md deleted file mode 100644 index a5599174ab..0000000000 --- a/docs/sections/learn/tutorial/step/generator_step.md +++ /dev/null @@ -1,117 +0,0 @@ -# GeneratorStep - -The [`GeneratorStep`][distilabel.steps.GeneratorStep] is a subclass of [`Step`][distilabel.steps.Step] that only produces outputs, but doesn't receive any input. The [`GeneratorStep`][distilabel.steps.GeneratorStep] is intended to be used the first step within a [`Pipeline`][distilabel.pipeline.Pipeline], since it doesn't require any input to run and will generate data that can be potentially used by the follow up steps. - -## Working with GeneratorSteps - -The [`GeneratorStep`][distilabel.steps.GeneratorStep] is intended to be used within the scope of a [`Pipeline`][distilabel.pipeline.Pipeline] before any other [`Step`][distilabel.steps.Step]. Alternatively, in can also be used as a standalone [`Step`][distilabel.steps.Step] i.e. not within the context of a [`Pipeline`][distilabel.pipeline.Pipeline]. - -For example, the following code snippet shows how to use the [`GeneratorStep`][distilabel.steps.GeneratorStep] as a standalone [`Step`][distilabel.steps.Step], to generate data out of a provided list of strings. - -```python -from typing import List -from typing_extensions import override - -from distilabel.steps import GeneratorStep -from distilabel.steps.typing import GeneratorStepOutput - -class MyGeneratorStep(GeneratorStep): - instructions: List[str] - - @override - def process(self, offset: int = 0) -> GeneratorStepOutput: - if offset: - self.instructions = self.instructions[offset:] - - while self.instructions: - batch = [ - { - "instruction": instruction - } for instruction in self.instructions[: self.batch_size] - ] - self.instructions = self.instructions[self.batch_size :] - yield ( - batch, - True if len(self.instructions) == 0 else False, - ) - - @property - def outputs(self) -> List[str]: - return ["instruction"] -``` - -Then we can use / instantiate it as follows: - -```python -step = MyGeneratorStep( - name="my-generator-step", - instructions=["Tell me a joke.", "Tell me a story."], - batch_size=1, -) -step.load() - -next(step.process(offset=0)) -# ([{'instruction': 'Tell me a joke.'}], False) -next(step.process(offset=1)) -# ([{'instruction': 'Tell me a story.'}], True) -``` - -!!! NOTE - The `load` method needs to be called ALWAYS if using the steps and any [`Step`][distilabel.steps.Step] subclass as standalone, unless the [`Pipeline`][distilabel.pipeline.Pipeline] context manager is used, meaning that there will be no need to call the `load` method, since it will be automatically called on `Pipeline.run`; but in any other case the method `load` needs to be called from the parent class. - -Anyway, most of the times we'll end up using pre-defined steps in `distilabel`, so that there's no need to create custom steps, but anyway, we'll cover that later in this page. - -## Defining custom GeneratorSteps - -In order to define a custom [`GeneratorStep`][distilabel.steps.GeneratorStep], we need to subclass it, and set the `outputs` property, and define the `process` method. In this case, the `process` method signature differs from the `process` method signature of the [`Step`][distilabel.steps.Step], since it won't receive any `inputs` but generate those, so the only argument of `process` is `offset` which is automatically handled by the [`Pipeline`][distilabel.pipeline.Pipeline] shifting it until all the batches are generated. - -So on, the following will need to be defined: - -- `outputs`: is a property that returns a list of strings with the names of the output fields. - -- `process`: is a method that yields output data and a boolean flag indicating whether that's the last batch to be generated. It's important to override the default signature of the [`Step.process`][distilabel.steps.Step] method `def process(self, *inputs: StepInput) -> StepOutput`, to be set to `def process(self, offset: int = 0) -> GeneratorStepOutput` instead, since that's the one that will be used by the [`Pipeline`][distilabel.pipeline.Pipeline] to orchestrate the steps, meaning that the argument `offset` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too. - -!!! NOTE - The default signature for the `process` method is `process(self, *inputs: StepInput) -> StepOutput`, but since in this case we're defining a [`GeneratorStep`][distilabel.steps.GeneratorStep], we will need to override that (ideally under the `typing_extensions.override` decorator) with `process(self, offset: int = 0) -> GeneratorStepOutput`, so that the `process` method only receives the `offset` argument, and the return type-hints should be respected too. The `offset` argument is automatically handled by the [`Pipeline`][distilabel.pipeline.Pipeline] shifting it until all the batches are generated, and there's no need to default it to 0, since it will be set to 0 by default anyway. - -!!! WARNING - For the custom [`GeneratorStep`][distilabel.steps.GeneratorStep] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for [`GeneratorStepOutput`][distilabel.steps.typing.GeneratorStepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. - -```python -from typing import List -from typing_extensions import override - -from distilabel.steps import GeneratorStep -from distilabel.steps.typing import GeneratorStepOutput - -class MyGeneratorStep(GeneratorStep): - instructions: List[str] - - @override - def process(self, offset: int = 0) -> GeneratorStepOutput: - ... - - @property - def outputs(self) -> List[str]: - ... -``` - -Alternatively, a simpler and more suitable way of defining custom [`GeneratorStep`][distilabel.steps.GeneratorStep] subclasses is via the `@step` decorator with the `step_type="generator"`, which will take care of the boilerplate code, and will allow to define the `outputs`, and `process` methods in a more straightforward way. - -```python -from distilabel.steps import step -from distilabel.steps.typing import GeneratorStepOutput - -@step(outputs=[...], step_type="generator") -def CustomGeneratorStep(offset: int = 0) -> GeneratorStepOutput: - yield ( - ..., - True if offset == 10 else False, - ) - -step = CustomGeneratorStep(name="my-step") -``` - -!!! WARNING - One downside of the `@step` decorator is that it won't let you access the `self` attributes if any, neither set those, so if you need to access or set any attribute, you should go with the first approach of defining the custom [`GeneratorStep`][distilabel.steps.GeneratorStep] subclass. - diff --git a/docs/sections/learn/tutorial/step/global_step.md b/docs/sections/learn/tutorial/step/global_step.md deleted file mode 100644 index 1b08884f44..0000000000 --- a/docs/sections/learn/tutorial/step/global_step.md +++ /dev/null @@ -1,70 +0,0 @@ -# GlobalStep - -The [`GlobalStep`][distilabel.steps.GlobalStep] is a subclass of [`Step`][distilabel.steps.Step] that is used to define a step that requires the previous steps to be completed to run, since it will wait until all the input batches are received before running. This step is useful when you need to run a step that requires all the input data to be processed before running. - -## Working with GlobalSteps - -The [`GlobalStep`][distilabel.steps.GlobalStep] is intended to be used within the scope of a [`Pipeline`][distilabel.pipeline.Pipeline] and after some previous steps have been defined. Alternatively, it can also be used as a standalone [`Step`][distilabel.steps.Step] if needed, but then using [`Step`][distilabel.steps.Step] instead would be more appropriate. - -## Defining custom GlobalSteps - -In order to define custom steps, we need to create a new subclass of the [`GlobalStep`][distilabel.steps.GlobalStep] class, and set both the `inputs` and `outputs` property, as well as the `process` method. - -So on, the following will need to be defined: - -- `inputs`: is a property that returns a list of strings with the names of the required input fields. - -- `outputs`: is a property that returns a list of strings with the names of the output fields. - -- `process`: is a method that receives the input data and returns the output data, and it should be a generator, meaning that it should `yield` the output data. It's important to preserve the default signature within the method `def process(self, *inputs: StepInput) -> StepOutput`, since that's the one that will be used by the [`Pipeline`][distilabel.pipeline.Pipeline] to orchestrate the steps, meaning that the argument `inputs` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too. - -!!! NOTE - The default signature for the `process` method is `process(self, *inputs: StepInput) -> StepOutput`, meaning that it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. Anyway, when defining custom steps, that can be overridden with `process(self, inputs: StepInput) -> StepOutput`, so that the `process` method only receives the outputs from one previous [`Step`][distilabel.steps.Step] connected to it. - -!!! WARNING - For the custom [`GlobalStep`][distilabel.steps.GlobalStep] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. - -```python -from distilabel.steps import GlobalStep, StepInput -from distilabel.steps.typing import StepOutput - -class CustomStep(Step): - @property - def inputs(self) -> List[str]: - ... - - @property - def outputs(self) -> List[str]: - ... - - def process(self, *inputs: StepInput) -> StepOutput: - for input in inputs: - for item in input: - ... - yield item - - # When overridden (ideally under the `typing_extensions.override` decorator) - # @typing_extensions.override - # def process(self, inputs: StepInput) -> StepOutput: - # for input in inputs: - # ... - # yield inputs -``` - -Alternatively, a simpler and more suitable way of defining custom [`GlobalStep`][distilabel.steps.GlobalStep] subclasses is via the `@step` decorator with the `step_type="global"`, which will take care of the boilerplate code, and will allow to define the `inputs`, `outputs`, and `process` methods in a more straightforward way. - -```python -from distilabel.steps import StepInput, step -from distilabel.steps.typing import StepOutput - -@step(inputs=[...], outputs=[...], step_type="global") -def CustomStep(inputs: StepInput) -> StepOutput: - for input in inputs: - ... - yield inputs - -step = CustomStep(name="my-step") -``` - -!!! WARNING - One downside of the `@step` decorator is that it won't let you access the `self` attributes if any, neither set those, so if you need to access or set any attribute, you should go with the first approach of defining the custom [`GlobalStep`][distilabel.steps.GlobalStep] subclass. diff --git a/docs/sections/learn/tutorial/step/index.md b/docs/sections/learn/tutorial/step/index.md deleted file mode 100644 index df9d1bbc74..0000000000 --- a/docs/sections/learn/tutorial/step/index.md +++ /dev/null @@ -1,142 +0,0 @@ -# Step - -The [`Step`][distilabel.steps.Step] is an abstract class which defines the interface for the building blocks to be defined within the context of a [`Pipeline`][distilabel.pipeline.Pipeline], a [`Step`][distilabel.steps.Step] can be seen as a node within a Direct Acyclic Graph (DAG) which execution is orchestrated by the [`Pipeline`][distilabel.pipeline.Pipeline]. - -## Working with Steps - -The [`Step`][distilabel.steps.Step] is intended to be used within the scope of a [`Pipeline`][distilabel.pipeline.Pipeline], which will orchestrate the different steps defined; but nonetheless, they can be used standalone if needed too. - -Assuming that we have a [`Step`][distilabel.steps.Step] already defined as it follows: - -```python -class MyStep(Step): - @property - def inputs(self) -> List[str]: - return ["input_field"] - - @property - def outputs(self) -> List[str]: - return ["output_field"] - - def process(self, inputs: StepInput) -> "StepOutput": - for input in inputs: - input["output_field"] = input["input_field"] - yield inputs -``` - -Then we can use / instantiate it as follows: - -```python -step = MyStep(name="my-step") -step.load() - -next(step.process([{"input_field": "value"}])) -# [{'input_field': 'value', 'output_field': 'value'}] -``` -!!! NOTE - The `load` method needs to be called ALWAYS if using the steps and any [`Step`][distilabel.steps.Step] subclass as standalone, unless the [`Pipeline`][distilabel.pipeline.Pipeline] context manager is used, meaning that there will be no need to call the `load` method, since it will be automatically called on `Pipeline.run`; but in any other case the method `load` needs to be called from the parent class. - -Anyway, most of the times we'll end up using pre-defined steps in `distilabel`, so that there's no need to create custom steps, but anyway, we'll cover that later in this page. - -Let's see now a set of arguments that can be used to map fields across steps, or to set the batch size specific for the step: - -- `input_mappings`, which is a dictionary that can be useful to map keys from the input dictionaries to the keys expected by the step. For example, if `input_mappings={"instruction": "prompt"}`, that means that the key prompt from the input dictionaries will be used as the key instruction for the step. - -- `output_mappings`, which is a dictionary that can be used to map the outputs of the step to other names. For example, if `output_mappings={"conversation": "prompt"}`, that means that the key conversation generated by the step will be renamed to prompt and the output dictionaries of this step will contain a key called prompt instead of conversation. - -- `input_batch_size` (by default set to 50), which is independent for every step and will determine how many input dictionaries will process at once. If won't matter that much in this step, but as we will see later, other types of steps will come with an LLM, so having this flexibility will be really useful. - -### Runtime parameters - -Finally, let's introduce at a special type of argument that we will find when dealing with the `Steps`, the `Runtime parameters`. For example, the `input_batch_size` is of type `RuntimeParameter`: - -```python -from distilabel.mixins.runtime_parameters import RuntimeParameter - -class Step(...): - input_batch_size: RuntimeParameter[PositiveInt] = Field( - default=DEFAULT_INPUT_BATCH_SIZE, - description="The number of rows that will contain the batches processed by the" - " step.", - ) -``` - -We can interact with these types of arguments when we call the `Pipeline.run` method as we will see in the `Pipeline` section. These types of arguments can be really useful to insert info to the steps after the pipeline has been defined. - -## Types of Steps - -Besides the default [`Step`][distilabel.steps.Step] already described, in `distilabel` we find the following abstract subclasses on top of the [`Step`][distilabel.steps.Step]. - -* [`GeneratorStep`][distilabel.steps.GeneratorStep]: is a step that only produces / generates data, and it doesn't need any input data from previous steps, is in most of the cases a parent node of the graph i.e. the first [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline]. - - More information about it at [Components -> Step - GeneratorStep](./generator_step.md). - -* [`GlobalStep`][distilabel.steps.GlobalStep]: is a step with the standard interface i.e. receives inputs and generates outputs, but it processes all the data at once, is in most of the cases a leaf node of the graph i.e. the last [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline]. The fact that a [`GlobalStep`][distilabel.steps.GlobalStep] requires the outputs from the previous steps, means that the previous steps needs to finish for this step to start, and the connected outputs steps, if any, will need to wait until this step is done. - - More information about it at [Components - Step - GlobalStep](global_step.md). - -Additionally, `distilabel` also defines another type of [`Step`][distilabel.steps.Step], which is the [`Task`][distilabel.steps.tasks.Task], which is essentially the same, besides the fact that the task will expect an [`LLM`][distilabel.llms.LLM] as an attribute, and the `process` method will be in charge of calling that LLM. So one could say that the [`Task`][distilabel.steps.tasks.Task] is a [`Step`][distilabel.steps.Step] to work with an [`LLM`][distilabel.llms.LLM]. - -More information about it at [Components - Task](../task/index.md). - -## Defining custom Steps - -In order to define custom steps, we need to create a new subclass of the [`Step`][distilabel.steps.Step] class, and set both the `inputs` and `outputs` property, as well as the `process` method. - -So on, the following will need to be defined: - -- `inputs`: is a property that returns a list of strings with the names of the required input fields. - -- `outputs`: is a property that returns a list of strings with the names of the output fields. - -- `process`: is a method that receives the input data and returns the output data, and it should be a generator, meaning that it should `yield` the output data. It's important to preserve the default signature within the method `def process(self, *inputs: StepInput) -> StepOutput`, since that's the one that will be used by the [`Pipeline`][distilabel.pipeline.Pipeline] to orchestrate the steps, meaning that the argument `inputs` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too. - -!!! NOTE - The default signature for the `process` method is `process(self, *inputs: StepInput) -> StepOutput`, meaning that it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. Anyway, when defining custom steps, that can be overridden with `process(self, inputs: StepInput) -> StepOutput`, so that the `process` method only receives the outputs from one previous [`Step`][distilabel.steps.Step] connected to it. - -!!! WARNING - For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. - -```python -from distilabel.steps import Step, StepInput -from distilabel.steps.typing import StepOutput - -class CustomStep(Step): - @property - def inputs(self) -> List[str]: - ... - - @property - def outputs(self) -> List[str]: - ... - - def process(self, *inputs: StepInput) -> StepOutput: - for input in inputs: - ... - yield item - - # When overridden (ideally under the `typing_extensions.override` decorator) - # @typing_extensions.override - # def process(self, inputs: StepInput) -> StepOutput: - # for input in inputs: - # ... - # yield inputs -``` - -Alternatively, a simpler and more suitable way of defining custom [`Step`][distilabel.steps.Step] subclasses is via the `@step` decorator, which will take care of the boilerplate code, and will allow to define the `inputs`, `outputs`, and `process` methods in a more straightforward way. - -```python -from distilabel.steps import StepInput, step -from distilabel.steps.typing import StepOutput - -@step(inputs=[...], outputs=[...]) -def CustomStep(inputs: StepInput) - StepOutput: - for input in inputs: - ... - yield inputs - -step = CustomStep(name="my-step") -``` - -!!! WARNING - One downside of the `@step` decorator is that it won't let you access the `self` attributes if any, neither set those, so if you need to access or set any attribute, you should go with the first approach of defining the custom [`Step`][distilabel.steps.Step] subclass. diff --git a/docs/sections/learn/tutorial/task/index.md b/docs/sections/learn/tutorial/task/index.md deleted file mode 100644 index 6f6f1259bd..0000000000 --- a/docs/sections/learn/tutorial/task/index.md +++ /dev/null @@ -1,71 +0,0 @@ -# Task - -The [`Task`][distilabel.steps.tasks.Task] is an implementation on top of [`Step`][distilabel.steps.Step] that includes the [`LLM`][distilabel.llms.LLM] as a mandatory argument, so that the [`Task`][distilabel.steps.tasks.Task] defines both the input and output format via the `format_input` and `format_output` abstract methods, respectively; and calls the [`LLM`][distilabel.llms.LLM] to generate the text. We can see the [`Task`][distilabel.steps.tasks.Task] as an [`LLM`][distilabel.llms.LLM] powered [`Step`][distilabel.steps.Step]. - -## Working with Tasks - -The subclasses of [`Task`][distilabel.steps.tasks.Task] are intended to be used within the scope of a [`Pipeline`][distilabel.pipeline.Pipeline], which will orchestrate the different tasks defined; but nonetheless, they can be used standalone if needed too. - -For example, the most basic task is the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task, which generates text based on a given instruction, and it can be used standalone as well as within a [`Pipeline`][distilabel.pipeline.Pipeline]. - -```python -from distilabel.steps.tasks import TextGeneration - -task = TextGeneration( - name="text-generation", - llm=OpenAILLM(model="gpt-4"), -) -task.load() - -next(task.process([{"instruction": "What's the capital of Spain?"}])) -# [{'instruction': "What's the capital of Spain?", "generation": "The capital of Spain is Madrid.", "model_name": "gpt-4"}] -``` - -!!! NOTE - The `load` method needs to be called ALWAYS if using the tasks as standalone, otherwise, if the [`Pipeline`][distilabel.pipeline.Pipeline] context manager is used, there's no need to call that method, since it will be automatically called on `Pipeline.run`; but in any other case the method `load` needs to be called from the parent class e.g. a [`Task`][distilabel.steps.tasks.Task] with an [`LLM`][distilabel.llms.LLM] will need to call `Task.load` to load both the task and the LLM. - -## Defining custom Tasks - -In order to define custom tasks, we need to inherit from the [`Task`][distilabel.steps.tasks.Task] class and implement the `format_input` and `format_output` methods, as well as setting the properties `inputs` and `outputs`, as for [`Step`][distilabel.steps.Step] subclasses. - -So on, the following will need to be defined: - -- `inputs`: is a property that returns a list of strings with the names of the required input fields. - -- `format_input`: is a method that receives a dictionary with the input data and returns a [`ChatType`][distilabel.steps.tasks.ChatType], which is basically a list of dictionaries with the input data formatted for the [`LLM`][distilabel.llms.LLM] following [the chat-completion OpenAI formatting](https://platform.openai.com/docs/guides/text-generation). It's important to note that the [`ChatType`][distilabel.steps.tasks.ChatType] is a list of dictionaries, where each dictionary represents a turn in the conversation, and it must contain the keys `role` and `content`, and this is done like this since the [`LLM`][distilabel.llms.LLM] subclasses will format that according to the LLM used, since it's the most standard formatting. - -- `outputs`: is a property that returns a list of strings with the names of the output fields. Note that since all the [`Task`][distilabel.steps.tasks.Task] subclasses are designed to work with a single [`LLM`][distilabel.llms.LLM], this property should always include `model_name` as one of the outputs, since that's automatically injected from the LLM. - -- `format_output`: is a method that receives the output from the [`LLM`][distilabel.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output, since that's automatically injected from the LLM in the `process` method of the [`Task`][distilabel.steps.tasks.Task]. - -Once those methods have been implemented, the task can be used as any other task, and it will be able to generate text based on the input data. - -```python -from typing import Any, Dict, List, Union - -from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import ChatType - - -class MyCustomTask(Task): - @property - def inputs(self) -> List[str]: - return ["input_field"] - - def format_input(self, input: Dict[str, Any]) -> ChatType: - return [ - { - "role": "user", - "content": input["input_field"], - }, - ] - - @property - def outputs(self) -> List[str]: - return ["output_field", "model_name"] - - def format_output( - self, output: Union[str, None], input: Dict[str, Any] - ) -> Dict[str, Any]: - return {"output_field": output} -``` diff --git a/docs/sections/pipeline_samples/examples/index.md b/docs/sections/pipeline_samples/examples/index.md index c6117e3ddb..ffcadb3199 100644 --- a/docs/sections/pipeline_samples/examples/index.md +++ b/docs/sections/pipeline_samples/examples/index.md @@ -2,7 +2,7 @@ This section contains different example pipelines that showcase different tasks, maybe you can take inspiration from them. -### llama.cpp with outlines +### [llama.cpp with `outlines`](#llama-cpp-with-outlines) Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `distilabel`. @@ -10,7 +10,7 @@ Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `dis This script makes use of [`LlamaCppLLM`][distilabel.llms.llamacpp.LlamaCppLLM] and the structured output capabilities thanks to [`outlines`](https://outlines-dev.github.io/outlines/welcome/) to generate RPG characters that adhere to a JSON schema. - It makes use of a local model which can be downlaoded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM]. + It makes use of a local model which can be downloaded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM]. ??? Run @@ -21,3 +21,58 @@ Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `dis ```python title="structured_generation_with_outlines.py" --8<-- "examples/structured_generation_with_outlines.py" ``` + + +### [MistralAI with `instructor`](#mistralai-with-instructor) + +Answer instructions with knowledge graphs defined as `pydantic.BaseModel` objects using `instructor` in `distilabel`. + +??? Example "See example" + + This script makes use of [`MistralLLM`][distilabel.llms.mistral.MistralLLM] and the structured output capabilities thanks to [`instructor`](https://python.useinstructor.com/) to generate knowledge graphs from complex topics. + + This example is translated from this [awesome example](https://python.useinstructor.com/examples/knowledge_graph/) from `instructor` cookbook. + + ??? Run + + ```python + python examples/structured_generation_with_instructor.py + ``` + + ```python title="structured_generation_with_instructor.py" + --8<-- "examples/structured_generation_with_instructor.py" + ``` + + ??? "Visualizing the graphs" + + Want to see how to visualize the graphs? You can test it using the following script. Generate some samples on your own and take a look: + + !!! NOTE + + This example uses graphviz to render the graph, you can install with `pip` in the following way: + + ```console + pip install graphviz + ``` + + ```python + python examples/draw_kg.py 2 # You can pass 0,1,2 to visualize each of the samples. + ``` + + ![Knowledge graph figure](../../../assets/images/sections/examples/knowledge-graph-example.png) + + +### [Benchmarking with `distilabel`: Arena Hard](#benchmarking-with-distilabel-arena-hard) + +Benchmark LLMs with `distilabel`: reproducing the Arena Hard benchmark. + +??? Example "See example" + + The script below first defines both the `ArenaHard` and the `ArenaHardResults` tasks, so as to generate responses for a given collection of prompts/questions with up to two LLMs, and then calculate the results as per the original implementation, respectively. Additionally, the second part of the example builds a `Pipeline` to run the generation on top of the prompts with `InferenceEndpointsLLM` while streaming the rest of the generations from a pre-computed set of GPT-4 generations, and then evaluate one against the other with `OpenAILLM` generating an alternate response, a comparison between the responses, and a result as A>>B, A>B, B>A, B>>A, or tie. + + To run this example you will first need to install the Arena Hard optional dependencies, being `pandas`, `scikit-learn`, and `numpy`. + + ```python title="arena_hard.py" + --8<-- "examples/arena_hard.py" + ``` + diff --git a/docs/sections/pipeline_samples/index.md b/docs/sections/pipeline_samples/index.md deleted file mode 100644 index e2fa065098..0000000000 --- a/docs/sections/pipeline_samples/index.md +++ /dev/null @@ -1,3 +0,0 @@ -# Pipeline Samples - -Take a look at this section to see some [`Examples`](./examples/index.md) of pipelines ready to run or go visit the [`Papers`](./papers/index.md) section for more structured implementations of pipelines seen in the literature. \ No newline at end of file diff --git a/docs/sections/pipeline_samples/papers/deita.md b/docs/sections/pipeline_samples/papers/deita.md index f4d3014464..5c9036d756 100644 --- a/docs/sections/pipeline_samples/papers/deita.md +++ b/docs/sections/pipeline_samples/papers/deita.md @@ -38,7 +38,7 @@ Import distilabel: ```python from distilabel.llms import TransformersLLM, OpenAILLM from distilabel.pipeline import Pipeline -from distilabel.steps import ConversationTemplate, DeitaFiltering, ExpandColumns, LoadHubDataset +from distilabel.steps import ConversationTemplate, DeitaFiltering, ExpandColumns, LoadDataFromHub from distilabel.steps.tasks import ComplexityScorer, EvolInstruct, EvolQuality, GenerateEmbeddings, QualityScorer ``` @@ -47,7 +47,7 @@ Define the distilabel Pipeline and load the dataset from the Hugging Face Hub. ```python pipeline = Pipeline(name="DEITA") -load_data = LoadHubDataset( +load_data = LoadDataFromHub( name="load_data", batch_size=100, output_mappings={"prompt": "instruction"}, pipeline=pipeline ) ``` diff --git a/docs/sections/pipeline_samples/papers/instruction_backtranslation.md b/docs/sections/pipeline_samples/papers/instruction_backtranslation.md index c5c1968712..cd05c9aafe 100644 --- a/docs/sections/pipeline_samples/papers/instruction_backtranslation.md +++ b/docs/sections/pipeline_samples/papers/instruction_backtranslation.md @@ -30,7 +30,7 @@ And since we will be using [`InferenceEndpointsLLM`][distilabel.llms.InferenceEn #### Building blocks -- [`LoadHubDataset`][distilabel.steps.LoadHubDataset]: Generator Step to load a dataset from the Hugging Face Hub. +- [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub]: Generator Step to load a dataset from the Hugging Face Hub. - [`TextGeneration`][distilabel.steps.tasks.TextGeneration]: Task to generate responses for a given instruction using an LLM. - [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM]: LLM that runs a model from an Inference Endpoint in the Hugging Face Hub. - [`InstructionBacktranslation`][distilabel.steps.tasks.InstructionBacktranslation]: Task that generates a score and a reason for a response for a given instruction using the Self Alignment with Instruction Backtranslation prompt. @@ -43,12 +43,12 @@ As mentioned before, we will put the previously mentioned building blocks togeth ```python from distilabel.llms import InferenceEndpointsLLM, OpenAILLM from distilabel.pipeline import Pipeline -from distilabel.steps import LoadHubDataset +from distilabel.steps import LoadDataFromHub from distilabel.steps.tasks import InstructionBacktranslation, TextGeneration with Pipeline(name="self-alignment-with-instruction-backtranslation") as pipeline: - load_hub_dataset = LoadHubDataset( + load_hub_dataset = LoadDataFromHub( name="load_dataset", output_mappings={"prompt": "instruction"}, ) diff --git a/docs/sections/pipeline_samples/papers/prometheus.md b/docs/sections/pipeline_samples/papers/prometheus.md new file mode 100644 index 0000000000..7f7b1d19d5 --- /dev/null +++ b/docs/sections/pipeline_samples/papers/prometheus.md @@ -0,0 +1,121 @@ +# Prometheus 2 + +["Prometheus 2: An Open Source Language Model Specialized in Evaluating Other Language Models"](https://arxiv.org/pdf/2405.01535) presents Prometheus 2, a new and more powerful evaluator LLM compared to Prometheus (its predecessor) presented in ["Prometheus: Inducing Fine-grained Evaluation Capability in Language Models"](https://arxiv.org/abs/2310.08491); since GPT-4, as well as other proprietary LLMs, are commonly used to asses the quality of the responses for various LLMs, but there are concerns about transparency, controllability, and affordability, that motivate the need of open-source LLMs specialized in evaluations. + +Existing open evaluator LMs exhibit critical shortcomings: + +1. They issue scores that significantly diverge from those assigned by humans. +2. They lack the flexibility to perform both direct assessment and pairwise ranking, the two most prevalent forms of assessment. + +Additionally, they do not possess the ability to evaluate based on custom evaluation criteria, focusing instead on general attributes like helpfulness and harmlessness. Prometheus 2 is capable of processing both direct assessment and pair-wise ranking formats grouped with a user-defined evaluation criteria. + +Prometheus 2 released two variants: + +- [`prometheus-eval/prometheus-7b-v2.0`](https://hf.co/prometheus-eval/prometheus-7b-v2.0): fine-tuned on top of [`mistralai/Mistral-7B-Instruct-v0.2`](https://hf.co/mistralai/Mistral-7B-Instruct-v0.2) +- [`prometheus-eval/prometheus-8x7b-v2.0`](https://hf.co/prometheus-eval/prometheus-8x7b-v2.0): fine-tuned on top of [`mistralai/Mixtral-8x7B-Instruct-v0.1`](https://hf.co/mistralai/Mixtral-8x7B-Instruct-v0.1) + +Both models have been fine-tuned for both direct assessment and pairwise ranking tasks i.e. assessing the quality of a single isolated response for a given instruction with or without a reference answer, and assessing the quality of one response against another one for a given instruction with or without a reference answer, respectively. + +On four direct assessment benchmarks and four pairwise ranking benchmarks, Prometheus 2 scores the highest correlation and agreement with humans and proprietary LM judges among all tested open evaluator LMs. Their models, code, and data are all publicly available at [`prometheus-eval/prometheus-eval`](https://github.com/prometheus-eval/prometheus-eval). + +### Replication + +!!! NOTE + The section is named `Replication` but in this case we're not replicating the Prometheus 2 paper per se, but rather showing how to use the [`PrometheusEval`][distilabel.steps.tasks.PrometheusEval] task implemented within `distilabel` to evaluate the quality of the responses from a given instruction using the Prometheus 2 model. + +To showcase Prometheus 2 we will be using the [`PrometheusEval`][distilabel.steps.tasks.PrometheusEval] task implemented in `distilabel` and a smaller dataset created by the Hugging Face H4 team named [`HuggingFaceH4/instruction-dataset`](https://hf.co/datasets/HuggingFaceH4/instruction-dataset) for testing purposes. + +#### Installation + +To reproduce the code below, one will need to install `distilabel` as it follows: + +```bash +pip install "distilabel[vllm]>=1.1.0" +``` + +Alternatively, it's recommended to install [`Dao-AILab/flash-attention`](https://github.com/Dao-AILab/flash-attention) to benefit from Flash Attention 2 speed ups during inference via `vllm`. + +```bash +pip install flash-attn --no-build-isolation +``` + +!!! NOTE + The installation notes above assume that you are using a VM with one GPU accelerator with at least the required VRAM to fit [`prometheus-eval/prometheus-7b-v2.0`](https://hf.co/prometheus-eval/prometheus-7b-v2.0) in bfloat16 (28GB); but if you have enough VRAM to fit their 8x7B model in bfloat16 (~90GB) you can use [`prometheus-eval/prometheus-8x7b-v2.0`](https://hf.co/prometheus-eval/prometheus-8x7b-v2.0) instead. + +#### Building blocks + +- [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub]: [`GeneratorStep`][distilabel.steps.GeneratorStep] to load a dataset from the Hugging Face Hub. + +- [`PrometheusEval`][distilabel.steps.tasks.PrometheusEval]: [`Task`][distilabel.steps.tasks.Task] that assesses the quality of a response for a given instruction using any of the Prometheus 2 models. + - [`vLLM`][distilabel.llms.vLLM]: [`LLM`][distilabel.llms.LLM] that loads a model from the Hugging Face Hub via [vllm-project/vllm](https://github.com/vllm-project/vllm). + + !!! NOTE + Since the Prometheus 2 models use a slightly different chat template than [`mistralai/Mistral-7B-Instruct-v0.2`](https://hf.co/mistralai/Mistral-7B-Instruct-v0.2), we need to set the `chat_template` parameter to `[INST] {{ messages[0]['content'] }}\n{{ messages[1]['content'] }}[/INST]` so as to properly format the input for Prometheus 2. + +- (Optional) [`KeepColumns`][distilabel.steps.KeepColumns]: [`Task`][distilabel.steps.tasks.Task] that keeps only the specified columns in the dataset, used to remove the undesired columns. + +#### Code + +As mentioned before, we will put the previously mentioned building blocks together to see how Prometheus 2 can be used via `distilabel`. + +```python +from distilabel.llms import vLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import KeepColumns, LoadDataFromHub +from distilabel.steps.tasks import PrometheusEval + +if __name__ == "__main__": + with Pipeline(name="prometheus") as pipeline: + load_dataset = LoadDataFromHub( + name="load_dataset", + repo_id="HuggingFaceH4/instruction-dataset", + split="test", + output_mappings={"prompt": "instruction", "completion": "generation"}, + ) + + task = PrometheusEval( + name="task", + llm=vLLM( + model="prometheus-eval/prometheus-7b-v2.0", + chat_template="[INST] {{ messages[0]['content'] }}\n{{ messages[1]['content'] }}[/INST]", + ), + mode="absolute", + rubric="factual-validity", + reference=False, + num_generations=1, + group_generations=False, + ) + + keep_columns = KeepColumns( + name="keep_columns", + columns=["instruction", "generation", "feedback", "result", "model_name"], + ) + + load_dataset >> task >> keep_columns +``` + +Then we need to call `pipeline.run` with the runtime parameters so that the pipeline can be launched. + +```python +distiset = pipeline.run( + parameters={ + task.name: { + "llm": { + "generation_kwargs": { + "max_new_tokens": 1024, + "temperature": 0.7, + }, + }, + }, + }, +) +``` + +Finally, we can optionally push the generated dataset, named [`Distiset`][distilabel.distiset.Distiset], to the Hugging Face Hub via the `push_to_hub` method, so that each subset generated in the leaf steps is pushed to the Hub. + +```python +distiset.push_to_hub( + "instruction-dataset-prometheus", + private=True, +) +``` diff --git a/docs/sections/pipeline_samples/papers/ultrafeedback.md b/docs/sections/pipeline_samples/papers/ultrafeedback.md index b63702ef67..df36e0345b 100644 --- a/docs/sections/pipeline_samples/papers/ultrafeedback.md +++ b/docs/sections/pipeline_samples/papers/ultrafeedback.md @@ -24,7 +24,7 @@ And since we will be using `vllm` we will need to use a VM with at least 6 NVIDI #### Building blocks -- [`LoadHubDataset`][distilabel.steps.LoadHubDataset]: Generator Step to load a dataset from the Hugging Face Hub. +- [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub]: Generator Step to load a dataset from the Hugging Face Hub. - [`sample_n_steps`][distilabel.pipeline.sample_n_steps]: Function to create a `routing_batch_function` that samples `n` downstream steps for each batch generated by the upstream step. This is the key to replicate the LLM pooling mechanism described in the paper. - [`TextGeneration`][distilabel.steps.tasks.TextGeneration]: Task to generate responses for a given instruction using an LLM. - [`vLLM`][distilabel.llms.vLLM]: LLM that loads a model from the Hugging Face Hub using `vllm`. @@ -44,7 +44,7 @@ from distilabel.pipeline import Pipeline, sample_n_steps from distilabel.steps import ( CombineColumns, KeepColumns, - LoadHubDataset, + LoadDataFromHub, PreferenceToArgilla, ) from distilabel.steps.tasks import TextGeneration, UltraFeedback @@ -53,7 +53,7 @@ sample_three_llms = sample_n_steps(n=3) with Pipeline(name="ultrafeedback-pipeline") as pipeline: - load_hub_dataset = LoadHubDataset( + load_hub_dataset = LoadDataFromHub( name="load_dataset", output_mappings={"prompt": "instruction"}, batch_size=2, diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css index e32447d5d9..538a4b4776 100644 --- a/docs/stylesheets/extra.css +++ b/docs/stylesheets/extra.css @@ -1,15 +1,22 @@ +@import url('https://fonts.googleapis.com/css2?family=Inter:wght@100..600&display=swap'); + :root { - --md-primary-fg-color: #FF675F; - --md-primary-fg-color--light: #FF675F; - --md-primary-fg-color--dark: #FF675F; + --md-primary-fg-color: #f2a8ff; + --md-primary-fg-color--light: #f2a8ff; + --md-primary-fg-color--dark: #f2a8ff; + --md-text-font: "Inter"; } [data-md-color-scheme="default"] { --md-primary-fg-color: #000000; - --md-typeset-a-color: #FF675F; - --md-accent-fg-color: #F7A399; + --md-typeset-a-color: #9c50c2; + --md-accent-fg-color: #c57fed; } [data-md-color-scheme="slate"] { --md-primary-fg-color: #000000; - --md-typeset-a-color: #F7A399; - --md-accent-fg-color: #FF675F; -} \ No newline at end of file + --md-typeset-a-color: #ca77d8; + --md-accent-fg-color: #f2a8ff; +} + +.md-sidebar__scrollwrap:focus-within, .md-sidebar__scrollwrap:hover { + scrollbar-color: var(--md-default-fg-color--lighter) #0000; +} diff --git a/docs/stylesheets/fonts/FontAwesome.otf b/docs/stylesheets/fonts/FontAwesome.otf new file mode 100644 index 0000000000..401ec0f36e Binary files /dev/null and b/docs/stylesheets/fonts/FontAwesome.otf differ diff --git a/docs/stylesheets/fonts/fontawesome-webfont.eot b/docs/stylesheets/fonts/fontawesome-webfont.eot new file mode 100644 index 0000000000..e9f60ca953 Binary files /dev/null and b/docs/stylesheets/fonts/fontawesome-webfont.eot differ diff --git a/docs/stylesheets/fonts/fontawesome-webfont.svg b/docs/stylesheets/fonts/fontawesome-webfont.svg new file mode 100644 index 0000000000..52c0773359 --- /dev/null +++ b/docs/stylesheets/fonts/fontawesome-webfont.svg @@ -0,0 +1,2671 @@ + + + + +Created by FontForge 20120731 at Mon Oct 24 17:37:40 2016 + By ,,, +Copyright Dave Gandy 2016. All rights reserved. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/stylesheets/fonts/fontawesome-webfont.ttf b/docs/stylesheets/fonts/fontawesome-webfont.ttf new file mode 100644 index 0000000000..35acda2fa1 Binary files /dev/null and b/docs/stylesheets/fonts/fontawesome-webfont.ttf differ diff --git a/docs/stylesheets/fonts/fontawesome-webfont.woff b/docs/stylesheets/fonts/fontawesome-webfont.woff new file mode 100644 index 0000000000..400014a4b0 Binary files /dev/null and b/docs/stylesheets/fonts/fontawesome-webfont.woff differ diff --git a/docs/stylesheets/fonts/fontawesome-webfont.woff2 b/docs/stylesheets/fonts/fontawesome-webfont.woff2 new file mode 100644 index 0000000000..4d13fc6040 Binary files /dev/null and b/docs/stylesheets/fonts/fontawesome-webfont.woff2 differ diff --git a/examples/arena_hard.py b/examples/arena_hard.py new file mode 100644 index 0000000000..81bec55ace --- /dev/null +++ b/examples/arena_hard.py @@ -0,0 +1,458 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict, List, Optional, Union + +from distilabel.steps import GlobalStep, StepInput +from distilabel.steps.tasks.base import Task +from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.typing import StepOutput +from typing_extensions import override + + +class ArenaHard(Task): + """Evaluates two assistant responses using an LLM as judge. + + This `Task` is based on the "From Live Data to High-Quality Benchmarks: The + Arena-Hard Pipeline" paper that presents Arena Hard, which is a benchmark for + instruction-tuned LLMs that contains 500 challenging user queries. GPT-4 is used + as the judge to compare the model responses against a baseline model, which defaults + to `gpt-4-0314`. + + Note: + Arena-Hard-Auto has the highest correlation and separability to Chatbot Arena + among popular open-ended LLM benchmarks. + + Input columns: + - instruction (`str`): The instruction to evaluate the responses. + - generations (`List[str]`): The responses generated by two, and only two, LLMs. + + Output columns: + - evaluation (`str`): The evaluation of the responses generated by the LLMs. + - score (`str`): The score extracted from the evaluation. + - model_name (`str`): The model name used to generate the evaluation. + + Categories: + - benchmark + + References: + - [From Live Data to High-Quality Benchmarks: The Arena-Hard Pipeline](https://lmsys.org/blog/2024-04-19-arena-hard/) + - [`arena-hard-auto`](https://github.com/lm-sys/arena-hard-auto/tree/main) + + Examples: + + Evaluate two assistant responses for a given instruction using Arean Hard prompts: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps import CombineColumns, LoadDataFromDicts + from distilabel.steps.tasks import ArenaHard, TextGeneration + + with Pipeline() as pipeline: + load_data = LoadDataFromDicts( + data=[{"instruction": "What is the capital of France?"}], + ) + + text_generation_a = TextGeneration( + llm=..., # LLM instance + output_mappings={"model_name": "generation_model"}, + ) + + text_generation_b = TextGeneration( + llm=..., # LLM instance + output_mappings={"model_name": "generation_model"}, + ) + + combine = CombineColumns( + columns=["generation", "generation_model"], + output_columns=["generations", "generation_models"], + ) + + arena_hard = ArenaHard( + llm=..., # LLM instance + ) + + load_data >> [text_generation_a, text_generation_b] >> combine >> arena_hard + ``` + """ + + @property + def inputs(self) -> List[str]: + """The inputs required by this task are the `instruction` and the `generations`, + which are the responses generated by two, and only two, LLMs.""" + return ["instruction", "generations"] + + def format_input(self, input: Dict[str, Any]) -> ChatType: + """This method formats the input data as a `ChatType` using the prompt defined + by the Arena Hard benchmark, which consists on a `system_prompt` plus a template + for the user first message that contains the `instruction` and both `generations`. + """ + return [ + { + "role": "system", + "content": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user prompt displayed below. You will be given assistant A's answer and assistant B's answer. Your job is to evaluate which assistant's answer is better.\n\nBegin your evaluation by generating your own answer to the prompt. You must provide your answers before judging any answers.\n\nWhen evaluating the assistants' answers, compare both assistants' answers with your answer. You must identify and correct any mistakes or inaccurate information.\n\nThen consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate to what is being asked. Concise means the response is clear and not verbose or excessive.\n\nThen consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing important information in the assistants' answers that would be beneficial to include when responding to the user prompt.\n\nAfter providing your explanation, you must output only one of the following choices as your final verdict with a label:\n\n1. Assistant A is significantly better: [[A>>B]]\n2. Assistant A is slightly better: [[A>B]]\n3. Tie, relatively the same: [[A=B]]\n4. Assistant B is slightly better: [[B>A]]\n5. Assistant B is significantly better: [[B>>A]]\n\nExample output: \"My final verdict is tie: [[A=B]]\".", + }, + { + "role": "user", + "content": f"<|User Prompt|>\n{input['instruction']}\n\n<|The Start of Assistant A's Answer|>\n{input['generations'][0]}\n<|The End of Assistant A's Answer|>\n\n<|The Start of Assistant B's Answer|>\n{input['generations'][1]}\n<|The End of Assistant B's Answer|>", + }, + ] + + @property + def outputs(self) -> List[str]: + """The outputs generated by this task are the `evaluation`, the `score` and + the `model_name` (which is automatically injected within the `process` method + of the parent task).""" + return ["evaluation", "score", "model_name"] + + def format_output( + self, + output: Union[str, None], + input: Union[Dict[str, Any], None] = None, + ) -> Dict[str, Any]: + """This method formats the output generated by the LLM as a Python dictionary + containing the `evaluation` which is the raw output generated by the LLM (consisting + of the judge LLM alternate generation for the given instruction, plus an explanation + on the evaluation of the given responses; plus the `score` extracted from the output. + + Args: + output: the raw output of the LLM. + input: the input to the task. Is provided in case it needs to be used to enrich + the output if needed. + + Returns: + A dict with the keys `evaluation` with the raw output which contains the LLM + evaluation and the extracted `score` if possible. + """ + if output is None: + return {"evaluation": None, "score": None} + pattern = re.compile(r"\[\[([AB<>=]+)\]\]") + match = pattern.search(output) + if match is None: + return {"evaluation": output, "score": None} + return {"evaluation": output, "score": match.group(1)} + + +class ArenaHardResults(GlobalStep): + """Process Arena Hard results to calculate the ELO scores. + + This `Step` is based on the "From Live Data to High-Quality Benchmarks: The + Arena-Hard Pipeline" paper that presents Arena Hard, which is a benchmark for + instruction-tuned LLMs that contains 500 challenging user queries. This step is + a `GlobalStep` that should run right after the `ArenaHard` task to calculate the + ELO scores for the evaluated models. + + Note: + Arena-Hard-Auto has the highest correlation and separability to Chatbot Arena + among popular open-ended LLM benchmarks. + + Input columns: + - evaluation (`str`): The evaluation of the responses generated by the LLMs. + - score (`str`): The score extracted from the evaluation. + + References: + - [From Live Data to High-Quality Benchmarks: The Arena-Hard Pipeline](https://lmsys.org/blog/2024-04-19-arena-hard/) + - [`arena-hard-auto`](https://github.com/lm-sys/arena-hard-auto/tree/main) + + Examples: + + Rate the ELO scores for two assistant responses for a given an evaluation / comparison between both using Arean Hard prompts: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps import CombineColumns, LoadDataFromDicts + from distilabel.steps.tasks import ArenaHard, TextGeneration + + with Pipeline() as pipeline: + load_data = LoadDataFromDicts( + data=[{"instruction": "What is the capital of France?"}], + ) + + text_generation_a = TextGeneration( + llm=..., # LLM instance + output_mappings={"model_name": "generation_model"}, + ) + + text_generation_b = TextGeneration( + llm=..., # LLM instance + output_mappings={"model_name": "generation_model"}, + ) + + combine = CombineColumns( + columns=["generation", "generation_model"], + output_columns=["generations", "generation_models"], + ) + + arena_hard = ArenaHard( + llm=..., # LLM instance + ) + + arena_hard_results = ArenaHardResults( + custom_model_column="generation_models", + custom_weights={"A>B": 1, "A>>B": 3, "B>A": 1, "B>>A": 3}, + ) + + load_data >> [text_generation_a, text_generation_b] >> combine >> arena_hard >> arena_hard_results + ``` + + """ + + custom_model_column: Optional[str] = None + custom_weights: Dict[str, int] = {"A>B": 1, "A>>B": 3, "B>A": 1, "B>>A": 3} + + def load(self) -> None: + """Ensures that the required dependencies are installed.""" + super().load() + + try: + import numpy as np # noqa: F401 + import pandas as pd # noqa: F401 + from sklearn.linear_model import LogisticRegression # noqa: F401 + except ImportError as e: + raise ImportError( + "In order to run `ArenaHardResults`, the `arena-hard` extra dependencies" + " must be installed i.e. `numpy`, `pandas`, and `scikit-learn`.\n" + "Please install the dependencies by running `pip install distilabel[arena-hard]`." + ) from e + + # TODO: the `evaluation` is not really required as an input, so it could be removed, since + # only `score` is used / required + @property + def inputs(self) -> List[str]: + """The inputs required by this step are the `evaluation` and the `score` generated + by the `ArenaHard` task. Since this step does use the identifiers `model_a` and `model_b`, + optionally one can set `custom_model_column` to use the model names if existing within + the input data, ideally this value should be `model_name` if connected from the `ArenaHard` + step.""" + columns = ["evaluation", "score"] + if self.custom_model_column: + columns.append(self.custom_model_column) + return columns + + @override + def process(self, inputs: StepInput) -> StepOutput: # type: ignore + """This method processes the inputs generated by the `ArenaHard` task to calculate the + win rates for each of the models to evaluate. Since this step inherits from the `GlobalStep`, + it will wait for all the input batches to be processed, and then the output will be yielded in + case there's a follow up step, since this step won't modify the received inputs. + + Args: + inputs: A list of Python dictionaries with the inputs of the task. + + Yields: + A list of Python dictionaries with the outputs of the task. + + References: + - https://github.com/lm-sys/arena-hard-auto/blob/main/show_result.py + """ + import numpy as np + import pandas as pd + from sklearn.linear_model import LogisticRegression + + models = ["A", "B"] + if self.custom_model_column: + models = inputs[0][self.custom_model_column] + + # TODO: the battles are only calculated for the first game, even though the official + # implementation also covers the possibility of a second game (not within the released + # dataset yet) + battles = pd.DataFrame() + for input in inputs: + output = { + # TODO: "question_id": input["question_id"], + "model_a": models[0], + "model_b": models[1], + } + if input["score"] in ["A>B", "A>>B"]: + output["winner"] = models[0] + rows = [output] * self.custom_weights[input["score"]] + elif input["score"] in ["B>A", "B>>A"]: + output["winner"] = models[1] + rows = [output] * self.custom_weights[input["score"]] + elif input["score"] == "A=B": + output["winner"] = "tie" + rows = [output] + else: + continue + + battles = pd.concat([battles, pd.DataFrame(rows)]) + + models = pd.concat([battles["model_a"], battles["model_b"]]).unique() + models = pd.Series(np.arange(len(models)), index=models) + + battles = pd.concat([battles, battles], ignore_index=True) + p = len(models.index) + n = battles.shape[0] + + X = np.zeros([n, p]) + X[np.arange(n), models[battles["model_a"]]] = +np.log(10) + X[np.arange(n), models[battles["model_b"]]] = -np.log(10) + + Y = np.zeros(n) + Y[battles["winner"] == "model_a"] = 1.0 + + tie_idx = battles["winner"] == "tie" + tie_idx[len(tie_idx) // 2 :] = False + Y[tie_idx] = 1.0 + + lr = LogisticRegression(fit_intercept=False, penalty=None, tol=1e-8) # type: ignore + lr.fit(X, Y) + + # The ELO scores are calculated assuming that the reference is `gpt-4-0314` + # with an starting ELO of 1000, so that the evaluated models are compared with + # `gtp-4-0314` only if it's available within the models + elo_scores = 400 * lr.coef_[0] + 1000 + # TODO: we could parametrize the reference / anchor model, but left as is to be faithful to the + # original implementation + if "gpt-4-0314" in models.index: + elo_scores += 1000 - elo_scores[models["gpt-4-0314"]] + + output = pd.Series(elo_scores, index=models.index).sort_values(ascending=False) + self._logger.info(f"Arena Hard ELO: {output}") + + # Here only so that if follow up steps are connected the inputs are preserved, + # since this step doesn't modify nor generate new inputs + yield inputs + + +if __name__ == "__main__": + import json + + from distilabel.llms import InferenceEndpointsLLM, OpenAILLM + from distilabel.pipeline import Pipeline + from distilabel.steps import ( + CombineColumns, + KeepColumns, + LoadHubDataset, + StepInput, + step, + ) + from distilabel.steps.tasks import TextGeneration + from distilabel.steps.typing import StepOutput + + @step(inputs=["turns"], outputs=["system_prompt", "instruction"]) + def PrepareForTextGeneration(*inputs: StepInput) -> StepOutput: + for input in inputs: + for item in input: + item["system_prompt"] = "You are a helpful assistant." + item["instruction"] = item["turns"][0]["content"] + yield input + + @step( + inputs=["question_id"], + outputs=["generation", "generation_model"], + step_type="global", + ) + def LoadReference(*inputs: StepInput) -> StepOutput: + # File downloaded from https://raw.githubusercontent.com/lm-sys/arena-hard-auto/e0a8ea1df42c1df76451a6cd04b14e31ff992b87/data/arena-hard-v0.1/model_answer/gpt-4-0314.jsonl + lines = open("gpt-4-0314.jsonl", mode="r").readlines() + for input in inputs: + for item in input: + for line in lines: + data = json.loads(line) + if data["question_id"] == item["question_id"]: + item["generation"] = data["choices"][0]["turns"][0]["content"] + item["generation_model"] = data["model_id"] + break + yield input + + with Pipeline(name="arena-hard-v0.1") as pipeline: + load_dataset = LoadHubDataset( + name="load_dataset", + repo_id="alvarobartt/lmsys-arena-hard-v0.1", + split="test", + num_examples=5, + ) + + load_reference = LoadReference(name="load_reference") + + prepare = PrepareForTextGeneration(name="prepare") + + text_generation_cohere = TextGeneration( + name="text_generation_cohere", + llm=InferenceEndpointsLLM( + model_id="CohereForAI/c4ai-command-r-plus", + tokenizer_id="CohereForAI/c4ai-command-r-plus", + ), + use_system_prompt=True, + input_batch_size=10, + output_mappings={"model_name": "generation_model"}, + ) + + combine_columns = CombineColumns( + name="combine_columns", + columns=["generation", "generation_model"], + output_columns=["generations", "generation_models"], + ) + + arena_hard = ArenaHard( + name="arena_hard", + llm=OpenAILLM(model="gpt-4-1106-preview"), + output_mappings={"model_name": "evaluation_model"}, + ) + + keep_columns = KeepColumns( + name="keep_columns", + columns=[ + "question_id", + "category", + "cluster", + "system_prompt", + "instruction", + "generations", + "generation_models", + "evaluation", + "score", + "evaluation_model", + ], + ) + + win_rates = ArenaHardResults( + name="win_rates", custom_model_column="generation_models" + ) + + load_dataset >> load_reference # type: ignore + load_dataset >> prepare >> text_generation_cohere # type: ignore + ( # type: ignore + [load_reference, text_generation_cohere] + >> combine_columns + >> arena_hard + >> keep_columns + >> win_rates + ) + + distiset = pipeline.run( + parameters={ # type: ignore + text_generation_cohere.name: { + "llm": { + "generation_kwargs": { + "temperature": 0.7, + "max_new_tokens": 4096, + "stop_sequences": ["", "<|END_OF_TURN_TOKEN|>"], + } + } + }, + arena_hard.name: { + "llm": { + "generation_kwargs": { + "temperature": 0.0, + "max_new_tokens": 4096, + } + } + }, + }, + ) + if distiset is not None: + distiset.push_to_hub("arena-hard-results") diff --git a/examples/draw_kg.py b/examples/draw_kg.py new file mode 100644 index 0000000000..8d45e40b85 --- /dev/null +++ b/examples/draw_kg.py @@ -0,0 +1,82 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, Dict, List, Union + +from graphviz import Digraph +from pydantic import BaseModel, Field + + +class Node(BaseModel): + id: int + label: str + color: str + + +class Edge(BaseModel): + source: int + target: int + label: str + color: str = "black" + + +class KnowledgeGraph(BaseModel): + nodes: List[Node] = Field(..., default_factory=list) + edges: List[Edge] = Field(..., default_factory=list) + + +def visualize_knowledge_graph(kg: KnowledgeGraph): + dot = Digraph(comment="Knowledge Graph") + + # Add nodes + for node in kg.nodes: + dot.node(str(node.id), node.label, color=node.color) + + # Add edges + for edge in kg.edges: + dot.edge( + str(edge.source), + str(edge.target), + label=edge.label, + color=edge.color or "black", + ) + + # Render the graph + dot.render("knowledge_graph.gv", view=True) + + +def create_knowledge_graph(data: str) -> Union[KnowledgeGraph, None]: + data: Dict[str, Any] = json.loads(data) + + nodes = [Node(**node) for node in data["nodes"]] + edges = [] + for edge in data["edges"]: + if edge.get("color") is None: + edge["color"] = "black" + edges.append(Edge(**edge)) + + return KnowledgeGraph(nodes=nodes, edges=edges) + + +if __name__ == "__main__": + import sys + + args = sys.argv[1:] + + from datasets import load_dataset + + ds = load_dataset("distilabel-internal-testing/knowledge_graphs", split="train") + graphs = [create_knowledge_graph(g) for g in ds["generation"]] + visualize_knowledge_graph(graphs[int(args[0])]) diff --git a/examples/structured_generation_with_instructor.py b/examples/structured_generation_with_instructor.py new file mode 100644 index 0000000000..48082886f4 --- /dev/null +++ b/examples/structured_generation_with_instructor.py @@ -0,0 +1,87 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from distilabel.llms import MistralLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import LoadDataFromDicts +from distilabel.steps.tasks import TextGeneration +from pydantic import BaseModel, Field + + +class Node(BaseModel): + id: int + label: str + color: str + + +class Edge(BaseModel): + source: int + target: int + label: str + color: str = "black" + + +class KnowledgeGraph(BaseModel): + nodes: List[Node] = Field(..., default_factory=list) + edges: List[Edge] = Field(..., default_factory=list) + + +with Pipeline( + name="Knowledge-Graphs", + description=( + "Generate knowledge graphs to answer questions, this type of dataset can be used to " + "steer a model to answer questions with a knowledge graph." + ), +) as pipeline: + sample_questions = [ + "Teach me about quantum mechanics", + "Who is who in The Simpsons family?", + "Tell me about the evolution of programming languages", + ] + + load_dataset = LoadDataFromDicts( + name="load_instructions", + data=[ + { + "system_prompt": "You are a knowledge graph expert generator. Help me understand by describing everything as a detailed knowledge graph.", + "instruction": f"{question}", + } + for question in sample_questions + ], + ) + + text_generation = TextGeneration( + name="knowledge_graph_generation", + llm=MistralLLM( + model="open-mixtral-8x22b", structured_output={"schema": KnowledgeGraph} + ), + input_batch_size=8, + output_mappings={"model_name": "generation_model"}, + ) + load_dataset >> text_generation + + +if __name__ == "__main__": + distiset = pipeline.run( + parameters={ + text_generation.name: { + "llm": {"generation_kwargs": {"max_new_tokens": 2048}} + } + }, + use_cache=False, + ) + + distiset.push_to_hub("distilabel-internal-testing/knowledge_graphs") diff --git a/mkdocs.yml b/mkdocs.yml index 5665237fbf..da01807064 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,5 +1,5 @@ # Project information -site_name: distilabel +site_name: Distilabel Docs site_url: https://argilla-io.github.io/distilabel site_author: Argilla, Inc. site_description: Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs. @@ -11,6 +11,15 @@ repo_url: https://github.com/argilla-io/distilabel extra: version: provider: mike + social: + - icon: fontawesome/brands/linkedin + link: https://www.linkedin.com/company/argilla-io + - icon: fontawesome/brands/x-twitter + link: https://twitter.com/argilla_io + - icon: fontawesome/brands/youtube + link: https://www.youtube.com/channel/UCAIz8TmvQQrLqbD7sd-5S2A + - icon: fontawesome/brands/slack + link: https://join.slack.com/t/rubrixworkspace/shared_invite/zt-20wllqq29-Z11~kp2SeFYjJ0qevJRiPg extra_css: - stylesheets/extra.css @@ -27,12 +36,23 @@ theme: icon: repo: fontawesome/brands/github-alt features: - - navigation.sections # Sections are included in the navigation on the left. - # - toc.integrate # # Table of contents is integrated on the left; does not appear separately on the right. + - navigation.instant + - navigation.sections + - navigation.tabs + - navigation.footer + - navigation.top + - navigation.tracking + - navigation.path - header.autohide # header disappears as you scroll - content.code.copy - content.code.annotate - content.tabs.link + - content.action.edit + - toc.follow + - search.suggest + - search.highlight + - search.share + palette: - media: "(prefers-color-scheme)" primary: white @@ -70,6 +90,7 @@ markdown_extensions: line_spans: __span pygments_lang_class: true - pymdownx.inlinehilite + - pymdownx.keys - pymdownx.superfences: custom_fences: - name: mermaid @@ -93,9 +114,7 @@ plugins: - autorefs # Cross-links to headings - gen-files: scripts: - - docs/scripts/gen_ref_pages.py - - literate-nav: - nav_file: SUMMARY.md + - docs/scripts/gen_popular_issues.py - section-index - mkdocstrings: handlers: @@ -115,68 +134,65 @@ plugins: heading_level: 4 - social - distilabel/components-gallery: - add_after_page: Learn + add_after_page: How-to guides nav: - - Introduction: "index.md" + - Distilabel: "index.md" - Getting started: - - Installation: "sections/installation.md" - - How-to-Guide: "sections/how_to_guide.md" - - Learn: - - "sections/learn/index.md" - - Tutorial: - - "sections/learn/tutorial/index.md" - - Step: - - "sections/learn/tutorial/step/index.md" - - GeneratorStep: "sections/learn/tutorial/step/generator_step.md" - - GlobalStep: "sections/learn/tutorial/step/global_step.md" - - Task: - - "sections/learn/tutorial/task/index.md" - - GeneratorTask: "sections/learn/tutorial/task/generator_task.md" - - LLM: "sections/learn/tutorial/llm/index.md" - - Pipeline: "sections/learn/tutorial/pipeline/index.md" - - CLI: "sections/learn/tutorial/cli/index.md" + - Installation: "sections/getting_started/installation.md" + - Quickstart: "sections/getting_started/quickstart.md" + - FAQ: "sections/getting_started/faq.md" + - How-to guides: + - "sections/how_to_guides/index.md" + - Basic: + - Define Steps for your Pipeline: + - "sections/how_to_guides/basic/step/index.md" + - GeneratorStep: "sections/how_to_guides/basic/step/generator_step.md" + - GlobalStep: "sections/how_to_guides/basic/step/global_step.md" + - Define Tasks that rely on LLMs: + - "sections/how_to_guides/basic/task/index.md" + - GeneratorTask: "sections/how_to_guides/basic/task/generator_task.md" + - Define LLMs as local or remote models: "sections/how_to_guides/basic/llm/index.md" + - Execute Steps and Tasks in a Pipeline: "sections/how_to_guides/basic/pipeline/index.md" - Advanced: - - "sections/learn/advanced/index.md" - - Argilla: "sections/learn/advanced/argilla.md" - - Caching: "sections/learn/advanced/caching.md" - - Distiset: "sections/learn/advanced/distiset.md" - - Structured Generation: "sections/learn/advanced/structured_generation.md" + - Using the Distiset dataset object: "sections/how_to_guides/advanced/distiset.md" + - Export data to Argilla: "sections/how_to_guides/advanced/argilla.md" + - Using a file system to pass data of batches between steps: "sections/how_to_guides/advanced/fs_to_pass_data.md" + - Using CLI to explore and re-run existing Pipelines: "sections/how_to_guides/advanced/cli/index.md" + - Cache and recover pipeline executions: "sections/how_to_guides/advanced/caching.md" + - Structured data generation: "sections/how_to_guides/advanced/structured_generation.md" + - Serving an LLM for sharing it between several tasks: "sections/how_to_guides/advanced/serving_an_llm_for_reuse.md" - Pipeline Samples: - - "sections/pipeline_samples/index.md" - Examples: "sections/pipeline_samples/examples/index.md" - Papers: - "sections/pipeline_samples/papers/index.md" - DEITA: "sections/pipeline_samples/papers/deita.md" - Instruction Backtranslation: "sections/pipeline_samples/papers/instruction_backtranslation.md" - # - Prometheus: "sections/examples/papers/prometheus.md" + - Prometheus 2: "sections/pipeline_samples/papers/prometheus.md" - UltraFeedback: "sections/pipeline_samples/papers/ultrafeedback.md" - - FAQ: "sections/faq.md" - API Reference: - - Pipeline: - - "api/pipeline/index.md" - - Routing Batch Function: "api/pipeline/routing_batch_function.md" - - Typing: "api/pipeline/typing.md" - - Utils: "api/pipeline/utils.md" - Step: - "api/step/index.md" - GeneratorStep: "api/step/generator_step.md" - GlobalStep: "api/step/global_step.md" - "@step": "api/step/decorator.md" - Step Gallery: - - Argilla: "api/step_gallery/argilla.md" - - Columns: "api/step_gallery/columns.md" - - Extra: "api/step_gallery/extra.md" + - Argilla: "api/step_gallery/argilla.md" + - Hugging Face: "api/step_gallery/hugging_face.md" + - Columns: "api/step_gallery/columns.md" + - Extra: "api/step_gallery/extra.md" - Task: - "api/task/index.md" - GeneratorTask: "api/task/generator_task.md" - Task Gallery: "api/task_gallery/index.md" + - Typing: "api/task/typing.md" - LLM: - "api/llm/index.md" - LLM Gallery: - Anthropic: "api/llm/anthropic.md" - Anyscale: "api/llm/anyscale.md" - Azure (via OpenAI): "api/llm/azure.md" + - Cohere: "api/llm/cohere.md" - Groq: "api/llm/groq.md" - Hugging Face: "api/llm/huggingface.md" - LiteLLM: "api/llm/litellm.md" @@ -187,4 +203,13 @@ nav: - Together AI: "api/llm/together.md" - Google Vertex AI: "api/llm/vertexai.md" - vLLM: "api/llm/vllm.md" + - Pipeline: + - "api/pipeline/index.md" + - Routing Batch Function: "api/pipeline/routing_batch_function.md" + - Typing: "api/pipeline/typing.md" + - Utils: "api/pipeline/utils.md" + - Distiset: "api/distiset.md" - CLI: "api/cli.md" + - Community: + - sections/community/index.md + - Issue dashboard: sections/community/popular_issues.md diff --git a/pyproject.toml b/pyproject.toml index c3317bf5b5..b94387ff51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,11 +17,14 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "datasets >= 2.14.0", + # Bump `datasets` to support `load_dataset` from cache + # Ref https://github.com/huggingface/datasets/releases/tag/2.16.0 + "datasets >= 2.16.0", "httpx >= 0.25.2", "importlib-resources >= 6.1.1; python_version < '3.9'", "Jinja2 >= 3.1.2", @@ -33,6 +36,9 @@ dependencies = [ "scipy >= 1.10.0", "typer >= 0.9.0", "tblib >= 3.0.0", + "orjson >= 3.10.0", + "universal_pathlib >= 0.2.2", + "portalocker >= 2.8.2", ] dynamic = ["version"] @@ -43,9 +49,9 @@ distilabel = "distilabel.cli.app:app" "distilabel/components-gallery" = "distilabel.utils.mkdocs.components_gallery:ComponentsGalleryPlugin" [project.optional-dependencies] -dev = ["ruff == 0.2.2", "pre-commit >= 3.5.0"] +dev = ["ruff == 0.4.5", "pre-commit >= 3.5.0"] docs = [ - "mkdocs-material >= 9.5.0", + "mkdocs-material >=9.5.17", "mkdocstrings[python] >= 0.24.0", "mkdocs-literate-nav >= 0.6.1", "mkdocs-section-index >= 0.3.8", @@ -54,16 +60,24 @@ docs = [ "Pillow >= 9.5.0", "CairoSVG >= 2.7.1", "mknotebooks >= 0.8.0", + "pandas >= 2.0", +] +tests = [ + "pytest >= 7.4.0", + "pytest-asyncio", + "nest-asyncio", + "pytest-timeout", + "pytest-codspeed", ] -tests = ["pytest >= 7.4.0", "pytest-asyncio", "nest-asyncio", "pytest-timeout"] # Optional LLMs, integrations, etc anthropic = ["anthropic >= 0.20.0"] -argilla = ["argilla >= 1.23.0"] +argilla = ["argilla >= 1.29.0"] cohere = ["cohere >= 5.2.0"] groq = ["groq >= 0.4.1"] hf-inference-endpoints = ["huggingface_hub >= 0.19.0"] hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"] +instructor = ["instructor >= 1.2.3"] litellm = ["litellm >= 1.30.0"] llama-cpp = ["llama-cpp-python >= 0.2.0"] mistralai = ["mistralai >= 0.1.0"] @@ -71,7 +85,7 @@ ollama = ["ollama >= 0.1.7"] openai = ["openai >= 1.0.0"] outlines = ["outlines >= 0.0.40"] vertexai = ["google-cloud-aiplatform >= 1.38.0"] -vllm = ["vllm >= 0.2.1", "filelock >= 3.13.4"] +vllm = ["vllm >= 0.4.0", "outlines == 0.0.34", "filelock >= 3.13.4"] [project.urls] Documentation = "https://distilabel.argilla.io/" diff --git a/scripts/install_dependencies.sh b/scripts/install_dependencies.sh new file mode 100755 index 0000000000..4da6ad9dd4 --- /dev/null +++ b/scripts/install_dependencies.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +set -e + +python_version=$(python -c "import sys; print(sys.version_info[:2])") + +python -m pip install uv + +uv pip install --system -e ".[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai]" +if [ "${python_version}" != "(3, 8)" ]; then + uv pip install --system -e .[mistralai,instructor] +fi + +uv pip install --system git+https://github.com/argilla-io/LLM-Blender.git diff --git a/src/distilabel/__init__.py b/src/distilabel/__init__.py index be49c91677..574194b287 100644 --- a/src/distilabel/__init__.py +++ b/src/distilabel/__init__.py @@ -14,6 +14,6 @@ from rich import traceback as rich_traceback -__version__ = "1.1.1" +__version__ = "1.2.0" rich_traceback.install(show_locals=True) diff --git a/src/distilabel/cli/pipeline/utils.py b/src/distilabel/cli/pipeline/utils.py index 59d006b523..869e0e6db1 100644 --- a/src/distilabel/cli/pipeline/utils.py +++ b/src/distilabel/cli/pipeline/utils.py @@ -120,7 +120,8 @@ def get_pipeline(config: str) -> "BasePipeline": FileNotFoundError: If the configuration file does not exist. """ if valid_http_url(config): - return Pipeline.from_dict(get_config_from_url(config)) + data = get_config_from_url(config) + return Pipeline.from_dict(data) if Path(config).is_file(): return Pipeline.from_file(config) @@ -200,9 +201,15 @@ def _build_steps_panel(pipeline: "BasePipeline") -> "Panel": from rich.table import Table def _add_rows( - table: Table, runtime_params: List[Dict[str, Any]], prefix: str = "" + table: Table, + runtime_params: List[Dict[str, Any]], + prefix: str = "", ) -> None: for param in runtime_params: + if isinstance(param, str): + _add_rows(table, runtime_params[param], f"{prefix}{param}.") + continue + # nested (for example `LLM` in `Task`) if "runtime_parameters_info" in param: _add_rows( @@ -210,22 +217,22 @@ def _add_rows( runtime_params=param["runtime_parameters_info"], prefix=f"{prefix}{param['name']}.", ) - continue - # `LLM` special case - if "keys" in param: + elif "keys" in param: _add_rows( table=table, runtime_params=param["keys"], prefix=f"{prefix}{param['name']}.", ) - continue - - optional = param.get("optional", "") - if optional != "": - optional = "Yes" if optional else "No" + return + else: + optional = param.get("optional", "") + if optional != "": + optional = "Yes" if optional else "No" - table.add_row(prefix + param["name"], param.get("description"), optional) + table.add_row( + prefix + param["name"], param.get("description"), optional + ) steps = [] for step_name, runtime_params in pipeline.get_runtime_parameters_info().items(): @@ -239,7 +246,7 @@ def _add_rows( expand=True, ) - table.add_column("Runtime parameter", style="dim", width=50) + table.add_column("Runtime parameter", style="dim", width=60) table.add_column("Description", width=100) table.add_column("Optional", justify="right") _add_rows(table, runtime_params) diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index b9bd1b0e03..f5e8cc7b3a 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -13,20 +13,31 @@ # limitations under the License. import logging +import os.path as posixpath import re +from os import PathLike from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Final, Optional, Union -from datasets import load_dataset +import fsspec +import yaml +from datasets import Dataset, load_dataset, load_from_disk +from datasets.filesystems import is_remote_filesystem from huggingface_hub import DatasetCardData, HfApi +from huggingface_hub.file_download import hf_hub_download from pyarrow.lib import ArrowInvalid +from typing_extensions import Self -from distilabel.steps.tasks.base import DISTILABEL_METADATA_KEY from distilabel.utils.card.dataset_card import ( DistilabelDatasetCard, size_categories_parser, ) from distilabel.utils.files import list_files_in_dir +from distilabel.utils.huggingface import get_hf_token + +DISTISET_CONFIG_FOLDER: Final[str] = "distiset_configs" +PIPELINE_CONFIG_FILENAME: Final[str] = "pipeline.yaml" +PIPELINE_LOG_FILENAME: Final[str] = "pipeline.log" class Distiset(dict): @@ -41,8 +52,8 @@ class Distiset(dict): pipeline. """ - pipeline_path: Optional[Path] = None - log_filename_path: Optional[Path] = None + _pipeline_path: Optional[Path] = None + _log_filename_path: Optional[Path] = None def push_to_hub( self, @@ -71,7 +82,13 @@ def push_to_hub( Whether to generate a dataset card or not. Defaults to True. **kwargs: Additional keyword arguments to pass to the `push_to_hub` method of the `datasets.Dataset` object. + + Raises: + ValueError: If no token is provided and couldn't be retrieved automatically. """ + if token is None: + token = get_hf_token(self.__class__.__name__, "token") + for name, dataset in self.items(): dataset.push_to_hub( repo_id=repo_id, @@ -84,14 +101,23 @@ def push_to_hub( if generate_card: self._generate_card(repo_id, token) - def _generate_card(self, repo_id: str, token: Optional[str]) -> None: - """Generates a dataset card and pushes it to the Hugging Face Hub, and - if the `pipeline.yaml` path is available in the `Distiset`, uploads that - to the same repository. + def _get_card( + self, repo_id: str, token: Optional[str] = None + ) -> DistilabelDatasetCard: + """Generates the dataset card for the `Distiset`. + + Note: + If `repo_id` and `token` are provided, it will extract the metadata from the README.md file + on the hub. Args: - repo_id: The ID of the repository to push to, from the `push_to_hub` method. - token: The token to authenticate with the Hugging Face Hub, from the `push_to_hub` method. + repo_id: Name of the repository to push to, or the path for the distiset if saved to disk. + token: The token to authenticate with the Hugging Face Hub. + We assume that if it's provided, the dataset will be in the Hugging Face Hub, + so the README metadata will be extracted from there. + + Returns: + The dataset card for the `Distiset`. """ sample_records = {} for name, dataset in self.items(): @@ -99,8 +125,12 @@ def _generate_card(self, repo_id: str, token: Optional[str]) -> None: dataset[0] if not isinstance(dataset, dict) else dataset["train"][0] ) + readme_metadata = {} + if repo_id and token: + readme_metadata = self._extract_readme_metadata(repo_id, token) + metadata = { - **self._extract_readme_metadata(repo_id, token), + **readme_metadata, "size_categories": size_categories_parser( max(len(dataset) for dataset in self.values()) ), @@ -112,29 +142,8 @@ def _generate_card(self, repo_id: str, token: Optional[str]) -> None: repo_id=repo_id, sample_records=sample_records, ) - card.push_to_hub( - repo_id, - repo_type="dataset", - token=token, - ) - if self.pipeline_path: - # If the pipeline.yaml is available, upload it to the Hugging Face Hub as well. - HfApi().upload_file( - path_or_fileobj=self.pipeline_path, - path_in_repo="pipeline.yaml", - repo_id=repo_id, - repo_type="dataset", - token=token, - ) - if self.log_filename_path: - # The same we had with "pipeline.yaml" but with the log file. - HfApi().upload_file( - path_or_fileobj=self.log_filename_path, - path_in_repo="pipeline.log", - repo_id=repo_id, - repo_type="dataset", - token=token, - ) + + return card def _extract_readme_metadata( self, repo_id: str, token: Optional[str] @@ -151,11 +160,6 @@ def _extract_readme_metadata( Returns: The metadata extracted from the README.md file of the dataset repository as a dict. """ - import re - - import yaml - from huggingface_hub.file_download import hf_hub_download - readme_path = Path( hf_hub_download(repo_id, "README.md", repo_type="dataset", token=token) ) @@ -164,12 +168,47 @@ def _extract_readme_metadata( metadata = yaml.safe_load(metadata) return metadata + def _generate_card(self, repo_id: str, token: str) -> None: + """Generates a dataset card and pushes it to the Hugging Face Hub, and + if the `pipeline.yaml` path is available in the `Distiset`, uploads that + to the same repository. + + Args: + repo_id: The ID of the repository to push to, from the `push_to_hub` method. + token: The token to authenticate with the Hugging Face Hub, from the `push_to_hub` method. + """ + card = self._get_card(repo_id=repo_id, token=token) + + card.push_to_hub( + repo_id, + repo_type="dataset", + token=token, + ) + if self.pipeline_path: + # If the pipeline.yaml is available, upload it to the Hugging Face Hub as well. + HfApi().upload_file( + path_or_fileobj=self.pipeline_path, + path_in_repo=PIPELINE_CONFIG_FILENAME, + repo_id=repo_id, + repo_type="dataset", + token=token, + ) + if self.log_filename_path: + # The same we had with "pipeline.yaml" but with the log file. + HfApi().upload_file( + path_or_fileobj=self.log_filename_path, + path_in_repo=PIPELINE_LOG_FILENAME, + repo_id=repo_id, + repo_type="dataset", + token=token, + ) + def train_test_split( self, train_size: float, shuffle: bool = True, seed: Optional[int] = None, - ) -> "Distiset": + ) -> Self: """Return a `Distiset` whose values will be a `datasets.DatasetDict` with two random train and test subsets. Splits are created from the dataset according to `train_size` and `shuffle`. @@ -193,6 +232,198 @@ def train_test_split( ) return self + def save_to_disk( + self, + distiset_path: PathLike, + max_shard_size: Optional[Union[str, int]] = None, + num_shards: Optional[int] = None, + num_proc: Optional[int] = None, + storage_options: Optional[dict] = None, + save_card: bool = True, + save_pipeline_config: bool = True, + save_pipeline_log: bool = True, + ) -> None: + r""" + Saves a `Distiset` to a dataset directory, or in a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. + + In case you want to save the `Distiset` in a remote filesystem, you can pass the `storage_options` parameter + as you would do with `datasets`'s `Dataset.save_to_disk` method: [see example](https://huggingface.co/docs/datasets/filesystems#saving-serialized-datasets) + + Args: + distiset_path: Path where you want to save the `Distiset`. It can be a local path + (e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`) + max_shard_size: The maximum size of the dataset shards to be uploaded to the hub. + If expressed as a string, needs to be digits followed by a unit (like `"50MB"`). + Defaults to `None`. + num_shards: Number of shards to write. By default the number of shards depends on + `max_shard_size` and `num_proc`. Defaults to `None`. + num_proc: Number of processes when downloading and generating the dataset locally. + Multiprocessing is disabled by default. Defaults to `None`. + storage_options: Key/value pairs to be passed on to the file-system backend, if any. + Defaults to `None`. + save_card: Whether to save the dataset card. Defaults to `True`. + save_pipeline_config: Whether to save the pipeline configuration file (aka the `pipeline.yaml` file). + Defaults to `True`. + save_pipeline_log: Whether to save the pipeline log file (aka the `pipeline.log` file). + Defaults to `True`. + + Examples: + ```python + # Save your distiset in a local folder: + >>> distiset.save_to_disk(dataset_path="my-distiset") + # Save your distiset in a remote storage: + >>> storage_options = { + ... "key": os.environ["S3_ACCESS_KEY"], + ... "secret": os.environ["S3_SECRET_KEY"], + ... "client_kwargs": { + ... "endpoint_url": os.environ["S3_ENDPOINT_URL"], + ... "region_name": os.environ["S3_REGION"], + ... }, + ... } + >>> distiset.save_to_disk(dataset_path="my-distiset", storage_options=storage_options) + ``` + """ + distiset_path = str(distiset_path) + for name, dataset in self.items(): + dataset.save_to_disk( + f"{distiset_path}/{name}", + max_shard_size=max_shard_size, + num_shards=num_shards, + num_proc=num_proc, + storage_options=storage_options, + ) + + distiset_config_folder = posixpath.join(distiset_path, DISTISET_CONFIG_FOLDER) + + fs: fsspec.AbstractFileSystem + fs, _, _ = fsspec.get_fs_token_paths( + distiset_config_folder, storage_options=storage_options + ) + fs.makedirs(distiset_config_folder, exist_ok=True) + + if save_card: + # NOTE: Currently the card is not the same if we write to disk or push to the HF hub, + # as we aren't generating the README copying/updating the data from the dataset repo. + card = self._get_card(repo_id=Path(distiset_path).stem, token=None) + new_filename = posixpath.join(distiset_config_folder, "README.md") + if storage_options: + # Write the card the same way as DatasetCard.save does: + with fs.open(new_filename, "w", newline="", encoding="utf-8") as f: + f.write(str(card)) + else: + card.save(new_filename) + + # Write our internal files to the distiset folder by copying them to the distiset folder. + if save_pipeline_config and self.pipeline_path: + new_filename = posixpath.join( + distiset_config_folder, PIPELINE_CONFIG_FILENAME + ) + if self.pipeline_path.exists() and (not fs.isfile(new_filename)): + data = yaml.safe_load(self.pipeline_path.read_text()) + with fs.open(new_filename, "w", encoding="utf-8") as f: + yaml.dump(data, f, default_flow_style=False) + + if save_pipeline_log and self.log_filename_path: + new_filename = posixpath.join(distiset_config_folder, PIPELINE_LOG_FILENAME) + if self.log_filename_path.exists() and (not fs.isfile(new_filename)): + data = self.log_filename_path.read_text() + with fs.open(new_filename, "w", encoding="utf-8") as f: + f.write(data) + + @classmethod + def load_from_disk( + cls, + distiset_path: PathLike, + keep_in_memory: Optional[bool] = None, + storage_options: Optional[Dict[str, Any]] = None, + download_dir: Optional[PathLike] = None, + ) -> Self: + """Loads a dataset that was previously saved using `Distiset.save_to_disk` from a dataset + directory, or from a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. + + Args: + distiset_path: Path ("dataset/train") or remote URI ("s3://bucket/dataset/train"). + keep_in_memory: Whether to copy the dataset in-memory, see `datasets.Dataset.load_from_disk`` + for more information. Defaults to `None`. + storage_options: Key/value pairs to be passed on to the file-system backend, if any. + Defaults to `None`. + download_dir: Optional directory to download the dataset to. Defaults to None, + in which case it will create a temporary directory. + + Returns: + A `Distiset` loaded from disk, it should be a `Distiset` object created using `Distiset.save_to_disk`. + """ + original_distiset_path = str(distiset_path) + + fs: fsspec.AbstractFileSystem + fs, _, [distiset_path] = fsspec.get_fs_token_paths( + original_distiset_path, storage_options=storage_options + ) + dest_distiset_path = distiset_path + + assert fs.isdir( + original_distiset_path + ), "`distiset_path` must be a `PathLike` object pointing to a folder or a URI of a remote filesystem." + + has_config = False + distiset = cls() + + if is_remote_filesystem(fs): + src_dataset_path = distiset_path + if download_dir: + dest_distiset_path = download_dir + else: + dest_distiset_path = Dataset._build_local_temp_path(src_dataset_path) + fs.download(src_dataset_path, dest_distiset_path.as_posix(), recursive=True) + + # Now we should have the distiset locally, so we can read those files + for folder in Path(dest_distiset_path).iterdir(): + if folder.stem == DISTISET_CONFIG_FOLDER: + has_config = True + continue + distiset[folder.stem] = load_from_disk( + str(folder), + keep_in_memory=keep_in_memory, + ) + # From the config folder we just need to point to the files. Once downloaded we set the path + # to wherever they are. + if has_config: + distiset_config_folder = posixpath.join( + dest_distiset_path, DISTISET_CONFIG_FOLDER + ) + + pipeline_path = posixpath.join( + distiset_config_folder, PIPELINE_CONFIG_FILENAME + ) + if Path(pipeline_path).exists(): + distiset.pipeline_path = Path(pipeline_path) + + log_filename_path = posixpath.join( + distiset_config_folder, PIPELINE_LOG_FILENAME + ) + if Path(log_filename_path).exists(): + distiset.log_filename_path = Path(log_filename_path) + + return distiset + + @property + def pipeline_path(self) -> Union[Path, None]: + """Returns the path to the `pipeline.yaml` file that generated the `Pipeline`.""" + return self._pipeline_path + + @pipeline_path.setter + def pipeline_path(self, path: PathLike) -> None: + self._pipeline_path = Path(path) + + @property + def log_filename_path(self) -> Union[Path, None]: + """Returns the path to the `pipeline.log` file that generated the `Pipeline`.""" + return self._log_filename_path + + @log_filename_path.setter + def log_filename_path(self, path: PathLike) -> None: + self._log_filename_path = Path(path) + def __repr__(self): # Copy from `datasets.DatasetDict.__repr__`. repr = "\n".join([f"{k}: {v}" for k, v in self.items()]) @@ -208,6 +439,9 @@ def create_distiset( # noqa: C901 ) -> Distiset: """Creates a `Distiset` from the buffer folder. + This function is intended to be used as a helper to create a `Distiset` from from the folder + where the cached data was written by the `_WriteBuffer`. + Args: data_dir: Folder where the data buffers were written by the `_WriteBuffer`. It should correspond to `CacheLocation.data`. @@ -223,7 +457,16 @@ def create_distiset( # noqa: C901 Returns: The dataset created from the buffer folder, where the different leaf steps will correspond to different configurations of the dataset. + + Examples: + + ```python + >>> from pathlib import Path + >>> distiset = create_distiset(Path.home() / ".cache/distilabel/pipelines/path-to-pipe-hashname") + ``` """ + from distilabel.steps.constants import DISTILABEL_METADATA_KEY + logger = logging.getLogger("distilabel.distiset") data_dir = Path(data_dir) diff --git a/src/distilabel/llms/__init__.py b/src/distilabel/llms/__init__.py index 73009795a2..3e50ddefaa 100644 --- a/src/distilabel/llms/__init__.py +++ b/src/distilabel/llms/__init__.py @@ -23,6 +23,7 @@ from distilabel.llms.llamacpp import LlamaCppLLM from distilabel.llms.mistral import MistralLLM from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.moa import MixtureOfAgentsLLM from distilabel.llms.ollama import OllamaLLM from distilabel.llms.openai import OpenAILLM from distilabel.llms.together import TogetherLLM @@ -43,6 +44,7 @@ "LlamaCppLLM", "MistralLLM", "CudaDevicePlacementMixin", + "MixtureOfAgentsLLM", "OllamaLLM", "OpenAILLM", "TogetherLLM", diff --git a/src/distilabel/llms/anthropic.py b/src/distilabel/llms/anthropic.py index f472aca664..843b14b21f 100644 --- a/src/distilabel/llms/anthropic.py +++ b/src/distilabel/llms/anthropic.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import os from typing import ( TYPE_CHECKING, - Any, List, Literal, Optional, @@ -28,13 +26,14 @@ from httpx import AsyncClient from pydantic import Field, PrivateAttr, SecretStr, validate_call -from typing_extensions import override from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType -from distilabel.utils.itertools import grouper +from distilabel.steps.tasks.typing import ( + FormattedInput, + InstructorStructuredOutputType, +) if TYPE_CHECKING: from anthropic import AsyncAnthropic @@ -59,6 +58,9 @@ class AnthropicLLM(AsyncLLM): to `6`. http_client: if provided, an alternative HTTP client to use for calling Anthropic API. Defaults to `None`. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. _api_key_env_var: the name of the environment variable to use for the API key. It is meant to be used internally. _aclient: the `AsyncAnthropic` client to use for the Anthropic API. It is meant @@ -71,6 +73,46 @@ class AnthropicLLM(AsyncLLM): - `timeout`: the maximum time in seconds to wait for a response. Defaults to `600.0`. - `max_retries`: the maximum number of times to retry the request before failing. Defaults to `6`. + + Examples: + + Generate text: + + ```python + from distilabel.llms import AnthropicLLM + + llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key") + + llm.load() + + # Synchronous request + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + + # Asynchronous request + output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}]) + ``` + + Generate structured data: + + ```python + from pydantic import BaseModel + from distilabel.llms import AnthropicLLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = AnthropicLLM( + model="claude-3-opus-20240229", + api_key="api.key", + structured_output={"schema": User} + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` """ model: str @@ -94,6 +136,14 @@ class AnthropicLLM(AsyncLLM): " failing.", ) http_client: Optional[AsyncClient] = Field(default=None, exclude=True) + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( + Field( + default=None, + description="The structured output format to use across all the generations.", + ) + ) + + _num_generations_param_supported = False _api_key_env_var: str = PrivateAttr(default=_ANTHROPIC_API_KEY_ENV_VAR_NAME) _aclient: Optional["AsyncAnthropic"] = PrivateAttr(...) @@ -143,6 +193,15 @@ def load(self) -> None: http_client=self.http_client, max_retries=self.max_retries, ) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="anthropic", + ) + self._aclient = result.get("client") + if structured_output := result.get("structured_output"): + self.structured_output = structured_output @property def model_name(self) -> str: @@ -152,7 +211,7 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: FormattedInput, max_tokens: int = 128, stop_sequences: Union[List[str], None] = None, temperature: float = 1.0, @@ -174,22 +233,45 @@ async def agenerate( # type: ignore """ from anthropic._types import NOT_GIVEN - completion = await self._aclient.messages.create( # type: ignore - model=self.model, - system=( + structured_output = None + if isinstance(input, tuple): + input, structured_output = input + result = self._prepare_structured_output( + structured_output=structured_output, + client=self._aclient, + framework="anthropic", + ) + self._aclient = result.get("client") + + if structured_output is None and self.structured_output is not None: + structured_output = self.structured_output + + kwargs = { + "messages": input, # type: ignore + "model": self.model, + "system": ( input.pop(0)["content"] if input and input[0]["role"] == "system" else NOT_GIVEN ), - messages=input, # type: ignore - max_tokens=max_tokens, - stream=False, - stop_sequences=NOT_GIVEN if stop_sequences is None else stop_sequences, - temperature=temperature, - top_p=NOT_GIVEN if top_p is None else top_p, - top_k=NOT_GIVEN if top_k is None else top_k, - ) + "max_tokens": max_tokens, + "stream": False, + "stop_sequences": NOT_GIVEN if stop_sequences is None else stop_sequences, + "temperature": temperature, + "top_p": NOT_GIVEN if top_p is None else top_p, + "top_k": NOT_GIVEN if top_k is None else top_k, + } + + if structured_output: + kwargs = self._prepare_kwargs(kwargs, structured_output) + generations = [] + + completion = await self._aclient.messages.create(**kwargs) # type: ignore + if structured_output: + generations.append(completion.model_dump_json()) + return generations + if (content := completion.content[0].text) is None: self._logger.warning( f"Received no response using Anthropic client (model: '{self.model}')." @@ -197,29 +279,3 @@ async def agenerate( # type: ignore ) generations.append(content) return generations - - # TODO: remove this function once Anthropic client allows `n` parameter - @override - def generate( - self, - inputs: List["ChatType"], - num_generations: int = 1, - **kwargs: Any, - ) -> List["GenerateOutput"]: - """Method to generate a list of responses asynchronously, returning the output - synchronously awaiting for the response of each input sent to `agenerate`. - """ - - async def agenerate( - inputs: List["ChatType"], **kwargs: Any - ) -> "GenerateOutput": - """Internal function to parallelize the asynchronous generation of responses.""" - tasks = [ - asyncio.create_task(self.agenerate(input=input, **kwargs)) - for input in inputs - for _ in range(num_generations) - ] - return [outputs[0] for outputs in await asyncio.gather(*tasks)] - - outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs)) - return list(grouper(outputs, n=num_generations, incomplete="ignore")) diff --git a/src/distilabel/llms/anyscale.py b/src/distilabel/llms/anyscale.py index d7eff02043..54b777b8a8 100644 --- a/src/distilabel/llms/anyscale.py +++ b/src/distilabel/llms/anyscale.py @@ -38,6 +38,20 @@ class AnyscaleLLM(OpenAILLM): `None` if not set. _api_key_env_var: the name of the environment variable to use for the API key. It is meant to be used internally. + + Examples: + + Generate text: + + ```python + from distilabel.llms import AnyscaleLLM + + llm = AnyscaleLLM(model="google/gemma-7b-it", api_key="api.key") + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + ``` """ base_url: Optional[RuntimeParameter[str]] = Field( diff --git a/src/distilabel/llms/azure.py b/src/distilabel/llms/azure.py index 58d455d65e..ebcb5ef9ea 100644 --- a/src/distilabel/llms/azure.py +++ b/src/distilabel/llms/azure.py @@ -14,6 +14,7 @@ import os from typing import TYPE_CHECKING, Optional +from unittest.mock import patch from pydantic import Field, PrivateAttr, SecretStr from typing_extensions import override @@ -45,6 +46,65 @@ class AzureOpenAILLM(OpenAILLM): Icon: `:simple-microsoftazure:` + + Examples: + + Generate text: + + ```python + from distilabel.llms import AzureOpenAILLM + + llm = AzureOpenAILLM(model="gpt-4-turbo", api_key="api.key") + + llm.load() + + # Synchrounous request + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + + # Asynchronous request + output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}]) + ``` + + Generate text from a custom endpoint following the OpenAI API: + + ```python + from distilabel.llms import AzureOpenAILLM + + llm = AzureOpenAILLM( + model="prometheus-eval/prometheus-7b-v2.0", + base_url=r"http://localhost:8080/v1" + ) + + llm.load() + + # Synchronous request + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + + # Asynchronous request + output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}]) + ``` + + Generate structured data: + + ```python + from pydantic import BaseModel + from distilabel.llms import AzureOpenAILLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = AzureOpenAILLM( + model="gpt-4-turbo", + api_key="api.key", + structured_output={"schema": User} + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` """ base_url: Optional[RuntimeParameter[str]] = Field( @@ -68,7 +128,12 @@ class AzureOpenAILLM(OpenAILLM): @override def load(self) -> None: """Loads the `AsyncAzureOpenAI` client to benefit from async requests.""" - super().load() + # This is a workaround to avoid the `OpenAILLM` calling the _prepare_structured_output + # in the load method before we have the proper client. + with patch( + "distilabel.llms.openai.OpenAILLM._prepare_structured_output", lambda x: x + ): + super().load() try: from openai import AsyncAzureOpenAI @@ -93,3 +158,6 @@ def load(self) -> None: max_retries=self.max_retries, # type: ignore timeout=self.timeout, ) + + if self.structured_output: + self._prepare_structured_output(self.structured_output) diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index 5bab1c603c..2a64e77847 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -14,6 +14,7 @@ import asyncio import inspect +import json import logging import sys from abc import ABC, abstractmethod @@ -27,14 +28,22 @@ RuntimeParametersMixin, ) from distilabel.utils.docstring import parse_google_docstring +from distilabel.utils.itertools import grouper from distilabel.utils.notebook import in_notebook from distilabel.utils.serialization import _Serializable if TYPE_CHECKING: from distilabel.llms.typing import GenerateOutput, HiddenState - from distilabel.mixins.runtime_parameters import RuntimeParametersNames + from distilabel.mixins.runtime_parameters import ( + RuntimeParameterInfo, + RuntimeParametersNames, + ) from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType - from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.tasks.typing import ( + FormattedInput, + InstructorStructuredOutputType, + StandardInput, + ) from distilabel.utils.docstring import Docstring if in_notebook(): @@ -55,8 +64,6 @@ class LLM(RuntimeParametersMixin, BaseModel, _Serializable, ABC): Attributes: generation_kwargs: the kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`. - structured_output: a dictionary containing the structured output configuration or if more - fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. _logger: the logger to be used for the `LLM`. It will be initialized when the `load` method is called. """ @@ -74,7 +81,6 @@ class LLM(RuntimeParametersMixin, BaseModel, _Serializable, ABC): description="The kwargs to be propagated to either `generate` or `agenerate`" " methods within each `LLM`.", ) - structured_output: Optional[Any] = None _logger: Union[logging.Logger, None] = PrivateAttr(...) @@ -82,16 +88,29 @@ def load(self) -> None: """Method to be called to initialize the `LLM`, its logger and optionally the structured output generator.""" self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}") + def unload(self) -> None: + """Method to be called to unload the `LLM` and release any resources.""" + pass + @property @abstractmethod def model_name(self) -> str: """Returns the model name used for the LLM.""" pass + def get_generation_kwargs(self) -> Dict[str, Any]: + """Returns the generation kwargs to be used for the generation. This method can + be overridden to provide a more complex logic for the generation kwargs. + + Returns: + The kwargs to be used for the generation. + """ + return self.generation_kwargs # type: ignore + @abstractmethod def generate( self, - inputs: List["ChatType"], + inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any, ) -> List["GenerateOutput"]: @@ -146,7 +165,7 @@ def runtime_parameters_names(self) -> "RuntimeParametersNames": return runtime_parameters - def get_runtime_parameters_info(self) -> List[Dict[str, Any]]: + def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]: """Gets the information of the runtime parameters of the `LLM` such as the name and the description. This function is meant to include the information of the runtime parameters in the serialized data of the `LLM`. @@ -157,21 +176,27 @@ def get_runtime_parameters_info(self) -> List[Dict[str, Any]]: runtime_parameters_info = super().get_runtime_parameters_info() generation_kwargs_info = next( - runtime_parameter_info - for runtime_parameter_info in runtime_parameters_info - if runtime_parameter_info["name"] == "generation_kwargs" + ( + runtime_parameter_info + for runtime_parameter_info in runtime_parameters_info + if runtime_parameter_info["name"] == "generation_kwargs" + ), + None, ) - generate_docstring_args = self.generate_parsed_docstring["args"] + # If `generation_kwargs` attribute is present, we need to include the `generate` + # method arguments as the information for this attribute. + if generation_kwargs_info: + generate_docstring_args = self.generate_parsed_docstring["args"] - generation_kwargs_info["keys"] = [] - for key, value in generation_kwargs_info["optional"].items(): - info = {"name": key, "optional": value} - if description := generate_docstring_args.get(key): - info["description"] = description - generation_kwargs_info["keys"].append(info) + generation_kwargs_info["keys"] = [] + for key, value in generation_kwargs_info["optional"].items(): + info = {"name": key, "optional": value} + if description := generate_docstring_args.get(key): + info["description"] = description + generation_kwargs_info["keys"].append(info) - generation_kwargs_info.pop("optional") + generation_kwargs_info.pop("optional") return runtime_parameters_info @@ -184,7 +209,9 @@ def generate_parsed_docstring(self) -> "Docstring": """ return parse_google_docstring(self.generate) - def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]: + def get_last_hidden_states( + self, inputs: List["StandardInput"] + ) -> List["HiddenState"]: """Method to get the last hidden states of the model for a list of inputs. Args: @@ -227,7 +254,9 @@ class AsyncLLM(LLM): _event_loop: the event loop to be used for the asynchronous generation of responses. """ + _num_generations_param_supported = True _event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None) + _new_event_loop: bool = PrivateAttr(default=False) @property def generate_parameters(self) -> List[inspect.Parameter]: @@ -254,34 +283,36 @@ def event_loop(self) -> "asyncio.AbstractEventLoop": self._event_loop = asyncio.get_running_loop() if self._event_loop.is_closed(): self._event_loop = asyncio.new_event_loop() # type: ignore + self._new_event_loop = True except RuntimeError: self._event_loop = asyncio.new_event_loop() + self._new_event_loop = True asyncio.set_event_loop(self._event_loop) return self._event_loop @abstractmethod async def agenerate( - self, input: "ChatType", num_generations: int = 1, **kwargs: Any + self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any ) -> List[Union[str, None]]: """Method to generate a `num_generations` responses for a given input asynchronously, and executed concurrently in `generate` method. """ pass - def generate( - self, - inputs: List["ChatType"], - num_generations: int = 1, - **kwargs: Any, + async def _agenerate( + self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any ) -> List["GenerateOutput"]: - """Method to generate a list of responses asynchronously, returning the output - synchronously awaiting for the response of each input sent to `agenerate`. - """ + """Internal function to concurrently generate responses for a list of inputs. - async def agenerate( - inputs: List["ChatType"], **kwargs: Any - ) -> List[List[Union[str, None]]]: - """Internal function to parallelize the asynchronous generation of responses.""" + Args: + inputs: the list of inputs to generate responses for. + num_generations: the number of generations to generate per input. + **kwargs: the additional kwargs to be used for the generation. + + Returns: + A list containing the generations for each input. + """ + if self._num_generations_param_supported: tasks = [ asyncio.create_task( self.agenerate( @@ -292,11 +323,125 @@ async def agenerate( ] return await asyncio.gather(*tasks) - return self.event_loop.run_until_complete(agenerate(inputs, **kwargs)) + tasks = [ + asyncio.create_task(self.agenerate(input=input, **kwargs)) + for input in inputs + for _ in range(num_generations) + ] + outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)] + return list(grouper(outputs, n=num_generations, incomplete="ignore")) + + def generate( + self, + inputs: List["FormattedInput"], + num_generations: int = 1, + **kwargs: Any, + ) -> List["GenerateOutput"]: + """Method to generate a list of responses asynchronously, returning the output + synchronously awaiting for the response of each input sent to `agenerate`. + + Args: + inputs: the list of inputs to generate responses for. + num_generations: the number of generations to generate per input. + **kwargs: the additional kwargs to be used for the generation. + + Returns: + A list containing the generations for each input. + """ + return self.event_loop.run_until_complete( + self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs) + ) def __del__(self) -> None: """Closes the event loop when the object is deleted.""" if sys.meta_path is None: return - if self.event_loop is not None: - self.event_loop.close() + + if self._new_event_loop: + if self._event_loop.is_running(): + self._event_loop.stop() + self._event_loop.close() + + @staticmethod + def _prepare_structured_output( # type: ignore + structured_output: "InstructorStructuredOutputType", + client: Any = None, + framework: Optional[str] = None, + ) -> Dict[str, Union[str, Any]]: + """Wraps the client and updates the schema to work store it internally as a json schema. + + Args: + structured_output: The configuration dict to prepare the structured output. + client: The client to wrap to generate structured output. Implemented to work + with `instructor`. + framework: The name of the framework. + + Returns: + A dictionary containing the wrapped client and the schema to update the structured_output + variable in case it is a pydantic model. + """ + from distilabel.steps.tasks.structured_outputs.instructor import ( + prepare_instructor, + ) + + result = {} + client = prepare_instructor( + client, + mode=structured_output.get("mode"), + framework=framework, # type: ignore + ) + result["client"] = client + + schema = structured_output.get("schema") + if not schema: + raise ValueError( + f"The `structured_output` argument must contain a schema: {structured_output}" + ) + if inspect.isclass(schema) and issubclass(schema, BaseModel): + # We want a json schema for the serialization, but instructor wants a pydantic BaseModel. + structured_output["schema"] = schema.model_json_schema() # type: ignore + result["structured_output"] = structured_output + + return result + + @staticmethod + def _prepare_kwargs( + arguments: Dict[str, Any], structured_output: Dict[str, Any] + ) -> Dict[str, Any]: + """Helper method to update the kwargs with the structured output configuration, + used in case they are defined. + + Args: + arguments: The arguments that would be passed to the LLM as **kwargs. + to update with the structured output configuration. + structured_outputs: The structured output configuration to update the arguments. + + Returns: + kwargs updated with the special arguments used by `instructor`. + """ + # We can deal with json schema or BaseModel, but we need to convert it to a BaseModel + # for the Instructor client. + schema = structured_output.get("schema", {}) + if not issubclass(schema, BaseModel): + from distilabel.steps.tasks.structured_outputs.utils import ( + json_schema_to_model, + ) + + if isinstance(schema, str): + # In case it was saved in the dataset as a string. + schema = json.loads(schema) + + try: + schema = json_schema_to_model(schema) + except Exception as e: + raise ValueError( + f"Failed to convert the schema to a pydantic model, the model is too complex currently: {e}" + ) from e + + arguments.update( + **{ + "response_model": schema, + "max_retries": structured_output.get("max_retries", 1), + }, + ) + return arguments diff --git a/src/distilabel/llms/chat_templates.py b/src/distilabel/llms/chat_templates.py index 47b96b33b0..7edba0132c 100644 --- a/src/distilabel/llms/chat_templates.py +++ b/src/distilabel/llms/chat_templates.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -CHATML_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" +CHATML_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message[\"content\"] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/llms/cohere.py index b7ecafcace..a1295a9ba8 100644 --- a/src/distilabel/llms/cohere.py +++ b/src/distilabel/llms/cohere.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import os from typing import ( TYPE_CHECKING, - Any, List, Optional, Sequence, @@ -26,17 +24,18 @@ ) from pydantic import Field, PrivateAttr, SecretStr, validate_call -from typing_extensions import override from distilabel.llms.base import AsyncLLM +from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType -from distilabel.utils.itertools import grouper +from distilabel.steps.tasks.typing import ( + FormattedInput, + InstructorStructuredOutputType, +) if TYPE_CHECKING: from cohere import AsyncClient, ChatMessage - from distilabel.llms.typing import GenerateOutput _COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY" @@ -54,6 +53,9 @@ class CohereLLM(AsyncLLM): to `120`. client_name: the name of the client to use for the API requests. Defaults to `"distilabel"`. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. _ChatMessage: the `ChatMessage` class from the `cohere` package. _aclient: the `AsyncClient` client from the `cohere` package. @@ -66,6 +68,42 @@ class CohereLLM(AsyncLLM): to `120`. - `client_name`: the name of the client to use for the API requests. Defaults to `"distilabel"`. + + Examples: + + Generate text: + + ```python + from distilabel.llms import CohereLLM + + llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus") + + llm.load() + + # Call the model + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + + Generate structured data: + + ```python + from pydantic import BaseModel + from distilabel.llms import CohereLLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = CohereLLM( + model="CohereForAI/c4ai-command-r-plus", + api_key="api.key", + structured_output={"schema": User} + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` """ model: str @@ -87,6 +125,14 @@ class CohereLLM(AsyncLLM): default="distilabel", description="The name of the client to use for the API requests.", ) + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( + Field( + default=None, + description="The structured output format to use across all the generations.", + ) + ) + + _num_generations_param_supported = False _ChatMessage: Type["ChatMessage"] = PrivateAttr(...) _aclient: "AsyncClient" = PrivateAttr(...) @@ -117,8 +163,18 @@ def load(self) -> None: timeout=self.timeout, ) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="cohere", + ) + self._aclient = result.get("client") # type: ignore + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + def _format_chat_to_cohere( - self, input: "ChatType" + self, input: "FormattedInput" ) -> Tuple[Union[str, None], List["ChatMessage"], str]: """Formats the chat input to the Cohere Chat API conversational format. @@ -144,7 +200,7 @@ def _format_chat_to_cohere( "An assistant message but be preceded by a user message." ) chat_history.append(self._ChatMessage(role="USER", message=message)) # type: ignore - chat_history.append(self._ChatMessage(role="CHATBOT", message=content)) + chat_history.append(self._ChatMessage(role="CHATBOT", message=content)) # type: ignore message = None if message is None: @@ -155,7 +211,7 @@ def _format_chat_to_cohere( @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: FormattedInput, temperature: Optional[float] = None, max_tokens: Optional[int] = None, k: Optional[int] = None, @@ -165,7 +221,7 @@ async def agenerate( # type: ignore frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, raw_prompting: Optional[bool] = None, - ) -> Union[str, None]: + ) -> GenerateOutput: """Generates a response from the LLM given an input. Args: @@ -190,53 +246,49 @@ async def agenerate( # type: ignore Returns: The generated response from the Cohere API model. """ + structured_output = None + if isinstance(input, tuple): + input, structured_output = input + result = self._prepare_structured_output( + structured_output=structured_output, # type: ignore + client=self._aclient, + framework="cohere", + ) + self._aclient = result.get("client") # type: ignore + + if structured_output is None and self.structured_output is not None: + structured_output = self.structured_output + system, chat_history, message = self._format_chat_to_cohere(input) - response = await self._aclient.chat( # type: ignore - message=message, - model=self.model, - preamble=system, - chat_history=chat_history, - temperature=temperature, - max_tokens=max_tokens, - k=k, - p=p, - seed=seed, - stop_sequences=stop_sequences, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - raw_prompting=raw_prompting, - ) + kwargs = { + "message": message, + "model": self.model, + "preamble": system, + "chat_history": chat_history, + "temperature": temperature, + "max_tokens": max_tokens, + "k": k, + "p": p, + "seed": seed, + "stop_sequences": stop_sequences, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "raw_prompting": raw_prompting, + } + if structured_output: + kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore + + response = await self._aclient.chat(**kwargs) # type: ignore + + if structured_output: + return [response.model_dump_json()] if (text := response.text) == "": - self._logger.warning( + self._logger.warning( # type: ignore f"Received no response using Cohere client (model: '{self.model}')." f" Finish reason was: {response.finish_reason}" ) - return None - - return text + return [None] - @override - def generate( - self, - inputs: List["ChatType"], - num_generations: int = 1, - **kwargs: Any, - ) -> List["GenerateOutput"]: - """Method to generate a list of responses asynchronously, returning the output - synchronously awaiting for the response of each input sent to `agenerate`.""" - - async def agenerate( - inputs: List["ChatType"], **kwargs: Any - ) -> "GenerateOutput": - """Internal function to parallelize the asynchronous generation of responses.""" - tasks = [ - asyncio.create_task(self.agenerate(input=input, **kwargs)) - for input in inputs - for _ in range(num_generations) - ] - return await asyncio.gather(*tasks) - - outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs)) - return list(grouper(outputs, n=num_generations, incomplete="ignore")) + return [text] diff --git a/src/distilabel/llms/groq.py b/src/distilabel/llms/groq.py index f7fbda1dc8..3a362951ec 100644 --- a/src/distilabel/llms/groq.py +++ b/src/distilabel/llms/groq.py @@ -12,22 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import os -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Optional from pydantic import Field, PrivateAttr, SecretStr, validate_call -from typing_extensions import override from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.steps.base import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType -from distilabel.utils.itertools import grouper +from distilabel.steps.tasks.typing import ( + FormattedInput, + InstructorStructuredOutputType, +) if TYPE_CHECKING: from groq import AsyncGroq + _GROQ_API_BASE_URL_ENV_VAR_NAME = "GROQ_BASE_URL" _GROQ_API_KEY_ENV_VAR_NAME = "GROQ_API_KEY" @@ -45,6 +46,9 @@ class GroqLLM(AsyncLLM): failing. Defaults to `2`. timeout: the maximum time in seconds to wait for a response from the API. Defaults to `120`. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. _api_key_env_var: the name of the environment variable to use for the API key. _aclient: the `AsyncGroq` client from the `groq` package. @@ -57,6 +61,42 @@ class GroqLLM(AsyncLLM): failing. Defaults to `2`. - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults to `120`. + + Examples: + + Generate text: + + ```python + from distilabel.llms import GroqLLM + + llm = GroqLLM(model="llama3-70b-8192") + + llm.load() + + # Call the model + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + + Generate structured data: + + ```python + from pydantic import BaseModel + from distilabel.llms import GroqLLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = GroqLLM( + model="llama3-70b-8192", + api_key="api.key", + structured_output={"schema": User} + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` """ model: str @@ -80,6 +120,14 @@ class GroqLLM(AsyncLLM): default=120, description="The maximum time in seconds to wait for a response from the API.", ) + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( + Field( + default=None, + description="The structured output format to use across all the generations.", + ) + ) + + _num_generations_param_supported = False _api_key_env_var: str = PrivateAttr(_GROQ_API_KEY_ENV_VAR_NAME) _aclient: Optional["AsyncGroq"] = PrivateAttr(...) @@ -109,6 +157,16 @@ def load(self) -> None: timeout=self.timeout, ) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="groq", + ) + self._aclient = result.get("client") # type: ignore + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" @@ -117,7 +175,7 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: FormattedInput, seed: Optional[int] = None, max_new_tokens: int = 128, temperature: float = 1.0, @@ -142,17 +200,38 @@ async def agenerate( # type: ignore References: - https://console.groq.com/docs/text-chat """ - completion = await self._aclient.chat.completions.create( # type: ignore - messages=input, # type: ignore - model=self.model, - seed=seed, # type: ignore - temperature=temperature, - max_tokens=max_new_tokens, - top_p=top_p, - stream=False, - stop=stop, - ) + structured_output = None + if isinstance(input, tuple): + input, structured_output = input + result = self._prepare_structured_output( + structured_output=structured_output, + client=self._aclient, + framework="groq", + ) + self._aclient = result.get("client") + + if structured_output is None and self.structured_output is not None: + structured_output = self.structured_output + + kwargs = { + "messages": input, # type: ignore + "model": self.model, + "seed": seed, + "temperature": temperature, + "max_tokens": max_new_tokens, + "top_p": top_p, + "stream": False, + "stop": stop, + } + if structured_output: + kwargs = self._prepare_kwargs(kwargs, structured_output) + generations = [] + completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore + if structured_output: + generations.append(completion.model_dump_json()) + return generations + for choice in completion.choices: if (content := choice.message.content) is None: self._logger.warning( # type: ignore @@ -161,29 +240,3 @@ async def agenerate( # type: ignore ) generations.append(content) return generations - - # TODO: remove this function once Groq client allows `n` parameter - @override - def generate( - self, - inputs: List["ChatType"], - num_generations: int = 1, - **kwargs: Any, - ) -> List["GenerateOutput"]: - """Method to generate a list of responses asynchronously, returning the output - synchronously awaiting for the response of each input sent to `agenerate`. - """ - - async def agenerate( - inputs: List["ChatType"], **kwargs: Any - ) -> "GenerateOutput": - """Internal function to parallelize the asynchronous generation of responses.""" - tasks = [ - asyncio.create_task(self.agenerate(input=input, **kwargs)) - for input in inputs - for _ in range(num_generations) - ] - return [outputs[0] for outputs in await asyncio.gather(*tasks)] - - outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs)) - return list(grouper(outputs, n=num_generations, incomplete="ignore")) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 4570b93d1d..6d4d3d1a5e 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import os import random import warnings -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from pydantic import ( Field, @@ -31,8 +30,15 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType -from distilabel.utils.itertools import grouper +from distilabel.steps.tasks.typing import ( + FormattedInput, + StandardInput, + StructuredOutputType, +) +from distilabel.utils.huggingface import ( + _INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME, + get_hf_token, +) if TYPE_CHECKING: from huggingface_hub import AsyncInferenceClient @@ -40,9 +46,6 @@ from transformers import PreTrainedTokenizer -_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME = "HF_TOKEN" - - class InferenceEndpointsLLM(AsyncLLM): """InferenceEndpoints LLM implementation running the async API client. @@ -78,11 +81,7 @@ class InferenceEndpointsLLM(AsyncLLM): llm.load() - # Synchrounous request output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) - - # Asynchronous request - output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}]) ``` Dedicated Inference Endpoints: @@ -98,11 +97,7 @@ class InferenceEndpointsLLM(AsyncLLM): llm.load() - # Synchrounous request output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) - - # Asynchronous request - output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}]) ``` Dedicated Inference Endpoints or TGI: @@ -117,11 +112,7 @@ class InferenceEndpointsLLM(AsyncLLM): llm.load() - # Synchrounous request output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) - - # Asynchronous request - output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}]) ``` """ @@ -148,6 +139,13 @@ class InferenceEndpointsLLM(AsyncLLM): model_display_name: Optional[str] = None use_openai_client: bool = False + structured_output: Optional[RuntimeParameter[StructuredOutputType]] = Field( + default=None, + description="The structured output format to use across all the generations.", + ) + + _num_generations_param_supported = False + _model_name: Optional[str] = PrivateAttr(default=None) _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None) _api_key_env_var: str = PrivateAttr(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME) @@ -210,10 +208,7 @@ def load(self) -> None: # noqa: C901 ) from ie if self.api_key is None: - raise ValueError( - f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`" - f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`." - ) + self.api_key = SecretStr(get_hf_token(self.__class__.__name__, "api_key")) if self.model_id is not None: client = InferenceClient() @@ -290,7 +285,7 @@ def model_name(self) -> Union[str, None]: # type: ignore async def _openai_agenerate( self, - input: "ChatType", + input: "StandardInput", max_new_tokens: int = 128, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, @@ -318,11 +313,10 @@ async def _openai_agenerate( ) return [completion.choices[0].message.content] - # TODO: add `num_generations` parameter once either TGI or `AsyncInferenceClient` allows `n` parameter @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: FormattedInput, max_new_tokens: int = 128, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, @@ -336,7 +330,7 @@ async def agenerate( # type: ignore return_full_text: bool = False, seed: Optional[int] = None, watermark: bool = False, - ) -> "GenerateOutput": + ) -> GenerateOutput: """Generates completions for the given input using the OpenAI async client. Args: @@ -379,6 +373,30 @@ async def agenerate( # type: ignore ) stop_sequences = stop_sequences[:4] + structured_output = None + if isinstance(input, tuple): + input, structured_output = input + structured_output = { + "type": structured_output["format"], + "value": structured_output["schema"], + } + + # NOTE: `self.structured_output` applies to all the generations, while `structured_output` i.e. the + # value included within the tuple provided as `input` to this method, is intended to be different per + # each input, so those should not be used together. Meaning that it should be either provided at attribute + # level i.e. self, or via a column within each input i.e. row. + if structured_output is None and self.structured_output is not None: + try: + structured_output = { + "type": self.structured_output["format"], + "value": self.structured_output["schema"], + } + except KeyError as e: + raise ValueError( + "To use the structured output you have to inform the `format` and `schema` in " + "the `structured_output` attribute." + ) from e + if self.use_openai_client: return await self._openai_agenerate( input=input, @@ -400,6 +418,7 @@ async def agenerate( # type: ignore # TODO: should we apply a default chat template here instead? e.g. ChatML prompt = "\n".join([message["content"] for message in input]) + completion = None try: completion = await self._aclient.text_generation( # type: ignore prompt=prompt, # type: ignore @@ -413,40 +432,15 @@ async def agenerate( # type: ignore stop_sequences=stop_sequences, return_full_text=return_full_text, watermark=watermark, + grammar=structured_output, # type: ignore # NOTE: here to ensure that the cache is not used and a different response is # generated every time seed=seed or random.randint(0, 2147483647), ) - return [completion] except Exception as e: self._logger.warning( # type: ignore f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) - return [None] - - # TODO: remove this function once `AsyncInferenceClient` allows `n` parameter - @override - def generate( - self, - inputs: List["ChatType"], - num_generations: int = 1, - **kwargs: Any, - ) -> List["GenerateOutput"]: - """Method to generate a list of responses asynchronously, returning the output - synchronously awaiting for the response of each input sent to `agenerate`. - """ - async def agenerate( - inputs: List["ChatType"], **kwargs: Any - ) -> "GenerateOutput": - """Internal function to parallelize the asynchronous generation of responses.""" - tasks = [ - asyncio.create_task(self.agenerate(input=input, **kwargs)) - for input in inputs - for _ in range(num_generations) - ] - return [outputs[0] for outputs in await asyncio.gather(*tasks)] - - outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs)) - return list(grouper(outputs, n=num_generations, incomplete="ignore")) + return [completion] diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 19ac43b41b..6e7736d006 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -15,13 +15,14 @@ import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union -from pydantic import PrivateAttr, validate_call +from pydantic import Field, PrivateAttr, validate_call from distilabel.llms.base import LLM from distilabel.llms.chat_templates import CHATML_TEMPLATE from distilabel.llms.mixins import CudaDevicePlacementMixin from distilabel.llms.typing import GenerateOutput -from distilabel.steps.tasks.typing import ChatType +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput if TYPE_CHECKING: from transformers import Pipeline @@ -29,7 +30,6 @@ from transformers.tokenization_utils import PreTrainedTokenizer from distilabel.llms.typing import HiddenState - from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType class TransformersLLM(LLM, CudaDevicePlacementMixin): @@ -62,9 +62,26 @@ class TransformersLLM(LLM, CudaDevicePlacementMixin): token: the Hugging Face Hub token that will be used to authenticate to the Hugging Face Hub. If not provided, the `HF_TOKEN` environment or `huggingface_hub` package local configuration will be used. Defaults to `None`. + structured_output: a dictionary containing the structured output configuration or if more + fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. Icon: `:hugging:` + + Examples: + + Generate text: + + ```python + from distilabel.llms import TransformersLLM + + llm = TransformersLLM(model="microsoft/Phi-3-mini-4k-instruct") + + llm.load() + + # Call the model + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + ``` """ model: str @@ -78,6 +95,10 @@ class TransformersLLM(LLM, CudaDevicePlacementMixin): device: Optional[Union[str, int]] = None device_map: Optional[Union[str, Dict[str, Any]]] = None token: Optional[str] = None + structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( + default=None, + description="The structured output format to use across all the generations.", + ) _pipeline: Optional["Pipeline"] = PrivateAttr(...) _prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None) @@ -125,12 +146,17 @@ def load(self) -> None: super().load() + def unload(self) -> None: + """Unloads the `vLLM` model.""" + CudaDevicePlacementMixin.unload(self) + super().unload() + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model - def prepare_input(self, input: "ChatType") -> str: + def prepare_input(self, input: "StandardInput") -> str: """Prepares the input by applying the chat template to the input, which is formatted as an OpenAI conversation, and adding the generation prompt. """ @@ -143,7 +169,7 @@ def prepare_input(self, input: "ChatType") -> str: @validate_call def generate( # type: ignore self, - inputs: List[ChatType], + inputs: List[StandardInput], num_generations: int = 1, max_new_tokens: int = 128, temperature: float = 0.1, @@ -189,7 +215,9 @@ def generate( # type: ignore for output in outputs ] - def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]: + def get_last_hidden_states( + self, inputs: List["StandardInput"] + ) -> List["HiddenState"]: """Gets the last `hidden_states` of the model for the given inputs. It doesn't execute the task head. @@ -222,7 +250,7 @@ def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState" ] def _prepare_structured_output( - self, structured_output: Optional["StructuredOutputType"] = None + self, structured_output: Optional[OutlinesStructuredOutputType] = None ) -> Union[Callable, None]: """Creates the appropriate function to filter tokens to generate structured outputs. diff --git a/src/distilabel/llms/litellm.py b/src/distilabel/llms/litellm.py index 1b0add14ac..71a73365bb 100644 --- a/src/distilabel/llms/litellm.py +++ b/src/distilabel/llms/litellm.py @@ -20,7 +20,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType if TYPE_CHECKING: from litellm import Choices @@ -33,15 +33,60 @@ class LiteLLM(AsyncLLM): model: the model name to use for the LLM e.g. "gpt-3.5-turbo" or "mistral/mistral-large", etc. verbose: whether to log the LiteLLM client's logs. Defaults to `False`. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. Runtime parameters: - `verbose`: whether to log the LiteLLM client's logs. Defaults to `False`. + + Examples: + + Generate text: + + ```python + from distilabel.llms import LiteLLM + + llm = LiteLLM(model="gpt-3.5-turbo") + + llm.load() + + # Call the model + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + + Generate structured data: + + ```python + from pydantic import BaseModel + from distilabel.llms import LiteLLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = LiteLLM( + model="gpt-3.5-turbo", + api_key="api.key", + structured_output={"schema": User} + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` """ model: str verbose: RuntimeParameter[bool] = Field( default=False, description="Whether to log the LiteLLM client's logs." ) + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( + Field( + default=None, + description="The structured output format to use across all the generations.", + ) + ) _aclient: Optional[Callable] = PrivateAttr(...) @@ -69,15 +114,25 @@ def load(self) -> None: continue logging.getLogger(key).setLevel(logging.CRITICAL) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="litellm", + ) + self._aclient = result.get("client") + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model @validate_call - async def agenerate( # type: ignore + async def agenerate( # type: ignore # noqa: C901 self, - input: ChatType, + input: FormattedInput, num_generations: int = 1, functions: Optional[List] = None, function_call: Optional[str] = None, @@ -141,34 +196,53 @@ async def agenerate( # type: ignore """ import litellm + structured_output = None + if isinstance(input, tuple): + input, structured_output = input + result = self._prepare_structured_output( + structured_output=structured_output, + client=self._aclient, + framework="litellm", + ) + self._aclient = result.get("client") + + if structured_output is None and self.structured_output is not None: + structured_output = self.structured_output + + kwargs = { + "model": self.model, + "messages": input, + "n": num_generations, + "functions": functions, + "function_call": function_call, + "temperature": temperature, + "top_p": top_p, + "stream": False, + "stop": stop, + "max_tokens": max_tokens, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "user": user, + "metadata": metadata, + "api_base": api_base, + "api_version": api_version, + "api_key": api_key, + "model_list": model_list, + "mock_response": mock_response, + "force_timeout": force_timeout, + "custom_llm_provider": custom_llm_provider, + } + if structured_output: + kwargs = self._prepare_kwargs(kwargs, structured_output) + async def _call_aclient_until_n_choices() -> List["Choices"]: choices = [] while len(choices) < num_generations: - completion = await self._aclient( # type: ignore - model=self.model, - messages=input, - n=num_generations, - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - stream=False, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - metadata=metadata, - api_base=api_base, - api_version=api_version, - api_key=api_key, - model_list=model_list, - mock_response=mock_response, - force_timeout=force_timeout, - custom_llm_provider=custom_llm_provider, - ) - choices.extend(completion.choices) + completion = await self._aclient(**kwargs) # type: ignore + if not self.structured_output: + completion = completion.choices + choices.extend(completion) return choices # litellm.drop_params is used to en/disable sending **kwargs parameters to the API if they cannot be used @@ -183,9 +257,14 @@ async def _call_aclient_until_n_choices() -> List["Choices"]: raise e generations = [] + + if self.structured_output: + generations.append([choice.model_dump_json() for choice in choices]) + return generations + for choice in choices: if (content := choice.message.content) is None: - self._logger.warning( + self._logger.warning( # type: ignore f"Received no response using LiteLLM client (model: '{self.model}')." f" Finish reason was: {choice.finish_reason}" ) diff --git a/src/distilabel/llms/llamacpp.py b/src/distilabel/llms/llamacpp.py index f8f50ff154..f66eb214b0 100644 --- a/src/distilabel/llms/llamacpp.py +++ b/src/distilabel/llms/llamacpp.py @@ -19,13 +19,11 @@ from distilabel.llms.base import LLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType if TYPE_CHECKING: from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList - from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType - class LlamaCppLLM(LLM): """llama.cpp LLM implementation running the Python bindings for the C++ code. @@ -59,6 +57,58 @@ class LlamaCppLLM(LLM): References: - [`llama.cpp`](https://github.com/ggerganov/llama.cpp) - [`llama-cpp-python`](https://github.com/abetlen/llama-cpp-python) + + Examples: + + Generate text: + + ```python + from pathlib import Path + from distilabel.llms import LlamaCppLLM + + # You can follow along this example downloading the following model running the following + # command in the terminal, that will download the model to the `Downloads` folder: + # curl -L -o ~/Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf + + model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf" + + llm = LlamaCppLLM( + model_path=str(Path.home() / model_path), + n_gpu_layers=-1, # To use the GPU if available + n_ctx=1024, # Set the context size + ) + + llm.load() + + # Call the model + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + ``` + + Generate structured data: + + ```python + from pathlib import Path + from distilabel.llms import LlamaCppLLM + + model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf" + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = LlamaCppLLM( + model_path=str(Path.home() / model_path), # type: ignore + n_gpu_layers=-1, + n_ctx=1024, + structured_output={"format": "json", "schema": Character}, + ) + + llm.load() + + # Call the model + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` """ model_path: RuntimeParameter[FilePath] = Field( @@ -76,7 +126,6 @@ class LlamaCppLLM(LLM): n_ctx: int = 512 n_batch: int = 512 seed: int = 4294967295 - verbose: RuntimeParameter[bool] = Field( default=False, description="Whether to print verbose output from llama.cpp library.", @@ -87,6 +136,10 @@ class LlamaCppLLM(LLM): " `Llama` class of `llama_cpp` library. See all the supported arguments at: " "https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__", ) + structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( + default=None, + description="The structured output format to use across all the generations.", + ) _logits_processor: Optional["LogitsProcessorList"] = PrivateAttr(default=None) _model: Optional["Llama"] = PrivateAttr(...) @@ -128,7 +181,7 @@ def model_name(self) -> str: @validate_call def generate( # type: ignore self, - inputs: List[ChatType], + inputs: List[FormattedInput], num_generations: int = 1, max_new_tokens: int = 128, frequency_penalty: float = 0.0, @@ -158,18 +211,23 @@ def generate( # type: ignore Returns: A list of lists of strings containing the generated responses for each input. """ - + structured_output = None batch_outputs = [] for input in inputs: + if isinstance(input, tuple): + input, structured_output = input + elif self.structured_output: + structured_output = self.structured_output + outputs = [] for _ in range(num_generations): # NOTE(plaguss): There seems to be a bug in how the logits processor # is used. Basically it consumes the FSM internally, and it isn't reinitialized # after each generation, so subsequent calls yield nothing. This is a workaround # until is fixed in the `llama_cpp` or `outlines` libraries. - if self.structured_output: + if structured_output: self._logits_processor = self._prepare_structured_output( - self.structured_output + structured_output ) chat_completions: "CreateChatCompletionResponse" = ( self._model.create_chat_completion( # type: ignore @@ -188,7 +246,7 @@ def generate( # type: ignore return batch_outputs def _prepare_structured_output( - self, structured_output: Optional["StructuredOutputType"] = None + self, structured_output: Optional[OutlinesStructuredOutputType] = None ) -> Union["LogitsProcessorList", None]: """Creates the appropriate function to filter tokens to generate structured outputs. @@ -203,6 +261,6 @@ def _prepare_structured_output( ) result = prepare_guided_output(structured_output, "llamacpp", self._model) - if schema := result.get("schema"): + if (schema := result.get("schema")) and self.structured_output: self.structured_output["schema"] = schema return result["processor"] diff --git a/src/distilabel/llms/mistral.py b/src/distilabel/llms/mistral.py index d05d9d3f65..ed1c3af7d5 100644 --- a/src/distilabel/llms/mistral.py +++ b/src/distilabel/llms/mistral.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import os -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Optional from pydantic import Field, PrivateAttr, SecretStr, validate_call -from typing_extensions import override from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType -from distilabel.utils.itertools import grouper +from distilabel.steps.tasks.typing import ( + FormattedInput, + InstructorStructuredOutputType, +) if TYPE_CHECKING: from mistralai.async_client import MistralAsyncClient @@ -45,6 +45,9 @@ class MistralLLM(AsyncLLM): timeout: the maximum time in seconds to wait for a response. Defaults to `120`. max_concurrent_requests: the maximum number of concurrent requests to send. Defaults to `64`. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. _api_key_env_var: the name of the environment variable to use for the API key. It is meant to be used internally. _aclient: the `MistralAsyncClient` to use for the Mistral API. It is meant to be used internally. @@ -57,6 +60,42 @@ class MistralLLM(AsyncLLM): - `timeout`: the maximum time in seconds to wait for a response. Defaults to `120`. - `max_concurrent_requests`: the maximum number of concurrent requests to send. Defaults to `64`. + + Examples: + + Generate text: + + ```python + from distilabel.llms import MistralLLM + + llm = MistralLLM(model="open-mixtral-8x22b") + + llm.load() + + # Call the model + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + + Generate structured data: + + ```python + from pydantic import BaseModel + from distilabel.llms import MistralLLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = MistralLLM( + model="open-mixtral-8x22b", + api_key="api.key", + structured_output={"schema": User} + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` """ model: str @@ -77,6 +116,14 @@ class MistralLLM(AsyncLLM): max_concurrent_requests: RuntimeParameter[int] = Field( default=64, description="The maximum number of concurrent requests to send." ) + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( + Field( + default=None, + description="The structured output format to use across all the generations.", + ) + ) + + _num_generations_param_supported = False _api_key_env_var: str = PrivateAttr(_MISTRALAI_API_KEY_ENV_VAR_NAME) _aclient: Optional["MistralAsyncClient"] = PrivateAttr(...) @@ -102,11 +149,21 @@ def load(self) -> None: self._aclient = MistralAsyncClient( api_key=self.api_key.get_secret_value(), endpoint=self.endpoint, - max_retries=self.max_retries, - timeout=self.timeout, - max_concurrent_requests=self.max_concurrent_requests, + max_retries=self.max_retries, # type: ignore + timeout=self.timeout, # type: ignore + max_concurrent_requests=self.max_concurrent_requests, # type: ignore ) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="mistral", + ) + self._aclient = result.get("client") # type: ignore + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" @@ -116,7 +173,7 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: FormattedInput, max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, @@ -134,45 +191,44 @@ async def agenerate( # type: ignore Returns: A list of lists of strings containing the generated responses for each input. """ - completion = await self._aclient.chat( # type: ignore - messages=input, - model=self.model, - temperature=temperature, - max_tokens=max_new_tokens, - top_p=top_p, - ) + structured_output = None + if isinstance(input, tuple): + input, structured_output = input + result = self._prepare_structured_output( + structured_output=structured_output, + client=self._aclient, + framework="mistral", + ) + self._aclient = result.get("client") + + if structured_output is None and self.structured_output is not None: + structured_output = self.structured_output + + kwargs = { + "messages": input, # type: ignore + "model": self.model, + "max_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + } generations = [] + if structured_output: + kwargs = self._prepare_kwargs(kwargs, structured_output) + # TODO: This should work just with the _aclient.chat method, but it's not working. + # We need to check instructor and see if we can create a PR. + completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore + else: + completion = await self._aclient.chat(**kwargs) # type: ignore + + if structured_output: + generations.append(completion.model_dump_json()) + return generations + for choice in completion.choices: if (content := choice.message.content) is None: - self._logger.warning( + self._logger.warning( # type: ignore f"Received no response using MistralAI client (model: '{self.model}')." f" Finish reason was: {choice.finish_reason}" ) generations.append(content) return generations - - # TODO: remove this function once Mistral client allows `n` parameter - @override - def generate( - self, - inputs: List["ChatType"], - num_generations: int = 1, - **kwargs: Any, - ) -> List["GenerateOutput"]: - """Method to generate a list of responses asynchronously, returning the output - synchronously awaiting for the response of each input sent to `agenerate`. - """ - - async def agenerate( - inputs: List["ChatType"], **kwargs: Any - ) -> "GenerateOutput": - """Internal function to parallelize the asynchronous generation of responses.""" - tasks = [ - asyncio.create_task(self.agenerate(input=input, **kwargs)) - for input in inputs - for _ in range(num_generations) - ] - return [outputs[0] for outputs in await asyncio.gather(*tasks)] - - outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs)) - return list(grouper(outputs, n=num_generations, incomplete="ignore")) diff --git a/src/distilabel/llms/mixins.py b/src/distilabel/llms/mixins.py index 9146a0ef4f..1d2e8b35a0 100644 --- a/src/distilabel/llms/mixins.py +++ b/src/distilabel/llms/mixins.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, Generator, List, Literal, Union +import portalocker from pydantic import BaseModel, Field, PrivateAttr -if TYPE_CHECKING: - from multiprocessing.managers import DictProxy - from multiprocessing.synchronize import Lock +_CUDA_DEVICE_PLACEMENT_MIXIN_FILE = ( + Path(tempfile.gettempdir()) / "distilabel_cuda_device_placement_mixin.json" +) class CudaDevicePlacementMixin(BaseModel): @@ -44,11 +49,7 @@ class CudaDevicePlacementMixin(BaseModel): cuda_devices: Union[List[int], Literal["auto"]] = Field(default="auto") _llm_identifier: Union[str, None] = PrivateAttr(default=None) - _device_llm_placement_map: Union["DictProxy[str, Any]", None] = PrivateAttr( - default=None - ) - _device_llm_placement_lock: Union["Lock", None] = PrivateAttr(default=None) - _available_cuda_devices: Union[List[int], None] = PrivateAttr(default=None) + _available_cuda_devices: List[int] = PrivateAttr(default_factory=list) _can_check_cuda_devices: bool = PrivateAttr(default=False) def load(self) -> None: @@ -77,29 +78,40 @@ def load(self) -> None: self._assign_cuda_devices() - def set_device_placement_info( - self, - llm_identifier: str, - device_llm_placement_map: "DictProxy[str, Any]", - device_llm_placement_lock: "Lock", - ) -> None: - """Sets the value of `_device_llm_placement_map` to be used to assign CUDA devices - to the LLM. + def unload(self) -> None: + """Unloads the LLM and removes the CUDA devices assigned to it from the device + placement information provided in `_device_llm_placement_map`.""" + with self._device_llm_placement_map() as device_map: + if self._llm_identifier in device_map: + self._logger.debug( + f"Removing '{self._llm_identifier}' from the CUDA device map file" + f" '{_CUDA_DEVICE_PLACEMENT_MIXIN_FILE}'." + ) + del device_map[self._llm_identifier] - Args: - llm_identifier: the identifier of the LLM to be used as key in the device - placement information. - device_llm_placement_map: a dictionary with the device placement information for - each LLM. It should have two keys. The first key is "lock" and its value is - a lock object to be used to synchronize the access to the device placement - information. The second key is "value" and its value is a dictionary with the - device placement information for each LLM. - device_llm_placement_lock: a lock object to be used to synchronize the access to - `_device_llm_placement_map`. + @contextmanager + def _device_llm_placement_map(self) -> Generator[Dict[str, List[int]], None, None]: + """Reads the content of the device placement file of the node with a lock, yields + the content, and writes the content back to the file after the context manager is + closed. If the file doesn't exist, an empty dictionary will be yielded. + + Yields: + The content of the device placement file. """ - self._llm_identifier = llm_identifier - self._device_llm_placement_map = device_llm_placement_map - self._device_llm_placement_lock = device_llm_placement_lock + _CUDA_DEVICE_PLACEMENT_MIXIN_FILE.touch() + with portalocker.Lock( + _CUDA_DEVICE_PLACEMENT_MIXIN_FILE, + "r+", + flags=portalocker.LockFlags.EXCLUSIVE, + ) as f: + try: + content = json.load(f) + except json.JSONDecodeError: + content = {} + yield content + f.seek(0) + f.truncate() + f.write(json.dumps(content)) def _assign_cuda_devices(self) -> None: """Assigns CUDA devices to the LLM based on the device placement information provided @@ -109,16 +121,14 @@ def _assign_cuda_devices(self) -> None: checked if the devices are available to be used by the LLM. If not, a warning will be logged.""" - if self._device_llm_placement_map is not None: - with self._device_llm_placement_lock: # type: ignore - if self.cuda_devices == "auto": - self.cuda_devices = [ - self._get_cuda_device(self._device_llm_placement_map) - ] - else: - self._check_cuda_devices(self._device_llm_placement_map) + # Take the lock and read the device placement information for each LLM. + with self._device_llm_placement_map() as device_map: + if self.cuda_devices == "auto": + self.cuda_devices = [self._get_cuda_device(device_map)] + else: + self._check_cuda_devices(device_map) - self._device_llm_placement_map[self._llm_identifier] = self.cuda_devices # type: ignore + device_map[self._llm_identifier] = self.cuda_devices # type: ignore # `_device_llm_placement_map` was not provided and user didn't set the `cuda_devices` # attribute. In this case, the `cuda_devices` attribute will be set to an empty list. diff --git a/src/distilabel/llms/moa.py b/src/distilabel/llms/moa.py new file mode 100644 index 0000000000..d139da87e3 --- /dev/null +++ b/src/distilabel/llms/moa.py @@ -0,0 +1,275 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import itertools +from typing import TYPE_CHECKING, Any, Dict, List, Union, cast + +from pydantic import Field + +from distilabel.llms.base import LLM, AsyncLLM +from distilabel.steps.tasks.typing import StandardInput + +if TYPE_CHECKING: + from distilabel.llms.typing import GenerateOutput + from distilabel.mixins.runtime_parameters import RuntimeParametersNames + from distilabel.steps.tasks.typing import FormattedInput + +# Mixture-of-Agents system prompt from the paper with the addition instructing the LLM +# to not mention that it used responses from previous models to avoid having texts like +# "Based on the previous responses..." in the completion. +MOA_SYSTEM_PROMPT = ( + "You have been provided with a set of responses from various open-source models to the" + " latest user query. Your task is to synthesize these responses into a single, high-quality" + " response. It is crucial to critically evaluate the information provided in these responses," + " recognizing that some of it may be biased or incorrect. Your response should not simply" + " replicate the given answers but should offer a refined, accurate, and comprehensive" + " reply to the instruction. Ensure your response is well-structured, coherent, and adheres" + " to the highest standards of accuracy and reliability. Do not mention that you have used" + " the responses from previous models." + "\nResponses from models:" +) + + +class MixtureOfAgentsLLM(AsyncLLM): + """`Mixture-of-Agents` implementation. + + An `LLM` class that leverages `LLM`s collective strenghts to generate a response, + as described in the "Mixture-of-Agents Enhances Large Language model Capabilities" + paper. There is a list of `LLM`s proposing/generating outputs that `LLM`s from the next + round/layer can use as auxiliary information. Finally, there is an `LLM` that aggregates + the outputs to generate the final response. + + Attributes: + aggregator_llm: The `LLM` that aggregates the outputs of the proposer `LLM`s. + proposers_llms: The list of `LLM`s that propose outputs to be aggregated. + rounds: The number of layers or rounds that the `proposers_llms` will generate + outputs. Defaults to `1`. + + References: + - [Mixture-of-Agents Enhances Large Language Model Capabilities](https://arxiv.org/abs/2406.04692) + + Examples: + + Generate text: + + ```python + from distilabel.llms import MixtureOfAgentsLLM, InferenceEndpointsLLM + + llm = MixtureOfAgentsLLM( + aggregator_llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + proposers_llms=[ + InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + InferenceEndpointsLLM( + model_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", + tokenizer_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", + ), + InferenceEndpointsLLM( + model_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + tokenizer_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + ), + ], + rounds=2, + ) + + llm.load() + + output = llm.generate( + inputs=[ + [ + { + "role": "user", + "content": "My favorite witty review of The Rings of Power series is this: Input:", + } + ] + ] + ) + ``` + """ + + aggregator_llm: LLM + proposers_llms: List[AsyncLLM] = Field(default_factory=list) + rounds: int = 1 + + @property + def runtime_parameters_names(self) -> "RuntimeParametersNames": + """Returns the runtime parameters of the `LLM`, which are a combination of the + `RuntimeParameter`s of the `LLM`, the `aggregator_llm` and the `proposers_llms`. + + Returns: + The runtime parameters of the `LLM`. + """ + runtime_parameters_names = super().runtime_parameters_names + del runtime_parameters_names["generation_kwargs"] + return runtime_parameters_names + + def load(self) -> None: + """Loads all the `LLM`s in the `MixtureOfAgents`.""" + super().load() + + for llm in self.proposers_llms: + self._logger.debug(f"Loading proposer LLM in MoA: {llm}") # type: ignore + llm.load() + + self._logger.debug(f"Loading aggregator LLM in MoA: {self.aggregator_llm}") # type: ignore + self.aggregator_llm.load() + + @property + def model_name(self) -> str: + """Returns the aggregated model name.""" + return f"moa-{self.aggregator_llm.model_name}-{'-'.join([llm.model_name for llm in self.proposers_llms])}" + + def get_generation_kwargs(self) -> Dict[str, Any]: + """Returns the generation kwargs of the `MixtureOfAgents` as a dictionary. + + Returns: + The generation kwargs of the `MixtureOfAgents`. + """ + return { + "aggregator_llm": self.aggregator_llm.get_generation_kwargs(), + "proposers_llms": [ + llm.get_generation_kwargs() for llm in self.proposers_llms + ], + } + + # `abstractmethod`, had to be implemented but not used + async def agenerate( + self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any + ) -> List[Union[str, None]]: + raise NotImplementedError( + "`agenerate` method is not implemented for `MixtureOfAgents`" + ) + + def _build_moa_system_prompt(self, prev_outputs: List[str]) -> str: + """Builds the Mixture-of-Agents system prompt. + + Args: + prev_outputs: The list of previous outputs to use as references. + + Returns: + The Mixture-of-Agents system prompt. + """ + moa_system_prompt = MOA_SYSTEM_PROMPT + for i, prev_output in enumerate(prev_outputs): + if prev_output is not None: + moa_system_prompt += f"\n{i + 1}. {prev_output}" + return moa_system_prompt + + def _inject_moa_system_prompt( + self, input: "StandardInput", prev_outputs: List[str] + ) -> "StandardInput": + """Injects the Mixture-of-Agents system prompt into the input. + + Args: + input: The input to inject the system prompt into. + prev_outputs: The list of previous outputs to use as references. + + Returns: + The input with the Mixture-of-Agents system prompt injected. + """ + if len(prev_outputs) == 0: + return input + + moa_system_prompt = self._build_moa_system_prompt(prev_outputs) + + system = next((item for item in input if item["role"] == "system"), None) + if system: + original_system_prompt = system["content"] + system["content"] = f"{moa_system_prompt}\n\n{original_system_prompt}" + else: + input.insert(0, {"role": "system", "content": moa_system_prompt}) + + return input + + async def _agenerate( + self, + inputs: List["FormattedInput"], + num_generations: int = 1, + **kwargs: Any, + ) -> List["GenerateOutput"]: + """Internal function to concurrently generate responses for a list of inputs. + + Args: + inputs: the list of inputs to generate responses for. + num_generations: the number of generations to generate per input. + **kwargs: the additional kwargs to be used for the generation. + + Returns: + A list containing the generations for each input. + """ + aggregator_llm_kwargs: Dict[str, Any] = kwargs.get("aggregator_llm", {}) + proposers_llms_kwargs: List[Dict[str, Any]] = kwargs.get( + "proposers_llms", [{}] * len(self.proposers_llms) + ) + + prev_outputs = [] + for round in range(self.rounds): + self._logger.debug(f"Generating round {round + 1}/{self.rounds} in MoA") # type: ignore + + # Generate `num_generations` with each proposer LLM for each input + tasks = [ + asyncio.create_task( + llm._agenerate( + inputs=[ + self._inject_moa_system_prompt( + cast("StandardInput", input), prev_input_outputs + ) + for input, prev_input_outputs in itertools.zip_longest( + inputs, prev_outputs, fillvalue=[] + ) + ], + num_generations=1, + **generation_kwargs, + ) + ) + for llm, generation_kwargs in zip( + self.proposers_llms, proposers_llms_kwargs + ) + ] + + # Group generations per input + outputs: List[List["GenerateOutput"]] = await asyncio.gather(*tasks) + prev_outputs = [ + list(itertools.chain(*input_outputs)) for input_outputs in zip(*outputs) + ] + + self._logger.debug("Aggregating outputs in MoA") # type: ignore + if isinstance(self.aggregator_llm, AsyncLLM): + return await self.aggregator_llm._agenerate( + inputs=[ + self._inject_moa_system_prompt( + cast("StandardInput", input), prev_input_outputs + ) + for input, prev_input_outputs in zip(inputs, prev_outputs) + ], + num_generations=num_generations, + **aggregator_llm_kwargs, + ) + + return self.aggregator_llm.generate( + inputs=[ + self._inject_moa_system_prompt( + cast("StandardInput", input), prev_input_outputs + ) + for input, prev_input_outputs in zip(inputs, prev_outputs) + ], + num_generations=num_generations, + **aggregator_llm_kwargs, + ) diff --git a/src/distilabel/llms/ollama.py b/src/distilabel/llms/ollama.py index fb06f1eed3..bd664b30db 100644 --- a/src/distilabel/llms/ollama.py +++ b/src/distilabel/llms/ollama.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Literal, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Union from pydantic import Field, PrivateAttr, validate_call from typing_extensions import TypedDict from distilabel.llms.base import AsyncLLM +from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput if TYPE_CHECKING: from ollama import AsyncClient @@ -88,6 +89,14 @@ class OllamaLLM(AsyncLLM): default=120, description="The timeout for the Ollama API." ) follow_redirects: bool = True + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( + Field( + default=None, + description="The structured output format to use across all the generations.", + ) + ) + + _num_generations_param_supported = False _aclient: Optional["AsyncClient"] = PrivateAttr(...) @@ -117,19 +126,17 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, - num_generations: int = 1, + input: StandardInput, format: Literal["", "json"] = "", # TODO: include relevant options from `Options` in `agenerate` method. options: Union[Options, None] = None, keep_alive: Union[bool, None] = None, - ) -> List[str]: + ) -> GenerateOutput: """ Generates a response asynchronously, using the [Ollama Async API definition](https://github.com/ollama/ollama-python). Args: input: the input to use for the generation. - num_generations: the number of generations to produce. Defaults to `1`. format: the format to use for the generation. Defaults to `""`. options: the options to use for the generation. Defaults to `None`. keep_alive: whether to keep the connection alive. Defaults to `None`. @@ -137,10 +144,9 @@ async def agenerate( # type: ignore Returns: A list of strings as completion for the given input. """ - generations = [] - # TODO: remove this for-loop and override the `generate` method - for _ in range(num_generations): - completion = await self._aclient.chat( # type: ignore + text = None + try: + completion: Dict[str, Any] = await self._aclient.chat( # type: ignore model=self.model, messages=input, # type: ignore stream=False, @@ -148,7 +154,11 @@ async def agenerate( # type: ignore options=options, keep_alive=keep_alive, ) - # TODO: improve error handling - generations.append(completion["message"]["content"]) + text = completion["message"]["content"] + except Exception as e: + self._logger.warning( # type: ignore + f"⚠️ Received no response using Ollama client (model: '{self.model_name}')." + f" Finish reason was: {e}" + ) - return generations + return [text] diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py index 7314f7c74d..42555b3649 100644 --- a/src/distilabel/llms/openai.py +++ b/src/distilabel/llms/openai.py @@ -20,7 +20,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType if TYPE_CHECKING: from openai import AsyncOpenAI @@ -45,8 +45,9 @@ class OpenAILLM(AsyncLLM): failing. Defaults to `6`. timeout: the maximum time in seconds to wait for a response from the API. Defaults to `120`. - structured_output: a dictionary containing the structured output configuration or if more - fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. + structured_output: a dictionary containing the structured output configuration configuration + using `instructor`. You can take a look at the dictionary structure in + `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. Runtime parameters: - `base_url`: the base URL to use for the OpenAI API requests. Defaults to `None`. @@ -59,6 +60,57 @@ class OpenAILLM(AsyncLLM): Icon: `:simple-openai:` + + Examples: + + Generate text: + + ```python + from distilabel.llms import OpenAILLM + + llm = OpenAILLM(model="gpt-4-turbo", api_key="api.key") + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + ``` + + Generate text from a custom endpoint following the OpenAI API: + + ```python + from distilabel.llms import OpenAILLM + + llm = OpenAILLM( + model="prometheus-eval/prometheus-7b-v2.0", + base_url=r"http://localhost:8080/v1" + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + ``` + + Generate structured data: + + ```python + from pydantic import BaseModel + from distilabel.llms import OpenAILLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = OpenAILLM( + model="gpt-4-turbo", + api_key="api.key", + structured_output={"schema": User} + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` """ model: str @@ -81,6 +133,12 @@ class OpenAILLM(AsyncLLM): default=120, description="The maximum time in seconds to wait for a response from the API.", ) + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( + Field( + default=None, + description="The structured output format to use across all the generations.", + ) + ) _api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME) _aclient: Optional["AsyncOpenAI"] = PrivateAttr(...) @@ -110,6 +168,16 @@ def load(self) -> None: timeout=self.timeout, ) + if self.structured_output: + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="openai", + ) + self._aclient = result.get("client") # type: ignore + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" @@ -118,7 +186,7 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: FormattedInput, num_generations: int = 1, max_new_tokens: int = 128, frequency_penalty: float = 0.0, @@ -162,20 +230,42 @@ async def agenerate( # type: ignore f"Invalid response format '{response_format}'. Must be either 'text' or 'json'." ) - completion = await self._aclient.chat.completions.create( # type: ignore - messages=input, # type: ignore - model=self.model, - max_tokens=max_new_tokens, - n=num_generations, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - temperature=temperature, - top_p=top_p, - stop=stop, - timeout=50, - response_format={"type": response_format}, - ) + structured_output = None + if isinstance(input, tuple): + input, structured_output = input + result = self._prepare_structured_output( + structured_output=structured_output, + client=self._aclient, + framework="openai", + ) + self._aclient = result.get("client") + + if structured_output is None and self.structured_output is not None: + structured_output = self.structured_output + + kwargs = { + "messages": input, # type: ignore + "model": self.model, + "max_tokens": max_new_tokens, + "n": num_generations, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "temperature": temperature, + "top_p": top_p, + "stop": stop, + "timeout": 50, + "response_format": {"type": response_format}, + } + if structured_output: + kwargs = self._prepare_kwargs(kwargs, structured_output) + generations = [] + completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore + + if structured_output: + generations.append(completion.model_dump_json()) + return generations + for choice in completion.choices: if (content := choice.message.content) is None: self._logger.warning( # type: ignore diff --git a/src/distilabel/llms/together.py b/src/distilabel/llms/together.py index 28b9d23571..aa63ae1ad5 100644 --- a/src/distilabel/llms/together.py +++ b/src/distilabel/llms/together.py @@ -37,6 +37,20 @@ class TogetherLLM(OpenAILLM): used, or `None` if not set. _api_key_env_var: the name of the environment variable to use for the API key. It is meant to be used internally. + + Examples: + + Generate text: + + ```python + from distilabel.llms import AnyscaleLLM + + llm = TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key="api.key") + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + ``` """ base_url: Optional[RuntimeParameter[str]] = Field( diff --git a/src/distilabel/llms/vertexai.py b/src/distilabel/llms/vertexai.py index 28ceee3a0c..f89a7b0912 100644 --- a/src/distilabel/llms/vertexai.py +++ b/src/distilabel/llms/vertexai.py @@ -18,22 +18,10 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput if TYPE_CHECKING: - from vertexai.generative_models import Content, GenerativeModel - - -def _is_gemini_model(model: str) -> bool: - """Returns `True` if the model is a model from the Vertex AI Gemini API. - - Args: - model (str): the model name to be checked. - - Returns: - bool: `True` if the model is a model from the Vertex AI Gemini API. - """ - return "gemini" in model + from vertexai.generative_models import Content, GenerationResponse, GenerativeModel class VertexAILLM(AsyncLLM): @@ -59,6 +47,8 @@ class VertexAILLM(AsyncLLM): model: str + _num_generations_param_supported = False + _aclient: Optional["GenerativeModel"] = PrivateAttr(...) def load(self) -> None: @@ -87,7 +77,7 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model - def _chattype_to_content(self, input: "ChatType") -> List["Content"]: + def _chattype_to_content(self, input: "StandardInput") -> List["Content"]: """Converts a chat type to a list of content items expected by the API. Args: @@ -114,8 +104,7 @@ def _chattype_to_content(self, input: "ChatType") -> List["Content"]: @validate_call async def agenerate( # type: ignore self, - input: ChatType, - num_generations: int = 1, + input: StandardInput, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -128,8 +117,6 @@ async def agenerate( # type: ignore Args: input: a single input in chat format to generate responses for. - num_generations: the number of generations to create per input. Defaults to - `1`. temperature: Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to `None`. top_p: If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to `None`. top_k: If specified, top-k sampling will be used. Defaults to `None`. @@ -143,33 +130,40 @@ async def agenerate( # type: ignore """ from vertexai.generative_models import GenerationConfig - contents = self._chattype_to_content(input) - generations = [] - # TODO: remove this for-loop and override `generate` - for _ in range(num_generations): - content = await self._aclient.generate_content_async( # type: ignore - contents=contents, - generation_config=GenerationConfig( - candidate_count=1, # only one candidate allowed per call - temperature=temperature, - top_k=top_k, - top_p=top_p, - max_output_tokens=max_output_tokens, - stop_sequences=stop_sequences, - ), - safety_settings=safety_settings, - tools=tools, - stream=False, + content: "GenerationResponse" = await self._aclient.generate_content_async( # type: ignore + contents=self._chattype_to_content(input), + generation_config=GenerationConfig( + candidate_count=1, # only one candidate allowed per call + temperature=temperature, + top_k=top_k, + top_p=top_p, + max_output_tokens=max_output_tokens, + stop_sequences=stop_sequences, + ), + safety_settings=safety_settings, # type: ignore + tools=tools, # type: ignore + stream=False, + ) + + text = None + try: + text = content.candidates[0].text + except ValueError: + self._logger.warning( # type: ignore + f"Received no response using VertexAI client (model: '{self.model}')." + f" Finish reason was: '{content.candidates[0].finish_reason}'." ) - text = None - try: - text = content.candidates[0].text - except ValueError: - self._logger.warning( - f"Received no response using VertexAI client (model: '{self.model}')." - f" Finish reason was: '{content.candidates[0].finish_reason}'." - ) - generations.append(text) + return [text] - return generations + +def _is_gemini_model(model: str) -> bool: + """Returns `True` if the model is a model from the Vertex AI Gemini API. + + Args: + model (str): the model name to be checked. + + Returns: + bool: `True` if the model is a model from the Vertex AI Gemini API. + """ + return "gemini" in model diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index 00b3807465..ee124c7dfa 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -12,8 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union - +import json +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Union, +) + +import numpy as np from pydantic import Field, PrivateAttr, validate_call from distilabel.llms.base import LLM @@ -21,14 +33,12 @@ from distilabel.llms.mixins import CudaDevicePlacementMixin from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType if TYPE_CHECKING: from transformers import PreTrainedTokenizer from vllm import LLM as _vLLM - from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType - SamplingParams = None @@ -72,6 +82,47 @@ class vLLM(LLM, CudaDevicePlacementMixin): Runtime parameters: - `extra_kwargs`: additional dictionary of keyword arguments that will be passed to the `LLM` class of `vllm` library. + + Examples: + + Generate text: + + ```python + from distilabel.llms import vLLM + + # You can pass a custom chat_template to the model + llm = vLLM( + model="prometheus-eval/prometheus-7b-v2.0", + chat_template="[INST] {{ messages[0]\"content\" }}\\n{{ messages[1]\"content\" }}[/INST]", + ) + + llm.load() + + # Call the model + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) + ``` + + Generate structured data: + + ```python + from pathlib import Path + from distilabel.llms import vLLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = vLLM( + model="prometheus-eval/prometheus-7b-v2.0" + structured_output={"format": "json", "schema": Character}, + ) + + llm.load() + + # Call the model + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` """ model: str @@ -94,6 +145,10 @@ class vLLM(LLM, CudaDevicePlacementMixin): " `vLLM` class of `vllm` library. See all the supported arguments at: " "https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py", ) + structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( + default=None, + description="The structured output format to use across all the generations.", + ) _model: Optional["_vLLM"] = PrivateAttr(...) _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(...) @@ -148,12 +203,17 @@ def load(self) -> None: self.structured_output ) + def unload(self) -> None: + """Unloads the `vLLM` model.""" + CudaDevicePlacementMixin.unload(self) + super().unload() + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model - def prepare_input(self, input: "ChatType") -> str: + def prepare_input(self, input: "FormattedInput") -> str: """Prepares the input by applying the chat template to the input, which is formatted as an OpenAI conversation, and adding the generation prompt. """ @@ -163,10 +223,52 @@ def prepare_input(self, input: "ChatType") -> str: add_generation_prompt=True, # type: ignore ) + def _prepare_batches( + self, inputs: List[FormattedInput] + ) -> Tuple[List[List[FormattedInput]], List[int]]: + """Prepares the inputs by grouping them by the structured output. + + When we generate structured outputs with schemas obtained from a dataset, we need to + prepare the data to try to send batches of inputs instead of single inputs to the model + to take advante of the engine. So we group the inputs by the structured output to be + passed in the `generate` method. + + Args: + inputs: The batch of inputs passed to the generate method. As we expect to be generating + structured outputs, each element will be a tuple containing the instruction and the + structured output. + + Returns: + The prepared batches (sub-batches let's say) to be passed to the `generate` method. + Each new tuple will contain instead of the single instruction, a list of instructions + """ + instruction_order = {} + batches = {} + for i, (instruction, structured_output) in enumerate(inputs): + instruction = self.prepare_input(instruction) + instruction_order[instruction] = i + structured_output = json.dumps(structured_output) + if structured_output not in batches: + batches[structured_output] = [instruction] + else: + batches[structured_output].append(instruction) + + # Flatten the instructions in prepared_data + flat_instructions = [ + instruction for _, group in batches.items() for instruction in group + ] + # Generate the list of indices based on the original order + sorted_indices = [ + instruction_order[instruction] for instruction in flat_instructions + ] + return [ + (batch, json.loads(schema)) for schema, batch in batches.items() + ], sorted_indices + @validate_call def generate( # type: ignore self, - inputs: List[ChatType], + inputs: List[FormattedInput], num_generations: int = 1, max_new_tokens: int = 128, frequency_penalty: float = 0.0, @@ -198,36 +300,62 @@ def generate( # type: ignore Returns: A list of lists of strings containing the generated responses for each input. """ - prepared_inputs = [self.prepare_input(input) for input in inputs] if extra_sampling_params is None: extra_sampling_params = {} + structured_output = None + needs_sorting = False + + if isinstance(inputs[0], tuple): + prepared_batches, sorted_indices = self._prepare_batches(inputs) + needs_sorting = True + else: + # Simulate a batch without the structured output content + prepared_batches = [([self.prepare_input(input) for input in inputs], None)] + + # In case we have a single structured output for the dataset, we can + logits_processors = None + if self._logits_processor: + logits_processors = [self._logits_processor] + + batched_outputs = [] + + for prepared_inputs, structured_output in prepared_batches: + if structured_output: + logits_processors = [self._prepare_structured_output(structured_output)] + + sampling_params = SamplingParams( # type: ignore + n=num_generations, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_tokens=max_new_tokens, + logits_processors=logits_processors, + **extra_sampling_params, + ) - sampling_params = SamplingParams( # type: ignore - n=num_generations, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - max_tokens=max_new_tokens, - logits_processors=( - [self._logits_processor] if self._logits_processor else None - ), - **extra_sampling_params, - ) + batch_outputs = self._model.generate( # type: ignore + prepared_inputs, + sampling_params, + use_tqdm=False, # type: ignore + ) - batch_outputs = self._model.generate( # type: ignore - prepared_inputs, - sampling_params, - use_tqdm=False, # type: ignore - ) - return [ - [output.text for output in outputs.outputs] for outputs in batch_outputs - ] + batched_outputs += [ + [output.text for output in outputs.outputs] for outputs in batch_outputs + ] + + # If logits_processor is set, we need to sort the outputs back to the original order + # (would be needed only if we have multiple structured outputs in the dataset) + if needs_sorting: + batched_outputs = _sort_batches( + batched_outputs, sorted_indices, num_generations=num_generations + ) + return batched_outputs def _prepare_structured_output( - self, structured_output: Optional["StructuredOutputType"] = None + self, structured_output: Optional[OutlinesStructuredOutputType] = None ) -> Union[Callable, None]: """Creates the appropriate function to filter tokens to generate structured outputs. @@ -242,6 +370,51 @@ def _prepare_structured_output( ) result = prepare_guided_output(structured_output, "vllm", self._model) - if schema := result.get("schema"): + if (schema := result.get("schema")) and self.structured_output: self.structured_output["schema"] = schema return result["processor"] + + +def _sort_batches( + batches: List[List[FormattedInput]], indices: List[int], num_generations: int = 1 +) -> List[str]: + """Helper function to sort back the mini-batches generated by the model. + + It must take into account the number of `num_generations` to repeat the indices + accordingly. + + Args: + batches: The mini-batches generated by the model. + indices: The indices that would sort the mini-batches back to the original order. + num_generations: The number of generations requested to vLLM. Defaults to 1. + + Returns: + Sorted batched_outputs. + """ + batch_sizes = [len(batch) for batch in batches] + flattened_batches = np.array([b for batch in batches for b in batch]) + sorted_batches = np.take_along_axis( + flattened_batches, + np.argsort(np.repeat(indices, num_generations)), + axis=0, + ).tolist() + sorted_batches = _batchify(sorted_batches, batch_sizes) + return sorted_batches + + +def _batchify(sorted_batches: List[str], batch_sizes: List[int]) -> List[List[str]]: + """Helper function to regenerate the sorted batches from the flattened sorted ones. + + Args: + sorted_batches: Output obtained from the `_sort_batches` function. + batch_sizes: The batch sizes to be used to split the sorted batches. + + Returns: + Batched sorted batches in the original shape. + """ + batches = [] + idx = 0 + for bs in batch_sizes: + batches.append(sorted_batches[idx : idx + bs]) + idx += bs + return batches diff --git a/src/distilabel/mixins/runtime_parameters.py b/src/distilabel/mixins/runtime_parameters.py index 5959244803..a7dd848f17 100644 --- a/src/distilabel/mixins/runtime_parameters.py +++ b/src/distilabel/mixins/runtime_parameters.py @@ -31,6 +31,10 @@ """Used to mark the attributes of a `Step` as a runtime parameter.""" RuntimeParametersNames = Dict[str, Union[bool, "RuntimeParametersNames"]] +"""Alias for the names of the runtime parameters of a `Step`.""" + +RuntimeParameterInfo = Dict[str, Any] +"""Alias for the information of the runtime parameters of a `Step`.""" class RuntimeParametersMixin(BaseModel): @@ -45,7 +49,7 @@ class RuntimeParametersMixin(BaseModel): _runtime_parameters: Dict[str, Any] = PrivateAttr(default_factory=dict) @property - def runtime_parameters_names(self) -> RuntimeParametersNames: + def runtime_parameters_names(self) -> "RuntimeParametersNames": """Returns a dictionary containing the name of the runtime parameters of the class as keys and whether the parameter is required or not as values. @@ -57,18 +61,27 @@ def runtime_parameters_names(self) -> RuntimeParametersNames: runtime_parameters = {} for name, field_info in self.model_fields.items(): # type: ignore + # `field: RuntimeParameter[Any]` or `field: Optional[RuntimeParameter[Any]]` is_runtime_param, is_optional = _is_runtime_parameter(field_info) if is_runtime_param: runtime_parameters[name] = is_optional continue attr = getattr(self, name) + + # `field: RuntimeParametersMixin` if isinstance(attr, RuntimeParametersMixin): runtime_parameters[name] = attr.runtime_parameters_names + # `field: List[RuntiemParametersMixin]` + if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin): + runtime_parameters[name] = { + str(i): item.runtime_parameters_names for i, item in enumerate(attr) + } + return runtime_parameters - def get_runtime_parameters_info(self) -> List[Dict[str, Any]]: + def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]: """Gets the information of the runtime parameters of the class such as the name and the description. This function is meant to include the information of the runtime parameters in the serialized data of the class. @@ -82,6 +95,8 @@ def get_runtime_parameters_info(self) -> List[Dict[str, Any]]: continue attr = getattr(self, name) + + # Get runtime parameters info for `RuntimeParametersMixin` field if isinstance(attr, RuntimeParametersMixin): runtime_parameters_info.append( { @@ -91,6 +106,19 @@ def get_runtime_parameters_info(self) -> List[Dict[str, Any]]: ) continue + # Get runtime parameters info for `List[RuntimeParametersMixin]` field + if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin): + runtime_parameters_info.append( + { + "name": name, + "runtime_parameters_info": { + str(i): item.get_runtime_parameters_info() + for i, item in enumerate(attr) + }, + } + ) + continue + info = {"name": name, "optional": self.runtime_parameters_names[name]} if field_info.description is not None: info["description"] = field_info.description @@ -115,27 +143,38 @@ def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None: name, runtime_parameters_names, cutoff=0.5 ) msg = ( - f"⚠️ Runtime parameter '{name}' unknown in step '{self.name}'." + f"⚠️ Runtime parameter '{name}' unknown in step '{self.name}'." # type: ignore ) if closest: msg += f" Did you mean any of: {closest}" else: msg += f" Available runtime parameters for the step: {runtime_parameters_names}." - self.pipeline._logger.warning(msg) + self.pipeline._logger.warning(msg) # type: ignore continue attr = getattr(self, name) + + # Set runtime parameters for `RuntimeParametersMixin` field if isinstance(attr, RuntimeParametersMixin): attr.set_runtime_parameters(value) self._runtime_parameters[name] = value continue + # Set runtime parameters for `List[RuntimeParametersMixin]` field + if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin): + for i, item in enumerate(attr): + item_value = value.get(str(i), {}) + item.set_runtime_parameters(item_value) + self._runtime_parameters[name] = value + continue + # Handle settings values for `_SecretField` field_info = self.model_fields[name] inner_type = _extract_runtime_parameter_inner_type(field_info.annotation) if inspect.isclass(inner_type) and issubclass(inner_type, _SecretField): value = inner_type(value) + # Set the value of the runtime parameter setattr(self, name, value) self._runtime_parameters[name] = value diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index f67158dfc4..455ee7ab60 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -12,47 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import hashlib import logging import os -from collections import defaultdict -from dataclasses import asdict, dataclass, field +import signal +import threading +import time +from abc import ABC, abstractmethod from pathlib import Path from typing import ( TYPE_CHECKING, Any, + Callable, Dict, - Iterable, List, Optional, - Set, Tuple, TypedDict, Union, ) -import pyarrow as pa -import pyarrow.parquet as pq +import fsspec from typing_extensions import Self +from upath import UPath from distilabel import __version__ +from distilabel.distiset import create_distiset from distilabel.pipeline._dag import DAG +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.batch_manager import _BatchManager from distilabel.pipeline.constants import ( + CONVERGENCE_STEP_ATTR_NAME, + INPUT_QUEUE_ATTR_NAME, + LAST_BATCH_SENT_FLAG, RECEIVES_ROUTED_BATCHES_ATTR_NAME, ROUTING_BATCH_FUNCTION_ATTR_NAME, STEP_ATTR_NAME, ) -from distilabel.utils.files import list_files_in_dir -from distilabel.utils.serialization import TYPE_INFO_KEY, _Serializable +from distilabel.pipeline.write_buffer import _WriteBuffer +from distilabel.utils.logging import setup_logging, stop_logging +from distilabel.utils.serialization import ( + TYPE_INFO_KEY, + _Serializable, +) if TYPE_CHECKING: from os import PathLike + from queue import Queue from distilabel.distiset import Distiset from distilabel.pipeline.routing_batch_function import RoutingBatchFunction - from distilabel.steps.base import _Step - from distilabel.utils.serialization import SaveFormats, StrOrPath + from distilabel.pipeline.typing import PipelineRuntimeParametersInfo, StepLoadStatus + from distilabel.steps.base import Step, _Step BASE_CACHE_DIR = Path.home() / ".cache" / "distilabel" / "pipelines" @@ -71,6 +82,7 @@ class _CacheLocation(TypedDict): pipeline: Path batch_manager: Path data: Path + batch_input_data: Path log_file: Path @@ -103,7 +115,11 @@ def get_pipeline(cls) -> Union["BasePipeline", None]: return cls._context_global_pipeline -class BasePipeline(_Serializable): +_STEP_LOAD_FAILED_CODE = -666 +_STEP_NOT_LOADED_CODE = -999 + + +class BasePipeline(ABC, _Serializable): """Base class for a `distilabel` pipeline. Attributes: @@ -113,14 +129,37 @@ class BasePipeline(_Serializable): _cache_dir: The directory where the pipeline will be cached. _logger: The logger instance that will be used by the pipeline. _batch_manager: The batch manager that will manage the batches received from the - steps while running the pipeline. + steps while running the pipeline. It will be created when the pipeline is run, + from scratch or from cache. Defaults to `None`. + _write_buffer: The buffer that will store the data of the leaf steps of the pipeline + while running, so the `Distiset` can be created at the end. It will be created + when the pipeline is run. Defaults to `None`. + _logging_parameters: A dictionary containing the parameters that will passed to + `setup_logging` function to initialize the logging. Defaults to `{}`. + _fs: The `fsspec` filesystem to be used to store the data of the `_Batch`es passed + between the steps. It will be set when the pipeline is run. Defaults to `None`. + _storage_base_path: The base path where the data of the `_Batch`es passed between + the steps will be stored. It will be set then the pipeline is run. Defaults + to `None`. + _use_fs_to_pass_data: Whether to use the file system to pass the data of the + `_Batch`es between the steps. Even if this parameter is `False`, the `Batch`es + received by `GlobalStep`s will always use the file system to pass the data. + Defaults to `False`. + _dry_run: A flag to indicate if the pipeline is running in dry run mode. Defaults + to `False`. + output_queue: A queue to store the output of the steps while running the pipeline. + load_queue: A queue used by each `Step` to notify the main process it has finished + loading or it the step has been unloaded. """ + _output_queue: "Queue[Any]" + _load_queue: "Queue[Union[StepLoadStatus, None]]" + def __init__( self, name: str, description: Optional[str] = None, - cache_dir: Optional["PathLike"] = None, + cache_dir: Optional[Union[str, "PathLike"]] = None, enable_metadata: bool = False, ) -> None: """Initialize the `BasePipeline` instance. @@ -148,9 +187,24 @@ def __init__( self._logger = logging.getLogger("distilabel.pipeline") - # It's set to None here, will be created in the call to run self._batch_manager: Optional["_BatchManager"] = None - self._dry_run: bool = False + self._write_buffer: Optional["_WriteBuffer"] = None + self._logging_parameters: Dict[str, Any] = { + "filename": self._cache_location["log_file"] + } + + self._steps_load_status: Dict[str, int] = {} + self._steps_load_status_lock = threading.Lock() + + self._stop_called = False + self._stop_called_lock = threading.Lock() + self._stop_calls = 0 + + self._fs: Optional[fsspec.AbstractFileSystem] = None + self._storage_base_path: Optional[str] = None + self._use_fs_to_pass_data: bool = False + + self._dry_run = False def __enter__(self) -> Self: """Set the global pipeline instance when entering a pipeline context.""" @@ -219,10 +273,22 @@ def _create_signature(self) -> str: return hasher.hexdigest() + def _set_logging_parameters(self, parameters: Dict[str, Any]) -> None: + """Set the parameters that will be passed to the `setup_logging` function to + initialize the logging. + + Args: + parameters: A dictionary with the parameters that will be passed to the + `setup_logging` function. + """ + self._logging_parameters = parameters + def run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, use_cache: bool = True, + storage_parameters: Optional[Dict[str, Any]] = None, + use_fs_to_pass_data: bool = False, ) -> "Distiset": # type: ignore """Run the pipeline. It will set the runtime parameters for the steps and validate the pipeline. @@ -235,15 +301,68 @@ def run( the runtime parameters for the step as the value. Defaults to `None`. use_cache: Whether to use the cache from previous pipeline runs. Defaults to `True`. + storage_parameters: A dictionary with the storage parameters (`fsspec` and path) + that will be used to store the data of the `_Batch`es passed between the + steps if `use_fs_to_pass_data` is `True` (for the batches received by a + `GlobalStep` it will be always used). It must have at least the "path" key, + and it can contain additional keys depending on the protocol. By default, + it will use the local file system and a directory in the cache directory. + Defaults to `None`. + use_fs_to_pass_data: Whether to use the file system to pass the data of + the `_Batch`es between the steps. Even if this parameter is `False`, the + `Batch`es received by `GlobalStep`s will always use the file system to + pass the data. Defaults to `False`. Returns: The `Distiset` created by the pipeline. """ - if use_cache: - self._load_from_cache() + + # Set the runtime parameters that will be used during the pipeline execution. + # They are used to generate the signature of the pipeline that is used to hit the + # cache when the pipeline is run, so it's important to do it first. self._set_runtime_parameters(parameters or {}) + + setup_logging( + **{ + **self._logging_parameters, + "filename": str(self._cache_location["log_file"]), + } + ) + + self._init_steps_load_status() + + # Validate the pipeline DAG to check that all the steps are chainable, there are + # no missing runtime parameters, batch sizes are correct, etc. self.dag.validate() + # Load the `_BatchManager` from cache or create one from scratch + self._load_batch_manager(use_cache) + + # Setup the filesystem that will be used to pass the data of the `_Batch`es + self._setup_fsspec(storage_parameters) + self._use_fs_to_pass_data = use_fs_to_pass_data + + if self._dry_run: + self._logger.info("🌵 Dry run mode") + + # If the batch manager is not able to generate batches, that means that the loaded + # `_BatchManager` from cache didn't have any remaining batches to process i.e. + # the previous pipeline execution was completed successfully. + if not self._batch_manager.can_generate(): # type: ignore + self._logger.info( + "💾 Loaded batch manager from cache doesn't contain any remaining data." + " Returning `Distiset` from cache data..." + ) + stop_logging() + return create_distiset( + self._cache_location["data"], + pipeline_path=self._cache_location["pipeline"], + log_filename_path=self._cache_location["log_file"], + enable_metadata=self._enable_metadata, + ) + + self._setup_write_buffer() + def dry_run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, @@ -252,14 +371,14 @@ def dry_run( """Do a dry run to test the pipeline runs as expected. Running a `Pipeline` in dry run mode will set all the `batch_size` of generator steps - to the specified batch_size, and run just with a single batch, effectively - running the whole pipeline with a single example. The cache will be set to False. + to the specified `batch_size`, and run just with a single batch, effectively + running the whole pipeline with a single example. The cache will be set to `False`. Args: - parameters: The same parameters variable from `BasePipeline.run`. Defaults to None. - Will be passed to the parent method, but with the batch_size of the generator steps - fixed to 1. - batch_size: The batch size to test the pipeline. Defaults to 1. + parameters: A dictionary with the step name as the key and a dictionary with + the runtime parameters for the step as the value. Defaults to `None`. + batch_size: The batch size of the unique batch generated by the generators + steps of the pipeline. Defaults to `1`. Returns: Will return the `Distiset` as the main run method would do. @@ -279,7 +398,7 @@ def dry_run( self._dry_run = False return distiset - def get_runtime_parameters_info(self) -> Dict[str, List[Dict[str, Any]]]: + def get_runtime_parameters_info(self) -> "PipelineRuntimeParametersInfo": """Get the runtime parameters for the steps in the pipeline. Returns: @@ -292,6 +411,46 @@ def get_runtime_parameters_info(self) -> Dict[str, List[Dict[str, Any]]]: runtime_parameters[step_name] = step.get_runtime_parameters_info() return runtime_parameters + def _init_steps_load_status(self) -> None: + """Initialize the `_steps_load_status` dictionary assigning 0 to every step of + the pipeline.""" + for step_name in self.dag: + self._steps_load_status[step_name] = _STEP_NOT_LOADED_CODE + + def _setup_fsspec( + self, storage_parameters: Optional[Dict[str, Any]] = None + ) -> None: + """Setups the `fsspec` filesystem to be used to store the data of the `_Batch`es + passed between the steps. + + Args: + storage_parameters: A dictionary with the storage parameters (`fsspec` and path) + that will be used to store the data of the `_Batch`es passed between the + steps if `use_fs_to_pass_data` is `True` (for the batches received by a + `GlobalStep` it will be always used). It must have at least the "path" key, + and it can contain additional keys depending on the protocol. By default, + it will use the local file system and a directory in the cache directory. + Defaults to `None`. + """ + if not storage_parameters: + self._fs = fsspec.filesystem("file") + self._storage_base_path = ( + f"file://{self._cache_location['batch_input_data']}" + ) + return + + if "path" not in storage_parameters: + raise ValueError( + "The 'path' key must be present in the `storage_parameters` dictionary" + " if it's not `None`." + ) + + path = storage_parameters.pop("path") + protocol = UPath(path).protocol + + self._fs = fsspec.filesystem(protocol, **storage_parameters) + self._storage_base_path = path + def _add_step(self, step: "_Step") -> None: """Add a step to the pipeline. @@ -319,6 +478,14 @@ def _add_edge(self, from_step: str, to_step: str) -> None: value=routing_batch_function is not None, ) + def _is_convergence_step(self, step_name: str) -> None: + """Checks if a step is a convergence step. + + Args: + step_name: The name of the step. + """ + return self.dag.get_step(step_name).get(CONVERGENCE_STEP_ATTR_NAME) + def _add_routing_batch_function( self, step_name: str, routing_batch_function: "RoutingBatchFunction" ) -> None: @@ -406,1043 +573,559 @@ def _cache_location(self) -> _CacheLocation: "pipeline": folder / "pipeline.yaml", "batch_manager": folder / "batch_manager.json", "data": folder / "data", + "batch_input_data": folder / "batch_input_data", "log_file": folder / "pipeline.log", } def _cache(self) -> None: """Saves the `BasePipeline` using the `_cache_filename`.""" + if self._dry_run: + return + self.save( path=self._cache_location["pipeline"], format=self._cache_location["pipeline"].suffix.replace(".", ""), # type: ignore ) if self._batch_manager is not None: - self._batch_manager.save( - self._cache_location["batch_manager"], - format=self._cache_location["batch_manager"].suffix.replace(".", ""), # type: ignore - ) + self._batch_manager.cache(self._cache_location["batch_manager"]) self._logger.debug("Pipeline and batch manager saved to cache.") - def _load_from_cache(self) -> None: - """Will try to load the `BasePipeline` from the cache dir if found, updating - the internal `DAG` and `_BatchManager`. + def _load_batch_manager(self, use_cache: bool = True) -> None: + """Will try to load the `_BatchManager` from the cache dir if found. Otherwise, + it will create one from scratch. """ - cache_loc = self._cache_location - if cache_loc["pipeline"].exists(): - if cache_loc["batch_manager"].exists(): - self._batch_manager = _BatchManager.from_json( - cache_loc["batch_manager"] - ) - self._logger.info("💾 Load pipeline from cache") - - -@dataclass -class _Batch(_Serializable): - """Dataclass to represent a batch of data to be processed by a `_Step`. - - Attributes: - seq_no: The sequence number of the batch. - step_name: The name of the step that will process the batch. - last_batch: A flag to indicate if the batch is the last one. - data: The data to be processed. - accumulated: A flag to indicate if the batch is accumulated. - created_from: A dictionary containing the `seq_no` of the batches of the steps that - were used to create this batch. - size: The size of the batch. - """ - - seq_no: int - step_name: str - last_batch: bool - data: List[List[Dict[str, Any]]] = field(default_factory=list, repr=False) - accumulated: bool = False - created_from: Dict[str, List[Tuple[int, int]]] = field(default_factory=dict) - batch_routed_to: List[str] = field(default_factory=list) - size: int = 0 - - def next_batch(self) -> "_Batch": - """Create a new `_Batch` instance with the next batch of data. - - Args: - data: The data to be processed. - - Returns: - A `_Batch` instance. - """ - return _Batch( - seq_no=self.seq_no + 1, step_name=self.step_name, last_batch=self.last_batch - ) - - def set_data(self, data: List[List[Dict[str, Any]]]) -> None: - """Sets the data of the batch and updates the size of the batch. - - Args: - data: The data of the batch. - """ - self.data = data - self.size = len(data[0]) - - @classmethod - def accumulate(cls, step_name: str, batches: List[List["_Batch"]]) -> "_Batch": - """Creates a `_Batch` instance using the data from the list of batches that - were received from another steps. The batches will be accumulated in a single - list of data. - - Args: - step_name: The name of the step that will process the batch. - batches: a list containing the list of batches received from the predecessors. - - Returns: - A `_Batch` instance. - """ - - data = [] - for step_batches in batches: - accumulated_data = [row for batch in step_batches for row in batch.data[0]] - data.append(accumulated_data) - return cls( - seq_no=0, step_name=step_name, last_batch=True, data=data, accumulated=True - ) - - def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: - """Dumps the content of the `_Batch` to a dictionary, using the `dataclass` helper function. - - Args: - obj: Unused, just kept to match the signature of the parent method. - kwargs: Additional arguments that are kept to match the signature of the parent method. - - Returns: - A `dict` containing the internal representation of the `_Batch`. - """ - return asdict(self) - - def copy(self) -> "_Batch": - """Creates a copy of the `_Batch` instance. + batch_manager_cache_loc = self._cache_location["batch_manager"] + if use_cache and batch_manager_cache_loc.exists(): + self._logger.info( + f"💾 Loading `_BatchManager` from cache: '{batch_manager_cache_loc}'" + ) + self._batch_manager = _BatchManager.load_from_cache(batch_manager_cache_loc) + else: + self._batch_manager = _BatchManager.from_dag(self.dag) - Returns: - A copy of the `_Batch` instance. + def _setup_write_buffer(self) -> None: + """Setups the `_WriteBuffer` that will store the data of the leaf steps of the + pipeline while running, so the `Distiset` can be created at the end. """ - return copy.deepcopy(self) + buffer_data_path = self._cache_location["data"] + self._logger.info(f"📝 Pipeline data will be written to '{buffer_data_path}'") + self._write_buffer = _WriteBuffer(buffer_data_path, self.dag.leaf_steps) + + def _run_output_queue_loop_in_thread(self) -> threading.Thread: + """Runs the output queue loop in a separate thread to receive the output batches + from the steps. This is done to avoid the signal handler to block the loop, which + would prevent the pipeline from stopping correctly.""" + thread = threading.Thread(target=self._output_queue_loop) + thread.start() + return thread + + def _output_queue_loop(self) -> None: + """Loop to receive the output batches from the steps and manage the flow of the + batches through the pipeline.""" + while self._batch_manager.can_generate() and not self._stop_called: # type: ignore + self._logger.debug("Waiting for output batch from step...") + if (batch := self._output_queue.get()) is None: + self._logger.debug("Received `None` from output queue. Breaking loop.") + break + self._logger.debug( + f"Received batch with seq_no {batch.seq_no} from step '{batch.step_name}'" + f" from output queue: {batch}" + ) -@dataclass -class _BatchManagerStep(_Serializable): - """A class that will accumulate data for a step from the predecessors and create - batches for the step to process when there is enough data. + if batch.data_path: + self._logger.debug( + f"Reading {batch.seq_no} batch data from '{batch.step_name}': '{batch.data_path}'" + ) + batch.read_batch_data_from_fs() - Attributes: - step_name: The name of the step that will process the data. - accumulate: A flag to indicate if the data should be accumulated and create a - batch with all the data received from the predecessors instead of creating - batches with the `input_batch_size`. - input_batch_size: The size of the batch to be created for the step to process. - If `None`, then `accumulate` must be `True`. Defaults to `None`. - data: A dictionary with the predecessor step name as the key and a list of - dictionaries (rows) as the value. - seq_no: The sequence number of the next batch to be created. It will be - incremented for each batch created. - last_batch_received: A list with the names of the steps that sent the last - batch of data. - convergence_step: A flag to indicate if the step is a convergence step. An - `Step` is a convergence step if all its predecessors are receiving routed - batches. Defaults to `False`. - convergence_step_batches_consumed: A dictionary in which the key is the `seq_no` - of the batch created by step A, that was used by step B and C and obtained from - the `created_from` of the batches created by them. It's used to know if all - the batches from B and C steps created from batches of A have been consumed - by D, in order to not mess up the order of the batches. Only used if `convergence_step=True`. - Defaults to an empty dictionary. - next_expected_created_from_batch_seq_no: The next expected sequence number of the - batch from step A used by steps B and C and obtained from the `created_from` - of the batches created by them. It's used to avoid messing up the order of the - batches. Only used if `convergence_step=True`. Defaults to `0`. - """ + if batch.step_name in self.dag.leaf_steps: + self._write_buffer.add_batch(batch) # type: ignore - step_name: str - accumulate: bool - input_batch_size: Union[int, None] = None - data: Dict[str, List[_Batch]] = field(default_factory=dict) - seq_no: int = 0 - last_batch_received: List[str] = field(default_factory=list) - convergence_step: bool = False - convergence_step_batches_consumed: Dict[int, Dict[str, int]] = field( - default_factory=dict - ) - next_expected_created_from_batch_seq_no: int = 0 - - def add_batch(self, batch: _Batch, prepend: bool = False) -> None: - """Add a batch of data from `batch.step_name` to the step. It will accumulate the - data and keep track of the last batch received from the predecessors. + # If `_stop_called` was set to `True` while waiting for the output queue, then + # we need to handle the stop of the pipeline and break the loop to avoid + # propagating the batches through the pipeline and making the stop process + # slower. + if self._stop_called: + self._handle_batch_on_stop(batch) + break - Args: - batch: The output batch of an step to be processed by the step. - prepend: If `True`, the content of the batch will be added at the start of - the buffer. - """ - from_step = batch.step_name + self._manage_batch_flow(batch) - if prepend: - self.data[from_step].insert(0, batch) - else: - self.data[from_step].append(batch) + if self._stop_called: + self._handle_stop() - if batch.last_batch: - self.last_batch_received.append(from_step) + self._cache() - def get_batch(self) -> Union[_Batch, None]: - """Create a new batch of data for the step to process. It will return `None` if - there is not enough data to create a batch. + def _run_load_queue_loop_in_thread(self) -> threading.Thread: + """Runs a background thread that reads from the `load_queue` to update the status + of the number of workers loaded for each step. Returns: - A `_Batch` instance if there is enough data to create a batch. Otherwise, - `None`. + The thread that was started. """ - if not self._ready_to_create_batch(): - return None + thread = threading.Thread(target=self._run_load_queue_loop) + thread.start() + return thread + + def _run_load_queue_loop(self) -> None: + """Runs a loop that reads from the `load_queue` to update the status of the number + of workers loaded for each step.""" + while True: + if (load_info := self._load_queue.get()) is None: + self._logger.debug("Received `None` from load queue. Breaking loop.") + break + + with self._steps_load_status_lock: + step_name, status = load_info["name"], load_info["status"] + if status == "loaded": + if self._steps_load_status[step_name] == _STEP_NOT_LOADED_CODE: + self._steps_load_status[step_name] = 1 + else: + self._steps_load_status[step_name] += 1 + elif status == "unloaded": + self._steps_load_status[step_name] -= 1 + else: + # load failed + self._steps_load_status[step_name] = _STEP_LOAD_FAILED_CODE - # `_last_batch` must be called before `_get_data`, as `_get_data` will update the - # list of data which is used to determine if the batch to be created is the last one. - # TODO: remove `_last_batch` method and integrate logic in `_get_data` - last_batch = self._last_batch() - data, created_from, batch_routed_to = self._get_data() - - return _Batch( - seq_no=self._get_seq_no(), - step_name=self.step_name, - last_batch=last_batch, - data=data, - accumulated=self.accumulate, - created_from=created_from, - batch_routed_to=batch_routed_to, - ) + self._logger.debug( + f"Step '{step_name}' loaded workers: {self._steps_load_status[step_name]}" + ) - def empty_buffers(self) -> List[str]: - """Checks if the input buffer for the step is empty. + def _all_steps_loaded(self) -> bool: + """Waits for all the steps to load. Returns: - The name of the previous steps for which the input buffer for this step is - empty. + `True` if all the steps have been loaded correctly, `False` otherwise. """ - if self.accumulate: - return [ - previous_step - for previous_step in self.data.keys() - if previous_step not in self.last_batch_received - ] - return [ - previous_step - for previous_step, batches in self.data.items() - if previous_step not in self.last_batch_received - and sum(len(batch.data[0]) for batch in batches) < self.input_batch_size # type: ignore - ] - - @classmethod - def from_step( - cls, step: "_Step", predecessors: Iterable[str], convergence_step: bool = False - ) -> "_BatchManagerStep": - """Creates a `_BatchManagerStep` instance from a `_Step` instance and its - predecessors. + self._logger.info("⏳ Waiting for all the steps to load...") + previous_message = None + while not self._stop_called: + with self._steps_load_status_lock: + self._logger.debug(f"Steps loaded: {self._steps_load_status}") + + if any( + num_workers_loaded == _STEP_LOAD_FAILED_CODE + for num_workers_loaded in self._steps_load_status.values() + ): + self._logger.error("❌ Failed to load all the steps") + return False + + num_steps_loaded = 0 + workers_message = "" + for step_name, num_workers_loaded in self._steps_load_status.items(): + # TODO: update condition once we allow more than one worker per step + if num_workers_loaded == 1: + num_steps_loaded += 1 + workers_message += ( + f"\n * '{step_name}' workers: {max(0, num_workers_loaded)}" + ) - Args: - step: The `_Step` instance. - predecessors: The names of the predecessors of the step. - convergence_step: A flag to indicate if the step is a convergence step. An - `Step` is a convergence step if all its predecessors are receiving routed - batches. Defaults to `False`. + message = f"⏳ Steps loaded: {num_steps_loaded}/{len(self.dag)}{workers_message}" + if num_steps_loaded > 0 and message != previous_message: + self._logger.info(message) + previous_message = message - Returns: - A `_BatchManagerStep` instance. - """ - return cls( - step_name=step.name, # type: ignore - accumulate=step.is_global, - input_batch_size=getattr(step, "input_batch_size", None), - data={predecessor: [] for predecessor in predecessors}, - convergence_step=convergence_step, - ) + if num_steps_loaded == len(self.dag): + self._logger.info("✅ All the steps have been loaded!") + return True - def _get_seq_no(self) -> int: - """Gets the sequence number for the next batch to be created and increments it. + time.sleep(2.5) - Returns: - The sequence number for the next batch to be created. - """ - seq_no = self.seq_no - self.seq_no += 1 - return seq_no + return not self._stop_called - def _get_data( - self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]: - """Gets the data needed to create a batch for the step to process. If the step is - accumulating data, then it will return a list with all the data received from the - predecessors. Otherwise, it will return a list of data with the `input_batch_size` - for each predecessor. In addition, it will remove the data used to create the - batch from the step's data. + def _handle_stop(self) -> None: + """Handles the stop of the pipeline execution, which will stop the steps from + processing more batches and wait for the output queue to be empty, to not lose + any data that was already processed by the steps before the stop was called.""" + self._logger.debug("Handling stop of the pipeline execution...") - Returns: - A tuple containing the list of data needed to create a batch for the step to - process, a dictionary with the sequence numbers of the batches that were used - to create the batch and the list of steps to which the batch was routed to if - the step is a normal step. - """ - if self.accumulate: - # Steps accumulating cannot receive routed batches - return self._get_data_for_accumulate() + ([],) + self._add_batches_back_to_batch_manager() - if self.convergence_step: - # Convergence steps will receive routed batches, but we need to clean the - # `batch_routed_to` list - return self._get_data_for_convergence_step() + ([],) + # Wait for the input queue to be empty, which means that all the steps finished + # processing the batches that were sent before the stop flag. + for step_name in self.dag: + self._wait_step_input_queue_empty(step_name) - return self._get_data_normal() + self._consume_output_queue() - def _get_data_for_accumulate( - self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]: - """Gets the data needed to create a batch for the step to process when the step - is accumulating data. It will return a list with all the data received from the - predecessors. In addition, it will remove the data used to create the batch from - the step's data. + def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", None]: + """Waits for the input queue of a step to be empty. - Returns: - A tuple containing the list of data needed to create a batch for the step to - process and a dictionary with the sequence numbers of the batches that were - used to create the batch. - """ - data = [] - batches_used = {} - for step_name, batches in self.data.items(): - batches_used[step_name] = [] - for batch in batches: - batches_used[step_name].append((batch.seq_no, batch.size)) - data.append([row for batch in batches for row in batch.data[0]]) - # Reset the data buffer - self.data = {step_name: [] for step_name in self.data} - return data, batches_used - - def _get_data_for_convergence_step( - self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]: - """Gets the data needed to create a batch for the step to process when the step is - a convergence step. + Args: + step_name: The name of the step. Returns: - A tuple containing the list of data needed to create a batch for the step to - process and a dictionary with the sequence numbers of the batches that were - used to create the batch. + The input queue of the step if it's not loaded or finished, `None` otherwise. """ - grouped_batches = self._group_batches_by_created_from() - seq_no, batches = grouped_batches[0] - - remaining_rows_per_step = { - step_name: self.input_batch_size for step_name in self.data - } - batches_used = defaultdict(list) - data = defaultdict(list) - for batch, batch_size in batches: - batch_data = batch.data[0] - remaining_rows = remaining_rows_per_step[batch.step_name] - selected_data = batch_data[:remaining_rows] - data[batch.step_name].extend(selected_data) - - # If A -> [B, C] -> D, then in D (this step) we keep track of the remaining - # rows from the batches of A that B and C used to create the `batches`. - batch_size = self.convergence_step_batches_consumed.setdefault( - seq_no, {} - ).get(batch.step_name, batch_size) - remaining_rows_in_batch = batch_size - len(selected_data) - self.convergence_step_batches_consumed[seq_no].update( - {batch.step_name: remaining_rows_in_batch} - ) - - # Update the remaining rows - num_rows = len(selected_data) - remaining_rows_per_step[batch.step_name] -= num_rows # type: ignore + if self._check_step_not_loaded_or_finished(step_name): + return None - # Keep track of the batches used to create the batch - batches_used[batch.step_name].append((batch.seq_no, batch.size)) + if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME): + while input_queue.qsize() != 0: + pass + return input_queue - # If the batch was entirely consumed, then remove it from the buffer - if num_rows >= len(batch_data): - self.data[batch.step_name].remove(batch) - continue - - # The batch was not entirely consumed. so we need to update the batch - # with the remaining data - batch_idx = self.data[batch.step_name].index(batch) - batch_ref = self.data[batch.step_name][batch_idx] - batch_ref.data[0] = batch_data[len(selected_data) :] - - # If all the batches grouped by the `seq_no` in `created_from` were consumed, then - # we can update the `next_expected_created_from_batch_seq_no` to the next one - # to avoid skipping batches. - no_remaining_rows = all( - count == 0 - for count in self.convergence_step_batches_consumed[seq_no].values() - ) - if no_remaining_rows: - self.next_expected_created_from_batch_seq_no += 1 + def _check_step_not_loaded_or_finished(self, step_name: str) -> bool: + """Checks if a step is not loaded or already finished. - return list(data.values()), dict(batches_used) - - def _get_data_normal( - self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]: - """Gets the data needed to create a batch for the step to process when the step is - not accumulating data. It will return a list of data with the `input_batch_size` - for each predecessor. In addition, it will remove the data used to create the batch - from the step's data. + Args: + step_name: The name of the step. Returns: - A tuple containing the list of data needed to create a batch for the step to - process, a dictionary with the sequence numbers of the batches that were used - to create the batch and the list of steps to which the batch was routed to if - the step is a convergence step. + `True` if the step is not loaded or already finished, `False` otherwise. """ - data = [] - batches_used = defaultdict(list) - batch_routed_to = [] - for step_name in self.data: - # For each step batches buffer, we will create a batch with the `input_batch_size` - # using the data from the buffer. We will remove the consumed batches (no data - # left) and update the batch data with the remaining data. - step_data = [] - idx_drop_batches = [] - remaining_rows: int = self.input_batch_size # type: ignore - for idx, batch in enumerate(self.data[step_name]): - if remaining_rows == 0: - break - - # Get `remaining_rows` or the remaining rows in the batch and add it to - # the step data that will be used to create the batch - batch_data = batch.data[0] - selected_data = batch_data[:remaining_rows] - step_data.extend(selected_data) - batch_routed_to = batch.batch_routed_to - - # Update the remaining rows - num_rows = len(selected_data) - remaining_rows -= num_rows - - # Keep track of the batches used to create the batch - batches_used[step_name].append((batch.seq_no, batch.size)) - - # If the batch was entirely consumed, then remove it from the buffer - if num_rows >= len(batch_data): - idx_drop_batches.append(idx) - continue - - # The batch was not entirely consumed. so we need to update the batch - # with the remaining data - batch.data[0] = batch_data[len(selected_data) :] - - # Remove the batches that were entirely consumed - idx_drop_batches.reverse() - for idx in idx_drop_batches: - self.data[step_name].pop(idx) - - data.append(step_data) - - return data, dict(batches_used), batch_routed_to + with self._steps_load_status_lock: + num_workers = self._steps_load_status[step_name] - def _ready_to_create_batch(self) -> bool: - """Checks if there is enough data to create a batch for the step. + # The step has finished (workers = 0) or it has failed to load + if num_workers in [0, _STEP_LOAD_FAILED_CODE]: + return True - Returns: - `True` if there is enough data to create a batch for the step. Otherwise, - `False`. - """ - if self.accumulate: - return self._ready_to_create_batch_accumulate() + return False - if self.convergence_step: - return self._ready_to_create_batch_convergence_step() + @property + @abstractmethod + def QueueClass(self) -> Callable: + """The class of the queue to use in the pipeline.""" + pass - return self._ready_to_create_batch_normal() + def _create_step_input_queue(self, step_name: str) -> "Queue[Any]": + """Creates an input queue for a step. - def _ready_to_create_batch_accumulate(self) -> bool: - """Checks if there is enough data for an step accumulating data. It will return - `True` if the last batch was received from all the predecessors. + Args: + step_name: The name of the step. Returns: - `True` if ready to create a batch, `False` otherwise. + The input queue created. """ - return all( - step in self.last_batch_received - and sum(len(batch.data[0]) for batch in batches) >= 0 - for step, batches in self.data.items() - ) + input_queue = self.QueueClass() + self.dag.set_step_attr(step_name, INPUT_QUEUE_ATTR_NAME, input_queue) + return input_queue - def _ready_to_create_batch_convergence_step(self) -> bool: - """Checks if there is enough data for creating a batch for an step in which output - batches that were generated by steps that received routed batches are received. - It will return `True`, if all the output batches that were generated from a routed - batch have been received. + @abstractmethod + def _run_step(self, step: "_Step", input_queue: "Queue[Any]") -> None: + """Runs the `Step` instance. - Returns: - `True` if ready to create a batch, `False` otherwise. + Args: + step: The `Step` instance to run. + input_queue: The input queue where the step will receive the batches. """ - grouped_batches = self._group_batches_by_created_from() - if not grouped_batches: - return False - seq_no, batches = grouped_batches[0] - - # If the `seq_no` from the `created_from` field is not the expected one, then - # we cannot create a batch yet or the order will be messed up - if seq_no != self.next_expected_created_from_batch_seq_no: - return False - - # Not all output batches to which the input batch was routed to haven't been - # received - batch_routed_to = batches[0][0].batch_routed_to - batches_received_from = {batch.step_name for batch, _ in batches} - if any(step_name not in batches_received_from for step_name in batch_routed_to): - return False - - # There are output batches to which the input batch was routed to from all - # the steps. Check if there is enough data for creating a batch with `input_batch_size` - rows_per_step = defaultdict(lambda: 0) - for batch, _ in batches: - num_rows = len(batch.data[0]) - rows_per_step[batch.step_name] += num_rows - - # If there aren't at least `input_batch_size` rows from each step, then there - # isn't enough data to create a batch - if not all( - num_rows >= self.input_batch_size or step_name in self.last_batch_received # type: ignore - for step_name, num_rows in rows_per_step.items() - ): - return False - - return True - - def _ready_to_create_batch_normal(self) -> bool: - """Checks if there is enough data for creating a batch for a normal step. It will - be `True` it there are at least `input_batch_size` rows from each predecessor step. + pass - Returns: - `True` if ready to create a batch, `False` otherwise. + def _run_steps(self) -> None: + """Runs the `Step`s of the pipeline, creating first an input queue for each step + that will be used to send the batches. """ - for step_name, batches in self.data.items(): - num_rows = sum(len(batch.data[0]) for batch in batches) - - # If there are now rows but the last batch was already received, then there - # are no more batch to be created - if num_rows == 0 and step_name in self.last_batch_received: - return False - - # If there are not enough rows and the last batch was not received yet, then - # there is not enough data yet to creata a batch - if ( - self.input_batch_size - and num_rows < self.input_batch_size - and step_name not in self.last_batch_received - ): - return False + for step_name in self.dag: + step: "Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME] + input_queue = self._create_step_input_queue(step_name=step_name) - return True + # Set `pipeline` to `None` as in some Python environments the pipeline is not + # picklable and it will raise an error when trying to send the step to the process. + # `TypeError: cannot pickle 'code' object` + step.pipeline = None - def _last_batch(self) -> bool: - """Checks if the batch to be created is the last one i.e. if the last batch was - received from all the predecessors. + self._logger.debug(f"Running 1 instance of step '{step.name}'...") + self._run_step(step=step, input_queue=input_queue) - Returns: - `True` if the batch to be created is the last one. Otherwise, `False`. - """ - if self.accumulate: - return self._last_batch_accumulate() + def _add_batches_back_to_batch_manager(self) -> None: + """Add the `Batch`es that were sent to a `Step` back to the `_BatchManager`. This + method should be used when the pipeline has been stopped prematurely.""" + for step_name in self.dag: + node = self.dag.get_step(step_name) + step: "_Step" = node[STEP_ATTR_NAME] + if step.is_generator: + continue + if input_queue := node.get(INPUT_QUEUE_ATTR_NAME): + while not input_queue.empty(): + batch = input_queue.get() + if batch is None: + continue + self._batch_manager.add_batch( # type: ignore + to_step=step_name, batch=batch, prepend=True + ) + self._logger.debug( + f"Adding batch back to the batch manager: {batch}" + ) + input_queue.put(None) + + def _consume_output_queue(self) -> None: + """Consumes the `Batch`es from the output queue until it's empty. This method should + be used when the pipeline has been stopped prematurely to consume and to not lose + the `Batch`es that were processed by the leaf `Step`s before stopping the pipeline.""" + while not self._output_queue.empty(): + batch = self._output_queue.get() + if batch is None: + continue - if self.convergence_step: - return self._last_batch_convergence_step() + if batch.step_name in self.dag.leaf_steps: + self._write_buffer.add_batch(batch) # type: ignore - return self._last_batch_normal() + self._handle_batch_on_stop(batch) - def _last_batch_accumulate(self) -> bool: - """Checks if the batch to be created is the last one for an step accumulating data. - `True` if the last batch was received from all the predecessors. + def _manage_batch_flow(self, batch: "_Batch") -> None: + """Checks if the step that generated the batch has more data in its buffer to + generate a new batch. If there's data, then a new batch is sent to the step. If + the step has no data in its buffer, then the predecessors generator steps are + requested to send a new batch. - Returns: - `True` if the batch to be created is the last one. Otherwise, `False`. + Args: + batch: The batch that was processed. """ - return all(step in self.last_batch_received for step in self.data.keys()) + assert self._batch_manager, "Batch manager is not set" - def _last_batch_convergence_step(self) -> bool: - """Checks if the batch to be created is the last one for a convergence step. `True` - if the last batch of all the steps (`batch_routed_to`) in the last routed batch - have been received. + # Make sure to send the `LAST_BATCH_SENT_FLAG` to the predecessors of the convergence + # step if the batch is the last one, so they stop their processing loop even if + # they haven't received the last batch because of the routing function. + if self._is_convergence_step(batch.step_name) and batch.last_batch: + for step_name in self.dag.get_step_predecessors(batch.step_name): + self._send_last_batch_flag_to_step(step_name) - Returns: - `True` if the batch to be created is the last one. Otherwise, `False`. - """ - grouped_batches = self._group_batches_by_created_from() - if not grouped_batches: - return False - _, batches = grouped_batches[0] + route_to, routed = self._get_successors(batch) - for batch, _ in batches: - if not batch.last_batch: - return False + # Keep track of the steps that the batch was routed to + if routed: + batch.batch_routed_to = route_to - if len(batch.data[0]) > self.input_batch_size: # type: ignore - return False + self._register_batch(batch) - return True + step = self._get_step_from_batch(batch) - def _last_batch_normal(self) -> bool: - """Checks if the batch to be created is the last one for a normal step. `True` if - there is no more data to be received from the predecessors. + # Add the batch to the successors input buffers + for successor in route_to: + # Copy batch to avoid modifying the same reference in the batch manager + batch_to_add = batch.copy() if len(route_to) > 1 else batch - Returns: - `True` if the batch to be created is the last one. Otherwise, `False`. - """ - for step_name, batches in self.data.items(): - if step_name not in self.last_batch_received: - return False - - num_rows = sum(len(batch.data[0]) for batch in batches) + self._batch_manager.add_batch(successor, batch_to_add) + # Check if the step is a generator and if there are successors that need data + # from this step. This usually happens when the generator `batch_size` is smaller + # than the `input_batch_size` of the successor steps. if ( - self.input_batch_size - and num_rows > self.input_batch_size - and step_name in self.last_batch_received + step.is_generator + and step.name in self._batch_manager.step_empty_buffers(successor) ): - return False - - return True - - def _group_batches_by_created_from( - self, - ) -> List[Tuple[int, List[Tuple["_Batch", int]]]]: - """Group the batches by the first key of `created_from` field. This method is - meant to be used only with a `convergence_step`. + last_batch_sent = self._batch_manager.get_last_batch_sent(step.name) + self._send_batch_to_step(last_batch_sent.next_batch()) # type: ignore + + # If successor step has enough data in its buffer to create a new batch, then + # send the batch to the step. + if new_batch := self._batch_manager.get_batch(successor): + self._send_batch_to_step(new_batch) + + if not step.is_generator: + # Step ("this", the one from which the batch was received) has enough data on its + # buffers to create a new batch + if new_batch := self._batch_manager.get_batch(step.name): # type: ignore + self._send_batch_to_step(new_batch) + else: + self._request_more_batches_if_needed(step) - Returns: - A list of the batches grouped by the `seq_no` of the first step name in `created_from`. - The list is sorted by the `seq_no`. - """ - grouped_batches: Dict[int, List[Tuple["_Batch", int]]] = defaultdict(list) - for batches in self.data.values(): - for batch in batches: - first_key = next(iter(batch.created_from)) - batch_seq_no, batch_size = batch.created_from[first_key][0] - grouped_batches[batch_seq_no].append((batch, batch_size)) - return sorted((seq_no, batches) for seq_no, batches in grouped_batches.items()) + self._cache() - def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: - """Dumps the content of the `_BatchManagerStep` to a dictionary, using the `dataclass` helper function. + def _send_to_step(self, step_name: str, to_send: Any) -> None: + """Sends something to the input queue of a step. Args: - obj: Unused, just kept to match the signature of the parent method. - kwargs: Additional arguments that are kept to match the signature of the parent method. - - Returns: - Internal representation of the `_BatchManagerStep`. + step_name: The name of the step. + to_send: The object to send. """ - return asdict(self) - - -LAST_BATCH_SENT_FLAG = "last_batch_sent" + input_queue = self.dag.get_step(step_name)[INPUT_QUEUE_ATTR_NAME] + input_queue.put(to_send) + def _send_batch_to_step(self, batch: "_Batch") -> None: + """Sends a batch to the input queue of a step, writing the data of the batch + to the filesystem and setting `batch.data_path` with the path where the data + was written (if requiered i.e. the step is a global step or `use_fs_to_pass_data`) -class _BatchManager(_Serializable): - """Class to manage the batches received from the steps. It keeps track of the - received batches and returns new batches for the steps to process based on their - input batch size and the batches received from the predecessors. - - Attributes: - steps: A dictionary with the step name as the key and a `_BatchManagerStep` - instance as the value. - last_batch_received: A dictionary with the step name as the key and a flag to - indicate whether we received the last batch from the step. - """ - - def __init__( - self, - steps: Dict[str, _BatchManagerStep], - last_batch_received: Dict[str, Union[_Batch, None]], - last_batch_sent: Dict[str, Union[_Batch, None]], - last_batch_flag_sent_to: List[str], - ) -> None: - """Initialize the `_BatchManager` instance. + This method should be extended by the specific pipeline implementation, adding + the logic to send the batch to the step. Args: - steps: A dictionary with the step name as the key and a dictionary with the - predecessor step name as the key and a list of batches as the value. - last_batch_received: A dictionary with the step name as the key and a the last - `_Batch` received from the step. - last_batch_sent: A dictionary with the step name as the key and a the last - `_Batch` sent to the step. - last_batch_flag_sent_to: A list with the names of the steps to which `LAST_BATCH_SENT_FLAG` - was sent. + batch: The batch to send. """ + self._logger.debug( + f"Setting batch {batch.seq_no} as last batch sent to '{batch.step_name}': {batch}" + ) + self._batch_manager.set_last_batch_sent(batch) # type: ignore - self._steps = steps - self._last_batch_received = last_batch_received - self._last_batch_sent = last_batch_sent - self._last_batch_flag_sent_to = last_batch_flag_sent_to - - def can_generate(self) -> bool: - """Checks if there are still batches to be processed by the steps. - - Returns: - `True` if there are still batches to be processed by the steps. Otherwise, - `False`. - """ - - for step_name, batch in self._last_batch_received.items(): - if step_name not in self._last_batch_flag_sent_to: - if not batch: - return True - - if not batch.last_batch: - return True + step: "_Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] + if not step.is_generator and (step.is_global or self._use_fs_to_pass_data): + base_path = UPath(self._storage_base_path) / step.name # type: ignore + self._logger.debug( + f"Writing {batch.seq_no} batch for '{batch.step_name}' step to filesystem: {base_path}" + ) + batch.write_batch_data_to_fs(self._fs, base_path) # type: ignore - if not self.get_last_batch_sent(step_name): - return True + self._logger.debug( + f"Sending batch {batch.seq_no} to step '{batch.step_name}': {batch}" + ) - return False + self._send_to_step(batch.step_name, batch) - def register_batch(self, batch: _Batch) -> None: - """Method to register a batch received from a step. It will keep track of the - sequence number and the last batch received from the step in the internal maps. + def _register_batch(self, batch: "_Batch") -> None: + """Registers a batch in the batch manager. Args: - batch: _Batch from which we will register the sequence number and the last batch received. + batch: The batch to register. """ - self._last_batch_received[batch.step_name] = batch + self._batch_manager.register_batch(batch) # type: ignore + self._logger.debug( + f"Batch {batch.seq_no} from step '{batch.step_name}' registered in batch" + " manager" + ) - def get_last_batch(self, step_name: str) -> Union[_Batch, None]: - """Gets the last batch received from a step. + def _send_last_batch_flag_to_step(self, step_name: str) -> None: + """Sends the `LAST_BATCH_SENT_FLAG` to a step to stop processing batches. Args: step_name: The name of the step. - - Returns: - The last batch received from the step or `None` if no batch was received. """ - return self._last_batch_received.get(step_name) - - def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None: - """Add an output batch from `batch.step_name` to `to_step`. + batch = self._batch_manager.get_last_batch_sent(step_name) # type: ignore + if batch and batch.last_batch: + return - Args: - to_step: The name of the step that will process the batch. - batch: The output batch of an step to be processed by `to_step`. - prepend: If `True`, the content of the batch will be added at the start of - the buffer. - - Raises: - ValueError: If `to_step` is not found in the batch manager. - """ - if to_step not in self._steps: - raise ValueError(f"Step '{to_step}' not found in the batch manager.") - - step = self._steps[to_step] - step.add_batch(batch, prepend) + self._logger.debug( + f"Sending `LAST_BATCH_SENT_FLAG` to '{step_name}' step to stop processing" + " batches..." + ) - def get_batch(self, step_name: str) -> Union[_Batch, None]: - """Get the next batch to be processed by the step. + self._send_to_step(step_name, LAST_BATCH_SENT_FLAG) + self._batch_manager.set_last_batch_flag_sent_to(step_name) # type: ignore - Args: - step_name: The name of the step that will process the batch. + def _request_initial_batches(self) -> None: + """Requests the initial batches to the generator steps.""" + assert self._batch_manager, "Batch manager is not set" - Returns: - A `_Batch` instance if there is a batch to be processed by the step. Otherwise, - `None`. - """ - if step_name not in self._steps: - raise ValueError(f"Step '{step_name}' not found in the batch manager.") + for step in self._batch_manager._steps.values(): + if batch := step.get_batch(): + self._logger.debug( + f"Sending initial batch to '{step.step_name}' step: {batch}" + ) + self._send_batch_to_step(batch) - return self._steps[step_name].get_batch() + for step_name in self.dag.root_steps: + seq_no = 0 + if last_batch := self._batch_manager.get_last_batch(step_name): + seq_no = last_batch.seq_no + 1 + batch = _Batch(seq_no=seq_no, step_name=step_name, last_batch=self._dry_run) + self._logger.debug( + f"Requesting initial batch to '{step_name}' generator step: {batch}" + ) + self._send_batch_to_step(batch) - def step_empty_buffers(self, step_name: str) -> List[str]: - """Checks if the input buffer for a step is empty. + def _request_more_batches_if_needed(self, step: "Step") -> None: + """Request more batches to the predecessors steps of `step` if needed. Args: - step_name: The name of the step. - - Returns: - The name of the previous steps for which the input buffer for this step is - empty. + step: The step of which it has to be checked if more batches are needed from + its predecessors. """ - return self._steps[step_name].empty_buffers() + empty_buffers = self._batch_manager.step_empty_buffers(step.name) # type: ignore + for previous_step_name in empty_buffers: + # Only more batches can be requested to the `GeneratorStep`s as they are the + # only kind of steps that lazily generate batches. + if previous_step_name not in self.dag.root_steps: + continue - def set_last_batch_sent(self, batch: "_Batch") -> None: - """Set the last batch sent to a step. + # Get the last batch that the previous step sent to generate the next batch + # (next `seq_no`). + last_batch = self._batch_manager.get_last_batch_sent(previous_step_name) # type: ignore + if last_batch is None: + continue - Args: - batch: The last batch sent to a step. - """ - self._last_batch_sent[batch.step_name] = batch + self._logger.debug( + f"Step '{step.name}' input buffer for step '{previous_step_name}' is" + " empty. Requesting new batch..." + ) + self._send_batch_to_step(last_batch.next_batch()) - def get_last_batch_sent(self, step_name: str) -> Union["_Batch", None]: - """Get the last batch sent to a step. + def _handle_batch_on_stop(self, batch: "_Batch") -> None: + """Handles a batch that was received from the output queue when the pipeline was + stopped. It will add and register the batch in the batch manager. Args: - step_name: The name of the step. - - Returns: - The last batch sent to a step or `None` if no batch was sent. + batch: The batch to handle. """ - return self._last_batch_sent.get(step_name, None) + assert self._batch_manager, "Batch manager is not set" - def set_last_batch_flag_sent_to(self, step_name: str) -> None: - """Set the flag to indicate that the last batch was sent to a step. + self._batch_manager.register_batch(batch) + step: "Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] + for successor in self.dag.get_step_successors(step.name): # type: ignore + self._batch_manager.add_batch(successor, batch) - Args: - step_name: The name of the step. - """ - self._last_batch_flag_sent_to.append(step_name) - - @classmethod - def from_dag(cls, dag: "DAG") -> "_BatchManager": - """Create a `_BatchManager` instance from a `DAG` instance. + def _get_step_from_batch(self, batch: "_Batch") -> "Step": + """Gets the `Step` instance from a batch. Args: - dag: The `DAG` instance. + batch: The batch to get the step from. Returns: - A `_BatchManager` instance. + The `Step` instance. """ - steps = {} - last_batch_received = {} - last_batch_sent = {} - for step_name in dag: - step: "_Step" = dag.get_step(step_name)[STEP_ATTR_NAME] - last_batch_received[step.name] = None - last_batch_sent[step.name] = None - if step.is_generator: - continue - predecessors = list(dag.get_step_predecessors(step_name)) - convergence_step = all( - dag.get_step(predecessor).get(RECEIVES_ROUTED_BATCHES_ATTR_NAME, False) - for predecessor in predecessors - ) - batch_manager_step = _BatchManagerStep.from_step( - step=step, - predecessors=predecessors, - convergence_step=convergence_step, - ) - steps[step_name] = batch_manager_step - return cls(steps, last_batch_received, last_batch_sent, []) + return self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] - def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: - """Dumps the content of the `_BatchManager` to a dictionary. - - Args: - obj (Any): Unused, just kept to match the signature of the parent method. - kwargs (Any): Additional arguments that are kept to match the signature of the parent method. - - Returns: - Dict[str, Any]: Internal representation of the `_BatchManager`. - """ - return { - "steps": {name: step.dump() for name, step in self._steps.items()}, - "last_batch_received": { - step_name: batch.dump() if batch is not None else None - for step_name, batch in self._last_batch_received.items() - }, - "last_batch_sent": { - step_name: batch.dump() if batch is not None else None - for step_name, batch in self._last_batch_sent.items() - }, - "last_batch_flag_sent_to": self._last_batch_flag_sent_to, - } + def _notify_steps_to_stop(self) -> None: + """Notifies the steps to stop their infinite running loop by sending `None` to + their input queues.""" + for step_name in self.dag: + self._send_to_step(step_name, None) - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "_BatchManager": - """Loads a `_BatchManager` from its serialized content in a dictionary. + def _get_successors(self, batch: "_Batch") -> Tuple[List[str], bool]: + """Gets the successors and the successors to which the batch has to be routed. Args: - data: The serialized batch manager. + batch: The batch to which the successors will be determined. Returns: - A `_BatchManager` instance. - """ - # Remove the type info, we already know its a `_BatchManager`, and there aren't subclasses of it - data.pop(TYPE_INFO_KEY) - # Also there is only one type of `_BatchManagerStep`, so we can call it directly instead of generically - # via `_get_module_attr` - return cls( - { - name: _BatchManagerStep.from_file(step_path) - for name, step_path in data["steps"].items() - }, - { - step_name: _Batch.from_dict(batch) if batch is not None else None - for step_name, batch in data["last_batch_received"].items() - }, - { - step_name: _Batch.from_dict(batch) if batch is not None else None - for step_name, batch in data["last_batch_sent"].items() - }, - data["last_batch_flag_sent_to"], - ) - - def save( - self, - path: Union["StrOrPath", None] = None, - format: "SaveFormats" = "json", - dump: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - """Overrides the parent method to save the each `_BatchManagerStep` to a file, and the contents - keep in the `_BatchManager` dump the paths to those files. - - Note: - Not expected to be used directly, but through the `Pipeline._cache` class. - - Args: - path: filename of the object to save. If a folder is given, will create the object - inside. If None is given, the file will be created at the current - working directory. Defaults to None. - format: the format to use when saving the file. Valid options are 'json' and - 'yaml'. Defaults to `"json"`. - dump: the serialized object to save. If None, the object will be serialized using - the default self.dump. This variable is here to allow extra customization, in - general should be set as None. + The successors to route the batch to and whether the batch was routed using + a routing function. """ - path = Path(path) - dump = self.dump() - batch_manager_step_files = {} - # Do this to avoid modifying the dictionary while iterating over it - batch_manager_steps = set(dump["steps"].keys()) - for step_name in batch_manager_steps: - step_dump = dump["steps"].pop(step_name) - filename = str(path.parent / f"batch_manager_steps/{step_name}.json") - batch_manager_step_files[step_name] = filename - super().save(path=filename, format=format, dump=step_dump) - dump["steps"] = batch_manager_step_files - super().save(path=path, format=format, dump=dump) - - -class _WriteBuffer: - """Class in charge of sending the batched contents to a buffer and writing - those to files under a given folder. - - As batches are received, they are added to the buffer and once each buffer - is full, the content is written to a parquet file. - """ + node = self.dag.get_step(batch.step_name) + step: "Step" = node[STEP_ATTR_NAME] + successors = list(self.dag.get_step_successors(step.name)) # type: ignore + route_to = successors + + # Check if the step has a routing function to send the batch to specific steps + if routing_batch_function := node.get(ROUTING_BATCH_FUNCTION_ATTR_NAME): + route_to = routing_batch_function(batch, successors) + successors_str = ", ".join(f"'{successor}'" for successor in route_to) + self._logger.info( + f"🚏 Using '{step.name}' routing function to send batch {batch.seq_no} to steps: {successors_str}" + ) - def __init__(self, path: "PathLike", leaf_steps: Set[str]) -> None: - """ - Args: - path: Folder where the files will be written, the idea - is for this path to be in the cache folder under /data. - leaf_steps: Leaf steps from either the DAG of the Pipeline. + return route_to, route_to != successors - Raises: - ValueError: If the path is not a directory. - """ - self._path = Path(path) - if not self._path.exists(): - self._path.mkdir(parents=True, exist_ok=True) - for step in leaf_steps: - (self._path / step).mkdir(parents=True, exist_ok=True) + @abstractmethod + def _stop(self) -> None: + """Stops the pipeline in a controlled way.""" + pass - if not self._path.is_dir(): - raise ValueError(f"The path should be a directory, not a file: {path}") + def _stop_load_queue_loop(self) -> None: + """Stops the `_load_queue` loop sending a `None`.""" + self._logger.debug("Sending `None` to the load queue to notify stop...") + self._load_queue.put(None) - self._buffers: Dict[str, List[Dict[str, Any]]] = { - step: [] for step in leaf_steps - } - # TODO: make this configurable - self._buffers_dump_batch_size: Dict[str, int] = { - step: 50 for step in leaf_steps - } - self._buffer_last_schema = {} - self._buffers_last_file: Dict[str, int] = {step: 1 for step in leaf_steps} - self._logger = logging.getLogger("distilabel.write_buffer") + def _stop_output_queue_loop(self) -> None: + """Stops the `_output_queue` loop sending a `None`.""" + self._logger.debug("Sending `None` to the output queue to notify stop...") + self._output_queue.put(None) - def _get_filename(self, step_name: str) -> Path: - """Creates the filename for the step. + def _handle_keyboard_interrupt(self) -> Any: + """Handles KeyboardInterrupt signal sent during the Pipeline.run method. - Args: - step_name: Name of the step to which the data belongs to. - - Returns: - Filename for the step. - """ - return self._path / f"{step_name}.parquet" - - def is_full(self, step_name: str) -> bool: - """Checks the buffers that are full so that those can be written to the file. + It will try to call self._stop (if the pipeline didn't started yet, it won't + have any effect), and if the pool is already started, will close it before exiting + the program. Returns: - Whether the buffer is full. + The original `signal.SIGINT` handler. """ - return len(self._buffers[step_name]) >= self._buffers_dump_batch_size[step_name] - def add_batch(self, batch: "_Batch") -> None: - """Adds a batch to the buffer and writes the buffer to the file if it's full. + def signal_handler(signumber: int, frame: Any) -> None: + self._stop() - Args: - batch: batch to add to the buffer. - """ - step_name = batch.step_name - data = batch.data[0] - self._buffers[step_name].extend(data) - self._logger.debug( - f"Added batch to write buffer for step '{step_name}' with {len(data)} rows." - ) - if self.is_full(step_name): - self._logger.debug( - f"Buffer for step '{step_name}' is full (rows: {len(self._buffers[step_name])}," - f" full: {self._buffers_dump_batch_size[step_name]}), writing to file..." - ) - self._write(step_name) - - def _write(self, step_name: str) -> None: - """Writes the content to the file and cleans the buffer. - - Args: - step_name (str): Name of the step to which the data pertains. - """ - step_parquet_dir = Path(self._path, step_name) - if not step_parquet_dir.exists(): - self._logger.debug( - f"Creating directory for step '{step_name}' parquet files..." - ) - step_parquet_dir.mkdir() - - table = pa.Table.from_pylist(self._buffers[step_name]) - - last_schema = self._buffer_last_schema.get(step_name) - if last_schema is None: - self._buffer_last_schema[step_name] = table.schema - else: - if not last_schema.equals(table.schema): - new_schema = pa.unify_schemas([last_schema, table.schema]) - self._buffer_last_schema[step_name] = new_schema - table = table.cast(new_schema) - - next_file_number = self._buffers_last_file[step_name] - self._buffers_last_file[step_name] = next_file_number + 1 - - parquet_file = step_parquet_dir / f"{str(next_file_number).zfill(5)}.parquet" - pq.write_table(table, parquet_file) - self._logger.debug(f"Written to file '{parquet_file}'") - - self._clean_buffer(step_name) - - def _clean_buffer(self, step_name: str) -> None: - """Cleans the buffer by setting it's content to `None`. - - Args: - step_name: The name of the buffer to clean. - """ - self._buffers[step_name] = [] - - def close(self) -> None: - """Closes the buffer by writing the remaining content to the file.""" - for step_name in self._buffers: - if self._buffers[step_name]: - self._write(step_name) - - # We need to read the parquet files and write them again to ensure the schema - # is correct. Otherwise, the first parquets won't have the last schema and - # then we will have issues when reading them. - for file in list_files_in_dir(self._path / step_name): - table = pq.read_table(file, schema=self._buffer_last_schema[step_name]) - pq.write_table(table, file) + return signal.signal(signal.SIGINT, signal_handler) diff --git a/src/distilabel/pipeline/batch.py b/src/distilabel/pipeline/batch.py new file mode 100644 index 0000000000..d8ad4312ae --- /dev/null +++ b/src/distilabel/pipeline/batch.py @@ -0,0 +1,233 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import hashlib +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +import fsspec +import pyarrow as pa +import pyarrow.parquet as pq +from upath import UPath + +from distilabel.utils.serialization import _Serializable + + +@dataclass +class _Batch(_Serializable): + """Dataclass to represent a batch of data to be processed by a `_Step`. + + Attributes: + seq_no: The sequence number of the batch. + step_name: The name of the step that will process the batch. + last_batch: A flag to indicate if the batch is the last one. + data: The data to be processed. + data_hash: The hash of the data. Defaults to `None`. + data_path: The path where the data of the batch is stored. Defaults to `None`. + accumulated: A flag to indicate if the batch is accumulated. + created_from: A dictionary containing the `seq_no` of the batches of the steps that + were used to create this batch. + size: The size of the batch. + """ + + seq_no: int + step_name: str + last_batch: bool + data: List[List[Dict[str, Any]]] = field(default_factory=list, repr=False) + data_hash: Optional[str] = None + data_path: Optional[str] = None + accumulated: bool = False + created_from: Dict[str, List[Tuple[int, int]]] = field(default_factory=dict) + batch_routed_to: List[str] = field(default_factory=list) + size: int = 0 + _fs: Optional[fsspec.AbstractFileSystem] = None + + def next_batch(self) -> "_Batch": + """Create a new `_Batch` instance with the next batch of data. + + Args: + data: The data to be processed. + + Returns: + A `_Batch` instance. + """ + return _Batch( + seq_no=self.seq_no + 1, step_name=self.step_name, last_batch=self.last_batch + ) + + def set_data(self, data: List[List[Dict[str, Any]]]) -> None: + """Sets the data of the batch and updates the size of the batch. + + Args: + data: The data of the batch. + """ + self.data = data + self.size = len(data[0]) + self._update_data_hash() + + def get_data(self, num_rows: Union[int, None] = None) -> List[Dict[str, Any]]: + """Takes `num_rows` from the data of the batch and returns it. This method will + also remove the data from the batch and update the hash of the data. + + Args: + num_rows: The number of rows to take from the data. If `None`, then all the + data will be taken. Defaults to `None`. + + Returns: + A list with the data taken from the batch. + """ + + if self.data == [] and self.data_path is not None: + pass + + if num_rows is None: + data = self.data[0] + self.data = [] + else: + data = self.data[0][:num_rows] + self.data[0] = self.data[0][num_rows:] + + self._update_data_hash() + return data + + def _update_data_hash(self) -> None: + """Updates the hash of the data of the batch.""" + self.data_hash = hashlib.sha1(str(self.data).encode()).hexdigest() + + @classmethod + def accumulate(cls, step_name: str, batches: List[List["_Batch"]]) -> "_Batch": + """Creates a `_Batch` instance using the data from the list of batches that + were received from another steps. The batches will be accumulated in a single + list of data. + + Args: + step_name: The name of the step that will process the batch. + batches: a list containing the list of batches received from the predecessors. + + Returns: + A `_Batch` instance. + """ + + data = [] + for step_batches in batches: + accumulated_data = [row for batch in step_batches for row in batch.data[0]] + data.append(accumulated_data) + return cls( + seq_no=0, step_name=step_name, last_batch=True, data=data, accumulated=True + ) + + def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: + """Dumps the content of the `_Batch` to a dictionary, using the `dataclass` helper function. + + Args: + obj: Unused, just kept to match the signature of the parent method. + kwargs: Additional arguments that are kept to match the signature of the parent method. + + Returns: + A `dict` containing the internal representation of the `_Batch`. + """ + + include_batch_data = kwargs.get("include_batch_data", True) + + dump = { + "seq_no": self.seq_no, + "step_name": self.step_name, + "last_batch": self.last_batch, + "data_hash": self.data_hash, + "accumulated": self.accumulated, + "created_from": self.created_from, + "batch_routed_to": self.batch_routed_to, + "size": self.size, + } + + if include_batch_data: + dump["data"] = self.data + + return dump + + def copy(self) -> "_Batch": + """Creates a copy of the `_Batch` instance. + + Returns: + A copy of the `_Batch` instance. + """ + return copy.deepcopy(self) + + def write_batch_data_to_fs( + self, + fs: Optional[fsspec.AbstractFileSystem] = None, + base_path: Optional[UPath] = None, + ) -> None: + """Writes the content of the batch to the filesystem. + + Args + fs: The `fsspec` filesystem to be used to write the data. If not provided, the + one set in the `_fs` attribute will be used. Defaults to `None`. + base_path: The base path where the data of the batch will be stored. If not + provided, the one set in the `data_path` attribute will be used. Defaults + to `None`. + + Raises: + ValueError: If `fs` is not provided and the `_fs` attribute is not set. + """ + + if not fs and not self._fs: + raise ValueError( + "The `fs` parameter must be provided if the `_fs` attribute is not set." + ) + + if fs: + self._fs = fs + + if not base_path and not self.data_path: + raise ValueError( + "The `base_path` parameter must be provided if the `data_path` attribute" + " is not set." + ) + + seq_no_dir = ( + base_path / f"seq_no_{self.seq_no}" if base_path else UPath(self.data_path) + ) + seq_no_dir._fs_cached = self._fs # type: ignore + seq_no_dir.mkdir(parents=True, exist_ok=True) + + for i, data in enumerate(self.data): + table = pa.Table.from_pylist(data) + with self._fs.open(seq_no_dir / f"data_index_{i}.parquet", "wb") as f: # type: ignore + pq.write_table(table, f) + + self.data = [] + self.data_path = str(seq_no_dir) + + def read_batch_data_from_fs(self) -> None: + """Reads the content of the batch from the filesystem.""" + if not self.data_path: + raise ValueError( + "`data_path` attribute must be set to read the data from the filesystem." + " Use `write_batch_data_to_fs` method to set the `data_path` attribute." + ) + + if not self._fs: + raise ValueError( + "`_fs` attribute must be set to read the data from the filesystem." + " Use `write_batch_data_to_fs` method to set the `_fs` attribute." + ) + + for file in self._fs.ls(self.data_path): + with self._fs.open(file, "rb") as f: + table = pq.read_table(f) + self.data.append(table.to_pylist()) + + self._fs.rm(self.data_path, recursive=True) diff --git a/src/distilabel/pipeline/batch_manager.py b/src/distilabel/pipeline/batch_manager.py new file mode 100644 index 0000000000..cc14f0dd21 --- /dev/null +++ b/src/distilabel/pipeline/batch_manager.py @@ -0,0 +1,896 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Union + +from distilabel.pipeline._dag import DAG +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.constants import ( + RECEIVES_ROUTED_BATCHES_ATTR_NAME, + STEP_ATTR_NAME, +) +from distilabel.steps.base import _Step +from distilabel.utils.files import list_files_in_dir +from distilabel.utils.serialization import ( + StrOrPath, + _check_is_dir, + _Serializable, + read_json, +) + +if TYPE_CHECKING: + from distilabel.utils.serialization import StrOrPath + + +@dataclass +class _BatchManagerStep(_Serializable): + """A class that will accumulate data for a step from the predecessors and create + batches for the step to process when there is enough data. + + Attributes: + step_name: The name of the step that will process the data. + accumulate: A flag to indicate if the data should be accumulated and create a + batch with all the data received from the predecessors instead of creating + batches with the `input_batch_size`. + input_batch_size: The size of the batch to be created for the step to process. + If `None`, then `accumulate` must be `True`. Defaults to `None`. + data: A dictionary with the predecessor step name as the key and a list of + dictionaries (rows) as the value. + built_batches: A list with the batches that were built and sent to the step queue, + but the step was stopped before processing the batch, so the batch doesn't get + lost. Defaults to an empty list. + seq_no: The sequence number of the next batch to be created. It will be + incremented for each batch created. + last_batch_received: A list with the names of the steps that sent the last + batch of data. + convergence_step: A flag to indicate if the step is a convergence step. An + `Step` is a convergence step if all its predecessors are receiving routed + batches. Defaults to `False`. + convergence_step_batches_consumed: A dictionary in which the key is the `seq_no` + of the batch created by step A, that was used by step B and C and obtained from + the `created_from` of the batches created by them. It's used to know if all + the batches from B and C steps created from batches of A have been consumed + by D, in order to not mess up the order of the batches. Only used if `convergence_step=True`. + Defaults to an empty dictionary. + next_expected_created_from_batch_seq_no: The next expected sequence number of the + batch from step A used by steps B and C and obtained from the `created_from` + of the batches created by them. It's used to avoid messing up the order of the + batches. Only used if `convergence_step=True`. Defaults to `0`. + """ + + step_name: str + accumulate: bool + input_batch_size: Union[int, None] = None + data: Dict[str, List[_Batch]] = field(default_factory=dict) + built_batches: List[_Batch] = field(default_factory=list) + seq_no: int = 0 + last_batch_received: List[str] = field(default_factory=list) + convergence_step: bool = False + convergence_step_batches_consumed: Dict[str, Dict[str, int]] = field( + default_factory=dict + ) + next_expected_created_from_batch_seq_no: int = 0 + + def add_batch(self, batch: _Batch, prepend: bool = False) -> None: + """Add a batch of data from `batch.step_name` to the step. It will accumulate the + data and keep track of the last batch received from the predecessors. + + Args: + batch: The output batch of an step to be processed by the step. + prepend: If `True`, the content of the batch will be added to the `built_batches` + list. This is done so if a `_Batch` was already built and send to the step + queue, and the step is stopped before processing the batch, the batch doesn't + get lost. Defaults to `False`. + """ + from_step = batch.step_name + + if prepend: + self.built_batches.append(batch) + else: + self.data[from_step].append(batch) + + if batch.last_batch: + self.last_batch_received.append(from_step) + + def get_batch(self) -> Union[_Batch, None]: + """Create a new batch of data for the step to process. It will return `None` if + there is not enough data to create a batch. + + Returns: + A `_Batch` instance if there is enough data to create a batch. Otherwise, + `None`. + """ + if not self._ready_to_create_batch(): + return None + + # If there are batches in the `built_batches` list, then return the first one + # and remove it from the list. + if self.built_batches: + return self.built_batches.pop(0) + + # `_last_batch` must be called before `_get_data`, as `_get_data` will update the + # list of data which is used to determine if the batch to be created is the last one. + # TODO: remove `_last_batch` method and integrate logic in `_get_data` + last_batch = self._last_batch() + data, created_from, batch_routed_to = self._get_data() + + return _Batch( + seq_no=self._get_seq_no(), + step_name=self.step_name, + last_batch=last_batch, + data=data, + accumulated=self.accumulate, + created_from=created_from, + batch_routed_to=batch_routed_to, + ) + + def empty_buffers(self) -> List[str]: + """Checks if the input buffer for the step is empty. + + Returns: + The name of the previous steps for which the input buffer for this step is + empty. + """ + if self.accumulate: + return [ + previous_step + for previous_step in self.data.keys() + if previous_step not in self.last_batch_received + ] + + return [ + previous_step + for previous_step, batches in self.data.items() + if previous_step not in self.last_batch_received + and sum(len(batch.data[0]) for batch in batches) < self.input_batch_size # type: ignore + ] + + @classmethod + def from_step( + cls, step: "_Step", predecessors: Iterable[str], convergence_step: bool = False + ) -> "_BatchManagerStep": + """Creates a `_BatchManagerStep` instance from a `_Step` instance and its + predecessors. + + Args: + step: The `_Step` instance. + predecessors: The names of the predecessors of the step. + convergence_step: A flag to indicate if the step is a convergence step. An + `Step` is a convergence step if all its predecessors are receiving routed + batches. Defaults to `False`. + + Returns: + A `_BatchManagerStep` instance. + """ + return cls( + step_name=step.name, # type: ignore + accumulate=step.is_global, + input_batch_size=getattr(step, "input_batch_size", None), + data={predecessor: [] for predecessor in predecessors}, + convergence_step=convergence_step, + ) + + def _get_seq_no(self) -> int: + """Gets the sequence number for the next batch to be created and increments it. + + Returns: + The sequence number for the next batch to be created. + """ + seq_no = self.seq_no + self.seq_no += 1 + return seq_no + + def _get_data( + self, + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]: + """Gets the data needed to create a batch for the step to process. If the step is + accumulating data, then it will return a list with all the data received from the + predecessors. Otherwise, it will return a list of data with the `input_batch_size` + for each predecessor. In addition, it will remove the data used to create the + batch from the step's data. + + Returns: + A tuple containing the list of data needed to create a batch for the step to + process, a dictionary with the sequence numbers of the batches that were used + to create the batch and the list of steps to which the batch was routed to if + the step is a normal step. + """ + if self.accumulate: + # Steps accumulating cannot receive routed batches + return self._get_data_for_accumulate() + ([],) + + if self.convergence_step: + # Convergence steps will receive routed batches, but we need to clean the + # `batch_routed_to` list + return self._get_data_for_convergence_step() + ([],) + + return self._get_data_normal() + + def _get_data_for_accumulate( + self, + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]: + """Gets the data needed to create a batch for the step to process when the step + is accumulating data. It will return a list with all the data received from the + predecessors. In addition, it will remove the data used to create the batch from + the step's data. + + Returns: + A tuple containing the list of data needed to create a batch for the step to + process and a dictionary with the sequence numbers of the batches that were + used to create the batch. + """ + data = [] + batches_used = {} + for step_name, batches in self.data.items(): + batches_used[step_name] = [] + for batch in batches: + batches_used[step_name].append((batch.seq_no, batch.size)) + data.append([row for batch in batches for row in batch.get_data()]) + # Reset the data buffer + self.data = {step_name: [] for step_name in self.data} + return data, batches_used + + def _get_data_for_convergence_step( + self, + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]: + """Gets the data needed to create a batch for the step to process when the step is + a convergence step. + + Returns: + A tuple containing the list of data needed to create a batch for the step to + process and a dictionary with the sequence numbers of the batches that were + used to create the batch. + """ + grouped_batches = self._group_batches_by_created_from() + seq_no, batches = grouped_batches[0] + str_seq_no = str(seq_no) + + remaining_rows_per_step: Dict[str, int] = { + step_name: self.input_batch_size + for step_name in self.data # type: ignore + } + batches_used = defaultdict(list) + data = defaultdict(list) + for batch, batch_size in batches: + remaining_rows = remaining_rows_per_step[batch.step_name] + selected_data = batch.get_data(remaining_rows) + data[batch.step_name].extend(selected_data) + + # If A -> [B, C] -> D, then in D (this step) we keep track of the remaining + # rows from the batches of A that B and C used to create the `batches`. + batch_size = self.convergence_step_batches_consumed.setdefault( + str_seq_no, {} + ).get(batch.step_name, batch_size) + remaining_rows_in_batch = batch_size - len(selected_data) + self.convergence_step_batches_consumed[str_seq_no].update( + {batch.step_name: remaining_rows_in_batch} + ) + + # Update the remaining rows + num_rows = len(selected_data) + remaining_rows_per_step[batch.step_name] -= num_rows # type: ignore + + # Keep track of the batches used to create the batch + batches_used[batch.step_name].append((batch.seq_no, batch.size)) + + # If the batch was entirely consumed, then remove it from the buffer + if len(batch.data[0]) == 0: + self.data[batch.step_name].remove(batch) + continue + + # If all the batches grouped by the `seq_no` in `created_from` were consumed, then + # we can update the `next_expected_created_from_batch_seq_no` to the next one + # to avoid skipping batches. + no_remaining_rows = all( + count == 0 + for count in self.convergence_step_batches_consumed[str_seq_no].values() + ) + if no_remaining_rows: + self.next_expected_created_from_batch_seq_no += 1 + + return list(data.values()), dict(batches_used) + + def _get_data_normal( + self, + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]: + """Gets the data needed to create a batch for the step to process when the step is + not accumulating data. It will return a list of data with the `input_batch_size` + for each predecessor. In addition, it will remove the data used to create the batch + from the step's data. + + Returns: + A tuple containing the list of data needed to create a batch for the step to + process, a dictionary with the sequence numbers of the batches that were used + to create the batch and the list of steps to which the batch was routed to if + the step is a convergence step. + """ + data = [] + batches_used = defaultdict(list) + batch_routed_to = [] + for step_name in self.data: + # For each step batches buffer, we will create a batch with the `input_batch_size` + # using the data from the buffer. We will remove the consumed batches (no data + # left) and update the batch data with the remaining data. + step_data = [] + idx_drop_batches = [] + remaining_rows: int = self.input_batch_size # type: ignore + for idx, batch in enumerate(self.data[step_name]): + if remaining_rows == 0: + break + + # Get `remaining_rows` or the remaining rows in the batch and add it to + # the step data that will be used to create the batch + selected_data = batch.get_data(remaining_rows) + step_data.extend(selected_data) + batch_routed_to = batch.batch_routed_to + + # Update the remaining rows + num_rows = len(selected_data) + remaining_rows -= num_rows + + # Keep track of the batches used to create the batch + batches_used[step_name].append((batch.seq_no, batch.size)) + + # If the batch was entirely consumed, then remove it from the buffer + if len(batch.data[0]) == 0: + idx_drop_batches.append(idx) + continue + + # Remove the batches that were entirely consumed + idx_drop_batches.reverse() + for idx in idx_drop_batches: + self.data[step_name].pop(idx) + + data.append(step_data) + + return data, dict(batches_used), batch_routed_to + + def _ready_to_create_batch(self) -> bool: + """Checks if there is enough data to create a batch for the step. + + Returns: + `True` if there is enough data to create a batch for the step. Otherwise, + `False`. + """ + if self.accumulate: + return self._ready_to_create_batch_accumulate() + + if self.convergence_step: + return self._ready_to_create_batch_convergence_step() + + return self._ready_to_create_batch_normal() + + def _ready_to_create_batch_accumulate(self) -> bool: + """Checks if there is enough data for an step accumulating data. It will return + `True` if the last batch was received from all the predecessors. + + Returns: + `True` if ready to create a batch, `False` otherwise. + """ + return all( + step in self.last_batch_received + and sum(len(batch.data[0]) for batch in batches) >= 0 + for step, batches in self.data.items() + ) + + def _ready_to_create_batch_convergence_step(self) -> bool: + """Checks if there is enough data for creating a batch for an step in which output + batches that were generated by steps that received routed batches are received. + It will return `True`, if all the output batches that were generated from a routed + batch have been received. + + Returns: + `True` if ready to create a batch, `False` otherwise. + """ + grouped_batches = self._group_batches_by_created_from() + if not grouped_batches: + return False + seq_no, batches = grouped_batches[0] + + # If the `seq_no` from the `created_from` field is not the expected one, then + # we cannot create a batch yet or the order will be messed up + if seq_no != self.next_expected_created_from_batch_seq_no: + return False + + # Not all output batches to which the input batch was routed to haven't been + # received + batch_routed_to = batches[0][0].batch_routed_to + batches_received_from = {batch.step_name for batch, _ in batches} + if any(step_name not in batches_received_from for step_name in batch_routed_to): + return False + + # There are output batches to which the input batch was routed to from all + # the steps. Check if there is enough data for creating a batch with `input_batch_size` + rows_per_step = defaultdict(lambda: 0) + for batch, _ in batches: + num_rows = len(batch.data[0]) + rows_per_step[batch.step_name] += num_rows + + # If there aren't at least `input_batch_size` rows from each step, then there + # isn't enough data to create a batch + if not all( + num_rows >= self.input_batch_size or step_name in self.last_batch_received # type: ignore + for step_name, num_rows in rows_per_step.items() + ): + return False + + return True + + def _ready_to_create_batch_normal(self) -> bool: + """Checks if there is enough data for creating a batch for a normal step. It will + be `True` it there are at least `input_batch_size` rows from each predecessor step. + + Returns: + `True` if ready to create a batch, `False` otherwise. + """ + for step_name, batches in self.data.items(): + num_rows = sum(len(batch.data[0]) for batch in batches) + + # If there are now rows but the last batch was already received, then there + # are no more batch to be created + if num_rows == 0 and step_name in self.last_batch_received: + return False + + # If there are not enough rows and the last batch was not received yet, then + # there is not enough data yet to creata a batch + if ( + self.input_batch_size + and num_rows < self.input_batch_size + and step_name not in self.last_batch_received + ): + return False + + return True + + def _last_batch(self) -> bool: + """Checks if the batch to be created is the last one i.e. if the last batch was + received from all the predecessors. + + Returns: + `True` if the batch to be created is the last one. Otherwise, `False`. + """ + if self.accumulate: + return self._last_batch_accumulate() + + if self.convergence_step: + return self._last_batch_convergence_step() + + return self._last_batch_normal() + + def _last_batch_accumulate(self) -> bool: + """Checks if the batch to be created is the last one for an step accumulating data. + `True` if the last batch was received from all the predecessors. + + Returns: + `True` if the batch to be created is the last one. Otherwise, `False`. + """ + return all(step in self.last_batch_received for step in self.data.keys()) + + def _last_batch_convergence_step(self) -> bool: + """Checks if the batch to be created is the last one for a convergence step. `True` + if the last batch of all the steps (`batch_routed_to`) in the last routed batch + have been received. + + Returns: + `True` if the batch to be created is the last one. Otherwise, `False`. + """ + grouped_batches = self._group_batches_by_created_from() + if not grouped_batches: + return False + _, batches = grouped_batches[0] + + for batch, _ in batches: + if not batch.last_batch: + return False + + if len(batch.data[0]) > self.input_batch_size: # type: ignore + return False + + return True + + def _last_batch_normal(self) -> bool: + """Checks if the batch to be created is the last one for a normal step. `True` if + there is no more data to be received from the predecessors. + + Returns: + `True` if the batch to be created is the last one. Otherwise, `False`. + """ + for step_name, batches in self.data.items(): + if step_name not in self.last_batch_received: + return False + + num_rows = sum(len(batch.data[0]) for batch in batches) + + if ( + self.input_batch_size + and num_rows > self.input_batch_size + and step_name in self.last_batch_received + ): + return False + + return True + + def _group_batches_by_created_from( + self, + ) -> List[Tuple[int, List[Tuple["_Batch", int]]]]: + """Group the batches by the first key of `created_from` field. This method is + meant to be used only with a `convergence_step`. + + Returns: + A list of the batches grouped by the `seq_no` of the first step name in `created_from`. + The list is sorted by the `seq_no`. + """ + grouped_batches: Dict[int, List[Tuple["_Batch", int]]] = defaultdict(list) + for batches in self.data.values(): + for batch in batches: + first_key = next(iter(batch.created_from)) + batch_seq_no, batch_size = batch.created_from[first_key][0] + grouped_batches[batch_seq_no].append((batch, batch_size)) + return sorted((seq_no, batches) for seq_no, batches in grouped_batches.items()) + + def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: + """Dumps the content of the `_BatchManagerStep` to a dictionary, using the `dataclass` helper function. + + Args: + obj: Unused, just kept to match the signature of the parent method. + kwargs: Additional arguments that are kept to match the signature of the parent method. + + Returns: + Internal representation of the `_BatchManagerStep`. + """ + return { + "step_name": self.step_name, + "accumulate": self.accumulate, + "input_batch_size": self.input_batch_size, + "data": { + step_name: [batch.dump(**kwargs) for batch in batches] + for step_name, batches in self.data.items() + }, + "built_batches": [batch.dump(**kwargs) for batch in self.built_batches], + "seq_no": self.seq_no, + "last_batch_received": self.last_batch_received, + "convergence_step": self.convergence_step, + "convergence_step_batches_consumed": self.convergence_step_batches_consumed, + "next_expected_created_from_batch_seq_no": self.next_expected_created_from_batch_seq_no, + } + + +class _BatchManager(_Serializable): + """Class to manage the batches received from the steps. It keeps track of the + received batches and returns new batches for the steps to process based on their + input batch size and the batches received from the predecessors. + + Attributes: + steps: A dictionary with the step name as the key and a `_BatchManagerStep` + instance as the value. + last_batch_received: A dictionary with the step name as the key and a flag to + indicate whether we received the last batch from the step. + """ + + def __init__( + self, + steps: Dict[str, _BatchManagerStep], + last_batch_received: Dict[str, Union[_Batch, None]], + last_batch_sent: Dict[str, Union[_Batch, None]], + last_batch_flag_sent_to: List[str], + ) -> None: + """Initialize the `_BatchManager` instance. + + Args: + steps: A dictionary with the step name as the key and a dictionary with the + predecessor step name as the key and a list of batches as the value. + last_batch_received: A dictionary with the step name as the key and a the last + `_Batch` received from the step. + last_batch_sent: A dictionary with the step name as the key and a the last + `_Batch` sent to the step. + last_batch_flag_sent_to: A list with the names of the steps to which `LAST_BATCH_SENT_FLAG` + was sent. + """ + + self._steps = steps + self._last_batch_received = last_batch_received + self._last_batch_sent = last_batch_sent + self._last_batch_flag_sent_to = last_batch_flag_sent_to + + def can_generate(self) -> bool: + """Checks if there are still batches to be processed by the steps. + + Returns: + `True` if there are still batches to be processed by the steps. Otherwise, + `False`. + """ + + for step_name, batch in self._last_batch_received.items(): + if step_name not in self._last_batch_flag_sent_to: + if not batch: + return True + + if not batch.last_batch: + return True + + if not self.get_last_batch_sent(step_name): + return True + + return False + + def register_batch(self, batch: _Batch) -> None: + """Method to register a batch received from a step. It will keep track of the + sequence number and the last batch received from the step in the internal maps. + + Args: + batch: _Batch from which we will register the sequence number and the last batch received. + """ + self._last_batch_received[batch.step_name] = batch + + def get_last_batch(self, step_name: str) -> Union[_Batch, None]: + """Gets the last batch received from a step. + + Args: + step_name: The name of the step. + + Returns: + The last batch received from the step or `None` if no batch was received. + """ + return self._last_batch_received.get(step_name) + + def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None: + """Add an output batch from `batch.step_name` to `to_step`. + + Args: + to_step: The name of the step that will process the batch. + batch: The output batch of an step to be processed by `to_step`. + prepend: If `True`, the content of the batch will be added at the start of + the buffer. + + Raises: + ValueError: If `to_step` is not found in the batch manager. + """ + if to_step not in self._steps: + raise ValueError(f"Step '{to_step}' not found in the batch manager.") + + step = self._steps[to_step] + step.add_batch(batch, prepend) + + def get_batch(self, step_name: str) -> Union[_Batch, None]: + """Get the next batch to be processed by the step. + + Args: + step_name: The name of the step that will process the batch. + + Returns: + A `_Batch` instance if there is a batch to be processed by the step. Otherwise, + `None`. + """ + if step_name not in self._steps: + raise ValueError(f"Step '{step_name}' not found in the batch manager.") + + return self._steps[step_name].get_batch() + + def step_empty_buffers(self, step_name: str) -> List[str]: + """Checks if the input buffer for a step is empty. + + Args: + step_name: The name of the step. + + Returns: + The name of the previous steps for which the input buffer for this step is + empty. + """ + return self._steps[step_name].empty_buffers() + + def set_last_batch_sent(self, batch: "_Batch") -> None: + """Set the last batch sent to a step. + + Args: + batch: The last batch sent to a step. + """ + self._last_batch_sent[batch.step_name] = batch + + def get_last_batch_sent(self, step_name: str) -> Union["_Batch", None]: + """Get the last batch sent to a step. + + Args: + step_name: The name of the step. + + Returns: + The last batch sent to a step or `None` if no batch was sent. + """ + return self._last_batch_sent.get(step_name, None) + + def set_last_batch_flag_sent_to(self, step_name: str) -> None: + """Set the flag to indicate that the last batch was sent to a step. + + Args: + step_name: The name of the step. + """ + self._last_batch_flag_sent_to.append(step_name) + + @classmethod + def from_dag(cls, dag: "DAG") -> "_BatchManager": + """Create a `_BatchManager` instance from a `DAG` instance. + + Args: + dag: The `DAG` instance. + + Returns: + A `_BatchManager` instance. + """ + steps = {} + last_batch_received = {} + last_batch_sent = {} + for step_name in dag: + step: "_Step" = dag.get_step(step_name)[STEP_ATTR_NAME] + last_batch_received[step.name] = None + last_batch_sent[step.name] = None + if step.is_generator: + continue + predecessors = list(dag.get_step_predecessors(step_name)) + convergence_step = all( + dag.get_step(predecessor).get(RECEIVES_ROUTED_BATCHES_ATTR_NAME, False) + for predecessor in predecessors + ) + batch_manager_step = _BatchManagerStep.from_step( + step=step, + predecessors=predecessors, + convergence_step=convergence_step, + ) + steps[step_name] = batch_manager_step + return cls(steps, last_batch_received, last_batch_sent, []) + + def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: + """Dumps the content of the `_BatchManager` to a dictionary. + + Args: + obj (Any): Unused, just kept to match the signature of the parent method. + kwargs (Any): Additional arguments that are kept to match the signature of the parent method. + + Returns: + Dict[str, Any]: Internal representation of the `_BatchManager`. + """ + return { + "steps": {name: step.dump(**kwargs) for name, step in self._steps.items()}, + "last_batch_received": { + step_name: batch.dump(**kwargs) if batch is not None else None + for step_name, batch in self._last_batch_received.items() + }, + "last_batch_sent": { + step_name: batch.dump(**kwargs) if batch is not None else None + for step_name, batch in self._last_batch_sent.items() + }, + "last_batch_flag_sent_to": self._last_batch_flag_sent_to, + } + + def cache(self, path: "StrOrPath") -> None: + """Cache the `_BatchManager` to a file. + + Args: + path: The path to the file where the `_BatchManager` will be cached. If `None`, + then the `_BatchManager` will be cached in the default cache folder. + """ + + def save_batch( + batches_dir: Path, batch_dump: Dict[str, Any], batch_list: List[_Batch] + ) -> Path: + seq_no = batch_dump["seq_no"] + data_hash = batch_dump["data_hash"] + batch_file = batches_dir / f"batch_{seq_no}_{data_hash}.json" + + # Save the batch if it doesn't exist + if not batch_file.exists(): + # Get the data of the batch before saving it + batch = next(batch for batch in batch_list if batch.seq_no == seq_no) + batch_dump["data"] = batch.data + self.save(path=batch_file, format="json", dump=batch_dump) + + return batch_file + + def remove_files(keep_files: List[str], dir: Path) -> None: + files = list_files_in_dir(dir, key=None) + remove = set(files) - {Path(file) for file in keep_files} + for file in remove: + file.unlink() + + path = Path(path) + + # Do not include `_Batch` data so `dump` is fast + dump = self.dump(include_batch_data=False) + batch_manager_step_files = {} + + # Do this to avoid modifying the dictionary while iterating over it + batch_manager_steps = set(dump["steps"].keys()) + for step_name in batch_manager_steps: + step_dump = dump["steps"].pop(step_name) + + # Create a directory for each batch manager step to store their batches + batch_manager_step_dir = path.parent / "batch_manager_steps" / step_name + batch_manager_step_dir.mkdir(parents=True, exist_ok=True) + + # Store each built `_Batch` in a separate file + built_batches_dir = batch_manager_step_dir / "built_batches" + built_batches_dir.mkdir(parents=True, exist_ok=True) + step_dump["built_batches"] = [ + str( + save_batch( + batches_dir=built_batches_dir, + batch_dump=batch_dump, + batch_list=self._steps[step_name].built_batches, + ) + ) + for batch_dump in step_dump["built_batches"] + ] + # Remove built `_Batch`es that were consumed from cache + remove_files(step_dump["built_batches"], built_batches_dir) + + # Store each `_BatchManagerStep` `_Batch`es in a separate file + for buffered_step_name in step_dump["data"]: + step_batches_dir = batch_manager_step_dir / buffered_step_name + step_batches_dir.mkdir(parents=True, exist_ok=True) + + # Store each `_Batch` in a separate file + step_dump["data"][buffered_step_name] = [ + str( + save_batch( + batches_dir=step_batches_dir, + batch_dump=batch_dump, + batch_list=self._steps[step_name].data[buffered_step_name], + ) + ) + for batch_dump in step_dump["data"][buffered_step_name] + ] + + # Remove `_Batch`es that were consumed from cache + remove_files(step_dump["data"][buffered_step_name], step_batches_dir) + + # Store the `_BatchManagerStep` info + batch_manager_step_file = str( + path.parent / f"batch_manager_steps/{step_name}/batch_manager_step.json" + ) + self.save(path=batch_manager_step_file, format="json", dump=step_dump) + + # Store the path to the `_BatchManagerStep` file + batch_manager_step_files[step_name] = batch_manager_step_file + + dump["steps"] = batch_manager_step_files + self.save(path=path, format="json", dump=dump) + + @classmethod + def load_from_cache(cls, path: "StrOrPath") -> "_BatchManager": + """Loads the `_BatchManager` from a cache file. + + Args: + path: The path to the cache file. + """ + _check_is_dir(path) + content = read_json(path) + + # Read each `_BatchManagerStep` from file + steps = {} + for step_name, step_file in content["steps"].items(): + steps[step_name] = read_json(step_file) + + # Read each `_Batch` from file + steps[step_name]["built_batches"] = [ + read_json(batch) for batch in steps[step_name]["built_batches"] + ] + + for buffered_step_name, batch_files in steps[step_name]["data"].items(): + steps[step_name]["data"][buffered_step_name] = [ + read_json(batch_file) for batch_file in batch_files + ] + + content["steps"] = steps + return cls.from_dict(content) diff --git a/src/distilabel/pipeline/constants.py b/src/distilabel/pipeline/constants.py index 450ef0ed6d..3d400e4a1b 100644 --- a/src/distilabel/pipeline/constants.py +++ b/src/distilabel/pipeline/constants.py @@ -19,3 +19,4 @@ RECEIVES_ROUTED_BATCHES_ATTR_NAME: Final[str] = "receives_routed_batches" ROUTING_BATCH_FUNCTION_ATTR_NAME: Final[str] = "routing_batch_function" CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step" +LAST_BATCH_SENT_FLAG: Final[str] = "last_batch_sent" diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index 5bb7fd7991..51986c031d 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -12,65 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import multiprocessing as mp import signal -import threading -import time +import sys import traceback -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast import tblib from distilabel.distiset import create_distiset from distilabel.llms.mixins import CudaDevicePlacementMixin from distilabel.pipeline.base import ( - LAST_BATCH_SENT_FLAG, BasePipeline, - _Batch, - _BatchManager, - _WriteBuffer, ) +from distilabel.pipeline.batch import _Batch from distilabel.pipeline.constants import ( - CONVERGENCE_STEP_ATTR_NAME, - INPUT_QUEUE_ATTR_NAME, - ROUTING_BATCH_FUNCTION_ATTR_NAME, - STEP_ATTR_NAME, + LAST_BATCH_SENT_FLAG, ) -from distilabel.steps.base import Step +from distilabel.steps.tasks.base import Task from distilabel.utils.logging import setup_logging, stop_logging if TYPE_CHECKING: - from multiprocessing.managers import DictProxy, SyncManager - from multiprocessing.pool import Pool from queue import Queue from distilabel.distiset import Distiset - from distilabel.steps.base import GeneratorStep, _Step - - -_STEPS_LOADED_KEY = "steps_loaded" -_STEPS_LOADED_LOCK_KEY = "steps_loaded_lock" -_STEPS_LOADED_ERROR_CODE = -1 -_CUDA_LLM_DEVICE_PLACEMENT_KEY = "cuda_llm_device_placement" -_CUDA_LLM_DEVICE_PLACEMENT_LOCK_KEY = "cuda_llm_device_placement_lock" - -_STOP_CALLED = False -_STOP_CALLED_LOCK = threading.Lock() -_STOP_CALLS = 0 - -_STEPS_LOADED = set() -_STEPS_LOADED_LOCK = threading.Lock() + from distilabel.pipeline.typing import StepLoadStatus + from distilabel.steps.base import GeneratorStep, Step, _Step -_STEPS_FINISHED = set() -_STEPS_FINISHED_LOCK = threading.Lock() _SUBPROCESS_EXCEPTION: Union[Exception, None] = None -def _init_worker(queue: "Queue[Any]") -> None: +def _init_worker(log_queue: "Queue[Any]") -> None: + """Init function for the child processes that will execute the `Step`s of the `Pipeline`. + + Args: + log_queue: The queue to send the logs to the main process. + """ signal.signal(signal.SIGINT, signal.SIG_IGN) - setup_logging(queue) + setup_logging(log_queue) class Pipeline(BasePipeline): @@ -80,6 +60,8 @@ def run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, use_cache: bool = True, + storage_parameters: Optional[Dict[str, Any]] = None, + use_fs_to_pass_data: bool = False, ) -> "Distiset": """Runs the pipeline. @@ -88,6 +70,17 @@ def run( the runtime parameters for the step as the value. Defaults to `None`. use_cache: Whether to use the cache from previous pipeline runs. Defaults to `True`. + storage_parameters: A dictionary with the storage parameters (`fsspec` and path) + that will be used to store the data of the `_Batch`es passed between the + steps if `use_fs_to_pass_data` is `True` (for the batches received by a + `GlobalStep` it will be always used). It must have at least the "path" key, + and it can contain additional keys depending on the protocol. By default, + it will use the local file system and a directory in the cache directory. + Defaults to `None`. + use_fs_to_pass_data: Whether to use the file system to pass the data of + the `_Batch`es between the steps. Even if this parameter is `False`, the + `Batch`es received by `GlobalStep`s will always use the file system to + pass the data. Defaults to `False`. Returns: The `Distiset` created by the pipeline. @@ -96,37 +89,15 @@ def run( RuntimeError: If the pipeline fails to load all the steps. """ log_queue = mp.Queue() - # We must place the runtime parameters before calling setup_logging to ensure consistency - super().run(parameters, use_cache) - setup_logging(log_queue, filename=str(self._cache_location["log_file"])) # type: ignore - self._logger = logging.getLogger("distilabel.pipeline.local") - - if self._dry_run: - # This message is placed here to ensure we are using the already setup logger. - self._logger.info("🌵 Dry run mode") - - if self._batch_manager is None: - self._batch_manager = _BatchManager.from_dag(self.dag) - - # If the batch manager is not able to generate batches, that means that the loaded - # `_BatchManager` from cache didn't have any remaining batches to process i.e. - # the previous pipeline execution was completed successfully. - if not self._batch_manager.can_generate(): - self._logger.info( - "💾 Loaded batch manager from cache doesn't have any remaining data. Returning" - " `Distiset` from cache data..." - ) - stop_logging() - return create_distiset( - self._cache_location["data"], - pipeline_path=self._cache_location["pipeline"], - log_filename_path=self._cache_location["log_file"], - enable_metadata=self._enable_metadata, - ) - buffer_data_path = self._cache_location["data"] - self._logger.info(f"📝 Pipeline data will be written to '{buffer_data_path}'") - write_buffer = _WriteBuffer(buffer_data_path, self.dag.leaf_steps) + self._set_logging_parameters( + {"log_queue": log_queue, "filename": self._cache_location["log_file"]} + ) + + if distiset := super().run( + parameters, use_cache, storage_parameters, use_fs_to_pass_data + ): + return distiset num_processes = len(self.dag) ctx = mp.get_context() # type: ignore @@ -135,17 +106,24 @@ def run( initializer=_init_worker, initargs=(log_queue,), ) as pool: - self.output_queue: "Queue[Any]" = manager.Queue() - self.shared_info = self._create_shared_info_dict(manager) - self._handle_keyboard_interrupt(manager=manager, pool=pool) + self._manager = manager + self._pool = pool + self._output_queue = self.QueueClass() + self._load_queue = self.QueueClass() + self._handle_keyboard_interrupt() # Run the steps using the pool of processes - self._run_steps_in_loop(pool, manager, self.output_queue, self.shared_info) + self._run_steps() + + # Run the loop for receiving the load status of each step + self._load_steps_thread = self._run_load_queue_loop_in_thread() # Wait for all the steps to be loaded correctly if not self._all_steps_loaded(): - write_buffer.close() + self._write_buffer.close() # type: ignore self._batch_manager = None + self._stop_load_queue_loop() + self._load_steps_thread.join() stop_logging() raise RuntimeError( "Failed to load all the steps. Could not run pipeline." @@ -156,15 +134,22 @@ def run( self._request_initial_batches() # Start a loop to receive the output batches from the steps - self._run_output_queue_loop_in_thread(write_buffer) + self._output_queue_thread = self._run_output_queue_loop_in_thread() + self._output_queue_thread.join() # Send `None` to steps `input_queue`s just in case some step is still waiting self._notify_steps_to_stop() - pool.close() - pool.join() + # Stop the load queue loop + self._stop_load_queue_loop() + + # `Pool.__exit__` has already called `terminate`, `join` the pool to make sure + # all the processes have finished + self._load_steps_thread.join() + self._pool.join() + self._manager.join() - write_buffer.close() + self._write_buffer.close() # type: ignore distiset = create_distiset( self._cache_location["data"], pipeline_path=self._cache_location["pipeline"], @@ -174,436 +159,35 @@ def run( stop_logging() return distiset - def _run_output_queue_loop_in_thread(self, write_buffer: "_WriteBuffer") -> None: - """Runs the output queue loop in a separate thread to receive the output batches - from the steps. This is done to avoid the signal handler to block the loop, which - would prevent the pipeline from stopping correctly. - - Args: - write_buffer: The write buffer to write the data from the leaf steps to disk. - """ - thread = threading.Thread(target=self._output_queue_loop, args=(write_buffer,)) - thread.start() - thread.join() - - def _notify_steps_to_stop(self) -> None: - """Notifies the steps to stop their infinite running loop by sending `None` to - their input queues.""" - for step_name in self.dag: - if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME): - input_queue.put(None) - - def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None: - """Loop to receive the output batches from the steps and manage the flow of the - batches through the pipeline. - - Args: - write_buffer: The write buffer to write the data from the leaf steps to disk. - """ - while self._batch_manager.can_generate() and not _STOP_CALLED: # type: ignore - self._logger.debug("Waiting for output batch from step...") - if (batch := self.output_queue.get()) is None: - self._logger.debug("Received `None` from output queue. Breaking loop.") - break - - if batch.step_name in self.dag.leaf_steps: - write_buffer.add_batch(batch) - - # If `_STOP_CALLED` was set to `True` while waiting for the output queue, then - # we need to handle the stop of the pipeline and break the loop to avoid - # propagating the batches through the pipeline and making the stop process - # slower. - if _STOP_CALLED: - self._handle_batch_on_stop(batch) - break - - self._logger.debug( - f"Received batch with seq_no {batch.seq_no} from step '{batch.step_name}'" - f" from output queue: {batch}" - ) - - self._manage_batch_flow(batch) - - if _STOP_CALLED: - self._handle_stop(write_buffer) - - def _manage_batch_flow(self, batch: "_Batch") -> None: - """Checks if the step that generated the batch has more data in its buffer to - generate a new batch. If there's data, then a new batch is sent to the step. If - the step has no data in its buffer, then the predecessors generator steps are - requested to send a new batch. - - Args: - batch: The batch that was processed. - """ - assert self._batch_manager, "Batch manager is not set" - - # Make sure to send the `LAST_BATCH_SENT_FLAG` to the predecessors of the convergence - # step if the batch is the last one, so they stop their processing loop even if - # they haven't received the last batch because of the routing function. - if self._is_convergence_step(batch.step_name) and batch.last_batch: - for step_name in self.dag.get_step_predecessors(batch.step_name): - self._send_last_batch_flag_to_step(step_name) - - route_to, routed = self._get_successors(batch) - - # Keep track of the steps that the batch was routed to - if routed: - batch.batch_routed_to = route_to - - self._register_batch(batch) - - step = self._get_step_from_batch(batch) - - # Add the batch to the successors input buffers - for successor in route_to: - # Copy batch to avoid modifying the same reference in the batch manager - batch_to_add = batch.copy() if len(route_to) > 1 else batch - - self._batch_manager.add_batch(successor, batch_to_add) - - # Check if the step is a generator and if there are successors that need data - # from this step. This usually happens when the generator `batch_size` is smaller - # than the `input_batch_size` of the successor steps. - if ( - step.is_generator - and step.name in self._batch_manager.step_empty_buffers(successor) - ): - last_batch_sent = self._batch_manager.get_last_batch_sent(step.name) - self._send_batch_to_step(last_batch_sent.next_batch()) # type: ignore - - # If successor step has enough data in its buffer to create a new batch, then - # send the batch to the step. - if new_batch := self._batch_manager.get_batch(successor): - self._send_batch_to_step(new_batch) - - if step.is_generator: - return - - # Step ("this", the one from which the batch was received) has enough data on its - # buffers to create a new batch - if new_batch := self._batch_manager.get_batch(step.name): # type: ignore - self._send_batch_to_step(new_batch) - else: - self._request_more_batches_if_needed(step) - - self._cache() - - def _register_batch(self, batch: "_Batch") -> None: - """Registers a batch in the batch manager. - - Args: - batch: The batch to register. - """ - self._batch_manager.register_batch(batch) # type: ignore - self._logger.debug( - f"Batch {batch.seq_no} from step '{batch.step_name}' registered in batch" - " manager" - ) - - def _get_successors(self, batch: "_Batch") -> Tuple[List[str], bool]: - """Gets the successors and the successors to which the batch has to be routed. - - Args: - batch: The batch to which the successors will be determined. - - Returns: - The successors to route the batch to and whether the batch was routed using - a routing function. - """ - node = self.dag.get_step(batch.step_name) - step: "Step" = node[STEP_ATTR_NAME] - successors = list(self.dag.get_step_successors(step.name)) # type: ignore - route_to = successors - - # Check if the step has a routing function to send the batch to specific steps - if routing_batch_function := node.get(ROUTING_BATCH_FUNCTION_ATTR_NAME): - route_to = routing_batch_function(batch, successors) - successors_str = ", ".join(f"'{successor}'" for successor in route_to) - self._logger.info( - f"🚏 Using '{step.name}' routing function to send batch {batch.seq_no} to steps: {successors_str}" - ) - - return route_to, route_to != successors - - def _get_step_from_batch(self, batch: "_Batch") -> "Step": - """Gets the `Step` instance from a batch. - - Args: - batch: The batch to get the step from. - - Returns: - The `Step` instance. - """ - return self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] - - def _request_more_batches_if_needed(self, step: "Step") -> None: - """Request more batches to the predecessors steps of `step` if needed. - - Args: - step: The step of which it has to be checked if more batches are needed from - its predecessors. - """ - empty_buffers = self._batch_manager.step_empty_buffers(step.name) # type: ignore - for previous_step_name in empty_buffers: - if previous_step_name not in self.dag.root_steps: - continue - - last_batch = self._batch_manager.get_last_batch_sent(previous_step_name) # type: ignore - if last_batch is None: - continue - - self._logger.debug( - f"Step '{step.name}' input buffer for step '{previous_step_name}' is" - " empty. Requesting new batch..." - ) - self._send_batch_to_step(last_batch.next_batch()) - - def _handle_stop(self, write_buffer: "_WriteBuffer") -> None: - """Handles the stop of the pipeline execution, which will stop the steps from - processing more batches and wait for the output queue to be empty, to not lose - any data that was already processed by the steps before the stop was called. - - Args: - write_buffer: The write buffer to write the data from the leaf steps to disk. - """ - self._logger.debug("Handling stop of the pipeline execution...") - - # Add the remaining batches in the input queues back to the batch manager - for step_name in self.dag: - node = self.dag.get_step(step_name) - step: "_Step" = node[STEP_ATTR_NAME] - if step.is_generator: - continue - if input_queue := node.get(INPUT_QUEUE_ATTR_NAME): - while not input_queue.empty(): - batch = input_queue.get() - if batch is None: - continue - self._batch_manager.add_batch( # type: ignore - to_step=step_name, batch=batch, prepend=True - ) - self._logger.debug( - f"Adding batch back to the batch manager: {batch}" - ) - input_queue.put(None) - - # Wait for the input queue to be empty, which means that all the steps finished - # processing the batches that were sent before the stop flag. - for step_name in self.dag: - self._wait_step_input_queue_empty(step_name) - - # Consume the output queue until it's empty to not lose any data that was already - # processed by the steps before stop was called. - while not self.output_queue.empty(): - batch = self.output_queue.get() - if batch is None: - continue - - if batch.step_name in self.dag.leaf_steps: - write_buffer.add_batch(batch) - - self._handle_batch_on_stop(batch) - - self._cache() - - def _handle_batch_on_stop(self, batch: "_Batch") -> None: - """Handles a batch that was received from the output queue when the pipeline was - stopped. It will add and register the batch in the batch manager. - - Args: - batch: The batch to handle. - """ - self._batch_manager.register_batch(batch) # type: ignore - step: "Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] - for successor in self.dag.get_step_successors(step.name): # type: ignore - self._batch_manager.add_batch(successor, batch) # type: ignore - - def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", None]: - """Waits for the input queue of a step to be empty. - - Args: - step_name: The name of the step. - - Returns: - The input queue of the step if it's not loaded or finished, `None` otherwise. - """ - if self._check_step_not_loaded_or_finished(step_name): - return None - - if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME): - while input_queue.qsize() != 0: - pass - return input_queue - - def _create_shared_info_dict(self, manager: "SyncManager") -> "DictProxy[str, Any]": - """Creates the shared information dictionary to be used by the processes. - - Args: - manager: The manager to create the shared information. - - Returns: - The shared information dictionary. - """ - # TODO: not very important, but we could use a different lock for each matter - return manager.dict( - **{ - _STEPS_LOADED_KEY: manager.list(), - _STEPS_LOADED_LOCK_KEY: manager.Lock(), - _CUDA_LLM_DEVICE_PLACEMENT_KEY: manager.dict(**{}), - _CUDA_LLM_DEVICE_PLACEMENT_LOCK_KEY: manager.Lock(), - } - ) - - def _all_steps_loaded(self) -> bool: - """Waits for all the steps to load. + @property + def QueueClass(self) -> Callable: + """The callable used to create the input and output queues. Returns: - `True` if all the steps have been loaded correctly, `False` otherwise. - """ - - def _update_all_steps_loaded(steps_loaded: List[str]) -> None: - with _STEPS_LOADED_LOCK: - _STEPS_LOADED.update(steps_loaded) - - self._logger.info("⏳ Waiting for all the steps to load...") - previous_message = None - while not _STOP_CALLED: - with self.shared_info[_STEPS_LOADED_LOCK_KEY]: - steps_loaded = self.shared_info[_STEPS_LOADED_KEY] - num_steps_loaded = ( - len(steps_loaded) - if steps_loaded != [_STEPS_LOADED_ERROR_CODE] - else 0 - ) - self._logger.debug(f"Steps loaded: {steps_loaded}") - - message = f"⏳ Steps loaded: {num_steps_loaded}/{len(self.dag)}" - if num_steps_loaded > 0 and message != previous_message: - self._logger.info(message) - previous_message = message - - if num_steps_loaded == len(self.dag): - self._logger.info("✅ All the steps have been loaded!") - _update_all_steps_loaded(steps_loaded) - return True - - if steps_loaded == [_STEPS_LOADED_ERROR_CODE]: - self._logger.error("❌ Failed to load all the steps") - _update_all_steps_loaded(steps_loaded) - return False - - time.sleep(2.5) - - return not _STOP_CALLED - - def _request_initial_batches(self) -> None: - """Requests the initial batches to the generator steps.""" - assert self._batch_manager, "Batch manager is not set" - - for step in self._batch_manager._steps.values(): - if batch := step.get_batch(): - self._logger.debug( - f"Sending initial batch to '{step.step_name}' step: {batch}" - ) - self._send_batch_to_step(batch) - - for step_name in self.dag.root_steps: - seq_no = 0 - if last_batch := self._batch_manager.get_last_batch(step_name): - seq_no = last_batch.seq_no + 1 - batch = _Batch(seq_no=seq_no, step_name=step_name, last_batch=self._dry_run) - self._logger.debug( - f"Requesting initial batch to '{step_name}' generator step: {batch}" - ) - self._send_batch_to_step(batch) - - def _send_batch_to_step(self, batch: "_Batch") -> None: - """Sends a batch to the input queue of a step. - - Args: - batch: The batch to send. - """ - self._logger.debug( - f"Setting batch {batch.seq_no} as last batch sent to '{batch.step_name}': {batch}" - ) - self._batch_manager.set_last_batch_sent(batch) # type: ignore - - self._logger.debug( - f"Sending batch {batch.seq_no} to step '{batch.step_name}': {batch}" - ) - input_queue = self.dag.get_step(batch.step_name)[INPUT_QUEUE_ATTR_NAME] - input_queue.put(batch) - - def _is_convergence_step(self, step_name: str) -> None: - """Checks if a step is a convergence step. - - Args: - step_name: The name of the step. + The callable to create a `Queue`. """ - return self.dag.get_step(step_name).get(CONVERGENCE_STEP_ATTR_NAME) + assert self._manager, "Manager is not initialized" + return self._manager.Queue - def _send_last_batch_flag_to_step(self, step_name: str) -> None: - """Sends the `LAST_BATCH_SENT_FLAG` to a step to stop processing batches. + def _run_step(self, step: "_Step", input_queue: "Queue[Any]") -> None: + """Runs the `Step` wrapped in a `_ProcessWrapper` in a separate process of the + `Pool`. Args: - step_name: The name of the step. + step: The step to run. + input_queue: The input queue to send the data to the step. """ - batch = self._batch_manager.get_last_batch_sent(step_name) # type: ignore - if batch and batch.last_batch: - return - - self._logger.debug( - f"Sending `LAST_BATCH_SENT_FLAG` to '{step_name}' step to stop processing" - " batches..." + assert self._pool, "Pool is not initialized" + + process_wrapper = _ProcessWrapper( + step=step, + input_queue=input_queue, + output_queue=self._output_queue, + load_queue=self._load_queue, + dry_run=self._dry_run, ) - input_queue = self.dag.get_step(step_name)[INPUT_QUEUE_ATTR_NAME] - input_queue.put(LAST_BATCH_SENT_FLAG) - self._batch_manager.set_last_batch_flag_sent_to(step_name) # type: ignore - def _run_steps_in_loop( - self, - pool: "Pool", - manager: "SyncManager", - output_queue: "Queue[_Batch]", - shared_info: "DictProxy[str, Any]", - ) -> None: - """Using the `pool`, runs the steps in the DAG in an infinite loop waiting for - input batches and sending the output batches to the `output_queue`. - - Each `Step` is wrapped in a `_ProcessWrapper`, which will handle the lifecycle of - the `Step` and the communication with the `input_queue` and `output_queue`. The - `_ProcessWrapper.run` method is the target function of the process. - - Args: - pool: The pool of processes. - manager: The manager to create the queues. - output_queue: The queue to send the output batches. - shared_info: The shared information between the processes. - """ - for step_name in self.dag: - step: "Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME] - input_queue = manager.Queue() - self.dag.set_step_attr(step.name, INPUT_QUEUE_ATTR_NAME, input_queue) - - # Set `pipeline` to `None` as in some Python environments the pipeline is not - # picklable and it will raise an error when trying to send the step to the process. - # `TypeError: cannot pickle 'code' object` - step.pipeline = None - - process_wrapper = _ProcessWrapper( - step=step, - input_queue=input_queue, - output_queue=output_queue, - shared_info=shared_info, - dry_run=self._dry_run, - ) - - pool.apply_async( - process_wrapper.run, - callback=self._finished_callback, - error_callback=self._error_callback, - ) # type: ignore + self._pool.apply_async(process_wrapper.run, error_callback=self._error_callback) def _error_callback(self, e: BaseException) -> None: """Error callback that will be called when an error occurs in a `Step` process. @@ -622,23 +206,22 @@ def _error_callback(self, e: BaseException) -> None: if e.is_load_error: self._logger.error(f"❌ Failed to load step '{e.step.name}': {e.message}") - with self.shared_info[_STEPS_LOADED_LOCK_KEY]: - self.shared_info[_STEPS_LOADED_KEY] = [_STEPS_LOADED_ERROR_CODE] _SUBPROCESS_EXCEPTION = e.subprocess_exception - _SUBPROCESS_EXCEPTION.__traceback__ = tblib.Traceback.from_string( + _SUBPROCESS_EXCEPTION.__traceback__ = tblib.Traceback.from_string( # type: ignore e.formatted_traceback ).as_traceback() return # If the step is global, is not in the last trophic level and has no successors, # then we can ignore the error and continue executing the pipeline + step_name: str = e.step.name # type: ignore if ( e.step.is_global - and not self.dag.step_in_last_trophic_level(e.step.name) - and list(self.dag.get_step_successors(e.step.name)) == [] + and not self.dag.step_in_last_trophic_level(step_name) + and list(self.dag.get_step_successors(step_name)) == [] ): self._logger.error( - f"✋ An error occurred when running global step '{e.step.name}' with no" + f"✋ An error occurred when running global step '{step_name}' with no" " successors and not in the last trophic level. Pipeline execution can" f" continue. Error will be ignored." ) @@ -646,102 +229,53 @@ def _error_callback(self, e: BaseException) -> None: return # Global step with successors failed - self._logger.error(f"An error occurred in global step '{e.step.name}'") + self._logger.error(f"An error occurred in global step '{step_name}'") self._logger.error(f"Subprocess traceback:\n\n{e.formatted_traceback}") - self._cache() - self._stop() - - def _finished_callback(self, step_name: str) -> None: - """Callback that will be called when a `Step` process finishes. - Args: - step_name: The name of the step that finished. - """ - with _STEPS_FINISHED_LOCK: - _STEPS_FINISHED.add(step_name) - - def _check_step_not_loaded_or_finished(self, step_name: str) -> bool: - """Checks if a step is not loaded or already finished. - - Args: - step_name: The name of the step. - - Returns: - `True` if the step is not loaded or already finished, `False` otherwise. - """ - with _STEPS_LOADED_LOCK: - if step_name not in _STEPS_LOADED: - return True - - with _STEPS_FINISHED_LOCK: - if step_name in _STEPS_FINISHED: - return True - - return False + self._stop() - def _stop( - self, manager: Optional["SyncManager"] = None, pool: Optional["Pool"] = None - ) -> None: + def _stop(self) -> None: """Stops the pipeline execution. It will first send `None` to the input queues of all the steps and then wait until the output queue is empty i.e. all the steps finished processing the batches that were sent before the stop flag. Then it will send `None` to the output queue to notify the pipeline to stop.""" - global _STOP_CALLED - - with _STOP_CALLED_LOCK: - if _STOP_CALLED: - global _STOP_CALLS - _STOP_CALLS += 1 - # if _STOP_CALLS == 1: - # self._logger.warning( - # "🛑 Stop has already been called. Ignoring subsequent calls and waiting" - # " for the pipeline to finish..." - # ) - if _STOP_CALLS == 1: + with self._stop_called_lock: + if self._stop_called: + self._stop_calls += 1 + if self._stop_calls == 1: self._logger.warning( "🛑 Press again to force the pipeline to stop." ) - elif _STOP_CALLS > 1: + elif self._stop_calls > 1: self._logger.warning("🛑 Forcing pipeline interruption.") - import gc - import sys - if manager: - manager.shutdown() + if self._pool: + self._pool.terminate() + self._pool.join() + self._pool = None - if pool: - pool.close() - pool.terminate() + if self._manager: + self._manager.shutdown() + self._manager.join() + self._manager = None - gc.collect() + stop_logging() sys.exit(1) return - _STOP_CALLED = True + self._stop_called = True - self._logger.debug(f"Steps loaded before calling `stop`: {_STEPS_LOADED}") + self._logger.debug( + f"Steps loaded before calling `stop`: {self._steps_load_status}" + ) self._logger.info( "🛑 Stopping pipeline. Waiting for steps to finish processing batches..." ) - self._logger.debug("Sending `None` to the output queue to notify stop...") - self.output_queue.put(None) - def _handle_keyboard_interrupt( - self, manager: Optional["SyncManager"] = None, pool: Optional["Pool"] = None - ) -> None: - """Handles KeyboardInterrupt signal sent during the Pipeline.run method. - - It will try to call self._stop (if the pipeline didn't started yet, it won't - have any effect), and if the pool is already started, will close it before exiting - the program. - """ - - def signal_handler(signumber: int, frame: Any) -> None: - self._stop(manager=manager, pool=pool) - - signal.signal(signal.SIGINT, signal_handler) + self._stop_load_queue_loop() + self._stop_output_queue_loop() class _ProcessWrapperException(Exception): @@ -805,15 +339,16 @@ class _ProcessWrapper: step: The step to run. input_queue: The queue to receive the input data. output_queue: The queue to send the output data. - shared_info: The shared information between the processes. + load_queue: The queue used to notify the main process that the step has been loaded, + has been unloaded or has failed to load. """ def __init__( self, - step: "Step", + step: "_Step", input_queue: "Queue[_Batch]", output_queue: "Queue[_Batch]", - shared_info: "DictProxy[str, Any]", + load_queue: "Queue[Union[StepLoadStatus, None]]", dry_run: bool = False, ) -> None: """Initializes the `_ProcessWrapper`. @@ -822,29 +357,22 @@ def __init__( step: The step to run. input_queue: The queue to receive the input data. output_queue: The queue to send the output data. - shared_info: The shared information between the processes. + load_queue: The queue used to notify the main process that the step has been + loaded, has been unloaded or has failed to load. dry_run: Flag to ensure we are forcing to run the last batch. """ self.step = step self.input_queue = input_queue self.output_queue = output_queue - self.shared_info = shared_info + self.load_queue = load_queue self._dry_run = dry_run - # If step is a task, and it's using a `CUDALLM`, then set the CUDA device map - # and the lock for that map. - if hasattr(self.step, "llm") and isinstance( - self.step.llm, CudaDevicePlacementMixin + if ( + isinstance(self.step, Task) + and hasattr(self.step, "llm") + and isinstance(self.step.llm, CudaDevicePlacementMixin) ): - self.step.llm.set_device_placement_info( - llm_identifier=self.step.name, - device_llm_placement_map=self.shared_info[ - _CUDA_LLM_DEVICE_PLACEMENT_KEY - ], - device_llm_placement_lock=self.shared_info[ - _CUDA_LLM_DEVICE_PLACEMENT_LOCK_KEY - ], - ) + self.step.llm._llm_identifier = self.step.name def run(self) -> str: """The target function executed by the process. This function will also handle @@ -860,6 +388,8 @@ def run(self) -> str: self.step.load() self.step._logger.debug(f"Step '{self.step.name}' loaded!") except Exception as e: + self.step.unload() + self._notify_load_failed() raise _ProcessWrapperException.create_load_error( str(e), self.step, e ) from e @@ -877,14 +407,25 @@ def run(self) -> str: except Exception: pass + self.step.unload() + + self._notify_unload() + self.step._logger.info(f"🏁 Finished running step '{self.step.name}'") return self.step.name # type: ignore def _notify_load(self) -> None: """Notifies that the step has finished executing its `load` function successfully.""" - with self.shared_info[_STEPS_LOADED_LOCK_KEY]: - self.shared_info[_STEPS_LOADED_KEY].append(self.step.name) + self.load_queue.put({"name": self.step.name, "status": "loaded"}) # type: ignore + + def _notify_unload(self) -> None: + """Notifies that the step has been unloaded.""" + self.load_queue.put({"name": self.step.name, "status": "unloaded"}) # type: ignore + + def _notify_load_failed(self) -> None: + """Notifies that the step failed to load.""" + self.load_queue.put({"name": self.step.name, "status": "load_failed"}) # type: ignore def _generator_step_process_loop(self) -> None: """Runs the process loop for a generator step. It will call the `process` method @@ -960,6 +501,11 @@ def _non_generator_process_loop(self) -> None: self.step._logger.info( f"📦 Processing batch {batch.seq_no} in '{batch.step_name}'" ) + + if batch.data_path is not None: + self.step._logger.debug(f"Reading batch data from '{batch.data_path}'") + batch.read_batch_data_from_fs() + result = [] try: if self.step.has_multiple_inputs: @@ -971,11 +517,7 @@ def _non_generator_process_loop(self) -> None: raise _ProcessWrapperException(str(e), self.step, 2, e) from e # Impute step outputs columns with `None` - for row in batch.data[0]: - data = row.copy() - for output in self.step.outputs: - data[output] = None - result.append(data) + result = self._impute_step_outputs(batch) # if the step is not global then we can skip the batch which means sending # an empty batch to the output queue @@ -993,8 +535,26 @@ def _non_generator_process_loop(self) -> None: if batch.last_batch: break + def _impute_step_outputs(self, batch: "_Batch") -> List[Dict[str, Any]]: + """Imputes the step outputs columns with `None` in the batch data. + + Args: + batch: The batch to impute. + """ + result = [] + for row in batch.data[0]: + data = row.copy() + for output in self.step.outputs: + data[output] = None + result.append(data) + return result + def _send_batch(self, batch: _Batch) -> None: """Sends a batch to the `output_queue`.""" + if batch.data_path is not None: + self.step._logger.debug(f"Writing batch data to '{batch.data_path}'") + batch.write_batch_data_to_fs() + self.step._logger.info( f"📨 Step '{batch.step_name}' sending batch {batch.seq_no} to output queue" ) diff --git a/src/distilabel/pipeline/routing_batch_function.py b/src/distilabel/pipeline/routing_batch_function.py index c66a2d82b0..9d074d9fb6 100644 --- a/src/distilabel/pipeline/routing_batch_function.py +++ b/src/distilabel/pipeline/routing_batch_function.py @@ -26,7 +26,7 @@ ) if TYPE_CHECKING: - from distilabel.pipeline.base import _Batch + from distilabel.pipeline.batch import _Batch from distilabel.pipeline.typing import DownstreamConnectableSteps from distilabel.steps.base import _Step diff --git a/src/distilabel/pipeline/typing.py b/src/distilabel/pipeline/typing.py index 2ebb9b4643..e73d20e8ab 100644 --- a/src/distilabel/pipeline/typing.py +++ b/src/distilabel/pipeline/typing.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, TypeVar, Union +from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, TypeVar, Union if TYPE_CHECKING: + from distilabel.mixins.runtime_parameters import RuntimeParameterInfo from distilabel.steps.base import GeneratorStep, GlobalStep, Step DownstreamConnectable = Union["Step", "GlobalStep"] @@ -32,3 +33,17 @@ covariant=True, ) """Type for the `Step` types that can be connected as downstream steps.""" + + +class StepLoadStatus(TypedDict): + """Dict containing information about if one step was loaded/unloaded or if it's load + failed""" + + name: str + status: Literal["loaded", "unloaded", "load_failed"] + + +PipelineRuntimeParametersInfo = Dict[ + str, Union[List["RuntimeParameterInfo"], Dict[str, "RuntimeParameterInfo"]] +] +"""Alias for the information of the runtime parameters of a `Pipeline`.""" diff --git a/src/distilabel/pipeline/write_buffer.py b/src/distilabel/pipeline/write_buffer.py new file mode 100644 index 0000000000..a71ffdd9b2 --- /dev/null +++ b/src/distilabel/pipeline/write_buffer.py @@ -0,0 +1,168 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from os import PathLike +from pathlib import Path +from typing import Any, Dict, List, Set + +import pyarrow as pa +import pyarrow.parquet as pq + +from distilabel.pipeline.batch import _Batch +from distilabel.utils.dicts import flatten_dict +from distilabel.utils.files import list_files_in_dir + + +class _WriteBuffer: + """Class in charge of sending the batched contents to a buffer and writing + those to files under a given folder. + + As batches are received, they are added to the buffer and once each buffer + is full, the content is written to a parquet file. + """ + + def __init__(self, path: "PathLike", leaf_steps: Set[str]) -> None: + """ + Args: + path: Folder where the files will be written, the idea + is for this path to be in the cache folder under /data. + leaf_steps: Leaf steps from either the DAG of the Pipeline. + + Raises: + ValueError: If the path is not a directory. + """ + self._path = Path(path) + if not self._path.exists(): + self._path.mkdir(parents=True, exist_ok=True) + for step in leaf_steps: + (self._path / step).mkdir(parents=True, exist_ok=True) + + if not self._path.is_dir(): + raise ValueError(f"The path should be a directory, not a file: {path}") + + self._buffers: Dict[str, List[Dict[str, Any]]] = { + step: [] for step in leaf_steps + } + # TODO: make this configurable + self._buffers_dump_batch_size: Dict[str, int] = { + step: 50 for step in leaf_steps + } + self._buffer_last_schema = {} + self._buffers_last_file: Dict[str, int] = {step: 1 for step in leaf_steps} + self._logger = logging.getLogger("distilabel.write_buffer") + + def _get_filename(self, step_name: str) -> Path: + """Creates the filename for the step. + + Args: + step_name: Name of the step to which the data belongs to. + + Returns: + Filename for the step. + """ + return self._path / f"{step_name}.parquet" + + def is_full(self, step_name: str) -> bool: + """Checks the buffers that are full so that those can be written to the file. + + Returns: + Whether the buffer is full. + """ + return len(self._buffers[step_name]) >= self._buffers_dump_batch_size[step_name] + + def add_batch(self, batch: "_Batch") -> None: + """Adds a batch to the buffer and writes the buffer to the file if it's full. + + Args: + batch: batch to add to the buffer. + """ + step_name = batch.step_name + data = batch.data[0] + self._buffers[step_name].extend(data) + self._logger.debug( + f"Added batch to write buffer for step '{step_name}' with {len(data)} rows." + ) + if self.is_full(step_name): + self._logger.debug( + f"Buffer for step '{step_name}' is full (rows: {len(self._buffers[step_name])}," + f" full: {self._buffers_dump_batch_size[step_name]}), writing to file..." + ) + self._write(step_name) + + def _write(self, step_name: str) -> None: + """Writes the content to the file and cleans the buffer. + + Args: + step_name (str): Name of the step to which the data pertains. + """ + step_parquet_dir = Path(self._path, step_name) + if not step_parquet_dir.exists(): + self._logger.debug( + f"Creating directory for step '{step_name}' parquet files..." + ) + step_parquet_dir.mkdir() + + try: + table = pa.Table.from_pylist(self._buffers[step_name]) + except pa.lib.ArrowInvalid as pae: + if ( + repr(pae) + != "ArrowInvalid('cannot mix struct and non-struct, non-null values')" + ): + raise pae + flattened_buffers = [flatten_dict(buf) for buf in self._buffers[step_name]] + table = pa.Table.from_pylist(flattened_buffers) + + last_schema = self._buffer_last_schema.get(step_name) + if last_schema is None: + self._buffer_last_schema[step_name] = table.schema + else: + if not last_schema.equals(table.schema): + new_schema = pa.unify_schemas([last_schema, table.schema]) + self._buffer_last_schema[step_name] = new_schema + table = table.cast(new_schema) + + next_file_number = self._buffers_last_file[step_name] + self._buffers_last_file[step_name] = next_file_number + 1 + + parquet_file = step_parquet_dir / f"{str(next_file_number).zfill(5)}.parquet" + pq.write_table(table, parquet_file) + self._logger.debug(f"Written to file '{parquet_file}'") + + self._clean_buffer(step_name) + + def _clean_buffer(self, step_name: str) -> None: + """Cleans the buffer by setting it's content to `None`. + + Args: + step_name: The name of the buffer to clean. + """ + self._buffers[step_name] = [] + + def close(self) -> None: + """Closes the buffer by writing the remaining content to the file.""" + for step_name in self._buffers: + if self._buffers[step_name]: + self._write(step_name) + + # We need to read the parquet files and write them again to ensure the schema + # is correct. Otherwise, the first parquets won't have the last schema and + # then we will have issues when reading them. + for file in list_files_in_dir(self._path / step_name): + if step_name in self._buffer_last_schema: + table = pq.read_table( + file, schema=self._buffer_last_schema[step_name] + ) + pq.write_table(table, file) diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index efba7b3938..77c8818442 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -29,7 +29,12 @@ FormatTextGenerationSFT, ) from distilabel.steps.generators.data import LoadDataFromDicts -from distilabel.steps.generators.huggingface import LoadHubDataset +from distilabel.steps.generators.huggingface import ( + LoadDataFromDisk, + LoadDataFromFileSystem, + LoadDataFromHub, + LoadHubDataset, +) from distilabel.steps.globals.huggingface import PushToHub from distilabel.steps.keep import KeepColumns from distilabel.steps.typing import GeneratorStepOutput, StepOutput @@ -49,6 +54,9 @@ "GlobalStep", "KeepColumns", "LoadDataFromDicts", + "LoadDataFromDisk", + "LoadDataFromFileSystem", + "LoadDataFromHub", "LoadHubDataset", "PushToHub", "Step", diff --git a/src/distilabel/steps/argilla/base.py b/src/distilabel/steps/argilla/base.py index 1460e55f99..57cd1863bf 100644 --- a/src/distilabel/steps/argilla/base.py +++ b/src/distilabel/steps/argilla/base.py @@ -137,9 +137,7 @@ def load(self) -> None: @property @abstractmethod - def inputs(self) -> List[str]: - ... + def inputs(self) -> List[str]: ... @abstractmethod - def process(self, *inputs: StepInput) -> "StepOutput": - ... + def process(self, *inputs: StepInput) -> "StepOutput": ... diff --git a/src/distilabel/steps/argilla/preference.py b/src/distilabel/steps/argilla/preference.py index 7a1a5f7a15..fa674010b3 100644 --- a/src/distilabel/steps/argilla/preference.py +++ b/src/distilabel/steps/argilla/preference.py @@ -73,6 +73,62 @@ class PreferenceToArgilla(Argilla): generated ratings won't be pushed to Argilla. - rationales (`List[str]`, optional): The rationales for the ratings. If not provided, the generated rationales won't be pushed to Argilla. + + Examples: + + Push a preference dataset to an Argilla instance: + + ```python + from distilabel.steps import PreferenceToArgilla + + to_argilla = PreferenceToArgilla( + num_generations=2, + api_url="https://dibt-demo-argilla-space.hf.space/", + api_key="api.key", + dataset_name="argilla_dataset", + dataset_workspace="my_workspace", + ) + to_argilla.load() + + result = next( + to_argilla.process( + [ + { + "instruction": "instruction", + "generations": ["first_generation", "second_generation"], + } + ], + ) + ) + # >>> result + # [{'instruction': 'instruction', 'generations': ['first_generation', 'second_generation']}] + ``` + + It can also include ratings and rationales: + + ```python + result = next( + to_argilla.process( + [ + { + "instruction": "instruction", + "generations": ["first_generation", "second_generation"], + "ratings": ["4", "5"], + "rationales": ["rationale for 4", "rationale for 5"], + } + ], + ) + ) + # >>> result + # [ + # { + # 'instruction': 'instruction', + # 'generations': ['first_generation', 'second_generation'], + # 'ratings': ['4', '5'], + # 'rationales': ['rationale for 4', 'rationale for 5'] + # } + # ] + ``` """ num_generations: int diff --git a/src/distilabel/steps/argilla/text_generation.py b/src/distilabel/steps/argilla/text_generation.py index 1e259b7e0e..7e60f3b278 100644 --- a/src/distilabel/steps/argilla/text_generation.py +++ b/src/distilabel/steps/argilla/text_generation.py @@ -58,6 +58,36 @@ class TextGenerationToArgilla(Argilla): Input columns: - instruction (`str`): The instruction that was used to generate the completion. - generation (`str` or `List[str]`): The completions that were generated based on the input instruction. + + Examples: + + Push a text generation dataset to an Argilla instance: + + ```python + from distilabel.steps import PreferenceToArgilla + + to_argilla = TextGenerationToArgilla( + num_generations=2, + api_url="https://dibt-demo-argilla-space.hf.space/", + api_key="api.key", + dataset_name="argilla_dataset", + dataset_workspace="my_workspace", + ) + to_argilla.load() + + result = next( + to_argilla.process( + [ + { + "instruction": "instruction", + "generation": "generation", + } + ], + ) + ) + # >>> result + # [{'instruction': 'instruction', 'generation': 'generation'}] + ``` """ _id: str = PrivateAttr(default="id") diff --git a/src/distilabel/steps/base.py b/src/distilabel/steps/base.py index fcac454447..5db81a5666 100644 --- a/src/distilabel/steps/base.py +++ b/src/distilabel/steps/base.py @@ -16,7 +16,6 @@ import logging import re from abc import ABC, abstractmethod -from enum import Enum from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, overload @@ -27,7 +26,7 @@ RuntimeParameter, RuntimeParametersMixin, ) -from distilabel.utils.serialization import TYPE_INFO_KEY, _Serializable +from distilabel.utils.serialization import _Serializable from distilabel.utils.typing_ import is_parameter_annotated_with if TYPE_CHECKING: @@ -220,18 +219,15 @@ def _set_routing_batch_function( routing_batch_function._step = self @overload - def __rshift__(self, other: "RoutingBatchFunction") -> "RoutingBatchFunction": - ... + def __rshift__(self, other: "RoutingBatchFunction") -> "RoutingBatchFunction": ... @overload def __rshift__( self, other: List["DownstreamConnectableSteps"] - ) -> List["DownstreamConnectableSteps"]: - ... + ) -> List["DownstreamConnectableSteps"]: ... @overload - def __rshift__(self, other: "DownstreamConnectable") -> "DownstreamConnectable": - ... + def __rshift__(self, other: "DownstreamConnectable") -> "DownstreamConnectable": ... def __rshift__( self, @@ -306,6 +302,12 @@ def load(self) -> None: """ self._logger = logging.getLogger(f"distilabel.step.{self.name}") + def unload(self) -> None: + """Method to perform any cleanup logic after the `process` method is called. For + example, to close a connection to a database, etc. + """ + self._logger.debug("Executing step unload logic.") + @property def is_generator(self) -> bool: """Whether the step is a generator step or not. @@ -450,51 +452,6 @@ def get_outputs(self) -> List[str]: """ return [self.output_mappings.get(output, output) for output in self.outputs] - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "_Step": - """Create a Step from a dict containing the serialized data. - - Needs the information from the step and the Pipeline it belongs to. - - Note: - It's intended for internal use. - - Args: - data: dictionary containing the serialized data from a `Step` and the - `Pipeline` it belongs to. - - Returns: - A `Step` instance. - """ - # Remove the "type_info" to avoid errors on instantiation - _data = data.copy() - if TYPE_INFO_KEY in _data.keys(): - _data.pop(TYPE_INFO_KEY) - - # Before passing the data to instantiate the general step, we have to instantiate - # some of the internal objects. For the moment we only take into account the LLM, - # we should take care if we update any of the objects. - if llm := _data.get("llm"): - from distilabel.utils.serialization import _get_module_attr - - nested_cls = _get_module_attr(**llm.pop(TYPE_INFO_KEY)) - # Load the LLM and update the _data inplace - nested_cls = nested_cls(**llm) - _data.update({"llm": nested_cls}) - - # Enums need a specific restoring process - for k, v in _data.items(): - if isinstance(v, dict) and "_type" in v and v["_type"] == "enum": - _data[k] = Enum(v["_name"], v["_values"], type=eval(v["_enum_type"])) - - # Skip `runtime_parameters_info` since extras are not allowed - _data.pop("runtime_parameters_info", None) - - # Every step needs the pipeline, and the remaining arguments are general - step = cls(**_data) - - return step - def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: dump = super()._model_dump(obj, **kwargs) dump["runtime_parameters_info"] = self.get_runtime_parameters_info() diff --git a/src/distilabel/steps/combine.py b/src/distilabel/steps/combine.py index cc77d4091e..f0c5b49647 100644 --- a/src/distilabel/steps/combine.py +++ b/src/distilabel/steps/combine.py @@ -41,6 +41,51 @@ class CombineColumns(Step): Output columns: - dynamic (determined by `columns` and `output_columns` attributes): The columns that were merged. + + Examples: + + Combine columns of a dataset: + + ```python + from distilabel.steps import CombineColumns + + combine_columns = CombineColumns( + name="combine_columns", + columns=["generation", "model_name"], + ) + combine_columns.load() + + result = next( + combine_columns.process( + [{"generation": "AI generated text"}, {"model_name": "my_model"}], + [{"generation": "Other generated text", "model_name": "my_model"}] + ) + ) + # >>> result + # [{'merged_generation': ['AI generated text', 'Other generated text'], 'merged_model_name': ['my_model']}] + ``` + + Specify the name of the output columns: + + ```python + from distilabel.steps import CombineColumns + + combine_columns = CombineColumns( + name="combine_columns", + columns=["generation", "model_name"], + output_columns=["generations", "generation_models"] + ) + combine_columns.load() + + result = next( + combine_columns.process( + [{"generation": "AI generated text"}, {"model_name": "my_model"}], + [{"generation": "Other generated text", "model_name": "my_model"}] + ) + ) + # >>> result + #[{'generations': ['AI generated text', 'Other generated text'], 'generation_models': ['my_model']}] + ``` """ columns: List[str] diff --git a/src/distilabel/steps/constants.py b/src/distilabel/steps/constants.py new file mode 100644 index 0000000000..8d50ae4774 --- /dev/null +++ b/src/distilabel/steps/constants.py @@ -0,0 +1,15 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +DISTILABEL_METADATA_KEY = "distilabel_metadata" diff --git a/src/distilabel/steps/decorator.py b/src/distilabel/steps/decorator.py index 1d7c1853cb..da2cbb8dcc 100644 --- a/src/distilabel/steps/decorator.py +++ b/src/distilabel/steps/decorator.py @@ -53,8 +53,7 @@ def step( inputs: Union[List[str], None] = None, outputs: Union[List[str], None] = None, step_type: Literal["normal"] = "normal", -) -> Callable[..., Type["Step"]]: - ... +) -> Callable[..., Type["Step"]]: ... @overload @@ -62,8 +61,7 @@ def step( inputs: Union[List[str], None] = None, outputs: Union[List[str], None] = None, step_type: Literal["global"] = "global", -) -> Callable[..., Type["GlobalStep"]]: - ... +) -> Callable[..., Type["GlobalStep"]]: ... @overload @@ -71,8 +69,7 @@ def step( inputs: None = None, outputs: Union[List[str], None] = None, step_type: Literal["generator"] = "generator", -) -> Callable[..., Type["GeneratorStep"]]: - ... +) -> Callable[..., Type["GeneratorStep"]]: ... def step( diff --git a/src/distilabel/steps/deita.py b/src/distilabel/steps/deita.py index e1c258f87a..6b08abcbf7 100644 --- a/src/distilabel/steps/deita.py +++ b/src/distilabel/steps/deita.py @@ -64,6 +64,42 @@ class DeitaFiltering(GlobalStep): References: - [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685) + + Examples: + + Filter the dataset based on the DEITA score and the cosine distance between the embeddings: + + ```python + from distilabel.steps import DeitaFiltering + + deita_filtering = DeitaFiltering(data_budget=1) + + deita_filtering.load() + + result = next( + deita_filtering.process( + [ + { + "evol_instruction_score": 0.5, + "evol_response_score": 0.5, + "embedding": [-8.12729941, -5.24642847, -6.34003029], + }, + { + "evol_instruction_score": 0.6, + "evol_response_score": 0.6, + "embedding": [2.99329242, 0.7800932, 0.7799726], + }, + { + "evol_instruction_score": 0.7, + "evol_response_score": 0.7, + "embedding": [10.29041806, 14.33088073, 13.00557506], + }, + ], + ) + ) + # >>> result + # [{'evol_instruction_score': 0.5, 'evol_response_score': 0.5, 'embedding': [-8.12729941, -5.24642847, -6.34003029], 'deita_score': 0.25, 'deita_score_computed_with': ['evol_instruction_score', 'evol_response_score'], 'nearest_neighbor_distance': 1.9042812683723933}] + ``` """ data_budget: RuntimeParameter[int] = Field( diff --git a/src/distilabel/steps/expand.py b/src/distilabel/steps/expand.py index 3286d3c305..7312e1a4fd 100644 --- a/src/distilabel/steps/expand.py +++ b/src/distilabel/steps/expand.py @@ -41,6 +41,31 @@ class ExpandColumns(Step): Output columns: - dynamic (determined by `columns` attribute): The expanded columns. + + Examples: + + Expand the selected columns into multiple rows: + + ```python + from distilabel.steps import ExpandColumns + + expand_columns = ExpandColumns( + columns=["generation"], + ) + expand_columns.load() + + result = next( + expand_columns.process( + [ + { + "instruction": "instruction 1", + "generation": ["generation 1", "generation 2"]} + ], + ) + ) + # >>> result + # [{'instruction': 'instruction 1', 'generation': 'generation 1'}, {'instruction': 'instruction 1', 'generation': 'generation 2'}] + ``` """ columns: Union[Dict[str, str], List[str]] diff --git a/src/distilabel/steps/formatting/conversation.py b/src/distilabel/steps/formatting/conversation.py index 43d62de527..22cc582c54 100644 --- a/src/distilabel/steps/formatting/conversation.py +++ b/src/distilabel/steps/formatting/conversation.py @@ -34,6 +34,30 @@ class ConversationTemplate(Step): - format - chat - template + + Examples: + + Create a conversation from an instruction and a response: + + ```python + from distilabel.steps import ConversationTemplate + + conv_template = ConversationTemplate() + conv_template.load() + + result = next( + conv_template.process( + [ + { + "instruction": "Hello", + "response": "Hi", + } + ], + ) + ) + # >>> result + # [{'instruction': 'Hello', 'response': 'Hi', 'conversation': [{'role': 'user', 'content': 'Hello'}, {'role': 'assistant', 'content': 'Hi'}]}] + ``` """ @property diff --git a/src/distilabel/steps/formatting/dpo.py b/src/distilabel/steps/formatting/dpo.py index 72c4b1e440..9402436ee9 100644 --- a/src/distilabel/steps/formatting/dpo.py +++ b/src/distilabel/steps/formatting/dpo.py @@ -63,6 +63,43 @@ class FormatTextGenerationDPO(Step): - preference - instruction - generations + + Examples: + + Format your dataset for DPO fine tuning: + + ```python + from distilabel.steps import FormatTextGenerationDPO + + format_dpo = FormatTextGenerationDPO() + format_dpo.load() + + # NOTE: Both "system_prompt" and "generation_models" can be added optionally. + result = next( + format_dpo.process( + [ + { + "instruction": "What's 2+2?", + "generations": ["4", "5", "6"], + "ratings": [1, 0, -1], + } + ] + ) + ) + # >>> result + # [ + # { 'instruction': "What's 2+2?", + # 'generations': ['4', '5', '6'], + # 'ratings': [1, 0, -1], + # 'prompt': "What's 2+2?", + # 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29', + # 'chosen': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}], + # 'chosen_rating': 1, + # 'rejected': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '6'}], + # 'rejected_rating': -1 + # } + # ] + ``` """ @property @@ -194,6 +231,44 @@ class FormatChatGenerationDPO(Step): - preference - messages - generations + + Examples: + + Format your dataset for DPO fine tuning: + + ```python + from distilabel.steps import FormatChatGenerationDPO + + format_dpo = FormatChatGenerationDPO() + format_dpo.load() + + # NOTE: "generation_models" can be added optionally. + result = next( + format_dpo.process( + [ + { + "messages": [{"role": "user", "content": "What's 2+2?"}], + "generations": ["4", "5", "6"], + "ratings": [1, 0, -1], + } + ] + ) + ) + # >>> result + # [ + # { + # 'messages': [{'role': 'user', 'content': "What's 2+2?"}], + # 'generations': ['4', '5', '6'], + # 'ratings': [1, 0, -1], + # 'prompt': "What's 2+2?", + # 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29', + # 'chosen': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}], + # 'chosen_rating': 1, + # 'rejected': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '6'}], + # 'rejected_rating': -1 + # } + # ] + ``` """ @property diff --git a/src/distilabel/steps/formatting/sft.py b/src/distilabel/steps/formatting/sft.py index 0c6f51e2a9..ec93aadf79 100644 --- a/src/distilabel/steps/formatting/sft.py +++ b/src/distilabel/steps/formatting/sft.py @@ -48,6 +48,39 @@ class FormatTextGenerationSFT(Step): - text-generation - instruction - generation + + Examples: + + Format your dataset for SFT fine tuning: + + ```python + from distilabel.steps import FormatTextGenerationSFT + + format_sft = FormatTextGenerationSFT() + format_sft.load() + + # NOTE: "system_prompt" can be added optionally. + result = next( + format_sft.process( + [ + { + "instruction": "What's 2+2?", + "generation": "4" + } + ] + ) + ) + # >>> result + # [ + # { + # 'instruction': 'What's 2+2?', + # 'generation': '4', + # 'prompt': 'What's 2+2?', + # 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29', + # 'messages': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}] + # } + # ] + ``` """ @property @@ -133,6 +166,38 @@ class FormatChatGenerationSFT(Step): - chat-generation - instruction - generation + + Examples: + + Format your dataset for Supervised Fine Tuning (SFT): + + ```python + from distilabel.steps import FormatChatGenerationSFT + + format_sft = FormatChatGenerationSFT() + format_sft.load() + + # NOTE: "system_prompt" can be added optionally. + result = next( + format_sft.process( + [ + { + "messages": [{"role": "user", "content": "What's 2+2?"}], + "generation": "4" + } + ] + ) + ) + # >>> result + # [ + # { + # 'messages': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}], + # 'generation': '4', + # 'prompt': 'What's 2+2?', + # 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29', + # } + # ] + ``` """ @property diff --git a/src/distilabel/steps/generators/data.py b/src/distilabel/steps/generators/data.py index e51791006d..fbf29ec7fe 100644 --- a/src/distilabel/steps/generators/data.py +++ b/src/distilabel/steps/generators/data.py @@ -40,6 +40,24 @@ class LoadDataFromDicts(GeneratorStep): Categories: - load + + Examples: + + Load data from a list of dictionaries: + + ```python + from distilabel.steps import LoadDataFromDicts + + loader = LoadDataFromDicts( + data=[{"instruction": "What are 2+2?"}] * 5, + batch_size=2 + ) + loader.load() + + result = next(loader.process()) + # >>> result + # ([{'instruction': 'What are 2+2?'}, {'instruction': 'What are 2+2?'}], False) + ``` """ data: List[Dict[str, Any]] diff --git a/src/distilabel/steps/generators/huggingface.py b/src/distilabel/steps/generators/huggingface.py index da4a6d7f52..04f8b7c312 100644 --- a/src/distilabel/steps/generators/huggingface.py +++ b/src/distilabel/steps/generators/huggingface.py @@ -12,15 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from functools import lru_cache -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union - -import requests -from datasets import DatasetInfo, IterableDataset, load_dataset +import warnings +from collections import defaultdict +from functools import cached_property +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +from datasets import ( + Dataset, + DatasetInfo, + IterableDataset, + get_dataset_infos, + load_dataset, + load_from_disk, +) from pydantic import Field, PrivateAttr -from requests.exceptions import ConnectionError +from upath import UPath +from distilabel.distiset import Distiset from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.base import GeneratorStep @@ -28,7 +47,7 @@ from distilabel.steps.typing import GeneratorStepOutput -class LoadHubDataset(GeneratorStep): +class LoadDataFromHub(GeneratorStep): """Loads a dataset from the Hugging Face Hub. `GeneratorStep` that loads a dataset from the Hugging Face Hub using the `datasets` @@ -50,6 +69,8 @@ class LoadHubDataset(GeneratorStep): `False`. - `num_examples`: The number of examples to load from the dataset. By default will load all examples. + - `storage_options`: Key/value pairs to be passed on to the file-system backend, if any. + Defaults to `None`. Output columns: - dynamic (`all`): The columns that will be generated by this step, based on the @@ -57,6 +78,26 @@ class LoadHubDataset(GeneratorStep): Categories: - load + + Examples: + + Load data from a dataset in Hugging Face Hub: + + ```python + from distilabel.steps import LoadDataFromHub + + loader = LoadDataFromHub( + repo_id="distilabel-internal-testing/instruction-dataset-mini", + split="test", + batch_size=2 + ) + loader.load() + + # Just like we saw with LoadDataFromDicts, the `process` method will yield batches. + result = next(loader.process()) + # >>> result + # ([{'prompt': 'Arianna has 12...', False) + ``` """ repo_id: RuntimeParameter[str] = Field( @@ -80,8 +121,12 @@ class LoadHubDataset(GeneratorStep): default=None, description="The number of examples to load from the dataset. By default will load all examples.", ) + storage_options: Optional[Dict[str, Any]] = Field( + default=None, + description="The storage options to use when loading the dataset.", + ) - _dataset: Union[IterableDataset, None] = PrivateAttr(...) + _dataset: Union[IterableDataset, Dataset, None] = PrivateAttr(...) def load(self) -> None: """Load the dataset from the Hugging Face Hub""" @@ -155,11 +200,11 @@ def _get_dataset_num_examples(self) -> int: Returns: The number of examples in the dataset. """ - dataset_info = self._get_dataset_info() - split = self.split - if self.config: - return dataset_info["splits"][split]["num_examples"] - return dataset_info["default"]["splits"][split]["num_examples"] + return ( + self._dataset_info[self.config if self.config else "default"] + .splits[self.split] + .num_examples + ) def _get_dataset_columns(self) -> List[str]: """Get the columns of the dataset, based on the `config` runtime parameter provided. @@ -167,18 +212,14 @@ def _get_dataset_columns(self) -> List[str]: Returns: The columns of the dataset. """ - dataset_info = self._get_dataset_info() - - if isinstance(dataset_info, DatasetInfo): - if self.config: - return list(self._dataset[self.config].info.features.keys()) - return list(self._dataset.info.features.keys()) - - if self.config: - return list(dataset_info["features"].keys()) - return list(dataset_info["default"]["features"].keys()) + return list( + self._dataset_info[ + self.config if self.config else "default" + ].features.keys() + ) - def _get_dataset_info(self) -> Dict[str, Any]: + @cached_property + def _dataset_info(self) -> Dict[str, DatasetInfo]: """Calls the Datasets Server API from Hugging Face to obtain the dataset information. Returns: @@ -188,47 +229,355 @@ def _get_dataset_info(self) -> Dict[str, Any]: config = self.config try: - return _get_hf_dataset_info(repo_id, config) - except ConnectionError: + return get_dataset_infos(repo_id) + except Exception as e: # The previous could fail in case of a internet connection issues. # Assuming the dataset is already loaded and we can get the info from the loaded dataset, otherwise it will fail anyway. - self.load() + self._logger.warning( + f"Failed to get dataset info from Hugging Face Hub, trying to get it loading the dataset. Error: {e}" + ) + ds = load_dataset(repo_id, config=self.config, split=self.split) if config: - return self._dataset[config].info - return self._dataset.info + return ds[config].info + return ds.info + + +class LoadHubDataset(LoadDataFromHub): + def __init__(self, **data: Any) -> None: + warnings.warn( + "`LoadHubDataset` is deprecated and will be removed in version 1.3.0, use `LoadDataFromHub` instead.", + DeprecationWarning, + stacklevel=2, + ) + return super().__init__(**data) + +class LoadDataFromFileSystem(LoadDataFromHub): + """Loads a dataset from a file in your filesystem. -@lru_cache -def _get_hf_dataset_info( - repo_id: str, config: Union[str, None] = None -) -> Dict[str, Any]: - """Calls the Datasets Server API from Hugging Face to obtain the dataset information. - The results are cached to avoid making multiple requests to the server. + `GeneratorStep` that creates a dataset from a file in the filesystem, uses Hugging Face `datasets` + library. Take a look at [Hugging Face Datasets](https://huggingface.co/docs/datasets/loading) + for more information of the supported file types. + + Attributes: + data_files: The path to the file, or directory containing the files that conform + the dataset. + split: The split of the dataset to load (typically will be `train`, `test` or `validation`). + + Runtime parameters: + - `batch_size`: The batch size to use when processing the data. + - `data_files`: The path to the file, or directory containing the files that conform + the dataset. + - `split`: The split of the dataset to load. Defaults to 'train'. + - `streaming`: Whether to load the dataset in streaming mode or not. Defaults to + `False`. + - `num_examples`: The number of examples to load from the dataset. + By default will load all examples. + - `storage_options`: Key/value pairs to be passed on to the file-system backend, if any. + Defaults to `None`. + - `filetype`: The expected filetype. If not provided, it will be inferred from the file extension. + For more than one file, it will be inferred from the first file. + + Output columns: + - dynamic (`all`): The columns that will be generated by this step, based on the + datasets loaded from the Hugging Face Hub. + + Categories: + - load - Args: - repo_id: The Hugging Face Hub repository ID of the dataset. - config: The configuration of the dataset. This is optional and only needed if the - dataset has multiple configurations. + Examples: - Returns: - The dataset information. + Load data from a Hugging Face dataset in your file system: + + ```python + from distilabel.steps import LoadDataFromFileSystem + + loader = LoadDataFromFileSystem(data_files="path/to/dataset.jsonl") + loader.load() + + # Just like we saw with LoadDataFromDicts, the `process` method will yield batches. + result = next(loader.process()) + # >>> result + # ([{'type': 'function', 'function':...', False) + ``` + + Specify a filetype if the file extension is not expected: + + ```python + from distilabel.steps import LoadDataFromFileSystem + + loader = LoadDataFromFileSystem(filetype="csv", data_files="path/to/dataset.txtr") + loader.load() + + # Just like we saw with LoadDataFromDicts, the `process` method will yield batches. + result = next(loader.process()) + # >>> result + # ([{'type': 'function', 'function':...', False) + ``` + + Load data from a file in your cloud provider: + + ```python + from distilabel.steps import LoadDataFromFileSystem + + loader = LoadDataFromFileSystem( + data_files="gcs://path/to/dataset", + storage_options={"project": "experiments-0001"} + ) + loader.load() + + # Just like we saw with LoadDataFromDicts, the `process` method will yield batches. + result = next(loader.process()) + # >>> result + # ([{'type': 'function', 'function':...', False) + ``` """ - params = {"dataset": repo_id} - if config is not None: - params["config"] = config + data_files: RuntimeParameter[Union[str, Path]] = Field( + default=None, + description="The data files, or directory containing the data files, to generate the dataset from.", + ) + filetype: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The expected filetype. If not provided, it will be inferred from the file extension.", + ) + + def load(self) -> None: + """Load the dataset from the file/s in disk.""" + super(GeneratorStep, self).load() + + data_path = UPath(self.data_files, storage_options=self.storage_options) + + (data_files, self.filetype) = self._prepare_data_files(data_path) + + self._dataset = load_dataset( + self.filetype, + data_files=data_files, + split=self.split, + streaming=self.streaming, + storage_options=self.storage_options, + ) + + if not self.streaming and self.num_examples: + self._dataset = self._dataset.select(range(self.num_examples)) + if not self.num_examples: + if self.streaming: + # There's no better way to get the number of examples in a streaming dataset, + # load it again for the moment. + self.num_examples = len( + load_dataset( + self.filetype, data_files=self.data_files, split=self.split + ) + ) + else: + self.num_examples = len(self._dataset) + + @staticmethod + def _prepare_data_files( + data_path: UPath, + ) -> Tuple[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]], str]: + """Prepare the loading process by setting the `data_files` attribute. + + Args: + data_path: The path to the data files, or directory containing the data files. + + Returns: + Tuple with the data files and the filetype. + """ + + def get_filetype(data_path: UPath) -> str: + filetype = data_path.suffix.lstrip(".") + if filetype == "jsonl": + filetype = "json" + return filetype + + if data_path.is_file(): + filetype = get_filetype(data_path) + data_files = str(data_path) + elif data_path.is_dir(): + file_sequence = [] + file_map = defaultdict(list) + for file_or_folder in data_path.iterdir(): + if file_or_folder.is_file(): + file_sequence.append(str(file_or_folder)) + elif file_or_folder.is_dir(): + for file in file_or_folder.iterdir(): + file_sequence.append(str(file)) + file_map[str(file_or_folder)].append(str(file)) + + data_files = file_sequence or file_map + # Try to obtain the filetype from any of the files, assuming all files have the same type. + if file_sequence: + filetype = get_filetype(UPath(file_sequence[0])) + else: + filetype = get_filetype(UPath(file_map[list(file_map.keys())[0]][0])) + return data_files, filetype + + @property + def outputs(self) -> List[str]: + """The columns that will be generated by this step, based on the datasets from a file + in disk. + + Returns: + The columns that will be generated by this step. + """ + # We assume there are Dataset/IterableDataset, not it's ...Dict counterparts + if self._dataset is Ellipsis: + raise ValueError( + "Dataset not loaded yet, you must call `load` method first." + ) + + return self._dataset.column_names + + +class LoadDataFromDisk(LoadDataFromHub): + """Load a dataset that was previously saved to disk. + + If you previously saved your dataset using the `save_to_disk` method, or + `Distiset.save_to_disk` you can load it again to build a new pipeline using this class. + + Attributes: + dataset_path: The path to the dataset or distiset. + split: The split of the dataset to load (typically will be `train`, `test` or `validation`). + config: The configuration of the dataset to load. This is optional and only needed + if the dataset has multiple configurations. + + Runtime parameters: + - `batch_size`: The batch size to use when processing the data. + - `dataset_path`: The path to the dataset or distiset. + - `is_distiset`: Whether the dataset to load is a `Distiset` or not. Defaults to False. + - `split`: The split of the dataset to load. Defaults to 'train'. + - `config`: The configuration of the dataset to load. This is optional and only + needed if the dataset has multiple configurations. + - `num_examples`: The number of examples to load from the dataset. + By default will load all examples. + - `storage_options`: Key/value pairs to be passed on to the file-system backend, if any. + Defaults to `None`. + + Output columns: + - dynamic (`all`): The columns that will be generated by this step, based on the + datasets loaded from the Hugging Face Hub. + + Categories: + - load + + Examples: + + Load data from a Hugging Face Dataset: + + ```python + from distilabel.steps import LoadDataFromDisk + + loader = LoadDataFromDisk(dataset_path="path/to/dataset") + loader.load() + + # Just like we saw with LoadDataFromDicts, the `process` method will yield batches. + result = next(loader.process()) + # >>> result + # ([{'type': 'function', 'function':...', False) + ``` + + Load data from a distilabel Distiset: - if "HF_TOKEN" in os.environ: - headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"} - else: - headers = None + ```python + from distilabel.steps import LoadDataFromDisk - response = requests.get( - "https://datasets-server.huggingface.co/info", params=params, headers=headers + # Specify the configuration to load. + loader = LoadDataFromDisk( + dataset_path="path/to/dataset", + is_distiset=True, + config="leaf_step_1" + ) + loader.load() + + # Just like we saw with LoadDataFromDicts, the `process` method will yield batches. + result = next(loader.process()) + # >>> result + # ([{'a': 1}, {'a': 2}, {'a': 3}], True) + ``` + + Load data from a Hugging Face Dataset or Distiset in your cloud provider: + + ```python + from distilabel.steps import LoadDataFromDisk + + loader = LoadDataFromDisk( + dataset_path="gcs://path/to/dataset", + storage_options={"project": "experiments-0001"} + ) + loader.load() + + # Just like we saw with LoadDataFromDicts, the `process` method will yield batches. + result = next(loader.process()) + # >>> result + # ([{'type': 'function', 'function':...', False) + ``` + """ + + dataset_path: RuntimeParameter[Union[str, Path]] = Field( + default=None, + description="Path to the dataset or distiset.", + ) + config: RuntimeParameter[str] = Field( + default=None, + description="The configuration of the dataset to load. This is optional and only" + " needed if the dataset has multiple configurations.", + ) + is_distiset: Optional[RuntimeParameter[bool]] = Field( + default=False, + description="Whether the dataset to load is a `Distiset` or not. Defaults to False.", + ) + keep_in_memory: Optional[RuntimeParameter[bool]] = Field( + default=None, + description="Whether to copy the dataset in-memory, see `datasets.Dataset.load_from_disk` " + " for more information. Defaults to `None`.", + ) + split: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The split of the dataset to load. By default will load the whole Dataset/Distiset.", ) - assert ( - response.status_code == 200 - ), f"Failed to get '{repo_id}' dataset info. Make sure you have set the HF_TOKEN environment variable if it is a private dataset." + def load(self) -> None: + """Load the dataset from the file/s in disk.""" + super(GeneratorStep, self).load() + if self.is_distiset: + ds = Distiset.load_from_disk( + self.dataset_path, + keep_in_memory=self.keep_in_memory, + storage_options=self.storage_options, + ) + if self.config: + ds = ds[self.config] + + else: + ds = load_from_disk( + self.dataset_path, + keep_in_memory=self.keep_in_memory, + storage_options=self.storage_options, + ) + + if self.split: + ds = ds[self.split] + + self._dataset = ds + + if self.num_examples: + self._dataset = self._dataset.select(range(self.num_examples)) + else: + self.num_examples = len(self._dataset) + + @property + def outputs(self) -> List[str]: + """The columns that will be generated by this step, based on the datasets from a file + in disk. + + Returns: + The columns that will be generated by this step. + """ + # We assume there are Dataset/IterableDataset, not it's ...Dict counterparts + if self._dataset is Ellipsis: + raise ValueError( + "Dataset not loaded yet, you must call `load` method first." + ) - return response.json()["dataset_info"] + return self._dataset.column_names diff --git a/src/distilabel/steps/globals/huggingface.py b/src/distilabel/steps/globals/huggingface.py index c5b95a3b7d..28ef3932bd 100644 --- a/src/distilabel/steps/globals/huggingface.py +++ b/src/distilabel/steps/globals/huggingface.py @@ -56,6 +56,30 @@ class PushToHub(GlobalStep): - save - dataset - huggingface + + Examples: + + Push batches of your dataset to the Hugging Face Hub repository: + + ```python + from distilabel.steps import PushToHub + + push = PushToHub(repo_id="path_to/repo") + push.load() + + result = next( + push.process( + [ + { + "instruction": "instruction ", + "generation": "generation" + } + ], + ) + ) + # >>> result + # [{'instruction': 'instruction ', 'generation': 'generation'}] + ``` """ repo_id: RuntimeParameter[str] = Field( diff --git a/src/distilabel/steps/keep.py b/src/distilabel/steps/keep.py index 1f6f9bab88..58380660fa 100644 --- a/src/distilabel/steps/keep.py +++ b/src/distilabel/steps/keep.py @@ -43,6 +43,27 @@ class KeepColumns(Step): Output columns: - dynamic (determined by `columns` attribute): The columns that were kept. + + Examples: + + Select the columns to keep: + + ```python + from distilabel.steps import KeepColumns + + keep_columns = KeepColumns( + columns=["instruction", "generation"], + ) + keep_columns.load() + + result = next( + keep_columns.process( + [{"instruction": "What's the brightest color?", "generation": "white", "model_name": "my_model"}], + ) + ) + # >>> result + # [{'instruction': "What's the brightest color?", 'generation': 'white'}] + ``` """ columns: List[str] diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 9fcb882c15..b2456d7824 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -23,6 +23,15 @@ from distilabel.steps.tasks.evol_quality.base import EvolQuality from distilabel.steps.tasks.generate_embeddings import GenerateEmbeddings from distilabel.steps.tasks.genstruct import Genstruct +from distilabel.steps.tasks.improving_text_embeddings import ( + BitextRetrievalGenerator, + EmbeddingTaskGenerator, + GenerateLongTextMatchingData, + GenerateShortTextMatchingData, + GenerateTextClassificationData, + GenerateTextRetrievalData, + MonolingualTripletGenerator, +) from distilabel.steps.tasks.instruction_backtranslation import ( InstructionBacktranslation, ) @@ -30,16 +39,15 @@ from distilabel.steps.tasks.prometheus_eval import PrometheusEval from distilabel.steps.tasks.quality_scorer import QualityScorer from distilabel.steps.tasks.self_instruct import SelfInstruct +from distilabel.steps.tasks.sentence_transformers import GenerateSentencePair +from distilabel.steps.tasks.structured_generation import StructuredGeneration from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback __all__ = [ - "Task", "GeneratorTask", - "ChatGeneration", - "ChatItem", - "ChatType", + "Task", "ComplexityScorer", "EvolInstruct", "EvolComplexity", @@ -48,11 +56,23 @@ "EvolQuality", "GenerateEmbeddings", "Genstruct", + "BitextRetrievalGenerator", + "EmbeddingTaskGenerator", + "GenerateLongTextMatchingData", + "GenerateShortTextMatchingData", + "GenerateTextClassificationData", + "GenerateTextRetrievalData", + "MonolingualTripletGenerator", "InstructionBacktranslation", "PairRM", "PrometheusEval", "QualityScorer", "SelfInstruct", + "GenerateSentencePair", + "StructuredGeneration", + "ChatGeneration", "TextGeneration", + "ChatItem", + "ChatType", "UltraFeedback", ] diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index b2a4734879..06a6fecd06 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union from pydantic import Field +from typing_extensions import override from distilabel.llms.base import LLM from distilabel.mixins.runtime_parameters import RuntimeParameter @@ -25,17 +26,15 @@ StepInput, _Step, ) +from distilabel.steps.constants import DISTILABEL_METADATA_KEY from distilabel.utils.dicts import combine_dicts if TYPE_CHECKING: from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.tasks.typing import FormattedInput from distilabel.steps.typing import StepOutput -DISTILABEL_METADATA_KEY = "distilabel_metadata" - - class _Task(_Step, ABC): """_Task is an abstract class that implements the `_Step` interface and adds the `format_input` and `format_output` methods to format the inputs and outputs of the @@ -54,19 +53,33 @@ class _Task(_Step, ABC): llm: LLM group_generations: bool = False - add_raw_output: bool = False + add_raw_output: RuntimeParameter[bool] = Field( + default=True, + description=( + "Whether to include the raw output of the LLM in the key `raw_output_`" + " of the `distilabel_metadata` dictionary output column" + ), + ) num_generations: RuntimeParameter[int] = Field( default=1, description="The number of generations to be produced per input." ) def load(self) -> None: - """Loads the LLM via the `LLM.load()` method (done for safer serialization).""" + """Loads the LLM via the `LLM.load()` method.""" super().load() self.llm.load() + @override + def unload(self) -> None: + """Unloads the LLM.""" + self._logger.debug("Executing task unload logic.") + self.llm.unload() + @abstractmethod def format_output( - self, output: Union[str, None], input: Dict[str, Any] + self, + output: Union[str, None], + input: Union[Dict[str, Any], None] = None, ) -> Dict[str, Any]: """Abstract method to format the outputs of the task. It needs to receive an output as a string, and generates a Python dictionary with the outputs of the task. In @@ -76,7 +89,9 @@ def format_output( pass def _format_outputs( - self, outputs: "GenerateOutput", inputs: List[Dict[str, Any]] + self, + outputs: "GenerateOutput", + inputs: Union[List[Dict[str, Any]], None] = None, ) -> List[Dict[str, Any]]: """Formats the outputs of the task using the `format_output` method. If the output is `None` (i.e. the LLM failed to generate a response), then the outputs will be @@ -89,12 +104,17 @@ def _format_outputs( Returns: A list containing a dictionary with the outputs of the task for each input. """ + if inputs is None: + inputs = [None] # type: ignore + formatted_outputs = [] - for output, input in zip(outputs, inputs * len(outputs)): + for output, input in zip(outputs, inputs * len(outputs)): # type: ignore try: formatted_output = self.format_output(output, input) formatted_output = self._maybe_add_raw_output( - formatted_output, output, add_raw_output=self.add_raw_output + formatted_output, + output, + add_raw_output=self.add_raw_output, # type: ignore ) formatted_outputs.append(formatted_output) except Exception as e: @@ -105,16 +125,18 @@ def _format_outputs( return formatted_outputs def _output_on_failure( - self, output: Union[str, None], input: Dict[str, Any] + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None ) -> Dict[str, Any]: """In case of failure to format the output, this method will return a dictionary including a new field `distilabel_meta` with the raw output of the LLM. """ # Create a dictionary with the outputs of the task (every output set to None) outputs = {output: None for output in self.outputs} - outputs["model_name"] = self.llm.model_name + outputs["model_name"] = self.llm.model_name # type: ignore outputs = self._maybe_add_raw_output( - outputs, output, add_raw_output=self.add_raw_output + outputs, + output, + add_raw_output=self.add_raw_output, # type: ignore ) return outputs @@ -144,12 +166,12 @@ class Task(_Task, Step): """ @abstractmethod - def format_input(self, input: Dict[str, Any]) -> "ChatType": + def format_input(self, input: Dict[str, Any]) -> "FormattedInput": """Abstract method to format the inputs of the task. It needs to receive an input as a Python dictionary, and generates an OpenAI chat-like list of dicts.""" pass - def _format_inputs(self, inputs: List[Dict[str, Any]]) -> List["ChatType"]: + def _format_inputs(self, inputs: List[Dict[str, Any]]) -> List["FormattedInput"]: """Formats the inputs of the task using the `format_input` method. Args: @@ -172,10 +194,11 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore """ formatted_inputs = self._format_inputs(inputs) + outputs = self.llm.generate( inputs=formatted_inputs, num_generations=self.num_generations, # type: ignore - **self.llm.generation_kwargs, # type: ignore + **self.llm.get_generation_kwargs(), # type: ignore ) task_outputs = [] @@ -185,14 +208,14 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore if self.group_generations: combined = combine_dicts(*formatted_outputs) task_outputs.append( - {**input, "model_name": self.llm.model_name, **combined} + {**input, **combined, "model_name": self.llm.model_name} ) continue # Create a row per generation for formatted_output in formatted_outputs: task_outputs.append( - {**input, "model_name": self.llm.model_name, **formatted_output} + {**input, **formatted_output, "model_name": self.llm.model_name} ) yield task_outputs diff --git a/src/distilabel/steps/tasks/complexity_scorer.py b/src/distilabel/steps/tasks/complexity_scorer.py index 758aaf05d6..b20909383a 100644 --- a/src/distilabel/steps/tasks/complexity_scorer.py +++ b/src/distilabel/steps/tasks/complexity_scorer.py @@ -59,6 +59,32 @@ class ComplexityScorer(Task): References: - [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685) + + Examples: + + Evaluate the complexity of your instructions: + + ```python + from distilabel.steps.tasks import ComplexityScorer + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + scorer = ComplexityScorer( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ) + ) + + scorer.load() + + result = next( + scorer.process( + [{"instructions": ["plain instruction", "highly complex instruction"]}] + ) + ) + # result + # [{'instructions': ['plain instruction', 'highly complex instruction'], 'model_name': 'test', 'scores': [1, 5], 'distilabel_metadata': {'raw_output_complexity_scorer_0': 'output'}}] + ``` """ _template: Union[Template, None] = PrivateAttr(...) diff --git a/src/distilabel/steps/tasks/evol_instruct/base.py b/src/distilabel/steps/tasks/evol_instruct/base.py index b7c4e2f65a..e0071bf5d9 100644 --- a/src/distilabel/steps/tasks/evol_instruct/base.py +++ b/src/distilabel/steps/tasks/evol_instruct/base.py @@ -69,6 +69,86 @@ class EvolInstruct(Task): References: - [WizardLM: Empowering Large Language Models to Follow Complex Instructions](https://arxiv.org/abs/2304.12244) - [GitHub: h2oai/h2o-wizardlm](https://github.com/h2oai/h2o-wizardlm) + + Examples: + + Evolve an instruction using an LLM: + + ```python + from distilabel.steps.tasks import EvolInstruct + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + evol_instruct = EvolInstruct( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ), + num_evolutions=2, + ) + + evol_instruct.load() + + result = next(evol_instruct.process([{"instruction": "common instruction"}])) + # result + # [{'instruction': 'common instruction', 'evolved_instruction': 'evolved instruction', 'model_name': 'model_name'}] + ``` + + Keep the iterations of the evolutions: + + ```python + from distilabel.steps.tasks import EvolInstruct + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + evol_instruct = EvolInstruct( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ), + num_evolutions=2, + store_evolutions=True, + ) + + evol_instruct.load() + + result = next(evol_instruct.process([{"instruction": "common instruction"}])) + # result + # [ + # { + # 'instruction': 'common instruction', + # 'evolved_instructions': ['initial evolution', 'final evolution'], + # 'model_name': 'model_name' + # } + # ] + ``` + + Generate answers for the instructions in a single step: + + ```python + from distilabel.steps.tasks import EvolInstruct + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + evol_instruct = EvolInstruct( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ), + num_evolutions=2, + generate_answers=True, + ) + + evol_instruct.load() + + result = next(evol_instruct.process([{"instruction": "common instruction"}])) + # result + # [ + # { + # 'instruction': 'common instruction', + # 'evolved_instruction': 'evolved instruction', + # 'answer': 'answer to the instruction', + # 'model_name': 'model_name' + # } + # ] + ``` """ num_evolutions: int diff --git a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py index 75161eb70c..c07f49d621 100644 --- a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py +++ b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py @@ -24,7 +24,7 @@ class EvolComplexity(EvolInstruct): """Evolve instructions to make them more complex using an `LLM`. `EvolComplexity` is a task that evolves instructions to make them more complex, - and it is based in the EvolInstruct task, but using slight different prompts, but the + and it is based in the EvolInstruct task, using slight different prompts, but the exact same evolutionary approach. Attributes: @@ -61,6 +61,29 @@ class EvolComplexity(EvolInstruct): References: - [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685) - [WizardLM: Empowering Large Language Models to Follow Complex Instructions](https://arxiv.org/abs/2304.12244) + + Examples: + + Evolve an instruction using an LLM: + + ```python + from distilabel.steps.tasks import EvolComplexity + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + evol_complexity = EvolComplexity( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ), + num_evolutions=2, + ) + + evol_complexity.load() + + result = next(evol_complexity.process([{"instruction": "common instruction"}])) + # result + # [{'instruction': 'common instruction', 'evolved_instruction': 'evolved instruction', 'model_name': 'model_name'}] + ``` """ mutation_templates: Dict[str, str] = MUTATION_TEMPLATES diff --git a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py index 8fc1f35db8..c4cd051190 100644 --- a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py +++ b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py @@ -59,6 +59,29 @@ class EvolComplexityGenerator(EvolInstructGenerator): References: - [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685) - [WizardLM: Empowering Large Language Models to Follow Complex Instructions](https://arxiv.org/abs/2304.12244) + + Examples: + + Generate evolved instructions without initial instructions: + + ```python + from distilabel.steps.tasks import EvolComplexityGenerator + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + evol_complexity_generator = EvolComplexityGenerator( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ), + num_instructions=2, + ) + + evol_complexity_generator.load() + + result = next(scorer.process()) + # result + # [{'instruction': 'generated instruction', 'model_name': 'test'}] + ``` """ mutation_templates: Dict[str, str] = GENERATION_MUTATION_TEMPLATES diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py index 8d4adc38b5..bc8c5d2eff 100644 --- a/src/distilabel/steps/tasks/evol_instruct/generator.py +++ b/src/distilabel/steps/tasks/evol_instruct/generator.py @@ -75,6 +75,29 @@ class EvolInstructGenerator(GeneratorTask): References: - [WizardLM: Empowering Large Language Models to Follow Complex Instructions](https://arxiv.org/abs/2304.12244) - [GitHub: h2oai/h2o-wizardlm](https://github.com/h2oai/h2o-wizardlm) + + Examples: + + Generate evolved instructions without initial instructions: + + ```python + from distilabel.steps.tasks import EvolInstructGenerator + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + evol_instruct_generator = EvolInstructGenerator( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ), + num_instructions=2, + ) + + evol_instruct_generator.load() + + result = next(scorer.process()) + # result + # [{'instruction': 'generated instruction', 'model_name': 'test'}] + ``` """ num_instructions: int diff --git a/src/distilabel/steps/tasks/evol_quality/base.py b/src/distilabel/steps/tasks/evol_quality/base.py index 931d994912..b245a879ff 100644 --- a/src/distilabel/steps/tasks/evol_quality/base.py +++ b/src/distilabel/steps/tasks/evol_quality/base.py @@ -65,6 +65,42 @@ class EvolQuality(Task): References: - [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685) + + Examples: + + Evolve the quality of the responses given a prompt: + + ```python + from distilabel.steps.tasks import EvolQuality + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + evol_quality = EvolQuality( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ), + num_evolutions=2, + ) + + evol_quality.load() + + result = next( + evol_quality.process( + [ + {"instruction": "common instruction", "response": "a response"}, + ] + ) + ) + # result + # [ + # { + # 'instruction': 'common instruction', + # 'response': 'a response', + # 'evolved_response': 'evolved response', + # 'model_name': '"mistralai/Mistral-7B-Instruct-v0.2"' + # } + # ] + ``` """ num_evolutions: int @@ -149,7 +185,7 @@ def _apply_random_mutation(self, instruction: str, response: str) -> str: return ( self.mutation_templates[mutation] .replace("", instruction) - .replace("", response[-1]) + .replace("", response) ) def _evolve_reponses(self, inputs: "StepInput") -> List[List[str]]: diff --git a/src/distilabel/steps/tasks/generate_embeddings.py b/src/distilabel/steps/tasks/generate_embeddings.py index 39c17f016e..1b0df634c6 100644 --- a/src/distilabel/steps/tasks/generate_embeddings.py +++ b/src/distilabel/steps/tasks/generate_embeddings.py @@ -47,6 +47,33 @@ class GenerateEmbeddings(Step): References: - [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685) + + Examples: + + Rank LLM candidates: + + ```python + from distilabel.steps.tasks import GenerateEmbeddings + from distilabel.llms.huggingface import TransformersLLM + + # Consider this as a placeholder for your actual LLM. + embedder = GenerateEmbeddings( + llm=TransformersLLM( + model="TaylorAI/bge-micro-v2", + model_kwargs={"is_decoder": True}, + cuda_devices=[], + ) + ) + embedder.load() + + result = next( + embedder.process( + [ + {"text": "Hello, how are you?"}, + ] + ) + ) + ``` """ llm: LLM diff --git a/src/distilabel/steps/tasks/genstruct.py b/src/distilabel/steps/tasks/genstruct.py index 550e1220d5..1e80fcb429 100644 --- a/src/distilabel/steps/tasks/genstruct.py +++ b/src/distilabel/steps/tasks/genstruct.py @@ -67,6 +67,42 @@ class Genstruct(Task): References: - [Genstruct 7B by Nous Research](https://huggingface.co/NousResearch/Genstruct-7B) - [Ada-Instruct: Adapting Instruction Generators for Complex Reasoning](https://arxiv.org/abs/2310.04484) + + Examples: + + Generate instructions from raw documents using the title and content: + + ```python + from distilabel.steps.tasks import Genstruct + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + genstruct = Genstruct( + llm=InferenceEndpointsLLM( + model_id="NousResearch/Genstruct-7B", + ), + ) + + genstruct.load() + + result = next( + genstruct.process( + [ + {"title": "common instruction", "content": "content of the document"}, + ] + ) + ) + # result + # [ + # { + # 'title': 'An instruction', + # 'content': 'content of the document', + # 'model_name': 'test', + # 'user': 'An instruction', + # 'assistant': 'content of the document', + # } + # ] + ``` """ _template: Union[Template, None] = PrivateAttr(...) diff --git a/src/distilabel/steps/tasks/improving_text_embeddings.py b/src/distilabel/steps/tasks/improving_text_embeddings.py new file mode 100644 index 0000000000..0e91354274 --- /dev/null +++ b/src/distilabel/steps/tasks/improving_text_embeddings.py @@ -0,0 +1,941 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import re +import sys +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Literal, Optional, Union + +if sys.version_info < (3, 9): + import importlib_resources +else: + import importlib.resources as importlib_resources + +from jinja2 import Template +from pydantic import Field, PrivateAttr +from typing_extensions import override + +from distilabel.steps.tasks.base import GeneratorTask, Task +from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.typing import GeneratorStepOutput + + +# BASE CLASSES +class _JSONFormatter(ABC): + """Abstract class that sets the `outputs` property and `format_output` method, assuming + that the output is a JSON string with the keys specified in the `keys` property. So on, + this class is intended to be used whenever we get a JSON string as the `LLM` output with + a set of `keys` we know are there. + + Note: + At the moment this abstract class is only intended to be used for the tasks defined + below based on the output generated by those. Also note that this is not a replacement + for neither the `StructuredGeneration` task nor for the `structured_output` argument + of an `LLM` subclass. + """ + + @property + @abstractmethod + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + ... + + @property + def outputs(self) -> List[str]: + """Contains the output columns produced by the `process` method of the task. In this + case, it consists of the `keys` (i.e. the JSON keys) and the `model_name`. + """ + return self.keys + ["model_name"] + + def format_output( + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None + ) -> Dict[str, Any]: + """Method to parse the JSON output into a Python dictionary based on the `keys` property. + + Args: + output: The JSON string output produced by the `LLM`. + input: The input dictionary that was used to generate the output. + + Returns: + A Python dictionary with the parsed output based on the `keys` property. + """ + if output is None: + return {key: None for key in self.keys} + + def escape_backslashes_in_values(s): + # Regular expression to match the key-value pairs in the dictionary + pattern = re.compile(r'(".*?":\s*")(.*?)(",?)', re.DOTALL) + + def replace_backslashes(match): + return ( + match.group(1) + + re.sub( + r"(? None: + """Loads the Jinja2 template and sets the random seed.""" + super().load() + + random.seed(self.seed) + + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "improving_text_embeddings" + / f"{self._template_name}.jinja2" # type: ignore + ) + + self._template = Template(open(_path).read()) + + @property + def inputs(self) -> List[str]: + """Contains the input columns expected by the `process` method of the task. In this + case, it consists of the `task`; ideally produced in a previous task which should be + preferrably `EmbeddingTaskGenerator` (as per the original implementation).""" + return ["task"] + + +class _EmbeddingDataGenerator(_JSONFormatter, GeneratorTask, ABC): + """Base class for the subtasks related to embedding data generation as presented in the + paper "Improving Text Embeddings with Large Language Models" that generate data without + an input i.e. `GeneratorStep` or `GeneratorTask`. This class includes a pre-defined `load` + method to load a Jinja2 template based on the `_template_name` private attribute (to be set + in each of the subclasses), assuming that the `prompt` property only expects the `task`, while + keeping the `format_input` as an abstract method to be implemented in the subclasses. + + Attributes: + seed: The random seed to be set in case there's any sampling within the `format_input` method. + _template: The Jinja2 template to be rendered within the `format_input` method with the + provided arguments. + _template_name: The name of the Jinja2 template file within the + `distilabel/steps/tasks/templates/improving_text_embeddings` directory. + """ + + seed: int = 42 + + _template: Union[Template, None] = PrivateAttr(...) + _template_name: str = PrivateAttr(...) + + def load(self) -> None: + """Loads the Jinja2 template and sets the random seed.""" + super().load() + + random.seed(self.seed) + + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "improving_text_embeddings" + / f"{self._template_name}.jinja2" # type: ignore + ) + + self._template = Template(open(_path).read()) + + @property + @abstractmethod + def prompt(self) -> ChatType: + """The prompt to be used for the generation step, ideally rendering the `_template`.""" + ... + + @override + def process(self, offset: int = 0) -> GeneratorStepOutput: # type: ignore + """Method to run the `LLM` generation with the `prompt`, as well as formatting the + outputs accordingly for the task i.e. via the `_JSONFormatter` inheritance. So on, the + `LLM` ideally will be prompted to produce JSON content and then the `format_output` + method will parse it into a Python dictionary based on the `keys` property. + + Args: + offset: The offset to start the generation from. Defaults to 0. + + Yields: + The output rows and a boolean indicating if it's the last batch or not. + """ + formatted_inputs = [self.prompt] + outputs = self.llm.generate( + inputs=formatted_inputs, + num_generations=self.num_generations, + **self.llm.generation_kwargs, # type: ignore + ) + + task_outputs = [] + for input_outputs in outputs: + formatted_outputs = self._format_outputs(input_outputs) # type: ignore + for formatted_output in formatted_outputs: + task_outputs.append( + { + **formatted_output, + "model_name": self.llm.model_name, + } + ) + yield task_outputs, True + + +# IMPLEMENTED TASKS +class EmbeddingTaskGenerator(GeneratorTask): + """Generate task descriptions for embedding-related tasks using an `LLM`. + + `EmbeddingTaskGenerator` is a `GeneratorTask` that doesn't receieve any input besides the + provided attributes that generates task descriptions for embedding-related tasks using a + pre-defined prompt based on the `category` attribute. The `category` attribute should be + one of the following: + + - `text-retrieval`: Generate task descriptions for text retrieval tasks. + - `text-matching-short`: Generate task descriptions for short text matching tasks. + - `text-matching-long`: Generate task descriptions for long text matching tasks. + - `text-classification`: Generate task descriptions for text classification tasks. + + Attributes: + category: The category of the task to be generated, which can either be `text-retrieval`, + `text-matching-short`, `text-matching-long`, or `text-classification`. + flatten_tasks: Whether to flatten the tasks i.e. since a list of tasks is generated by the + `LLM`, this attribute indicates whether to flatten the list or not. Defaults to `False`, + meaning that running this task with `num_generations=1` will return a `distilabel.Distiset` + with one row only containing a list with around 20 tasks; otherwise, if set to `True`, it + will return a `distilabel.Distiset` with around 20 rows, each containing one task. + + References: + - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368) + + Examples: + + Generate embedding tasks for text retrieval: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import EmbeddingTaskGenerator + + with Pipeline("my-pipeline") as pipeline: + task = EmbeddingTaskGenerator( + category="text-retrieval", + flatten_tasks=True, + llm=..., # LLM instance + ) + + ... + + task >> ... + ``` + """ + + category: Literal[ + "text-retrieval", + "text-matching-short", + "text-matching-long", + "text-classification", + ] + flatten_tasks: bool = False + + _template: Union[Template, None] = PrivateAttr(...) + + def load(self) -> None: + """Loads the Jinja2 template.""" + super().load() + + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "improving_text_embeddings" + / "brainstorming" + / f"{self.category}.jinja2" + ) + + self._template = Template(open(_path).read()) + + @property + def prompt(self) -> ChatType: # type: ignore + """The prompt to be used in the `process` method, rendering the `_template` with the + provided args / attributes. + """ + return [{"role": "user", "content": self._template.render().strip()}] # type: ignore + + @override + def process(self, offset: int = 0) -> GeneratorStepOutput: # type: ignore + """Method to run the `LLM` generation with the `prompt`, as well as formatting the + outputs accordingly for the task i.e. via the `_JSONFormatter` inheritance. So on, the + `LLM` ideally will be prompted to produce JSON content and then the `format_output` + method will parse it into a Python dictionary based on the `keys` property. + + Args: + offset: The offset to start the generation from. Defaults to 0. + + Yields: + The output rows and a boolean indicating if it's the last batch or not. + """ + formatted_inputs = [self.prompt] + outputs = self.llm.generate( + inputs=formatted_inputs, + num_generations=self.num_generations, + **self.llm.generation_kwargs, # type: ignore + ) + + task_outputs = [] + for input_outputs in outputs: + formatted_outputs = self._format_outputs(input_outputs) # type: ignore + for formatted_output in formatted_outputs: + if isinstance(formatted_output["tasks"], list) and self.flatten_tasks: + tasks = formatted_output.pop("tasks") + task_outputs.extend( + [ + { + "task": task, + **formatted_output, + "model_name": self.llm.model_name, + } + for task in tasks + ] + ) + else: + if self.flatten_tasks: + formatted_output["task"] = formatted_output.pop("tasks") + task_outputs.append( + {**formatted_output, "model_name": self.llm.model_name} + ) + yield task_outputs, True + + @property + def outputs(self) -> List[str]: + """Contains the output columns produced by the `process` method of the task. In this + case, it consists of the `tasks` or `task` (depending on the `flatten_tasks` attribute) + and the `model_name`. + """ + return ["tasks" if not self.flatten_tasks else "task", "model_name"] + + def format_output( + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None + ) -> Dict[str, Any]: + """Method to parse the JSON output into a Python dictionary based on the `keys` property. + + Args: + output: The JSON string output produced by the `LLM`. + input: The input dictionary that was used to generate the output. + + Returns: + A Python dictionary with the parsed output based on the `keys` property. + """ + try: + if output is not None: + output = eval(output) + except Exception: + pass + return {"tasks": output} + + +class GenerateTextRetrievalData(_EmbeddingDataGeneration): + """Generate text retrieval data with an `LLM` to later on train an embedding model. + + `GenerateTextRetrievalData` is a `Task` that generates text retrieval data with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Note: + Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True` + with the `category="text-retrieval"`; so that the `LLM` generates a list of tasks that + are flattened so that each row contains a single task for the text-retrieval category. + + Attributes: + language: The language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + query_type: The type of query to be generated, which can be `extremely long-tail`, `long-tail`, + or `common`. Defaults to `None`, meaning that it will be randomly sampled. + query_length: The length of the query to be generated, which can be `less than 5 words`, `5 to 15 words`, + or `at least 10 words`. Defaults to `None`, meaning that it will be randomly sampled. + difficulty: The difficulty of the query to be generated, which can be `high school`, `college`, or `PhD`. + Defaults to `None`, meaning that it will be randomly sampled. + clarity: The clarity of the query to be generated, which can be `clear`, `understandable with some effort`, + or `ambiguous`. Defaults to `None`, meaning that it will be randomly sampled. + num_words: The number of words in the query to be generated, which can be `50`, `100`, `200`, `300`, `400`, or `500`. + Defaults to `None`, meaning that it will be randomly sampled. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + + References: + - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368) + + Examples: + + Generate synthetic text retrieval data for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateTextRetrievalData + + with Pipeline("my-pipeline") as pipeline: + task = EmbeddingTaskGenerator( + category="text-retrieval", + flatten_tasks=True, + llm=..., # LLM instance + ) + + generate = GenerateTextRetrievalData( + language="English", + query_type="common", + query_length="5 to 15 words", + difficulty="high school", + clarity="clear", + num_words=100, + llm=..., # LLM instance + ) + + task >> generate + ``` + """ + + language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + query_type: Optional[Literal["extremely long-tail", "long-tail", "common"]] = None + query_length: Optional[ + Literal["less than 5 words", "5 to 15 words", "at least 10 words"] + ] = None + difficulty: Optional[Literal["high school", "college", "PhD"]] = None + clarity: Optional[ + Literal["clear", "understandable with some effort", "ambiguous"] + ] = None + num_words: Optional[Literal[50, 100, 200, 300, 400, 500]] = None + + _template_name: str = PrivateAttr(default="text-retrieval") + + def format_input(self, input: Dict[str, Any]) -> ChatType: + """Method to format the input based on the `task` and the provided attributes, or just + randomly sampling those if not provided. This method will render the `_template` with + the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that + there's only one turn, being from the user with the content being the rendered `_template`. + + Args: + input: The input dictionary containing the `task` to be used in the `_template`. + + Returns: + A list with a single chat containing the user's message with the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + task=input["task"], + language=self.language, + query_type=self.query_type + or random.choice(["extremely long-tail", "long-tail", "common"]), + query_length=self.query_length + or random.choice( + ["less than 5 words", "5 to 15 words", "at least 10 words"] + ), + difficulty=self.difficulty + or random.choice(["high school", "college", "PhD"]), + clarity=self.clarity + or random.choice( + ["clear", "understandable with some effort", "ambiguous"] + ), + num_words=self.num_words + or random.choice([50, 100, 200, 300, 400, 500]), + ).strip(), + } + ] + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return [ + "user_query", + "positive_document", + "hard_negative_document", + ] + + +class GenerateShortTextMatchingData(_EmbeddingDataGeneration): + """Generate short text matching data with an `LLM` to later on train an embedding model. + + `GenerateShortTextMatchingData` is a `Task` that generates short text matching data with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Note: + Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True` + with the `category="text-matching-short"`; so that the `LLM` generates a list of tasks that + are flattened so that each row contains a single task for the text-matching-short category. + + Attributes: + language: The language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + Note that in this task the `seed` has no effect since there are no sampling params. + + References: + - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368) + + Examples: + + Generate synthetic short text matching data for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateShortTextMatchingData + + with Pipeline("my-pipeline") as pipeline: + task = EmbeddingTaskGenerator( + category="text-matching-short", + flatten_tasks=True, + llm=..., # LLM instance + ) + + generate = GenerateShortTextMatchingData( + language="English", + llm=..., # LLM instance + ) + + task >> generate + ``` + """ + + language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + _template_name: str = PrivateAttr(default="short-text-matching") + + def format_input(self, input: Dict[str, Any]) -> ChatType: + """Method to format the input based on the `task` and the provided attributes, or just + randomly sampling those if not provided. This method will render the `_template` with + the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that + there's only one turn, being from the user with the content being the rendered `_template`. + + Args: + input: The input dictionary containing the `task` to be used in the `_template`. + + Returns: + A list with a single chat containing the user's message with the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + task=input["task"], + language=self.language, + ).strip(), + } + ] + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return ["input", "positive_document"] + + +class GenerateLongTextMatchingData(_EmbeddingDataGeneration): + """Generate long text matching data with an `LLM` to later on train an embedding model. + + `GenerateLongTextMatchingData` is a `Task` that generates long text matching data with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Note: + Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True` + with the `category="text-matching-long"`; so that the `LLM` generates a list of tasks that + are flattened so that each row contains a single task for the text-matching-long category. + + Attributes: + language: The language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + Note that in this task the `seed` has no effect since there are no sampling params. + + References: + - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368) + + Examples: + + Generate synthetic long text matching data for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateLongTextMatchingData + + with Pipeline("my-pipeline") as pipeline: + task = EmbeddingTaskGenerator( + category="text-matching-long", + flatten_tasks=True, + llm=..., # LLM instance + ) + + generate = GenerateLongTextMatchingData( + language="English", + llm=..., # LLM instance + ) + + task >> generate + ``` + """ + + language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + _template_name: str = PrivateAttr(default="long-text-matching") + + def format_input(self, input: Dict[str, Any]) -> ChatType: + """Method to format the input based on the `task` and the provided attributes, or just + randomly sampling those if not provided. This method will render the `_template` with + the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that + there's only one turn, being from the user with the content being the rendered `_template`. + + Args: + input: The input dictionary containing the `task` to be used in the `_template`. + + Returns: + A list with a single chat containing the user's message with the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + task=input["task"], + language=self.language, + ).strip(), + } + ] + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return ["input", "positive_document"] + + +class GenerateTextClassificationData(_EmbeddingDataGeneration): + """Generate text classification data with an `LLM` to later on train an embedding model. + + `GenerateTextClassificationData` is a `Task` that generates text classification data with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Note: + Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True` + with the `category="text-classification"`; so that the `LLM` generates a list of tasks that + are flattened so that each row contains a single task for the text-classification category. + + Attributes: + language: The language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + difficulty: The difficulty of the query to be generated, which can be `high school`, `college`, or `PhD`. + Defaults to `None`, meaning that it will be randomly sampled. + clarity: The clarity of the query to be generated, which can be `clear`, `understandable with some effort`, + or `ambiguous`. Defaults to `None`, meaning that it will be randomly sampled. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + + References: + - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368) + + Examples: + + Generate synthetic text classification data for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateTextClassificationData + + with Pipeline("my-pipeline") as pipeline: + task = EmbeddingTaskGenerator( + category="text-classification", + flatten_tasks=True, + llm=..., # LLM instance + ) + + generate = GenerateTextClassificationData( + language="English", + difficulty="high school", + clarity="clear", + llm=..., # LLM instance + ) + + task >> generate + ``` + """ + + language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + difficulty: Optional[Literal["high school", "college", "PhD"]] = None + clarity: Optional[ + Literal["clear", "understandable with some effort", "ambiguous"] + ] = None + + _template_name: str = PrivateAttr(default="text-classification") + + def format_input(self, input: Dict[str, Any]) -> ChatType: + """Method to format the input based on the `task` and the provided attributes, or just + randomly sampling those if not provided. This method will render the `_template` with + the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that + there's only one turn, being from the user with the content being the rendered `_template`. + + Args: + input: The input dictionary containing the `task` to be used in the `_template`. + + Returns: + A list with a single chat containing the user's message with the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + task=input["task"], + language=self.language, + difficulty=self.difficulty + or random.choice(["high school", "college", "PhD"]), + clarity=self.clarity + or random.choice( + ["clear", "understandable with some effort", "ambiguous"] + ), + ).strip(), + } + ] + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return ["input_text", "label", "misleading_label"] + + +class MonolingualTripletGenerator(_EmbeddingDataGenerator): + """Generate monolingual triplets with an `LLM` to later on train an embedding model. + + `MonolingualTripletGenerator` is a `GeneratorTask` that generates monolingual triplets with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Attributes: + language: The language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + unit: The unit of the data to be generated, which can be `sentence`, `phrase`, or `passage`. + Defaults to `None`, meaning that it will be randomly sampled. + difficulty: The difficulty of the query to be generated, which can be `elementary school`, `high school`, or `college`. + Defaults to `None`, meaning that it will be randomly sampled. + high_score: The high score of the query to be generated, which can be `4`, `4.5`, or `5`. + Defaults to `None`, meaning that it will be randomly sampled. + low_score: The low score of the query to be generated, which can be `2.5`, `3`, or `3.5`. + Defaults to `None`, meaning that it will be randomly sampled. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + + Examples: + + Generate monolingual triplets for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import MonolingualTripletGenerator + + with Pipeline("my-pipeline") as pipeline: + task = MonolingualTripletGenerator( + language="English", + unit="sentence", + difficulty="elementary school", + high_score="4", + low_score="2.5", + llm=..., + ) + + ... + + task >> ... + ``` + """ + + language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + unit: Optional[Literal["sentence", "phrase", "passage"]] = None + difficulty: Optional[Literal["elementary school", "high school", "college"]] = None + high_score: Optional[Literal["4", "4.5", "5"]] = None + low_score: Optional[Literal["2.5", "3", "3.5"]] = None + + _template_name: str = PrivateAttr(default="monolingual-triplet") + + @property + def prompt(self) -> ChatType: + """Contains the `prompt` to be used in the `process` method, rendering the `_template`; and + formatted as an OpenAI formatted chat i.e. a `ChatType`, assuming that there's only one turn, + being from the user with the content being the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + language=self.language, + unit=self.unit or random.choice(["sentence", "phrase", "passage"]), + difficulty=self.difficulty + or random.choice(["elementary school", "high school", "college"]), + high_score=self.high_score or random.choice(["4", "4.5", "5"]), + low_score=self.low_score or random.choice(["2.5", "3", "3.5"]), + ).strip(), + } + ] # type: ignore + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return ["S1", "S2", "S3"] + + +class BitextRetrievalGenerator(_EmbeddingDataGenerator): + """Generate bitext retrieval data with an `LLM` to later on train an embedding model. + + `BitextRetrievalGenerator` is a `GeneratorTask` that generates bitext retrieval data with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Attributes: + source_language: The source language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + target_language: The target language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + unit: The unit of the data to be generated, which can be `sentence`, `phrase`, or `passage`. + Defaults to `None`, meaning that it will be randomly sampled. + difficulty: The difficulty of the query to be generated, which can be `elementary school`, `high school`, or `college`. + Defaults to `None`, meaning that it will be randomly sampled. + high_score: The high score of the query to be generated, which can be `4`, `4.5`, or `5`. + Defaults to `None`, meaning that it will be randomly sampled. + low_score: The low score of the query to be generated, which can be `2.5`, `3`, or `3.5`. + Defaults to `None`, meaning that it will be randomly sampled. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + + Examples: + + Generate bitext retrieval data for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import BitextRetrievalGenerator + + with Pipeline("my-pipeline") as pipeline: + task = BitextRetrievalGenerator( + source_language="English", + target_language="Spanish", + unit="sentence", + difficulty="elementary school", + high_score="4", + low_score="2.5", + llm=..., + ) + + ... + + task >> ... + ``` + """ + + source_language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + target_language: str = Field( + default=..., + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + unit: Optional[Literal["sentence", "phrase", "passage"]] = None + difficulty: Optional[Literal["elementary school", "high school", "college"]] = None + high_score: Optional[Literal["4", "4.5", "5"]] = None + low_score: Optional[Literal["2.5", "3", "3.5"]] = None + + _template_name: str = PrivateAttr(default="bitext-retrieval") + + @property + def prompt(self) -> ChatType: + """Contains the `prompt` to be used in the `process` method, rendering the `_template`; and + formatted as an OpenAI formatted chat i.e. a `ChatType`, assuming that there's only one turn, + being from the user with the content being the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + source_language=self.source_language, + target_language=self.target_language, + unit=self.unit or random.choice(["sentence", "phrase", "passage"]), + difficulty=self.difficulty + or random.choice(["elementary school", "high school", "college"]), + high_score=self.high_score or random.choice(["4", "4.5", "5"]), + low_score=self.low_score or random.choice(["2.5", "3", "3.5"]), + ).strip(), + } + ] # type: ignore + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return ["S1", "S2", "S3"] diff --git a/src/distilabel/steps/tasks/pair_rm.py b/src/distilabel/steps/tasks/pair_rm.py index 3c1ecc7a56..be23a38699 100644 --- a/src/distilabel/steps/tasks/pair_rm.py +++ b/src/distilabel/steps/tasks/pair_rm.py @@ -49,6 +49,37 @@ class PairRM(Step): Note: This step differs to other tasks as there is a single implementation of this model currently, and we will use a specific `LLM`. + + Examples: + + Rank LLM candidates: + + ```python + from distilabel.steps.tasks import PairRM + + # Consider this as a placeholder for your actual LLM. + pair_rm = PairRM() + + pair_rm.load() + + result = next( + scorer.process( + [ + {"input": "Hello, how are you?", "candidates": ["fine", "good", "bad"]}, + ] + ) + ) + # result + # [ + # { + # 'input': 'Hello, how are you?', + # 'candidates': ['fine', 'good', 'bad'], + # 'ranks': [2, 1, 3], + # 'ranked_candidates': ['good', 'fine', 'bad'], + # 'model_name': 'llm-blender/PairRM', + # } + # ] + ``` """ model: str = "llm-blender/PairRM" diff --git a/src/distilabel/steps/tasks/prometheus_eval.py b/src/distilabel/steps/tasks/prometheus_eval.py index 0edde308df..294f9f4d0e 100644 --- a/src/distilabel/steps/tasks/prometheus_eval.py +++ b/src/distilabel/steps/tasks/prometheus_eval.py @@ -135,6 +135,165 @@ class PrometheusEval(Task): References: - [Prometheus 2: An Open Source Language Model Specialized in Evaluating Other Language Models](https://arxiv.org/abs/2405.01535) - [prometheus-eval: Evaluate your LLM's response with Prometheus 💯](https://github.com/prometheus-eval/prometheus-eval) + + Examples: + + Critique and evaluate LLM generation quality using Prometheus 2.0: + + ```python + from distilabel.steps.tasks import PrometheusEval + from distilabel.llms import vLLM + + # Consider this as a placeholder for your actual LLM. + prometheus = PrometheusEval( + llm=vLLM( + model="prometheus-eval/prometheus-7b-v2.0", + chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]", + ), + mode="absolute", + rubric="factual-validity" + ) + + prometheus.load() + + result = next( + prometheus.process( + [ + {"instruction": "make something", "generation": "something done"}, + ] + ) + ) + # result + # [ + # { + # 'instruction': 'make something', + # 'generation': 'something done', + # 'model_name': 'prometheus-eval/prometheus-7b-v2.0', + # 'feedback': 'the feedback', + # 'result': 6, + # } + # ] + ``` + + Critique for relative evaluation: + + ```python + from distilabel.steps.tasks import PrometheusEval + from distilabel.llms import vLLM + + # Consider this as a placeholder for your actual LLM. + prometheus = PrometheusEval( + llm=vLLM( + model="prometheus-eval/prometheus-7b-v2.0", + chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]", + ), + mode="relative", + rubric="honesty" + ) + + prometheus.load() + + result = next( + prometheus.process( + [ + {"instruction": "make something", "generations": ["something done", "other thing"]}, + ] + ) + ) + # result + # [ + # { + # 'instruction': 'make something', + # 'generations': ['something done', 'other thing'], + # 'model_name': 'prometheus-eval/prometheus-7b-v2.0', + # 'feedback': 'the feedback', + # 'result': 'something done', + # } + # ] + ``` + + Critique with a custom rubric: + + ```python + from distilabel.steps.tasks import PrometheusEval + from distilabel.llms import vLLM + + # Consider this as a placeholder for your actual LLM. + prometheus = PrometheusEval( + llm=vLLM( + model="prometheus-eval/prometheus-7b-v2.0", + chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]", + ), + mode="absolute", + rubric="custom", + rubrics={ + "custom": "[A]\nScore 1: A\nScore 2: B\nScore 3: C\nScore 4: D\nScore 5: E" + } + ) + + prometheus.load() + + result = next( + prometheus.process( + [ + {"instruction": "make something", "generation": "something done"}, + ] + ) + ) + # result + # [ + # { + # 'instruction': 'make something', + # 'generation': 'something done', + # 'model_name': 'prometheus-eval/prometheus-7b-v2.0', + # 'feedback': 'the feedback', + # 'result': 6, + # } + # ] + ``` + + Critique using a reference answer: + + ```python + from distilabel.steps.tasks import PrometheusEval + from distilabel.llms import vLLM + + # Consider this as a placeholder for your actual LLM. + prometheus = PrometheusEval( + llm=vLLM( + model="prometheus-eval/prometheus-7b-v2.0", + chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]", + ), + mode="absolute", + rubric="helpfulness", + reference=True, + ) + + prometheus.load() + + result = next( + prometheus.process( + [ + { + "instruction": "make something", + "generation": "something done", + "reference": "this is a reference answer", + }, + ] + ) + ) + # result + # [ + # { + # 'instruction': 'make something', + # 'generation': 'something done', + # 'reference': 'this is a reference answer', + # 'model_name': 'prometheus-eval/prometheus-7b-v2.0', + # 'feedback': 'the feedback', + # 'result': 6, + # } + # ] + ``` """ mode: Literal["absolute", "relative"] @@ -202,7 +361,7 @@ def inputs(self) -> List[str]: if self.reference: return ["instruction", "generation", "reference"] return ["instruction", "generation"] - else: # self.mode == "relative" + else: if self.reference: return ["instruction", "generations", "reference"] return ["instruction", "generations"] diff --git a/src/distilabel/steps/tasks/quality_scorer.py b/src/distilabel/steps/tasks/quality_scorer.py index f805c91e38..a93c2a399a 100644 --- a/src/distilabel/steps/tasks/quality_scorer.py +++ b/src/distilabel/steps/tasks/quality_scorer.py @@ -59,6 +59,43 @@ class QualityScorer(Task): References: - [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685) + + Examples: + + Evaluate the quality of your instructions: + + ```python + from distilabel.steps.tasks import QualityScorer + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + scorer = QualityScorer( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ) + ) + + scorer.load() + + result = next( + scorer.process( + [ + { + "instruction": "instruction", + "responses": ["good response", "weird response", "bad response"] + } + ] + ) + ) + # result + [ + { + 'instructions': 'instruction', + 'model_name': 'test', + 'scores': [5, 3, 1], + } + ] + ``` """ _template: Union[Template, None] = PrivateAttr(...) diff --git a/src/distilabel/steps/tasks/self_instruct.py b/src/distilabel/steps/tasks/self_instruct.py index 6bb673b8cf..34d3ffee06 100644 --- a/src/distilabel/steps/tasks/self_instruct.py +++ b/src/distilabel/steps/tasks/self_instruct.py @@ -60,6 +60,34 @@ class SelfInstruct(Task): Reference: - [`Self-Instruct: Aligning Language Models with Self-Generated Instructions`](https://arxiv.org/abs/2212.10560) + + Examples: + + Generate instructions based on a given input: + + ```python + from distilabel.steps.tasks import SelfInstruct + from distilabel.llms.huggingface import InferenceEndpointsLLM + + self_instruct = SelfInstruct( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ), + num_instructions=5, # This is the default value + ) + + self_instruct.load() + + result = next(self_instruct.process([{"input": "instruction"}])) + # result + # [ + # { + # 'input': 'instruction', + # 'model_name': 'mistralai/Mistral-7B-Instruct-v0.2', + # 'instructions': ["instruction 1", "instruction 2", "instruction 3", "instruction 4", "instruction 5"], + # } + # ] + ``` """ num_instructions: int = 5 diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py new file mode 100644 index 0000000000..b1ad50f5e1 --- /dev/null +++ b/src/distilabel/steps/tasks/sentence_transformers.py @@ -0,0 +1,291 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys +from typing import TYPE_CHECKING, Any, Dict, Final, List, Literal, Optional, Union + +from jinja2 import Template + +from distilabel.steps.tasks.base import Task + +if sys.version_info < (3, 9): + import importlib_resources +else: + import importlib.resources as importlib_resources + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType + +GenerationAction = Literal["paraphrase", "semantically-similar", "query", "answer"] + +POSITIVE_NEGATIVE_PAIR_REGEX = re.compile( + r"## Positive\s+(.*?)(?:\s+## Negative\s+(.*?))?\s*$", + re.DOTALL, +) + +GENERATION_ACTION_SENTENCES: Final[Dict[GenerationAction, str]] = { + "paraphrase": "paraphrase", + "semantically-similar": "be semantically similar to", + "query": "be a query for", + "answer": "be an answer for", +} + +POSITIVE_SYSTEM_PROMPT: str = ( + "Your task is to generate a positive sentence given an anchor sentence.{context} The positive" + " sentence has to {action_sentence} the anchor sentence. You must output only one new" + " section: `## Positive`." +) + +POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = ( + "Your task is to generate a positive and a negative sentence given an anchor sentence.{context}" + " The positive sentence has to {action_sentence} the anchor sentence, while the negative" + " sentence can use similar words but must not be related to the anchor sentence. You" + " must output only two new sections: `## Positive` and `## Negative`." +) + +CONTEXT_INTRO: Final[str] = " Take into account the context given." + + +class GenerateSentencePair(Task): + """Generate a positive and negative (optionally) sentences given an anchor sentence. + + `GenerateSentencePair` is a pre-defined task that given an anchor sentence generates + a positive sentence related to the anchor and optionally a negative sentence unrelated + to the anchor. Optionally, you can give a context to guide the LLM towards more specific + behavior. This task is useful to generate training datasets for training embeddings + models. + + Attributes: + triplet: a flag to indicate if the task should generate a triplet of sentences + (anchor, positive, negative). Defaults to `False`. + action: the action to perform to generate the positive sentence. + context: the context to use for the generation. Can be helpful to guide the LLM + towards more specific context. Not used by default. + + Input columns: + - anchor (`str`): The anchor sentence to generate the positive and negative sentences. + + Output columns: + - positive (`str`): The positive sentence related to the `anchor`. + - negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`. + - model_name (`str`): The name of the model that was used to generate the sentences. + + Categories: + - embedding + + Examples: + + Paraphrasing: + + ```python + from distilabel.steps.tasks import GenerateSentencePair + from distilabel.llms import InferenceEndpointsLLM + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="paraphrase", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}]) + ``` + + Generating semantically similar sentences: + + ```python + from distilabel.llms import InferenceEndpointsLLM + from distilabel.steps.tasks import GenerateSentencePair + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="semantically-similar", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "How does 3D printing work?"}]) + ``` + + Generating queries: + + ```python + from distilabel.steps.tasks import GenerateSentencePair + from distilabel.llms import InferenceEndpointsLLM + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="query", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "Argilla is an open-source data curation platform for LLMs. Using Argilla, ..."}]) + ``` + + Generating answers: + + ```python + from distilabel.steps.tasks import GenerateSentencePair + from distilabel.llms import InferenceEndpointsLLM + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="answer", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}]) + ``` + + Generating queries with context (**applies to every action**): + + ```python + from distilabel.steps.tasks import GenerateSentencePair + from distilabel.llms import InferenceEndpointsLLM + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="query", + context="Argilla is an open-source data curation platform for LLMs.", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}]) + ``` + """ + + triplet: bool = False + action: GenerationAction + context: str = "" + + def load(self) -> None: + """Loads the Jinja2 template.""" + super().load() + + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "generate-sentence-pair.jinja2" + ) + + self._template = Template(open(_path).read()) + + @property + def inputs(self) -> List[str]: + """The inputs for the task is the `anchor` sentence.""" + return ["anchor"] + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + """The inputs are formatted as a `ChatType`, with a system prompt describing the + task of generating a positive and negative sentences for the anchor sentence. The + anchor is provided as the first user interaction in the conversation. + + Args: + input: The input containing the `anchor` sentence. + + Returns: + A list of dictionaries containing the system and user interactions. + """ + action_sentence = GENERATION_ACTION_SENTENCES[self.action] + system_prompt = ( + POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT + ).format( + action_sentence=action_sentence, + context=CONTEXT_INTRO if self.context else "", + ) + + return [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": self._template.render( + anchor=input["anchor"], + context=self.context if self.context else None, + ), + }, + ] + + @property + def outputs(self) -> List[str]: + """The outputs for the task are the `positive` and `negative` sentences, as well + as the `model_name` used to generate the sentences.""" + columns = ["positive", "negative"] if self.triplet else ["positive"] + columns += ["model_name"] + return columns + + def format_output( + self, output: Union[str, None], input: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Formats the output of the LLM, to extract the `positive` and `negative` sentences + generated. If the output is `None` or the regex doesn't match, then the outputs + will be set to `None` as well. + + Args: + output: The output of the LLM. + input: The input used to generate the output. + + Returns: + The formatted output containing the `positive` and `negative` sentences. + """ + if output is None: + return {"positive": None, "negative": None} + + match = POSITIVE_NEGATIVE_PAIR_REGEX.match(output) + if match is None: + formatted_output = {"positive": None} + if self.triplet: + formatted_output["negative"] = None + return formatted_output + + groups = match.groups() + if self.triplet: + return { + "positive": groups[0].strip(), + "negative": groups[1].strip() + if len(groups) > 1 and groups[1] is not None + else None, + } + + return {"positive": groups[0].strip()} diff --git a/src/distilabel/steps/tasks/structured_generation.py b/src/distilabel/steps/tasks/structured_generation.py new file mode 100644 index 0000000000..240cd44698 --- /dev/null +++ b/src/distilabel/steps/tasks/structured_generation.py @@ -0,0 +1,187 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Any, Dict, List, Union + +from distilabel.steps.tasks.base import Task +from distilabel.steps.tasks.typing import StructuredInput + + +class StructuredGeneration(Task): + """Generate structured content for a given `instruction` using an `LLM`. + + `StructuredGeneration` is a pre-defined task that defines the `instruction` and the `structured_output` + as the inputs, and `generation` as the output. This task is used to generate structured content based on + the input instruction and following the schema provided within the `structured_output` column per each + `instruction`. The `model_name` also returned as part of the output in order to enhance it. + + Attributes: + use_system_prompt: Whether to use the system prompt in the generation. Defaults to `True`, + which means that if the column `system_prompt` is defined within the input batch, then + the `system_prompt` will be used, otherwise, it will be ignored. + + Input columns: + - instruction (`str`): The instruction to generate structured content from. + - structured_output (`Dict[str, Any]`): The structured_output to generate structured content from. It should be a + Python dictionary with the keys `format` and `schema`, where `format` should be one of `json` or + `regex`, and the `schema` should be either the JSON schema or the regex pattern, respectively. + + Output columns: + - generation (`str`): The generated text matching the provided schema, if possible. + - model_name (`str`): The name of the model used to generate the text. + + Categories: + - outlines + - structured-generation + + Examples: + + Generate structured output from a JSON schema: + + ```python + from distilabel.steps.tasks import StructuredGeneration + from distilabel.llms import InferenceEndpointsLLM + + structured_gen = StructuredGeneration( + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + ) + + structured_gen.load() + + result = next( + structured_gen.process( + [ + { + "instruction": "Create an RPG character", + "structured_output": { + "type": "json", + "value": { + "properties": { + "name": { + "title": "Name", + "type": "string" + }, + "description": { + "title": "Description", + "type": "string" + }, + "role": { + "title": "Role", + "type": "string" + }, + "weapon": { + "title": "Weapon", + "type": "string" + } + }, + "required": [ + "name", + "description", + "role", + "weapon" + ], + "title": "Character", + "type": "object" + } + }, + } + ] + ) + ) + ``` + + Generate structured output from a regex pattern: + + ```python + from distilabel.steps.tasks import StructuredGeneration + from distilabel.llms import InferenceEndpointsLLM + + structured_gen = StructuredGeneration( + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + ) + + structured_gen.load() + + result = next( + structured_gen.process( + [ + { + "instruction": "What's the weather like today in Seattle in Celsius degrees?", + "structured_output": { + "type": "regex", + "value": r"(\\d{1,2})°C" + }, + + } + ] + ) + ) + ``` + """ + + use_system_prompt: bool = False + + @property + def inputs(self) -> List[str]: + """The input for the task are the `instruction` and the `structured_output`. + Optionally, if the `use_system_prompt` flag is set to True, then the + `system_prompt` will be used too.""" + columns = ["instruction", "structured_output"] + if self.use_system_prompt: + columns = ["system_prompt"] + columns + return columns + + def format_input(self, input: Dict[str, Any]) -> StructuredInput: + """The input is formatted as a `ChatType` assuming that the instruction + is the first interaction from the user within a conversation.""" + if not isinstance(input["instruction"], str): + raise ValueError( + f"Input `instruction` must be a string. Got: {input['instruction']}." + ) + + messages = [{"role": "user", "content": input["instruction"]}] + if self.use_system_prompt: + if "system_prompt" in input: + messages.insert( + 0, {"role": "system", "content": input["system_prompt"]} + ) + else: + warnings.warn( + "`use_system_prompt` is set to `True`, but no `system_prompt` in input batch, so it will be ignored.", + UserWarning, + stacklevel=2, + ) + + return (messages, input.get("structured_output", None)) # type: ignore + + @property + def outputs(self) -> List[str]: + """The output for the task is the `generation` and the `model_name`.""" + return ["generation", "model_name"] + + def format_output( + self, output: Union[str, None], input: Dict[str, Any] + ) -> Dict[str, Any]: + """The output is formatted as a dictionary with the `generation`. The `model_name` + will be automatically included within the `process` method of `Task`. Note that even + if the `structured_output` is defined to produce a JSON schema, this method will return the raw + output i.e. a string without any parsing.""" + return {"generation": output} diff --git a/src/distilabel/steps/tasks/structured_outputs/instructor.py b/src/distilabel/steps/tasks/structured_outputs/instructor.py new file mode 100644 index 0000000000..94ab1097e5 --- /dev/null +++ b/src/distilabel/steps/tasks/structured_outputs/instructor.py @@ -0,0 +1,124 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +from typing import ( + TYPE_CHECKING, + Callable, + Literal, + Optional, + Tuple, + TypeAlias, + Union, + get_args, +) + +if TYPE_CHECKING: + import instructor + from anthropic import AsyncAnthropic + from cohere import AsyncClient as AsyncCohere + from groq import AsyncGroq + from mistralai.async_client import MistralAsyncClient + from openai import AsyncAzureOpenAI, AsyncOpenAI + + +InstructorFrameworks = Literal[ + "openai", "azure_openai", "anthropic", "cohere", "groq", "litellm", "mistral" +] +"""Available frameworks for the structured output configuration with `instructor`. """ + +InstructorAvailableClients: TypeAlias = Union[ + "AsyncAnthropic", + "AsyncAzureOpenAI", + "AsyncCohere", + "AsyncGroq", + "AsyncOpenAI", + "MistralAsyncClient", +] +"""Available clients that can be wrapped with `instructor`. """ + + +def _client_patcher(framework: InstructorFrameworks) -> Tuple[Callable, str]: + """Helper function to return the appropriate instructor client for the given framework. + + Args: + framework: The framework to use for the instructor client. + + Raises: + ValueError: If the framework is not one of the available frameworks. + + Returns: + Tuple of Callable and string, with the builder of the client patch and the + default mode to use. + """ + import instructor + + if framework in {"openai", "azure_openai"}: + patch = instructor.from_openai, instructor.Mode.TOOLS + elif framework == "anthropic": + patch = instructor.from_anthropic, instructor.Mode.ANTHROPIC_JSON + elif framework == "litellm": + patch = instructor.from_litellm, instructor.Mode.TOOLS + elif framework == "mistral": + patch = instructor.from_mistral, instructor.Mode.MISTRAL_TOOLS + elif framework == "cohere": + patch = instructor.from_cohere, instructor.Mode.COHERE_TOOLS + elif framework == "groq": + patch = instructor.from_groq, instructor.Mode.TOOLS + else: + raise ValueError( + f"Invalid framework '{framework}'. Must be one of {get_args(InstructorFrameworks)}" + ) + + return patch + + +def prepare_instructor( + client: InstructorAvailableClients, + mode: Optional["instructor.Mode"] = None, + framework: Optional[InstructorFrameworks] = None, +) -> "instructor.AsyncInstructor": + """Wraps the given client with the instructor client for the given framework. + + Args: + client: The client to wrap with the instructor client, corresponds to the internal + client we wrap on `LLM`, and one of the implemented in `instructor`. + mode: One of the `instructor.Mode` values. Defaults to None. + framework: The framework corresponding to the client. Defaults to None. + + Raises: + ImportError: If `instructor` is not installed. + ValueError: If the mode is not one of the available modes. + + Returns: + patched_client: The instructor wrapping the original client to be used for + structured generation. + """ + if not importlib.util.find_spec("instructor"): + raise ImportError( + "`instructor` is not installed. Please install it using `pip install instructor`." + ) + import instructor + + builder, default_mode = _client_patcher(framework) + + mode = mode or default_mode + if mode.value not in [m.value for m in instructor.mode.Mode]: + raise ValueError( + f"Invalid mode '{mode}'. Must be one of {[m.value for m in instructor.mode.Mode]}" + ) + + patched_client: instructor.AsyncInstructor = builder(client, mode=mode) + + return patched_client diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py index 087f4913bc..d726b5e4f5 100644 --- a/src/distilabel/steps/tasks/structured_outputs/outlines.py +++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py @@ -19,55 +19,27 @@ Any, Callable, Dict, - List, Literal, - Optional, Tuple, Type, - TypedDict, Union, get_args, ) from pydantic import BaseModel +from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict +from distilabel.steps.tasks.typing import StructuredOutputType + Frameworks = Literal["transformers", "llamacpp", "vllm"] """Available frameworks for the structured output configuration. """ -class StructuredOutputType(TypedDict): - """TypedDict to represent the structured output configuration from outlines.""" - - format: Literal["json", "regex"] - """One of "json" or "regex".""" - schema: Union[str, Type[BaseModel]] - """The schema to use for the structured output. If "json", it - can be a pydantic.BaseModel class, or the schema as a string, - as obtained from `model_to_schema(BaseModel)`, if "regex", it - should be a regex pattern as a string. - """ - whitespace_pattern: Optional[Union[str, List[str]]] - """If "json" corresponds to a string or a list of - strings with a pattern (doesn't impact string literals). - For example, to allow only a single space or newline with - `whitespace_pattern=r"[\n ]?"` - """ - - def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]: """Helper function to return a string representation of the schema from a `pydantic.BaseModel` class.""" return json.dumps(schema.model_json_schema()) -def _schema_as_dict(schema: Union[str, Type[BaseModel]]) -> Dict[str, Any]: - """Helper function to obtain the schema and simplify serialization.""" - if type(schema) == type(BaseModel): - return schema.model_json_schema() - elif isinstance(schema, str): - return json.loads(schema) - return schema - - def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]: """Helper function to return the appropriate logits processor for the given framework.""" if framework == "transformers": @@ -137,7 +109,7 @@ def prepare_guided_output( llm, whitespace_pattern=structured_output.get("whitespace_pattern"), ), - "schema": _schema_as_dict(schema), + "schema": schema_as_dict(schema), } if format == "regex": diff --git a/src/distilabel/steps/tasks/structured_outputs/utils.py b/src/distilabel/steps/tasks/structured_outputs/utils.py new file mode 100644 index 0000000000..8bcebcb819 --- /dev/null +++ b/src/distilabel/steps/tasks/structured_outputs/utils.py @@ -0,0 +1,157 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, Dict, List, Optional, Type, Union + +from pydantic import BaseModel, Field, create_model + + +def schema_as_dict(schema: Union[str, Type[BaseModel]]) -> Dict[str, Any]: + """Helper function to obtain the schema and simplify serialization.""" + if type(schema) == type(BaseModel): + return schema.model_json_schema() + elif isinstance(schema, str): + return json.loads(schema) + return schema + + +# NOTE: The following functions were copied from: +# https://github.com/pydantic/pydantic/issues/643#issuecomment-1999755873 +# and slightly modified to work with nested models. +# It would be nice to find the original source of this code to give credit. +# Other option would be working with this library: https://github.com/c32168/dyntamic + + +def json_schema_to_model(json_schema: Dict[str, Any]) -> Type[BaseModel]: + """Converts a JSON schema to a `pydantic.BaseModel` class. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A `pydantic.BaseModel` class. + """ + + # Extract the model name from the schema title. + model_name = json_schema.get("title") + if defs := json_schema.get("$defs", None): + # This is done to grab the content of nested classes that need to dereference + # the objects (those should be in a higher level). + pass + + # Extract the field definitions from the schema properties. + field_definitions = { + name: json_schema_to_pydantic_field( + name, prop, json_schema.get("required", []), defs=defs + ) + for name, prop in json_schema.get("properties", {}).items() + } + + # Create the BaseModel class using create_model(). + return create_model(model_name, **field_definitions) + + +def json_schema_to_pydantic_field( + name: str, + json_schema: Dict[str, Any], + required: List[str], + defs: Optional[Dict[str, Any]] = None, +) -> Any: + """Converts a JSON schema property to a `pydantic.Field`. + + Args: + name: The field name. + json_schema: The JSON schema property. + required: The list of required fields. + defs: The definitions of the JSON schema. It's used to dereference nested classes, + so we can grab the original definition from the json schema (it won't + work out of the box with just the reference). + + Returns: + A `pydantic.Field`. + """ + + # NOTE(plaguss): This needs more testing, nested classes need extra work to be converted + # here if we pass a reference to another class it will crash, we have to find the original + # definition and insert it here + # This takes into account single items referred to other classes + if ref := json_schema.get("$ref"): + json_schema = defs.get(ref.split("/")[-1]) + + # This takes into account lists of items referred to other classes + if "items" in json_schema and (ref := json_schema["items"].get("$ref")): + json_schema["items"] = defs.get(ref.split("/")[-1]) + + # Get the field type. + type_ = json_schema_to_pydantic_type(json_schema) + + # Get the field description. + description = json_schema.get("description") + + # Get the field examples. + examples = json_schema.get("examples") + + # Create a Field object with the type, description, and examples. + # The "required" flag will be set later when creating the model. + return ( + type_, + Field( + description=description, + examples=examples, + default=... if name in required else None, + ), + ) + + +def json_schema_to_pydantic_type(json_schema: Dict[str, Any]) -> Any: + """Converts a JSON schema type to a Pydantic type. + + Args: + json_schema: The JSON schema to convert. + + Returns: + A Pydantic type. + """ + type_ = json_schema.get("type") + + if type_ == "string": + type_val = str + elif type_ == "integer": + type_val = int + elif type_ == "number": + type_val = float + elif type_ == "boolean": + type_val = bool + elif type_ == "array": + items_schema = json_schema.get("items") + if items_schema: + item_type = json_schema_to_pydantic_type(items_schema) + type_val = List[item_type] + else: + type_val = List + elif type_ == "object": + # Handle nested models. + properties = json_schema.get("properties") + if properties: + nested_model = json_schema_to_model(json_schema) + type_val = nested_model + else: + type_val = Dict + elif type_ == "null": + type_val = Optional[Any] # Use Optional[Any] for nullable fields + else: + raise ValueError(f"Unsupported JSON schema type: {type_}") + + return type_val diff --git a/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 b/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 new file mode 100644 index 0000000000..cac188e101 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 @@ -0,0 +1,11 @@ +{% if context is not none -%} +## Context + +{{ context }} + +{% endif -%} + +## Anchor + +{{ anchor }} + diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/bitext-retrieval.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/bitext-retrieval.jinja2 new file mode 100644 index 0000000000..1cf238015f --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/bitext-retrieval.jinja2 @@ -0,0 +1,13 @@ +Write a {{ unit }} triple with one {{ unit }} in {{ source_language }} and two {{ unit }}s in {{ target_language }} with varying translation qualities in JSON format. + +The triple is denotes as ("S1", "S2", "S3"). The translation quality score ranges from 1 to 5, with higher scores are better. + +Please adhere to the following guidelines: + - The values of "S1" is a string in {{ source_language }}, the value of "S2" and "S3" are strings in {{ target_language }}. + - There should be some word overlaps between "S2" and "S3". + - The translation quality score of "S2" with respect to "S1" should be {{ high_score }}. + - The translation quality score of "S3" with respect to "S1" should be {{ low_score }}. + - "S3" should be grammatical and fluent, but contain some keyword or number translation errors, or miss some information, or contain some redundant information. + - "S1" requires {{ difficulty }} level education to understand and should be diverse in terms of topic and length. + +Your output must always be a JSON object only with three keys "S1", "S2" and "S3", do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-classification.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-classification.jinja2 new file mode 100644 index 0000000000..3501b9332d --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-classification.jinja2 @@ -0,0 +1,6 @@ +Brainstorm a list of potentially useful text classification tasks. + +Please adhere to the following guidelines: + - Tasks should cover a diverse range of domains and task types. + +Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct text classification task in one sentence. Do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-long.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-long.jinja2 new file mode 100644 index 0000000000..0090ef2af4 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-long.jinja2 @@ -0,0 +1,7 @@ +Brainstorm a list of text matching tasks where the queries are long documents. + +Here are a few examples: + - Given a document that supports a debatable argument, find another document that contains opposite arguments. + - Provided a lengthy business proposal, retrieve competitive business strategies in the same industry. + +Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct task in one sentence. Do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-short.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-short.jinja2 new file mode 100644 index 0000000000..cf42fddae5 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-short.jinja2 @@ -0,0 +1,8 @@ +Brainstorm a list of text matching tasks where both the queries and the groundtruth documents are very short (one or two sentences, even a short phrase). + +Here are a few examples: + - Given a scientific paper title, retrieve the title of papers that cite the given paper. + - Match a word with its definition. + - Provided a notable person's name, identify their occupation or achievement. + +Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct task in one sentence. Do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-retrieval.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-retrieval.jinja2 new file mode 100644 index 0000000000..464ed0e763 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-retrieval.jinja2 @@ -0,0 +1,11 @@ +Brainstorm a list of potentially useful text retrieval tasks. + +Here are a few examples for your reference: + - Provided a scientific claim as query, retrieve documents that help verify or refute the claim. + - Search for documents that answers a FAQ-style query on children's nutrition. + +Please adhere to the following guidelines: + - Specify what the query is, and what the desired documents are. + - Each retrieval task should cover a wide range of queries, and should not be too specific. + +Your output should always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct retrieval task in one sentence. Do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/long-text-matching.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/long-text-matching.jinja2 new file mode 100644 index 0000000000..cd8bf1922a --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/long-text-matching.jinja2 @@ -0,0 +1,12 @@ +You have been assigned a text matching task: {{ task }} + +Your mission is to write one example for this task in JSON format. The JSON object must contain the following keys: + - "input": a string, a random input specified by the task. + - "positive_document": a string, a relevant document for the "input" according to the task. + +Please adhere to the following guidelines: + - The values of all fields should be in {{ language }}. + - Both the "input" and "positive_document" should be long documents (at least 300 words), avoid substantial word overlaps, otherwise the task would be too easy. + - The "input" and "positive_document" should be independent of each other. + +Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/monolingual-triplet.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/monolingual-triplet.jinja2 new file mode 100644 index 0000000000..585d618620 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/monolingual-triplet.jinja2 @@ -0,0 +1,10 @@ +Write a {{ unit }} triple with varying semantic similarity scores in JSON format. The semantic similarity score ranges from 1 to 5, with 1 denotes least similar and 5 denotes most similar. + +Please adhere to the following guidelines: + - The keys in JSON are "S1", "S2", and "S3", the values are all strings in {{ language }}, do not add any other keys. + - There should be some word overlaps between all three {{ unit }}s. + - The similarity score between S1 and S2 should be {{ high_score }}. + - The similarity score between S1 and S3 should be {{ low_score }}. + - The {{ unit }}s require {{ difficulty }} level education to understand and should be diverse in terms of topic and length. + +Your output must always be a JSON object only with three keys "S1", "S2" and "S3", do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/short-text-matching.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/short-text-matching.jinja2 new file mode 100644 index 0000000000..90b08f9e57 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/short-text-matching.jinja2 @@ -0,0 +1,12 @@ +You have been assigned a text matching task: {{ task }} + +Your mission is to write one example for this task in JSON format. The JSON object must contain the following keys: + - "input": a string, a random input specified by the task. + - "positive_document": a string, a relevant document for the "input" according to the task. + +Please adhere to the following guidelines: + - The values of all fields should be in {{ language }}. + - Both the "input" and "positive_document" should be very short (a sentence or a phrase), avoid substantial word overlaps, otherwise the task would be too easy. + - The "input" and "positive_document" should be independent of each other. + +Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-classification.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-classification.jinja2 new file mode 100644 index 0000000000..74a184bc56 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-classification.jinja2 @@ -0,0 +1,15 @@ +You have been assigned a text classification task: {{ task }} + +Your mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys: + - "input_text": a string, the input text specified by the classification task. + - "label": a string, the correct label of the input text. + - "misleading_label": a string, an incorrect label that is related to the task. + +Please adhere to the following guidelines: + - The "input_text" should be diverse in expression. + - The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the "input_text". + - The values for all fields should be in {{ language }}. + - Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make the task too easy. + - The "input_text" is {{ clarity }} and requires {{ difficulty }} level education to comprehend. + +Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-retrieval.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-retrieval.jinja2 new file mode 100644 index 0000000000..c76ac8a698 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-retrieval.jinja2 @@ -0,0 +1,17 @@ +You have been assigned a retrieval task: {{ task }} + +Your mission is to write one text retrieval example for this task in JSON format. The JSON object must contain the following keys: + - "user_query": a string, a random user search query specified by the retrieval task. + - "positive_document": a string, a relevant document for the user query. + - "hard_negative_document": a string, a hard negative document that only appears relevant to the query. + +Please adhere to the following guidelines: + - The "user_query" should be {{ query_type }}, {{ query_length }}, {{ clarity }}, and diverse in topic. + - All documents must be created independent of the query. Avoid copying the query verbatim. It's acceptable if some parts of the "positive_document" are not topically related to the query. + - All documents should be at least {{ num_words}} words long. + - The "hard_negative_document" contains some useful information, but it should be less useful or comprehensive compared to the "positive_document". + - Both the query and documents should be in {{ language }}. + - Do not provide any explanation in any document on why it is relevant or not relevant to the query. + - Both the query and documents require {{ difficulty }} level education to understand. + +Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index 41eb4444dc..f5c4659651 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -37,16 +37,41 @@ class TextGeneration(Task): Output columns: - generation (`str`): The generated text. - - model_name (`str`): The model name used to generate the text. + - model_name (`str`): The name of the model used to generate the text. Categories: - text-generation Examples: + + Generate text from an instruction: + ```python from distilabel.steps.tasks import TextGeneration + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + text_gen = TextGeneration( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ) + ) - task = TextGeneration(llm=LLM(...)) + text_gen.load() + + result = next( + text_gen.process( + [{"instruction": "your instruction"}] + ) + ) + # result + # [ + # { + # 'instruction': 'your instruction', + # 'model_name': 'mistralai/Mistral-7B-Instruct-v0.2', + # 'generation': 'generation', + # } + # ] ``` """ @@ -62,14 +87,10 @@ def format_input(self, input: Dict[str, Any]) -> ChatType: is the first interaction from the user within a conversation.""" if is_openai_format(input["instruction"]): - warnings.warn( + raise ValueError( "Providing `instruction` formatted as an OpenAI chat / conversation is" - " about to be deprecated in `distilabel v1.2.0`, please make sure to use" - " `ChatTextGeneration` with `messages` as input instead.", - DeprecationWarning, - stacklevel=2, + " deprecated, you should use `ChatGeneration` with `messages` as input instead.", ) - return input["instruction"] if not isinstance(input["instruction"], str): raise ValueError( @@ -96,7 +117,7 @@ def outputs(self) -> List[str]: return ["generation", "model_name"] def format_output( - self, output: Union[str, None], input: Dict[str, Any] + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None ) -> Dict[str, Any]: """The output is formatted as a dictionary with the `generation`. The `model_name` will be automatically included within the `process` method of `Task`.""" @@ -123,6 +144,44 @@ class ChatGeneration(Task): Icon: `:material-chat:` + + Examples: + + Generate text from a conversation in OpenAI chat format: + + ```python + from distilabel.steps.tasks import ChatGeneration + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + chat = ChatGeneration( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ) + ) + + chat.load() + + result = next( + chat.process( + [ + { + "messages": [ + {"role": "user", "content": "How much is 2+2?"}, + ] + } + ] + ) + ) + # result + # [ + # { + # 'messages': [{'role': 'user', 'content': 'How much is 2+2?'}], + # 'model_name': 'mistralai/Mistral-7B-Instruct-v0.2', + # 'generation': '4', + # } + # ] + ``` """ @property @@ -136,7 +195,7 @@ def format_input(self, input: Dict[str, Any]) -> ChatType: if not is_openai_format(input["messages"]): raise ValueError( - "Input `instruction` must be a string or an OpenAI chat-like format. " + "Input `messages` must be an OpenAI chat-like format conversation. " f"Got: {input['messages']}. Please check: 'https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models'." ) diff --git a/src/distilabel/steps/tasks/typing.py b/src/distilabel/steps/tasks/typing.py index cbd6ffc09c..4f92cdc057 100644 --- a/src/distilabel/steps/tasks/typing.py +++ b/src/distilabel/steps/tasks/typing.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union +from pydantic import BaseModel from typing_extensions import TypedDict @@ -24,3 +25,47 @@ class ChatItem(TypedDict): ChatType = List[ChatItem] """ChatType is a type alias for a `list` of `dict`s following the OpenAI conversational format.""" + + +class OutlinesStructuredOutputType(TypedDict, total=False): + """TypedDict to represent the structured output configuration from `outlines`.""" + + format: Literal["json", "regex"] + """One of "json" or "regex".""" + schema: Union[str, Type[BaseModel], Dict[str, Any]] + """The schema to use for the structured output. If "json", it + can be a pydantic.BaseModel class, or the schema as a string, + as obtained from `model_to_schema(BaseModel)`, if "regex", it + should be a regex pattern as a string. + """ + whitespace_pattern: Optional[Union[str, List[str]]] = None + """If "json" corresponds to a string or a list of + strings with a pattern (doesn't impact string literals). + For example, to allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + + +class InstructorStructuredOutputType(TypedDict, total=False): + """TypedDict to represent the structured output configuration from `instructor`.""" + + schema: Union[Type[BaseModel], Dict[str, Any]] + """The schema to use for the structured output, a `pydantic.BaseModel` class. """ + mode: Optional[str] + """Generation mode. Take a look at `instructor.Mode` for more information, if not informed it will + be determined automatically. """ + max_retries: int + """Number of times to reask the model in case of error, if not set will default to the model's default. """ + + +StructuredOutputType = Union[ + OutlinesStructuredOutputType, InstructorStructuredOutputType +] +"""StructuredOutputType is an alias for the union of `OutlinesStructuredOutputType` and `InstructorStructuredOutputType`.""" + +StandardInput = ChatType +"""StandardInput is an alias for ChatType that defines the default / standard input produced by `format_input`.""" +StructuredInput = Tuple[StandardInput, Union[StructuredOutputType, None]] +"""StructuredInput defines a type produced by `format_input` when using either `StructuredGeneration` or a subclass of it.""" +FormattedInput = Union[StandardInput, StructuredInput] +"""FormattedInput is an alias for the union of `StandardInput` and `StructuredInput` as generated by `format_input` and expected by the `LLM`s.""" diff --git a/src/distilabel/steps/tasks/ultrafeedback.py b/src/distilabel/steps/tasks/ultrafeedback.py index 9b200ddafd..c6cd95482c 100644 --- a/src/distilabel/steps/tasks/ultrafeedback.py +++ b/src/distilabel/steps/tasks/ultrafeedback.py @@ -60,6 +60,45 @@ class UltraFeedback(Task): References: - [`UltraFeedback: Boosting Language Models with High-quality Feedback`](https://arxiv.org/abs/2310.01377) - [`UltraFeedback - GitHub Repository`](https://github.com/OpenBMB/UltraFeedback) + + Examples: + + Rate generations from different LLMs based on the selected aspect: + + ```python + from distilabel.steps.tasks import UltraFeedback + from distilabel.llms.huggingface import InferenceEndpointsLLM + + # Consider this as a placeholder for your actual LLM. + ultrafeedback = UltraFeedback( + llm=InferenceEndpointsLLM( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + ) + ) + + ultrafeedback.load() + + result = next( + chat.process( + [ + { + "instruction": "How much is 2+2?", + "generations": ["4", "and a car"], + } + ] + ) + ) + # result + # [ + # { + # 'instruction': 'How much is 2+2?', + # 'generations': ['4', 'and a car'], + # 'ratings': [1, 2], + # 'rationales': ['explanation for 4', 'explanation for and a car'], + # 'model_name': 'mistralai/Mistral-7B-Instruct-v0.2', + # } + # ] + ``` """ aspect: Literal[ diff --git a/src/distilabel/utils/dicts.py b/src/distilabel/utils/dicts.py index 0ce96334f9..53d33d47f5 100644 --- a/src/distilabel/utils/dicts.py +++ b/src/distilabel/utils/dicts.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from collections import defaultdict from typing import Any, Dict, List, TypeVar @@ -33,3 +34,7 @@ def combine_dicts(*dicts: Dict[_K, Any]) -> Dict[_K, List[Any]]: for key, value in d.items(): combined_dict[key].append(value) return dict(combined_dict) + + +def flatten_dict(x: Dict[Any, Any]) -> Dict[Any, Any]: + return {k: json.dumps(v) if isinstance(v, dict) else v for k, v in x.items()} diff --git a/src/distilabel/utils/huggingface.py b/src/distilabel/utils/huggingface.py new file mode 100644 index 0000000000..7a637a831c --- /dev/null +++ b/src/distilabel/utils/huggingface.py @@ -0,0 +1,53 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +from typing import Final + +from huggingface_hub import constants + +_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME: Final[str] = "HF_TOKEN" + + +def get_hf_token(cls_name: str, token_arg: str) -> str: + """Get the token for the hugging face API. + + Tries to extract it from the environment variable, if it is not found + it tries to read it from the file using 'huggingface_hub', + and if not possible raises a ValueError. + + Args: + cls_name: Name of the class/function that requires the token. + token_arg: Argument name to use in the error message, normally + is "token" or "api_key". + + Raises: + ValueError: If the token is not found in the file. + + Returns: + The token for the hugging face API. + """ + token = os.getenv(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME) + if token is None: + if not Path(constants.HF_TOKEN_PATH).exists(): + raise ValueError( + f"To use `{cls_name}` an API key must be provided via" + f" `{token_arg}`, set the environment variable" + f" `{_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME}` or use the `huggingface-hub` CLI to login" + " with `huggingface-cli login`." + ) + with open(constants.HF_TOKEN_PATH) as f: + token = f.read().strip() + return token diff --git a/src/distilabel/utils/itertools.py b/src/distilabel/utils/itertools.py index 9428389188..88ce86cc4e 100644 --- a/src/distilabel/utils/itertools.py +++ b/src/distilabel/utils/itertools.py @@ -13,18 +13,20 @@ # limitations under the License. from itertools import zip_longest -from typing import Any, Iterable, Literal +from typing import Any, Iterable, List, Literal, TypeVar + +T = TypeVar("T") # Copy pasted from https://docs.python.org/3/library/itertools.html#itertools-recipes # Just added the type hints and use `if`s instead of `match` def grouper( - iterable: Iterable[Any], + iterable: Iterable[T], n: int, *, incomplete: Literal["fill", "strict", "ignore"] = "fill", fillvalue: Any = None, -) -> Iterable[Any]: +) -> Iterable[List[T]]: "Collect data into non-overlapping fixed-length chunks or blocks." # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError diff --git a/src/distilabel/utils/logging.py b/src/distilabel/utils/logging.py index 15a737f448..af1b26a18b 100644 --- a/src/distilabel/utils/logging.py +++ b/src/distilabel/utils/logging.py @@ -42,7 +42,9 @@ queue_listener: Union[QueueListener, None] = None -def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> None: +def setup_logging( + log_queue: Optional["Queue[Any]"] = None, filename: Optional[str] = None +) -> None: """Sets up logging to use a queue across all processes.""" global queue_listener @@ -53,7 +55,7 @@ def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> No # If the current process is the main process, set up a `QueueListener` # to handle logs from all subprocesses - if mp.current_process().name == "MainProcess": + if mp.current_process().name == "MainProcess" and filename: formatter = logging.Formatter("['%(name)s'] %(message)s") handler = RichHandler(rich_tracebacks=True) handler.setFormatter(formatter) @@ -66,10 +68,11 @@ def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> No ) file_handler.setFormatter(file_formatter) - queue_listener = QueueListener( - log_queue, handler, file_handler, respect_handler_level=True - ) - queue_listener.start() + if log_queue is not None: + queue_listener = QueueListener( + log_queue, handler, file_handler, respect_handler_level=True + ) + queue_listener.start() log_level = os.environ.get("DISTILABEL_LOG_LEVEL", "INFO").upper() if log_level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: @@ -80,9 +83,15 @@ def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> No log_level = "INFO" root_logger = logging.getLogger() - root_logger.handlers.clear() + + running_test = "PYTEST_CURRENT_TEST" in os.environ + if not running_test: + root_logger.handlers.clear() + + if log_queue is not None: + root_logger.addHandler(QueueHandler(log_queue)) + root_logger.setLevel(log_level) - root_logger.addHandler(QueueHandler(log_queue)) def stop_logging() -> None: @@ -90,4 +99,5 @@ def stop_logging() -> None: global queue_listener if queue_listener is not None: queue_listener.stop() + queue_listener.queue.close() queue_listener = None diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index a5a5ba40e8..ae43c586af 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -21,7 +21,6 @@ from mkdocs.config.config_options import Type from mkdocs.plugins import BasePlugin from mkdocs.structure.files import File -from mkdocs.structure.pages import Page from mkdocs_section_index import SectionPage from distilabel.utils.export_components_info import export_components_info @@ -360,11 +359,26 @@ def on_nav( steps_file = files.get_file_from_path(self.file_paths["steps"][0]) tasks_file = files.get_file_from_path(self.file_paths["tasks"][0]) llms_file = files.get_file_from_path(self.file_paths["llms"][0]) + steps_files = [ + files.get_file_from_path(path) for path in self.file_paths["steps"][0:] + ] + tasks_files = [ + files.get_file_from_path(path) for path in self.file_paths["tasks"][0:] + ] + llms_files = [ + files.get_file_from_path(path) for path in self.file_paths["llms"][0:] + ] # Create subsections - steps_page = Page("Steps", file=steps_file, config=config) # type: ignore - tasks_page = Page("Tasks", file=tasks_file, config=config) # type: ignore - llms_page = Page("LLMs", file=llms_file, config=config) # type: ignore + steps_page = SectionPage( + "Steps", file=steps_file, config=config, children=steps_files + ) # type: ignore + tasks_page = SectionPage( + "Tasks", file=tasks_file, config=config, children=tasks_files + ) # type: ignore + llms_page = SectionPage( + "LLMs", file=llms_file, config=config, children=llms_files + ) # type: ignore # Create the gallery section page = SectionPage( diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/components-list.jinja2 b/src/distilabel/utils/mkdocs/templates/components-gallery/components-list.jinja2 index 3c465761c5..319c69164b 100644 --- a/src/distilabel/utils/mkdocs/templates/components-gallery/components-list.jinja2 +++ b/src/distilabel/utils/mkdocs/templates/components-gallery/components-list.jinja2 @@ -1,8 +1,8 @@ --- -hide: +hide: - toc + - navigation --- - # {{ title }} {{ description }} diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md index 5c7af73180..eb2914b6a6 100644 --- a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md +++ b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md @@ -1,8 +1,8 @@ --- -hide: +hide: + - navigation - toc --- - # Components Gallery
diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/llm-detail.jinja2 b/src/distilabel/utils/mkdocs/templates/components-gallery/llm-detail.jinja2 index 212bbe8601..5d2b72dd90 100644 --- a/src/distilabel/utils/mkdocs/templates/components-gallery/llm-detail.jinja2 +++ b/src/distilabel/utils/mkdocs/templates/components-gallery/llm-detail.jinja2 @@ -1,3 +1,7 @@ +--- +hide: + - navigation +--- # {{ llm.name }} {% if llm.docstring.short_description %} diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2 b/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2 index 486fc4b299..43a7d552b7 100644 --- a/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2 +++ b/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2 @@ -1,5 +1,8 @@ +--- +hide: + - navigation +--- # {{ step.name }} - {% if step.docstring.short_description %} {{ step.docstring.short_description }} {% endif %} @@ -56,7 +59,7 @@ {% for example_title, code in step.docstring.examples.items() %} #### {{ example_title }} ```python -{{ code | e }} +{{ code | replace("\n", "\n") }} ``` {% endfor %} {% endif %} diff --git a/src/distilabel/utils/serialization.py b/src/distilabel/utils/serialization.py index b97669b809..8f32afc2eb 100644 --- a/src/distilabel/utils/serialization.py +++ b/src/distilabel/utils/serialization.py @@ -13,18 +13,30 @@ # limitations under the License. import importlib -import json import os import sys from enum import Enum +import orjson + if sys.version_info < (3, 11): from enum import EnumMeta as EnumType else: from enum import EnumType from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, get_args +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + get_args, +) import yaml from pydantic import BaseModel @@ -40,17 +52,35 @@ SaveFormats = Literal["json", "yaml"] +# Mapping to handle import paths that could have been serialized from previous versions +_OLD_IMPORT_MODULE_ATTR: Dict[Tuple[str, str], Tuple[str, str]] = { + ("distilabel.pipeline.base", "_Batch"): ("distilabel.pipeline.batch", "_Batch"), + ("distilabel.pipeline.base", "_BatchManager"): ( + "distilabel.pipeline.batch_manager", + "_BatchManager", + ), + ("distilabel.pipeline.base", "_BatchManagerStep"): ( + "distilabel.pipeline.batch_manager", + "_BatchManagerStep", + ), +} + + def _get_module_attr(module: str, name: str) -> Type: """Gets a class given the module and the name of the class. Returns: The type of the class. """ + + if (module, name) in _OLD_IMPORT_MODULE_ATTR: + module, name = _OLD_IMPORT_MODULE_ATTR[(module, name)] + mod = importlib.import_module(module) return getattr(mod, name) -def load_from_dict(class_: Dict[str, Any]) -> Any: +def load_with_type_info(class_: Any) -> Any: """Creates an instance of a class from a dictionary containing the type info and the serialized data of the class. @@ -60,17 +90,32 @@ def load_from_dict(class_: Dict[str, Any]) -> Any: Returns: An instance of the class with the data loaded from the dictionary. """ - type_info = class_.pop(TYPE_INFO_KEY) - if TYPE_INFO_KEY in type_info: - # There is a nested type_info, load the class recursively - type_info = load_from_dict(type_info) + if not isinstance(class_, (list, dict)): + return class_ - cls = _get_module_attr(type_info["module"], type_info["name"]) + if isinstance(class_, list): + return [load_with_type_info(x) for x in class_] for k, v in class_.items(): + class_[k] = load_with_type_info(v) if isinstance(v, (dict, list)) else v + if isinstance(v, dict) and "_type" in v and v["_type"] == "enum": class_[k] = Enum(v["_name"], v["_values"], type=eval(v["_enum_type"])) + if TYPE_INFO_KEY not in class_: + return class_ + + type_info = class_.pop(TYPE_INFO_KEY) + + cls = _get_module_attr(type_info["module"], type_info["name"]) + + if issubclass(cls, BaseModel): + # `pop` keys from the dictionary that are not in the model fields + field_names = cls.model_fields + keys_to_drop = [k for k in class_.keys() if k not in field_names] + for k in keys_to_drop: + class_.pop(k) + instance = cls(**class_) return instance @@ -83,8 +128,8 @@ def write_json(filename: Path, data: Any) -> None: data: the data to write to the file. """ filename.parent.mkdir(parents=True, exist_ok=True) - with open(filename, "w") as file: - json.dump(data, file, indent=2) + with open(filename, "wb") as f: + f.write(orjson.dumps(data, option=orjson.OPT_SERIALIZE_NUMPY)) def read_json(filename: StrOrPath) -> Any: @@ -96,8 +141,8 @@ def read_json(filename: StrOrPath) -> Any: Returns: The data from the file. """ - with open(filename, "r") as file: - return json.load(file) + with open(filename, "rb") as f: + return orjson.loads(f.read()) def write_yaml(filename: Path, data: Dict[str, Any]) -> None: @@ -159,10 +204,15 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: "_name": getattr(obj, k).__name__, "_values": {x.name: x.value for x in v}, # type: ignore } + elif isinstance(v, list): + dump[k] = {str(i): list_v for i, list_v in enumerate(v)} + # Grab the fields that need extra care (LLMs from inside tasks) to_update = _extra_serializable_fields(obj) + # Update those in the dumped dict - [dump.update(field) for field in to_update] + for field in to_update: + dump.update(field) return dump @@ -237,7 +287,7 @@ def from_dict(cls, data: Dict[str, Any]) -> Self: Returns: An instance of the class with the data loaded from the dictionary. """ - return load_from_dict(data) + return load_with_type_info(data) @classmethod def from_json(cls, path: StrOrPath) -> Self: @@ -303,12 +353,19 @@ def _check_is_dir(path: StrOrPath) -> None: def _extra_serializable_fields(obj: BaseModel) -> List[Dict[str, Dict[str, Any]]]: - # This function is here to loop over objects that contains nested _Serializable objects. - # Cannot work recursively due to the mix between models that inherit from BaseModel and - # those that don't, so we loop over the classes and update those that are _Serializable. - # Extra introspection to dump nested objects. - # Mainly for the LLMs inside a Task for the moment. - # This way we ensure the "type_info" is inserted in those objects. + """Gets the information of the nested `_Serializable` attributes within another `_Serializable` + instance. + + It's mainly used to get the information of the `LLM` objects inside a `Task` object, + as they are nested and need to be serialized (`type_info`). + + Args: + obj: the object to extract the information from. + + Returns: + A list of dictionaries containing the information of the nested `_Serializable` + attributes. + """ from distilabel.pipeline.base import BasePipeline to_update = [] @@ -316,6 +373,12 @@ def _extra_serializable_fields(obj: BaseModel) -> List[Dict[str, Dict[str, Any]] field = getattr(obj, k) # Have to remove the Pipeline as it will be inside the Step objects but is really # in a higher level hierarchy. - if isinstance(field, _Serializable) and (not isinstance(field, BasePipeline)): + if isinstance(field, BasePipeline): + continue + + if isinstance(field, _Serializable): to_update.append({k: getattr(obj, k).dump()}) + elif isinstance(field, list) and field and isinstance(field[0], _Serializable): + to_update.append({k: {str(i): x.dump() for i, x in enumerate(field)}}) + return to_update diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000000..8337c9aaa9 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,27 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from typing import Generator + +import pytest + + +@pytest.fixture(autouse=True) +def temp_cache_dir() -> Generator[None, None, None]: + """Set the cache directory to a temporary directory for all tests.""" + with tempfile.TemporaryDirectory() as tmpdirname: + os.environ["DISTILABEL_CACHE_DIR"] = tmpdirname + yield diff --git a/tests/integration/test_cache.py b/tests/integration/test_cache.py new file mode 100644 index 0000000000..6eddd6f7ca --- /dev/null +++ b/tests/integration/test_cache.py @@ -0,0 +1,55 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List + +import numpy as np +import pytest +from distilabel.pipeline import Pipeline +from distilabel.steps import GeneratorStep, StepInput, step + +if TYPE_CHECKING: + from distilabel.steps import GeneratorStepOutput, StepOutput + + +class NumpyBigArrayGenerator(GeneratorStep): + num_batches: int + + @property + def outputs(self) -> List[str]: + return ["array"] + + def process(self, offset: int = 0) -> "GeneratorStepOutput": + for i in range(self.num_batches): + yield ( + [{"array": np.random.randn(256)} for _ in range(self.batch_size)], # type: ignore + i == self.num_batches - 1, + ) # type: ignore + + +@step(step_type="global") +def ReceiveArrays(inputs: StepInput) -> "StepOutput": + yield inputs + + +@pytest.mark.benchmark +def test_cache_time() -> None: + with Pipeline(name="dummy") as pipeline: + numpy_generator = NumpyBigArrayGenerator(num_batches=2, batch_size=100) + + receive_arrays = ReceiveArrays() + + numpy_generator >> receive_arrays + + pipeline.run(use_cache=False) diff --git a/tests/integration/test_pipe_simple.py b/tests/integration/test_pipe_simple.py index 8ab1fff29a..31a624b15f 100644 --- a/tests/integration/test_pipe_simple.py +++ b/tests/integration/test_pipe_simple.py @@ -166,15 +166,8 @@ def run_pipeline(): ) -def test_pipeline_cached(): - ds = run_pipeline() - print() - print("----- RUNNING PIPELINE AGAIN -----") - print() +def test_pipeline_cached() -> None: + run_pipeline() ds = run_pipeline() assert isinstance(ds, Distiset) assert len(ds["default"]["train"]) == 80 - - -if __name__ == "__main__": - test_pipeline_cached() diff --git a/tests/integration/test_routing_batch_function.py b/tests/integration/test_routing_batch_function.py index 0ea2ee3cdc..228fb1c43e 100644 --- a/tests/integration/test_routing_batch_function.py +++ b/tests/integration/test_routing_batch_function.py @@ -74,7 +74,7 @@ def CombineGenerations(*inputs: StepInput) -> "StepOutput": yield combined_list -@pytest.mark.timeout(120) +@pytest.mark.timeout(240) def test_routing_batch_function() -> None: with Pipeline(name="test") as pipeline: load_dataset = LoadDataFromDicts( @@ -95,7 +95,7 @@ def test_routing_batch_function() -> None: assert len(row["generations"]) == 2 -@pytest.mark.timeout(120) +@pytest.mark.timeout(240) def test_routing_batch_function_irregular_batch_sizes() -> None: with Pipeline(name="test") as pipeline: load_dataset = LoadDataFromDicts( @@ -120,7 +120,7 @@ def test_routing_batch_function_irregular_batch_sizes() -> None: assert len(row["generations"]) == 2 -@pytest.mark.timeout(120) +@pytest.mark.timeout(240) def test_multiple_routing_batch_function() -> None: batch_size = 200 diff --git a/tests/integration/test_using_fs_to_pass_data.py b/tests/integration/test_using_fs_to_pass_data.py new file mode 100644 index 0000000000..811885e356 --- /dev/null +++ b/tests/integration/test_using_fs_to_pass_data.py @@ -0,0 +1,68 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List + +import numpy as np +from distilabel.pipeline import Pipeline +from distilabel.steps import GeneratorStep, StepInput, step + +if TYPE_CHECKING: + from distilabel.steps import GeneratorStepOutput, StepOutput + + +class NumpyBigArrayGenerator(GeneratorStep): + num_batches: int + + @property + def outputs(self) -> List[str]: + return ["array"] + + def process(self, offset: int = 0) -> "GeneratorStepOutput": + for i in range(self.num_batches): + yield ( + [{"array": np.random.randn(128)} for _ in range(self.batch_size)], # type: ignore + i == self.num_batches - 1, + ) # type: ignore + + +@step(step_type="global") +def ReceiveArrays(inputs: StepInput) -> "StepOutput": + yield inputs + + +def test_passing_data_through_fs_only_global_steps() -> None: + with Pipeline(name="dummy") as pipeline: + numpy_generator = NumpyBigArrayGenerator(num_batches=5, batch_size=100) + + receive_arrays = ReceiveArrays() + + numpy_generator >> receive_arrays + + distiset = pipeline.run(use_fs_to_pass_data=False, use_cache=False) + + assert len(distiset["default"]["train"]) == 500 + + +def test_passing_data_through_fs() -> None: + with Pipeline(name="dummy") as pipeline: + numpy_generator = NumpyBigArrayGenerator(num_batches=2, batch_size=200) + + receive_arrays = ReceiveArrays() + + numpy_generator >> receive_arrays + + distiset = pipeline.run(use_fs_to_pass_data=True, use_cache=False) + + assert len(distiset["default"]["train"]) == 400 diff --git a/tests/unit/steps/tasks/utils.py b/tests/unit/conftest.py similarity index 56% rename from tests/unit/steps/tasks/utils.py rename to tests/unit/conftest.py index 989fb3ad5b..bbe6ca1ed4 100644 --- a/tests/unit/steps/tasks/utils.py +++ b/tests/unit/conftest.py @@ -12,13 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Union +from typing import TYPE_CHECKING -from distilabel.llms.base import LLM -from distilabel.steps.tasks.typing import ChatType +import pytest +from distilabel.llms.base import AsyncLLM +if TYPE_CHECKING: + from distilabel.llms.typing import GenerateOutput + from distilabel.steps.tasks.typing import FormattedInput -class DummyLLM(LLM): + +# Defined here too, so that the serde still works +class DummyLLM(AsyncLLM): def load(self) -> None: pass @@ -26,7 +31,12 @@ def load(self) -> None: def model_name(self) -> str: return "test" - def generate( - self, inputs: List["ChatType"], num_generations: int = 1, **kwargs: Any - ) -> List[List[Union[str, None]]]: - return [["output" for _ in range(num_generations)] for _ in inputs] + async def agenerate( + self, input: "FormattedInput", num_generations: int = 1 + ) -> "GenerateOutput": + return ["output" for _ in range(num_generations)] + + +@pytest.fixture +def dummy_llm() -> AsyncLLM: + return DummyLLM() diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index 9caccf43c4..87a890a38c 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random +from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch import nest_asyncio @@ -22,6 +24,36 @@ @patch("huggingface_hub.AsyncInferenceClient") @patch("openai.AsyncOpenAI") class TestInferenceEndpointsLLM: + def test_load_no_api_key( + self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral" + ) + + # Mock `huggingface_hub.constants.HF_TOKEN_PATH` to not exist + with mock.patch("pathlib.Path.exists") as mock_exists: + mock_exists.return_value = False + with pytest.raises( + ValueError, + match="To use `InferenceEndpointsLLM` an API key must be provided", + ): + llm.load() + + def test_load_with_cached_token( + self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral" + ) + + # Mock `huggingface_hub.constants.HF_TOKEN_PATH` to exist + with mock.patch("pathlib.Path.exists", return_value=True), mock.patch( + "builtins.open", new_callable=mock.mock_open, read_data="hf_token" + ): + # Should not raise any errors + llm.load() + def test_serverless_inference_endpoints_llm( self, mock_inference_client: MagicMock, mock_openai_client: MagicMock ) -> None: @@ -145,6 +177,7 @@ async def test_generate_via_openai_client( ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) + ... nest_asyncio.apply() assert llm.generate( @@ -159,8 +192,54 @@ async def test_generate_via_openai_client( ] ) == [(" Aenean hendrerit aliquam velit. ...",)] + @pytest.mark.asyncio + async def test_agenerate_with_structured_output( + self, mock_inference_client: MagicMock, _: MagicMock + ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral", + structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"}, + ) + llm._aclient = mock_inference_client + + llm._aclient.text_generation = AsyncMock( + return_value=" Aenean hendrerit aliquam velit. ..." + ) + + # Since there's a pseudo-random number within the generation kwargs, we set the seed + # here first to ensure reproducibility within the tests + random.seed(42) + + assert await llm.agenerate( + input=[ + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) == [" Aenean hendrerit aliquam velit. ..."] + + kwargs = { + "prompt": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + "max_new_tokens": 128, + "do_sample": False, + "typical_p": None, + "repetition_penalty": None, + "temperature": 1.0, + "top_p": None, + "top_k": None, + "stop_sequences": None, + "return_full_text": False, + "watermark": False, + "grammar": {"type": "regex", "value": "\\b[A-Z][a-z]*\\b"}, + "seed": 478163327, # pre-computed random value with `random.seed(42)` + } + mock_inference_client.text_generation.assert_called_with(**kwargs) + def test_serialization( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, + mock_inference_client: MagicMock, + mock_openai_client: MagicMock, ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", @@ -173,9 +252,9 @@ def test_serialization( "base_url": None, "tokenizer_id": None, "generation_kwargs": {}, + "structured_output": None, "model_display_name": None, "use_openai_client": False, - "structured_output": None, "type_info": { "module": "distilabel.llms.huggingface.inference_endpoints", "name": "InferenceEndpointsLLM", diff --git a/tests/unit/llms/test_anthropic.py b/tests/unit/llms/test_anthropic.py index 28e486756b..75e7bcbf62 100644 --- a/tests/unit/llms/test_anthropic.py +++ b/tests/unit/llms/test_anthropic.py @@ -13,12 +13,16 @@ # limitations under the License. import os +import sys +from typing import Any, Dict from unittest.mock import AsyncMock, MagicMock, Mock, patch import nest_asyncio import pytest from distilabel.llms.anthropic import AnthropicLLM +from .utils import DummyUserDetail + @patch("anthropic.AsyncAnthropic") class TestAnthropicLLM: @@ -47,6 +51,37 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None: ] ) + @pytest.mark.asyncio + async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: + llm = AnthropicLLM( + model="claude-3-opus-20240229", + api_key="api.key", + structured_output={ + "schema": DummyUserDetail, + "mode": "tool_call", + "max_retries": 1, + }, + ) # type: ignore + llm._aclient = mock_openai + + sample_user = DummyUserDetail(name="John Doe", age=30) + + llm._aclient.messages.create = AsyncMock(return_value=sample_user) + + generation = await llm.agenerate( + input=[ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) + assert generation[0] == sample_user.model_dump_json() + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + ) @pytest.mark.asyncio async def test_generate(self, mock_anthropic: MagicMock) -> None: llm = AnthropicLLM(model="claude-3-opus-20240229") # type: ignore @@ -71,7 +106,52 @@ async def test_generate(self, mock_anthropic: MagicMock) -> None: ] ) - def test_serialization(self, _: MagicMock) -> None: + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "base_url": "https://api.anthropic.com", + "generation_kwargs": {}, + "max_retries": 6, + "model": "claude-3-opus-20240229", + "timeout": 600.0, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.anthropic", + "name": "AnthropicLLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "base_url": "https://api.anthropic.com", + "generation_kwargs": {}, + "max_retries": 6, + "model": "claude-3-opus-20240229", + "timeout": 600.0, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.anthropic", + "name": "AnthropicLLM", + }, + }, + ), + ], + ) + def test_serialization( + self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: os.environ["ANTHROPIC_API_KEY"] = "api.key" llm = AnthropicLLM(model="claude-3-opus-20240229") # type: ignore diff --git a/tests/unit/llms/test_azure.py b/tests/unit/llms/test_azure.py index a5208da95f..e8af5d7b8f 100644 --- a/tests/unit/llms/test_azure.py +++ b/tests/unit/llms/test_azure.py @@ -13,10 +13,14 @@ # limitations under the License. import os +from typing import Any, Dict from unittest import mock +import pytest from distilabel.llms.azure import AzureOpenAILLM +from .utils import DummyUserDetail + class TestAzureOpenAILLM: model_id: str = "gpt-4" @@ -56,20 +60,70 @@ def test_azure_openai_llm_env_vars(self) -> None: assert llm.api_key.get_secret_value() == "another.api.key" # type: ignore assert llm.api_version == self.api_version - def test_serialization(self) -> None: + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "model": "gpt-4", + "api_version": "preview", + "generation_kwargs": {}, + "max_retries": 6, + "base_url": "https://example-resource.azure.openai.com/", + "timeout": 120, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.azure", + "name": "AzureOpenAILLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "model": "gpt-4", + "api_version": "preview", + "generation_kwargs": {}, + "max_retries": 6, + "base_url": "https://example-resource.azure.openai.com/", + "timeout": 120, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.azure", + "name": "AzureOpenAILLM", + }, + }, + ), + ], + ) + def test_serialization( + self, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: llm = AzureOpenAILLM( - model=self.model_id, base_url=self.base_url, api_version=self.api_version + model=self.model_id, + base_url=self.base_url, + api_version=self.api_version, + structured_output=structured_output, ) - _dump = { - "generation_kwargs": {}, - "model": "gpt-4", - "base_url": "https://example-resource.azure.openai.com/", - "max_retries": 6, - "timeout": 120, - "api_version": "preview", - "structured_output": None, - "type_info": {"module": "distilabel.llms.azure", "name": "AzureOpenAILLM"}, - } - assert llm.dump() == _dump - assert isinstance(AzureOpenAILLM.from_dict(_dump), AzureOpenAILLM) + # _dump = { + # "generation_kwargs": {}, + # "model": "gpt-4", + # "base_url": "https://example-resource.azure.openai.com/", + # "max_retries": 6, + # "timeout": 120, + # "api_version": "preview", + # "structured_output": None, + # "type_info": {"module": "distilabel.llms.azure", "name": "AzureOpenAILLM"}, + # } + assert llm.dump() == dump + assert isinstance(AzureOpenAILLM.from_dict(dump), AzureOpenAILLM) diff --git a/tests/unit/llms/test_cohere.py b/tests/unit/llms/test_cohere.py index 0c2e2e213c..a16d904e11 100644 --- a/tests/unit/llms/test_cohere.py +++ b/tests/unit/llms/test_cohere.py @@ -13,12 +13,16 @@ # limitations under the License. import os +import sys +from typing import Any, Dict from unittest import mock import nest_asyncio import pytest from distilabel.llms.cohere import CohereLLM +from .utils import DummyUserDetail + @mock.patch("cohere.AsyncClient") class TestCohereLLM: @@ -64,6 +68,38 @@ async def test_agenerate(self, mock_async_client: mock.MagicMock) -> None: ] ) + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + ) + @pytest.mark.asyncio + async def test_agenerate_structured( + self, mock_async_client: mock.MagicMock + ) -> None: + llm = CohereLLM( + model="command-r", + structured_output={ + "schema": DummyUserDetail, + "mode": "tool_call", + "max_retries": 1, + }, + ) + llm._aclient = mock_async_client # type: ignore + + sample_user = DummyUserDetail(name="John Doe", age=30) + + llm._aclient.chat = mock.AsyncMock(return_value=sample_user) + + generation = await llm.agenerate( + input=[ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) + assert generation == [sample_user.model_dump_json()] + @pytest.mark.asyncio async def test_generate(self, mock_async_client: mock.MagicMock) -> None: llm = CohereLLM(model="command-r") @@ -92,21 +128,53 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None: ] ) - def test_serialization(self, _: mock.MagicMock) -> None: - llm = CohereLLM(model="command-r") - - dump = { - "model": "command-r", - "generation_kwargs": {}, - "base_url": "https://api.cohere.ai/v1", - "timeout": 120, - "client_name": "distilabel", - "structured_output": None, - "type_info": { - "module": "distilabel.llms.cohere", - "name": "CohereLLM", - }, - } + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "model": "command-r", + "generation_kwargs": {}, + "base_url": "https://api.cohere.ai/v1", + "timeout": 120, + "client_name": "distilabel", + "structured_output": None, + "type_info": { + "module": "distilabel.llms.cohere", + "name": "CohereLLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "model": "command-r", + "generation_kwargs": {}, + "base_url": "https://api.cohere.ai/v1", + "timeout": 120, + "client_name": "distilabel", + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.cohere", + "name": "CohereLLM", + }, + }, + ), + ], + ) + def test_serialization( + self, _: mock.MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: + llm = CohereLLM(model="command-r", structured_output=structured_output) assert llm.dump() == dump assert isinstance(CohereLLM.from_dict(dump), CohereLLM) diff --git a/tests/unit/llms/test_groq.py b/tests/unit/llms/test_groq.py index e75166ce97..7607ab2cb2 100644 --- a/tests/unit/llms/test_groq.py +++ b/tests/unit/llms/test_groq.py @@ -13,12 +13,16 @@ # limitations under the License. import os +import sys +from typing import Any, Dict from unittest.mock import AsyncMock, MagicMock, Mock, patch import nest_asyncio import pytest from distilabel.llms.groq import GroqLLM +from .utils import DummyUserDetail + @patch("groq._client.AsyncGroq") class TestGroqLLM: @@ -47,6 +51,37 @@ async def test_agenerate(self, mock_groq: MagicMock) -> None: ] ) == [" Aenean hendrerit aliquam velit. ..."] + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + ) + @pytest.mark.asyncio + async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: + llm = GroqLLM( + model="llama3-70b-8192", + api_key="api.key", + structured_output={ + "schema": DummyUserDetail, + "mode": "tool_call", + "max_retries": 1, + }, + ) # type: ignore + llm._aclient = mock_openai + + sample_user = DummyUserDetail(name="John Doe", age=30) + + llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) + + generation = await llm.agenerate( + input=[ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) + assert generation[0] == sample_user.model_dump_json() + @pytest.mark.asyncio async def test_generate(self, mock_groq: MagicMock) -> None: llm = GroqLLM(model="llama3-70b-8192", api_key="api.key") # type: ignore @@ -71,22 +106,54 @@ async def test_generate(self, mock_groq: MagicMock) -> None: ] ) == [(" Aenean hendrerit aliquam velit. ...",)] - def test_serialization(self, mock_groq: MagicMock) -> None: + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "model": "llama3-70b-8192", + "base_url": "https://api.groq.com", + "generation_kwargs": {}, + "max_retries": 2, + "timeout": 120, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.groq", + "name": "GroqLLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "model": "llama3-70b-8192", + "base_url": "https://api.groq.com", + "generation_kwargs": {}, + "max_retries": 2, + "timeout": 120, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.groq", + "name": "GroqLLM", + }, + }, + ), + ], + ) + def test_serialization( + self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: os.environ["GROQ_API_KEY"] = "api.key" - llm = GroqLLM(model="llama3-70b-8192") - - _dump = { - "model": "llama3-70b-8192", - "base_url": "https://api.groq.com", - "generation_kwargs": {}, - "max_retries": 2, - "timeout": 120, - "structured_output": None, - "type_info": { - "module": "distilabel.llms.groq", - "name": "GroqLLM", - }, - } + llm = GroqLLM(model="llama3-70b-8192", structured_output=structured_output) - assert llm.dump() == _dump - assert isinstance(GroqLLM.from_dict(_dump), GroqLLM) # type: ignore + assert llm.dump() == dump + assert isinstance(GroqLLM.from_dict(dump), GroqLLM) # type: ignore diff --git a/tests/unit/llms/test_llamacpp.py b/tests/unit/llms/test_llamacpp.py index c69b460ce5..b226d99292 100644 --- a/tests/unit/llms/test_llamacpp.py +++ b/tests/unit/llms/test_llamacpp.py @@ -14,11 +14,13 @@ import os import urllib.request -from typing import Generator +from typing import Any, Dict, Generator import pytest from distilabel.llms.llamacpp import LlamaCppLLM +from .utils import DummyUserDetail + @pytest.fixture(scope="module") def llm() -> Generator[LlamaCppLLM, None, None]: @@ -54,3 +56,62 @@ def test_generate(self, llm: LlamaCppLLM) -> None: assert len(responses) == 2 assert len(responses[0]) == 3 + + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "chat_format": None, + "extra_kwargs": {}, + "n_batch": 512, + "n_ctx": 512, + "n_gpu_layers": 0, + "seed": 4294967295, + "generation_kwargs": {}, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.llamacpp", + "name": "LlamaCppLLM", + }, + "verbose": False, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "format": "json", + }, + { + "chat_format": None, + "extra_kwargs": {}, + "n_batch": 512, + "n_ctx": 512, + "n_gpu_layers": 0, + "seed": 4294967295, + "generation_kwargs": {}, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "format": "json", + }, + "type_info": { + "module": "distilabel.llms.llamacpp", + "name": "LlamaCppLLM", + }, + "verbose": False, + }, + ), + ], + ) + def test_serialization( + self, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: + llm = LlamaCppLLM( + model_path="tinyllama.gguf", + n_gpu_layers=0, + structured_output=structured_output, + ) + + assert llm.dump() == dump + assert isinstance(LlamaCppLLM.from_dict(dump), LlamaCppLLM) diff --git a/tests/unit/llms/test_mistral.py b/tests/unit/llms/test_mistral.py index f31e903d3e..5bb2337481 100644 --- a/tests/unit/llms/test_mistral.py +++ b/tests/unit/llms/test_mistral.py @@ -14,11 +14,14 @@ import os import sys +from typing import Any, Dict from unittest.mock import AsyncMock, MagicMock, Mock, patch import nest_asyncio import pytest +from .utils import DummyUserDetail + try: from distilabel.llms.mistral import MistralLLM except ImportError: @@ -55,6 +58,37 @@ async def test_agenerate(self, mock_mistral: MagicMock) -> None: ] ) + @pytest.mark.asyncio + async def test_agenerate_structured(self, mock_mistral: MagicMock) -> None: + llm = MistralLLM( + model="mistral-tiny", + api_key="api.key", + structured_output={ + "schema": DummyUserDetail, + "mode": "tool_call", + "max_retries": 1, + }, + ) # type: ignore + llm._aclient = mock_mistral + + sample_user = DummyUserDetail(name="John Doe", age=30) + + llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) + # This should work just with the _aclient.chat method once it's fixed in instructor, and + # then in our code. + # llm._aclient.chat = AsyncMock(return_value=sample_user) + + generation = await llm.agenerate( + input=[ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) + assert generation[0] == sample_user.model_dump_json() + @pytest.mark.asyncio async def test_generate(self, mock_mistral: MagicMock) -> None: llm = MistralLLM(model="mistral-tiny", api_key="api.key") # type: ignore @@ -79,7 +113,54 @@ async def test_generate(self, mock_mistral: MagicMock) -> None: ] ) - def test_serialization(self, mock_mistral: MagicMock) -> None: + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "model": "mistral-tiny", + "endpoint": "https://api.mistral.ai", + "generation_kwargs": {}, + "max_retries": 6, + "timeout": 120, + "max_concurrent_requests": 64, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.mistral", + "name": "MistralLLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "model": "mistral-tiny", + "endpoint": "https://api.mistral.ai", + "generation_kwargs": {}, + "max_retries": 6, + "timeout": 120, + "max_concurrent_requests": 64, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.mistral", + "name": "MistralLLM", + }, + }, + ), + ], + ) + def test_serialization( + self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: os.environ["MISTRAL_API_KEY"] = "api.key" llm = MistralLLM(model="mistral-tiny") # type: ignore diff --git a/tests/unit/llms/test_mixins.py b/tests/unit/llms/test_mixins.py index feb8b00e01..c0c7b10671 100644 --- a/tests/unit/llms/test_mixins.py +++ b/tests/unit/llms/test_mixins.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing as mp import os import sys from typing import TYPE_CHECKING, Any, Generator, List, Union @@ -43,6 +42,10 @@ def load(self) -> None: super().load() CudaDevicePlacementMixin.load(self) + def unload(self) -> None: + super().unload() + CudaDevicePlacementMixin.unload(self) + @property def model_name(self) -> str: return "test" @@ -63,13 +66,7 @@ def test_set_cuda_visible_devices(self) -> None: assert os.environ["CUDA_VISIBLE_DEVICES"] == "0,1" - def test_cuda_visible_devices_not_cuda_devices(self) -> None: - llm = DummyCudaLLM() - llm._llm_identifier = "unit-test" - - llm.load() - - assert os.getenv("CUDA_VISIBLE_DEVICES") is None + llm.unload() def test_set_cuda_visible_devices_unvalid_devices(self) -> None: llm = DummyCudaLLM(cuda_devices=[5, 6]) @@ -80,84 +77,54 @@ def test_set_cuda_visible_devices_unvalid_devices(self) -> None: ): llm.load() - def test_set_device_placement_info(self) -> None: - llm = DummyCudaLLM(cuda_devices="auto") + llm.unload() + + def test_set_cuda_visible_devices_auto(self) -> None: + llm1 = DummyCudaLLM() + llm1._llm_identifier = "unit-test-1" + llm1.load() - with mp.Manager() as manager: - llm.set_device_placement_info( - llm_identifier="unit-test", - device_llm_placement_map=manager.dict(), - device_llm_placement_lock=manager.Lock(), # type: ignore - ) + assert os.environ["CUDA_VISIBLE_DEVICES"] == "0" - assert llm._llm_identifier == "unit-test" - assert llm._device_llm_placement_map is not None + llm2 = DummyCudaLLM() + llm2._llm_identifier = "unit-test-2" + llm2.load() - def test_set_cuda_visible_devices_auto(self) -> None: - with mp.Manager() as manager: - device_llm_placement_map = manager.dict() - lock = manager.Lock() - - llm1 = DummyCudaLLM() - llm1.set_device_placement_info( - llm_identifier="unit-test-1", - device_llm_placement_map=device_llm_placement_map, - device_llm_placement_lock=lock, # type: ignore - ) - llm1.load() - - assert os.environ["CUDA_VISIBLE_DEVICES"] == "0" - - llm2 = DummyCudaLLM() - llm2.set_device_placement_info( - llm_identifier="unit-test-2", - device_llm_placement_map=device_llm_placement_map, - device_llm_placement_lock=lock, # type: ignore - ) - llm2.load() - - assert os.environ["CUDA_VISIBLE_DEVICES"] == "1" + assert os.environ["CUDA_VISIBLE_DEVICES"] == "1" + + llm1.unload() + llm2.unload() def test_set_cuda_visible_devices_auto_not_enough_devices(self) -> None: - with mp.Manager() as manager: - device_llm_placement_map = manager.dict() - lock = manager.Lock() - - with pytest.raises( - RuntimeError, match="Couldn't find an available CUDA device" - ): - # 4 devices are available, but 5 LLMs are going to be loaded - for i in range(5): - llm = DummyCudaLLM() - llm.set_device_placement_info( - llm_identifier=f"unit-test-{i}", - device_llm_placement_map=device_llm_placement_map, - device_llm_placement_lock=lock, # type: ignore - ) - llm.load() + llms = [] + for i in range(5): + llm = DummyCudaLLM() + llm._llm_identifier = f"unit-test-{i}" + llms.append(llm) + + with pytest.raises( + RuntimeError, match="Couldn't find an available CUDA device" + ): + # 4 devices are available, but 5 LLMs are going to be loaded + for llm in llms: + llm.load() + + for llm in llms: + llm.unload() def test_check_cuda_devices(self, caplog) -> None: - with mp.Manager() as manager: - device_llm_placement_map = manager.dict() - lock = manager.Lock() - - llm1 = DummyCudaLLM(cuda_devices=[1]) - llm1.set_device_placement_info( - llm_identifier="unit-test-1", - device_llm_placement_map=device_llm_placement_map, - device_llm_placement_lock=lock, # type: ignore - ) - llm1.load() - - llm2 = DummyCudaLLM(cuda_devices=[1]) - llm2.set_device_placement_info( - llm_identifier="unit-test-2", - device_llm_placement_map=device_llm_placement_map, - device_llm_placement_lock=lock, # type: ignore - ) - llm2.load() - - assert ( - "LLM with identifier 'unit-test-1' is also going to use CUDA device '1'" - in caplog.text - ) + llm1 = DummyCudaLLM(cuda_devices=[1]) + llm1._llm_identifier = "unit-test-1" + llm1.load() + + llm2 = DummyCudaLLM(cuda_devices=[1]) + llm2._llm_identifier = "unit-test-2" + llm2.load() + + assert ( + "LLM with identifier 'unit-test-1' is also going to use CUDA device '1'" + in caplog.text + ) + + llm1.unload() + llm2.unload() diff --git a/tests/unit/llms/test_moa.py b/tests/unit/llms/test_moa.py new file mode 100644 index 0000000000..b3a92eded1 --- /dev/null +++ b/tests/unit/llms/test_moa.py @@ -0,0 +1,61 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.llms.moa import MOA_SYSTEM_PROMPT, MixtureOfAgentsLLM + +from tests.unit.conftest import DummyLLM + + +class TestMixtureOfAgents: + def test_model_name(self) -> None: + llm = MixtureOfAgentsLLM( + aggregator_llm=DummyLLM(), + proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()], + ) + + assert llm.model_name == "moa-test-test-test-test" + + def test_build_moa_system_prompt(self) -> None: + llm = MixtureOfAgentsLLM( + aggregator_llm=DummyLLM(), + proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()], + ) + + system_prompt = llm._build_moa_system_prompt( + prev_outputs=["output1", "output2", "output3"] + ) + + assert ( + system_prompt == f"{MOA_SYSTEM_PROMPT}\n1. output1\n2. output2\n3. output3" + ) + + def test_inject_moa_system_prompt(self) -> None: + llm = MixtureOfAgentsLLM( + aggregator_llm=DummyLLM(), + proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()], + ) + + results = llm._inject_moa_system_prompt( + input=[ + {"role": "system", "content": "I'm a system prompt."}, + ], + prev_outputs=["output1", "output2", "output3"], + ) + + assert results == [ + { + "role": "system", + "content": f"{MOA_SYSTEM_PROMPT}\n1. output1\n2. output2\n3. output3\n\nI'm a system prompt.", + } + ] diff --git a/tests/unit/llms/test_openai.py b/tests/unit/llms/test_openai.py index 3562b6588b..7f90f513a2 100644 --- a/tests/unit/llms/test_openai.py +++ b/tests/unit/llms/test_openai.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import sys +from typing import Any, Dict from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -20,6 +22,8 @@ import pytest from distilabel.llms.openai import OpenAILLM +from .utils import DummyUserDetail + @patch("openai.AsyncOpenAI") class TestOpenAILLM: @@ -63,6 +67,37 @@ async def test_agenerate(self, mock_openai: MagicMock) -> None: ] ) + @pytest.mark.asyncio + async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: + llm = OpenAILLM( + model=self.model_id, + api_key="api.key", + structured_output={ + "schema": DummyUserDetail, + "mode": "tool_call", + "max_retries": 1, + }, + ) # type: ignore + llm._aclient = mock_openai + + sample_user = DummyUserDetail(name="John Doe", age=30) + + llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) + + generation = await llm.agenerate( + input=[ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) + assert generation[0] == sample_user.model_dump_json() + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + ) @pytest.mark.asyncio async def test_generate(self, mock_openai: MagicMock) -> None: llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore @@ -101,21 +136,53 @@ async def test_generate(self, mock_openai: MagicMock) -> None: response_format="unkown_format", ) - def test_serialization(self, _: MagicMock) -> None: - llm = OpenAILLM(model=self.model_id) - - _dump = { - "model": self.model_id, - "generation_kwargs": {}, - "max_retries": 6, - "base_url": "https://api.openai.com/v1", - "timeout": 120, - "structured_output": None, - "type_info": { - "module": "distilabel.llms.openai", - "name": "OpenAILLM", - }, - } - - assert llm.dump() == _dump - assert isinstance(OpenAILLM.from_dict(_dump), OpenAILLM) + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "model": "gpt-4", + "generation_kwargs": {}, + "max_retries": 6, + "base_url": "https://api.openai.com/v1", + "timeout": 120, + "structured_output": None, + "type_info": { + "module": "distilabel.llms.openai", + "name": "OpenAILLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + { + "model": "gpt-4", + "generation_kwargs": {}, + "max_retries": 6, + "base_url": "https://api.openai.com/v1", + "timeout": 120, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "mode": "tool_call", + "max_retries": 1, + }, + "type_info": { + "module": "distilabel.llms.openai", + "name": "OpenAILLM", + }, + }, + ), + ], + ) + def test_serialization( + self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: + llm = OpenAILLM(model=self.model_id, structured_output=structured_output) + + assert llm.dump() == dump + assert isinstance(OpenAILLM.from_dict(dump), OpenAILLM) diff --git a/tests/unit/llms/test_vertexai.py b/tests/unit/llms/test_vertexai.py index b15262e26c..5d3f8d1217 100644 --- a/tests/unit/llms/test_vertexai.py +++ b/tests/unit/llms/test_vertexai.py @@ -115,7 +115,6 @@ def test_serialization(self, _: MagicMock) -> None: _dump = { "model": "gemini-1.0-pro", "generation_kwargs": {}, - "structured_output": None, "type_info": { "module": "distilabel.llms.vertexai", "name": "VertexAILLM", diff --git a/tests/unit/llms/test_vllm.py b/tests/unit/llms/test_vllm.py new file mode 100644 index 0000000000..4c847aad8e --- /dev/null +++ b/tests/unit/llms/test_vllm.py @@ -0,0 +1,170 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import numpy as np +import pytest +from distilabel.llms import vLLM +from distilabel.llms.vllm import _sort_batches +from pydantic import BaseModel + + +class Character(BaseModel): + name: str + description: str + role: str + weapon: str + + +class Animal(BaseModel): + name: str + species: str + habitat: str + diet: str + + +SAMPLE_DATA = [ + [ + { + "instruction": "Generate a character from a RPG game.", + "structured_output": { + "format": "json", + "schema": Character.model_json_schema(), + }, + }, + { + "instruction": "Generate an animal from a zoo.", + "structured_output": { + "format": "json", + "schema": Animal.model_json_schema(), + }, + }, + { + "instruction": "Repeated character", + "structured_output": { + "format": "json", + "schema": Character.model_json_schema(), + }, + }, + { + "instruction": "What's the weather like today in Seattle in Celsius degrees?", + "structured_output": { + "format": "regex", + "schema": "(\\d{1,2})°C", + }, + }, + { + "instruction": "Other character", + "structured_output": { + "format": "json", + "schema": Character.model_json_schema(), + }, + }, + { + "instruction": "repeated regex", + "structured_output": { + "format": "regex", + "schema": "(\\d{1,2})°C", + }, + }, + ] +] + + +# Just a mock to avoid loading the model +class DummyTokenizer: + def __init__(self) -> None: + pass + + def apply_chat_template(self, input, **kwargs): + return input + + +class TestvLLM: + @pytest.mark.parametrize( + "num_generations, expected_sorted_batches", + [ + ( + 1, + [ + "Generate a character from a RPG game.", + "Generate an animal from a zoo.", + "Repeated character", + "What's the weather like today in Seattle in Celsius degrees?", + "Other character", + "repeated regex", + ], + ), + ( + 3, + np.repeat( + [ + "Generate a character from a RPG game.", + "Generate an animal from a zoo.", + "Repeated character", + "What's the weather like today in Seattle in Celsius degrees?", + "Other character", + "repeated regex", + ], + 3, + ).tolist(), + ), + ], + ) + def test_prepare_batches_and_sort_back( + self, num_generations: int, expected_sorted_batches: List[str] + ): + formatted_inputs = [ + (item["instruction"], item["structured_output"]) + for row in SAMPLE_DATA + for item in row + ] + llm = vLLM(model="dummy") + llm._tokenizer = DummyTokenizer() + batches, indices = llm._prepare_batches(formatted_inputs) + # NOTE: We have to simulate calling self._model.generate(n=num_generations) and then sorting the results + num_generations_batches = [] + for batch in batches: + num_generations_batches.append( + (np.repeat(batch[0], num_generations).tolist(), batch[1]) + ) + batches = num_generations_batches + # Recreate as the output from batched_outputs += [[output.text for output in outputs.outputs] for outputs in batch_outputs] + batches = [batch for batch, _ in batches] + sorted_batches = _sort_batches( + batches, indices, num_generations=num_generations + ) + + assert sorted_batches == [ + np.repeat( + [ + "Generate a character from a RPG game.", + "Generate an animal from a zoo.", + "Repeated character", + ], + num_generations, + ).tolist(), + np.repeat( + ["What's the weather like today in Seattle in Celsius degrees?"], + num_generations, + ).tolist(), + np.repeat( + [ + "Other character", + "repeated regex", + ], + num_generations, + ).tolist(), + ] diff --git a/tests/unit/llms/utils.py b/tests/unit/llms/utils.py new file mode 100644 index 0000000000..7b899253bb --- /dev/null +++ b/tests/unit/llms/utils.py @@ -0,0 +1,20 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import BaseModel + + +class DummyUserDetail(BaseModel): + name: str + age: int diff --git a/tests/unit/mixins/test_runtime_parameters.py b/tests/unit/mixins/test_runtime_parameters.py index 82cac59b7e..8e8d7766d0 100644 --- a/tests/unit/mixins/test_runtime_parameters.py +++ b/tests/unit/mixins/test_runtime_parameters.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import List, Optional from distilabel.mixins.runtime_parameters import ( RuntimeParameter, @@ -32,6 +32,7 @@ class DummyNestedClass(RuntimeParametersMixin): class DummyClass(RuntimeParametersMixin): nested_class: DummyNestedClass + mixins_list: List[DummyNestedClass] runtime_param1: RuntimeParameter[SecretStr] = Field( default=None, description="Runtime param 1" @@ -43,7 +44,10 @@ class DummyClass(RuntimeParametersMixin): class TestRuntimeParametersMixin: def test_runtime_parameters_names(self) -> None: - dummy = DummyClass(nested_class=DummyNestedClass()) + dummy = DummyClass( + nested_class=DummyNestedClass(), + mixins_list=[DummyNestedClass(), DummyNestedClass(), DummyNestedClass()], + ) assert dummy.runtime_parameters_names == { "runtime_param1": False, @@ -52,10 +56,27 @@ def test_runtime_parameters_names(self) -> None: "runtime_param1": False, "runtime_param2": True, }, + "mixins_list": { + "0": { + "runtime_param1": False, + "runtime_param2": True, + }, + "1": { + "runtime_param1": False, + "runtime_param2": True, + }, + "2": { + "runtime_param1": False, + "runtime_param2": True, + }, + }, } def test_get_runtime_parameters_info(self) -> None: - dummy = DummyClass(nested_class=DummyNestedClass()) + dummy = DummyClass( + nested_class=DummyNestedClass(), + mixins_list=[DummyNestedClass(), DummyNestedClass(), DummyNestedClass()], + ) assert dummy.get_runtime_parameters_info() == [ { @@ -73,6 +94,47 @@ def test_get_runtime_parameters_info(self) -> None: }, ], }, + { + "name": "mixins_list", + "runtime_parameters_info": { + "0": [ + { + "name": "runtime_param1", + "description": "Runtime param 1", + "optional": False, + }, + { + "name": "runtime_param2", + "description": "Runtime param 2", + "optional": True, + }, + ], + "1": [ + { + "name": "runtime_param1", + "description": "Runtime param 1", + "optional": False, + }, + { + "name": "runtime_param2", + "description": "Runtime param 2", + "optional": True, + }, + ], + "2": [ + { + "name": "runtime_param1", + "description": "Runtime param 1", + "optional": False, + }, + { + "name": "runtime_param2", + "description": "Runtime param 2", + "optional": True, + }, + ], + }, + }, { "name": "runtime_param1", "description": "Runtime param 1", @@ -86,7 +148,10 @@ def test_get_runtime_parameters_info(self) -> None: ] def test_set_runtime_parameters(self) -> None: - dummy = DummyClass(nested_class=DummyNestedClass()) + dummy = DummyClass( + nested_class=DummyNestedClass(), + mixins_list=[DummyNestedClass(), DummyNestedClass(), DummyNestedClass()], + ) dummy.set_runtime_parameters( { diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 76a48ec4be..c18a30e143 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -12,33 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import tempfile from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from queue import Queue +from typing import Any, Callable, Dict, List, Optional from unittest import mock import pytest -from distilabel.distiset import Distiset, create_distiset from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.pipeline._dag import DAG from distilabel.pipeline.base import ( + _STEP_LOAD_FAILED_CODE, + _STEP_NOT_LOADED_CODE, BasePipeline, - _Batch, - _BatchManager, - _BatchManagerStep, _GlobalPipelineManager, - _WriteBuffer, ) -from distilabel.pipeline.local import Pipeline -from distilabel.steps.base import GlobalStep, Step, StepInput +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.batch_manager import _BatchManager +from distilabel.pipeline.constants import INPUT_QUEUE_ATTR_NAME, LAST_BATCH_SENT_FLAG +from distilabel.pipeline.routing_batch_function import ( + routing_batch_function, + sample_n_steps, +) +from distilabel.pipeline.write_buffer import _WriteBuffer +from distilabel.steps.base import Step, StepInput, _Step +from distilabel.steps.typing import StepOutput from distilabel.utils.serialization import TYPE_INFO_KEY +from fsspec.implementations.local import LocalFileSystem from pydantic import Field +from upath import UPath + +from .utils import ( + DummyGeneratorStep, + DummyGlobalStep, + DummyStep1, + DummyStep2, +) + -from .utils import DummyGeneratorStep, DummyStep1, DummyStep2, batch_gen +class DummyPipeline(BasePipeline): + @property + def QueueClass(self) -> Callable: + return Queue -if TYPE_CHECKING: - from distilabel.steps.base import GeneratorStep + def _run_step(self, step: "_Step", input_queue: "Queue[Any]") -> None: + pass + + def _stop(self) -> None: + pass class TestGlobalPipelineManager: @@ -46,7 +68,7 @@ def teardown_method(self) -> None: _GlobalPipelineManager.set_pipeline(None) def test_set_pipeline(self) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") _GlobalPipelineManager.set_pipeline(pipeline) assert _GlobalPipelineManager.get_pipeline() == pipeline @@ -55,7 +77,7 @@ def test_set_pipeline_none(self) -> None: assert _GlobalPipelineManager.get_pipeline() is None def test_get_pipeline(self) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") _GlobalPipelineManager.set_pipeline(pipeline) assert _GlobalPipelineManager.get_pipeline() == pipeline @@ -64,1906 +86,832 @@ class TestBasePipeline: def test_context_manager(self) -> None: assert _GlobalPipelineManager.get_pipeline() is None - with BasePipeline(name="unit-test-pipeline") as pipeline: + with DummyPipeline(name="unit-test-pipeline") as pipeline: assert pipeline is not None assert _GlobalPipelineManager.get_pipeline() == pipeline assert _GlobalPipelineManager.get_pipeline() is None - def test_get_runtime_parameters_info(self) -> None: - class DummyStep1(Step): - runtime_param1: RuntimeParameter[str] = Field( - default=None, description="runtime_param1 description" - ) - runtime_param2: Optional[RuntimeParameter[str]] = Field( - default=None, description="runtime_param2 description" + @pytest.mark.parametrize("use_cache", [False, True]) + def test_load_batch_manager(self, use_cache: bool) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") + pipeline._load_batch_manager(use_cache=True) + pipeline._cache() + + with mock.patch( + "distilabel.pipeline.base._BatchManager.load_from_cache" + ) as mock_load_from_cache, mock.patch( + "distilabel.pipeline.base._BatchManager.from_dag" + ) as mock_from_dag: + pipeline._load_batch_manager(use_cache=use_cache) + + if use_cache: + mock_load_from_cache.assert_called_once_with( + pipeline._cache_location["batch_manager"] ) + mock_from_dag.assert_not_called() + else: + mock_load_from_cache.assert_not_called() + mock_from_dag.assert_called_once_with(pipeline.dag) - def process(self, inputs: StepInput) -> None: - pass + def test_setup_write_buffer(self) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") - class DummyStep2(Step): - runtime_param3: RuntimeParameter[str] = Field( - default=None, description="runtime_param3 description" - ) - runtime_param4: Optional[RuntimeParameter[str]] = Field( - default=None, description="runtime_param4 description" - ) + pipeline._setup_write_buffer() + assert isinstance(pipeline._write_buffer, _WriteBuffer) - def process(self, inputs: StepInput) -> None: - pass + def test_set_logging_parameters(self) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") + pipeline._set_logging_parameters({"unit-test": "yes"}) - with BasePipeline(name="unit-test-pipeline") as pipeline: - DummyStep1(name="dummy_step_1") - DummyStep2(name="dummy_step_2") + assert pipeline._logging_parameters == {"unit-test": "yes"} - assert pipeline.get_runtime_parameters_info() == { - "dummy_step_1": [ - { - "description": "The number of rows that will contain the batches processed by the " - "step.", - "name": "input_batch_size", - "optional": True, - }, - { - "name": "runtime_param1", - "description": "runtime_param1 description", - "optional": False, - }, - { - "name": "runtime_param2", - "description": "runtime_param2 description", - "optional": True, - }, - ], - "dummy_step_2": [ - { - "description": "The number of rows that will contain the batches processed by the " - "step.", - "name": "input_batch_size", - "optional": True, - }, - { - "name": "runtime_param3", - "description": "runtime_param3 description", - "optional": False, - }, - { - "name": "runtime_param4", - "description": "runtime_param4 description", - "optional": True, - }, - ], + def test_setup_fsspec(self) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") + + with mock.patch("fsspec.filesystem") as mock_filesystem: + pipeline._setup_fsspec({"path": "gcs://my-bucket", "extra": "stuff"}) + + mock_filesystem.assert_called_once_with("gcs", **{"extra": "stuff"}) + + def test_setup_fsspec_default(self) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") + pipeline._setup_fsspec() + + assert isinstance(pipeline._fs, LocalFileSystem) + assert ( + pipeline._storage_base_path + == f"file://{pipeline._cache_location['batch_input_data']}" + ) + + def test_setup_fsspec_raises_value_error(self) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") + + with pytest.raises(ValueError, match="The 'path' key must be present"): + pipeline._setup_fsspec({"key": "random"}) + + def test_init_steps_load_status(self) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + pipeline._init_steps_load_status() + assert pipeline._steps_load_status == { + generator.name: _STEP_NOT_LOADED_CODE, + step.name: _STEP_NOT_LOADED_CODE, + step2.name: _STEP_NOT_LOADED_CODE, + step3.name: _STEP_NOT_LOADED_CODE, } - # Test no log, Test log, test log without close match - @pytest.mark.parametrize( - "parameters, expected", - ( - ( - { - "dummy_step_1": {"runtime_param1": "value1"}, - "dummy_step_2": {"runtime_param3": "value1"}, - }, - "", - ), - ( - { - "dummy_step_1": {"runtime_param1": "value1"}, - "dummy_step_2": { - "runtime_param3": "value1", - "runtime_param_unknown": "value1", - }, - }, - "Did you mean any of:", - ), - ( - { - "dummy_step_1": {"runtime_param1": "value1"}, - "dummy_step_2": { - "runtime_param3": "value1", - "weird_name": "value1", - }, - }, - "Available runtime parameters for the step", - ), - ), - ) - def test_check_runtime_parameters( - self, caplog, parameters: Dict[str, Any], expected: str - ) -> None: - class DummyStep1(Step): - runtime_param1: RuntimeParameter[str] = Field( - default=None, description="runtime_param1 description" - ) - runtime_param2: Optional[RuntimeParameter[str]] = Field( - default=None, description="runtime_param2 description" - ) + def test_run_load_queue_loop(self) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") - def process(self, inputs: StepInput) -> None: - pass + pipeline._load_queue = Queue() + pipeline._steps_load_status = {"dummy": 0} + pipeline._load_queue.put({"name": "dummy", "status": "loaded"}) - class DummyStep2(Step): - runtime_param3: RuntimeParameter[str] = Field( - default=None, description="runtime_param3 description" - ) - runtime_param4: Optional[RuntimeParameter[str]] = Field( - default=None, description="runtime_param4 description" - ) + thread = pipeline._run_load_queue_loop_in_thread() + pipeline._load_queue.put(None) + thread.join() - def process(self, inputs: StepInput) -> None: - pass + assert pipeline._steps_load_status["dummy"] == 1 - with BasePipeline(name="unit-test-pipeline") as pipeline: - gen_step = DummyGeneratorStep(name="dummy_generator_step") - step1 = DummyStep1(name="dummy_step_1") - step2 = DummyStep2(name="dummy_step_2") + def test_run_load_queue_loop_receiving_none(self) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") - gen_step >> step1 >> step2 + pipeline._load_queue = Queue() + pipeline._load_queue.put(None) - pipeline.run(parameters=parameters) - if expected: - assert expected in caplog.text - else: - assert caplog.text == expected + thread = pipeline._run_load_queue_loop_in_thread() + thread.join() - def test_cache_dir_env_variable(self) -> None: - with mock.patch.dict(os.environ, clear=True): - os.environ["DISTILABEL_CACHE_DIR"] = "/tmp/unit-test" - pipeline = BasePipeline(name="unit-test-pipeline") - assert pipeline._cache_dir == Path("/tmp/unit-test") + assert not thread.is_alive() - @pytest.mark.parametrize( - "in_pipeline, names", - ( - ( - True, - [ - "dummy_generator_step_0", - "dummy_step1_0", - "dummy_step2_0", - "dummy_step1_1", - ], - ), - # TODO: Activate this test once we merge the option of not passing a Pipeline - # ( - # False, ["dummy_generator_step", "dummy_step1", "dummy_step2"] - # ) - ), - ) - def test_step_names_inferred(self, in_pipeline: bool, names: List[str]) -> None: - if in_pipeline: - with BasePipeline(name="unit-test-pipeline"): - gen_step = DummyGeneratorStep() - step1_0 = DummyStep1() - step2 = DummyStep2() - step1_1 = DummyStep1() + def test_all_steps_loaded(self, caplog) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() - gen_step >> step1_0 >> step2 >> step1_1 - else: - gen_step = DummyGeneratorStep() - step1_0 = DummyStep1() - step2 = DummyStep2() - step1_1 = DummyStep1() + generator >> [step, step2] >> step3 - assert gen_step.name == names[0] - assert step1_0.name == names[1] - assert step2.name == names[2] - assert step1_1.name == names[3] + pipeline._steps_load_status = { # type: ignore + generator.name: 1, + step.name: 1, + step2.name: 1, + step3.name: 1, + } + caplog.set_level(logging.INFO) - def test_infer_step_names_big_pipeline(self) -> None: - # Tests that the name of the steps are inferred correctly when the pipeline is big (say 50 steps). - with BasePipeline(name="unit-test-pipeline") as pipe: - gen_step = DummyGeneratorStep() - for _ in range(50): - gen_step.connect(DummyStep1()) - assert list(pipe.dag.G)[-1] == "dummy_step1_49" + assert pipeline._all_steps_loaded() is True + assert "All the steps have been loaded!" in caplog.text + def test_all_steps_loaded_with_failing_step(self, caplog) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() -class TestBatch: - def test_set_data(self) -> None: - batch = _Batch(seq_no=0, step_name="step1", last_batch=False) - data = [[{"i": i} for i in range(5000)]] - batch.set_data(data) + generator >> [step, step2] >> step3 - assert batch.data == data - assert batch.size == 5000 + pipeline._init_steps_load_status() + pipeline._steps_load_status[generator.name] = _STEP_LOAD_FAILED_CODE # type: ignore + caplog.set_level(logging.INFO) - def test_next_batch(self) -> None: - batch = _Batch(seq_no=0, step_name="step1", last_batch=False) - next_batch = batch.next_batch() + assert pipeline._all_steps_loaded() is False + assert "Failed to load all the steps" in caplog.text - assert next_batch == _Batch(seq_no=1, step_name="step1", last_batch=False) + def test_all_steps_loaded_stop_aclled(self) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() - def test_accumulate(self) -> None: - batches = [ - [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}]], - ), - _Batch( - seq_no=1, - step_name="step1", - last_batch=True, - data=[[{"a": 4}, {"a": 5}, {"a": 6}]], - ), - ], + generator >> [step, step2] >> step3 + + pipeline._init_steps_load_status() + pipeline._stop_called = True + + assert pipeline._all_steps_loaded() is False + + def test_handle_stop(self) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + pipeline._add_batches_back_to_batch_manager = mock.MagicMock() + pipeline._wait_step_input_queue_empty = mock.MagicMock() + pipeline._consume_output_queue = mock.MagicMock() + + pipeline._handle_stop() + + pipeline._add_batches_back_to_batch_manager.assert_called_once() + pipeline._wait_step_input_queue_empty.assert_has_calls( [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}]], - ), - _Batch( - seq_no=1, - step_name="step2", - last_batch=True, - data=[[{"b": 4}, {"b": 5}, {"b": 6}]], - ), + mock.call(generator.name), + mock.call(step.name), + mock.call(step2.name), + mock.call(step3.name), ], - ] + any_order=True, + ) + pipeline._consume_output_queue.assert_called_once() - batch = _Batch.accumulate("step3", batches) + @pytest.mark.parametrize( + "num_workers,expected", [(0, True), (_STEP_LOAD_FAILED_CODE, True), (1, False)] + ) + def test_check_step_not_loaded_or_finished( + self, num_workers: int, expected: bool + ) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") + pipeline._steps_load_status = {"dummy": num_workers} - assert batch.seq_no == 0 - assert batch.step_name == "step3" - assert batch.last_batch is True - assert batch.data == [ - [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}], - [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}], - ] + assert pipeline._check_step_not_loaded_or_finished("dummy") is expected - def test_dump(self) -> None: - batch = _Batch(seq_no=0, step_name="step1", last_batch=False) - assert batch.dump() == { - "seq_no": 0, - "size": 0, - "step_name": "step1", - "last_batch": False, - "data": [], - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": {"module": "distilabel.pipeline.base", "name": "_Batch"}, - } + def test_is_convergence_step(self) -> None: + sample_two_steps = sample_n_steps(2) - batch = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}]], - accumulated=False, - created_from={"step0": [0, 1]}, - batch_routed_to=["step2", "step3"], - ) - assert batch.dump() == { - "seq_no": 0, - "size": 0, - "step_name": "step1", - "last_batch": False, - "data": [[{"a": 1}, {"a": 2}, {"a": 3}]], - "accumulated": False, - "created_from": {"step0": [0, 1]}, - "batch_routed_to": ["step2", "step3"], - "type_info": {"module": "distilabel.pipeline.base", "name": "_Batch"}, - } + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() - def test_from_dict(self) -> None: + generator >> sample_two_steps >> [step, step2] >> step3 + + pipeline.dag.validate() + + assert not pipeline._is_convergence_step(generator.name) # type: ignore + assert not pipeline._is_convergence_step(step.name) # type: ignore + assert not pipeline._is_convergence_step(step2.name) # type: ignore + assert pipeline._is_convergence_step(step3.name) # type: ignore + + def test_create_step_input_queue(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + generator_name: str = generator.name # type: ignore + input_queue = pipeline._create_step_input_queue(generator_name) + assert isinstance(input_queue, Queue) assert isinstance( - _Batch.from_dict( - { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [[{"a": 1}, {"a": 2}, {"a": 3}]], - "accumulated": False, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ), - _Batch, + pipeline.dag.get_step(generator_name)[INPUT_QUEUE_ATTR_NAME], Queue ) + def test_run_steps(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + pipeline._create_step_input_queue = mock.MagicMock() + pipeline._run_step = mock.MagicMock() + pipeline._run_steps() -class TestBatchManagerStep: - def test_add_batch(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []} + pipeline._create_step_input_queue.assert_has_calls( + [ + mock.call(step_name=step.name), + mock.call(step_name=generator.name), + ], + any_order=True, ) - batch = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + pipeline._run_step.assert_has_calls( + [ + mock.call(step=mock.ANY, input_queue=mock.ANY), + mock.call(step=mock.ANY, input_queue=mock.ANY), + ] ) - batch_manager_step.add_batch(batch) + def test_add_batches_back_to_batch_manager(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() - assert batch_manager_step.data["step1"] == [batch] - assert batch_manager_step.last_batch_received == [] + generator >> step - def test_add_batch_with_prepend(self) -> None: - batch_1 = _Batch( - seq_no=1, - step_name="step1", - last_batch=False, - data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]], - ) - batch_manager_step = _BatchManagerStep( - step_name="step2", - accumulate=False, - input_batch_size=10, - data={"step1": [batch_1]}, + generator_name: str = generator.name # type: ignore + step_name: str = step.name # type: ignore + + pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag) + generator_queue = Queue() + pipeline.dag.set_step_attr( + generator_name, INPUT_QUEUE_ATTR_NAME, generator_queue ) + step_queue = Queue() + pipeline.dag.set_step_attr(step_name, INPUT_QUEUE_ATTR_NAME, step_queue) - batch_0 = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + generator_queue.put( + _Batch(seq_no=0, step_name=generator_name, last_batch=False) + ) + generator_queue.put( + _Batch(seq_no=1, step_name=generator_name, last_batch=False) ) - batch_manager_step.add_batch(batch_0, prepend=True) - assert batch_manager_step.data["step1"] == [batch_0, batch_1] - assert batch_manager_step.last_batch_received == [] + step_batch_0 = _Batch(seq_no=0, step_name=step_name, last_batch=False) + step_batch_1 = _Batch(seq_no=0, step_name=step_name, last_batch=False) + step_queue.put(step_batch_0) + step_queue.put(step_batch_1) - def test_add_batch_last_batch(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []} - ) + pipeline._add_batches_back_to_batch_manager() - batch = _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}]], - ) + assert pipeline._batch_manager._steps[step_name].built_batches == [ + step_batch_0, + step_batch_1, + ] - batch_manager_step.add_batch(batch) - - assert batch_manager_step.data["step1"] == [batch] - assert batch_manager_step.last_batch_received == ["step1"] - - def test_get_batch(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=2, - data={ - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - ] - ], - size=5, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - ] - ], - size=5, - ) - ], - }, - ) + def test_consume_output_queue(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() - batch = batch_manager_step.get_batch() + generator >> step - assert batch == _Batch( - step_name="step3", - seq_no=0, - last_batch=False, - data=[ - [ - {"a": 1}, - {"a": 2}, - ], - [ - {"b": 1}, - {"b": 2}, - ], - ], - created_from={"step1": [(0, 5)], "step2": [(0, 5)]}, - ) + pipeline._output_queue = Queue() + pipeline._write_buffer = mock.MagicMock() + pipeline._handle_batch_on_stop = mock.MagicMock() - batch = batch_manager_step.get_batch() + generator_name: str = generator.name # type: ignore + step_name: str = step.name # type: ignore - assert batch == _Batch( - step_name="step3", - seq_no=1, - last_batch=False, - data=[ - [ - {"a": 3}, - {"a": 4}, - ], - [ - {"b": 3}, - {"b": 4}, - ], - ], - created_from={"step1": [(0, 5)], "step2": [(0, 5)]}, - ) + generator_batch = _Batch(seq_no=0, step_name=generator_name, last_batch=False) + step_batch = _Batch(seq_no=0, step_name=step_name, last_batch=False) - def test_get_batches_accumulate(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=True, - data={ - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - ] - ], - size=5, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - ] - ], - size=6, - ) - ], - }, - last_batch_received=["step1", "step2"], - ) + pipeline._output_queue.put(generator_batch) + pipeline._output_queue.put(step_batch) - batch = batch_manager_step.get_batch() + pipeline._consume_output_queue() - assert batch == _Batch( - step_name="step3", - seq_no=0, - last_batch=True, - accumulated=True, - data=[ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - ], - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - ], - ], - created_from={"step1": [(0, 5)], "step2": [(0, 6)]}, + pipeline._write_buffer.add_batch.assert_called_once_with(step_batch) + pipeline._handle_batch_on_stop.assert_has_calls( + [ + mock.call(generator_batch), + mock.call(step_batch), + ] ) - def test_get_batches_not_enough_data(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=2, - data={ - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[ - [ - {"a": 1}, - ] - ], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[ - [ - {"b": 1}, - {"b": 2}, - ] - ], - ) - ], - }, - ) + def test_send_batch_to_step(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + global_step = DummyGlobalStep() - assert batch_manager_step.get_batch() is None + generator >> [step, global_step] - def test_from_step(self, dummy_step_1: "Step") -> None: - batch_manager_step = _BatchManagerStep.from_step( - step=dummy_step_1, predecessors=["step1", "step2"] - ) + pipeline._batch_manager = mock.MagicMock() + pipeline._send_to_step = mock.MagicMock() + pipeline._setup_fsspec() - assert batch_manager_step.step_name == "dummy_step_1" - assert batch_manager_step.accumulate is False - assert batch_manager_step.input_batch_size == 50 - assert batch_manager_step.data == {"step1": [], "step2": []} - assert batch_manager_step.seq_no == 0 - assert batch_manager_step.last_batch_received == [] + with mock.patch( + "distilabel.pipeline.base._Batch.write_batch_data_to_fs" + ) as mock_write: + batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + pipeline._send_batch_to_step(batch) + pipeline._batch_manager.set_last_batch_sent.assert_called_once_with(batch) - def test_from_step_with_global_step(self, dummy_global_step: "GlobalStep") -> None: - batch_manager_step = _BatchManagerStep.from_step( - step=dummy_global_step, predecessors=["step1", "step2"] - ) + pipeline._send_batch_to_step( + _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore + ) - assert batch_manager_step.step_name == "dummy_global_step" - assert batch_manager_step.accumulate is True - assert batch_manager_step.input_batch_size == 50 - assert batch_manager_step.data == {"step1": [], "step2": []} - assert batch_manager_step.seq_no == 0 - assert batch_manager_step.last_batch_received == [] + # `write_batch_data_to_fs` shouldn't have been called because last batch sent with + # `_send_batch_to_step` is from a non-global step. + mock_write.assert_not_called() - def test_get_seq_no(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step2", accumulate=False, input_batch_size=5, data={"step1": []} - ) + with mock.patch( + "distilabel.pipeline.base._Batch.write_batch_data_to_fs" + ) as mock_write: + pipeline._send_batch_to_step( + _Batch(seq_no=0, step_name=global_step.name, last_batch=False) # type: ignore + ) - seq_no = batch_manager_step._get_seq_no() - - assert seq_no == 0 - assert batch_manager_step.seq_no == 1 - - def test_get_data(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=5, - data={ - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[ - [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}] - ], - size=6, - batch_routed_to=["step1", "step2"], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - size=7, - batch_routed_to=["step1", "step2"], - ) - ], - }, + # `write_batch_data_to_fs` should have been called because last batch sent with + # `_send_batch_to_step` is from a global step. + mock_write.assert_called_once_with( + pipeline._fs, + UPath(pipeline._storage_base_path) / global_step.name, ) - data, created_from, routed_to = batch_manager_step._get_data() - assert data == [ - [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}], - [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}], - ] - assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]} - assert routed_to == ["step1", "step2"] - - assert batch_manager_step.data == { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 6}]], - size=6, - batch_routed_to=["step1", "step2"], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 6}, {"b": 7}]], - size=7, - batch_routed_to=["step1", "step2"], - ) - ], - } + pipeline._use_fs_to_pass_data = True - def test_get_data_accumulate(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=True, - data={ - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[ - [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}] - ], - size=6, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - size=7, - ) - ], - }, + with mock.patch( + "distilabel.pipeline.base._Batch.write_batch_data_to_fs" + ) as mock_write: + pipeline._send_batch_to_step( + _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + ) + + # `write_batch_data_to_fs` shouldn't have been called because generator receives + # empty batches, so there's no data to write. + mock_write.assert_not_called() + + with mock.patch( + "distilabel.pipeline.base._Batch.write_batch_data_to_fs" + ) as mock_write: + pipeline._send_batch_to_step( + _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore + ) + pipeline._send_batch_to_step( + _Batch(seq_no=0, step_name=global_step.name, last_batch=False) # type: ignore + ) + + mock_write.assert_has_calls( + [ + mock.call( + pipeline._fs, + UPath(pipeline._storage_base_path) / step.name, + ), + mock.call( + pipeline._fs, + UPath(pipeline._storage_base_path) / global_step.name, + ), + ] ) - data, created_from, routed_to = batch_manager_step._get_data() + def test_register_batch(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() - assert data == [ - [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}], - [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}], - ] - assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]} - assert routed_to == [] + generator >> step - assert batch_manager_step.data == {"step1": [], "step2": []} + pipeline._batch_manager = mock.MagicMock() + batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + pipeline._register_batch(batch) - def test_get_data_convergence_step(self) -> None: - batch_a_0 = _Batch( - seq_no=0, - step_name="A", - last_batch=False, - data=[ - [ - {"generation": "Hello, I'm A 0"}, - {"generation": "Hello, I'm A 0"}, - {"generation": "Hello, I'm A 0"}, - ] - ], - size=3, - created_from={"Z": [(0, 3)]}, + pipeline._batch_manager.register_batch.assert_called_once_with(batch) + + def test_send_last_batch_flag_to_step(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + step_name: str = step.name # type: ignore + + pipeline._batch_manager = _BatchManager( + steps={}, + last_batch_received={step_name: None}, + last_batch_sent={step_name: None}, + last_batch_flag_sent_to=[], ) - batch_a_1 = _Batch( - seq_no=1, - step_name="A", - last_batch=False, - data=[ - [ - {"generation": "Hello, I'm A 1"}, - {"generation": "Hello, I'm A 1"}, - {"generation": "Hello, I'm A 1"}, - ] - ], - size=3, - created_from={"Z": [(1, 3)]}, + with mock.patch.object(pipeline, "_send_to_step") as mock_sent_to_step: + pipeline._send_last_batch_flag_to_step(step_name) + + mock_sent_to_step.assert_called_once_with(step_name, LAST_BATCH_SENT_FLAG) + + pipeline._batch_manager._last_batch_sent[step_name] = _Batch( + seq_no=0, + step_name=step_name, + last_batch=True, ) + with mock.patch.object(pipeline, "_send_to_step") as mock_sent_to_step: + pipeline._send_last_batch_flag_to_step(step_name) + + mock_sent_to_step.assert_not_called() + + def test_request_initial_batches(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1(input_batch_size=5) - batch_b_0 = _Batch( + generator >> step + + generator2 = DummyGeneratorStep() + step2 = DummyStep1(input_batch_size=5) + + generator2 >> step2 + + pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag) + + # Simulate there were batches from the cache for the steps + batch_0 = _Batch( seq_no=0, - step_name="B", + step_name=generator.name, # type: ignore last_batch=False, - data=[ - [ - {"generation": "Hello, I'm B 0"}, - {"generation": "Hello, I'm B 0"}, - {"generation": "Hello, I'm B 0"}, - ] - ], - size=3, - created_from={"Z": [(0, 3)]}, + data=[[{"a": i} for i in range(5)]], ) + pipeline._batch_manager._steps[step.name].data[generator.name] = [ # type: ignore + batch_0 + ] - batch_c_0 = _Batch( + batch_1 = _Batch( seq_no=0, - step_name="C", + step_name=generator2.name, # type: ignore last_batch=False, - data=[ - [ - {"generation": "Hello, I'm C 0"}, - {"generation": "Hello, I'm C 0"}, - {"generation": "Hello, I'm C 0"}, - ] + data=[[{"b": i} for i in range(5)]], + ) # type: ignore + pipeline._batch_manager._steps[step2.name].data[generator2.name] = [ # type: ignore + batch_1 + ] + + with mock.patch.object( + pipeline, "_send_batch_to_step" + ) as mock_send_batch_to_step: + pipeline._request_initial_batches() + + mock_send_batch_to_step.assert_has_calls( + [ + mock.call(mock.ANY), + mock.call(mock.ANY), + mock.call(_Batch(seq_no=0, step_name=generator.name, last_batch=False)), # type: ignore + mock.call( + _Batch(seq_no=0, step_name=generator2.name, last_batch=False) # type: ignore + ), ], - size=3, - created_from={"Z": [(1, 3)]}, + any_order=True, ) - batch_manager_step = _BatchManagerStep( - step_name="D", - input_batch_size=3, - convergence_step=True, - accumulate=False, - data={"A": [batch_a_0, batch_a_1], "B": [batch_b_0], "C": [batch_c_0]}, - ) + def test_request_more_batches_if_needed(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() - data, created_from, routed_to = batch_manager_step._get_data() + generator >> step - assert data == [ - [ - {"generation": "Hello, I'm A 0"}, - {"generation": "Hello, I'm A 0"}, - {"generation": "Hello, I'm A 0"}, - ], - [ - {"generation": "Hello, I'm B 0"}, - {"generation": "Hello, I'm B 0"}, - {"generation": "Hello, I'm B 0"}, - ], - ] - assert created_from == {"A": [(0, 3)], "B": [(0, 3)]} - assert routed_to == [] - assert batch_manager_step.next_expected_created_from_batch_seq_no == 1 + generator_name: str = generator.name # type: ignore - data, created_from, routed_to = batch_manager_step._get_data() + pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag) - assert data == [ - [ - {"generation": "Hello, I'm A 1"}, - {"generation": "Hello, I'm A 1"}, - {"generation": "Hello, I'm A 1"}, - ], - [ - {"generation": "Hello, I'm C 0"}, - {"generation": "Hello, I'm C 0"}, - {"generation": "Hello, I'm C 0"}, - ], - ] - assert created_from == {"A": [(1, 3)], "C": [(0, 3)]} - assert routed_to == [] - assert batch_manager_step.next_expected_created_from_batch_seq_no == 2 + batch = _Batch(seq_no=0, step_name=generator_name, last_batch=False) + pipeline._batch_manager._last_batch_sent[generator_name] = batch - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ] - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - {"a": 6}, - ] - ], - ) - ] - }, - ["step1"], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ] - }, - ["step1"], - True, - ), - ], - ) - def test_last_batch( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step2", - accumulate=False, - input_batch_size=5, - data=data, - last_batch_received=last_batch_received, - ) + with mock.patch.object( + pipeline, "_send_batch_to_step" + ) as mock_send_batch_to_step: + pipeline._request_more_batches_if_needed(step) - assert batch_manager_step._last_batch() is expected + mock_send_batch_to_step.assert_called_once_with(batch.next_batch()) - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - ["step1"], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - ["step1", "step2"], - True, - ), - ], - ) - def test_last_batch_accumulate( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=True, - data=data, - last_batch_received=last_batch_received, + def test_handle_batch_on_stop(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1(input_batch_size=5) + step2 = DummyStep1(input_batch_size=5) + step3 = DummyStep1(input_batch_size=5) + + generator >> [step, step2, step3] + + batch_manager_mock = mock.MagicMock() + pipeline._batch_manager = batch_manager_mock + + batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + pipeline._handle_batch_on_stop(batch) + + batch_manager_mock.register_batch.assert_called_once_with(batch) + batch_manager_mock.add_batch.assert_has_calls( + [ + mock.call(step.name, batch), + mock.call(step2.name, batch), + mock.call(step3.name, batch), + ] ) - assert batch_manager_step._last_batch() is expected + def test_get_step_from_batch(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - created_from={"step0": [(0, 5)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - created_from={"step0": [(0, 5)]}, - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - created_from={"step0": [(0, 5)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - created_from={"step0": [(0, 5)]}, - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}]], - created_from={"step0": [(0, 3)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}]], - created_from={"step0": [(0, 3)]}, - ) - ], - }, - [], - True, - ), - ], - ) - def test_last_batch_convergence_step( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=False, - data=data, - last_batch_received=last_batch_received, - input_batch_size=3, - convergence_step=True, + generator >> step + + batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + assert pipeline._get_step_from_batch(batch) == generator + + batch = _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore + assert pipeline._get_step_from_batch(batch) == step + + def test_notify_steps_to_stop(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1(input_batch_size=5) + + generator >> step + + with mock.patch.object(pipeline, "_send_to_step") as mock_send_to_step: + pipeline._notify_steps_to_stop() + + mock_send_to_step.assert_has_calls( + [ + mock.call(generator.name, None), + mock.call(step.name, None), + ] ) - assert batch_manager_step._last_batch() is expected + def test_get_successors(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + ) == ([step.name, step2.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore + ) == ([step3.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step2.name, last_batch=False) # type: ignore + ) == ([step3.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step3.name, last_batch=False) # type: ignore + ) == ([], False) + + def test_get_successors_with_routing_batch_function(self) -> None: + @routing_batch_function() + def fixed_routing_batch_function(steps: List[str]) -> List[str]: + return ["step_2", "step_3"] + + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1(name="step_1") + step2 = DummyStep1(name="step_2") + step3 = DummyStep1(name="step_3") + step4 = DummyStep2(name="step_4") + + generator >> fixed_routing_batch_function >> [step, step2, step3] >> step4 + + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + ) == (["step_2", "step_3"], True) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore + ) == ([step4.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step2.name, last_batch=False) # type: ignore + ) == ([step4.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step3.name, last_batch=False) # type: ignore + ) == ([step4.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step4.name, last_batch=False) # type: ignore + ) == ([], False) - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( + def test_get_runtime_parameters_info(self) -> None: + class DummyStep1(Step): + runtime_param1: RuntimeParameter[str] = Field( + default=None, description="runtime_param1 description" + ) + runtime_param2: Optional[RuntimeParameter[str]] = Field( + default=None, description="runtime_param2 description" + ) + + def process(self, inputs: StepInput) -> None: + pass + + class DummyStep2(Step): + runtime_param3: RuntimeParameter[str] = Field( + default=None, description="runtime_param3 description" + ) + runtime_param4: Optional[RuntimeParameter[str]] = Field( + default=None, description="runtime_param4 description" + ) + + def process(self, inputs: StepInput) -> None: + pass + + with DummyPipeline(name="unit-test-pipeline") as pipeline: + DummyStep1(name="dummy_step_1") + DummyStep2(name="dummy_step_2") + + assert pipeline.get_runtime_parameters_info() == { + "dummy_step_1": [ { - "step1": [], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], + "description": "The number of rows that will contain the batches processed by the " + "step.", + "name": "input_batch_size", + "optional": True, }, - [], - False, - ), - ( { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], + "name": "runtime_param1", + "description": "runtime_param1 description", + "optional": False, }, - [], - False, - ), - ( { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], + "name": "runtime_param2", + "description": "runtime_param2 description", + "optional": True, }, - ["step1", "step2"], - True, - ), - ( + ], + "dummy_step_2": [ { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], - ) - ], + "description": "The number of rows that will contain the batches processed by the " + "step.", + "name": "input_batch_size", + "optional": True, }, - ["step1", "step2"], - True, - ), - ], - ) - def test_ready_to_create_batch( - self, - data: Dict[str, List[Dict[str, Any]]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step2", - accumulate=False, - input_batch_size=5, - data=data, - last_batch_received=last_batch_received, - ) - - assert batch_manager_step._ready_to_create_batch() is expected - - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], + "name": "runtime_param3", + "description": "runtime_param3 description", + "optional": False, }, - ["step1", "step2"], - True, - ), - ( { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], + "name": "runtime_param4", + "description": "runtime_param4 description", + "optional": True, }, - ["step1"], - False, - ), - ], - ) - def test_ready_to_create_batch_accumulate( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=True, - data=data, - last_batch_received=last_batch_received, - ) - - assert batch_manager_step._ready_to_create_batch() is expected - - def test_dump(self) -> None: - batch_step_1 = _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]], - size=6, - ) - batch_step_2 = _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[ - [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}] ], - size=7, - ) - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=True, - data={ - "step1": [batch_step_1], - "step2": [batch_step_2], - }, - ) - assert batch_manager_step.dump() == { - "step_name": "step3", - "accumulate": True, - "convergence_step": False, - "convergence_step_batches_consumed": {}, - "input_batch_size": None, - "data": { - "step1": [ - { - "seq_no": 0, - "step_name": "step1", - "last_batch": True, - "data": [ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - {"a": 6}, - ] - ], - "size": 6, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - } - ], - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "size": 7, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - } - ], - }, - "seq_no": 0, - "last_batch_received": [], - "next_expected_created_from_batch_seq_no": 0, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, } + # Test no log, Test log, test log without close match @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 5)]}, - ) - ], - "step2": [], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 5)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 5)]}, - ) - ], - }, - [], - True, - ), + "parameters, expected", + ( ( { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 4)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 5)]}, - ) - ], + "dummy_step_1": {"runtime_param1": "value1"}, + "dummy_step_2": {"runtime_param3": "value1"}, }, - [], - False, + "", ), ( { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 4)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 4)]}, - ) - ], + "dummy_step_1": {"runtime_param1": "value1"}, + "dummy_step_2": { + "runtime_param3": "value1", + "runtime_param_unknown": "value1", + }, }, - ["step1", "step2"], - True, + "Did you mean any of:", ), ( { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 4)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 5)]}, - ) - ], + "dummy_step_1": {"runtime_param1": "value1"}, + "dummy_step_2": { + "runtime_param3": "value1", + "weird_name": "value1", + }, }, - [], - False, + "Available runtime parameters for the step", ), - ], + ), ) - def test_ready_to_create_batch_convergence_step( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, + def test_check_runtime_parameters( + self, caplog, parameters: Dict[str, Any], expected: str ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=5, - data=data, - last_batch_received=last_batch_received, - convergence_step=True, - ) - - assert batch_manager_step._ready_to_create_batch() is expected - - def test_from_dict(self) -> None: - batch_manager_step = _BatchManagerStep.from_dict( - { - "step_name": "step3", - "accumulate": True, - "convergence_step": False, - "convergence_step_batches_consumed": {0: {"Z": 1234}}, - "input_batch_size": None, - "data": { - "step1": [ - { - "seq_no": 0, - "step_name": "step1", - "last_batch": True, - "data": [ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - {"a": 6}, - ] - ], - "size": 6, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - } - ], - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "size": 7, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - } - ], - }, - "seq_no": 0, - "last_batch_received": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - } - ) - - assert isinstance(batch_manager_step, _BatchManagerStep) - assert batch_manager_step.step_name == "step3" - assert batch_manager_step.accumulate is True - assert batch_manager_step.convergence_step is False - assert batch_manager_step.convergence_step_batches_consumed == {0: {"Z": 1234}} - assert batch_manager_step.input_batch_size is None - assert batch_manager_step.seq_no == 0 - assert batch_manager_step.last_batch_received == [] - - -class TestBatchManager: - def test_add_batch(self) -> None: - batch_manager = _BatchManager( - steps={ - "step3": _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=5, - data={"step1": [], "step2": []}, - ) - }, - last_batch_received={"step3": None}, - last_batch_sent={"step3": None}, - last_batch_flag_sent_to=[], - ) - - batch_from_step_1 = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - batch_manager.add_batch(to_step="step3", batch=batch_from_step_1) + class DummyStep1(Step): + runtime_param1: RuntimeParameter[str] = Field( + default=None, description="runtime_param1 description" + ) + runtime_param2: Optional[RuntimeParameter[str]] = Field( + default=None, description="runtime_param2 description" + ) - assert batch_manager._steps["step3"].data == { - "step1": [batch_from_step_1], - "step2": [], - } + def process(self, inputs: StepInput) -> StepOutput: # type: ignore + yield [{}] - def test_add_batch_with_prepend(self) -> None: - batch_1 = _Batch( - seq_no=1, - step_name="step1", - last_batch=False, - data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]], - ) - batch_manager = _BatchManager( - steps={ - "step3": _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=5, - data={ - "step1": [batch_1], - "step2": [], - }, - ) - }, - last_batch_received={"step3": None}, - last_batch_sent={"step3": None}, - last_batch_flag_sent_to=[], - ) - batch_0 = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - batch_manager.add_batch(to_step="step3", batch=batch_0, prepend=True) - assert batch_manager._steps["step3"].data == { - "step1": [batch_0, batch_1], - "step2": [], - } + class DummyStep2(Step): + runtime_param3: RuntimeParameter[str] = Field( + default=None, description="runtime_param3 description" + ) + runtime_param4: Optional[RuntimeParameter[str]] = Field( + default=None, description="runtime_param4 description" + ) - def test_from_dag( - self, - dummy_generator_step: "GeneratorStep", - dummy_step_1: "Step", - dummy_step_2: "Step", - dummy_global_step: "GlobalStep", - ) -> None: - dag = DAG() - dag.add_step(dummy_generator_step) - dag.add_step(dummy_step_1) - dag.add_step(dummy_step_2) - dag.add_step(dummy_global_step) - dag.add_edge("dummy_generator_step", "dummy_step_1") - dag.add_edge("dummy_generator_step", "dummy_global_step") - dag.add_edge("dummy_step_1", "dummy_step_2") - - batch_manager = _BatchManager.from_dag(dag) - - assert batch_manager._steps == { - "dummy_step_1": _BatchManagerStep( - step_name="dummy_step_1", - accumulate=False, - input_batch_size=50, - data={"dummy_generator_step": []}, - ), - "dummy_global_step": _BatchManagerStep( - step_name="dummy_global_step", - accumulate=True, - input_batch_size=50, - data={"dummy_generator_step": []}, - ), - "dummy_step_2": _BatchManagerStep( - step_name="dummy_step_2", - accumulate=False, - input_batch_size=50, - data={"dummy_step_1": []}, - ), - } + def process(self, inputs: StepInput) -> StepOutput: # type: ignore + yield [{}] - def test_can_generate(self) -> None: - batch_manager = _BatchManager( - steps={}, - last_batch_received={ - "step_1": _Batch(seq_no=0, step_name="step_1", last_batch=False), - "step_2": _Batch(seq_no=0, step_name="step_2", last_batch=False), - "step_3": _Batch(seq_no=0, step_name="step_3", last_batch=False), - }, - last_batch_sent={"step_1": None, "step_2": None, "step_3": None}, - last_batch_flag_sent_to=[], - ) + with DummyPipeline(name="unit-test-pipeline") as pipeline: + gen_step = DummyGeneratorStep(name="dummy_generator_step") + step1 = DummyStep1(name="dummy_step_1") + step2 = DummyStep2(name="dummy_step_2") - assert batch_manager.can_generate() + gen_step >> step1 >> step2 - batch_1 = _Batch(seq_no=0, step_name="step_1", last_batch=True) - batch_2 = _Batch(seq_no=0, step_name="step_2", last_batch=True) - batch_3 = _Batch(seq_no=0, step_name="step_3", last_batch=True) + pipeline.run(parameters=parameters) + if expected: + assert expected in caplog.text + else: + assert "Did you mean any of:" not in expected + assert "Available runtime parameters for the step" not in expected - batch_manager = _BatchManager( - steps={}, - last_batch_received={ - "step_1": batch_1, - "step_2": batch_2, - "step_3": batch_3, - }, - last_batch_sent={"step_1": batch_1, "step_2": batch_2, "step_3": batch_3}, - last_batch_flag_sent_to=[], - ) + def test_cache_dir_env_variable(self) -> None: + with mock.patch.dict(os.environ, clear=True): + os.environ["DISTILABEL_CACHE_DIR"] = "/tmp/unit-test" + pipeline = DummyPipeline(name="unit-test-pipeline") + assert pipeline._cache_dir == Path("/tmp/unit-test") - assert not batch_manager.can_generate() - - def test_dump(self) -> None: - batch_manager = _BatchManager( - steps={ - "step3": _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=5, - data={"step1": [], "step2": []}, - seq_no=1, - ) - }, - last_batch_received={ - "step3": _Batch( - seq_no=0, - step_name="step3", - last_batch=False, - ) - }, - last_batch_sent={ - "step3": _Batch( - seq_no=1, - step_name="step3", - last_batch=False, - ) - }, - last_batch_flag_sent_to=["step99"], - ) - assert batch_manager.dump() == { - "steps": { - "step3": { - "step_name": "step3", - "accumulate": False, - "convergence_step": False, - "convergence_step_batches_consumed": {}, - "input_batch_size": 5, - "data": {"step1": [], "step2": []}, - "seq_no": 1, - "last_batch_received": [], - "next_expected_created_from_batch_seq_no": 0, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - }, - }, - "last_batch_received": { - "step3": { - "seq_no": 0, - "step_name": "step3", - "batch_routed_to": [], - "created_from": {}, - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - }, - "last_batch_sent": { - "step3": { - "seq_no": 1, - "step_name": "step3", - "batch_routed_to": [], - "created_from": {}, - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - }, - "last_batch_flag_sent_to": ["step99"], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManager", - }, - } + @pytest.mark.parametrize( + "in_pipeline, names", + ( + ( + True, + [ + "dummy_generator_step_0", + "dummy_step1_0", + "dummy_step2_0", + "dummy_step1_1", + ], + ), + # TODO: Activate this test once we merge the option of not passing a Pipeline + # ( + # False, ["dummy_generator_step", "dummy_step1", "dummy_step2"] + # ) + ), + ) + def test_step_names_inferred(self, in_pipeline: bool, names: List[str]) -> None: + if in_pipeline: + with DummyPipeline(name="unit-test-pipeline"): + gen_step = DummyGeneratorStep() + step1_0 = DummyStep1() + step2 = DummyStep2() + step1_1 = DummyStep1() - def test_from_dict(self) -> None: - batch_manager_step = _BatchManagerStep.from_dict( - { - "step_name": "step3", - "accumulate": True, - "convergence_step": False, - "input_batch_size": None, - "data": { - "step1": [ - { - "seq_no": 0, - "step_name": "step1", - "last_batch": True, - "data": [ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - {"a": 6}, - ] - ], - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - } - ], - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - } - ], - }, - "seq_no": 0, - "last_batch_received": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - } - ) + gen_step >> step1_0 >> step2 >> step1_1 + else: + gen_step = DummyGeneratorStep() + step1_0 = DummyStep1() + step2 = DummyStep2() + step1_1 = DummyStep1() - with tempfile.TemporaryDirectory() as tmpdirname: - batch_manager_step.save(Path(tmpdirname) / "batch_manager_step3.json") + assert gen_step.name == names[0] + assert step1_0.name == names[1] + assert step2.name == names[2] + assert step1_1.name == names[3] - batch_manager = _BatchManager.from_dict( - { - "steps": { - "step3": str(Path(tmpdirname) / "batch_manager_step3.json") - }, - "last_batch_received": { - "step3": { - "seq_no": 0, - "step_name": "step3", - "last_batch": False, - "data": [], - "accumulated": False, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - }, - "last_batch_sent": { - "step3": { - "seq_no": 0, - "step_name": "step3", - "last_batch": False, - "data": [], - "accumulated": False, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - }, - "last_batch_flag_sent_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManager", - }, - } - ) - assert isinstance(batch_manager, _BatchManager) - assert all( - isinstance(step, _BatchManagerStep) - for _, step in batch_manager._steps.items() - ) - assert all( - isinstance(batch, _Batch) - for _, batch in batch_manager._last_batch_received.items() - ) + def test_infer_step_names_big_pipeline(self) -> None: + # Tests that the name of the steps are inferred correctly when the pipeline is big (say 50 steps). + with DummyPipeline(name="unit-test-pipeline") as pipe: + gen_step = DummyGeneratorStep() + for _ in range(50): + gen_step.connect(DummyStep1()) + assert list(pipe.dag.G)[-1] == "dummy_step1_49" class TestPipelineSerialization: def test_base_pipeline_dump(self): - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") dump = pipeline.dump() assert len(dump.keys()) == 2 assert "pipeline" in dump assert "distilabel" in dump assert TYPE_INFO_KEY in dump["pipeline"] - assert dump["pipeline"][TYPE_INFO_KEY]["module"] == "distilabel.pipeline.base" - assert dump["pipeline"][TYPE_INFO_KEY]["name"] == "BasePipeline" + assert ( + dump["pipeline"][TYPE_INFO_KEY]["module"] == "tests.unit.pipeline.test_base" + ) + assert dump["pipeline"][TYPE_INFO_KEY]["name"] == "DummyPipeline" def test_base_pipeline_from_dict(self): - pipeline = BasePipeline(name="unit-test-pipeline") - pipe = BasePipeline.from_dict(pipeline.dump()) - assert isinstance(pipe, BasePipeline) + pipeline = DummyPipeline(name="unit-test-pipeline") + pipe = DummyPipeline.from_dict(pipeline.dump()) + assert isinstance(pipe, DummyPipeline) def test_pipeline_dump(self): from distilabel.pipeline.local import Pipeline @@ -1980,8 +928,8 @@ def test_pipeline_dump(self): @pytest.mark.parametrize( "format, name, loader", [ - ("yaml", "pipe.yaml", BasePipeline.from_yaml), - ("json", "pipe.json", BasePipeline.from_json), + ("yaml", "pipe.yaml", DummyPipeline.from_yaml), + ("json", "pipe.json", DummyPipeline.from_json), ("invalid", "pipe.invalid", None), ], ) @@ -1991,7 +939,7 @@ def test_pipeline_to_from_file_format( name: str, loader: Callable, ) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") with tempfile.TemporaryDirectory() as tmpdirname: filename = Path(tmpdirname) / name @@ -2002,10 +950,10 @@ def test_pipeline_to_from_file_format( pipeline.save(filename, format=format) assert filename.exists() pipe_from_file = loader(filename) - assert isinstance(pipe_from_file, BasePipeline) + assert isinstance(pipe_from_file, DummyPipeline) def test_base_pipeline_signature(self): - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") # Doesn't matter if it's exactly this or not, the test should fail if we change the # way this is created. signature = pipeline._create_signature() @@ -2036,62 +984,6 @@ def test_base_pipeline_signature(self): signature = pipeline._create_signature() assert signature == "a11ac46253598e6fe126420b23b9ad31c6422c92" - @pytest.mark.parametrize("use_cache", [True, False]) - def test_run_pipe_and_load_from_cache(self, use_cache: bool): - # Maybe not the best place for this test, but does the work for now - from distilabel.pipeline.base import BasePipeline - from distilabel.pipeline.routing_batch_function import sample_n_steps - - from tests.unit.pipeline.utils import DummyGeneratorStep, DummyStep1, DummyStep2 - - sample_two_steps = sample_n_steps(2) - - with tempfile.TemporaryDirectory() as tmpdirname: - with BasePipeline( - name="unit-test-pipeline", cache_dir=tmpdirname - ) as pipeline: - dummy_generator = DummyGeneratorStep() - dummy_step_1_0 = DummyStep1() - dummy_step_1_1 = DummyStep1() - dummy_step_1_2 = DummyStep1() - dummy_step_2 = DummyStep2() - - ( - dummy_generator - >> sample_two_steps - >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2] - >> dummy_step_2 - ) - - pipeline.run({}, use_cache=use_cache) - - assert not pipeline._cache_location["pipeline"].exists() - # Set the _BatchManager to the pipeline to check it exists afterwards - pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag) - pipeline._cache() - - assert pipeline._cache_location["pipeline"].exists() - - with BasePipeline(name="unit-test-pipeline", cache_dir=tmpdirname) as pipe: - dummy_generator = DummyGeneratorStep() - dummy_step_1_0 = DummyStep1() - dummy_step_1_1 = DummyStep1() - dummy_step_1_2 = DummyStep1() - dummy_step_2 = DummyStep2() - - ( - dummy_generator - >> sample_two_steps - >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2] - >> dummy_step_2 - ) - - pipe.run({}, use_cache=use_cache) - if use_cache: - assert pipe._batch_manager - else: - assert not pipe._batch_manager - def test_binary_rshift_operator(self) -> None: # Tests the steps can be connected using the >> operator. from distilabel.pipeline.local import Pipeline @@ -2210,126 +1102,3 @@ def test_binary_operators(self) -> None: signature_2 = pipeline_2._create_signature() assert signature_1 == signature_2 - - -class TestWriteBuffer: - def test_create(self) -> None: - with tempfile.TemporaryDirectory() as tmpdirname: - folder = Path(tmpdirname) / "data" - with Pipeline(name="unit-test-pipeline") as pipeline: - dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1") - dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2") - dummy_step_1 = DummyStep1(name="dummy_step_1") - dummy_step_2 = DummyStep2(name="dummy_step_2") - dummy_step_3 = DummyStep2(name="dummy_step_3") - - dummy_generator_1.connect(dummy_step_1) - dummy_generator_2.connect(dummy_step_2) - dummy_step_1.connect(dummy_step_2) - dummy_step_1.connect(dummy_step_3) - - write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) - - assert write_buffer._buffers == {"dummy_step_2": [], "dummy_step_3": []} - assert write_buffer._buffers_dump_batch_size == { - "dummy_step_2": 50, - "dummy_step_3": 50, - } - assert write_buffer._buffer_last_schema == {} - assert write_buffer._buffers_last_file == { - "dummy_step_2": 1, - "dummy_step_3": 1, - } - - def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None: - with tempfile.TemporaryDirectory() as tmpdirname: - folder = Path(tmpdirname) / "data" - with Pipeline(name="unit-test-pipeline") as pipeline: - dummy_generator = DummyGeneratorStep(name="dummy_generator_step") - dummy_step_1 = DummyStep1(name="dummy_step_1") - dummy_step_2 = DummyStep2(name="dummy_step_2") - - dummy_generator.connect(dummy_step_1) - dummy_step_1.connect(dummy_step_2) - - write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) - - # Add one batch with 5 rows, shouldn't write anything 5 < 50 - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - # Add 45 more rows, should write now - for _ in range(9): - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - assert Path(folder, "dummy_step_2", "00001.parquet").exists() - - # Add 50 more rows, we should have a new file - for _ in range(10): - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - assert Path(folder, "dummy_step_2", "00002.parquet").exists() - - # Add more rows and close the write buffer, we should have a new file - for _ in range(5): - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - write_buffer.close() - - assert Path(folder, "dummy_step_2", "00003.parquet").exists() - - ds = create_distiset(write_buffer._path) - assert isinstance(ds, Distiset) - assert len(ds.keys()) == 1 - assert len(ds["default"]["train"]) == 125 - - def test_write_buffer_multiple_leaf_steps_and_create_dataset(self): - with tempfile.TemporaryDirectory() as tmpdirname: - folder = Path(tmpdirname) / "data" - with Pipeline(name="unit-test-pipeline") as pipeline: - dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1") - dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2") - dummy_step_1 = DummyStep1(name="dummy_step_1") - dummy_step_2 = DummyStep2(name="dummy_step_2") - dummy_step_3 = DummyStep2(name="dummy_step_3") - - dummy_generator_1.connect(dummy_step_1) - dummy_generator_2.connect(dummy_step_2) - dummy_step_1.connect(dummy_step_2) - dummy_step_1.connect(dummy_step_3) - - write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) - - for _ in range(10): - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - assert Path(folder, "dummy_step_2", "00001.parquet").exists() - - for _ in range(10): - batch = batch_gen(dummy_step_3.name) - write_buffer.add_batch(batch) - - assert Path(folder, "dummy_step_3", "00001.parquet").exists() - - for _ in range(5): - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - for _ in range(5): - batch = batch_gen(dummy_step_3.name) - write_buffer.add_batch(batch) - - write_buffer.close() - - assert Path(folder, "dummy_step_2", "00002.parquet").exists() - assert Path(folder, "dummy_step_3", "00002.parquet").exists() - - ds = create_distiset(write_buffer._path) - assert isinstance(ds, Distiset) - assert len(ds.keys()) == 2 - assert len(ds["dummy_step_2"]["train"]) == 75 - assert len(ds["dummy_step_3"]["train"]) == 75 diff --git a/tests/unit/pipeline/test_batch.py b/tests/unit/pipeline/test_batch.py new file mode 100644 index 0000000000..ed246e491f --- /dev/null +++ b/tests/unit/pipeline/test_batch.py @@ -0,0 +1,172 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.pipeline.batch import _Batch + + +class TestBatch: + def test_get_data(self) -> None: + batch = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[ + [ + {"a": 0}, + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + ] + ], + ) + + batch.set_data( + [ + [ + {"a": 0}, + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + ] + ] + ) + + old_hash = batch.data_hash + + data = batch.get_data(5) + assert data == [{"a": 0}, {"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}] + assert batch.data == [[{"a": 5}, {"a": 6}]] + assert batch.data_hash != old_hash + + def test_set_data(self) -> None: + batch = _Batch(seq_no=0, step_name="step1", last_batch=False) + data = [[{"i": i} for i in range(5000)]] + batch.set_data(data) + + assert batch.data == data + assert batch.size == 5000 + + def test_next_batch(self) -> None: + batch = _Batch(seq_no=0, step_name="step1", last_batch=False) + next_batch = batch.next_batch() + + assert next_batch == _Batch(seq_no=1, step_name="step1", last_batch=False) + + def test_accumulate(self) -> None: + batches = [ + [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + ), + _Batch( + seq_no=1, + step_name="step1", + last_batch=True, + data=[[{"a": 4}, {"a": 5}, {"a": 6}]], + ), + ], + [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}]], + ), + _Batch( + seq_no=1, + step_name="step2", + last_batch=True, + data=[[{"b": 4}, {"b": 5}, {"b": 6}]], + ), + ], + ] + + batch = _Batch.accumulate("step3", batches) + + assert batch.seq_no == 0 + assert batch.step_name == "step3" + assert batch.last_batch is True + assert batch.data == [ + [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}], + [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}], + ] + + def test_dump(self) -> None: + batch = _Batch(seq_no=0, step_name="step1", last_batch=False) + assert batch.dump() == { + "seq_no": 0, + "size": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "data_hash": None, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": {"module": "distilabel.pipeline.batch", "name": "_Batch"}, + } + + batch = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + data_hash="hash", + accumulated=False, + created_from={"step0": [(0, 5), (1, 5)]}, + batch_routed_to=["step2", "step3"], + ) + assert batch.dump() == { + "seq_no": 0, + "size": 0, + "step_name": "step1", + "last_batch": False, + "data": [[{"a": 1}, {"a": 2}, {"a": 3}]], + "data_hash": "hash", + "accumulated": False, + "created_from": {"step0": [(0, 5), (1, 5)]}, + "batch_routed_to": ["step2", "step3"], + "type_info": {"module": "distilabel.pipeline.batch", "name": "_Batch"}, + } + + def test_from_dict(self) -> None: + batch = _Batch.from_dict( + { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [[{"a": 1}, {"a": 2}, {"a": 3}]], + "accumulated": False, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ) + + assert isinstance(batch, _Batch) + assert batch.seq_no == 0 + assert batch.step_name == "step1" + assert batch.last_batch is False + assert batch.data == [[{"a": 1}, {"a": 2}, {"a": 3}]] + assert batch.accumulated is False diff --git a/tests/unit/pipeline/test_batch_manager.py b/tests/unit/pipeline/test_batch_manager.py new file mode 100644 index 0000000000..7b1cb1a8a6 --- /dev/null +++ b/tests/unit/pipeline/test_batch_manager.py @@ -0,0 +1,2214 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path +from typing import Dict, List + +import pytest +from distilabel.pipeline._dag import DAG +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.batch_manager import _BatchManager, _BatchManagerStep +from distilabel.steps.base import GeneratorStep, GlobalStep, Step + + +class TestBatchManagerStep: + def test_add_batch(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []} + ) + + batch = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + ) + + batch_manager_step.add_batch(batch) + + assert batch_manager_step.data["step1"] == [batch] + assert batch_manager_step.last_batch_received == [] + + def test_add_batch_with_prepend(self) -> None: + batch_1 = _Batch( + seq_no=1, + step_name="step1", + last_batch=False, + data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]], + ) + batch_manager_step = _BatchManagerStep( + step_name="step2", + accumulate=False, + input_batch_size=10, + data={"step1": [batch_1]}, + ) + + batch_0 = _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + batch_manager_step.add_batch(batch_0, prepend=True) + + assert batch_manager_step.built_batches == [batch_0] + assert batch_manager_step.data["step1"] == [batch_1] + assert batch_manager_step.last_batch_received == [] + + def test_add_batch_last_batch(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []} + ) + + batch = _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + ) + + batch_manager_step.add_batch(batch) + + assert batch_manager_step.data["step1"] == [batch] + assert batch_manager_step.last_batch_received == ["step1"] + + def test_get_batch(self) -> None: + previously_built_batch = _Batch( + seq_no=0, + step_name="step3", + last_batch=False, + data=[ + [ + {"a": -1}, + {"a": 0}, + ], + [ + {"b": -1}, + {"b": 0}, + ], + ], + ) + + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=2, + seq_no=1, + data={ + "step1": [ + _Batch( + seq_no=1, + step_name="step1", + last_batch=False, + data=[ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + size=5, + ) + ], + "step2": [ + _Batch( + seq_no=1, + step_name="step2", + last_batch=False, + data=[ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + ] + ], + size=5, + ) + ], + }, + built_batches=[previously_built_batch], + ) + + batch = batch_manager_step.get_batch() + + assert batch == previously_built_batch + + batch = batch_manager_step.get_batch() + + assert batch == _Batch( + step_name="step3", + seq_no=1, + last_batch=False, + data=[ + [ + {"a": 1}, + {"a": 2}, + ], + [ + {"b": 1}, + {"b": 2}, + ], + ], + created_from={"step1": [(1, 5)], "step2": [(1, 5)]}, + ) + + batch = batch_manager_step.get_batch() + + assert batch == _Batch( + step_name="step3", + seq_no=2, + last_batch=False, + data=[ + [ + {"a": 3}, + {"a": 4}, + ], + [ + {"b": 3}, + {"b": 4}, + ], + ], + created_from={"step1": [(1, 5)], "step2": [(1, 5)]}, + ) + + def test_get_batches_accumulate(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=True, + data={ + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + size=5, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + ] + ], + size=6, + ) + ], + }, + last_batch_received=["step1", "step2"], + ) + + batch = batch_manager_step.get_batch() + + assert batch == _Batch( + step_name="step3", + seq_no=0, + last_batch=True, + accumulated=True, + data=[ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ], + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + ], + ], + created_from={"step1": [(0, 5)], "step2": [(0, 6)]}, + ) + + def test_get_batches_not_enough_data(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=2, + data={ + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[ + [ + {"a": 1}, + ] + ], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[ + [ + {"b": 1}, + {"b": 2}, + ] + ], + ) + ], + }, + ) + + assert batch_manager_step.get_batch() is None + + def test_from_step(self, dummy_step_1: "Step") -> None: + batch_manager_step = _BatchManagerStep.from_step( + step=dummy_step_1, predecessors=["step1", "step2"] + ) + + assert batch_manager_step.step_name == "dummy_step_1" + assert batch_manager_step.accumulate is False + assert batch_manager_step.input_batch_size == 50 + assert batch_manager_step.data == {"step1": [], "step2": []} + assert batch_manager_step.seq_no == 0 + assert batch_manager_step.last_batch_received == [] + + def test_from_step_with_global_step(self, dummy_global_step: "GlobalStep") -> None: + batch_manager_step = _BatchManagerStep.from_step( + step=dummy_global_step, predecessors=["step1", "step2"] + ) + + assert batch_manager_step.step_name == "dummy_global_step" + assert batch_manager_step.accumulate is True + assert batch_manager_step.input_batch_size == 50 + assert batch_manager_step.data == {"step1": [], "step2": []} + assert batch_manager_step.seq_no == 0 + assert batch_manager_step.last_batch_received == [] + + def test_get_seq_no(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", accumulate=False, input_batch_size=5, data={"step1": []} + ) + + seq_no = batch_manager_step._get_seq_no() + + assert seq_no == 0 + assert batch_manager_step.seq_no == 1 + + def test_get_data(self) -> None: + batch_step_1 = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]], + size=6, + batch_routed_to=["step1", "step2"], + ) + batch_step_2 = _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + size=7, + batch_routed_to=["step1", "step2"], + ) + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data={ + "step1": [batch_step_1], + "step2": [batch_step_2], + }, + ) + + data, created_from, routed_to = batch_manager_step._get_data() + assert data == [ + [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}], + [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}], + ] + assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]} + assert routed_to == ["step1", "step2"] + + assert batch_manager_step.data == { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 6}]], + data_hash=batch_step_1.data_hash, + size=6, + batch_routed_to=["step1", "step2"], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 6}, {"b": 7}]], + data_hash=batch_step_2.data_hash, + size=7, + batch_routed_to=["step1", "step2"], + ) + ], + } + + def test_get_data_accumulate(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=True, + data={ + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[ + [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}] + ], + size=6, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + size=7, + ) + ], + }, + ) + + data, created_from, routed_to = batch_manager_step._get_data() + + assert data == [ + [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}], + [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}], + ] + assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]} + assert routed_to == [] + + assert batch_manager_step.data == {"step1": [], "step2": []} + + def test_get_data_convergence_step(self) -> None: + batch_a_0 = _Batch( + seq_no=0, + step_name="A", + last_batch=False, + data=[ + [ + {"generation": "Hello, I'm A 0"}, + {"generation": "Hello, I'm A 0"}, + {"generation": "Hello, I'm A 0"}, + ] + ], + size=3, + created_from={"Z": [(0, 3)]}, + ) + + batch_a_1 = _Batch( + seq_no=1, + step_name="A", + last_batch=False, + data=[ + [ + {"generation": "Hello, I'm A 1"}, + {"generation": "Hello, I'm A 1"}, + {"generation": "Hello, I'm A 1"}, + ] + ], + size=3, + created_from={"Z": [(1, 3)]}, + ) + + batch_b_0 = _Batch( + seq_no=0, + step_name="B", + last_batch=False, + data=[ + [ + {"generation": "Hello, I'm B 0"}, + {"generation": "Hello, I'm B 0"}, + {"generation": "Hello, I'm B 0"}, + ] + ], + size=3, + created_from={"Z": [(0, 3)]}, + ) + + batch_c_0 = _Batch( + seq_no=0, + step_name="C", + last_batch=False, + data=[ + [ + {"generation": "Hello, I'm C 0"}, + {"generation": "Hello, I'm C 0"}, + {"generation": "Hello, I'm C 0"}, + ] + ], + size=3, + created_from={"Z": [(1, 3)]}, + ) + + batch_manager_step = _BatchManagerStep( + step_name="D", + input_batch_size=3, + convergence_step=True, + accumulate=False, + data={"A": [batch_a_0, batch_a_1], "B": [batch_b_0], "C": [batch_c_0]}, + ) + + data, created_from, routed_to = batch_manager_step._get_data() + + assert data == [ + [ + {"generation": "Hello, I'm A 0"}, + {"generation": "Hello, I'm A 0"}, + {"generation": "Hello, I'm A 0"}, + ], + [ + {"generation": "Hello, I'm B 0"}, + {"generation": "Hello, I'm B 0"}, + {"generation": "Hello, I'm B 0"}, + ], + ] + assert created_from == {"A": [(0, 3)], "B": [(0, 3)]} + assert routed_to == [] + assert batch_manager_step.next_expected_created_from_batch_seq_no == 1 + + data, created_from, routed_to = batch_manager_step._get_data() + + assert data == [ + [ + {"generation": "Hello, I'm A 1"}, + {"generation": "Hello, I'm A 1"}, + {"generation": "Hello, I'm A 1"}, + ], + [ + {"generation": "Hello, I'm C 0"}, + {"generation": "Hello, I'm C 0"}, + {"generation": "Hello, I'm C 0"}, + ], + ] + assert created_from == {"A": [(1, 3)], "C": [(0, 3)]} + assert routed_to == [] + assert batch_manager_step.next_expected_created_from_batch_seq_no == 2 + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ] + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + ] + ], + ) + ] + }, + ["step1"], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ] + }, + ["step1"], + True, + ), + ], + ) + def test_last_batch( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", + accumulate=False, + input_batch_size=5, + data=data, + last_batch_received=last_batch_received, + ) + + assert batch_manager_step._last_batch() is expected + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + ["step1"], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + ["step1", "step2"], + True, + ), + ], + ) + def test_last_batch_accumulate( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=True, + data=data, + last_batch_received=last_batch_received, + ) + + assert batch_manager_step._last_batch() is expected + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + created_from={"step0": [(0, 3)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}]], + created_from={"step0": [(0, 3)]}, + ) + ], + }, + [], + True, + ), + ], + ) + def test_last_batch_convergence_step( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + data=data, + last_batch_received=last_batch_received, + input_batch_size=3, + convergence_step=True, + ) + + assert batch_manager_step._last_batch() is expected + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + ["step1", "step2"], + True, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], + ) + ], + }, + ["step1", "step2"], + True, + ), + ], + ) + def test_ready_to_create_batch( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", + accumulate=False, + input_batch_size=5, + data=data, + last_batch_received=last_batch_received, + ) + + assert batch_manager_step._ready_to_create_batch() is expected + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + ["step1", "step2"], + True, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + ["step1"], + False, + ), + ], + ) + def test_ready_to_create_batch_accumulate( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=True, + data=data, + last_batch_received=last_batch_received, + ) + + assert batch_manager_step._ready_to_create_batch() is expected + + def test_dump(self) -> None: + batch_step_1 = _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]], + data_hash="hash0", + size=6, + ) + batch_step_2 = _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[ + [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}] + ], + data_hash="hash1", + size=7, + ) + batch_step_3 = _Batch( + seq_no=0, + step_name="step3", + last_batch=True, + data=[[{"c": 1}, {"c": 2}, {"c": 3}, {"c": 4}, {"c": 5}]], + data_hash="hash2", + size=5, + ) + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=True, + data={ + "step1": [batch_step_1], + "step2": [batch_step_2], + }, + built_batches=[batch_step_3], + ) + assert batch_manager_step.dump() == { + "step_name": "step3", + "accumulate": True, + "convergence_step": False, + "convergence_step_batches_consumed": {}, + "input_batch_size": None, + "data": { + "step1": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": True, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + ] + ], + "data_hash": "hash0", + "size": 6, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "data_hash": "hash1", + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + }, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step3", + "last_batch": True, + "data": [[{"c": 1}, {"c": 2}, {"c": 3}, {"c": 4}, {"c": 5}]], + "data_hash": "hash2", + "size": 5, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 0, + "last_batch_received": [], + "next_expected_created_from_batch_seq_no": 0, + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + } + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + "step2": [], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + True, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, + ) + ], + }, + ["step1", "step2"], + True, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), + ], + ) + def test_ready_to_create_batch_convergence_step( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data=data, + last_batch_received=last_batch_received, + convergence_step=True, + ) + + assert batch_manager_step._ready_to_create_batch() is expected + + def test_from_dict(self) -> None: + batch_manager_step = _BatchManagerStep.from_dict( + { + "step_name": "step3", + "accumulate": True, + "convergence_step": False, + "convergence_step_batches_consumed": {0: {"Z": 1234}}, + "input_batch_size": None, + "data": { + "step1": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": True, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + ] + ], + "size": 6, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + } + ], + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + } + ], + }, + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + } + ) + + assert isinstance(batch_manager_step, _BatchManagerStep) + assert batch_manager_step.step_name == "step3" + assert batch_manager_step.accumulate is True + assert batch_manager_step.convergence_step is False + assert batch_manager_step.convergence_step_batches_consumed == {0: {"Z": 1234}} + assert batch_manager_step.input_batch_size is None + assert batch_manager_step.seq_no == 0 + assert batch_manager_step.last_batch_received == [] + + +class TestBatchManager: + def test_add_batch(self) -> None: + batch_manager = _BatchManager( + steps={ + "step3": _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data={"step1": [], "step2": []}, + ) + }, + last_batch_received={"step3": None}, + last_batch_sent={"step3": None}, + last_batch_flag_sent_to=[], + ) + + batch_from_step_1 = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + batch_manager.add_batch(to_step="step3", batch=batch_from_step_1) + + assert batch_manager._steps["step3"].data == { + "step1": [batch_from_step_1], + "step2": [], + } + + def test_add_batch_with_prepend(self) -> None: + batch_1 = _Batch( + seq_no=1, + step_name="step1", + last_batch=False, + data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]], + ) + batch_manager = _BatchManager( + steps={ + "step3": _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data={ + "step1": [batch_1], + "step2": [], + }, + ) + }, + last_batch_received={"step3": None}, + last_batch_sent={"step3": None}, + last_batch_flag_sent_to=[], + ) + batch_0 = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + batch_manager.add_batch(to_step="step3", batch=batch_0, prepend=True) + assert batch_manager._steps["step3"].built_batches == [batch_0] + assert batch_manager._steps["step3"].data == { + "step1": [batch_1], + "step2": [], + } + + def test_from_dag( + self, + dummy_generator_step: "GeneratorStep", + dummy_step_1: "Step", + dummy_step_2: "Step", + dummy_global_step: "GlobalStep", + ) -> None: + dag = DAG() + dag.add_step(dummy_generator_step) + dag.add_step(dummy_step_1) + dag.add_step(dummy_step_2) + dag.add_step(dummy_global_step) + dag.add_edge("dummy_generator_step", "dummy_step_1") + dag.add_edge("dummy_generator_step", "dummy_global_step") + dag.add_edge("dummy_step_1", "dummy_step_2") + + batch_manager = _BatchManager.from_dag(dag) + + assert batch_manager._steps == { + "dummy_step_1": _BatchManagerStep( + step_name="dummy_step_1", + accumulate=False, + input_batch_size=50, + data={"dummy_generator_step": []}, + ), + "dummy_global_step": _BatchManagerStep( + step_name="dummy_global_step", + accumulate=True, + input_batch_size=50, + data={"dummy_generator_step": []}, + ), + "dummy_step_2": _BatchManagerStep( + step_name="dummy_step_2", + accumulate=False, + input_batch_size=50, + data={"dummy_step_1": []}, + ), + } + + def test_can_generate(self) -> None: + batch_manager = _BatchManager( + steps={}, + last_batch_received={ + "step_1": _Batch(seq_no=0, step_name="step_1", last_batch=False), + "step_2": _Batch(seq_no=0, step_name="step_2", last_batch=False), + "step_3": _Batch(seq_no=0, step_name="step_3", last_batch=False), + }, + last_batch_sent={"step_1": None, "step_2": None, "step_3": None}, + last_batch_flag_sent_to=[], + ) + + assert batch_manager.can_generate() + + batch_1 = _Batch(seq_no=0, step_name="step_1", last_batch=True) + batch_2 = _Batch(seq_no=0, step_name="step_2", last_batch=True) + batch_3 = _Batch(seq_no=0, step_name="step_3", last_batch=True) + + batch_manager = _BatchManager( + steps={}, + last_batch_received={ + "step_1": batch_1, + "step_2": batch_2, + "step_3": batch_3, + }, + last_batch_sent={"step_1": batch_1, "step_2": batch_2, "step_3": batch_3}, + last_batch_flag_sent_to=[], + ) + + assert not batch_manager.can_generate() + + def test_dump(self) -> None: + built_batch = _Batch( + seq_no=0, + last_batch=False, + step_name="step3", + data=[[]], + data_hash="hash", + ) + + batch_manager = _BatchManager( + steps={ + "step3": _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data={"step1": [], "step2": []}, + built_batches=[built_batch], + seq_no=1, + ) + }, + last_batch_received={ + "step3": _Batch( + seq_no=0, + step_name="step3", + last_batch=False, + ) + }, + last_batch_sent={ + "step3": _Batch( + seq_no=1, + step_name="step3", + last_batch=False, + ) + }, + last_batch_flag_sent_to=["step99"], + ) + assert batch_manager.dump() == { + "steps": { + "step3": { + "step_name": "step3", + "accumulate": False, + "convergence_step": False, + "convergence_step_batches_consumed": {}, + "input_batch_size": 5, + "data": {"step1": [], "step2": []}, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step3", + "last_batch": False, + "data": [[]], + "data_hash": "hash", + "size": 0, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 1, + "last_batch_received": [], + "next_expected_created_from_batch_seq_no": 0, + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + }, + "last_batch_received": { + "step3": { + "seq_no": 0, + "step_name": "step3", + "batch_routed_to": [], + "created_from": {}, + "last_batch": False, + "data": [], + "data_hash": None, + "size": 0, + "accumulated": False, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + }, + "last_batch_sent": { + "step3": { + "seq_no": 1, + "step_name": "step3", + "batch_routed_to": [], + "created_from": {}, + "last_batch": False, + "data": [], + "data_hash": None, + "size": 0, + "accumulated": False, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + }, + "last_batch_flag_sent_to": ["step99"], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManager", + }, + } + + def test_from_dict(self) -> None: + batch_manager = _BatchManager.from_dict( + { + "steps": { + "step1": { + "step_name": "step1", + "accumulate": True, + "convergence_step": False, + "convergence_step_batches_consumed": {0: {"Z": 1234}}, + "input_batch_size": None, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + } + ], + }, + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + "step2": { + "step_name": "step2", + "accumulate": False, + "convergence_step": False, + "convergence_step_batches_consumed": {0: {"Z": 1234}}, + "input_batch_size": 50, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + } + ], + }, + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + }, + "last_batch_received": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + }, + "last_batch_sent": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + }, + "last_batch_flag_sent_to": ["step3"], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManager", + }, + } + ) + + assert isinstance(batch_manager, _BatchManager) + + assert len(batch_manager._steps) == 2 + for step in batch_manager._steps.values(): + assert isinstance(step, _BatchManagerStep) + + assert len(batch_manager._last_batch_received) == 2 + for step in batch_manager._last_batch_received.values(): + assert isinstance(step, _Batch) + + assert len(batch_manager._last_batch_sent) == 2 + for step in batch_manager._last_batch_sent.values(): + assert isinstance(step, _Batch) + + assert batch_manager._last_batch_flag_sent_to == ["step3"] + + def test_cache(self) -> None: + batch_manager = _BatchManager.from_dict( + { + "steps": { + "step1": { + "step_name": "step1", + "accumulate": True, + "convergence_step": False, + "convergence_step_batches_consumed": {"0": {"Z": 1234}}, + "input_batch_size": None, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "data_hash": "1234", + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + } + ], + }, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + "data_hash": "1234", + "size": 5, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + "step2": { + "step_name": "step2", + "accumulate": False, + "convergence_step": False, + "convergence_step_batches_consumed": {"0": {"Z": 1234}}, + "input_batch_size": 50, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "data_hash": "1234", + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + } + ], + }, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + "data_hash": "1234", + "size": 5, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + }, + "last_batch_received": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + }, + "last_batch_sent": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + }, + "last_batch_flag_sent_to": ["step3"], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManager", + }, + } + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + batch_manager_path = Path(tmp_dir) / "batch_manager.json" + batch_manager.cache(batch_manager_path) + + assert batch_manager_path.exists() and batch_manager_path.is_file() + + for step_name, step in batch_manager._steps.items(): + batch_manager_step_dir = ( + Path(tmp_dir) / "batch_manager_steps" / step_name + ) + assert ( + batch_manager_step_dir.exists() and batch_manager_step_dir.is_dir() + ) + + batch_manager_step_path = ( + batch_manager_step_dir / "batch_manager_step.json" + ) + assert ( + batch_manager_step_path.exists() + and batch_manager_step_path.is_file() + ) + + built_batches_dir = batch_manager_step_dir / "built_batches" + assert built_batches_dir.exists() + + for batch in step.built_batches: + batch_path = ( + built_batches_dir + / f"batch_{batch.seq_no}_{batch.data_hash}.json" + ) + assert batch_path.exists() and batch_path.is_file() + + for buffered_step_name in step.data: + buffered_step_dir = batch_manager_step_dir / buffered_step_name + assert buffered_step_dir.exists() and buffered_step_dir.is_dir() + + for batch in step.data[buffered_step_name]: + batch_path = ( + buffered_step_dir + / f"batch_{batch.seq_no}_{batch.data_hash}.json" + ) + assert batch_path.exists() and batch_path.is_file() + + def test_load_from_cache(self) -> None: + batch_manager = _BatchManager.from_dict( + { + "steps": { + "step1": { + "step_name": "step1", + "accumulate": True, + "convergence_step": False, + "convergence_step_batches_consumed": {"0": {"Z": 1234}}, + "input_batch_size": None, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "data_hash": "1234", + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + }, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + "data_hash": "1234", + "size": 5, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + "step2": { + "step_name": "step2", + "accumulate": False, + "convergence_step": False, + "convergence_step_batches_consumed": {"0": {"Z": 1234}}, + "input_batch_size": 50, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "data_hash": "1234", + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + }, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + "data_hash": "1234", + "size": 5, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + }, + "last_batch_received": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + }, + }, + "last_batch_sent": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + }, + }, + "last_batch_flag_sent_to": ["step3"], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManager", + }, + } + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + batch_manager_path = Path(tmp_dir) / "batch_manager.json" + batch_manager.cache(batch_manager_path) + loaded_batch_manager = _BatchManager.load_from_cache(batch_manager_path) + + assert batch_manager.dump() == loaded_batch_manager.dump() diff --git a/tests/unit/pipeline/test_local.py b/tests/unit/pipeline/test_local.py index 3c4a15b534..4797f8e66d 100644 --- a/tests/unit/pipeline/test_local.py +++ b/tests/unit/pipeline/test_local.py @@ -15,7 +15,8 @@ from typing import TYPE_CHECKING from unittest import mock -from distilabel.pipeline.base import _Batch, _BatchManager +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.batch_manager import _BatchManager from distilabel.pipeline.local import Pipeline from .utils import DummyGeneratorStep, DummyStep1, DummyStep2 @@ -58,17 +59,11 @@ def test_send_batch_to_step(self, dummy_generator_step: "GeneratorStep") -> None ) pipeline._send_batch_to_step(batch=batch) # type: ignore - batch_manager_mock.set_last_batch_sent.assert_called_once_with(batch) - get_step_mock.assert_called_once_with(dummy_generator_step.name) + get_step_mock.assert_has_calls([mock.call(dummy_generator_step.name)]) input_queue.put.assert_called_once_with(batch) @mock.patch("distilabel.pipeline.local._ProcessWrapper") def test_create_processes(self, process_wrapper_mock: mock.MagicMock) -> None: - pool = mock.MagicMock() - manager = mock.MagicMock() - queue = mock.MagicMock() - shared_info = mock.MagicMock() - with Pipeline(name="unit-test-pipeline") as pipeline: dummy_generator = DummyGeneratorStep(name="dummy_generator_step") dummy_step_1 = DummyStep1(name="dummy_step_1") @@ -77,51 +72,52 @@ def test_create_processes(self, process_wrapper_mock: mock.MagicMock) -> None: dummy_generator.connect(dummy_step_1) dummy_step_1.connect(dummy_step_2) - pipeline._run_steps_in_loop(pool, manager, queue, shared_info) + pipeline._pool = mock.MagicMock() + pipeline._manager = mock.MagicMock() + pipeline._output_queue = mock.MagicMock() + pipeline._load_queue = mock.MagicMock() + pipeline._run_steps() - assert manager.Queue.call_count == 3 + assert pipeline._manager.Queue.call_count == 3 process_wrapper_mock.assert_has_calls( [ mock.call( step=dummy_generator, input_queue=mock.ANY, - output_queue=queue, - shared_info=shared_info, + output_queue=pipeline._output_queue, + load_queue=pipeline._load_queue, dry_run=False, ), mock.call( step=dummy_step_1, input_queue=mock.ANY, - output_queue=queue, - shared_info=shared_info, + output_queue=pipeline._output_queue, + load_queue=pipeline._load_queue, dry_run=False, ), mock.call( step=dummy_step_2, input_queue=mock.ANY, - output_queue=queue, - shared_info=shared_info, + output_queue=pipeline._output_queue, + load_queue=pipeline._load_queue, dry_run=False, ), ], ) - pool.apply_async.assert_has_calls( + pipeline._pool.apply_async.assert_has_calls( [ mock.call( process_wrapper_mock.return_value.run, - callback=pipeline._finished_callback, error_callback=pipeline._error_callback, ), mock.call( process_wrapper_mock.return_value.run, - callback=pipeline._finished_callback, error_callback=pipeline._error_callback, ), mock.call( process_wrapper_mock.return_value.run, - callback=pipeline._finished_callback, error_callback=pipeline._error_callback, ), ] diff --git a/tests/unit/pipeline/test_routing_batch_function.py b/tests/unit/pipeline/test_routing_batch_function.py index 5e3f208c5b..6cc3090eb7 100644 --- a/tests/unit/pipeline/test_routing_batch_function.py +++ b/tests/unit/pipeline/test_routing_batch_function.py @@ -14,7 +14,7 @@ from typing import List -from distilabel.pipeline.base import _Batch +from distilabel.pipeline.batch import _Batch from distilabel.pipeline.local import Pipeline from distilabel.pipeline.routing_batch_function import ( RoutingBatchFunction, diff --git a/tests/unit/pipeline/test_write_buffer.py b/tests/unit/pipeline/test_write_buffer.py new file mode 100644 index 0000000000..a7ae64c91e --- /dev/null +++ b/tests/unit/pipeline/test_write_buffer.py @@ -0,0 +1,150 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +from distilabel.distiset import Distiset, create_distiset +from distilabel.pipeline.local import Pipeline +from distilabel.pipeline.write_buffer import _WriteBuffer + +from tests.unit.pipeline.utils import ( + DummyGeneratorStep, + DummyStep1, + DummyStep2, + batch_gen, +) + + +class TestWriteBuffer: + def test_create(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + folder = Path(tmpdirname) / "data" + with Pipeline(name="unit-test-pipeline") as pipeline: + dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1") + dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2") + dummy_step_1 = DummyStep1(name="dummy_step_1") + dummy_step_2 = DummyStep2(name="dummy_step_2") + dummy_step_3 = DummyStep2(name="dummy_step_3") + + dummy_generator_1.connect(dummy_step_1) + dummy_generator_2.connect(dummy_step_2) + dummy_step_1.connect(dummy_step_2) + dummy_step_1.connect(dummy_step_3) + + write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) + + assert write_buffer._buffers == {"dummy_step_2": [], "dummy_step_3": []} + assert write_buffer._buffers_dump_batch_size == { + "dummy_step_2": 50, + "dummy_step_3": 50, + } + assert write_buffer._buffer_last_schema == {} + assert write_buffer._buffers_last_file == { + "dummy_step_2": 1, + "dummy_step_3": 1, + } + + def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + folder = Path(tmpdirname) / "data" + with Pipeline(name="unit-test-pipeline") as pipeline: + dummy_generator = DummyGeneratorStep(name="dummy_generator_step") + dummy_step_1 = DummyStep1(name="dummy_step_1") + dummy_step_2 = DummyStep2(name="dummy_step_2") + + dummy_generator.connect(dummy_step_1) + dummy_step_1.connect(dummy_step_2) + + write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) + + # Add one batch with 5 rows, shouldn't write anything 5 < 50 + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + # Add 45 more rows, should write now + for _ in range(9): + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + assert Path(folder, "dummy_step_2", "00001.parquet").exists() + + # Add 50 more rows, we should have a new file + for _ in range(10): + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + assert Path(folder, "dummy_step_2", "00002.parquet").exists() + + # Add more rows and close the write buffer, we should have a new file + for _ in range(5): + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + write_buffer.close() + + assert Path(folder, "dummy_step_2", "00003.parquet").exists() + + ds = create_distiset(write_buffer._path) + assert isinstance(ds, Distiset) + assert len(ds.keys()) == 1 + assert len(ds["default"]["train"]) == 125 + + def test_write_buffer_multiple_leaf_steps_and_create_dataset(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + folder = Path(tmpdirname) / "data" + with Pipeline(name="unit-test-pipeline") as pipeline: + dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1") + dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2") + dummy_step_1 = DummyStep1(name="dummy_step_1") + dummy_step_2 = DummyStep2(name="dummy_step_2") + dummy_step_3 = DummyStep2(name="dummy_step_3") + + dummy_generator_1.connect(dummy_step_1) + dummy_generator_2.connect(dummy_step_2) + dummy_step_1.connect(dummy_step_2) + dummy_step_1.connect(dummy_step_3) + + write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) + + for _ in range(10): + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + assert Path(folder, "dummy_step_2", "00001.parquet").exists() + + for _ in range(10): + batch = batch_gen(dummy_step_3.name) # type: ignore + write_buffer.add_batch(batch) + + assert Path(folder, "dummy_step_3", "00001.parquet").exists() + + for _ in range(5): + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + for _ in range(5): + batch = batch_gen(dummy_step_3.name) # type: ignore + write_buffer.add_batch(batch) + + write_buffer.close() + + assert Path(folder, "dummy_step_2", "00002.parquet").exists() + assert Path(folder, "dummy_step_3", "00002.parquet").exists() + + ds = create_distiset(write_buffer._path) + assert isinstance(ds, Distiset) + assert len(ds.keys()) == 2 + assert len(ds["dummy_step_2"]["train"]) == 75 + assert len(ds["dummy_step_3"]["train"]) == 75 diff --git a/tests/unit/pipeline/utils.py b/tests/unit/pipeline/utils.py index 8d02340114..7f771271d0 100644 --- a/tests/unit/pipeline/utils.py +++ b/tests/unit/pipeline/utils.py @@ -14,7 +14,7 @@ from typing import List -from distilabel.pipeline.base import _Batch +from distilabel.pipeline.batch import _Batch from distilabel.steps.base import GeneratorStep, GlobalStep, Step, StepInput from distilabel.steps.typing import GeneratorStepOutput, StepOutput diff --git a/tests/unit/steps/argilla/test_base.py b/tests/unit/steps/argilla/test_base.py index c816c8bcac..dbb8773923 100644 --- a/tests/unit/steps/argilla/test_base.py +++ b/tests/unit/steps/argilla/test_base.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import sys from typing import TYPE_CHECKING, List import pytest @@ -83,7 +84,9 @@ def test_with_errors(self, caplog) -> None: with pytest.raises( TypeError, - match="Can't instantiate abstract class Argilla with abstract methods inputs, process", + match="Can't instantiate abstract class Argilla with abstract methods inputs, process" + if sys.version_info < (3, 12) + else "Can't instantiate abstract class Argilla without an implementation for abstract methods 'inputs', 'process'", ): Argilla(name="step", pipeline=Pipeline(name="unit-test-pipeline")) # type: ignore diff --git a/tests/unit/steps/generators/sample_functions.jsonl b/tests/unit/steps/generators/sample_functions.jsonl new file mode 100644 index 0000000000..700d21ad5b --- /dev/null +++ b/tests/unit/steps/generators/sample_functions.jsonl @@ -0,0 +1,11 @@ +{"type": "function", "function": {"name": "code_interpreter", "description": "Execute the provided Python code string on the terminal using exec.\n\n The string should contain valid, executable and pure Python code in markdown syntax.\n Code should also import any required Python packages.\n\n Args:\n code_markdown (str): The Python code with markdown syntax to be executed.\n For example: ```python\n\n```\n\n Returns:\n dict | str: A dictionary containing variables declared and values returned by function calls,\n or an error message if an exception occurred.\n\n Note:\n Use this function with caution, as executing arbitrary code can pose security risks.", "parameters": {"type": "object", "properties": {"code_markdown": {"type": "string"}}, "required": ["code_markdown"]}}} +{"type": "function", "function": {"name": "google_search_and_scrape", "description": "Performs a Google search for the given query, retrieves the top search result URLs,\nand scrapes the text content and table data from those pages in parallel.\n\nArgs:\n query (str): The search query.\nReturns:\n list: A list of dictionaries containing the URL, text content, and table data for each scraped page.", "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "get_current_stock_price", "description": "Get the current stock price for a given symbol.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n float: The current stock price, or None if an error occurs.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_company_news", "description": "Get company news and press releases for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing company news and press releases.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_company_profile", "description": "Get company profile and overview for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing company profile and overview.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_stock_fundamentals", "description": "Get fundamental data for a given stock symbol using yfinance API.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n dict: A dictionary containing fundamental data.\n Keys:\n - 'symbol': The stock symbol.\n - 'company_name': The long name of the company.\n - 'sector': The sector to which the company belongs.\n - 'industry': The industry to which the company belongs.\n - 'market_cap': The market capitalization of the company.\n - 'pe_ratio': The forward price-to-earnings ratio.\n - 'pb_ratio': The price-to-book ratio.\n - 'dividend_yield': The dividend yield.\n - 'eps': The trailing earnings per share.\n - 'beta': The beta value of the stock.\n - '52_week_high': The 52-week high price of the stock.\n - '52_week_low': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_financial_statements", "description": "Get financial statements for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing financial statements (income statement, balance sheet, cash flow statement).", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_key_financial_ratios", "description": "Get key financial ratios for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing key financial ratios.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_analyst_recommendations", "description": "Get analyst recommendations for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing analyst recommendations.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_dividend_data", "description": "Get dividend data for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing dividend data.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} +{"type": "function", "function": {"name": "get_technical_indicators", "description": "Get technical indicators for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing technical indicators.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} diff --git a/tests/unit/steps/generators/test_data.py b/tests/unit/steps/generators/test_data.py index 9817837e20..c35b9db86d 100644 --- a/tests/unit/steps/generators/test_data.py +++ b/tests/unit/steps/generators/test_data.py @@ -17,7 +17,7 @@ from pydantic import ValidationError -class TestLoadDataFromDictsTask: +class TestLoadDataFromDicts: data = [{"instruction": "test"}] * 10 def test_init(self) -> None: diff --git a/tests/unit/steps/generators/test_huggingface.py b/tests/unit/steps/generators/test_huggingface.py index 34b44f4fc5..e72a70acb2 100644 --- a/tests/unit/steps/generators/test_huggingface.py +++ b/tests/unit/steps/generators/test_huggingface.py @@ -13,19 +13,27 @@ # limitations under the License. import os +import tempfile +from pathlib import Path from typing import Generator, Union import pytest from datasets import Dataset, IterableDataset +from distilabel.distiset import Distiset from distilabel.pipeline import Pipeline -from distilabel.steps.generators.huggingface import LoadHubDataset +from distilabel.steps.generators.huggingface import ( + LoadDataFromDisk, + LoadDataFromFileSystem, + LoadDataFromHub, + LoadHubDataset, +) DISTILABEL_RUN_SLOW_TESTS = os.getenv("DISTILABEL_RUN_SLOW_TESTS", False) @pytest.fixture(scope="module") def dataset_loader() -> Generator[Union[Dataset, IterableDataset], None, None]: - load_hub_dataset = LoadHubDataset( + load_hub_dataset = LoadDataFromHub( name="load_dataset", repo_id="distilabel-internal-testing/instruction-dataset-mini", split="test", @@ -39,12 +47,12 @@ def dataset_loader() -> Generator[Union[Dataset, IterableDataset], None, None]: not DISTILABEL_RUN_SLOW_TESTS, reason="These tests depend on internet connection, are slow and depend mainly on HF API, we don't need to test them often.", ) -class TestLoadHubDataset: +class TestLoadDataFromHub: @pytest.mark.parametrize( "streaming, ds_type", [(True, IterableDataset), (False, Dataset)] ) def test_runtime_parameters(self, streaming: bool, ds_type) -> None: - load_hub_dataset = LoadHubDataset( + load_hub_dataset = LoadDataFromHub( name="load_dataset", repo_id="distilabel-internal-testing/instruction-dataset-mini", split="test", @@ -60,6 +68,131 @@ def test_runtime_parameters(self, streaming: bool, ds_type) -> None: assert isinstance(generator_step_output[1], bool) assert len(generator_step_output[0]) == 2 - def test_dataset_outputs(self, dataset_loader: LoadHubDataset) -> None: + def test_dataset_outputs(self, dataset_loader: LoadDataFromHub) -> None: # TODO: This test can be run with/without internet connection, we should emulate it here with a mock. assert dataset_loader.outputs == ["prompt", "completion", "meta"] + + +class TestLoadDataFromFileSystem: + @pytest.mark.parametrize("filetype", ["json", None]) + @pytest.mark.parametrize("streaming", [True, False]) + def test_read_from_jsonl(self, streaming: bool, filetype: Union[str, None]) -> None: + loader = LoadDataFromFileSystem( + filetype=filetype, + data_files=str(Path(__file__).parent / "sample_functions.jsonl"), + streaming=streaming, + ) + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 11 + + @pytest.mark.parametrize("filetype", ["json", None]) + def test_read_from_jsonl_with_folder(self, filetype: Union[str, None]) -> None: + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + filename = "sample_functions.jsonl" + sample_file = Path(__file__).parent / filename + for i in range(3): + Path(tmpdir).mkdir(parents=True, exist_ok=True) + (Path(tmpdir) / f"sample_functions_{i}.jsonl").write_text( + sample_file.read_text(), encoding="utf-8" + ) + + loader = LoadDataFromFileSystem( + filetype=filetype, + data_files=tmpdir, + ) + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 33 + + @pytest.mark.parametrize("filetype", ["json", None]) + def test_read_from_jsonl_with_nested_folder( + self, filetype: Union[str, None] + ) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + filename = "sample_functions.jsonl" + sample_file = Path(__file__).parent / filename + for folder in ["train", "validation"]: + (Path(tmpdir) / folder).mkdir(parents=True, exist_ok=True) + (Path(tmpdir) / folder / filename).write_text( + sample_file.read_text(), encoding="utf-8" + ) + + loader = LoadDataFromFileSystem( + filetype=filetype, + data_files=tmpdir, + ) + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 22 + + @pytest.mark.parametrize("load", [True, False]) + def test_outputs(self, load: bool) -> None: + loader = LoadDataFromFileSystem( + filetype="json", + data_files=str(Path(__file__).parent / "sample_functions.jsonl"), + ) + if load: + loader.load() + assert loader.outputs == ["type", "function"] + else: + with pytest.raises(ValueError): + loader.outputs # noqa: B018 + + +class TestLoadDataFromDisk: + def test_load_dataset_from_disk(self) -> None: + dataset = Dataset.from_dict({"a": [1, 2, 3]}) + with tempfile.TemporaryDirectory() as tmpdir: + dataset_path = str(Path(tmpdir) / "dataset_path") + dataset.save_to_disk(dataset_path) + + loader = LoadDataFromDisk(dataset_path=dataset_path) + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 3 + + def test_load_distiset_from_disk(self) -> None: + distiset = Distiset( + { + "leaf_step_1": Dataset.from_dict({"a": [1, 2, 3]}), + "leaf_step_2": Dataset.from_dict( + {"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]} + ), + } + ) + with tempfile.TemporaryDirectory() as tmpdir: + dataset_path = str(Path(tmpdir) / "dataset_path") + distiset.save_to_disk(dataset_path) + + loader = LoadDataFromDisk( + dataset_path=dataset_path, is_distiset=True, config="leaf_step_1" + ) + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 3 + + +def test_LoadHubDataset_deprecation_warning(): + with pytest.deprecated_call(): + LoadHubDataset( + repo_id="distilabel-internal-testing/instruction-dataset-mini", + split="test", + batch_size=2, + ) + import distilabel + from packaging.version import Version + + assert Version(distilabel.__version__) <= Version("1.3.0") diff --git a/tests/unit/steps/tasks/conftest.py b/tests/unit/steps/tasks/conftest.py deleted file mode 100644 index da1493c9ce..0000000000 --- a/tests/unit/steps/tasks/conftest.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2023-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, List - -import pytest -from distilabel.llms.base import LLM - -if TYPE_CHECKING: - from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType - - -@pytest.fixture -def dummy_llm() -> LLM: - class DummyLLM(LLM): - def load(self) -> None: - pass - - @property - def model_name(self) -> str: - return "test" - - def generate( # type: ignore - self, inputs: List["ChatType"], num_generations: int = 1 - ) -> List["GenerateOutput"]: - return [["output"] for _ in inputs] - - return DummyLLM() - - -# Defined here too, so that the serde still works -class DummyLLM(LLM): - def load(self) -> None: - pass - - @property - def model_name(self) -> str: - return "test" - - def generate( # type: ignore - self, inputs: List["ChatType"], num_generations: int = 1 - ) -> List["GenerateOutput"]: - return [["output"] for _ in inputs] diff --git a/tests/unit/steps/tasks/evol_instruct/test_base.py b/tests/unit/steps/tasks/evol_instruct/test_base.py index d3999684c1..9b679f6dfb 100644 --- a/tests/unit/steps/tasks/evol_instruct/test_base.py +++ b/tests/unit/steps/tasks/evol_instruct/test_base.py @@ -121,13 +121,12 @@ def test_serialization(self, dummy_llm: LLM) -> None: task.load() assert task.dump() == { "name": "task", - "add_raw_output": False, + "add_raw_output": True, "input_mappings": task.input_mappings, "output_mappings": task.output_mappings, "input_batch_size": task.input_batch_size, "llm": { "generation_kwargs": {}, - "structured_output": None, "type_info": { "module": task.llm.__module__, "name": task.llm.__class__.__name__, @@ -163,6 +162,11 @@ def test_serialization(self, dummy_llm: LLM) -> None: } ], }, + { + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + "name": "add_raw_output", + "optional": True, + }, { "name": "num_generations", "optional": True, diff --git a/tests/unit/steps/tasks/evol_instruct/test_generator.py b/tests/unit/steps/tasks/evol_instruct/test_generator.py index fee2234083..13cf6e2783 100644 --- a/tests/unit/steps/tasks/evol_instruct/test_generator.py +++ b/tests/unit/steps/tasks/evol_instruct/test_generator.py @@ -117,13 +117,12 @@ def test_serialization(self, dummy_llm: LLM) -> None: "name": "task", "llm": { "generation_kwargs": {}, - "structured_output": None, "type_info": { "module": task.llm.__class__.__module__, "name": task.llm.__class__.__name__, }, }, - "add_raw_output": False, + "add_raw_output": True, "input_mappings": task.input_mappings, "output_mappings": task.output_mappings, "batch_size": task.batch_size, @@ -158,6 +157,11 @@ def test_serialization(self, dummy_llm: LLM) -> None: }, ], }, + { + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + "name": "add_raw_output", + "optional": True, + }, { "name": "num_generations", "optional": True, diff --git a/tests/unit/steps/tasks/evol_quality/test_base.py b/tests/unit/steps/tasks/evol_quality/test_base.py index 251e377d9e..fffacafa06 100644 --- a/tests/unit/steps/tasks/evol_quality/test_base.py +++ b/tests/unit/steps/tasks/evol_quality/test_base.py @@ -34,6 +34,18 @@ def test_with_errors( EvolQuality(name="task", llm=dummy_llm, num_evolutions=2) assert "Step 'task' hasn't received a pipeline" in caplog.text + def test_apply_random_mutation(self, dummy_llm: LLM) -> None: + pipeline = Pipeline(name="unit-test-pipeline") + task = EvolQuality( + name="task", llm=dummy_llm, num_evolutions=2, pipeline=pipeline + ) + task.load() + + mutated = task._apply_random_mutation("I'm an instruction", "I'm a response") + + assert "I'm an instruction" in mutated + assert "I'm a response" in mutated + def test_process(self, dummy_llm: LLM) -> None: pipeline = Pipeline(name="unit-test-pipeline") task = EvolQuality( @@ -80,13 +92,12 @@ def test_serialization(self, dummy_llm: LLM) -> None: task.load() assert task.dump() == { "name": "task", - "add_raw_output": False, + "add_raw_output": True, "input_mappings": task.input_mappings, "output_mappings": task.output_mappings, "input_batch_size": task.input_batch_size, "llm": { "generation_kwargs": {}, - "structured_output": None, "type_info": { "module": task.llm.__module__, "name": task.llm.__class__.__name__, @@ -112,9 +123,14 @@ def test_serialization(self, dummy_llm: LLM) -> None: "name": "generation_kwargs", "description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.", "keys": [], - } + }, ], }, + { + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + "name": "add_raw_output", + "optional": True, + }, { "name": "num_generations", "optional": True, diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index 549155076b..e174f53716 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -17,9 +17,10 @@ import pytest from distilabel.llms.huggingface.transformers import TransformersLLM from distilabel.steps.tasks.structured_outputs.outlines import ( - StructuredOutputType, + # StructuredOutputType, model_to_schema, ) +from distilabel.steps.tasks.typing import OutlinesStructuredOutputType from pydantic import BaseModel @@ -88,10 +89,6 @@ class DummyUserTest(BaseModel): class TestOutlinesIntegration: - # @pytest.mark.skipif( - # not DISTILABEL_RUN_SLOW_TESTS, - # reason="Slow tests, run locally when needed.", - # ) @pytest.mark.parametrize( "format, schema, prompt", [ @@ -99,7 +96,7 @@ class TestOutlinesIntegration: "json", DummyUserTest, "Create a user profile with the fields name, last_name and id", - ), # + ), ( "json", model_to_schema(DummyUserTest), @@ -117,7 +114,9 @@ def test_generation( ) -> None: llm = TransformersLLM( model="openaccess-ai-collective/tiny-mistral", - structured_output=StructuredOutputType(format=format, schema=schema), + structured_output=OutlinesStructuredOutputType( + format=format, schema=schema + ), ) llm.load() @@ -154,7 +153,9 @@ def test_serialization( ) -> None: llm = TransformersLLM( model="openaccess-ai-collective/tiny-mistral", - structured_output=StructuredOutputType(format=format, schema=schema), + structured_output=OutlinesStructuredOutputType( + format=format, schema=schema + ), ) llm.load() assert llm.dump() == dump diff --git a/tests/unit/steps/tasks/structured_outputs/test_utils.py b/tests/unit/steps/tasks/structured_outputs/test_utils.py new file mode 100644 index 0000000000..6238c8567f --- /dev/null +++ b/tests/unit/steps/tasks/structured_outputs/test_utils.py @@ -0,0 +1,75 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import List + +from distilabel.steps.tasks.structured_outputs.utils import json_schema_to_model +from pydantic import BaseModel, Field, StringConstraints, conint +from typing_extensions import Annotated + + +class Node(BaseModel): + id: int + label: str + color: str + + +class Edge(BaseModel): + source: int + target: int + label: str + color: str = "black" + + +class KnowledgeGraph(BaseModel): + nodes: List[Node] = Field(..., default_factory=list) + edges: List[Edge] = Field(..., default_factory=list) + + +class Weapon(str, Enum): + sword = "sword" + axe = "axe" + mace = "mace" + spear = "spear" + bow = "bow" + crossbow = "crossbow" + + +class Armor(str, Enum): + leather = "leather" + chainmail = "chainmail" + plate = "plate" + mithril = "mithril" + + +class Character(BaseModel): + name: Annotated[str, StringConstraints(max_length=30)] + age: conint(gt=1, lt=3000) + armor: Armor + weapon: Weapon + + +def test_json_schema_to_model(): + assert type(json_schema_to_model(Node.model_json_schema())) == type(Node) + + +def test_json_schema_to_model_with_enum(): + assert type(json_schema_to_model(Character.model_json_schema())) == type(Character) + + +def test_json_schema_to_model_nested(): + assert type(json_schema_to_model(KnowledgeGraph.model_json_schema())) == type( + KnowledgeGraph + ) diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index ed1fd956cf..c5ff87e1e0 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys from dataclasses import field from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -21,7 +22,7 @@ from distilabel.steps.tasks.base import Task from pydantic import ValidationError -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM if TYPE_CHECKING: from distilabel.steps.tasks.typing import ChatType @@ -77,7 +78,9 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: with pytest.raises( TypeError, - match="Can't instantiate abstract class Task with abstract methods format_input, format_output", + match="Can't instantiate abstract class Task with abstract methods format_input, format_output" + if sys.version_info < (3, 12) + else "Can't instantiate abstract class Task without an implementation for abstract methods 'format_input', 'format_output'", ): Task(name="task", llm=DummyLLM()) # type: ignore @@ -91,16 +94,19 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "instruction": "test", "output": "output", "model_name": "test", + "distilabel_metadata": {"raw_output_task": "output"}, }, { "instruction": "test", "output": "output", "model_name": "test", + "distilabel_metadata": {"raw_output_task": "output"}, }, { "instruction": "test", "output": "output", "model_name": "test", + "distilabel_metadata": {"raw_output_task": "output"}, }, ], ), @@ -111,6 +117,11 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "instruction": "test", "output": ["output", "output", "output"], "model_name": "test", + "distilabel_metadata": [ + {"raw_output_task": "output"}, + {"raw_output_task": "output"}, + {"raw_output_task": "output"}, + ], }, ], ), @@ -145,7 +156,7 @@ def test_process_with_runtime_parameters(self) -> None: assert task.llm.runtime_parameters_names == { "runtime_parameter": False, "runtime_parameter_optional": True, - "generation_kwargs": {"kwargs": False}, + "generation_kwargs": {}, } # 2. Runtime parameters in init @@ -160,7 +171,7 @@ def test_process_with_runtime_parameters(self) -> None: assert task.llm.runtime_parameters_names == { "runtime_parameter": False, "runtime_parameter_optional": True, - "generation_kwargs": {"kwargs": False}, + "generation_kwargs": {}, } # 3. Runtime parameters in init superseded by runtime parameters @@ -176,7 +187,7 @@ def test_process_with_runtime_parameters(self) -> None: assert task.llm.runtime_parameters_names == { "runtime_parameter": False, "runtime_parameter_optional": True, - "generation_kwargs": {"kwargs": False}, + "generation_kwargs": {}, } def test_serialization(self) -> None: @@ -185,15 +196,14 @@ def test_serialization(self) -> None: task = DummyTask(name="task", llm=llm, pipeline=pipeline) assert task.dump() == { "name": "task", - "add_raw_output": False, + "add_raw_output": True, "input_mappings": {}, "output_mappings": {}, "input_batch_size": 50, "llm": { "generation_kwargs": {}, - "structured_output": None, "type_info": { - "module": "tests.unit.steps.tasks.utils", + "module": "tests.unit.conftest", "name": "DummyLLM", }, }, @@ -211,16 +221,16 @@ def test_serialization(self) -> None: { "description": "The kwargs to be propagated to either `generate` or " "`agenerate` methods within each `LLM`.", - "keys": [ - { - "name": "kwargs", - "optional": False, - }, - ], + "keys": [], "name": "generation_kwargs", }, ], }, + { + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + "name": "add_raw_output", + "optional": True, + }, { "name": "num_generations", "description": "The number of generations to be produced per input.", diff --git a/tests/unit/steps/tasks/test_complexity_scorer.py b/tests/unit/steps/tasks/test_complexity_scorer.py index a47a16445d..ec0575d745 100644 --- a/tests/unit/steps/tasks/test_complexity_scorer.py +++ b/tests/unit/steps/tasks/test_complexity_scorer.py @@ -18,7 +18,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.complexity_scorer import ComplexityScorer -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestComplexityScorer: diff --git a/tests/unit/steps/tasks/test_genstruct.py b/tests/unit/steps/tasks/test_genstruct.py index 8ecc9d2d58..12878b9f26 100644 --- a/tests/unit/steps/tasks/test_genstruct.py +++ b/tests/unit/steps/tasks/test_genstruct.py @@ -18,7 +18,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.genstruct import Genstruct -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestGenstruct: diff --git a/tests/unit/steps/tasks/test_improving_text_embeddings.py b/tests/unit/steps/tasks/test_improving_text_embeddings.py new file mode 100644 index 0000000000..8ab9b2fd51 --- /dev/null +++ b/tests/unit/steps/tasks/test_improving_text_embeddings.py @@ -0,0 +1,406 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, List + +import pytest +from distilabel.llms import LLM +from distilabel.llms.typing import GenerateOutput +from distilabel.pipeline.local import Pipeline +from distilabel.steps.tasks.improving_text_embeddings import ( + BitextRetrievalGenerator, + EmbeddingTaskGenerator, + GenerateLongTextMatchingData, + GenerateShortTextMatchingData, + GenerateTextClassificationData, + GenerateTextRetrievalData, + MonolingualTripletGenerator, +) +from distilabel.steps.tasks.typing import ChatType + + +class MockLLM(LLM): + output: str + + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + def generate( # type: ignore + self, inputs: List[ChatType], num_generations: int = 1 + ) -> List[GenerateOutput]: + return [[self.output] for _ in range(num_generations)] + + +class TestEmbeddingTaskGenerator: + @pytest.mark.parametrize( + "category", + [ + "text-retrieval", + "text-matching-short", + "text-matching-long", + "text-classification", + ], + ) + @pytest.mark.parametrize("flatten_tasks", [True, False]) + def test_process(self, category: str, flatten_tasks: bool) -> None: + task = EmbeddingTaskGenerator( + name="embedding_task_generator", + category=category, # type: ignore + flatten_tasks=flatten_tasks, + add_raw_output=False, + llm=MockLLM(output="[ 'A', 'B', 'C' ]"), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + assert task.outputs == ["tasks" if not flatten_tasks else "task", "model_name"] + + result = ( + ([{"tasks": ["A", "B", "C"], "model_name": "test"}], True) + if not flatten_tasks + else ( + [ + {"task": "A", "model_name": "test"}, + {"task": "B", "model_name": "test"}, + {"task": "C", "model_name": "test"}, + ], + True, + ) + ) + assert next(task.process()) == result + + +class TestBitextRetrievalGenerator: + @pytest.mark.parametrize( + "task_kwargs", + [ + { + "source_language": "English", + "target_language": "French", + "unit": "sentence", + "difficulty": "elementary school", + "high_score": "4", + "low_score": "2.5", + } + ], + ) + def test_prompt(self, task_kwargs: Any) -> None: + task = BitextRetrievalGenerator( + name="bitext_retrieval_generator", + **task_kwargs, + add_raw_output=False, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + assert all( + task.prompt[-1]["content"].__contains__(v) for _, v in task_kwargs.items() + ) + + def test_process(self) -> None: + task = BitextRetrievalGenerator( + name="bitext_retrieval_generator", + source_language="English", + target_language="French", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + assert task.outputs == ["S1", "S2", "S3", "model_name"] + + assert next(task.process()) == ( + [{"S1": "A", "S2": "B", "S3": "C", "model_name": "test"}], + True, + ) + + def test_reproducibility(self) -> None: + unique_prompts = set() + for _ in range(10): + task = BitextRetrievalGenerator( + name="bitext_retrieval_generator", + source_language="English", + target_language="French", + add_raw_output=False, + seed=42, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + unique_prompts.add(task.prompt[-1]["content"]) + + assert len(unique_prompts) == 1 + + +class TestMonolingualTripletGenerator: + @pytest.mark.parametrize( + "task_kwargs", + [ + { + "language": "English", + "unit": "sentence", + "difficulty": "elementary school", + "high_score": "4", + "low_score": "2.5", + } + ], + ) + def test_prompt(self, task_kwargs: Any) -> None: + task = MonolingualTripletGenerator( + name="monolingual_triplet_generator", + **task_kwargs, + add_raw_output=False, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert all( + task.prompt[-1]["content"].__contains__(v) for _, v in task_kwargs.items() + ) + + def test_process(self) -> None: + task = MonolingualTripletGenerator( + name="monolingual_triplet_generator", + language="English", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.outputs == ["S1", "S2", "S3", "model_name"] + assert next(task.process()) == ( + [{"S1": "A", "S2": "B", "S3": "C", "model_name": "test"}], + True, + ) + + def test_reproducibility(self) -> None: + unique_prompts = set() + for _ in range(10): + task = MonolingualTripletGenerator( + name="monolingual_triplet_generator", + language="English", + add_raw_output=False, + seed=42, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + unique_prompts.add(task.prompt[-1]["content"]) + assert len(unique_prompts) == 1 + + +class TestGenerateLongTextMatchingData: + def test_format_input(self) -> None: + task = GenerateLongTextMatchingData( + name="generate_long_text_matching_data", + language="English", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + assert task.format_input({"task": "A"})[-1]["content"].startswith( + "You have been assigned a text matching task: A" + ) + + def test_process(self) -> None: + task = GenerateLongTextMatchingData( + name="generate_long_text_matching_data", + language="English", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + assert task.outputs == ["input", "positive_document", "model_name"] + + assert next(task.process(inputs=[{"task": "A"}])) == [ + {"task": "A", "input": "A", "positive_document": "B", "model_name": "test"} + ] + + +class TestGenerateShortTextMatchingData: + def test_format_input(self) -> None: + task = GenerateShortTextMatchingData( + name="generate_short_text_matching_data", + language="English", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.format_input({"task": "A"})[-1]["content"].startswith( + "You have been assigned a text matching task: A" + ) + + def test_process(self) -> None: + task = GenerateShortTextMatchingData( + name="generate_short_text_matching_data", + language="English", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.outputs == ["input", "positive_document", "model_name"] + assert next(task.process(inputs=[{"task": "A"}])) == [ + {"task": "A", "input": "A", "positive_document": "B", "model_name": "test"} + ] + + def test_reproducibility(self) -> None: + unique_prompts = set() + for _ in range(10): + task = GenerateShortTextMatchingData( + name="generate_short_text_matching_data", + language="English", + add_raw_output=False, + seed=42, + llm=MockLLM( + output=json.dumps({"input": "A", "positive_document": "B"}) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + unique_prompts.add(task.format_input({"task": "A"})[-1]["content"]) + + assert len(unique_prompts) == 1 + + +class TestGenerateTextClassificationData: + def test_format_input(self) -> None: + task = GenerateTextClassificationData( + name="generate_text_classification_data", + language="English", + add_raw_output=False, + llm=MockLLM( + output=json.dumps( + {"input_text": "A", "label": "B", "misleading_label": "C"} + ) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.format_input({"task": "A"})[-1]["content"].startswith( + "You have been assigned a text classification task: A" + ) + + def test_process(self) -> None: + task = GenerateTextClassificationData( + name="generate_text_classification_data", + language="English", + add_raw_output=False, + llm=MockLLM( + output=json.dumps( + {"input_text": "A", "label": "B", "misleading_label": "C"} + ) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.outputs == ["input_text", "label", "misleading_label", "model_name"] + assert next(task.process(inputs=[{"task": "A"}])) == [ + { + "task": "A", + "input_text": "A", + "label": "B", + "misleading_label": "C", + "model_name": "test", + } + ] + + def test_reproducibility(self) -> None: + unique_prompts = set() + for _ in range(10): + task = GenerateTextClassificationData( + name="generate_text_classification_data", + language="English", + add_raw_output=False, + seed=42, + llm=MockLLM( + output=json.dumps( + {"input_text": "A", "label": "B", "misleading_label": "C"} + ) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + unique_prompts.add(task.format_input({"task": "A"})[-1]["content"]) + + assert len(unique_prompts) == 1 + + +class TestGenerateTextRetrievalData: + def test_format_input(self) -> None: + task = GenerateTextRetrievalData( + name="generate_text_retrieval_data", + language="English", + add_raw_output=False, + llm=MockLLM( + output=json.dumps( + { + "user_query": "A", + "positive_document": "B", + "hard_negative_document": "C", + } + ) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.format_input({"task": "A"})[-1]["content"].startswith( + "You have been assigned a retrieval task: A" + ) + + def test_process(self) -> None: + task = GenerateTextRetrievalData( + name="generate_text_retrieval_data", + language="English", + add_raw_output=False, + llm=MockLLM( + output=json.dumps( + { + "user_query": "A", + "positive_document": "B", + "hard_negative_document": "C", + } + ) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.outputs == [ + "user_query", + "positive_document", + "hard_negative_document", + "model_name", + ] + assert next(task.process(inputs=[{"task": "A"}])) == [ + { + "task": "A", + "user_query": "A", + "positive_document": "B", + "hard_negative_document": "C", + "model_name": "test", + } + ] diff --git a/tests/unit/steps/tasks/test_instruction_backtranslation.py b/tests/unit/steps/tasks/test_instruction_backtranslation.py index 4c8a8df7fe..a6f2793285 100644 --- a/tests/unit/steps/tasks/test_instruction_backtranslation.py +++ b/tests/unit/steps/tasks/test_instruction_backtranslation.py @@ -86,5 +86,8 @@ def test_process(self) -> None: "score": 1, "reason": "This is the reason.", "model_name": "instruction-backtranslation-model", + "distilabel_metadata": { + "raw_output_instruction-backtranslation": "This is the reason. Score: 1" + }, } ] diff --git a/tests/unit/steps/tasks/test_prometheus_eval.py b/tests/unit/steps/tasks/test_prometheus_eval.py index e5a4ad8590..31a437fdab 100644 --- a/tests/unit/steps/tasks/test_prometheus_eval.py +++ b/tests/unit/steps/tasks/test_prometheus_eval.py @@ -27,7 +27,7 @@ from jinja2 import Template from pydantic import ValidationError -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM def load_template(template: str) -> Template: diff --git a/tests/unit/steps/tasks/test_quality_scorer.py b/tests/unit/steps/tasks/test_quality_scorer.py index 0a3db8261b..608631e9a2 100644 --- a/tests/unit/steps/tasks/test_quality_scorer.py +++ b/tests/unit/steps/tasks/test_quality_scorer.py @@ -18,7 +18,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.quality_scorer import QualityScorer -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestQualityScorer: diff --git a/tests/unit/steps/tasks/test_self_instruct.py b/tests/unit/steps/tasks/test_self_instruct.py index 8525b88e6b..e3378e7e93 100644 --- a/tests/unit/steps/tasks/test_self_instruct.py +++ b/tests/unit/steps/tasks/test_self_instruct.py @@ -15,7 +15,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.self_instruct import SelfInstruct -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestSelfInstruct: diff --git a/tests/unit/steps/tasks/test_sentence_transformers.py b/tests/unit/steps/tasks/test_sentence_transformers.py new file mode 100644 index 0000000000..2f81240755 --- /dev/null +++ b/tests/unit/steps/tasks/test_sentence_transformers.py @@ -0,0 +1,215 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict + +import pytest +from distilabel.steps.tasks.sentence_transformers import ( + CONTEXT_INTRO, + POSITIVE_NEGATIVE_SYSTEM_PROMPT, + POSITIVE_SYSTEM_PROMPT, + GenerateSentencePair, + GenerationAction, +) + +from tests.unit.conftest import DummyLLM + + +class TestGenerateSentencePair: + @pytest.mark.parametrize( + "action,triplet,system_prompt", + [ + ( + "paraphrase", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="paraphrase", context="" + ), + ), + ( + "paraphrase", + False, + POSITIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase", context=""), + ), + ( + "semantically-similar", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be semantically similar to", context="" + ), + ), + ( + "semantically-similar", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be semantically similar to", context="" + ), + ), + ( + "query", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be a query for", context="" + ), + ), + ( + "query", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be a query for", context="" + ), + ), + ( + "answer", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be an answer for", context="" + ), + ), + ( + "answer", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be an answer for", context="" + ), + ), + ], + ) + def test_format_input( + self, action: GenerationAction, triplet: bool, system_prompt: str + ) -> None: + task = GenerateSentencePair(llm=DummyLLM(), action=action, triplet=triplet) + task.load() + content = "## Anchor\n\nThis is a unit test\n" + assert task.format_input({"anchor": "This is a unit test"}) == [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": content}, + ] + + @pytest.mark.parametrize( + "action,triplet,system_prompt", + [ + ( + "paraphrase", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="paraphrase", context=CONTEXT_INTRO + ), + ), + ( + "paraphrase", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="paraphrase", context=CONTEXT_INTRO + ), + ), + ( + "semantically-similar", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be semantically similar to", context=CONTEXT_INTRO + ), + ), + ( + "semantically-similar", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be semantically similar to", context=CONTEXT_INTRO + ), + ), + ( + "query", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be a query for", context=CONTEXT_INTRO + ), + ), + ( + "query", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be a query for", context=CONTEXT_INTRO + ), + ), + ( + "answer", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be an answer for", context=CONTEXT_INTRO + ), + ), + ( + "answer", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be an answer for", context=CONTEXT_INTRO + ), + ), + ], + ) + def test_format_input_with_context( + self, action: GenerationAction, triplet: bool, system_prompt: str + ) -> None: + context = "This is your context." + task = GenerateSentencePair( + llm=DummyLLM(), + action=action, + triplet=triplet, + context=context, + ) + task.load() + content = f"## Context\n\n{context}\n\n## Anchor\n\nThis is a unit test\n" + # content = f"## Anchor\n\nThis is a unit test\n## Context\n\n{context}" + assert task.format_input({"anchor": "This is a unit test"}) == [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": content}, + ] + + @pytest.mark.parametrize( + "output,triplet,expected", + [ + ( + "## Positive\n\nThis is a paraphrase\n## Negative\n\nThis is not a paraphrase", + True, + { + "positive": "This is a paraphrase", + "negative": "This is not a paraphrase", + }, + ), + ( + "## Positive\n\nThis is a paraphrase", + True, + {"positive": "This is a paraphrase", "negative": None}, + ), + ( + "## Positive\n\nThis is a paraphrase", + False, + {"positive": "This is a paraphrase"}, + ), + ( + "random", + False, + {"positive": None}, + ), + ], + ) + def test_format_output( + self, output: str, triplet: bool, expected: Dict[str, Any] + ) -> None: + task = GenerateSentencePair( + llm=DummyLLM(), action="paraphrase", triplet=triplet + ) + task.load() + + assert task.format_output(output) == expected diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py new file mode 100644 index 0000000000..e2c230ef7e --- /dev/null +++ b/tests/unit/steps/tasks/test_structured_generation.py @@ -0,0 +1,125 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, List + +from distilabel.llms.base import LLM +from distilabel.llms.typing import GenerateOutput +from distilabel.pipeline.local import Pipeline +from distilabel.steps.tasks.structured_generation import StructuredGeneration +from distilabel.steps.tasks.typing import StructuredInput +from typing_extensions import override + + +class DummyStructuredLLM(LLM): + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + @override + def generate( # type: ignore + self, inputs: List["StructuredInput"], num_generations: int = 1, **kwargs: Any + ) -> List["GenerateOutput"]: + return [ + [json.dumps({"test": "output"}) for _ in range(num_generations)] + for _ in inputs + ] + + +class TestStructuredGeneration: + def test_format_input(self) -> None: + pipeline = Pipeline(name="unit-test-pipeline") + llm = DummyStructuredLLM() + task = StructuredGeneration(name="task", llm=llm, pipeline=pipeline) + + # 1. Including the `grammar` field within the input + assert task.format_input( + { + "instruction": "test", + "system_prompt": "test", + "structured_output": {"format": "regex", "schema": r"[a-zA-Z]+"}, + } + ) == ( + [{"role": "user", "content": "test"}], + {"format": "regex", "schema": r"[a-zA-Z]+"}, + ) + + # 2. Not including the `grammar` field within the input + assert task.format_input({"instruction": "test", "system_prompt": "test"}) == ( + [{"role": "user", "content": "test"}], + None, + ) + + def test_format_input_with_system_prompt(self) -> None: + pipeline = Pipeline(name="unit-test-pipeline") + llm = DummyStructuredLLM() + task = StructuredGeneration( + name="task", + llm=llm, + pipeline=pipeline, + use_system_prompt=True, + ) + + assert task.format_input({"instruction": "test", "system_prompt": "test"}) == ( + [ + {"role": "system", "content": "test"}, + {"role": "user", "content": "test"}, + ], + None, + ) + + def test_process(self) -> None: + pipeline = Pipeline(name="unit-test-pipeline") + llm = DummyStructuredLLM() + task = StructuredGeneration(name="task", llm=llm, pipeline=pipeline) + assert next( + task.process( + [ + { + "instruction": "test", + "structured_output": { + "format": "json", + "schema": { + "properties": { + "test": {"title": "Test", "type": "string"} + }, + "required": ["test"], + "title": "Test", + "type": "object", + }, + }, + } + ] + ) + ) == [ + { + "instruction": "test", + "structured_output": { + "format": "json", + "schema": { + "properties": {"test": {"title": "Test", "type": "string"}}, + "required": ["test"], + "title": "Test", + "type": "object", + }, + }, + "generation": '{"test": "output"}', + "model_name": "test", + "distilabel_metadata": {"raw_output_task": '{"test": "output"}'}, + } + ] diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index ecff0e1d90..c98adb00e5 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -16,7 +16,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestTextGeneration: @@ -53,6 +53,12 @@ def test_format_input_errors(self) -> None: name="task", llm=llm, pipeline=pipeline, use_system_prompt=True ) + with pytest.raises( + ValueError, + match=r"Providing \`instruction\` formatted as an OpenAI chat / conversation is deprecated", + ): + task.format_input({"instruction": [{"role": "user", "content": "test"}]}) + with pytest.raises( ValueError, match=r"Input \`instruction\` must be a string. Got: 1." ): @@ -76,26 +82,12 @@ def test_process(self) -> None: "instruction": "test", "generation": "output", "model_name": "test", + "distilabel_metadata": { + "raw_output_task": "output", + }, } ] - def test_deprecation_warning(self) -> None: - pipeline = Pipeline(name="unit-test-pipeline") - llm = DummyLLM() - task = TextGeneration(name="task", llm=llm, pipeline=pipeline) - - with pytest.warns( - DeprecationWarning, - match=r"Providing \`instruction\` formatted as an OpenAI chat \/ conversation is about to be deprecated in \`distilabel v1.2.0\`", - ): - task.format_input( - { - "instruction": [ - {"role": "user", "content": "Tell me a joke."}, - ] - } - ) - class TestChatGeneration: def test_format_input(self) -> None: @@ -150,5 +142,6 @@ def test_process(self) -> None: "messages": [{"role": "user", "content": "Tell me a joke."}], "generation": "output", "model_name": "test", + "distilabel_metadata": {"raw_output_task": "output"}, } ] diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py index 69e9326570..fa72ff9442 100644 --- a/tests/unit/steps/tasks/test_ultrafeedback.py +++ b/tests/unit/steps/tasks/test_ultrafeedback.py @@ -63,6 +63,9 @@ def test_process_with_simple_aspect(self) -> None: "ratings": [1, 2], "rationales": ["text", "text"], "model_name": "ultrafeedback-model", + "distilabel_metadata": { + "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" + }, } ] @@ -89,5 +92,8 @@ def test_process_with_complex_aspect(self) -> None: "ratings": [1, 2], "rationales-for-ratings": ["text", "text"], "model_name": "ultrafeedback-model", + "distilabel_metadata": { + "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" + }, } ] diff --git a/tests/unit/steps/test_base.py b/tests/unit/steps/test_base.py index 543e5a43d2..6b469bf324 100644 --- a/tests/unit/steps/test_base.py +++ b/tests/unit/steps/test_base.py @@ -310,7 +310,7 @@ def test_step_from_dict(self) -> None: **{ "name": "dummy", TYPE_INFO_KEY: { - "module": "tests.unit.pipeline.step.test_base", + "module": "tests.unit.steps.test_base", "name": "DummyStep", }, } @@ -327,7 +327,7 @@ def test_step_from_dict_without_pipeline_context( **{ "name": "dummy", TYPE_INFO_KEY: { - "module": "tests.pipeline.step.test_base", + "module": "tests.unit.steps.test_base", "name": "DummyStep", }, } diff --git a/tests/unit/test_distiset.py b/tests/unit/test_distiset.py index 18de3f8769..07e6549d7b 100644 --- a/tests/unit/test_distiset.py +++ b/tests/unit/test_distiset.py @@ -12,9 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import re +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional + import pytest +import yaml from datasets import Dataset, DatasetDict from distilabel.distiset import Distiset +from upath import UPath @pytest.fixture(scope="function") @@ -27,6 +35,24 @@ def distiset(): ) +def make_fake_file(filename: Path) -> None: + if not filename.parent.exists(): + filename.parent.mkdir(parents=True) + filename.touch() + + +def add_config_to_distiset(distiset: Distiset, folder: Path) -> Distiset: + from distilabel.distiset import DISTISET_CONFIG_FOLDER + + pipeline_yaml = folder / DISTISET_CONFIG_FOLDER / "pipeline.yaml" + pipeline_log = folder / DISTISET_CONFIG_FOLDER / "pipeline.log" + make_fake_file(pipeline_yaml) + make_fake_file(pipeline_log) + distiset.pipeline_path = pipeline_yaml + distiset.pipeline_log_path = pipeline_log + return distiset + + class TestDistiset: def test_train_test_split(self, distiset: Distiset) -> None: assert isinstance(distiset["leaf_step_1"], Dataset) @@ -34,3 +60,111 @@ def test_train_test_split(self, distiset: Distiset) -> None: assert isinstance(ds, Distiset) assert len(ds) == 2 assert isinstance(ds["leaf_step_1"], DatasetDict) + + @pytest.mark.parametrize("storage_options", [None, {"test": "option"}]) + @pytest.mark.parametrize("with_config", [False, True]) + def test_save_to_disk( + self, + distiset: Distiset, + with_config: bool, + storage_options: Optional[Dict[str, Any]], + ) -> None: + full_distiset = copy.deepcopy(distiset) + # Distiset with Distiset + with tempfile.TemporaryDirectory() as tmpdirname: + folder = Path(tmpdirname) / "distiset_folder" + if with_config: + full_distiset = add_config_to_distiset(full_distiset, folder) + + full_distiset.save_to_disk( + folder, + save_card=with_config, + save_pipeline_config=with_config, + save_pipeline_log=with_config, + storage_options=storage_options, + ) + assert folder.is_dir() + assert len(list(folder.iterdir())) == 3 + + full_distiset = copy.deepcopy(distiset) + # Distiset with DatasetDict + distiset_with_dict = full_distiset.train_test_split(0.8) + with tempfile.TemporaryDirectory() as tmpdirname: + folder = Path(tmpdirname) / "distiset_folder" + if with_config: + distiset_with_dict = add_config_to_distiset(distiset_with_dict, folder) + + distiset_with_dict.save_to_disk( + folder, + save_card=with_config, + save_pipeline_config=with_config, + save_pipeline_log=with_config, + ) + + assert folder.is_dir() + assert len(list(folder.iterdir())) == 3 + + @pytest.mark.parametrize("pathlib_implementation", [Path, UPath]) + @pytest.mark.parametrize("storage_options", [None, {"project": "experiments"}]) + @pytest.mark.parametrize("with_config", [False, True]) + def test_load_from_disk( + self, + distiset: Distiset, + with_config: bool, + storage_options: Optional[Dict[str, Any]], + pathlib_implementation: type, + ) -> None: + full_distiset = copy.deepcopy(distiset) + # Distiset with Distiset + with tempfile.TemporaryDirectory() as tmpdirname: + # This way we can test also we work with UPath, using FilePath protocol, as it should + # do the same as S3Path, GCSPath, etc. + folder = pathlib_implementation(tmpdirname) / "distiset_folder" + if with_config: + full_distiset = add_config_to_distiset(full_distiset, folder) + full_distiset.save_to_disk( + folder, + save_card=with_config, + save_pipeline_config=with_config, + save_pipeline_log=with_config, + storage_options=storage_options, + ) + ds = Distiset.load_from_disk( + folder, + storage_options=storage_options, + ) + assert isinstance(ds, Distiset) + assert isinstance(ds["leaf_step_1"], Dataset) + + if with_config: + assert ds.pipeline_path.exists() + assert ds.log_filename_path.exists() + + full_distiset = copy.deepcopy(distiset) + # Distiset with DatasetDict + distiset_with_dict = full_distiset.train_test_split(0.8) + with tempfile.TemporaryDirectory() as tmpdirname: + folder = pathlib_implementation(tmpdirname) / "distiset_folder" + if with_config: + distiset_with_dict = add_config_to_distiset(distiset_with_dict, folder) + + distiset_with_dict.save_to_disk(folder) + ds = Distiset.load_from_disk(folder, storage_options=storage_options) + + assert folder.is_dir() + assert isinstance(ds["leaf_step_1"], DatasetDict) + + if with_config: + assert ds.pipeline_path.exists() + assert ds.log_filename_path.exists() + + def test_dataset_card(self, distiset: Distiset) -> None: + # Test the the metadata we generate by default without extracting the already generated content from the HF hub. + # We parse the content and check it's the same as the one we generate. + distiset_card = distiset._get_card("repo_name_or_path") + metadata = re.findall(r"---\n(.*?)\n---", str(distiset_card), re.DOTALL)[0] + metadata = yaml.safe_load(metadata) + assert metadata == { + "size_categories": "n<1K", + "tags": ["synthetic", "distilabel", "rlaif"], + } diff --git a/tests/unit/test_imports.py b/tests/unit/test_imports.py index cfccb1585f..e20e186c8e 100644 --- a/tests/unit/test_imports.py +++ b/tests/unit/test_imports.py @@ -51,6 +51,8 @@ def test_imports() -> None: GeneratorStepOutput, KeepColumns, LoadDataFromDicts, + LoadDataFromHub, + LoadDataFromDisk, LoadHubDataset, PushToHub, Step, @@ -72,11 +74,19 @@ def test_imports() -> None: EvolInstructGenerator, GenerateEmbeddings, Genstruct, + BitextRetrievalGenerator, + EmbeddingTaskGenerator, + GenerateLongTextMatchingData, + GenerateShortTextMatchingData, + GenerateTextClassificationData, + GenerateTextRetrievalData, + MonolingualTripletGenerator, InstructionBacktranslation, PairRM, PrometheusEval, QualityScorer, SelfInstruct, + StructuredGeneration, TextGeneration, UltraFeedback, ) diff --git a/tests/unit/utils/test_serialization.py b/tests/unit/utils/test_serialization.py new file mode 100644 index 0000000000..153e2a8692 --- /dev/null +++ b/tests/unit/utils/test_serialization.py @@ -0,0 +1,37 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.utils.serialization import _extra_serializable_fields, _Serializable +from pydantic import BaseModel + + +def test_extra_serializable_fields() -> None: + class DummyAttribute(BaseModel, _Serializable): + pass + + class Dummy(BaseModel, _Serializable): + attr: DummyAttribute + + dummy = Dummy(attr=DummyAttribute()) + + assert _extra_serializable_fields(dummy) == [ + { + "attr": { + "type_info": { + "module": "tests.unit.utils.test_serialization", + "name": "DummyAttribute", + } + } + } + ]