", response)
)
def _evolve_reponses(self, inputs: "StepInput") -> List[List[str]]:
diff --git a/src/distilabel/steps/tasks/generate_embeddings.py b/src/distilabel/steps/tasks/generate_embeddings.py
index 39c17f016e..1b0df634c6 100644
--- a/src/distilabel/steps/tasks/generate_embeddings.py
+++ b/src/distilabel/steps/tasks/generate_embeddings.py
@@ -47,6 +47,33 @@ class GenerateEmbeddings(Step):
References:
- [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685)
+
+ Examples:
+
+ Rank LLM candidates:
+
+ ```python
+ from distilabel.steps.tasks import GenerateEmbeddings
+ from distilabel.llms.huggingface import TransformersLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ embedder = GenerateEmbeddings(
+ llm=TransformersLLM(
+ model="TaylorAI/bge-micro-v2",
+ model_kwargs={"is_decoder": True},
+ cuda_devices=[],
+ )
+ )
+ embedder.load()
+
+ result = next(
+ embedder.process(
+ [
+ {"text": "Hello, how are you?"},
+ ]
+ )
+ )
+ ```
"""
llm: LLM
diff --git a/src/distilabel/steps/tasks/genstruct.py b/src/distilabel/steps/tasks/genstruct.py
index 550e1220d5..1e80fcb429 100644
--- a/src/distilabel/steps/tasks/genstruct.py
+++ b/src/distilabel/steps/tasks/genstruct.py
@@ -67,6 +67,42 @@ class Genstruct(Task):
References:
- [Genstruct 7B by Nous Research](https://huggingface.co/NousResearch/Genstruct-7B)
- [Ada-Instruct: Adapting Instruction Generators for Complex Reasoning](https://arxiv.org/abs/2310.04484)
+
+ Examples:
+
+ Generate instructions from raw documents using the title and content:
+
+ ```python
+ from distilabel.steps.tasks import Genstruct
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ genstruct = Genstruct(
+ llm=InferenceEndpointsLLM(
+ model_id="NousResearch/Genstruct-7B",
+ ),
+ )
+
+ genstruct.load()
+
+ result = next(
+ genstruct.process(
+ [
+ {"title": "common instruction", "content": "content of the document"},
+ ]
+ )
+ )
+ # result
+ # [
+ # {
+ # 'title': 'An instruction',
+ # 'content': 'content of the document',
+ # 'model_name': 'test',
+ # 'user': 'An instruction',
+ # 'assistant': 'content of the document',
+ # }
+ # ]
+ ```
"""
_template: Union[Template, None] = PrivateAttr(...)
diff --git a/src/distilabel/steps/tasks/improving_text_embeddings.py b/src/distilabel/steps/tasks/improving_text_embeddings.py
new file mode 100644
index 0000000000..0e91354274
--- /dev/null
+++ b/src/distilabel/steps/tasks/improving_text_embeddings.py
@@ -0,0 +1,941 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import re
+import sys
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Literal, Optional, Union
+
+if sys.version_info < (3, 9):
+ import importlib_resources
+else:
+ import importlib.resources as importlib_resources
+
+from jinja2 import Template
+from pydantic import Field, PrivateAttr
+from typing_extensions import override
+
+from distilabel.steps.tasks.base import GeneratorTask, Task
+from distilabel.steps.tasks.typing import ChatType
+from distilabel.steps.typing import GeneratorStepOutput
+
+
+# BASE CLASSES
+class _JSONFormatter(ABC):
+ """Abstract class that sets the `outputs` property and `format_output` method, assuming
+ that the output is a JSON string with the keys specified in the `keys` property. So on,
+ this class is intended to be used whenever we get a JSON string as the `LLM` output with
+ a set of `keys` we know are there.
+
+ Note:
+ At the moment this abstract class is only intended to be used for the tasks defined
+ below based on the output generated by those. Also note that this is not a replacement
+ for neither the `StructuredGeneration` task nor for the `structured_output` argument
+ of an `LLM` subclass.
+ """
+
+ @property
+ @abstractmethod
+ def keys(self) -> List[str]:
+ """Contains the `keys` that will be parsed from the `LLM` output into a Python dict."""
+ ...
+
+ @property
+ def outputs(self) -> List[str]:
+ """Contains the output columns produced by the `process` method of the task. In this
+ case, it consists of the `keys` (i.e. the JSON keys) and the `model_name`.
+ """
+ return self.keys + ["model_name"]
+
+ def format_output(
+ self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """Method to parse the JSON output into a Python dictionary based on the `keys` property.
+
+ Args:
+ output: The JSON string output produced by the `LLM`.
+ input: The input dictionary that was used to generate the output.
+
+ Returns:
+ A Python dictionary with the parsed output based on the `keys` property.
+ """
+ if output is None:
+ return {key: None for key in self.keys}
+
+ def escape_backslashes_in_values(s):
+ # Regular expression to match the key-value pairs in the dictionary
+ pattern = re.compile(r'(".*?":\s*")(.*?)(",?)', re.DOTALL)
+
+ def replace_backslashes(match):
+ return (
+ match.group(1)
+ + re.sub(
+ r"(? None:
+ """Loads the Jinja2 template and sets the random seed."""
+ super().load()
+
+ random.seed(self.seed)
+
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "improving_text_embeddings"
+ / f"{self._template_name}.jinja2" # type: ignore
+ )
+
+ self._template = Template(open(_path).read())
+
+ @property
+ def inputs(self) -> List[str]:
+ """Contains the input columns expected by the `process` method of the task. In this
+ case, it consists of the `task`; ideally produced in a previous task which should be
+ preferrably `EmbeddingTaskGenerator` (as per the original implementation)."""
+ return ["task"]
+
+
+class _EmbeddingDataGenerator(_JSONFormatter, GeneratorTask, ABC):
+ """Base class for the subtasks related to embedding data generation as presented in the
+ paper "Improving Text Embeddings with Large Language Models" that generate data without
+ an input i.e. `GeneratorStep` or `GeneratorTask`. This class includes a pre-defined `load`
+ method to load a Jinja2 template based on the `_template_name` private attribute (to be set
+ in each of the subclasses), assuming that the `prompt` property only expects the `task`, while
+ keeping the `format_input` as an abstract method to be implemented in the subclasses.
+
+ Attributes:
+ seed: The random seed to be set in case there's any sampling within the `format_input` method.
+ _template: The Jinja2 template to be rendered within the `format_input` method with the
+ provided arguments.
+ _template_name: The name of the Jinja2 template file within the
+ `distilabel/steps/tasks/templates/improving_text_embeddings` directory.
+ """
+
+ seed: int = 42
+
+ _template: Union[Template, None] = PrivateAttr(...)
+ _template_name: str = PrivateAttr(...)
+
+ def load(self) -> None:
+ """Loads the Jinja2 template and sets the random seed."""
+ super().load()
+
+ random.seed(self.seed)
+
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "improving_text_embeddings"
+ / f"{self._template_name}.jinja2" # type: ignore
+ )
+
+ self._template = Template(open(_path).read())
+
+ @property
+ @abstractmethod
+ def prompt(self) -> ChatType:
+ """The prompt to be used for the generation step, ideally rendering the `_template`."""
+ ...
+
+ @override
+ def process(self, offset: int = 0) -> GeneratorStepOutput: # type: ignore
+ """Method to run the `LLM` generation with the `prompt`, as well as formatting the
+ outputs accordingly for the task i.e. via the `_JSONFormatter` inheritance. So on, the
+ `LLM` ideally will be prompted to produce JSON content and then the `format_output`
+ method will parse it into a Python dictionary based on the `keys` property.
+
+ Args:
+ offset: The offset to start the generation from. Defaults to 0.
+
+ Yields:
+ The output rows and a boolean indicating if it's the last batch or not.
+ """
+ formatted_inputs = [self.prompt]
+ outputs = self.llm.generate(
+ inputs=formatted_inputs,
+ num_generations=self.num_generations,
+ **self.llm.generation_kwargs, # type: ignore
+ )
+
+ task_outputs = []
+ for input_outputs in outputs:
+ formatted_outputs = self._format_outputs(input_outputs) # type: ignore
+ for formatted_output in formatted_outputs:
+ task_outputs.append(
+ {
+ **formatted_output,
+ "model_name": self.llm.model_name,
+ }
+ )
+ yield task_outputs, True
+
+
+# IMPLEMENTED TASKS
+class EmbeddingTaskGenerator(GeneratorTask):
+ """Generate task descriptions for embedding-related tasks using an `LLM`.
+
+ `EmbeddingTaskGenerator` is a `GeneratorTask` that doesn't receieve any input besides the
+ provided attributes that generates task descriptions for embedding-related tasks using a
+ pre-defined prompt based on the `category` attribute. The `category` attribute should be
+ one of the following:
+
+ - `text-retrieval`: Generate task descriptions for text retrieval tasks.
+ - `text-matching-short`: Generate task descriptions for short text matching tasks.
+ - `text-matching-long`: Generate task descriptions for long text matching tasks.
+ - `text-classification`: Generate task descriptions for text classification tasks.
+
+ Attributes:
+ category: The category of the task to be generated, which can either be `text-retrieval`,
+ `text-matching-short`, `text-matching-long`, or `text-classification`.
+ flatten_tasks: Whether to flatten the tasks i.e. since a list of tasks is generated by the
+ `LLM`, this attribute indicates whether to flatten the list or not. Defaults to `False`,
+ meaning that running this task with `num_generations=1` will return a `distilabel.Distiset`
+ with one row only containing a list with around 20 tasks; otherwise, if set to `True`, it
+ will return a `distilabel.Distiset` with around 20 rows, each containing one task.
+
+ References:
+ - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368)
+
+ Examples:
+
+ Generate embedding tasks for text retrieval:
+
+ ```python
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps.tasks import EmbeddingTaskGenerator
+
+ with Pipeline("my-pipeline") as pipeline:
+ task = EmbeddingTaskGenerator(
+ category="text-retrieval",
+ flatten_tasks=True,
+ llm=..., # LLM instance
+ )
+
+ ...
+
+ task >> ...
+ ```
+ """
+
+ category: Literal[
+ "text-retrieval",
+ "text-matching-short",
+ "text-matching-long",
+ "text-classification",
+ ]
+ flatten_tasks: bool = False
+
+ _template: Union[Template, None] = PrivateAttr(...)
+
+ def load(self) -> None:
+ """Loads the Jinja2 template."""
+ super().load()
+
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "improving_text_embeddings"
+ / "brainstorming"
+ / f"{self.category}.jinja2"
+ )
+
+ self._template = Template(open(_path).read())
+
+ @property
+ def prompt(self) -> ChatType: # type: ignore
+ """The prompt to be used in the `process` method, rendering the `_template` with the
+ provided args / attributes.
+ """
+ return [{"role": "user", "content": self._template.render().strip()}] # type: ignore
+
+ @override
+ def process(self, offset: int = 0) -> GeneratorStepOutput: # type: ignore
+ """Method to run the `LLM` generation with the `prompt`, as well as formatting the
+ outputs accordingly for the task i.e. via the `_JSONFormatter` inheritance. So on, the
+ `LLM` ideally will be prompted to produce JSON content and then the `format_output`
+ method will parse it into a Python dictionary based on the `keys` property.
+
+ Args:
+ offset: The offset to start the generation from. Defaults to 0.
+
+ Yields:
+ The output rows and a boolean indicating if it's the last batch or not.
+ """
+ formatted_inputs = [self.prompt]
+ outputs = self.llm.generate(
+ inputs=formatted_inputs,
+ num_generations=self.num_generations,
+ **self.llm.generation_kwargs, # type: ignore
+ )
+
+ task_outputs = []
+ for input_outputs in outputs:
+ formatted_outputs = self._format_outputs(input_outputs) # type: ignore
+ for formatted_output in formatted_outputs:
+ if isinstance(formatted_output["tasks"], list) and self.flatten_tasks:
+ tasks = formatted_output.pop("tasks")
+ task_outputs.extend(
+ [
+ {
+ "task": task,
+ **formatted_output,
+ "model_name": self.llm.model_name,
+ }
+ for task in tasks
+ ]
+ )
+ else:
+ if self.flatten_tasks:
+ formatted_output["task"] = formatted_output.pop("tasks")
+ task_outputs.append(
+ {**formatted_output, "model_name": self.llm.model_name}
+ )
+ yield task_outputs, True
+
+ @property
+ def outputs(self) -> List[str]:
+ """Contains the output columns produced by the `process` method of the task. In this
+ case, it consists of the `tasks` or `task` (depending on the `flatten_tasks` attribute)
+ and the `model_name`.
+ """
+ return ["tasks" if not self.flatten_tasks else "task", "model_name"]
+
+ def format_output(
+ self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """Method to parse the JSON output into a Python dictionary based on the `keys` property.
+
+ Args:
+ output: The JSON string output produced by the `LLM`.
+ input: The input dictionary that was used to generate the output.
+
+ Returns:
+ A Python dictionary with the parsed output based on the `keys` property.
+ """
+ try:
+ if output is not None:
+ output = eval(output)
+ except Exception:
+ pass
+ return {"tasks": output}
+
+
+class GenerateTextRetrievalData(_EmbeddingDataGeneration):
+ """Generate text retrieval data with an `LLM` to later on train an embedding model.
+
+ `GenerateTextRetrievalData` is a `Task` that generates text retrieval data with an
+ `LLM` to later on train an embedding model. The task is based on the paper "Improving
+ Text Embeddings with Large Language Models" and the data is generated based on the
+ provided attributes, or randomly sampled if not provided.
+
+ Note:
+ Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True`
+ with the `category="text-retrieval"`; so that the `LLM` generates a list of tasks that
+ are flattened so that each row contains a single task for the text-retrieval category.
+
+ Attributes:
+ language: The language of the data to be generated, which can be any of the languages
+ retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf.
+ query_type: The type of query to be generated, which can be `extremely long-tail`, `long-tail`,
+ or `common`. Defaults to `None`, meaning that it will be randomly sampled.
+ query_length: The length of the query to be generated, which can be `less than 5 words`, `5 to 15 words`,
+ or `at least 10 words`. Defaults to `None`, meaning that it will be randomly sampled.
+ difficulty: The difficulty of the query to be generated, which can be `high school`, `college`, or `PhD`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ clarity: The clarity of the query to be generated, which can be `clear`, `understandable with some effort`,
+ or `ambiguous`. Defaults to `None`, meaning that it will be randomly sampled.
+ num_words: The number of words in the query to be generated, which can be `50`, `100`, `200`, `300`, `400`, or `500`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ seed: The random seed to be set in case there's any sampling within the `format_input` method.
+
+ References:
+ - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368)
+
+ Examples:
+
+ Generate synthetic text retrieval data for training embedding models:
+
+ ```python
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateTextRetrievalData
+
+ with Pipeline("my-pipeline") as pipeline:
+ task = EmbeddingTaskGenerator(
+ category="text-retrieval",
+ flatten_tasks=True,
+ llm=..., # LLM instance
+ )
+
+ generate = GenerateTextRetrievalData(
+ language="English",
+ query_type="common",
+ query_length="5 to 15 words",
+ difficulty="high school",
+ clarity="clear",
+ num_words=100,
+ llm=..., # LLM instance
+ )
+
+ task >> generate
+ ```
+ """
+
+ language: str = Field(
+ default="English",
+ description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf",
+ )
+
+ query_type: Optional[Literal["extremely long-tail", "long-tail", "common"]] = None
+ query_length: Optional[
+ Literal["less than 5 words", "5 to 15 words", "at least 10 words"]
+ ] = None
+ difficulty: Optional[Literal["high school", "college", "PhD"]] = None
+ clarity: Optional[
+ Literal["clear", "understandable with some effort", "ambiguous"]
+ ] = None
+ num_words: Optional[Literal[50, 100, 200, 300, 400, 500]] = None
+
+ _template_name: str = PrivateAttr(default="text-retrieval")
+
+ def format_input(self, input: Dict[str, Any]) -> ChatType:
+ """Method to format the input based on the `task` and the provided attributes, or just
+ randomly sampling those if not provided. This method will render the `_template` with
+ the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that
+ there's only one turn, being from the user with the content being the rendered `_template`.
+
+ Args:
+ input: The input dictionary containing the `task` to be used in the `_template`.
+
+ Returns:
+ A list with a single chat containing the user's message with the rendered `_template`.
+ """
+ return [
+ {
+ "role": "user",
+ "content": self._template.render( # type: ignore
+ task=input["task"],
+ language=self.language,
+ query_type=self.query_type
+ or random.choice(["extremely long-tail", "long-tail", "common"]),
+ query_length=self.query_length
+ or random.choice(
+ ["less than 5 words", "5 to 15 words", "at least 10 words"]
+ ),
+ difficulty=self.difficulty
+ or random.choice(["high school", "college", "PhD"]),
+ clarity=self.clarity
+ or random.choice(
+ ["clear", "understandable with some effort", "ambiguous"]
+ ),
+ num_words=self.num_words
+ or random.choice([50, 100, 200, 300, 400, 500]),
+ ).strip(),
+ }
+ ]
+
+ @property
+ def keys(self) -> List[str]:
+ """Contains the `keys` that will be parsed from the `LLM` output into a Python dict."""
+ return [
+ "user_query",
+ "positive_document",
+ "hard_negative_document",
+ ]
+
+
+class GenerateShortTextMatchingData(_EmbeddingDataGeneration):
+ """Generate short text matching data with an `LLM` to later on train an embedding model.
+
+ `GenerateShortTextMatchingData` is a `Task` that generates short text matching data with an
+ `LLM` to later on train an embedding model. The task is based on the paper "Improving
+ Text Embeddings with Large Language Models" and the data is generated based on the
+ provided attributes, or randomly sampled if not provided.
+
+ Note:
+ Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True`
+ with the `category="text-matching-short"`; so that the `LLM` generates a list of tasks that
+ are flattened so that each row contains a single task for the text-matching-short category.
+
+ Attributes:
+ language: The language of the data to be generated, which can be any of the languages
+ retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf.
+ seed: The random seed to be set in case there's any sampling within the `format_input` method.
+ Note that in this task the `seed` has no effect since there are no sampling params.
+
+ References:
+ - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368)
+
+ Examples:
+
+ Generate synthetic short text matching data for training embedding models:
+
+ ```python
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateShortTextMatchingData
+
+ with Pipeline("my-pipeline") as pipeline:
+ task = EmbeddingTaskGenerator(
+ category="text-matching-short",
+ flatten_tasks=True,
+ llm=..., # LLM instance
+ )
+
+ generate = GenerateShortTextMatchingData(
+ language="English",
+ llm=..., # LLM instance
+ )
+
+ task >> generate
+ ```
+ """
+
+ language: str = Field(
+ default="English",
+ description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf",
+ )
+
+ _template_name: str = PrivateAttr(default="short-text-matching")
+
+ def format_input(self, input: Dict[str, Any]) -> ChatType:
+ """Method to format the input based on the `task` and the provided attributes, or just
+ randomly sampling those if not provided. This method will render the `_template` with
+ the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that
+ there's only one turn, being from the user with the content being the rendered `_template`.
+
+ Args:
+ input: The input dictionary containing the `task` to be used in the `_template`.
+
+ Returns:
+ A list with a single chat containing the user's message with the rendered `_template`.
+ """
+ return [
+ {
+ "role": "user",
+ "content": self._template.render( # type: ignore
+ task=input["task"],
+ language=self.language,
+ ).strip(),
+ }
+ ]
+
+ @property
+ def keys(self) -> List[str]:
+ """Contains the `keys` that will be parsed from the `LLM` output into a Python dict."""
+ return ["input", "positive_document"]
+
+
+class GenerateLongTextMatchingData(_EmbeddingDataGeneration):
+ """Generate long text matching data with an `LLM` to later on train an embedding model.
+
+ `GenerateLongTextMatchingData` is a `Task` that generates long text matching data with an
+ `LLM` to later on train an embedding model. The task is based on the paper "Improving
+ Text Embeddings with Large Language Models" and the data is generated based on the
+ provided attributes, or randomly sampled if not provided.
+
+ Note:
+ Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True`
+ with the `category="text-matching-long"`; so that the `LLM` generates a list of tasks that
+ are flattened so that each row contains a single task for the text-matching-long category.
+
+ Attributes:
+ language: The language of the data to be generated, which can be any of the languages
+ retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf.
+ seed: The random seed to be set in case there's any sampling within the `format_input` method.
+ Note that in this task the `seed` has no effect since there are no sampling params.
+
+ References:
+ - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368)
+
+ Examples:
+
+ Generate synthetic long text matching data for training embedding models:
+
+ ```python
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateLongTextMatchingData
+
+ with Pipeline("my-pipeline") as pipeline:
+ task = EmbeddingTaskGenerator(
+ category="text-matching-long",
+ flatten_tasks=True,
+ llm=..., # LLM instance
+ )
+
+ generate = GenerateLongTextMatchingData(
+ language="English",
+ llm=..., # LLM instance
+ )
+
+ task >> generate
+ ```
+ """
+
+ language: str = Field(
+ default="English",
+ description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf",
+ )
+
+ _template_name: str = PrivateAttr(default="long-text-matching")
+
+ def format_input(self, input: Dict[str, Any]) -> ChatType:
+ """Method to format the input based on the `task` and the provided attributes, or just
+ randomly sampling those if not provided. This method will render the `_template` with
+ the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that
+ there's only one turn, being from the user with the content being the rendered `_template`.
+
+ Args:
+ input: The input dictionary containing the `task` to be used in the `_template`.
+
+ Returns:
+ A list with a single chat containing the user's message with the rendered `_template`.
+ """
+ return [
+ {
+ "role": "user",
+ "content": self._template.render( # type: ignore
+ task=input["task"],
+ language=self.language,
+ ).strip(),
+ }
+ ]
+
+ @property
+ def keys(self) -> List[str]:
+ """Contains the `keys` that will be parsed from the `LLM` output into a Python dict."""
+ return ["input", "positive_document"]
+
+
+class GenerateTextClassificationData(_EmbeddingDataGeneration):
+ """Generate text classification data with an `LLM` to later on train an embedding model.
+
+ `GenerateTextClassificationData` is a `Task` that generates text classification data with an
+ `LLM` to later on train an embedding model. The task is based on the paper "Improving
+ Text Embeddings with Large Language Models" and the data is generated based on the
+ provided attributes, or randomly sampled if not provided.
+
+ Note:
+ Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True`
+ with the `category="text-classification"`; so that the `LLM` generates a list of tasks that
+ are flattened so that each row contains a single task for the text-classification category.
+
+ Attributes:
+ language: The language of the data to be generated, which can be any of the languages
+ retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf.
+ difficulty: The difficulty of the query to be generated, which can be `high school`, `college`, or `PhD`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ clarity: The clarity of the query to be generated, which can be `clear`, `understandable with some effort`,
+ or `ambiguous`. Defaults to `None`, meaning that it will be randomly sampled.
+ seed: The random seed to be set in case there's any sampling within the `format_input` method.
+
+ References:
+ - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368)
+
+ Examples:
+
+ Generate synthetic text classification data for training embedding models:
+
+ ```python
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateTextClassificationData
+
+ with Pipeline("my-pipeline") as pipeline:
+ task = EmbeddingTaskGenerator(
+ category="text-classification",
+ flatten_tasks=True,
+ llm=..., # LLM instance
+ )
+
+ generate = GenerateTextClassificationData(
+ language="English",
+ difficulty="high school",
+ clarity="clear",
+ llm=..., # LLM instance
+ )
+
+ task >> generate
+ ```
+ """
+
+ language: str = Field(
+ default="English",
+ description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf",
+ )
+
+ difficulty: Optional[Literal["high school", "college", "PhD"]] = None
+ clarity: Optional[
+ Literal["clear", "understandable with some effort", "ambiguous"]
+ ] = None
+
+ _template_name: str = PrivateAttr(default="text-classification")
+
+ def format_input(self, input: Dict[str, Any]) -> ChatType:
+ """Method to format the input based on the `task` and the provided attributes, or just
+ randomly sampling those if not provided. This method will render the `_template` with
+ the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that
+ there's only one turn, being from the user with the content being the rendered `_template`.
+
+ Args:
+ input: The input dictionary containing the `task` to be used in the `_template`.
+
+ Returns:
+ A list with a single chat containing the user's message with the rendered `_template`.
+ """
+ return [
+ {
+ "role": "user",
+ "content": self._template.render( # type: ignore
+ task=input["task"],
+ language=self.language,
+ difficulty=self.difficulty
+ or random.choice(["high school", "college", "PhD"]),
+ clarity=self.clarity
+ or random.choice(
+ ["clear", "understandable with some effort", "ambiguous"]
+ ),
+ ).strip(),
+ }
+ ]
+
+ @property
+ def keys(self) -> List[str]:
+ """Contains the `keys` that will be parsed from the `LLM` output into a Python dict."""
+ return ["input_text", "label", "misleading_label"]
+
+
+class MonolingualTripletGenerator(_EmbeddingDataGenerator):
+ """Generate monolingual triplets with an `LLM` to later on train an embedding model.
+
+ `MonolingualTripletGenerator` is a `GeneratorTask` that generates monolingual triplets with an
+ `LLM` to later on train an embedding model. The task is based on the paper "Improving
+ Text Embeddings with Large Language Models" and the data is generated based on the
+ provided attributes, or randomly sampled if not provided.
+
+ Attributes:
+ language: The language of the data to be generated, which can be any of the languages
+ retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf.
+ unit: The unit of the data to be generated, which can be `sentence`, `phrase`, or `passage`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ difficulty: The difficulty of the query to be generated, which can be `elementary school`, `high school`, or `college`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ high_score: The high score of the query to be generated, which can be `4`, `4.5`, or `5`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ low_score: The low score of the query to be generated, which can be `2.5`, `3`, or `3.5`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ seed: The random seed to be set in case there's any sampling within the `format_input` method.
+
+ Examples:
+
+ Generate monolingual triplets for training embedding models:
+
+ ```python
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps.tasks import MonolingualTripletGenerator
+
+ with Pipeline("my-pipeline") as pipeline:
+ task = MonolingualTripletGenerator(
+ language="English",
+ unit="sentence",
+ difficulty="elementary school",
+ high_score="4",
+ low_score="2.5",
+ llm=...,
+ )
+
+ ...
+
+ task >> ...
+ ```
+ """
+
+ language: str = Field(
+ default="English",
+ description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf",
+ )
+
+ unit: Optional[Literal["sentence", "phrase", "passage"]] = None
+ difficulty: Optional[Literal["elementary school", "high school", "college"]] = None
+ high_score: Optional[Literal["4", "4.5", "5"]] = None
+ low_score: Optional[Literal["2.5", "3", "3.5"]] = None
+
+ _template_name: str = PrivateAttr(default="monolingual-triplet")
+
+ @property
+ def prompt(self) -> ChatType:
+ """Contains the `prompt` to be used in the `process` method, rendering the `_template`; and
+ formatted as an OpenAI formatted chat i.e. a `ChatType`, assuming that there's only one turn,
+ being from the user with the content being the rendered `_template`.
+ """
+ return [
+ {
+ "role": "user",
+ "content": self._template.render( # type: ignore
+ language=self.language,
+ unit=self.unit or random.choice(["sentence", "phrase", "passage"]),
+ difficulty=self.difficulty
+ or random.choice(["elementary school", "high school", "college"]),
+ high_score=self.high_score or random.choice(["4", "4.5", "5"]),
+ low_score=self.low_score or random.choice(["2.5", "3", "3.5"]),
+ ).strip(),
+ }
+ ] # type: ignore
+
+ @property
+ def keys(self) -> List[str]:
+ """Contains the `keys` that will be parsed from the `LLM` output into a Python dict."""
+ return ["S1", "S2", "S3"]
+
+
+class BitextRetrievalGenerator(_EmbeddingDataGenerator):
+ """Generate bitext retrieval data with an `LLM` to later on train an embedding model.
+
+ `BitextRetrievalGenerator` is a `GeneratorTask` that generates bitext retrieval data with an
+ `LLM` to later on train an embedding model. The task is based on the paper "Improving
+ Text Embeddings with Large Language Models" and the data is generated based on the
+ provided attributes, or randomly sampled if not provided.
+
+ Attributes:
+ source_language: The source language of the data to be generated, which can be any of the languages
+ retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf.
+ target_language: The target language of the data to be generated, which can be any of the languages
+ retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf.
+ unit: The unit of the data to be generated, which can be `sentence`, `phrase`, or `passage`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ difficulty: The difficulty of the query to be generated, which can be `elementary school`, `high school`, or `college`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ high_score: The high score of the query to be generated, which can be `4`, `4.5`, or `5`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ low_score: The low score of the query to be generated, which can be `2.5`, `3`, or `3.5`.
+ Defaults to `None`, meaning that it will be randomly sampled.
+ seed: The random seed to be set in case there's any sampling within the `format_input` method.
+
+ Examples:
+
+ Generate bitext retrieval data for training embedding models:
+
+ ```python
+ from distilabel.pipeline import Pipeline
+ from distilabel.steps.tasks import BitextRetrievalGenerator
+
+ with Pipeline("my-pipeline") as pipeline:
+ task = BitextRetrievalGenerator(
+ source_language="English",
+ target_language="Spanish",
+ unit="sentence",
+ difficulty="elementary school",
+ high_score="4",
+ low_score="2.5",
+ llm=...,
+ )
+
+ ...
+
+ task >> ...
+ ```
+ """
+
+ source_language: str = Field(
+ default="English",
+ description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf",
+ )
+ target_language: str = Field(
+ default=...,
+ description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf",
+ )
+
+ unit: Optional[Literal["sentence", "phrase", "passage"]] = None
+ difficulty: Optional[Literal["elementary school", "high school", "college"]] = None
+ high_score: Optional[Literal["4", "4.5", "5"]] = None
+ low_score: Optional[Literal["2.5", "3", "3.5"]] = None
+
+ _template_name: str = PrivateAttr(default="bitext-retrieval")
+
+ @property
+ def prompt(self) -> ChatType:
+ """Contains the `prompt` to be used in the `process` method, rendering the `_template`; and
+ formatted as an OpenAI formatted chat i.e. a `ChatType`, assuming that there's only one turn,
+ being from the user with the content being the rendered `_template`.
+ """
+ return [
+ {
+ "role": "user",
+ "content": self._template.render( # type: ignore
+ source_language=self.source_language,
+ target_language=self.target_language,
+ unit=self.unit or random.choice(["sentence", "phrase", "passage"]),
+ difficulty=self.difficulty
+ or random.choice(["elementary school", "high school", "college"]),
+ high_score=self.high_score or random.choice(["4", "4.5", "5"]),
+ low_score=self.low_score or random.choice(["2.5", "3", "3.5"]),
+ ).strip(),
+ }
+ ] # type: ignore
+
+ @property
+ def keys(self) -> List[str]:
+ """Contains the `keys` that will be parsed from the `LLM` output into a Python dict."""
+ return ["S1", "S2", "S3"]
diff --git a/src/distilabel/steps/tasks/pair_rm.py b/src/distilabel/steps/tasks/pair_rm.py
index 3c1ecc7a56..be23a38699 100644
--- a/src/distilabel/steps/tasks/pair_rm.py
+++ b/src/distilabel/steps/tasks/pair_rm.py
@@ -49,6 +49,37 @@ class PairRM(Step):
Note:
This step differs to other tasks as there is a single implementation of this model
currently, and we will use a specific `LLM`.
+
+ Examples:
+
+ Rank LLM candidates:
+
+ ```python
+ from distilabel.steps.tasks import PairRM
+
+ # Consider this as a placeholder for your actual LLM.
+ pair_rm = PairRM()
+
+ pair_rm.load()
+
+ result = next(
+ scorer.process(
+ [
+ {"input": "Hello, how are you?", "candidates": ["fine", "good", "bad"]},
+ ]
+ )
+ )
+ # result
+ # [
+ # {
+ # 'input': 'Hello, how are you?',
+ # 'candidates': ['fine', 'good', 'bad'],
+ # 'ranks': [2, 1, 3],
+ # 'ranked_candidates': ['good', 'fine', 'bad'],
+ # 'model_name': 'llm-blender/PairRM',
+ # }
+ # ]
+ ```
"""
model: str = "llm-blender/PairRM"
diff --git a/src/distilabel/steps/tasks/prometheus_eval.py b/src/distilabel/steps/tasks/prometheus_eval.py
index 0edde308df..294f9f4d0e 100644
--- a/src/distilabel/steps/tasks/prometheus_eval.py
+++ b/src/distilabel/steps/tasks/prometheus_eval.py
@@ -135,6 +135,165 @@ class PrometheusEval(Task):
References:
- [Prometheus 2: An Open Source Language Model Specialized in Evaluating Other Language Models](https://arxiv.org/abs/2405.01535)
- [prometheus-eval: Evaluate your LLM's response with Prometheus 💯](https://github.com/prometheus-eval/prometheus-eval)
+
+ Examples:
+
+ Critique and evaluate LLM generation quality using Prometheus 2.0:
+
+ ```python
+ from distilabel.steps.tasks import PrometheusEval
+ from distilabel.llms import vLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ prometheus = PrometheusEval(
+ llm=vLLM(
+ model="prometheus-eval/prometheus-7b-v2.0",
+ chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]",
+ ),
+ mode="absolute",
+ rubric="factual-validity"
+ )
+
+ prometheus.load()
+
+ result = next(
+ prometheus.process(
+ [
+ {"instruction": "make something", "generation": "something done"},
+ ]
+ )
+ )
+ # result
+ # [
+ # {
+ # 'instruction': 'make something',
+ # 'generation': 'something done',
+ # 'model_name': 'prometheus-eval/prometheus-7b-v2.0',
+ # 'feedback': 'the feedback',
+ # 'result': 6,
+ # }
+ # ]
+ ```
+
+ Critique for relative evaluation:
+
+ ```python
+ from distilabel.steps.tasks import PrometheusEval
+ from distilabel.llms import vLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ prometheus = PrometheusEval(
+ llm=vLLM(
+ model="prometheus-eval/prometheus-7b-v2.0",
+ chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]",
+ ),
+ mode="relative",
+ rubric="honesty"
+ )
+
+ prometheus.load()
+
+ result = next(
+ prometheus.process(
+ [
+ {"instruction": "make something", "generations": ["something done", "other thing"]},
+ ]
+ )
+ )
+ # result
+ # [
+ # {
+ # 'instruction': 'make something',
+ # 'generations': ['something done', 'other thing'],
+ # 'model_name': 'prometheus-eval/prometheus-7b-v2.0',
+ # 'feedback': 'the feedback',
+ # 'result': 'something done',
+ # }
+ # ]
+ ```
+
+ Critique with a custom rubric:
+
+ ```python
+ from distilabel.steps.tasks import PrometheusEval
+ from distilabel.llms import vLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ prometheus = PrometheusEval(
+ llm=vLLM(
+ model="prometheus-eval/prometheus-7b-v2.0",
+ chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]",
+ ),
+ mode="absolute",
+ rubric="custom",
+ rubrics={
+ "custom": "[A]\nScore 1: A\nScore 2: B\nScore 3: C\nScore 4: D\nScore 5: E"
+ }
+ )
+
+ prometheus.load()
+
+ result = next(
+ prometheus.process(
+ [
+ {"instruction": "make something", "generation": "something done"},
+ ]
+ )
+ )
+ # result
+ # [
+ # {
+ # 'instruction': 'make something',
+ # 'generation': 'something done',
+ # 'model_name': 'prometheus-eval/prometheus-7b-v2.0',
+ # 'feedback': 'the feedback',
+ # 'result': 6,
+ # }
+ # ]
+ ```
+
+ Critique using a reference answer:
+
+ ```python
+ from distilabel.steps.tasks import PrometheusEval
+ from distilabel.llms import vLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ prometheus = PrometheusEval(
+ llm=vLLM(
+ model="prometheus-eval/prometheus-7b-v2.0",
+ chat_template="[INST] {{ messages[0]\"content\" }}\n{{ messages[1]\"content\" }}[/INST]",
+ ),
+ mode="absolute",
+ rubric="helpfulness",
+ reference=True,
+ )
+
+ prometheus.load()
+
+ result = next(
+ prometheus.process(
+ [
+ {
+ "instruction": "make something",
+ "generation": "something done",
+ "reference": "this is a reference answer",
+ },
+ ]
+ )
+ )
+ # result
+ # [
+ # {
+ # 'instruction': 'make something',
+ # 'generation': 'something done',
+ # 'reference': 'this is a reference answer',
+ # 'model_name': 'prometheus-eval/prometheus-7b-v2.0',
+ # 'feedback': 'the feedback',
+ # 'result': 6,
+ # }
+ # ]
+ ```
"""
mode: Literal["absolute", "relative"]
@@ -202,7 +361,7 @@ def inputs(self) -> List[str]:
if self.reference:
return ["instruction", "generation", "reference"]
return ["instruction", "generation"]
- else: # self.mode == "relative"
+ else:
if self.reference:
return ["instruction", "generations", "reference"]
return ["instruction", "generations"]
diff --git a/src/distilabel/steps/tasks/quality_scorer.py b/src/distilabel/steps/tasks/quality_scorer.py
index f805c91e38..a93c2a399a 100644
--- a/src/distilabel/steps/tasks/quality_scorer.py
+++ b/src/distilabel/steps/tasks/quality_scorer.py
@@ -59,6 +59,43 @@ class QualityScorer(Task):
References:
- [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685)
+
+ Examples:
+
+ Evaluate the quality of your instructions:
+
+ ```python
+ from distilabel.steps.tasks import QualityScorer
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ scorer = QualityScorer(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ )
+ )
+
+ scorer.load()
+
+ result = next(
+ scorer.process(
+ [
+ {
+ "instruction": "instruction",
+ "responses": ["good response", "weird response", "bad response"]
+ }
+ ]
+ )
+ )
+ # result
+ [
+ {
+ 'instructions': 'instruction',
+ 'model_name': 'test',
+ 'scores': [5, 3, 1],
+ }
+ ]
+ ```
"""
_template: Union[Template, None] = PrivateAttr(...)
diff --git a/src/distilabel/steps/tasks/self_instruct.py b/src/distilabel/steps/tasks/self_instruct.py
index 6bb673b8cf..34d3ffee06 100644
--- a/src/distilabel/steps/tasks/self_instruct.py
+++ b/src/distilabel/steps/tasks/self_instruct.py
@@ -60,6 +60,34 @@ class SelfInstruct(Task):
Reference:
- [`Self-Instruct: Aligning Language Models with Self-Generated Instructions`](https://arxiv.org/abs/2212.10560)
+
+ Examples:
+
+ Generate instructions based on a given input:
+
+ ```python
+ from distilabel.steps.tasks import SelfInstruct
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ self_instruct = SelfInstruct(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ ),
+ num_instructions=5, # This is the default value
+ )
+
+ self_instruct.load()
+
+ result = next(self_instruct.process([{"input": "instruction"}]))
+ # result
+ # [
+ # {
+ # 'input': 'instruction',
+ # 'model_name': 'mistralai/Mistral-7B-Instruct-v0.2',
+ # 'instructions': ["instruction 1", "instruction 2", "instruction 3", "instruction 4", "instruction 5"],
+ # }
+ # ]
+ ```
"""
num_instructions: int = 5
diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py
new file mode 100644
index 0000000000..b1ad50f5e1
--- /dev/null
+++ b/src/distilabel/steps/tasks/sentence_transformers.py
@@ -0,0 +1,291 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import sys
+from typing import TYPE_CHECKING, Any, Dict, Final, List, Literal, Optional, Union
+
+from jinja2 import Template
+
+from distilabel.steps.tasks.base import Task
+
+if sys.version_info < (3, 9):
+ import importlib_resources
+else:
+ import importlib.resources as importlib_resources
+
+if TYPE_CHECKING:
+ from distilabel.steps.tasks.typing import ChatType
+
+GenerationAction = Literal["paraphrase", "semantically-similar", "query", "answer"]
+
+POSITIVE_NEGATIVE_PAIR_REGEX = re.compile(
+ r"## Positive\s+(.*?)(?:\s+## Negative\s+(.*?))?\s*$",
+ re.DOTALL,
+)
+
+GENERATION_ACTION_SENTENCES: Final[Dict[GenerationAction, str]] = {
+ "paraphrase": "paraphrase",
+ "semantically-similar": "be semantically similar to",
+ "query": "be a query for",
+ "answer": "be an answer for",
+}
+
+POSITIVE_SYSTEM_PROMPT: str = (
+ "Your task is to generate a positive sentence given an anchor sentence.{context} The positive"
+ " sentence has to {action_sentence} the anchor sentence. You must output only one new"
+ " section: `## Positive`."
+)
+
+POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = (
+ "Your task is to generate a positive and a negative sentence given an anchor sentence.{context}"
+ " The positive sentence has to {action_sentence} the anchor sentence, while the negative"
+ " sentence can use similar words but must not be related to the anchor sentence. You"
+ " must output only two new sections: `## Positive` and `## Negative`."
+)
+
+CONTEXT_INTRO: Final[str] = " Take into account the context given."
+
+
+class GenerateSentencePair(Task):
+ """Generate a positive and negative (optionally) sentences given an anchor sentence.
+
+ `GenerateSentencePair` is a pre-defined task that given an anchor sentence generates
+ a positive sentence related to the anchor and optionally a negative sentence unrelated
+ to the anchor. Optionally, you can give a context to guide the LLM towards more specific
+ behavior. This task is useful to generate training datasets for training embeddings
+ models.
+
+ Attributes:
+ triplet: a flag to indicate if the task should generate a triplet of sentences
+ (anchor, positive, negative). Defaults to `False`.
+ action: the action to perform to generate the positive sentence.
+ context: the context to use for the generation. Can be helpful to guide the LLM
+ towards more specific context. Not used by default.
+
+ Input columns:
+ - anchor (`str`): The anchor sentence to generate the positive and negative sentences.
+
+ Output columns:
+ - positive (`str`): The positive sentence related to the `anchor`.
+ - negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`.
+ - model_name (`str`): The name of the model that was used to generate the sentences.
+
+ Categories:
+ - embedding
+
+ Examples:
+
+ Paraphrasing:
+
+ ```python
+ from distilabel.steps.tasks import GenerateSentencePair
+ from distilabel.llms import InferenceEndpointsLLM
+
+ generate_sentence_pair = GenerateSentencePair(
+ triplet=True, # `False` to generate only positive
+ action="paraphrase",
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ ),
+ input_batch_size=10,
+ )
+
+ generate_sentence_pair.load()
+
+ result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}])
+ ```
+
+ Generating semantically similar sentences:
+
+ ```python
+ from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.steps.tasks import GenerateSentencePair
+
+ generate_sentence_pair = GenerateSentencePair(
+ triplet=True, # `False` to generate only positive
+ action="semantically-similar",
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ ),
+ input_batch_size=10,
+ )
+
+ generate_sentence_pair.load()
+
+ result = generate_sentence_pair.process([{"anchor": "How does 3D printing work?"}])
+ ```
+
+ Generating queries:
+
+ ```python
+ from distilabel.steps.tasks import GenerateSentencePair
+ from distilabel.llms import InferenceEndpointsLLM
+
+ generate_sentence_pair = GenerateSentencePair(
+ triplet=True, # `False` to generate only positive
+ action="query",
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ ),
+ input_batch_size=10,
+ )
+
+ generate_sentence_pair.load()
+
+ result = generate_sentence_pair.process([{"anchor": "Argilla is an open-source data curation platform for LLMs. Using Argilla, ..."}])
+ ```
+
+ Generating answers:
+
+ ```python
+ from distilabel.steps.tasks import GenerateSentencePair
+ from distilabel.llms import InferenceEndpointsLLM
+
+ generate_sentence_pair = GenerateSentencePair(
+ triplet=True, # `False` to generate only positive
+ action="answer",
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ ),
+ input_batch_size=10,
+ )
+
+ generate_sentence_pair.load()
+
+ result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}])
+ ```
+
+ Generating queries with context (**applies to every action**):
+
+ ```python
+ from distilabel.steps.tasks import GenerateSentencePair
+ from distilabel.llms import InferenceEndpointsLLM
+
+ generate_sentence_pair = GenerateSentencePair(
+ triplet=True, # `False` to generate only positive
+ action="query",
+ context="Argilla is an open-source data curation platform for LLMs.",
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ ),
+ input_batch_size=10,
+ )
+
+ generate_sentence_pair.load()
+
+ result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}])
+ ```
+ """
+
+ triplet: bool = False
+ action: GenerationAction
+ context: str = ""
+
+ def load(self) -> None:
+ """Loads the Jinja2 template."""
+ super().load()
+
+ _path = str(
+ importlib_resources.files("distilabel")
+ / "steps"
+ / "tasks"
+ / "templates"
+ / "generate-sentence-pair.jinja2"
+ )
+
+ self._template = Template(open(_path).read())
+
+ @property
+ def inputs(self) -> List[str]:
+ """The inputs for the task is the `anchor` sentence."""
+ return ["anchor"]
+
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ """The inputs are formatted as a `ChatType`, with a system prompt describing the
+ task of generating a positive and negative sentences for the anchor sentence. The
+ anchor is provided as the first user interaction in the conversation.
+
+ Args:
+ input: The input containing the `anchor` sentence.
+
+ Returns:
+ A list of dictionaries containing the system and user interactions.
+ """
+ action_sentence = GENERATION_ACTION_SENTENCES[self.action]
+ system_prompt = (
+ POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT
+ ).format(
+ action_sentence=action_sentence,
+ context=CONTEXT_INTRO if self.context else "",
+ )
+
+ return [
+ {"role": "system", "content": system_prompt},
+ {
+ "role": "user",
+ "content": self._template.render(
+ anchor=input["anchor"],
+ context=self.context if self.context else None,
+ ),
+ },
+ ]
+
+ @property
+ def outputs(self) -> List[str]:
+ """The outputs for the task are the `positive` and `negative` sentences, as well
+ as the `model_name` used to generate the sentences."""
+ columns = ["positive", "negative"] if self.triplet else ["positive"]
+ columns += ["model_name"]
+ return columns
+
+ def format_output(
+ self, output: Union[str, None], input: Optional[Dict[str, Any]] = None
+ ) -> Dict[str, Any]:
+ """Formats the output of the LLM, to extract the `positive` and `negative` sentences
+ generated. If the output is `None` or the regex doesn't match, then the outputs
+ will be set to `None` as well.
+
+ Args:
+ output: The output of the LLM.
+ input: The input used to generate the output.
+
+ Returns:
+ The formatted output containing the `positive` and `negative` sentences.
+ """
+ if output is None:
+ return {"positive": None, "negative": None}
+
+ match = POSITIVE_NEGATIVE_PAIR_REGEX.match(output)
+ if match is None:
+ formatted_output = {"positive": None}
+ if self.triplet:
+ formatted_output["negative"] = None
+ return formatted_output
+
+ groups = match.groups()
+ if self.triplet:
+ return {
+ "positive": groups[0].strip(),
+ "negative": groups[1].strip()
+ if len(groups) > 1 and groups[1] is not None
+ else None,
+ }
+
+ return {"positive": groups[0].strip()}
diff --git a/src/distilabel/steps/tasks/structured_generation.py b/src/distilabel/steps/tasks/structured_generation.py
new file mode 100644
index 0000000000..240cd44698
--- /dev/null
+++ b/src/distilabel/steps/tasks/structured_generation.py
@@ -0,0 +1,187 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from typing import Any, Dict, List, Union
+
+from distilabel.steps.tasks.base import Task
+from distilabel.steps.tasks.typing import StructuredInput
+
+
+class StructuredGeneration(Task):
+ """Generate structured content for a given `instruction` using an `LLM`.
+
+ `StructuredGeneration` is a pre-defined task that defines the `instruction` and the `structured_output`
+ as the inputs, and `generation` as the output. This task is used to generate structured content based on
+ the input instruction and following the schema provided within the `structured_output` column per each
+ `instruction`. The `model_name` also returned as part of the output in order to enhance it.
+
+ Attributes:
+ use_system_prompt: Whether to use the system prompt in the generation. Defaults to `True`,
+ which means that if the column `system_prompt` is defined within the input batch, then
+ the `system_prompt` will be used, otherwise, it will be ignored.
+
+ Input columns:
+ - instruction (`str`): The instruction to generate structured content from.
+ - structured_output (`Dict[str, Any]`): The structured_output to generate structured content from. It should be a
+ Python dictionary with the keys `format` and `schema`, where `format` should be one of `json` or
+ `regex`, and the `schema` should be either the JSON schema or the regex pattern, respectively.
+
+ Output columns:
+ - generation (`str`): The generated text matching the provided schema, if possible.
+ - model_name (`str`): The name of the model used to generate the text.
+
+ Categories:
+ - outlines
+ - structured-generation
+
+ Examples:
+
+ Generate structured output from a JSON schema:
+
+ ```python
+ from distilabel.steps.tasks import StructuredGeneration
+ from distilabel.llms import InferenceEndpointsLLM
+
+ structured_gen = StructuredGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ ),
+ )
+
+ structured_gen.load()
+
+ result = next(
+ structured_gen.process(
+ [
+ {
+ "instruction": "Create an RPG character",
+ "structured_output": {
+ "type": "json",
+ "value": {
+ "properties": {
+ "name": {
+ "title": "Name",
+ "type": "string"
+ },
+ "description": {
+ "title": "Description",
+ "type": "string"
+ },
+ "role": {
+ "title": "Role",
+ "type": "string"
+ },
+ "weapon": {
+ "title": "Weapon",
+ "type": "string"
+ }
+ },
+ "required": [
+ "name",
+ "description",
+ "role",
+ "weapon"
+ ],
+ "title": "Character",
+ "type": "object"
+ }
+ },
+ }
+ ]
+ )
+ )
+ ```
+
+ Generate structured output from a regex pattern:
+
+ ```python
+ from distilabel.steps.tasks import StructuredGeneration
+ from distilabel.llms import InferenceEndpointsLLM
+
+ structured_gen = StructuredGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
+ ),
+ )
+
+ structured_gen.load()
+
+ result = next(
+ structured_gen.process(
+ [
+ {
+ "instruction": "What's the weather like today in Seattle in Celsius degrees?",
+ "structured_output": {
+ "type": "regex",
+ "value": r"(\\d{1,2})°C"
+ },
+
+ }
+ ]
+ )
+ )
+ ```
+ """
+
+ use_system_prompt: bool = False
+
+ @property
+ def inputs(self) -> List[str]:
+ """The input for the task are the `instruction` and the `structured_output`.
+ Optionally, if the `use_system_prompt` flag is set to True, then the
+ `system_prompt` will be used too."""
+ columns = ["instruction", "structured_output"]
+ if self.use_system_prompt:
+ columns = ["system_prompt"] + columns
+ return columns
+
+ def format_input(self, input: Dict[str, Any]) -> StructuredInput:
+ """The input is formatted as a `ChatType` assuming that the instruction
+ is the first interaction from the user within a conversation."""
+ if not isinstance(input["instruction"], str):
+ raise ValueError(
+ f"Input `instruction` must be a string. Got: {input['instruction']}."
+ )
+
+ messages = [{"role": "user", "content": input["instruction"]}]
+ if self.use_system_prompt:
+ if "system_prompt" in input:
+ messages.insert(
+ 0, {"role": "system", "content": input["system_prompt"]}
+ )
+ else:
+ warnings.warn(
+ "`use_system_prompt` is set to `True`, but no `system_prompt` in input batch, so it will be ignored.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ return (messages, input.get("structured_output", None)) # type: ignore
+
+ @property
+ def outputs(self) -> List[str]:
+ """The output for the task is the `generation` and the `model_name`."""
+ return ["generation", "model_name"]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """The output is formatted as a dictionary with the `generation`. The `model_name`
+ will be automatically included within the `process` method of `Task`. Note that even
+ if the `structured_output` is defined to produce a JSON schema, this method will return the raw
+ output i.e. a string without any parsing."""
+ return {"generation": output}
diff --git a/src/distilabel/steps/tasks/structured_outputs/instructor.py b/src/distilabel/steps/tasks/structured_outputs/instructor.py
new file mode 100644
index 0000000000..94ab1097e5
--- /dev/null
+++ b/src/distilabel/steps/tasks/structured_outputs/instructor.py
@@ -0,0 +1,124 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.util
+from typing import (
+ TYPE_CHECKING,
+ Callable,
+ Literal,
+ Optional,
+ Tuple,
+ TypeAlias,
+ Union,
+ get_args,
+)
+
+if TYPE_CHECKING:
+ import instructor
+ from anthropic import AsyncAnthropic
+ from cohere import AsyncClient as AsyncCohere
+ from groq import AsyncGroq
+ from mistralai.async_client import MistralAsyncClient
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
+
+
+InstructorFrameworks = Literal[
+ "openai", "azure_openai", "anthropic", "cohere", "groq", "litellm", "mistral"
+]
+"""Available frameworks for the structured output configuration with `instructor`. """
+
+InstructorAvailableClients: TypeAlias = Union[
+ "AsyncAnthropic",
+ "AsyncAzureOpenAI",
+ "AsyncCohere",
+ "AsyncGroq",
+ "AsyncOpenAI",
+ "MistralAsyncClient",
+]
+"""Available clients that can be wrapped with `instructor`. """
+
+
+def _client_patcher(framework: InstructorFrameworks) -> Tuple[Callable, str]:
+ """Helper function to return the appropriate instructor client for the given framework.
+
+ Args:
+ framework: The framework to use for the instructor client.
+
+ Raises:
+ ValueError: If the framework is not one of the available frameworks.
+
+ Returns:
+ Tuple of Callable and string, with the builder of the client patch and the
+ default mode to use.
+ """
+ import instructor
+
+ if framework in {"openai", "azure_openai"}:
+ patch = instructor.from_openai, instructor.Mode.TOOLS
+ elif framework == "anthropic":
+ patch = instructor.from_anthropic, instructor.Mode.ANTHROPIC_JSON
+ elif framework == "litellm":
+ patch = instructor.from_litellm, instructor.Mode.TOOLS
+ elif framework == "mistral":
+ patch = instructor.from_mistral, instructor.Mode.MISTRAL_TOOLS
+ elif framework == "cohere":
+ patch = instructor.from_cohere, instructor.Mode.COHERE_TOOLS
+ elif framework == "groq":
+ patch = instructor.from_groq, instructor.Mode.TOOLS
+ else:
+ raise ValueError(
+ f"Invalid framework '{framework}'. Must be one of {get_args(InstructorFrameworks)}"
+ )
+
+ return patch
+
+
+def prepare_instructor(
+ client: InstructorAvailableClients,
+ mode: Optional["instructor.Mode"] = None,
+ framework: Optional[InstructorFrameworks] = None,
+) -> "instructor.AsyncInstructor":
+ """Wraps the given client with the instructor client for the given framework.
+
+ Args:
+ client: The client to wrap with the instructor client, corresponds to the internal
+ client we wrap on `LLM`, and one of the implemented in `instructor`.
+ mode: One of the `instructor.Mode` values. Defaults to None.
+ framework: The framework corresponding to the client. Defaults to None.
+
+ Raises:
+ ImportError: If `instructor` is not installed.
+ ValueError: If the mode is not one of the available modes.
+
+ Returns:
+ patched_client: The instructor wrapping the original client to be used for
+ structured generation.
+ """
+ if not importlib.util.find_spec("instructor"):
+ raise ImportError(
+ "`instructor` is not installed. Please install it using `pip install instructor`."
+ )
+ import instructor
+
+ builder, default_mode = _client_patcher(framework)
+
+ mode = mode or default_mode
+ if mode.value not in [m.value for m in instructor.mode.Mode]:
+ raise ValueError(
+ f"Invalid mode '{mode}'. Must be one of {[m.value for m in instructor.mode.Mode]}"
+ )
+
+ patched_client: instructor.AsyncInstructor = builder(client, mode=mode)
+
+ return patched_client
diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py
index 087f4913bc..d726b5e4f5 100644
--- a/src/distilabel/steps/tasks/structured_outputs/outlines.py
+++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py
@@ -19,55 +19,27 @@
Any,
Callable,
Dict,
- List,
Literal,
- Optional,
Tuple,
Type,
- TypedDict,
Union,
get_args,
)
from pydantic import BaseModel
+from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict
+from distilabel.steps.tasks.typing import StructuredOutputType
+
Frameworks = Literal["transformers", "llamacpp", "vllm"]
"""Available frameworks for the structured output configuration. """
-class StructuredOutputType(TypedDict):
- """TypedDict to represent the structured output configuration from outlines."""
-
- format: Literal["json", "regex"]
- """One of "json" or "regex"."""
- schema: Union[str, Type[BaseModel]]
- """The schema to use for the structured output. If "json", it
- can be a pydantic.BaseModel class, or the schema as a string,
- as obtained from `model_to_schema(BaseModel)`, if "regex", it
- should be a regex pattern as a string.
- """
- whitespace_pattern: Optional[Union[str, List[str]]]
- """If "json" corresponds to a string or a list of
- strings with a pattern (doesn't impact string literals).
- For example, to allow only a single space or newline with
- `whitespace_pattern=r"[\n ]?"`
- """
-
-
def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]:
"""Helper function to return a string representation of the schema from a `pydantic.BaseModel` class."""
return json.dumps(schema.model_json_schema())
-def _schema_as_dict(schema: Union[str, Type[BaseModel]]) -> Dict[str, Any]:
- """Helper function to obtain the schema and simplify serialization."""
- if type(schema) == type(BaseModel):
- return schema.model_json_schema()
- elif isinstance(schema, str):
- return json.loads(schema)
- return schema
-
-
def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:
"""Helper function to return the appropriate logits processor for the given framework."""
if framework == "transformers":
@@ -137,7 +109,7 @@ def prepare_guided_output(
llm,
whitespace_pattern=structured_output.get("whitespace_pattern"),
),
- "schema": _schema_as_dict(schema),
+ "schema": schema_as_dict(schema),
}
if format == "regex":
diff --git a/src/distilabel/steps/tasks/structured_outputs/utils.py b/src/distilabel/steps/tasks/structured_outputs/utils.py
new file mode 100644
index 0000000000..8bcebcb819
--- /dev/null
+++ b/src/distilabel/steps/tasks/structured_outputs/utils.py
@@ -0,0 +1,157 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Any, Dict, List, Optional, Type, Union
+
+from pydantic import BaseModel, Field, create_model
+
+
+def schema_as_dict(schema: Union[str, Type[BaseModel]]) -> Dict[str, Any]:
+ """Helper function to obtain the schema and simplify serialization."""
+ if type(schema) == type(BaseModel):
+ return schema.model_json_schema()
+ elif isinstance(schema, str):
+ return json.loads(schema)
+ return schema
+
+
+# NOTE: The following functions were copied from:
+# https://github.com/pydantic/pydantic/issues/643#issuecomment-1999755873
+# and slightly modified to work with nested models.
+# It would be nice to find the original source of this code to give credit.
+# Other option would be working with this library: https://github.com/c32168/dyntamic
+
+
+def json_schema_to_model(json_schema: Dict[str, Any]) -> Type[BaseModel]:
+ """Converts a JSON schema to a `pydantic.BaseModel` class.
+
+ Args:
+ json_schema: The JSON schema to convert.
+
+ Returns:
+ A `pydantic.BaseModel` class.
+ """
+
+ # Extract the model name from the schema title.
+ model_name = json_schema.get("title")
+ if defs := json_schema.get("$defs", None):
+ # This is done to grab the content of nested classes that need to dereference
+ # the objects (those should be in a higher level).
+ pass
+
+ # Extract the field definitions from the schema properties.
+ field_definitions = {
+ name: json_schema_to_pydantic_field(
+ name, prop, json_schema.get("required", []), defs=defs
+ )
+ for name, prop in json_schema.get("properties", {}).items()
+ }
+
+ # Create the BaseModel class using create_model().
+ return create_model(model_name, **field_definitions)
+
+
+def json_schema_to_pydantic_field(
+ name: str,
+ json_schema: Dict[str, Any],
+ required: List[str],
+ defs: Optional[Dict[str, Any]] = None,
+) -> Any:
+ """Converts a JSON schema property to a `pydantic.Field`.
+
+ Args:
+ name: The field name.
+ json_schema: The JSON schema property.
+ required: The list of required fields.
+ defs: The definitions of the JSON schema. It's used to dereference nested classes,
+ so we can grab the original definition from the json schema (it won't
+ work out of the box with just the reference).
+
+ Returns:
+ A `pydantic.Field`.
+ """
+
+ # NOTE(plaguss): This needs more testing, nested classes need extra work to be converted
+ # here if we pass a reference to another class it will crash, we have to find the original
+ # definition and insert it here
+ # This takes into account single items referred to other classes
+ if ref := json_schema.get("$ref"):
+ json_schema = defs.get(ref.split("/")[-1])
+
+ # This takes into account lists of items referred to other classes
+ if "items" in json_schema and (ref := json_schema["items"].get("$ref")):
+ json_schema["items"] = defs.get(ref.split("/")[-1])
+
+ # Get the field type.
+ type_ = json_schema_to_pydantic_type(json_schema)
+
+ # Get the field description.
+ description = json_schema.get("description")
+
+ # Get the field examples.
+ examples = json_schema.get("examples")
+
+ # Create a Field object with the type, description, and examples.
+ # The "required" flag will be set later when creating the model.
+ return (
+ type_,
+ Field(
+ description=description,
+ examples=examples,
+ default=... if name in required else None,
+ ),
+ )
+
+
+def json_schema_to_pydantic_type(json_schema: Dict[str, Any]) -> Any:
+ """Converts a JSON schema type to a Pydantic type.
+
+ Args:
+ json_schema: The JSON schema to convert.
+
+ Returns:
+ A Pydantic type.
+ """
+ type_ = json_schema.get("type")
+
+ if type_ == "string":
+ type_val = str
+ elif type_ == "integer":
+ type_val = int
+ elif type_ == "number":
+ type_val = float
+ elif type_ == "boolean":
+ type_val = bool
+ elif type_ == "array":
+ items_schema = json_schema.get("items")
+ if items_schema:
+ item_type = json_schema_to_pydantic_type(items_schema)
+ type_val = List[item_type]
+ else:
+ type_val = List
+ elif type_ == "object":
+ # Handle nested models.
+ properties = json_schema.get("properties")
+ if properties:
+ nested_model = json_schema_to_model(json_schema)
+ type_val = nested_model
+ else:
+ type_val = Dict
+ elif type_ == "null":
+ type_val = Optional[Any] # Use Optional[Any] for nullable fields
+ else:
+ raise ValueError(f"Unsupported JSON schema type: {type_}")
+
+ return type_val
diff --git a/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 b/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2
new file mode 100644
index 0000000000..cac188e101
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2
@@ -0,0 +1,11 @@
+{% if context is not none -%}
+## Context
+
+{{ context }}
+
+{% endif -%}
+
+## Anchor
+
+{{ anchor }}
+
diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/bitext-retrieval.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/bitext-retrieval.jinja2
new file mode 100644
index 0000000000..1cf238015f
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/bitext-retrieval.jinja2
@@ -0,0 +1,13 @@
+Write a {{ unit }} triple with one {{ unit }} in {{ source_language }} and two {{ unit }}s in {{ target_language }} with varying translation qualities in JSON format.
+
+The triple is denotes as ("S1", "S2", "S3"). The translation quality score ranges from 1 to 5, with higher scores are better.
+
+Please adhere to the following guidelines:
+ - The values of "S1" is a string in {{ source_language }}, the value of "S2" and "S3" are strings in {{ target_language }}.
+ - There should be some word overlaps between "S2" and "S3".
+ - The translation quality score of "S2" with respect to "S1" should be {{ high_score }}.
+ - The translation quality score of "S3" with respect to "S1" should be {{ low_score }}.
+ - "S3" should be grammatical and fluent, but contain some keyword or number translation errors, or miss some information, or contain some redundant information.
+ - "S1" requires {{ difficulty }} level education to understand and should be diverse in terms of topic and length.
+
+Your output must always be a JSON object only with three keys "S1", "S2" and "S3", do not explain yourself or output anything else. Be creative!
diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-classification.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-classification.jinja2
new file mode 100644
index 0000000000..3501b9332d
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-classification.jinja2
@@ -0,0 +1,6 @@
+Brainstorm a list of potentially useful text classification tasks.
+
+Please adhere to the following guidelines:
+ - Tasks should cover a diverse range of domains and task types.
+
+Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct text classification task in one sentence. Do not explain yourself or output anything else. Be creative!
diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-long.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-long.jinja2
new file mode 100644
index 0000000000..0090ef2af4
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-long.jinja2
@@ -0,0 +1,7 @@
+Brainstorm a list of text matching tasks where the queries are long documents.
+
+Here are a few examples:
+ - Given a document that supports a debatable argument, find another document that contains opposite arguments.
+ - Provided a lengthy business proposal, retrieve competitive business strategies in the same industry.
+
+Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct task in one sentence. Do not explain yourself or output anything else. Be creative!
diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-short.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-short.jinja2
new file mode 100644
index 0000000000..cf42fddae5
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-short.jinja2
@@ -0,0 +1,8 @@
+Brainstorm a list of text matching tasks where both the queries and the groundtruth documents are very short (one or two sentences, even a short phrase).
+
+Here are a few examples:
+ - Given a scientific paper title, retrieve the title of papers that cite the given paper.
+ - Match a word with its definition.
+ - Provided a notable person's name, identify their occupation or achievement.
+
+Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct task in one sentence. Do not explain yourself or output anything else. Be creative!
diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-retrieval.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-retrieval.jinja2
new file mode 100644
index 0000000000..464ed0e763
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-retrieval.jinja2
@@ -0,0 +1,11 @@
+Brainstorm a list of potentially useful text retrieval tasks.
+
+Here are a few examples for your reference:
+ - Provided a scientific claim as query, retrieve documents that help verify or refute the claim.
+ - Search for documents that answers a FAQ-style query on children's nutrition.
+
+Please adhere to the following guidelines:
+ - Specify what the query is, and what the desired documents are.
+ - Each retrieval task should cover a wide range of queries, and should not be too specific.
+
+Your output should always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct retrieval task in one sentence. Do not explain yourself or output anything else. Be creative!
diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/long-text-matching.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/long-text-matching.jinja2
new file mode 100644
index 0000000000..cd8bf1922a
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/long-text-matching.jinja2
@@ -0,0 +1,12 @@
+You have been assigned a text matching task: {{ task }}
+
+Your mission is to write one example for this task in JSON format. The JSON object must contain the following keys:
+ - "input": a string, a random input specified by the task.
+ - "positive_document": a string, a relevant document for the "input" according to the task.
+
+Please adhere to the following guidelines:
+ - The values of all fields should be in {{ language }}.
+ - Both the "input" and "positive_document" should be long documents (at least 300 words), avoid substantial word overlaps, otherwise the task would be too easy.
+ - The "input" and "positive_document" should be independent of each other.
+
+Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!
diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/monolingual-triplet.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/monolingual-triplet.jinja2
new file mode 100644
index 0000000000..585d618620
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/monolingual-triplet.jinja2
@@ -0,0 +1,10 @@
+Write a {{ unit }} triple with varying semantic similarity scores in JSON format. The semantic similarity score ranges from 1 to 5, with 1 denotes least similar and 5 denotes most similar.
+
+Please adhere to the following guidelines:
+ - The keys in JSON are "S1", "S2", and "S3", the values are all strings in {{ language }}, do not add any other keys.
+ - There should be some word overlaps between all three {{ unit }}s.
+ - The similarity score between S1 and S2 should be {{ high_score }}.
+ - The similarity score between S1 and S3 should be {{ low_score }}.
+ - The {{ unit }}s require {{ difficulty }} level education to understand and should be diverse in terms of topic and length.
+
+Your output must always be a JSON object only with three keys "S1", "S2" and "S3", do not explain yourself or output anything else. Be creative!
diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/short-text-matching.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/short-text-matching.jinja2
new file mode 100644
index 0000000000..90b08f9e57
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/short-text-matching.jinja2
@@ -0,0 +1,12 @@
+You have been assigned a text matching task: {{ task }}
+
+Your mission is to write one example for this task in JSON format. The JSON object must contain the following keys:
+ - "input": a string, a random input specified by the task.
+ - "positive_document": a string, a relevant document for the "input" according to the task.
+
+Please adhere to the following guidelines:
+ - The values of all fields should be in {{ language }}.
+ - Both the "input" and "positive_document" should be very short (a sentence or a phrase), avoid substantial word overlaps, otherwise the task would be too easy.
+ - The "input" and "positive_document" should be independent of each other.
+
+Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!
diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-classification.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-classification.jinja2
new file mode 100644
index 0000000000..74a184bc56
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-classification.jinja2
@@ -0,0 +1,15 @@
+You have been assigned a text classification task: {{ task }}
+
+Your mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:
+ - "input_text": a string, the input text specified by the classification task.
+ - "label": a string, the correct label of the input text.
+ - "misleading_label": a string, an incorrect label that is related to the task.
+
+Please adhere to the following guidelines:
+ - The "input_text" should be diverse in expression.
+ - The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the "input_text".
+ - The values for all fields should be in {{ language }}.
+ - Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make the task too easy.
+ - The "input_text" is {{ clarity }} and requires {{ difficulty }} level education to comprehend.
+
+Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!
diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-retrieval.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-retrieval.jinja2
new file mode 100644
index 0000000000..c76ac8a698
--- /dev/null
+++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-retrieval.jinja2
@@ -0,0 +1,17 @@
+You have been assigned a retrieval task: {{ task }}
+
+Your mission is to write one text retrieval example for this task in JSON format. The JSON object must contain the following keys:
+ - "user_query": a string, a random user search query specified by the retrieval task.
+ - "positive_document": a string, a relevant document for the user query.
+ - "hard_negative_document": a string, a hard negative document that only appears relevant to the query.
+
+Please adhere to the following guidelines:
+ - The "user_query" should be {{ query_type }}, {{ query_length }}, {{ clarity }}, and diverse in topic.
+ - All documents must be created independent of the query. Avoid copying the query verbatim. It's acceptable if some parts of the "positive_document" are not topically related to the query.
+ - All documents should be at least {{ num_words}} words long.
+ - The "hard_negative_document" contains some useful information, but it should be less useful or comprehensive compared to the "positive_document".
+ - Both the query and documents should be in {{ language }}.
+ - Do not provide any explanation in any document on why it is relevant or not relevant to the query.
+ - Both the query and documents require {{ difficulty }} level education to understand.
+
+Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!
diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py
index 41eb4444dc..f5c4659651 100644
--- a/src/distilabel/steps/tasks/text_generation.py
+++ b/src/distilabel/steps/tasks/text_generation.py
@@ -37,16 +37,41 @@ class TextGeneration(Task):
Output columns:
- generation (`str`): The generated text.
- - model_name (`str`): The model name used to generate the text.
+ - model_name (`str`): The name of the model used to generate the text.
Categories:
- text-generation
Examples:
+
+ Generate text from an instruction:
+
```python
from distilabel.steps.tasks import TextGeneration
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ text_gen = TextGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ )
+ )
- task = TextGeneration(llm=LLM(...))
+ text_gen.load()
+
+ result = next(
+ text_gen.process(
+ [{"instruction": "your instruction"}]
+ )
+ )
+ # result
+ # [
+ # {
+ # 'instruction': 'your instruction',
+ # 'model_name': 'mistralai/Mistral-7B-Instruct-v0.2',
+ # 'generation': 'generation',
+ # }
+ # ]
```
"""
@@ -62,14 +87,10 @@ def format_input(self, input: Dict[str, Any]) -> ChatType:
is the first interaction from the user within a conversation."""
if is_openai_format(input["instruction"]):
- warnings.warn(
+ raise ValueError(
"Providing `instruction` formatted as an OpenAI chat / conversation is"
- " about to be deprecated in `distilabel v1.2.0`, please make sure to use"
- " `ChatTextGeneration` with `messages` as input instead.",
- DeprecationWarning,
- stacklevel=2,
+ " deprecated, you should use `ChatGeneration` with `messages` as input instead.",
)
- return input["instruction"]
if not isinstance(input["instruction"], str):
raise ValueError(
@@ -96,7 +117,7 @@ def outputs(self) -> List[str]:
return ["generation", "model_name"]
def format_output(
- self, output: Union[str, None], input: Dict[str, Any]
+ self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
"""The output is formatted as a dictionary with the `generation`. The `model_name`
will be automatically included within the `process` method of `Task`."""
@@ -123,6 +144,44 @@ class ChatGeneration(Task):
Icon:
`:material-chat:`
+
+ Examples:
+
+ Generate text from a conversation in OpenAI chat format:
+
+ ```python
+ from distilabel.steps.tasks import ChatGeneration
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ chat = ChatGeneration(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ )
+ )
+
+ chat.load()
+
+ result = next(
+ chat.process(
+ [
+ {
+ "messages": [
+ {"role": "user", "content": "How much is 2+2?"},
+ ]
+ }
+ ]
+ )
+ )
+ # result
+ # [
+ # {
+ # 'messages': [{'role': 'user', 'content': 'How much is 2+2?'}],
+ # 'model_name': 'mistralai/Mistral-7B-Instruct-v0.2',
+ # 'generation': '4',
+ # }
+ # ]
+ ```
"""
@property
@@ -136,7 +195,7 @@ def format_input(self, input: Dict[str, Any]) -> ChatType:
if not is_openai_format(input["messages"]):
raise ValueError(
- "Input `instruction` must be a string or an OpenAI chat-like format. "
+ "Input `messages` must be an OpenAI chat-like format conversation. "
f"Got: {input['messages']}. Please check: 'https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models'."
)
diff --git a/src/distilabel/steps/tasks/typing.py b/src/distilabel/steps/tasks/typing.py
index cbd6ffc09c..4f92cdc057 100644
--- a/src/distilabel/steps/tasks/typing.py
+++ b/src/distilabel/steps/tasks/typing.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
+from pydantic import BaseModel
from typing_extensions import TypedDict
@@ -24,3 +25,47 @@ class ChatItem(TypedDict):
ChatType = List[ChatItem]
"""ChatType is a type alias for a `list` of `dict`s following the OpenAI conversational format."""
+
+
+class OutlinesStructuredOutputType(TypedDict, total=False):
+ """TypedDict to represent the structured output configuration from `outlines`."""
+
+ format: Literal["json", "regex"]
+ """One of "json" or "regex"."""
+ schema: Union[str, Type[BaseModel], Dict[str, Any]]
+ """The schema to use for the structured output. If "json", it
+ can be a pydantic.BaseModel class, or the schema as a string,
+ as obtained from `model_to_schema(BaseModel)`, if "regex", it
+ should be a regex pattern as a string.
+ """
+ whitespace_pattern: Optional[Union[str, List[str]]] = None
+ """If "json" corresponds to a string or a list of
+ strings with a pattern (doesn't impact string literals).
+ For example, to allow only a single space or newline with
+ `whitespace_pattern=r"[\n ]?"`
+ """
+
+
+class InstructorStructuredOutputType(TypedDict, total=False):
+ """TypedDict to represent the structured output configuration from `instructor`."""
+
+ schema: Union[Type[BaseModel], Dict[str, Any]]
+ """The schema to use for the structured output, a `pydantic.BaseModel` class. """
+ mode: Optional[str]
+ """Generation mode. Take a look at `instructor.Mode` for more information, if not informed it will
+ be determined automatically. """
+ max_retries: int
+ """Number of times to reask the model in case of error, if not set will default to the model's default. """
+
+
+StructuredOutputType = Union[
+ OutlinesStructuredOutputType, InstructorStructuredOutputType
+]
+"""StructuredOutputType is an alias for the union of `OutlinesStructuredOutputType` and `InstructorStructuredOutputType`."""
+
+StandardInput = ChatType
+"""StandardInput is an alias for ChatType that defines the default / standard input produced by `format_input`."""
+StructuredInput = Tuple[StandardInput, Union[StructuredOutputType, None]]
+"""StructuredInput defines a type produced by `format_input` when using either `StructuredGeneration` or a subclass of it."""
+FormattedInput = Union[StandardInput, StructuredInput]
+"""FormattedInput is an alias for the union of `StandardInput` and `StructuredInput` as generated by `format_input` and expected by the `LLM`s."""
diff --git a/src/distilabel/steps/tasks/ultrafeedback.py b/src/distilabel/steps/tasks/ultrafeedback.py
index 9b200ddafd..c6cd95482c 100644
--- a/src/distilabel/steps/tasks/ultrafeedback.py
+++ b/src/distilabel/steps/tasks/ultrafeedback.py
@@ -60,6 +60,45 @@ class UltraFeedback(Task):
References:
- [`UltraFeedback: Boosting Language Models with High-quality Feedback`](https://arxiv.org/abs/2310.01377)
- [`UltraFeedback - GitHub Repository`](https://github.com/OpenBMB/UltraFeedback)
+
+ Examples:
+
+ Rate generations from different LLMs based on the selected aspect:
+
+ ```python
+ from distilabel.steps.tasks import UltraFeedback
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
+
+ # Consider this as a placeholder for your actual LLM.
+ ultrafeedback = UltraFeedback(
+ llm=InferenceEndpointsLLM(
+ model_id="mistralai/Mistral-7B-Instruct-v0.2",
+ )
+ )
+
+ ultrafeedback.load()
+
+ result = next(
+ chat.process(
+ [
+ {
+ "instruction": "How much is 2+2?",
+ "generations": ["4", "and a car"],
+ }
+ ]
+ )
+ )
+ # result
+ # [
+ # {
+ # 'instruction': 'How much is 2+2?',
+ # 'generations': ['4', 'and a car'],
+ # 'ratings': [1, 2],
+ # 'rationales': ['explanation for 4', 'explanation for and a car'],
+ # 'model_name': 'mistralai/Mistral-7B-Instruct-v0.2',
+ # }
+ # ]
+ ```
"""
aspect: Literal[
diff --git a/src/distilabel/utils/dicts.py b/src/distilabel/utils/dicts.py
index 0ce96334f9..53d33d47f5 100644
--- a/src/distilabel/utils/dicts.py
+++ b/src/distilabel/utils/dicts.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
from collections import defaultdict
from typing import Any, Dict, List, TypeVar
@@ -33,3 +34,7 @@ def combine_dicts(*dicts: Dict[_K, Any]) -> Dict[_K, List[Any]]:
for key, value in d.items():
combined_dict[key].append(value)
return dict(combined_dict)
+
+
+def flatten_dict(x: Dict[Any, Any]) -> Dict[Any, Any]:
+ return {k: json.dumps(v) if isinstance(v, dict) else v for k, v in x.items()}
diff --git a/src/distilabel/utils/huggingface.py b/src/distilabel/utils/huggingface.py
new file mode 100644
index 0000000000..7a637a831c
--- /dev/null
+++ b/src/distilabel/utils/huggingface.py
@@ -0,0 +1,53 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from pathlib import Path
+from typing import Final
+
+from huggingface_hub import constants
+
+_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME: Final[str] = "HF_TOKEN"
+
+
+def get_hf_token(cls_name: str, token_arg: str) -> str:
+ """Get the token for the hugging face API.
+
+ Tries to extract it from the environment variable, if it is not found
+ it tries to read it from the file using 'huggingface_hub',
+ and if not possible raises a ValueError.
+
+ Args:
+ cls_name: Name of the class/function that requires the token.
+ token_arg: Argument name to use in the error message, normally
+ is "token" or "api_key".
+
+ Raises:
+ ValueError: If the token is not found in the file.
+
+ Returns:
+ The token for the hugging face API.
+ """
+ token = os.getenv(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME)
+ if token is None:
+ if not Path(constants.HF_TOKEN_PATH).exists():
+ raise ValueError(
+ f"To use `{cls_name}` an API key must be provided via"
+ f" `{token_arg}`, set the environment variable"
+ f" `{_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME}` or use the `huggingface-hub` CLI to login"
+ " with `huggingface-cli login`."
+ )
+ with open(constants.HF_TOKEN_PATH) as f:
+ token = f.read().strip()
+ return token
diff --git a/src/distilabel/utils/itertools.py b/src/distilabel/utils/itertools.py
index 9428389188..88ce86cc4e 100644
--- a/src/distilabel/utils/itertools.py
+++ b/src/distilabel/utils/itertools.py
@@ -13,18 +13,20 @@
# limitations under the License.
from itertools import zip_longest
-from typing import Any, Iterable, Literal
+from typing import Any, Iterable, List, Literal, TypeVar
+
+T = TypeVar("T")
# Copy pasted from https://docs.python.org/3/library/itertools.html#itertools-recipes
# Just added the type hints and use `if`s instead of `match`
def grouper(
- iterable: Iterable[Any],
+ iterable: Iterable[T],
n: int,
*,
incomplete: Literal["fill", "strict", "ignore"] = "fill",
fillvalue: Any = None,
-) -> Iterable[Any]:
+) -> Iterable[List[T]]:
"Collect data into non-overlapping fixed-length chunks or blocks."
# grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
diff --git a/src/distilabel/utils/logging.py b/src/distilabel/utils/logging.py
index 15a737f448..af1b26a18b 100644
--- a/src/distilabel/utils/logging.py
+++ b/src/distilabel/utils/logging.py
@@ -42,7 +42,9 @@
queue_listener: Union[QueueListener, None] = None
-def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> None:
+def setup_logging(
+ log_queue: Optional["Queue[Any]"] = None, filename: Optional[str] = None
+) -> None:
"""Sets up logging to use a queue across all processes."""
global queue_listener
@@ -53,7 +55,7 @@ def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> No
# If the current process is the main process, set up a `QueueListener`
# to handle logs from all subprocesses
- if mp.current_process().name == "MainProcess":
+ if mp.current_process().name == "MainProcess" and filename:
formatter = logging.Formatter("['%(name)s'] %(message)s")
handler = RichHandler(rich_tracebacks=True)
handler.setFormatter(formatter)
@@ -66,10 +68,11 @@ def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> No
)
file_handler.setFormatter(file_formatter)
- queue_listener = QueueListener(
- log_queue, handler, file_handler, respect_handler_level=True
- )
- queue_listener.start()
+ if log_queue is not None:
+ queue_listener = QueueListener(
+ log_queue, handler, file_handler, respect_handler_level=True
+ )
+ queue_listener.start()
log_level = os.environ.get("DISTILABEL_LOG_LEVEL", "INFO").upper()
if log_level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
@@ -80,9 +83,15 @@ def setup_logging(log_queue: "Queue[Any]", filename: Optional[str] = None) -> No
log_level = "INFO"
root_logger = logging.getLogger()
- root_logger.handlers.clear()
+
+ running_test = "PYTEST_CURRENT_TEST" in os.environ
+ if not running_test:
+ root_logger.handlers.clear()
+
+ if log_queue is not None:
+ root_logger.addHandler(QueueHandler(log_queue))
+
root_logger.setLevel(log_level)
- root_logger.addHandler(QueueHandler(log_queue))
def stop_logging() -> None:
@@ -90,4 +99,5 @@ def stop_logging() -> None:
global queue_listener
if queue_listener is not None:
queue_listener.stop()
+ queue_listener.queue.close()
queue_listener = None
diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py
index a5a5ba40e8..ae43c586af 100644
--- a/src/distilabel/utils/mkdocs/components_gallery.py
+++ b/src/distilabel/utils/mkdocs/components_gallery.py
@@ -21,7 +21,6 @@
from mkdocs.config.config_options import Type
from mkdocs.plugins import BasePlugin
from mkdocs.structure.files import File
-from mkdocs.structure.pages import Page
from mkdocs_section_index import SectionPage
from distilabel.utils.export_components_info import export_components_info
@@ -360,11 +359,26 @@ def on_nav(
steps_file = files.get_file_from_path(self.file_paths["steps"][0])
tasks_file = files.get_file_from_path(self.file_paths["tasks"][0])
llms_file = files.get_file_from_path(self.file_paths["llms"][0])
+ steps_files = [
+ files.get_file_from_path(path) for path in self.file_paths["steps"][0:]
+ ]
+ tasks_files = [
+ files.get_file_from_path(path) for path in self.file_paths["tasks"][0:]
+ ]
+ llms_files = [
+ files.get_file_from_path(path) for path in self.file_paths["llms"][0:]
+ ]
# Create subsections
- steps_page = Page("Steps", file=steps_file, config=config) # type: ignore
- tasks_page = Page("Tasks", file=tasks_file, config=config) # type: ignore
- llms_page = Page("LLMs", file=llms_file, config=config) # type: ignore
+ steps_page = SectionPage(
+ "Steps", file=steps_file, config=config, children=steps_files
+ ) # type: ignore
+ tasks_page = SectionPage(
+ "Tasks", file=tasks_file, config=config, children=tasks_files
+ ) # type: ignore
+ llms_page = SectionPage(
+ "LLMs", file=llms_file, config=config, children=llms_files
+ ) # type: ignore
# Create the gallery section
page = SectionPage(
diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/components-list.jinja2 b/src/distilabel/utils/mkdocs/templates/components-gallery/components-list.jinja2
index 3c465761c5..319c69164b 100644
--- a/src/distilabel/utils/mkdocs/templates/components-gallery/components-list.jinja2
+++ b/src/distilabel/utils/mkdocs/templates/components-gallery/components-list.jinja2
@@ -1,8 +1,8 @@
---
-hide:
+hide:
- toc
+ - navigation
---
-
# {{ title }}
{{ description }}
diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md
index 5c7af73180..eb2914b6a6 100644
--- a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md
+++ b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md
@@ -1,8 +1,8 @@
---
-hide:
+hide:
+ - navigation
- toc
---
-
# Components Gallery
diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/llm-detail.jinja2 b/src/distilabel/utils/mkdocs/templates/components-gallery/llm-detail.jinja2
index 212bbe8601..5d2b72dd90 100644
--- a/src/distilabel/utils/mkdocs/templates/components-gallery/llm-detail.jinja2
+++ b/src/distilabel/utils/mkdocs/templates/components-gallery/llm-detail.jinja2
@@ -1,3 +1,7 @@
+---
+hide:
+ - navigation
+---
# {{ llm.name }}
{% if llm.docstring.short_description %}
diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2 b/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2
index 486fc4b299..43a7d552b7 100644
--- a/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2
+++ b/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2
@@ -1,5 +1,8 @@
+---
+hide:
+ - navigation
+---
# {{ step.name }}
-
{% if step.docstring.short_description %}
{{ step.docstring.short_description }}
{% endif %}
@@ -56,7 +59,7 @@
{% for example_title, code in step.docstring.examples.items() %}
#### {{ example_title }}
```python
-{{ code | e }}
+{{ code | replace("\n", "\n") }}
```
{% endfor %}
{% endif %}
diff --git a/src/distilabel/utils/serialization.py b/src/distilabel/utils/serialization.py
index b97669b809..8f32afc2eb 100644
--- a/src/distilabel/utils/serialization.py
+++ b/src/distilabel/utils/serialization.py
@@ -13,18 +13,30 @@
# limitations under the License.
import importlib
-import json
import os
import sys
from enum import Enum
+import orjson
+
if sys.version_info < (3, 11):
from enum import EnumMeta as EnumType
else:
from enum import EnumType
from pathlib import Path
-from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, get_args
+from typing import (
+ Any,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ get_args,
+)
import yaml
from pydantic import BaseModel
@@ -40,17 +52,35 @@
SaveFormats = Literal["json", "yaml"]
+# Mapping to handle import paths that could have been serialized from previous versions
+_OLD_IMPORT_MODULE_ATTR: Dict[Tuple[str, str], Tuple[str, str]] = {
+ ("distilabel.pipeline.base", "_Batch"): ("distilabel.pipeline.batch", "_Batch"),
+ ("distilabel.pipeline.base", "_BatchManager"): (
+ "distilabel.pipeline.batch_manager",
+ "_BatchManager",
+ ),
+ ("distilabel.pipeline.base", "_BatchManagerStep"): (
+ "distilabel.pipeline.batch_manager",
+ "_BatchManagerStep",
+ ),
+}
+
+
def _get_module_attr(module: str, name: str) -> Type:
"""Gets a class given the module and the name of the class.
Returns:
The type of the class.
"""
+
+ if (module, name) in _OLD_IMPORT_MODULE_ATTR:
+ module, name = _OLD_IMPORT_MODULE_ATTR[(module, name)]
+
mod = importlib.import_module(module)
return getattr(mod, name)
-def load_from_dict(class_: Dict[str, Any]) -> Any:
+def load_with_type_info(class_: Any) -> Any:
"""Creates an instance of a class from a dictionary containing the type info and the
serialized data of the class.
@@ -60,17 +90,32 @@ def load_from_dict(class_: Dict[str, Any]) -> Any:
Returns:
An instance of the class with the data loaded from the dictionary.
"""
- type_info = class_.pop(TYPE_INFO_KEY)
- if TYPE_INFO_KEY in type_info:
- # There is a nested type_info, load the class recursively
- type_info = load_from_dict(type_info)
+ if not isinstance(class_, (list, dict)):
+ return class_
- cls = _get_module_attr(type_info["module"], type_info["name"])
+ if isinstance(class_, list):
+ return [load_with_type_info(x) for x in class_]
for k, v in class_.items():
+ class_[k] = load_with_type_info(v) if isinstance(v, (dict, list)) else v
+
if isinstance(v, dict) and "_type" in v and v["_type"] == "enum":
class_[k] = Enum(v["_name"], v["_values"], type=eval(v["_enum_type"]))
+ if TYPE_INFO_KEY not in class_:
+ return class_
+
+ type_info = class_.pop(TYPE_INFO_KEY)
+
+ cls = _get_module_attr(type_info["module"], type_info["name"])
+
+ if issubclass(cls, BaseModel):
+ # `pop` keys from the dictionary that are not in the model fields
+ field_names = cls.model_fields
+ keys_to_drop = [k for k in class_.keys() if k not in field_names]
+ for k in keys_to_drop:
+ class_.pop(k)
+
instance = cls(**class_)
return instance
@@ -83,8 +128,8 @@ def write_json(filename: Path, data: Any) -> None:
data: the data to write to the file.
"""
filename.parent.mkdir(parents=True, exist_ok=True)
- with open(filename, "w") as file:
- json.dump(data, file, indent=2)
+ with open(filename, "wb") as f:
+ f.write(orjson.dumps(data, option=orjson.OPT_SERIALIZE_NUMPY))
def read_json(filename: StrOrPath) -> Any:
@@ -96,8 +141,8 @@ def read_json(filename: StrOrPath) -> Any:
Returns:
The data from the file.
"""
- with open(filename, "r") as file:
- return json.load(file)
+ with open(filename, "rb") as f:
+ return orjson.loads(f.read())
def write_yaml(filename: Path, data: Dict[str, Any]) -> None:
@@ -159,10 +204,15 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"_name": getattr(obj, k).__name__,
"_values": {x.name: x.value for x in v}, # type: ignore
}
+ elif isinstance(v, list):
+ dump[k] = {str(i): list_v for i, list_v in enumerate(v)}
+
# Grab the fields that need extra care (LLMs from inside tasks)
to_update = _extra_serializable_fields(obj)
+
# Update those in the dumped dict
- [dump.update(field) for field in to_update]
+ for field in to_update:
+ dump.update(field)
return dump
@@ -237,7 +287,7 @@ def from_dict(cls, data: Dict[str, Any]) -> Self:
Returns:
An instance of the class with the data loaded from the dictionary.
"""
- return load_from_dict(data)
+ return load_with_type_info(data)
@classmethod
def from_json(cls, path: StrOrPath) -> Self:
@@ -303,12 +353,19 @@ def _check_is_dir(path: StrOrPath) -> None:
def _extra_serializable_fields(obj: BaseModel) -> List[Dict[str, Dict[str, Any]]]:
- # This function is here to loop over objects that contains nested _Serializable objects.
- # Cannot work recursively due to the mix between models that inherit from BaseModel and
- # those that don't, so we loop over the classes and update those that are _Serializable.
- # Extra introspection to dump nested objects.
- # Mainly for the LLMs inside a Task for the moment.
- # This way we ensure the "type_info" is inserted in those objects.
+ """Gets the information of the nested `_Serializable` attributes within another `_Serializable`
+ instance.
+
+ It's mainly used to get the information of the `LLM` objects inside a `Task` object,
+ as they are nested and need to be serialized (`type_info`).
+
+ Args:
+ obj: the object to extract the information from.
+
+ Returns:
+ A list of dictionaries containing the information of the nested `_Serializable`
+ attributes.
+ """
from distilabel.pipeline.base import BasePipeline
to_update = []
@@ -316,6 +373,12 @@ def _extra_serializable_fields(obj: BaseModel) -> List[Dict[str, Dict[str, Any]]
field = getattr(obj, k)
# Have to remove the Pipeline as it will be inside the Step objects but is really
# in a higher level hierarchy.
- if isinstance(field, _Serializable) and (not isinstance(field, BasePipeline)):
+ if isinstance(field, BasePipeline):
+ continue
+
+ if isinstance(field, _Serializable):
to_update.append({k: getattr(obj, k).dump()})
+ elif isinstance(field, list) and field and isinstance(field[0], _Serializable):
+ to_update.append({k: {str(i): x.dump() for i, x in enumerate(field)}})
+
return to_update
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
new file mode 100644
index 0000000000..8337c9aaa9
--- /dev/null
+++ b/tests/integration/conftest.py
@@ -0,0 +1,27 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+from typing import Generator
+
+import pytest
+
+
+@pytest.fixture(autouse=True)
+def temp_cache_dir() -> Generator[None, None, None]:
+ """Set the cache directory to a temporary directory for all tests."""
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ os.environ["DISTILABEL_CACHE_DIR"] = tmpdirname
+ yield
diff --git a/tests/integration/test_cache.py b/tests/integration/test_cache.py
new file mode 100644
index 0000000000..6eddd6f7ca
--- /dev/null
+++ b/tests/integration/test_cache.py
@@ -0,0 +1,55 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, List
+
+import numpy as np
+import pytest
+from distilabel.pipeline import Pipeline
+from distilabel.steps import GeneratorStep, StepInput, step
+
+if TYPE_CHECKING:
+ from distilabel.steps import GeneratorStepOutput, StepOutput
+
+
+class NumpyBigArrayGenerator(GeneratorStep):
+ num_batches: int
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["array"]
+
+ def process(self, offset: int = 0) -> "GeneratorStepOutput":
+ for i in range(self.num_batches):
+ yield (
+ [{"array": np.random.randn(256)} for _ in range(self.batch_size)], # type: ignore
+ i == self.num_batches - 1,
+ ) # type: ignore
+
+
+@step(step_type="global")
+def ReceiveArrays(inputs: StepInput) -> "StepOutput":
+ yield inputs
+
+
+@pytest.mark.benchmark
+def test_cache_time() -> None:
+ with Pipeline(name="dummy") as pipeline:
+ numpy_generator = NumpyBigArrayGenerator(num_batches=2, batch_size=100)
+
+ receive_arrays = ReceiveArrays()
+
+ numpy_generator >> receive_arrays
+
+ pipeline.run(use_cache=False)
diff --git a/tests/integration/test_pipe_simple.py b/tests/integration/test_pipe_simple.py
index 8ab1fff29a..31a624b15f 100644
--- a/tests/integration/test_pipe_simple.py
+++ b/tests/integration/test_pipe_simple.py
@@ -166,15 +166,8 @@ def run_pipeline():
)
-def test_pipeline_cached():
- ds = run_pipeline()
- print()
- print("----- RUNNING PIPELINE AGAIN -----")
- print()
+def test_pipeline_cached() -> None:
+ run_pipeline()
ds = run_pipeline()
assert isinstance(ds, Distiset)
assert len(ds["default"]["train"]) == 80
-
-
-if __name__ == "__main__":
- test_pipeline_cached()
diff --git a/tests/integration/test_routing_batch_function.py b/tests/integration/test_routing_batch_function.py
index 0ea2ee3cdc..228fb1c43e 100644
--- a/tests/integration/test_routing_batch_function.py
+++ b/tests/integration/test_routing_batch_function.py
@@ -74,7 +74,7 @@ def CombineGenerations(*inputs: StepInput) -> "StepOutput":
yield combined_list
-@pytest.mark.timeout(120)
+@pytest.mark.timeout(240)
def test_routing_batch_function() -> None:
with Pipeline(name="test") as pipeline:
load_dataset = LoadDataFromDicts(
@@ -95,7 +95,7 @@ def test_routing_batch_function() -> None:
assert len(row["generations"]) == 2
-@pytest.mark.timeout(120)
+@pytest.mark.timeout(240)
def test_routing_batch_function_irregular_batch_sizes() -> None:
with Pipeline(name="test") as pipeline:
load_dataset = LoadDataFromDicts(
@@ -120,7 +120,7 @@ def test_routing_batch_function_irregular_batch_sizes() -> None:
assert len(row["generations"]) == 2
-@pytest.mark.timeout(120)
+@pytest.mark.timeout(240)
def test_multiple_routing_batch_function() -> None:
batch_size = 200
diff --git a/tests/integration/test_using_fs_to_pass_data.py b/tests/integration/test_using_fs_to_pass_data.py
new file mode 100644
index 0000000000..811885e356
--- /dev/null
+++ b/tests/integration/test_using_fs_to_pass_data.py
@@ -0,0 +1,68 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, List
+
+import numpy as np
+from distilabel.pipeline import Pipeline
+from distilabel.steps import GeneratorStep, StepInput, step
+
+if TYPE_CHECKING:
+ from distilabel.steps import GeneratorStepOutput, StepOutput
+
+
+class NumpyBigArrayGenerator(GeneratorStep):
+ num_batches: int
+
+ @property
+ def outputs(self) -> List[str]:
+ return ["array"]
+
+ def process(self, offset: int = 0) -> "GeneratorStepOutput":
+ for i in range(self.num_batches):
+ yield (
+ [{"array": np.random.randn(128)} for _ in range(self.batch_size)], # type: ignore
+ i == self.num_batches - 1,
+ ) # type: ignore
+
+
+@step(step_type="global")
+def ReceiveArrays(inputs: StepInput) -> "StepOutput":
+ yield inputs
+
+
+def test_passing_data_through_fs_only_global_steps() -> None:
+ with Pipeline(name="dummy") as pipeline:
+ numpy_generator = NumpyBigArrayGenerator(num_batches=5, batch_size=100)
+
+ receive_arrays = ReceiveArrays()
+
+ numpy_generator >> receive_arrays
+
+ distiset = pipeline.run(use_fs_to_pass_data=False, use_cache=False)
+
+ assert len(distiset["default"]["train"]) == 500
+
+
+def test_passing_data_through_fs() -> None:
+ with Pipeline(name="dummy") as pipeline:
+ numpy_generator = NumpyBigArrayGenerator(num_batches=2, batch_size=200)
+
+ receive_arrays = ReceiveArrays()
+
+ numpy_generator >> receive_arrays
+
+ distiset = pipeline.run(use_fs_to_pass_data=True, use_cache=False)
+
+ assert len(distiset["default"]["train"]) == 400
diff --git a/tests/unit/steps/tasks/utils.py b/tests/unit/conftest.py
similarity index 56%
rename from tests/unit/steps/tasks/utils.py
rename to tests/unit/conftest.py
index 989fb3ad5b..bbe6ca1ed4 100644
--- a/tests/unit/steps/tasks/utils.py
+++ b/tests/unit/conftest.py
@@ -12,13 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List, Union
+from typing import TYPE_CHECKING
-from distilabel.llms.base import LLM
-from distilabel.steps.tasks.typing import ChatType
+import pytest
+from distilabel.llms.base import AsyncLLM
+if TYPE_CHECKING:
+ from distilabel.llms.typing import GenerateOutput
+ from distilabel.steps.tasks.typing import FormattedInput
-class DummyLLM(LLM):
+
+# Defined here too, so that the serde still works
+class DummyLLM(AsyncLLM):
def load(self) -> None:
pass
@@ -26,7 +31,12 @@ def load(self) -> None:
def model_name(self) -> str:
return "test"
- def generate(
- self, inputs: List["ChatType"], num_generations: int = 1, **kwargs: Any
- ) -> List[List[Union[str, None]]]:
- return [["output" for _ in range(num_generations)] for _ in inputs]
+ async def agenerate(
+ self, input: "FormattedInput", num_generations: int = 1
+ ) -> "GenerateOutput":
+ return ["output" for _ in range(num_generations)]
+
+
+@pytest.fixture
+def dummy_llm() -> AsyncLLM:
+ return DummyLLM()
diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py
index 9caccf43c4..87a890a38c 100644
--- a/tests/unit/llms/huggingface/test_inference_endpoints.py
+++ b/tests/unit/llms/huggingface/test_inference_endpoints.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import random
+from unittest import mock
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import nest_asyncio
@@ -22,6 +24,36 @@
@patch("huggingface_hub.AsyncInferenceClient")
@patch("openai.AsyncOpenAI")
class TestInferenceEndpointsLLM:
+ def test_load_no_api_key(
+ self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
+ ) -> None:
+ llm = InferenceEndpointsLLM(
+ model_id="distilabel-internal-testing/tiny-random-mistral"
+ )
+
+ # Mock `huggingface_hub.constants.HF_TOKEN_PATH` to not exist
+ with mock.patch("pathlib.Path.exists") as mock_exists:
+ mock_exists.return_value = False
+ with pytest.raises(
+ ValueError,
+ match="To use `InferenceEndpointsLLM` an API key must be provided",
+ ):
+ llm.load()
+
+ def test_load_with_cached_token(
+ self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
+ ) -> None:
+ llm = InferenceEndpointsLLM(
+ model_id="distilabel-internal-testing/tiny-random-mistral"
+ )
+
+ # Mock `huggingface_hub.constants.HF_TOKEN_PATH` to exist
+ with mock.patch("pathlib.Path.exists", return_value=True), mock.patch(
+ "builtins.open", new_callable=mock.mock_open, read_data="hf_token"
+ ):
+ # Should not raise any errors
+ llm.load()
+
def test_serverless_inference_endpoints_llm(
self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
) -> None:
@@ -145,6 +177,7 @@ async def test_generate_via_openai_client(
)
llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion)
+ ...
nest_asyncio.apply()
assert llm.generate(
@@ -159,8 +192,54 @@ async def test_generate_via_openai_client(
]
) == [(" Aenean hendrerit aliquam velit. ...",)]
+ @pytest.mark.asyncio
+ async def test_agenerate_with_structured_output(
+ self, mock_inference_client: MagicMock, _: MagicMock
+ ) -> None:
+ llm = InferenceEndpointsLLM(
+ model_id="distilabel-internal-testing/tiny-random-mistral",
+ structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"},
+ )
+ llm._aclient = mock_inference_client
+
+ llm._aclient.text_generation = AsyncMock(
+ return_value=" Aenean hendrerit aliquam velit. ..."
+ )
+
+ # Since there's a pseudo-random number within the generation kwargs, we set the seed
+ # here first to ensure reproducibility within the tests
+ random.seed(42)
+
+ assert await llm.agenerate(
+ input=[
+ {
+ "role": "user",
+ "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
+ },
+ ]
+ ) == [" Aenean hendrerit aliquam velit. ..."]
+
+ kwargs = {
+ "prompt": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
+ "max_new_tokens": 128,
+ "do_sample": False,
+ "typical_p": None,
+ "repetition_penalty": None,
+ "temperature": 1.0,
+ "top_p": None,
+ "top_k": None,
+ "stop_sequences": None,
+ "return_full_text": False,
+ "watermark": False,
+ "grammar": {"type": "regex", "value": "\\b[A-Z][a-z]*\\b"},
+ "seed": 478163327, # pre-computed random value with `random.seed(42)`
+ }
+ mock_inference_client.text_generation.assert_called_with(**kwargs)
+
def test_serialization(
- self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
+ self,
+ mock_inference_client: MagicMock,
+ mock_openai_client: MagicMock,
) -> None:
llm = InferenceEndpointsLLM(
model_id="distilabel-internal-testing/tiny-random-mistral",
@@ -173,9 +252,9 @@ def test_serialization(
"base_url": None,
"tokenizer_id": None,
"generation_kwargs": {},
+ "structured_output": None,
"model_display_name": None,
"use_openai_client": False,
- "structured_output": None,
"type_info": {
"module": "distilabel.llms.huggingface.inference_endpoints",
"name": "InferenceEndpointsLLM",
diff --git a/tests/unit/llms/test_anthropic.py b/tests/unit/llms/test_anthropic.py
index 28e486756b..75e7bcbf62 100644
--- a/tests/unit/llms/test_anthropic.py
+++ b/tests/unit/llms/test_anthropic.py
@@ -13,12 +13,16 @@
# limitations under the License.
import os
+import sys
+from typing import Any, Dict
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import nest_asyncio
import pytest
from distilabel.llms.anthropic import AnthropicLLM
+from .utils import DummyUserDetail
+
@patch("anthropic.AsyncAnthropic")
class TestAnthropicLLM:
@@ -47,6 +51,37 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None:
]
)
+ @pytest.mark.asyncio
+ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
+ llm = AnthropicLLM(
+ model="claude-3-opus-20240229",
+ api_key="api.key",
+ structured_output={
+ "schema": DummyUserDetail,
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ ) # type: ignore
+ llm._aclient = mock_openai
+
+ sample_user = DummyUserDetail(name="John Doe", age=30)
+
+ llm._aclient.messages.create = AsyncMock(return_value=sample_user)
+
+ generation = await llm.agenerate(
+ input=[
+ {"role": "system", "content": ""},
+ {
+ "role": "user",
+ "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
+ },
+ ]
+ )
+ assert generation[0] == sample_user.model_dump_json()
+
+ @pytest.mark.skipif(
+ sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher"
+ )
@pytest.mark.asyncio
async def test_generate(self, mock_anthropic: MagicMock) -> None:
llm = AnthropicLLM(model="claude-3-opus-20240229") # type: ignore
@@ -71,7 +106,52 @@ async def test_generate(self, mock_anthropic: MagicMock) -> None:
]
)
- def test_serialization(self, _: MagicMock) -> None:
+ @pytest.mark.parametrize(
+ "structured_output, dump",
+ [
+ (
+ None,
+ {
+ "base_url": "https://api.anthropic.com",
+ "generation_kwargs": {},
+ "max_retries": 6,
+ "model": "claude-3-opus-20240229",
+ "timeout": 600.0,
+ "structured_output": None,
+ "type_info": {
+ "module": "distilabel.llms.anthropic",
+ "name": "AnthropicLLM",
+ },
+ },
+ ),
+ (
+ {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ {
+ "base_url": "https://api.anthropic.com",
+ "generation_kwargs": {},
+ "max_retries": 6,
+ "model": "claude-3-opus-20240229",
+ "timeout": 600.0,
+ "structured_output": {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ "type_info": {
+ "module": "distilabel.llms.anthropic",
+ "name": "AnthropicLLM",
+ },
+ },
+ ),
+ ],
+ )
+ def test_serialization(
+ self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any]
+ ) -> None:
os.environ["ANTHROPIC_API_KEY"] = "api.key"
llm = AnthropicLLM(model="claude-3-opus-20240229") # type: ignore
diff --git a/tests/unit/llms/test_azure.py b/tests/unit/llms/test_azure.py
index a5208da95f..e8af5d7b8f 100644
--- a/tests/unit/llms/test_azure.py
+++ b/tests/unit/llms/test_azure.py
@@ -13,10 +13,14 @@
# limitations under the License.
import os
+from typing import Any, Dict
from unittest import mock
+import pytest
from distilabel.llms.azure import AzureOpenAILLM
+from .utils import DummyUserDetail
+
class TestAzureOpenAILLM:
model_id: str = "gpt-4"
@@ -56,20 +60,70 @@ def test_azure_openai_llm_env_vars(self) -> None:
assert llm.api_key.get_secret_value() == "another.api.key" # type: ignore
assert llm.api_version == self.api_version
- def test_serialization(self) -> None:
+ @pytest.mark.parametrize(
+ "structured_output, dump",
+ [
+ (
+ None,
+ {
+ "model": "gpt-4",
+ "api_version": "preview",
+ "generation_kwargs": {},
+ "max_retries": 6,
+ "base_url": "https://example-resource.azure.openai.com/",
+ "timeout": 120,
+ "structured_output": None,
+ "type_info": {
+ "module": "distilabel.llms.azure",
+ "name": "AzureOpenAILLM",
+ },
+ },
+ ),
+ (
+ {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ {
+ "model": "gpt-4",
+ "api_version": "preview",
+ "generation_kwargs": {},
+ "max_retries": 6,
+ "base_url": "https://example-resource.azure.openai.com/",
+ "timeout": 120,
+ "structured_output": {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ "type_info": {
+ "module": "distilabel.llms.azure",
+ "name": "AzureOpenAILLM",
+ },
+ },
+ ),
+ ],
+ )
+ def test_serialization(
+ self, structured_output: Dict[str, Any], dump: Dict[str, Any]
+ ) -> None:
llm = AzureOpenAILLM(
- model=self.model_id, base_url=self.base_url, api_version=self.api_version
+ model=self.model_id,
+ base_url=self.base_url,
+ api_version=self.api_version,
+ structured_output=structured_output,
)
- _dump = {
- "generation_kwargs": {},
- "model": "gpt-4",
- "base_url": "https://example-resource.azure.openai.com/",
- "max_retries": 6,
- "timeout": 120,
- "api_version": "preview",
- "structured_output": None,
- "type_info": {"module": "distilabel.llms.azure", "name": "AzureOpenAILLM"},
- }
- assert llm.dump() == _dump
- assert isinstance(AzureOpenAILLM.from_dict(_dump), AzureOpenAILLM)
+ # _dump = {
+ # "generation_kwargs": {},
+ # "model": "gpt-4",
+ # "base_url": "https://example-resource.azure.openai.com/",
+ # "max_retries": 6,
+ # "timeout": 120,
+ # "api_version": "preview",
+ # "structured_output": None,
+ # "type_info": {"module": "distilabel.llms.azure", "name": "AzureOpenAILLM"},
+ # }
+ assert llm.dump() == dump
+ assert isinstance(AzureOpenAILLM.from_dict(dump), AzureOpenAILLM)
diff --git a/tests/unit/llms/test_cohere.py b/tests/unit/llms/test_cohere.py
index 0c2e2e213c..a16d904e11 100644
--- a/tests/unit/llms/test_cohere.py
+++ b/tests/unit/llms/test_cohere.py
@@ -13,12 +13,16 @@
# limitations under the License.
import os
+import sys
+from typing import Any, Dict
from unittest import mock
import nest_asyncio
import pytest
from distilabel.llms.cohere import CohereLLM
+from .utils import DummyUserDetail
+
@mock.patch("cohere.AsyncClient")
class TestCohereLLM:
@@ -64,6 +68,38 @@ async def test_agenerate(self, mock_async_client: mock.MagicMock) -> None:
]
)
+ @pytest.mark.skipif(
+ sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher"
+ )
+ @pytest.mark.asyncio
+ async def test_agenerate_structured(
+ self, mock_async_client: mock.MagicMock
+ ) -> None:
+ llm = CohereLLM(
+ model="command-r",
+ structured_output={
+ "schema": DummyUserDetail,
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ )
+ llm._aclient = mock_async_client # type: ignore
+
+ sample_user = DummyUserDetail(name="John Doe", age=30)
+
+ llm._aclient.chat = mock.AsyncMock(return_value=sample_user)
+
+ generation = await llm.agenerate(
+ input=[
+ {"role": "system", "content": ""},
+ {
+ "role": "user",
+ "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
+ },
+ ]
+ )
+ assert generation == [sample_user.model_dump_json()]
+
@pytest.mark.asyncio
async def test_generate(self, mock_async_client: mock.MagicMock) -> None:
llm = CohereLLM(model="command-r")
@@ -92,21 +128,53 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None:
]
)
- def test_serialization(self, _: mock.MagicMock) -> None:
- llm = CohereLLM(model="command-r")
-
- dump = {
- "model": "command-r",
- "generation_kwargs": {},
- "base_url": "https://api.cohere.ai/v1",
- "timeout": 120,
- "client_name": "distilabel",
- "structured_output": None,
- "type_info": {
- "module": "distilabel.llms.cohere",
- "name": "CohereLLM",
- },
- }
+ @pytest.mark.parametrize(
+ "structured_output, dump",
+ [
+ (
+ None,
+ {
+ "model": "command-r",
+ "generation_kwargs": {},
+ "base_url": "https://api.cohere.ai/v1",
+ "timeout": 120,
+ "client_name": "distilabel",
+ "structured_output": None,
+ "type_info": {
+ "module": "distilabel.llms.cohere",
+ "name": "CohereLLM",
+ },
+ },
+ ),
+ (
+ {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ {
+ "model": "command-r",
+ "generation_kwargs": {},
+ "base_url": "https://api.cohere.ai/v1",
+ "timeout": 120,
+ "client_name": "distilabel",
+ "structured_output": {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ "type_info": {
+ "module": "distilabel.llms.cohere",
+ "name": "CohereLLM",
+ },
+ },
+ ),
+ ],
+ )
+ def test_serialization(
+ self, _: mock.MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any]
+ ) -> None:
+ llm = CohereLLM(model="command-r", structured_output=structured_output)
assert llm.dump() == dump
assert isinstance(CohereLLM.from_dict(dump), CohereLLM)
diff --git a/tests/unit/llms/test_groq.py b/tests/unit/llms/test_groq.py
index e75166ce97..7607ab2cb2 100644
--- a/tests/unit/llms/test_groq.py
+++ b/tests/unit/llms/test_groq.py
@@ -13,12 +13,16 @@
# limitations under the License.
import os
+import sys
+from typing import Any, Dict
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import nest_asyncio
import pytest
from distilabel.llms.groq import GroqLLM
+from .utils import DummyUserDetail
+
@patch("groq._client.AsyncGroq")
class TestGroqLLM:
@@ -47,6 +51,37 @@ async def test_agenerate(self, mock_groq: MagicMock) -> None:
]
) == [" Aenean hendrerit aliquam velit. ..."]
+ @pytest.mark.skipif(
+ sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher"
+ )
+ @pytest.mark.asyncio
+ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
+ llm = GroqLLM(
+ model="llama3-70b-8192",
+ api_key="api.key",
+ structured_output={
+ "schema": DummyUserDetail,
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ ) # type: ignore
+ llm._aclient = mock_openai
+
+ sample_user = DummyUserDetail(name="John Doe", age=30)
+
+ llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user)
+
+ generation = await llm.agenerate(
+ input=[
+ {"role": "system", "content": ""},
+ {
+ "role": "user",
+ "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
+ },
+ ]
+ )
+ assert generation[0] == sample_user.model_dump_json()
+
@pytest.mark.asyncio
async def test_generate(self, mock_groq: MagicMock) -> None:
llm = GroqLLM(model="llama3-70b-8192", api_key="api.key") # type: ignore
@@ -71,22 +106,54 @@ async def test_generate(self, mock_groq: MagicMock) -> None:
]
) == [(" Aenean hendrerit aliquam velit. ...",)]
- def test_serialization(self, mock_groq: MagicMock) -> None:
+ @pytest.mark.parametrize(
+ "structured_output, dump",
+ [
+ (
+ None,
+ {
+ "model": "llama3-70b-8192",
+ "base_url": "https://api.groq.com",
+ "generation_kwargs": {},
+ "max_retries": 2,
+ "timeout": 120,
+ "structured_output": None,
+ "type_info": {
+ "module": "distilabel.llms.groq",
+ "name": "GroqLLM",
+ },
+ },
+ ),
+ (
+ {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ {
+ "model": "llama3-70b-8192",
+ "base_url": "https://api.groq.com",
+ "generation_kwargs": {},
+ "max_retries": 2,
+ "timeout": 120,
+ "structured_output": {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ "type_info": {
+ "module": "distilabel.llms.groq",
+ "name": "GroqLLM",
+ },
+ },
+ ),
+ ],
+ )
+ def test_serialization(
+ self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any]
+ ) -> None:
os.environ["GROQ_API_KEY"] = "api.key"
- llm = GroqLLM(model="llama3-70b-8192")
-
- _dump = {
- "model": "llama3-70b-8192",
- "base_url": "https://api.groq.com",
- "generation_kwargs": {},
- "max_retries": 2,
- "timeout": 120,
- "structured_output": None,
- "type_info": {
- "module": "distilabel.llms.groq",
- "name": "GroqLLM",
- },
- }
+ llm = GroqLLM(model="llama3-70b-8192", structured_output=structured_output)
- assert llm.dump() == _dump
- assert isinstance(GroqLLM.from_dict(_dump), GroqLLM) # type: ignore
+ assert llm.dump() == dump
+ assert isinstance(GroqLLM.from_dict(dump), GroqLLM) # type: ignore
diff --git a/tests/unit/llms/test_llamacpp.py b/tests/unit/llms/test_llamacpp.py
index c69b460ce5..b226d99292 100644
--- a/tests/unit/llms/test_llamacpp.py
+++ b/tests/unit/llms/test_llamacpp.py
@@ -14,11 +14,13 @@
import os
import urllib.request
-from typing import Generator
+from typing import Any, Dict, Generator
import pytest
from distilabel.llms.llamacpp import LlamaCppLLM
+from .utils import DummyUserDetail
+
@pytest.fixture(scope="module")
def llm() -> Generator[LlamaCppLLM, None, None]:
@@ -54,3 +56,62 @@ def test_generate(self, llm: LlamaCppLLM) -> None:
assert len(responses) == 2
assert len(responses[0]) == 3
+
+ @pytest.mark.parametrize(
+ "structured_output, dump",
+ [
+ (
+ None,
+ {
+ "chat_format": None,
+ "extra_kwargs": {},
+ "n_batch": 512,
+ "n_ctx": 512,
+ "n_gpu_layers": 0,
+ "seed": 4294967295,
+ "generation_kwargs": {},
+ "structured_output": None,
+ "type_info": {
+ "module": "distilabel.llms.llamacpp",
+ "name": "LlamaCppLLM",
+ },
+ "verbose": False,
+ },
+ ),
+ (
+ {
+ "schema": DummyUserDetail.model_json_schema(),
+ "format": "json",
+ },
+ {
+ "chat_format": None,
+ "extra_kwargs": {},
+ "n_batch": 512,
+ "n_ctx": 512,
+ "n_gpu_layers": 0,
+ "seed": 4294967295,
+ "generation_kwargs": {},
+ "structured_output": {
+ "schema": DummyUserDetail.model_json_schema(),
+ "format": "json",
+ },
+ "type_info": {
+ "module": "distilabel.llms.llamacpp",
+ "name": "LlamaCppLLM",
+ },
+ "verbose": False,
+ },
+ ),
+ ],
+ )
+ def test_serialization(
+ self, structured_output: Dict[str, Any], dump: Dict[str, Any]
+ ) -> None:
+ llm = LlamaCppLLM(
+ model_path="tinyllama.gguf",
+ n_gpu_layers=0,
+ structured_output=structured_output,
+ )
+
+ assert llm.dump() == dump
+ assert isinstance(LlamaCppLLM.from_dict(dump), LlamaCppLLM)
diff --git a/tests/unit/llms/test_mistral.py b/tests/unit/llms/test_mistral.py
index f31e903d3e..5bb2337481 100644
--- a/tests/unit/llms/test_mistral.py
+++ b/tests/unit/llms/test_mistral.py
@@ -14,11 +14,14 @@
import os
import sys
+from typing import Any, Dict
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import nest_asyncio
import pytest
+from .utils import DummyUserDetail
+
try:
from distilabel.llms.mistral import MistralLLM
except ImportError:
@@ -55,6 +58,37 @@ async def test_agenerate(self, mock_mistral: MagicMock) -> None:
]
)
+ @pytest.mark.asyncio
+ async def test_agenerate_structured(self, mock_mistral: MagicMock) -> None:
+ llm = MistralLLM(
+ model="mistral-tiny",
+ api_key="api.key",
+ structured_output={
+ "schema": DummyUserDetail,
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ ) # type: ignore
+ llm._aclient = mock_mistral
+
+ sample_user = DummyUserDetail(name="John Doe", age=30)
+
+ llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user)
+ # This should work just with the _aclient.chat method once it's fixed in instructor, and
+ # then in our code.
+ # llm._aclient.chat = AsyncMock(return_value=sample_user)
+
+ generation = await llm.agenerate(
+ input=[
+ {"role": "system", "content": ""},
+ {
+ "role": "user",
+ "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
+ },
+ ]
+ )
+ assert generation[0] == sample_user.model_dump_json()
+
@pytest.mark.asyncio
async def test_generate(self, mock_mistral: MagicMock) -> None:
llm = MistralLLM(model="mistral-tiny", api_key="api.key") # type: ignore
@@ -79,7 +113,54 @@ async def test_generate(self, mock_mistral: MagicMock) -> None:
]
)
- def test_serialization(self, mock_mistral: MagicMock) -> None:
+ @pytest.mark.parametrize(
+ "structured_output, dump",
+ [
+ (
+ None,
+ {
+ "model": "mistral-tiny",
+ "endpoint": "https://api.mistral.ai",
+ "generation_kwargs": {},
+ "max_retries": 6,
+ "timeout": 120,
+ "max_concurrent_requests": 64,
+ "structured_output": None,
+ "type_info": {
+ "module": "distilabel.llms.mistral",
+ "name": "MistralLLM",
+ },
+ },
+ ),
+ (
+ {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ {
+ "model": "mistral-tiny",
+ "endpoint": "https://api.mistral.ai",
+ "generation_kwargs": {},
+ "max_retries": 6,
+ "timeout": 120,
+ "max_concurrent_requests": 64,
+ "structured_output": {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ "type_info": {
+ "module": "distilabel.llms.mistral",
+ "name": "MistralLLM",
+ },
+ },
+ ),
+ ],
+ )
+ def test_serialization(
+ self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any]
+ ) -> None:
os.environ["MISTRAL_API_KEY"] = "api.key"
llm = MistralLLM(model="mistral-tiny") # type: ignore
diff --git a/tests/unit/llms/test_mixins.py b/tests/unit/llms/test_mixins.py
index feb8b00e01..c0c7b10671 100644
--- a/tests/unit/llms/test_mixins.py
+++ b/tests/unit/llms/test_mixins.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import multiprocessing as mp
import os
import sys
from typing import TYPE_CHECKING, Any, Generator, List, Union
@@ -43,6 +42,10 @@ def load(self) -> None:
super().load()
CudaDevicePlacementMixin.load(self)
+ def unload(self) -> None:
+ super().unload()
+ CudaDevicePlacementMixin.unload(self)
+
@property
def model_name(self) -> str:
return "test"
@@ -63,13 +66,7 @@ def test_set_cuda_visible_devices(self) -> None:
assert os.environ["CUDA_VISIBLE_DEVICES"] == "0,1"
- def test_cuda_visible_devices_not_cuda_devices(self) -> None:
- llm = DummyCudaLLM()
- llm._llm_identifier = "unit-test"
-
- llm.load()
-
- assert os.getenv("CUDA_VISIBLE_DEVICES") is None
+ llm.unload()
def test_set_cuda_visible_devices_unvalid_devices(self) -> None:
llm = DummyCudaLLM(cuda_devices=[5, 6])
@@ -80,84 +77,54 @@ def test_set_cuda_visible_devices_unvalid_devices(self) -> None:
):
llm.load()
- def test_set_device_placement_info(self) -> None:
- llm = DummyCudaLLM(cuda_devices="auto")
+ llm.unload()
+
+ def test_set_cuda_visible_devices_auto(self) -> None:
+ llm1 = DummyCudaLLM()
+ llm1._llm_identifier = "unit-test-1"
+ llm1.load()
- with mp.Manager() as manager:
- llm.set_device_placement_info(
- llm_identifier="unit-test",
- device_llm_placement_map=manager.dict(),
- device_llm_placement_lock=manager.Lock(), # type: ignore
- )
+ assert os.environ["CUDA_VISIBLE_DEVICES"] == "0"
- assert llm._llm_identifier == "unit-test"
- assert llm._device_llm_placement_map is not None
+ llm2 = DummyCudaLLM()
+ llm2._llm_identifier = "unit-test-2"
+ llm2.load()
- def test_set_cuda_visible_devices_auto(self) -> None:
- with mp.Manager() as manager:
- device_llm_placement_map = manager.dict()
- lock = manager.Lock()
-
- llm1 = DummyCudaLLM()
- llm1.set_device_placement_info(
- llm_identifier="unit-test-1",
- device_llm_placement_map=device_llm_placement_map,
- device_llm_placement_lock=lock, # type: ignore
- )
- llm1.load()
-
- assert os.environ["CUDA_VISIBLE_DEVICES"] == "0"
-
- llm2 = DummyCudaLLM()
- llm2.set_device_placement_info(
- llm_identifier="unit-test-2",
- device_llm_placement_map=device_llm_placement_map,
- device_llm_placement_lock=lock, # type: ignore
- )
- llm2.load()
-
- assert os.environ["CUDA_VISIBLE_DEVICES"] == "1"
+ assert os.environ["CUDA_VISIBLE_DEVICES"] == "1"
+
+ llm1.unload()
+ llm2.unload()
def test_set_cuda_visible_devices_auto_not_enough_devices(self) -> None:
- with mp.Manager() as manager:
- device_llm_placement_map = manager.dict()
- lock = manager.Lock()
-
- with pytest.raises(
- RuntimeError, match="Couldn't find an available CUDA device"
- ):
- # 4 devices are available, but 5 LLMs are going to be loaded
- for i in range(5):
- llm = DummyCudaLLM()
- llm.set_device_placement_info(
- llm_identifier=f"unit-test-{i}",
- device_llm_placement_map=device_llm_placement_map,
- device_llm_placement_lock=lock, # type: ignore
- )
- llm.load()
+ llms = []
+ for i in range(5):
+ llm = DummyCudaLLM()
+ llm._llm_identifier = f"unit-test-{i}"
+ llms.append(llm)
+
+ with pytest.raises(
+ RuntimeError, match="Couldn't find an available CUDA device"
+ ):
+ # 4 devices are available, but 5 LLMs are going to be loaded
+ for llm in llms:
+ llm.load()
+
+ for llm in llms:
+ llm.unload()
def test_check_cuda_devices(self, caplog) -> None:
- with mp.Manager() as manager:
- device_llm_placement_map = manager.dict()
- lock = manager.Lock()
-
- llm1 = DummyCudaLLM(cuda_devices=[1])
- llm1.set_device_placement_info(
- llm_identifier="unit-test-1",
- device_llm_placement_map=device_llm_placement_map,
- device_llm_placement_lock=lock, # type: ignore
- )
- llm1.load()
-
- llm2 = DummyCudaLLM(cuda_devices=[1])
- llm2.set_device_placement_info(
- llm_identifier="unit-test-2",
- device_llm_placement_map=device_llm_placement_map,
- device_llm_placement_lock=lock, # type: ignore
- )
- llm2.load()
-
- assert (
- "LLM with identifier 'unit-test-1' is also going to use CUDA device '1'"
- in caplog.text
- )
+ llm1 = DummyCudaLLM(cuda_devices=[1])
+ llm1._llm_identifier = "unit-test-1"
+ llm1.load()
+
+ llm2 = DummyCudaLLM(cuda_devices=[1])
+ llm2._llm_identifier = "unit-test-2"
+ llm2.load()
+
+ assert (
+ "LLM with identifier 'unit-test-1' is also going to use CUDA device '1'"
+ in caplog.text
+ )
+
+ llm1.unload()
+ llm2.unload()
diff --git a/tests/unit/llms/test_moa.py b/tests/unit/llms/test_moa.py
new file mode 100644
index 0000000000..b3a92eded1
--- /dev/null
+++ b/tests/unit/llms/test_moa.py
@@ -0,0 +1,61 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.llms.moa import MOA_SYSTEM_PROMPT, MixtureOfAgentsLLM
+
+from tests.unit.conftest import DummyLLM
+
+
+class TestMixtureOfAgents:
+ def test_model_name(self) -> None:
+ llm = MixtureOfAgentsLLM(
+ aggregator_llm=DummyLLM(),
+ proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()],
+ )
+
+ assert llm.model_name == "moa-test-test-test-test"
+
+ def test_build_moa_system_prompt(self) -> None:
+ llm = MixtureOfAgentsLLM(
+ aggregator_llm=DummyLLM(),
+ proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()],
+ )
+
+ system_prompt = llm._build_moa_system_prompt(
+ prev_outputs=["output1", "output2", "output3"]
+ )
+
+ assert (
+ system_prompt == f"{MOA_SYSTEM_PROMPT}\n1. output1\n2. output2\n3. output3"
+ )
+
+ def test_inject_moa_system_prompt(self) -> None:
+ llm = MixtureOfAgentsLLM(
+ aggregator_llm=DummyLLM(),
+ proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()],
+ )
+
+ results = llm._inject_moa_system_prompt(
+ input=[
+ {"role": "system", "content": "I'm a system prompt."},
+ ],
+ prev_outputs=["output1", "output2", "output3"],
+ )
+
+ assert results == [
+ {
+ "role": "system",
+ "content": f"{MOA_SYSTEM_PROMPT}\n1. output1\n2. output2\n3. output3\n\nI'm a system prompt.",
+ }
+ ]
diff --git a/tests/unit/llms/test_openai.py b/tests/unit/llms/test_openai.py
index 3562b6588b..7f90f513a2 100644
--- a/tests/unit/llms/test_openai.py
+++ b/tests/unit/llms/test_openai.py
@@ -13,6 +13,8 @@
# limitations under the License.
import os
+import sys
+from typing import Any, Dict
from unittest import mock
from unittest.mock import AsyncMock, MagicMock, Mock, patch
@@ -20,6 +22,8 @@
import pytest
from distilabel.llms.openai import OpenAILLM
+from .utils import DummyUserDetail
+
@patch("openai.AsyncOpenAI")
class TestOpenAILLM:
@@ -63,6 +67,37 @@ async def test_agenerate(self, mock_openai: MagicMock) -> None:
]
)
+ @pytest.mark.asyncio
+ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
+ llm = OpenAILLM(
+ model=self.model_id,
+ api_key="api.key",
+ structured_output={
+ "schema": DummyUserDetail,
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ ) # type: ignore
+ llm._aclient = mock_openai
+
+ sample_user = DummyUserDetail(name="John Doe", age=30)
+
+ llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user)
+
+ generation = await llm.agenerate(
+ input=[
+ {"role": "system", "content": ""},
+ {
+ "role": "user",
+ "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
+ },
+ ]
+ )
+ assert generation[0] == sample_user.model_dump_json()
+
+ @pytest.mark.skipif(
+ sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher"
+ )
@pytest.mark.asyncio
async def test_generate(self, mock_openai: MagicMock) -> None:
llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore
@@ -101,21 +136,53 @@ async def test_generate(self, mock_openai: MagicMock) -> None:
response_format="unkown_format",
)
- def test_serialization(self, _: MagicMock) -> None:
- llm = OpenAILLM(model=self.model_id)
-
- _dump = {
- "model": self.model_id,
- "generation_kwargs": {},
- "max_retries": 6,
- "base_url": "https://api.openai.com/v1",
- "timeout": 120,
- "structured_output": None,
- "type_info": {
- "module": "distilabel.llms.openai",
- "name": "OpenAILLM",
- },
- }
-
- assert llm.dump() == _dump
- assert isinstance(OpenAILLM.from_dict(_dump), OpenAILLM)
+ @pytest.mark.parametrize(
+ "structured_output, dump",
+ [
+ (
+ None,
+ {
+ "model": "gpt-4",
+ "generation_kwargs": {},
+ "max_retries": 6,
+ "base_url": "https://api.openai.com/v1",
+ "timeout": 120,
+ "structured_output": None,
+ "type_info": {
+ "module": "distilabel.llms.openai",
+ "name": "OpenAILLM",
+ },
+ },
+ ),
+ (
+ {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ {
+ "model": "gpt-4",
+ "generation_kwargs": {},
+ "max_retries": 6,
+ "base_url": "https://api.openai.com/v1",
+ "timeout": 120,
+ "structured_output": {
+ "schema": DummyUserDetail.model_json_schema(),
+ "mode": "tool_call",
+ "max_retries": 1,
+ },
+ "type_info": {
+ "module": "distilabel.llms.openai",
+ "name": "OpenAILLM",
+ },
+ },
+ ),
+ ],
+ )
+ def test_serialization(
+ self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any]
+ ) -> None:
+ llm = OpenAILLM(model=self.model_id, structured_output=structured_output)
+
+ assert llm.dump() == dump
+ assert isinstance(OpenAILLM.from_dict(dump), OpenAILLM)
diff --git a/tests/unit/llms/test_vertexai.py b/tests/unit/llms/test_vertexai.py
index b15262e26c..5d3f8d1217 100644
--- a/tests/unit/llms/test_vertexai.py
+++ b/tests/unit/llms/test_vertexai.py
@@ -115,7 +115,6 @@ def test_serialization(self, _: MagicMock) -> None:
_dump = {
"model": "gemini-1.0-pro",
"generation_kwargs": {},
- "structured_output": None,
"type_info": {
"module": "distilabel.llms.vertexai",
"name": "VertexAILLM",
diff --git a/tests/unit/llms/test_vllm.py b/tests/unit/llms/test_vllm.py
new file mode 100644
index 0000000000..4c847aad8e
--- /dev/null
+++ b/tests/unit/llms/test_vllm.py
@@ -0,0 +1,170 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import numpy as np
+import pytest
+from distilabel.llms import vLLM
+from distilabel.llms.vllm import _sort_batches
+from pydantic import BaseModel
+
+
+class Character(BaseModel):
+ name: str
+ description: str
+ role: str
+ weapon: str
+
+
+class Animal(BaseModel):
+ name: str
+ species: str
+ habitat: str
+ diet: str
+
+
+SAMPLE_DATA = [
+ [
+ {
+ "instruction": "Generate a character from a RPG game.",
+ "structured_output": {
+ "format": "json",
+ "schema": Character.model_json_schema(),
+ },
+ },
+ {
+ "instruction": "Generate an animal from a zoo.",
+ "structured_output": {
+ "format": "json",
+ "schema": Animal.model_json_schema(),
+ },
+ },
+ {
+ "instruction": "Repeated character",
+ "structured_output": {
+ "format": "json",
+ "schema": Character.model_json_schema(),
+ },
+ },
+ {
+ "instruction": "What's the weather like today in Seattle in Celsius degrees?",
+ "structured_output": {
+ "format": "regex",
+ "schema": "(\\d{1,2})°C",
+ },
+ },
+ {
+ "instruction": "Other character",
+ "structured_output": {
+ "format": "json",
+ "schema": Character.model_json_schema(),
+ },
+ },
+ {
+ "instruction": "repeated regex",
+ "structured_output": {
+ "format": "regex",
+ "schema": "(\\d{1,2})°C",
+ },
+ },
+ ]
+]
+
+
+# Just a mock to avoid loading the model
+class DummyTokenizer:
+ def __init__(self) -> None:
+ pass
+
+ def apply_chat_template(self, input, **kwargs):
+ return input
+
+
+class TestvLLM:
+ @pytest.mark.parametrize(
+ "num_generations, expected_sorted_batches",
+ [
+ (
+ 1,
+ [
+ "Generate a character from a RPG game.",
+ "Generate an animal from a zoo.",
+ "Repeated character",
+ "What's the weather like today in Seattle in Celsius degrees?",
+ "Other character",
+ "repeated regex",
+ ],
+ ),
+ (
+ 3,
+ np.repeat(
+ [
+ "Generate a character from a RPG game.",
+ "Generate an animal from a zoo.",
+ "Repeated character",
+ "What's the weather like today in Seattle in Celsius degrees?",
+ "Other character",
+ "repeated regex",
+ ],
+ 3,
+ ).tolist(),
+ ),
+ ],
+ )
+ def test_prepare_batches_and_sort_back(
+ self, num_generations: int, expected_sorted_batches: List[str]
+ ):
+ formatted_inputs = [
+ (item["instruction"], item["structured_output"])
+ for row in SAMPLE_DATA
+ for item in row
+ ]
+ llm = vLLM(model="dummy")
+ llm._tokenizer = DummyTokenizer()
+ batches, indices = llm._prepare_batches(formatted_inputs)
+ # NOTE: We have to simulate calling self._model.generate(n=num_generations) and then sorting the results
+ num_generations_batches = []
+ for batch in batches:
+ num_generations_batches.append(
+ (np.repeat(batch[0], num_generations).tolist(), batch[1])
+ )
+ batches = num_generations_batches
+ # Recreate as the output from batched_outputs += [[output.text for output in outputs.outputs] for outputs in batch_outputs]
+ batches = [batch for batch, _ in batches]
+ sorted_batches = _sort_batches(
+ batches, indices, num_generations=num_generations
+ )
+
+ assert sorted_batches == [
+ np.repeat(
+ [
+ "Generate a character from a RPG game.",
+ "Generate an animal from a zoo.",
+ "Repeated character",
+ ],
+ num_generations,
+ ).tolist(),
+ np.repeat(
+ ["What's the weather like today in Seattle in Celsius degrees?"],
+ num_generations,
+ ).tolist(),
+ np.repeat(
+ [
+ "Other character",
+ "repeated regex",
+ ],
+ num_generations,
+ ).tolist(),
+ ]
diff --git a/tests/unit/llms/utils.py b/tests/unit/llms/utils.py
new file mode 100644
index 0000000000..7b899253bb
--- /dev/null
+++ b/tests/unit/llms/utils.py
@@ -0,0 +1,20 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pydantic import BaseModel
+
+
+class DummyUserDetail(BaseModel):
+ name: str
+ age: int
diff --git a/tests/unit/mixins/test_runtime_parameters.py b/tests/unit/mixins/test_runtime_parameters.py
index 82cac59b7e..8e8d7766d0 100644
--- a/tests/unit/mixins/test_runtime_parameters.py
+++ b/tests/unit/mixins/test_runtime_parameters.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import List, Optional
from distilabel.mixins.runtime_parameters import (
RuntimeParameter,
@@ -32,6 +32,7 @@ class DummyNestedClass(RuntimeParametersMixin):
class DummyClass(RuntimeParametersMixin):
nested_class: DummyNestedClass
+ mixins_list: List[DummyNestedClass]
runtime_param1: RuntimeParameter[SecretStr] = Field(
default=None, description="Runtime param 1"
@@ -43,7 +44,10 @@ class DummyClass(RuntimeParametersMixin):
class TestRuntimeParametersMixin:
def test_runtime_parameters_names(self) -> None:
- dummy = DummyClass(nested_class=DummyNestedClass())
+ dummy = DummyClass(
+ nested_class=DummyNestedClass(),
+ mixins_list=[DummyNestedClass(), DummyNestedClass(), DummyNestedClass()],
+ )
assert dummy.runtime_parameters_names == {
"runtime_param1": False,
@@ -52,10 +56,27 @@ def test_runtime_parameters_names(self) -> None:
"runtime_param1": False,
"runtime_param2": True,
},
+ "mixins_list": {
+ "0": {
+ "runtime_param1": False,
+ "runtime_param2": True,
+ },
+ "1": {
+ "runtime_param1": False,
+ "runtime_param2": True,
+ },
+ "2": {
+ "runtime_param1": False,
+ "runtime_param2": True,
+ },
+ },
}
def test_get_runtime_parameters_info(self) -> None:
- dummy = DummyClass(nested_class=DummyNestedClass())
+ dummy = DummyClass(
+ nested_class=DummyNestedClass(),
+ mixins_list=[DummyNestedClass(), DummyNestedClass(), DummyNestedClass()],
+ )
assert dummy.get_runtime_parameters_info() == [
{
@@ -73,6 +94,47 @@ def test_get_runtime_parameters_info(self) -> None:
},
],
},
+ {
+ "name": "mixins_list",
+ "runtime_parameters_info": {
+ "0": [
+ {
+ "name": "runtime_param1",
+ "description": "Runtime param 1",
+ "optional": False,
+ },
+ {
+ "name": "runtime_param2",
+ "description": "Runtime param 2",
+ "optional": True,
+ },
+ ],
+ "1": [
+ {
+ "name": "runtime_param1",
+ "description": "Runtime param 1",
+ "optional": False,
+ },
+ {
+ "name": "runtime_param2",
+ "description": "Runtime param 2",
+ "optional": True,
+ },
+ ],
+ "2": [
+ {
+ "name": "runtime_param1",
+ "description": "Runtime param 1",
+ "optional": False,
+ },
+ {
+ "name": "runtime_param2",
+ "description": "Runtime param 2",
+ "optional": True,
+ },
+ ],
+ },
+ },
{
"name": "runtime_param1",
"description": "Runtime param 1",
@@ -86,7 +148,10 @@ def test_get_runtime_parameters_info(self) -> None:
]
def test_set_runtime_parameters(self) -> None:
- dummy = DummyClass(nested_class=DummyNestedClass())
+ dummy = DummyClass(
+ nested_class=DummyNestedClass(),
+ mixins_list=[DummyNestedClass(), DummyNestedClass(), DummyNestedClass()],
+ )
dummy.set_runtime_parameters(
{
diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py
index 76a48ec4be..c18a30e143 100644
--- a/tests/unit/pipeline/test_base.py
+++ b/tests/unit/pipeline/test_base.py
@@ -12,33 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import os
import tempfile
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
+from queue import Queue
+from typing import Any, Callable, Dict, List, Optional
from unittest import mock
import pytest
-from distilabel.distiset import Distiset, create_distiset
from distilabel.mixins.runtime_parameters import RuntimeParameter
-from distilabel.pipeline._dag import DAG
from distilabel.pipeline.base import (
+ _STEP_LOAD_FAILED_CODE,
+ _STEP_NOT_LOADED_CODE,
BasePipeline,
- _Batch,
- _BatchManager,
- _BatchManagerStep,
_GlobalPipelineManager,
- _WriteBuffer,
)
-from distilabel.pipeline.local import Pipeline
-from distilabel.steps.base import GlobalStep, Step, StepInput
+from distilabel.pipeline.batch import _Batch
+from distilabel.pipeline.batch_manager import _BatchManager
+from distilabel.pipeline.constants import INPUT_QUEUE_ATTR_NAME, LAST_BATCH_SENT_FLAG
+from distilabel.pipeline.routing_batch_function import (
+ routing_batch_function,
+ sample_n_steps,
+)
+from distilabel.pipeline.write_buffer import _WriteBuffer
+from distilabel.steps.base import Step, StepInput, _Step
+from distilabel.steps.typing import StepOutput
from distilabel.utils.serialization import TYPE_INFO_KEY
+from fsspec.implementations.local import LocalFileSystem
from pydantic import Field
+from upath import UPath
+
+from .utils import (
+ DummyGeneratorStep,
+ DummyGlobalStep,
+ DummyStep1,
+ DummyStep2,
+)
+
-from .utils import DummyGeneratorStep, DummyStep1, DummyStep2, batch_gen
+class DummyPipeline(BasePipeline):
+ @property
+ def QueueClass(self) -> Callable:
+ return Queue
-if TYPE_CHECKING:
- from distilabel.steps.base import GeneratorStep
+ def _run_step(self, step: "_Step", input_queue: "Queue[Any]") -> None:
+ pass
+
+ def _stop(self) -> None:
+ pass
class TestGlobalPipelineManager:
@@ -46,7 +68,7 @@ def teardown_method(self) -> None:
_GlobalPipelineManager.set_pipeline(None)
def test_set_pipeline(self) -> None:
- pipeline = BasePipeline(name="unit-test-pipeline")
+ pipeline = DummyPipeline(name="unit-test-pipeline")
_GlobalPipelineManager.set_pipeline(pipeline)
assert _GlobalPipelineManager.get_pipeline() == pipeline
@@ -55,7 +77,7 @@ def test_set_pipeline_none(self) -> None:
assert _GlobalPipelineManager.get_pipeline() is None
def test_get_pipeline(self) -> None:
- pipeline = BasePipeline(name="unit-test-pipeline")
+ pipeline = DummyPipeline(name="unit-test-pipeline")
_GlobalPipelineManager.set_pipeline(pipeline)
assert _GlobalPipelineManager.get_pipeline() == pipeline
@@ -64,1906 +86,832 @@ class TestBasePipeline:
def test_context_manager(self) -> None:
assert _GlobalPipelineManager.get_pipeline() is None
- with BasePipeline(name="unit-test-pipeline") as pipeline:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
assert pipeline is not None
assert _GlobalPipelineManager.get_pipeline() == pipeline
assert _GlobalPipelineManager.get_pipeline() is None
- def test_get_runtime_parameters_info(self) -> None:
- class DummyStep1(Step):
- runtime_param1: RuntimeParameter[str] = Field(
- default=None, description="runtime_param1 description"
- )
- runtime_param2: Optional[RuntimeParameter[str]] = Field(
- default=None, description="runtime_param2 description"
+ @pytest.mark.parametrize("use_cache", [False, True])
+ def test_load_batch_manager(self, use_cache: bool) -> None:
+ pipeline = DummyPipeline(name="unit-test-pipeline")
+ pipeline._load_batch_manager(use_cache=True)
+ pipeline._cache()
+
+ with mock.patch(
+ "distilabel.pipeline.base._BatchManager.load_from_cache"
+ ) as mock_load_from_cache, mock.patch(
+ "distilabel.pipeline.base._BatchManager.from_dag"
+ ) as mock_from_dag:
+ pipeline._load_batch_manager(use_cache=use_cache)
+
+ if use_cache:
+ mock_load_from_cache.assert_called_once_with(
+ pipeline._cache_location["batch_manager"]
)
+ mock_from_dag.assert_not_called()
+ else:
+ mock_load_from_cache.assert_not_called()
+ mock_from_dag.assert_called_once_with(pipeline.dag)
- def process(self, inputs: StepInput) -> None:
- pass
+ def test_setup_write_buffer(self) -> None:
+ pipeline = DummyPipeline(name="unit-test-pipeline")
- class DummyStep2(Step):
- runtime_param3: RuntimeParameter[str] = Field(
- default=None, description="runtime_param3 description"
- )
- runtime_param4: Optional[RuntimeParameter[str]] = Field(
- default=None, description="runtime_param4 description"
- )
+ pipeline._setup_write_buffer()
+ assert isinstance(pipeline._write_buffer, _WriteBuffer)
- def process(self, inputs: StepInput) -> None:
- pass
+ def test_set_logging_parameters(self) -> None:
+ pipeline = DummyPipeline(name="unit-test-pipeline")
+ pipeline._set_logging_parameters({"unit-test": "yes"})
- with BasePipeline(name="unit-test-pipeline") as pipeline:
- DummyStep1(name="dummy_step_1")
- DummyStep2(name="dummy_step_2")
+ assert pipeline._logging_parameters == {"unit-test": "yes"}
- assert pipeline.get_runtime_parameters_info() == {
- "dummy_step_1": [
- {
- "description": "The number of rows that will contain the batches processed by the "
- "step.",
- "name": "input_batch_size",
- "optional": True,
- },
- {
- "name": "runtime_param1",
- "description": "runtime_param1 description",
- "optional": False,
- },
- {
- "name": "runtime_param2",
- "description": "runtime_param2 description",
- "optional": True,
- },
- ],
- "dummy_step_2": [
- {
- "description": "The number of rows that will contain the batches processed by the "
- "step.",
- "name": "input_batch_size",
- "optional": True,
- },
- {
- "name": "runtime_param3",
- "description": "runtime_param3 description",
- "optional": False,
- },
- {
- "name": "runtime_param4",
- "description": "runtime_param4 description",
- "optional": True,
- },
- ],
+ def test_setup_fsspec(self) -> None:
+ pipeline = DummyPipeline(name="unit-test-pipeline")
+
+ with mock.patch("fsspec.filesystem") as mock_filesystem:
+ pipeline._setup_fsspec({"path": "gcs://my-bucket", "extra": "stuff"})
+
+ mock_filesystem.assert_called_once_with("gcs", **{"extra": "stuff"})
+
+ def test_setup_fsspec_default(self) -> None:
+ pipeline = DummyPipeline(name="unit-test-pipeline")
+ pipeline._setup_fsspec()
+
+ assert isinstance(pipeline._fs, LocalFileSystem)
+ assert (
+ pipeline._storage_base_path
+ == f"file://{pipeline._cache_location['batch_input_data']}"
+ )
+
+ def test_setup_fsspec_raises_value_error(self) -> None:
+ pipeline = DummyPipeline(name="unit-test-pipeline")
+
+ with pytest.raises(ValueError, match="The 'path' key must be present"):
+ pipeline._setup_fsspec({"key": "random"})
+
+ def test_init_steps_load_status(self) -> None:
+ with DummyPipeline(name="dummy") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
+
+ generator >> [step, step2] >> step3
+
+ pipeline._init_steps_load_status()
+ assert pipeline._steps_load_status == {
+ generator.name: _STEP_NOT_LOADED_CODE,
+ step.name: _STEP_NOT_LOADED_CODE,
+ step2.name: _STEP_NOT_LOADED_CODE,
+ step3.name: _STEP_NOT_LOADED_CODE,
}
- # Test no log, Test log, test log without close match
- @pytest.mark.parametrize(
- "parameters, expected",
- (
- (
- {
- "dummy_step_1": {"runtime_param1": "value1"},
- "dummy_step_2": {"runtime_param3": "value1"},
- },
- "",
- ),
- (
- {
- "dummy_step_1": {"runtime_param1": "value1"},
- "dummy_step_2": {
- "runtime_param3": "value1",
- "runtime_param_unknown": "value1",
- },
- },
- "Did you mean any of:",
- ),
- (
- {
- "dummy_step_1": {"runtime_param1": "value1"},
- "dummy_step_2": {
- "runtime_param3": "value1",
- "weird_name": "value1",
- },
- },
- "Available runtime parameters for the step",
- ),
- ),
- )
- def test_check_runtime_parameters(
- self, caplog, parameters: Dict[str, Any], expected: str
- ) -> None:
- class DummyStep1(Step):
- runtime_param1: RuntimeParameter[str] = Field(
- default=None, description="runtime_param1 description"
- )
- runtime_param2: Optional[RuntimeParameter[str]] = Field(
- default=None, description="runtime_param2 description"
- )
+ def test_run_load_queue_loop(self) -> None:
+ pipeline = DummyPipeline(name="unit-test-pipeline")
- def process(self, inputs: StepInput) -> None:
- pass
+ pipeline._load_queue = Queue()
+ pipeline._steps_load_status = {"dummy": 0}
+ pipeline._load_queue.put({"name": "dummy", "status": "loaded"})
- class DummyStep2(Step):
- runtime_param3: RuntimeParameter[str] = Field(
- default=None, description="runtime_param3 description"
- )
- runtime_param4: Optional[RuntimeParameter[str]] = Field(
- default=None, description="runtime_param4 description"
- )
+ thread = pipeline._run_load_queue_loop_in_thread()
+ pipeline._load_queue.put(None)
+ thread.join()
- def process(self, inputs: StepInput) -> None:
- pass
+ assert pipeline._steps_load_status["dummy"] == 1
- with BasePipeline(name="unit-test-pipeline") as pipeline:
- gen_step = DummyGeneratorStep(name="dummy_generator_step")
- step1 = DummyStep1(name="dummy_step_1")
- step2 = DummyStep2(name="dummy_step_2")
+ def test_run_load_queue_loop_receiving_none(self) -> None:
+ pipeline = DummyPipeline(name="unit-test-pipeline")
- gen_step >> step1 >> step2
+ pipeline._load_queue = Queue()
+ pipeline._load_queue.put(None)
- pipeline.run(parameters=parameters)
- if expected:
- assert expected in caplog.text
- else:
- assert caplog.text == expected
+ thread = pipeline._run_load_queue_loop_in_thread()
+ thread.join()
- def test_cache_dir_env_variable(self) -> None:
- with mock.patch.dict(os.environ, clear=True):
- os.environ["DISTILABEL_CACHE_DIR"] = "/tmp/unit-test"
- pipeline = BasePipeline(name="unit-test-pipeline")
- assert pipeline._cache_dir == Path("/tmp/unit-test")
+ assert not thread.is_alive()
- @pytest.mark.parametrize(
- "in_pipeline, names",
- (
- (
- True,
- [
- "dummy_generator_step_0",
- "dummy_step1_0",
- "dummy_step2_0",
- "dummy_step1_1",
- ],
- ),
- # TODO: Activate this test once we merge the option of not passing a Pipeline
- # (
- # False, ["dummy_generator_step", "dummy_step1", "dummy_step2"]
- # )
- ),
- )
- def test_step_names_inferred(self, in_pipeline: bool, names: List[str]) -> None:
- if in_pipeline:
- with BasePipeline(name="unit-test-pipeline"):
- gen_step = DummyGeneratorStep()
- step1_0 = DummyStep1()
- step2 = DummyStep2()
- step1_1 = DummyStep1()
+ def test_all_steps_loaded(self, caplog) -> None:
+ with DummyPipeline(name="dummy") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
- gen_step >> step1_0 >> step2 >> step1_1
- else:
- gen_step = DummyGeneratorStep()
- step1_0 = DummyStep1()
- step2 = DummyStep2()
- step1_1 = DummyStep1()
+ generator >> [step, step2] >> step3
- assert gen_step.name == names[0]
- assert step1_0.name == names[1]
- assert step2.name == names[2]
- assert step1_1.name == names[3]
+ pipeline._steps_load_status = { # type: ignore
+ generator.name: 1,
+ step.name: 1,
+ step2.name: 1,
+ step3.name: 1,
+ }
+ caplog.set_level(logging.INFO)
- def test_infer_step_names_big_pipeline(self) -> None:
- # Tests that the name of the steps are inferred correctly when the pipeline is big (say 50 steps).
- with BasePipeline(name="unit-test-pipeline") as pipe:
- gen_step = DummyGeneratorStep()
- for _ in range(50):
- gen_step.connect(DummyStep1())
- assert list(pipe.dag.G)[-1] == "dummy_step1_49"
+ assert pipeline._all_steps_loaded() is True
+ assert "All the steps have been loaded!" in caplog.text
+ def test_all_steps_loaded_with_failing_step(self, caplog) -> None:
+ with DummyPipeline(name="dummy") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
-class TestBatch:
- def test_set_data(self) -> None:
- batch = _Batch(seq_no=0, step_name="step1", last_batch=False)
- data = [[{"i": i} for i in range(5000)]]
- batch.set_data(data)
+ generator >> [step, step2] >> step3
- assert batch.data == data
- assert batch.size == 5000
+ pipeline._init_steps_load_status()
+ pipeline._steps_load_status[generator.name] = _STEP_LOAD_FAILED_CODE # type: ignore
+ caplog.set_level(logging.INFO)
- def test_next_batch(self) -> None:
- batch = _Batch(seq_no=0, step_name="step1", last_batch=False)
- next_batch = batch.next_batch()
+ assert pipeline._all_steps_loaded() is False
+ assert "Failed to load all the steps" in caplog.text
- assert next_batch == _Batch(seq_no=1, step_name="step1", last_batch=False)
+ def test_all_steps_loaded_stop_aclled(self) -> None:
+ with DummyPipeline(name="dummy") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
- def test_accumulate(self) -> None:
- batches = [
- [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
- ),
- _Batch(
- seq_no=1,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 4}, {"a": 5}, {"a": 6}]],
- ),
- ],
+ generator >> [step, step2] >> step3
+
+ pipeline._init_steps_load_status()
+ pipeline._stop_called = True
+
+ assert pipeline._all_steps_loaded() is False
+
+ def test_handle_stop(self) -> None:
+ with DummyPipeline(name="dummy") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
+
+ generator >> [step, step2] >> step3
+
+ pipeline._add_batches_back_to_batch_manager = mock.MagicMock()
+ pipeline._wait_step_input_queue_empty = mock.MagicMock()
+ pipeline._consume_output_queue = mock.MagicMock()
+
+ pipeline._handle_stop()
+
+ pipeline._add_batches_back_to_batch_manager.assert_called_once()
+ pipeline._wait_step_input_queue_empty.assert_has_calls(
[
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}]],
- ),
- _Batch(
- seq_no=1,
- step_name="step2",
- last_batch=True,
- data=[[{"b": 4}, {"b": 5}, {"b": 6}]],
- ),
+ mock.call(generator.name),
+ mock.call(step.name),
+ mock.call(step2.name),
+ mock.call(step3.name),
],
- ]
+ any_order=True,
+ )
+ pipeline._consume_output_queue.assert_called_once()
- batch = _Batch.accumulate("step3", batches)
+ @pytest.mark.parametrize(
+ "num_workers,expected", [(0, True), (_STEP_LOAD_FAILED_CODE, True), (1, False)]
+ )
+ def test_check_step_not_loaded_or_finished(
+ self, num_workers: int, expected: bool
+ ) -> None:
+ pipeline = DummyPipeline(name="unit-test-pipeline")
+ pipeline._steps_load_status = {"dummy": num_workers}
- assert batch.seq_no == 0
- assert batch.step_name == "step3"
- assert batch.last_batch is True
- assert batch.data == [
- [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}],
- [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}],
- ]
+ assert pipeline._check_step_not_loaded_or_finished("dummy") is expected
- def test_dump(self) -> None:
- batch = _Batch(seq_no=0, step_name="step1", last_batch=False)
- assert batch.dump() == {
- "seq_no": 0,
- "size": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [],
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- "type_info": {"module": "distilabel.pipeline.base", "name": "_Batch"},
- }
+ def test_is_convergence_step(self) -> None:
+ sample_two_steps = sample_n_steps(2)
- batch = _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
- accumulated=False,
- created_from={"step0": [0, 1]},
- batch_routed_to=["step2", "step3"],
- )
- assert batch.dump() == {
- "seq_no": 0,
- "size": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [[{"a": 1}, {"a": 2}, {"a": 3}]],
- "accumulated": False,
- "created_from": {"step0": [0, 1]},
- "batch_routed_to": ["step2", "step3"],
- "type_info": {"module": "distilabel.pipeline.base", "name": "_Batch"},
- }
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
- def test_from_dict(self) -> None:
+ generator >> sample_two_steps >> [step, step2] >> step3
+
+ pipeline.dag.validate()
+
+ assert not pipeline._is_convergence_step(generator.name) # type: ignore
+ assert not pipeline._is_convergence_step(step.name) # type: ignore
+ assert not pipeline._is_convergence_step(step2.name) # type: ignore
+ assert pipeline._is_convergence_step(step3.name) # type: ignore
+
+ def test_create_step_input_queue(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+
+ generator >> step
+
+ generator_name: str = generator.name # type: ignore
+ input_queue = pipeline._create_step_input_queue(generator_name)
+ assert isinstance(input_queue, Queue)
assert isinstance(
- _Batch.from_dict(
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": False,
- "data": [[{"a": 1}, {"a": 2}, {"a": 3}]],
- "accumulated": False,
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_Batch",
- },
- }
- ),
- _Batch,
+ pipeline.dag.get_step(generator_name)[INPUT_QUEUE_ATTR_NAME], Queue
)
+ def test_run_steps(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+
+ generator >> step
+
+ pipeline._create_step_input_queue = mock.MagicMock()
+ pipeline._run_step = mock.MagicMock()
+ pipeline._run_steps()
-class TestBatchManagerStep:
- def test_add_batch(self) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []}
+ pipeline._create_step_input_queue.assert_has_calls(
+ [
+ mock.call(step_name=step.name),
+ mock.call(step_name=generator.name),
+ ],
+ any_order=True,
)
- batch = _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
+ pipeline._run_step.assert_has_calls(
+ [
+ mock.call(step=mock.ANY, input_queue=mock.ANY),
+ mock.call(step=mock.ANY, input_queue=mock.ANY),
+ ]
)
- batch_manager_step.add_batch(batch)
+ def test_add_batches_back_to_batch_manager(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
- assert batch_manager_step.data["step1"] == [batch]
- assert batch_manager_step.last_batch_received == []
+ generator >> step
- def test_add_batch_with_prepend(self) -> None:
- batch_1 = _Batch(
- seq_no=1,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]],
- )
- batch_manager_step = _BatchManagerStep(
- step_name="step2",
- accumulate=False,
- input_batch_size=10,
- data={"step1": [batch_1]},
+ generator_name: str = generator.name # type: ignore
+ step_name: str = step.name # type: ignore
+
+ pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag)
+ generator_queue = Queue()
+ pipeline.dag.set_step_attr(
+ generator_name, INPUT_QUEUE_ATTR_NAME, generator_queue
)
+ step_queue = Queue()
+ pipeline.dag.set_step_attr(step_name, INPUT_QUEUE_ATTR_NAME, step_queue)
- batch_0 = _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ generator_queue.put(
+ _Batch(seq_no=0, step_name=generator_name, last_batch=False)
+ )
+ generator_queue.put(
+ _Batch(seq_no=1, step_name=generator_name, last_batch=False)
)
- batch_manager_step.add_batch(batch_0, prepend=True)
- assert batch_manager_step.data["step1"] == [batch_0, batch_1]
- assert batch_manager_step.last_batch_received == []
+ step_batch_0 = _Batch(seq_no=0, step_name=step_name, last_batch=False)
+ step_batch_1 = _Batch(seq_no=0, step_name=step_name, last_batch=False)
+ step_queue.put(step_batch_0)
+ step_queue.put(step_batch_1)
- def test_add_batch_last_batch(self) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []}
- )
+ pipeline._add_batches_back_to_batch_manager()
- batch = _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
- )
+ assert pipeline._batch_manager._steps[step_name].built_batches == [
+ step_batch_0,
+ step_batch_1,
+ ]
- batch_manager_step.add_batch(batch)
-
- assert batch_manager_step.data["step1"] == [batch]
- assert batch_manager_step.last_batch_received == ["step1"]
-
- def test_get_batch(self) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step3",
- accumulate=False,
- input_batch_size=2,
- data={
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ]
- ],
- size=5,
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- ]
- ],
- size=5,
- )
- ],
- },
- )
+ def test_consume_output_queue(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
- batch = batch_manager_step.get_batch()
+ generator >> step
- assert batch == _Batch(
- step_name="step3",
- seq_no=0,
- last_batch=False,
- data=[
- [
- {"a": 1},
- {"a": 2},
- ],
- [
- {"b": 1},
- {"b": 2},
- ],
- ],
- created_from={"step1": [(0, 5)], "step2": [(0, 5)]},
- )
+ pipeline._output_queue = Queue()
+ pipeline._write_buffer = mock.MagicMock()
+ pipeline._handle_batch_on_stop = mock.MagicMock()
- batch = batch_manager_step.get_batch()
+ generator_name: str = generator.name # type: ignore
+ step_name: str = step.name # type: ignore
- assert batch == _Batch(
- step_name="step3",
- seq_no=1,
- last_batch=False,
- data=[
- [
- {"a": 3},
- {"a": 4},
- ],
- [
- {"b": 3},
- {"b": 4},
- ],
- ],
- created_from={"step1": [(0, 5)], "step2": [(0, 5)]},
- )
+ generator_batch = _Batch(seq_no=0, step_name=generator_name, last_batch=False)
+ step_batch = _Batch(seq_no=0, step_name=step_name, last_batch=False)
- def test_get_batches_accumulate(self) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step3",
- accumulate=True,
- data={
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ]
- ],
- size=5,
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=True,
- data=[
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- ]
- ],
- size=6,
- )
- ],
- },
- last_batch_received=["step1", "step2"],
- )
+ pipeline._output_queue.put(generator_batch)
+ pipeline._output_queue.put(step_batch)
- batch = batch_manager_step.get_batch()
+ pipeline._consume_output_queue()
- assert batch == _Batch(
- step_name="step3",
- seq_no=0,
- last_batch=True,
- accumulated=True,
- data=[
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- ],
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- ],
- ],
- created_from={"step1": [(0, 5)], "step2": [(0, 6)]},
+ pipeline._write_buffer.add_batch.assert_called_once_with(step_batch)
+ pipeline._handle_batch_on_stop.assert_has_calls(
+ [
+ mock.call(generator_batch),
+ mock.call(step_batch),
+ ]
)
- def test_get_batches_not_enough_data(self) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step3",
- accumulate=False,
- input_batch_size=2,
- data={
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[
- [
- {"a": 1},
- ]
- ],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[
- [
- {"b": 1},
- {"b": 2},
- ]
- ],
- )
- ],
- },
- )
+ def test_send_batch_to_step(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ global_step = DummyGlobalStep()
- assert batch_manager_step.get_batch() is None
+ generator >> [step, global_step]
- def test_from_step(self, dummy_step_1: "Step") -> None:
- batch_manager_step = _BatchManagerStep.from_step(
- step=dummy_step_1, predecessors=["step1", "step2"]
- )
+ pipeline._batch_manager = mock.MagicMock()
+ pipeline._send_to_step = mock.MagicMock()
+ pipeline._setup_fsspec()
- assert batch_manager_step.step_name == "dummy_step_1"
- assert batch_manager_step.accumulate is False
- assert batch_manager_step.input_batch_size == 50
- assert batch_manager_step.data == {"step1": [], "step2": []}
- assert batch_manager_step.seq_no == 0
- assert batch_manager_step.last_batch_received == []
+ with mock.patch(
+ "distilabel.pipeline.base._Batch.write_batch_data_to_fs"
+ ) as mock_write:
+ batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
+ pipeline._send_batch_to_step(batch)
+ pipeline._batch_manager.set_last_batch_sent.assert_called_once_with(batch)
- def test_from_step_with_global_step(self, dummy_global_step: "GlobalStep") -> None:
- batch_manager_step = _BatchManagerStep.from_step(
- step=dummy_global_step, predecessors=["step1", "step2"]
- )
+ pipeline._send_batch_to_step(
+ _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore
+ )
- assert batch_manager_step.step_name == "dummy_global_step"
- assert batch_manager_step.accumulate is True
- assert batch_manager_step.input_batch_size == 50
- assert batch_manager_step.data == {"step1": [], "step2": []}
- assert batch_manager_step.seq_no == 0
- assert batch_manager_step.last_batch_received == []
+ # `write_batch_data_to_fs` shouldn't have been called because last batch sent with
+ # `_send_batch_to_step` is from a non-global step.
+ mock_write.assert_not_called()
- def test_get_seq_no(self) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step2", accumulate=False, input_batch_size=5, data={"step1": []}
- )
+ with mock.patch(
+ "distilabel.pipeline.base._Batch.write_batch_data_to_fs"
+ ) as mock_write:
+ pipeline._send_batch_to_step(
+ _Batch(seq_no=0, step_name=global_step.name, last_batch=False) # type: ignore
+ )
- seq_no = batch_manager_step._get_seq_no()
-
- assert seq_no == 0
- assert batch_manager_step.seq_no == 1
-
- def test_get_data(self) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step3",
- accumulate=False,
- input_batch_size=5,
- data={
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[
- [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]
- ],
- size=6,
- batch_routed_to=["step1", "step2"],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- size=7,
- batch_routed_to=["step1", "step2"],
- )
- ],
- },
+ # `write_batch_data_to_fs` should have been called because last batch sent with
+ # `_send_batch_to_step` is from a global step.
+ mock_write.assert_called_once_with(
+ pipeline._fs,
+ UPath(pipeline._storage_base_path) / global_step.name,
)
- data, created_from, routed_to = batch_manager_step._get_data()
- assert data == [
- [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}],
- [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}],
- ]
- assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]}
- assert routed_to == ["step1", "step2"]
-
- assert batch_manager_step.data == {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 6}]],
- size=6,
- batch_routed_to=["step1", "step2"],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[[{"b": 6}, {"b": 7}]],
- size=7,
- batch_routed_to=["step1", "step2"],
- )
- ],
- }
+ pipeline._use_fs_to_pass_data = True
- def test_get_data_accumulate(self) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step3",
- accumulate=True,
- data={
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[
- [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]
- ],
- size=6,
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- size=7,
- )
- ],
- },
+ with mock.patch(
+ "distilabel.pipeline.base._Batch.write_batch_data_to_fs"
+ ) as mock_write:
+ pipeline._send_batch_to_step(
+ _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
+ )
+
+ # `write_batch_data_to_fs` shouldn't have been called because generator receives
+ # empty batches, so there's no data to write.
+ mock_write.assert_not_called()
+
+ with mock.patch(
+ "distilabel.pipeline.base._Batch.write_batch_data_to_fs"
+ ) as mock_write:
+ pipeline._send_batch_to_step(
+ _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore
+ )
+ pipeline._send_batch_to_step(
+ _Batch(seq_no=0, step_name=global_step.name, last_batch=False) # type: ignore
+ )
+
+ mock_write.assert_has_calls(
+ [
+ mock.call(
+ pipeline._fs,
+ UPath(pipeline._storage_base_path) / step.name,
+ ),
+ mock.call(
+ pipeline._fs,
+ UPath(pipeline._storage_base_path) / global_step.name,
+ ),
+ ]
)
- data, created_from, routed_to = batch_manager_step._get_data()
+ def test_register_batch(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
- assert data == [
- [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}],
- [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}],
- ]
- assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]}
- assert routed_to == []
+ generator >> step
- assert batch_manager_step.data == {"step1": [], "step2": []}
+ pipeline._batch_manager = mock.MagicMock()
+ batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
+ pipeline._register_batch(batch)
- def test_get_data_convergence_step(self) -> None:
- batch_a_0 = _Batch(
- seq_no=0,
- step_name="A",
- last_batch=False,
- data=[
- [
- {"generation": "Hello, I'm A 0"},
- {"generation": "Hello, I'm A 0"},
- {"generation": "Hello, I'm A 0"},
- ]
- ],
- size=3,
- created_from={"Z": [(0, 3)]},
+ pipeline._batch_manager.register_batch.assert_called_once_with(batch)
+
+ def test_send_last_batch_flag_to_step(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+
+ generator >> step
+
+ step_name: str = step.name # type: ignore
+
+ pipeline._batch_manager = _BatchManager(
+ steps={},
+ last_batch_received={step_name: None},
+ last_batch_sent={step_name: None},
+ last_batch_flag_sent_to=[],
)
- batch_a_1 = _Batch(
- seq_no=1,
- step_name="A",
- last_batch=False,
- data=[
- [
- {"generation": "Hello, I'm A 1"},
- {"generation": "Hello, I'm A 1"},
- {"generation": "Hello, I'm A 1"},
- ]
- ],
- size=3,
- created_from={"Z": [(1, 3)]},
+ with mock.patch.object(pipeline, "_send_to_step") as mock_sent_to_step:
+ pipeline._send_last_batch_flag_to_step(step_name)
+
+ mock_sent_to_step.assert_called_once_with(step_name, LAST_BATCH_SENT_FLAG)
+
+ pipeline._batch_manager._last_batch_sent[step_name] = _Batch(
+ seq_no=0,
+ step_name=step_name,
+ last_batch=True,
)
+ with mock.patch.object(pipeline, "_send_to_step") as mock_sent_to_step:
+ pipeline._send_last_batch_flag_to_step(step_name)
+
+ mock_sent_to_step.assert_not_called()
+
+ def test_request_initial_batches(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1(input_batch_size=5)
- batch_b_0 = _Batch(
+ generator >> step
+
+ generator2 = DummyGeneratorStep()
+ step2 = DummyStep1(input_batch_size=5)
+
+ generator2 >> step2
+
+ pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag)
+
+ # Simulate there were batches from the cache for the steps
+ batch_0 = _Batch(
seq_no=0,
- step_name="B",
+ step_name=generator.name, # type: ignore
last_batch=False,
- data=[
- [
- {"generation": "Hello, I'm B 0"},
- {"generation": "Hello, I'm B 0"},
- {"generation": "Hello, I'm B 0"},
- ]
- ],
- size=3,
- created_from={"Z": [(0, 3)]},
+ data=[[{"a": i} for i in range(5)]],
)
+ pipeline._batch_manager._steps[step.name].data[generator.name] = [ # type: ignore
+ batch_0
+ ]
- batch_c_0 = _Batch(
+ batch_1 = _Batch(
seq_no=0,
- step_name="C",
+ step_name=generator2.name, # type: ignore
last_batch=False,
- data=[
- [
- {"generation": "Hello, I'm C 0"},
- {"generation": "Hello, I'm C 0"},
- {"generation": "Hello, I'm C 0"},
- ]
+ data=[[{"b": i} for i in range(5)]],
+ ) # type: ignore
+ pipeline._batch_manager._steps[step2.name].data[generator2.name] = [ # type: ignore
+ batch_1
+ ]
+
+ with mock.patch.object(
+ pipeline, "_send_batch_to_step"
+ ) as mock_send_batch_to_step:
+ pipeline._request_initial_batches()
+
+ mock_send_batch_to_step.assert_has_calls(
+ [
+ mock.call(mock.ANY),
+ mock.call(mock.ANY),
+ mock.call(_Batch(seq_no=0, step_name=generator.name, last_batch=False)), # type: ignore
+ mock.call(
+ _Batch(seq_no=0, step_name=generator2.name, last_batch=False) # type: ignore
+ ),
],
- size=3,
- created_from={"Z": [(1, 3)]},
+ any_order=True,
)
- batch_manager_step = _BatchManagerStep(
- step_name="D",
- input_batch_size=3,
- convergence_step=True,
- accumulate=False,
- data={"A": [batch_a_0, batch_a_1], "B": [batch_b_0], "C": [batch_c_0]},
- )
+ def test_request_more_batches_if_needed(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
- data, created_from, routed_to = batch_manager_step._get_data()
+ generator >> step
- assert data == [
- [
- {"generation": "Hello, I'm A 0"},
- {"generation": "Hello, I'm A 0"},
- {"generation": "Hello, I'm A 0"},
- ],
- [
- {"generation": "Hello, I'm B 0"},
- {"generation": "Hello, I'm B 0"},
- {"generation": "Hello, I'm B 0"},
- ],
- ]
- assert created_from == {"A": [(0, 3)], "B": [(0, 3)]}
- assert routed_to == []
- assert batch_manager_step.next_expected_created_from_batch_seq_no == 1
+ generator_name: str = generator.name # type: ignore
- data, created_from, routed_to = batch_manager_step._get_data()
+ pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag)
- assert data == [
- [
- {"generation": "Hello, I'm A 1"},
- {"generation": "Hello, I'm A 1"},
- {"generation": "Hello, I'm A 1"},
- ],
- [
- {"generation": "Hello, I'm C 0"},
- {"generation": "Hello, I'm C 0"},
- {"generation": "Hello, I'm C 0"},
- ],
- ]
- assert created_from == {"A": [(1, 3)], "C": [(0, 3)]}
- assert routed_to == []
- assert batch_manager_step.next_expected_created_from_batch_seq_no == 2
+ batch = _Batch(seq_no=0, step_name=generator_name, last_batch=False)
+ pipeline._batch_manager._last_batch_sent[generator_name] = batch
- @pytest.mark.parametrize(
- "data, last_batch_received, expected",
- [
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- ]
- },
- [],
- False,
- ),
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]],
- )
- ],
- },
- [],
- False,
- ),
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- {"a": 6},
- ]
- ],
- )
- ]
- },
- ["step1"],
- False,
- ),
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- ]
- },
- ["step1"],
- True,
- ),
- ],
- )
- def test_last_batch(
- self,
- data: Dict[str, List[_Batch]],
- last_batch_received: List[str],
- expected: bool,
- ) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step2",
- accumulate=False,
- input_batch_size=5,
- data=data,
- last_batch_received=last_batch_received,
- )
+ with mock.patch.object(
+ pipeline, "_send_batch_to_step"
+ ) as mock_send_batch_to_step:
+ pipeline._request_more_batches_if_needed(step)
- assert batch_manager_step._last_batch() is expected
+ mock_send_batch_to_step.assert_called_once_with(batch.next_batch())
- @pytest.mark.parametrize(
- "data, last_batch_received, expected",
- [
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- )
- ],
- },
- [],
- False,
- ),
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- )
- ],
- },
- ["step1"],
- False,
- ),
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- )
- ],
- },
- ["step1", "step2"],
- True,
- ),
- ],
- )
- def test_last_batch_accumulate(
- self,
- data: Dict[str, List[_Batch]],
- last_batch_received: List[str],
- expected: bool,
- ) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step3",
- accumulate=True,
- data=data,
- last_batch_received=last_batch_received,
+ def test_handle_batch_on_stop(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1(input_batch_size=5)
+ step2 = DummyStep1(input_batch_size=5)
+ step3 = DummyStep1(input_batch_size=5)
+
+ generator >> [step, step2, step3]
+
+ batch_manager_mock = mock.MagicMock()
+ pipeline._batch_manager = batch_manager_mock
+
+ batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
+ pipeline._handle_batch_on_stop(batch)
+
+ batch_manager_mock.register_batch.assert_called_once_with(batch)
+ batch_manager_mock.add_batch.assert_has_calls(
+ [
+ mock.call(step.name, batch),
+ mock.call(step2.name, batch),
+ mock.call(step3.name, batch),
+ ]
)
- assert batch_manager_step._last_batch() is expected
+ def test_get_step_from_batch(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
- @pytest.mark.parametrize(
- "data, last_batch_received, expected",
- [
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- created_from={"step0": [(0, 5)]},
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- created_from={"step0": [(0, 5)]},
- )
- ],
- },
- [],
- False,
- ),
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- created_from={"step0": [(0, 5)]},
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- created_from={"step0": [(0, 5)]},
- )
- ],
- },
- [],
- False,
- ),
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
- created_from={"step0": [(0, 3)]},
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}]],
- created_from={"step0": [(0, 3)]},
- )
- ],
- },
- [],
- True,
- ),
- ],
- )
- def test_last_batch_convergence_step(
- self,
- data: Dict[str, List[_Batch]],
- last_batch_received: List[str],
- expected: bool,
- ) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step3",
- accumulate=False,
- data=data,
- last_batch_received=last_batch_received,
- input_batch_size=3,
- convergence_step=True,
+ generator >> step
+
+ batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
+ assert pipeline._get_step_from_batch(batch) == generator
+
+ batch = _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore
+ assert pipeline._get_step_from_batch(batch) == step
+
+ def test_notify_steps_to_stop(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1(input_batch_size=5)
+
+ generator >> step
+
+ with mock.patch.object(pipeline, "_send_to_step") as mock_send_to_step:
+ pipeline._notify_steps_to_stop()
+
+ mock_send_to_step.assert_has_calls(
+ [
+ mock.call(generator.name, None),
+ mock.call(step.name, None),
+ ]
)
- assert batch_manager_step._last_batch() is expected
+ def test_get_successors(self) -> None:
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1()
+ step2 = DummyStep1()
+ step3 = DummyStep2()
+
+ generator >> [step, step2] >> step3
+
+ assert pipeline._get_successors(
+ _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
+ ) == ([step.name, step2.name], False)
+ assert pipeline._get_successors(
+ _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore
+ ) == ([step3.name], False)
+ assert pipeline._get_successors(
+ _Batch(seq_no=0, step_name=step2.name, last_batch=False) # type: ignore
+ ) == ([step3.name], False)
+ assert pipeline._get_successors(
+ _Batch(seq_no=0, step_name=step3.name, last_batch=False) # type: ignore
+ ) == ([], False)
+
+ def test_get_successors_with_routing_batch_function(self) -> None:
+ @routing_batch_function()
+ def fixed_routing_batch_function(steps: List[str]) -> List[str]:
+ return ["step_2", "step_3"]
+
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ generator = DummyGeneratorStep()
+ step = DummyStep1(name="step_1")
+ step2 = DummyStep1(name="step_2")
+ step3 = DummyStep1(name="step_3")
+ step4 = DummyStep2(name="step_4")
+
+ generator >> fixed_routing_batch_function >> [step, step2, step3] >> step4
+
+ assert pipeline._get_successors(
+ _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore
+ ) == (["step_2", "step_3"], True)
+ assert pipeline._get_successors(
+ _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore
+ ) == ([step4.name], False)
+ assert pipeline._get_successors(
+ _Batch(seq_no=0, step_name=step2.name, last_batch=False) # type: ignore
+ ) == ([step4.name], False)
+ assert pipeline._get_successors(
+ _Batch(seq_no=0, step_name=step3.name, last_batch=False) # type: ignore
+ ) == ([step4.name], False)
+ assert pipeline._get_successors(
+ _Batch(seq_no=0, step_name=step4.name, last_batch=False) # type: ignore
+ ) == ([], False)
- @pytest.mark.parametrize(
- "data, last_batch_received, expected",
- [
- (
+ def test_get_runtime_parameters_info(self) -> None:
+ class DummyStep1(Step):
+ runtime_param1: RuntimeParameter[str] = Field(
+ default=None, description="runtime_param1 description"
+ )
+ runtime_param2: Optional[RuntimeParameter[str]] = Field(
+ default=None, description="runtime_param2 description"
+ )
+
+ def process(self, inputs: StepInput) -> None:
+ pass
+
+ class DummyStep2(Step):
+ runtime_param3: RuntimeParameter[str] = Field(
+ default=None, description="runtime_param3 description"
+ )
+ runtime_param4: Optional[RuntimeParameter[str]] = Field(
+ default=None, description="runtime_param4 description"
+ )
+
+ def process(self, inputs: StepInput) -> None:
+ pass
+
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ DummyStep1(name="dummy_step_1")
+ DummyStep2(name="dummy_step_2")
+
+ assert pipeline.get_runtime_parameters_info() == {
+ "dummy_step_1": [
{
- "step1": [],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- )
- ],
+ "description": "The number of rows that will contain the batches processed by the "
+ "step.",
+ "name": "input_batch_size",
+ "optional": True,
},
- [],
- False,
- ),
- (
{
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- )
- ],
+ "name": "runtime_param1",
+ "description": "runtime_param1 description",
+ "optional": False,
},
- [],
- False,
- ),
- (
{
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=True,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- )
- ],
+ "name": "runtime_param2",
+ "description": "runtime_param2 description",
+ "optional": True,
},
- ["step1", "step2"],
- True,
- ),
- (
+ ],
+ "dummy_step_2": [
{
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=True,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]],
- )
- ],
+ "description": "The number of rows that will contain the batches processed by the "
+ "step.",
+ "name": "input_batch_size",
+ "optional": True,
},
- ["step1", "step2"],
- True,
- ),
- ],
- )
- def test_ready_to_create_batch(
- self,
- data: Dict[str, List[Dict[str, Any]]],
- last_batch_received: List[str],
- expected: bool,
- ) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step2",
- accumulate=False,
- input_batch_size=5,
- data=data,
- last_batch_received=last_batch_received,
- )
-
- assert batch_manager_step._ready_to_create_batch() is expected
-
- @pytest.mark.parametrize(
- "data, last_batch_received, expected",
- [
- (
{
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=True,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- )
- ],
+ "name": "runtime_param3",
+ "description": "runtime_param3 description",
+ "optional": False,
},
- ["step1", "step2"],
- True,
- ),
- (
{
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- )
- ],
+ "name": "runtime_param4",
+ "description": "runtime_param4 description",
+ "optional": True,
},
- ["step1"],
- False,
- ),
- ],
- )
- def test_ready_to_create_batch_accumulate(
- self,
- data: Dict[str, List[_Batch]],
- last_batch_received: List[str],
- expected: bool,
- ) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step3",
- accumulate=True,
- data=data,
- last_batch_received=last_batch_received,
- )
-
- assert batch_manager_step._ready_to_create_batch() is expected
-
- def test_dump(self) -> None:
- batch_step_1 = _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]],
- size=6,
- )
- batch_step_2 = _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=True,
- data=[
- [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}]
],
- size=7,
- )
- batch_manager_step = _BatchManagerStep(
- step_name="step3",
- accumulate=True,
- data={
- "step1": [batch_step_1],
- "step2": [batch_step_2],
- },
- )
- assert batch_manager_step.dump() == {
- "step_name": "step3",
- "accumulate": True,
- "convergence_step": False,
- "convergence_step_batches_consumed": {},
- "input_batch_size": None,
- "data": {
- "step1": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": True,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- {"a": 6},
- ]
- ],
- "size": 6,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- }
- ],
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "size": 7,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- }
- ],
- },
- "seq_no": 0,
- "last_batch_received": [],
- "next_expected_created_from_batch_seq_no": 0,
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_BatchManagerStep",
- },
}
+ # Test no log, Test log, test log without close match
@pytest.mark.parametrize(
- "data, last_batch_received, expected",
- [
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
- )
- ],
- "step2": [],
- },
- [],
- False,
- ),
- (
- {
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
- )
- ],
- },
- [],
- True,
- ),
+ "parameters, expected",
+ (
(
{
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
- batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
- )
- ],
+ "dummy_step_1": {"runtime_param1": "value1"},
+ "dummy_step_2": {"runtime_param3": "value1"},
},
- [],
- False,
+ "",
),
(
{
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=True,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
- batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=True,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]],
- batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
- )
- ],
+ "dummy_step_1": {"runtime_param1": "value1"},
+ "dummy_step_2": {
+ "runtime_param3": "value1",
+ "runtime_param_unknown": "value1",
+ },
},
- ["step1", "step2"],
- True,
+ "Did you mean any of:",
),
(
{
- "step1": [
- _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
- batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 4)]},
- )
- ],
- "step2": [
- _Batch(
- seq_no=0,
- step_name="step2",
- last_batch=False,
- data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
- batch_routed_to=["step1", "step2"],
- created_from={"step0": [(0, 5)]},
- )
- ],
+ "dummy_step_1": {"runtime_param1": "value1"},
+ "dummy_step_2": {
+ "runtime_param3": "value1",
+ "weird_name": "value1",
+ },
},
- [],
- False,
+ "Available runtime parameters for the step",
),
- ],
+ ),
)
- def test_ready_to_create_batch_convergence_step(
- self,
- data: Dict[str, List[_Batch]],
- last_batch_received: List[str],
- expected: bool,
+ def test_check_runtime_parameters(
+ self, caplog, parameters: Dict[str, Any], expected: str
) -> None:
- batch_manager_step = _BatchManagerStep(
- step_name="step3",
- accumulate=False,
- input_batch_size=5,
- data=data,
- last_batch_received=last_batch_received,
- convergence_step=True,
- )
-
- assert batch_manager_step._ready_to_create_batch() is expected
-
- def test_from_dict(self) -> None:
- batch_manager_step = _BatchManagerStep.from_dict(
- {
- "step_name": "step3",
- "accumulate": True,
- "convergence_step": False,
- "convergence_step_batches_consumed": {0: {"Z": 1234}},
- "input_batch_size": None,
- "data": {
- "step1": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": True,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- {"a": 6},
- ]
- ],
- "size": 6,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- }
- ],
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "size": 7,
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- }
- ],
- },
- "seq_no": 0,
- "last_batch_received": [],
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_BatchManagerStep",
- },
- }
- )
-
- assert isinstance(batch_manager_step, _BatchManagerStep)
- assert batch_manager_step.step_name == "step3"
- assert batch_manager_step.accumulate is True
- assert batch_manager_step.convergence_step is False
- assert batch_manager_step.convergence_step_batches_consumed == {0: {"Z": 1234}}
- assert batch_manager_step.input_batch_size is None
- assert batch_manager_step.seq_no == 0
- assert batch_manager_step.last_batch_received == []
-
-
-class TestBatchManager:
- def test_add_batch(self) -> None:
- batch_manager = _BatchManager(
- steps={
- "step3": _BatchManagerStep(
- step_name="step3",
- accumulate=False,
- input_batch_size=5,
- data={"step1": [], "step2": []},
- )
- },
- last_batch_received={"step3": None},
- last_batch_sent={"step3": None},
- last_batch_flag_sent_to=[],
- )
-
- batch_from_step_1 = _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- batch_manager.add_batch(to_step="step3", batch=batch_from_step_1)
+ class DummyStep1(Step):
+ runtime_param1: RuntimeParameter[str] = Field(
+ default=None, description="runtime_param1 description"
+ )
+ runtime_param2: Optional[RuntimeParameter[str]] = Field(
+ default=None, description="runtime_param2 description"
+ )
- assert batch_manager._steps["step3"].data == {
- "step1": [batch_from_step_1],
- "step2": [],
- }
+ def process(self, inputs: StepInput) -> StepOutput: # type: ignore
+ yield [{}]
- def test_add_batch_with_prepend(self) -> None:
- batch_1 = _Batch(
- seq_no=1,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]],
- )
- batch_manager = _BatchManager(
- steps={
- "step3": _BatchManagerStep(
- step_name="step3",
- accumulate=False,
- input_batch_size=5,
- data={
- "step1": [batch_1],
- "step2": [],
- },
- )
- },
- last_batch_received={"step3": None},
- last_batch_sent={"step3": None},
- last_batch_flag_sent_to=[],
- )
- batch_0 = _Batch(
- seq_no=0,
- step_name="step1",
- last_batch=False,
- data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
- )
- batch_manager.add_batch(to_step="step3", batch=batch_0, prepend=True)
- assert batch_manager._steps["step3"].data == {
- "step1": [batch_0, batch_1],
- "step2": [],
- }
+ class DummyStep2(Step):
+ runtime_param3: RuntimeParameter[str] = Field(
+ default=None, description="runtime_param3 description"
+ )
+ runtime_param4: Optional[RuntimeParameter[str]] = Field(
+ default=None, description="runtime_param4 description"
+ )
- def test_from_dag(
- self,
- dummy_generator_step: "GeneratorStep",
- dummy_step_1: "Step",
- dummy_step_2: "Step",
- dummy_global_step: "GlobalStep",
- ) -> None:
- dag = DAG()
- dag.add_step(dummy_generator_step)
- dag.add_step(dummy_step_1)
- dag.add_step(dummy_step_2)
- dag.add_step(dummy_global_step)
- dag.add_edge("dummy_generator_step", "dummy_step_1")
- dag.add_edge("dummy_generator_step", "dummy_global_step")
- dag.add_edge("dummy_step_1", "dummy_step_2")
-
- batch_manager = _BatchManager.from_dag(dag)
-
- assert batch_manager._steps == {
- "dummy_step_1": _BatchManagerStep(
- step_name="dummy_step_1",
- accumulate=False,
- input_batch_size=50,
- data={"dummy_generator_step": []},
- ),
- "dummy_global_step": _BatchManagerStep(
- step_name="dummy_global_step",
- accumulate=True,
- input_batch_size=50,
- data={"dummy_generator_step": []},
- ),
- "dummy_step_2": _BatchManagerStep(
- step_name="dummy_step_2",
- accumulate=False,
- input_batch_size=50,
- data={"dummy_step_1": []},
- ),
- }
+ def process(self, inputs: StepInput) -> StepOutput: # type: ignore
+ yield [{}]
- def test_can_generate(self) -> None:
- batch_manager = _BatchManager(
- steps={},
- last_batch_received={
- "step_1": _Batch(seq_no=0, step_name="step_1", last_batch=False),
- "step_2": _Batch(seq_no=0, step_name="step_2", last_batch=False),
- "step_3": _Batch(seq_no=0, step_name="step_3", last_batch=False),
- },
- last_batch_sent={"step_1": None, "step_2": None, "step_3": None},
- last_batch_flag_sent_to=[],
- )
+ with DummyPipeline(name="unit-test-pipeline") as pipeline:
+ gen_step = DummyGeneratorStep(name="dummy_generator_step")
+ step1 = DummyStep1(name="dummy_step_1")
+ step2 = DummyStep2(name="dummy_step_2")
- assert batch_manager.can_generate()
+ gen_step >> step1 >> step2
- batch_1 = _Batch(seq_no=0, step_name="step_1", last_batch=True)
- batch_2 = _Batch(seq_no=0, step_name="step_2", last_batch=True)
- batch_3 = _Batch(seq_no=0, step_name="step_3", last_batch=True)
+ pipeline.run(parameters=parameters)
+ if expected:
+ assert expected in caplog.text
+ else:
+ assert "Did you mean any of:" not in expected
+ assert "Available runtime parameters for the step" not in expected
- batch_manager = _BatchManager(
- steps={},
- last_batch_received={
- "step_1": batch_1,
- "step_2": batch_2,
- "step_3": batch_3,
- },
- last_batch_sent={"step_1": batch_1, "step_2": batch_2, "step_3": batch_3},
- last_batch_flag_sent_to=[],
- )
+ def test_cache_dir_env_variable(self) -> None:
+ with mock.patch.dict(os.environ, clear=True):
+ os.environ["DISTILABEL_CACHE_DIR"] = "/tmp/unit-test"
+ pipeline = DummyPipeline(name="unit-test-pipeline")
+ assert pipeline._cache_dir == Path("/tmp/unit-test")
- assert not batch_manager.can_generate()
-
- def test_dump(self) -> None:
- batch_manager = _BatchManager(
- steps={
- "step3": _BatchManagerStep(
- step_name="step3",
- accumulate=False,
- input_batch_size=5,
- data={"step1": [], "step2": []},
- seq_no=1,
- )
- },
- last_batch_received={
- "step3": _Batch(
- seq_no=0,
- step_name="step3",
- last_batch=False,
- )
- },
- last_batch_sent={
- "step3": _Batch(
- seq_no=1,
- step_name="step3",
- last_batch=False,
- )
- },
- last_batch_flag_sent_to=["step99"],
- )
- assert batch_manager.dump() == {
- "steps": {
- "step3": {
- "step_name": "step3",
- "accumulate": False,
- "convergence_step": False,
- "convergence_step_batches_consumed": {},
- "input_batch_size": 5,
- "data": {"step1": [], "step2": []},
- "seq_no": 1,
- "last_batch_received": [],
- "next_expected_created_from_batch_seq_no": 0,
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_BatchManagerStep",
- },
- },
- },
- "last_batch_received": {
- "step3": {
- "seq_no": 0,
- "step_name": "step3",
- "batch_routed_to": [],
- "created_from": {},
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_Batch",
- },
- }
- },
- "last_batch_sent": {
- "step3": {
- "seq_no": 1,
- "step_name": "step3",
- "batch_routed_to": [],
- "created_from": {},
- "last_batch": False,
- "data": [],
- "size": 0,
- "accumulated": False,
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_Batch",
- },
- }
- },
- "last_batch_flag_sent_to": ["step99"],
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_BatchManager",
- },
- }
+ @pytest.mark.parametrize(
+ "in_pipeline, names",
+ (
+ (
+ True,
+ [
+ "dummy_generator_step_0",
+ "dummy_step1_0",
+ "dummy_step2_0",
+ "dummy_step1_1",
+ ],
+ ),
+ # TODO: Activate this test once we merge the option of not passing a Pipeline
+ # (
+ # False, ["dummy_generator_step", "dummy_step1", "dummy_step2"]
+ # )
+ ),
+ )
+ def test_step_names_inferred(self, in_pipeline: bool, names: List[str]) -> None:
+ if in_pipeline:
+ with DummyPipeline(name="unit-test-pipeline"):
+ gen_step = DummyGeneratorStep()
+ step1_0 = DummyStep1()
+ step2 = DummyStep2()
+ step1_1 = DummyStep1()
- def test_from_dict(self) -> None:
- batch_manager_step = _BatchManagerStep.from_dict(
- {
- "step_name": "step3",
- "accumulate": True,
- "convergence_step": False,
- "input_batch_size": None,
- "data": {
- "step1": [
- {
- "seq_no": 0,
- "step_name": "step1",
- "last_batch": True,
- "data": [
- [
- {"a": 1},
- {"a": 2},
- {"a": 3},
- {"a": 4},
- {"a": 5},
- {"a": 6},
- ]
- ],
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- }
- ],
- "step2": [
- {
- "seq_no": 0,
- "step_name": "step2",
- "last_batch": True,
- "data": [
- [
- {"b": 1},
- {"b": 2},
- {"b": 3},
- {"b": 4},
- {"b": 5},
- {"b": 6},
- {"b": 7},
- ]
- ],
- "accumulated": False,
- "created_from": {},
- "batch_routed_to": [],
- }
- ],
- },
- "seq_no": 0,
- "last_batch_received": [],
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_BatchManagerStep",
- },
- }
- )
+ gen_step >> step1_0 >> step2 >> step1_1
+ else:
+ gen_step = DummyGeneratorStep()
+ step1_0 = DummyStep1()
+ step2 = DummyStep2()
+ step1_1 = DummyStep1()
- with tempfile.TemporaryDirectory() as tmpdirname:
- batch_manager_step.save(Path(tmpdirname) / "batch_manager_step3.json")
+ assert gen_step.name == names[0]
+ assert step1_0.name == names[1]
+ assert step2.name == names[2]
+ assert step1_1.name == names[3]
- batch_manager = _BatchManager.from_dict(
- {
- "steps": {
- "step3": str(Path(tmpdirname) / "batch_manager_step3.json")
- },
- "last_batch_received": {
- "step3": {
- "seq_no": 0,
- "step_name": "step3",
- "last_batch": False,
- "data": [],
- "accumulated": False,
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_Batch",
- },
- }
- },
- "last_batch_sent": {
- "step3": {
- "seq_no": 0,
- "step_name": "step3",
- "last_batch": False,
- "data": [],
- "accumulated": False,
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_Batch",
- },
- },
- },
- "last_batch_flag_sent_to": [],
- "type_info": {
- "module": "distilabel.pipeline.base",
- "name": "_BatchManager",
- },
- }
- )
- assert isinstance(batch_manager, _BatchManager)
- assert all(
- isinstance(step, _BatchManagerStep)
- for _, step in batch_manager._steps.items()
- )
- assert all(
- isinstance(batch, _Batch)
- for _, batch in batch_manager._last_batch_received.items()
- )
+ def test_infer_step_names_big_pipeline(self) -> None:
+ # Tests that the name of the steps are inferred correctly when the pipeline is big (say 50 steps).
+ with DummyPipeline(name="unit-test-pipeline") as pipe:
+ gen_step = DummyGeneratorStep()
+ for _ in range(50):
+ gen_step.connect(DummyStep1())
+ assert list(pipe.dag.G)[-1] == "dummy_step1_49"
class TestPipelineSerialization:
def test_base_pipeline_dump(self):
- pipeline = BasePipeline(name="unit-test-pipeline")
+ pipeline = DummyPipeline(name="unit-test-pipeline")
dump = pipeline.dump()
assert len(dump.keys()) == 2
assert "pipeline" in dump
assert "distilabel" in dump
assert TYPE_INFO_KEY in dump["pipeline"]
- assert dump["pipeline"][TYPE_INFO_KEY]["module"] == "distilabel.pipeline.base"
- assert dump["pipeline"][TYPE_INFO_KEY]["name"] == "BasePipeline"
+ assert (
+ dump["pipeline"][TYPE_INFO_KEY]["module"] == "tests.unit.pipeline.test_base"
+ )
+ assert dump["pipeline"][TYPE_INFO_KEY]["name"] == "DummyPipeline"
def test_base_pipeline_from_dict(self):
- pipeline = BasePipeline(name="unit-test-pipeline")
- pipe = BasePipeline.from_dict(pipeline.dump())
- assert isinstance(pipe, BasePipeline)
+ pipeline = DummyPipeline(name="unit-test-pipeline")
+ pipe = DummyPipeline.from_dict(pipeline.dump())
+ assert isinstance(pipe, DummyPipeline)
def test_pipeline_dump(self):
from distilabel.pipeline.local import Pipeline
@@ -1980,8 +928,8 @@ def test_pipeline_dump(self):
@pytest.mark.parametrize(
"format, name, loader",
[
- ("yaml", "pipe.yaml", BasePipeline.from_yaml),
- ("json", "pipe.json", BasePipeline.from_json),
+ ("yaml", "pipe.yaml", DummyPipeline.from_yaml),
+ ("json", "pipe.json", DummyPipeline.from_json),
("invalid", "pipe.invalid", None),
],
)
@@ -1991,7 +939,7 @@ def test_pipeline_to_from_file_format(
name: str,
loader: Callable,
) -> None:
- pipeline = BasePipeline(name="unit-test-pipeline")
+ pipeline = DummyPipeline(name="unit-test-pipeline")
with tempfile.TemporaryDirectory() as tmpdirname:
filename = Path(tmpdirname) / name
@@ -2002,10 +950,10 @@ def test_pipeline_to_from_file_format(
pipeline.save(filename, format=format)
assert filename.exists()
pipe_from_file = loader(filename)
- assert isinstance(pipe_from_file, BasePipeline)
+ assert isinstance(pipe_from_file, DummyPipeline)
def test_base_pipeline_signature(self):
- pipeline = BasePipeline(name="unit-test-pipeline")
+ pipeline = DummyPipeline(name="unit-test-pipeline")
# Doesn't matter if it's exactly this or not, the test should fail if we change the
# way this is created.
signature = pipeline._create_signature()
@@ -2036,62 +984,6 @@ def test_base_pipeline_signature(self):
signature = pipeline._create_signature()
assert signature == "a11ac46253598e6fe126420b23b9ad31c6422c92"
- @pytest.mark.parametrize("use_cache", [True, False])
- def test_run_pipe_and_load_from_cache(self, use_cache: bool):
- # Maybe not the best place for this test, but does the work for now
- from distilabel.pipeline.base import BasePipeline
- from distilabel.pipeline.routing_batch_function import sample_n_steps
-
- from tests.unit.pipeline.utils import DummyGeneratorStep, DummyStep1, DummyStep2
-
- sample_two_steps = sample_n_steps(2)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- with BasePipeline(
- name="unit-test-pipeline", cache_dir=tmpdirname
- ) as pipeline:
- dummy_generator = DummyGeneratorStep()
- dummy_step_1_0 = DummyStep1()
- dummy_step_1_1 = DummyStep1()
- dummy_step_1_2 = DummyStep1()
- dummy_step_2 = DummyStep2()
-
- (
- dummy_generator
- >> sample_two_steps
- >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2]
- >> dummy_step_2
- )
-
- pipeline.run({}, use_cache=use_cache)
-
- assert not pipeline._cache_location["pipeline"].exists()
- # Set the _BatchManager to the pipeline to check it exists afterwards
- pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag)
- pipeline._cache()
-
- assert pipeline._cache_location["pipeline"].exists()
-
- with BasePipeline(name="unit-test-pipeline", cache_dir=tmpdirname) as pipe:
- dummy_generator = DummyGeneratorStep()
- dummy_step_1_0 = DummyStep1()
- dummy_step_1_1 = DummyStep1()
- dummy_step_1_2 = DummyStep1()
- dummy_step_2 = DummyStep2()
-
- (
- dummy_generator
- >> sample_two_steps
- >> [dummy_step_1_0, dummy_step_1_1, dummy_step_1_2]
- >> dummy_step_2
- )
-
- pipe.run({}, use_cache=use_cache)
- if use_cache:
- assert pipe._batch_manager
- else:
- assert not pipe._batch_manager
-
def test_binary_rshift_operator(self) -> None:
# Tests the steps can be connected using the >> operator.
from distilabel.pipeline.local import Pipeline
@@ -2210,126 +1102,3 @@ def test_binary_operators(self) -> None:
signature_2 = pipeline_2._create_signature()
assert signature_1 == signature_2
-
-
-class TestWriteBuffer:
- def test_create(self) -> None:
- with tempfile.TemporaryDirectory() as tmpdirname:
- folder = Path(tmpdirname) / "data"
- with Pipeline(name="unit-test-pipeline") as pipeline:
- dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1")
- dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2")
- dummy_step_1 = DummyStep1(name="dummy_step_1")
- dummy_step_2 = DummyStep2(name="dummy_step_2")
- dummy_step_3 = DummyStep2(name="dummy_step_3")
-
- dummy_generator_1.connect(dummy_step_1)
- dummy_generator_2.connect(dummy_step_2)
- dummy_step_1.connect(dummy_step_2)
- dummy_step_1.connect(dummy_step_3)
-
- write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps)
-
- assert write_buffer._buffers == {"dummy_step_2": [], "dummy_step_3": []}
- assert write_buffer._buffers_dump_batch_size == {
- "dummy_step_2": 50,
- "dummy_step_3": 50,
- }
- assert write_buffer._buffer_last_schema == {}
- assert write_buffer._buffers_last_file == {
- "dummy_step_2": 1,
- "dummy_step_3": 1,
- }
-
- def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None:
- with tempfile.TemporaryDirectory() as tmpdirname:
- folder = Path(tmpdirname) / "data"
- with Pipeline(name="unit-test-pipeline") as pipeline:
- dummy_generator = DummyGeneratorStep(name="dummy_generator_step")
- dummy_step_1 = DummyStep1(name="dummy_step_1")
- dummy_step_2 = DummyStep2(name="dummy_step_2")
-
- dummy_generator.connect(dummy_step_1)
- dummy_step_1.connect(dummy_step_2)
-
- write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps)
-
- # Add one batch with 5 rows, shouldn't write anything 5 < 50
- batch = batch_gen(dummy_step_2.name)
- write_buffer.add_batch(batch)
-
- # Add 45 more rows, should write now
- for _ in range(9):
- batch = batch_gen(dummy_step_2.name)
- write_buffer.add_batch(batch)
-
- assert Path(folder, "dummy_step_2", "00001.parquet").exists()
-
- # Add 50 more rows, we should have a new file
- for _ in range(10):
- batch = batch_gen(dummy_step_2.name)
- write_buffer.add_batch(batch)
-
- assert Path(folder, "dummy_step_2", "00002.parquet").exists()
-
- # Add more rows and close the write buffer, we should have a new file
- for _ in range(5):
- batch = batch_gen(dummy_step_2.name)
- write_buffer.add_batch(batch)
-
- write_buffer.close()
-
- assert Path(folder, "dummy_step_2", "00003.parquet").exists()
-
- ds = create_distiset(write_buffer._path)
- assert isinstance(ds, Distiset)
- assert len(ds.keys()) == 1
- assert len(ds["default"]["train"]) == 125
-
- def test_write_buffer_multiple_leaf_steps_and_create_dataset(self):
- with tempfile.TemporaryDirectory() as tmpdirname:
- folder = Path(tmpdirname) / "data"
- with Pipeline(name="unit-test-pipeline") as pipeline:
- dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1")
- dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2")
- dummy_step_1 = DummyStep1(name="dummy_step_1")
- dummy_step_2 = DummyStep2(name="dummy_step_2")
- dummy_step_3 = DummyStep2(name="dummy_step_3")
-
- dummy_generator_1.connect(dummy_step_1)
- dummy_generator_2.connect(dummy_step_2)
- dummy_step_1.connect(dummy_step_2)
- dummy_step_1.connect(dummy_step_3)
-
- write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps)
-
- for _ in range(10):
- batch = batch_gen(dummy_step_2.name)
- write_buffer.add_batch(batch)
-
- assert Path(folder, "dummy_step_2", "00001.parquet").exists()
-
- for _ in range(10):
- batch = batch_gen(dummy_step_3.name)
- write_buffer.add_batch(batch)
-
- assert Path(folder, "dummy_step_3", "00001.parquet").exists()
-
- for _ in range(5):
- batch = batch_gen(dummy_step_2.name)
- write_buffer.add_batch(batch)
-
- for _ in range(5):
- batch = batch_gen(dummy_step_3.name)
- write_buffer.add_batch(batch)
-
- write_buffer.close()
-
- assert Path(folder, "dummy_step_2", "00002.parquet").exists()
- assert Path(folder, "dummy_step_3", "00002.parquet").exists()
-
- ds = create_distiset(write_buffer._path)
- assert isinstance(ds, Distiset)
- assert len(ds.keys()) == 2
- assert len(ds["dummy_step_2"]["train"]) == 75
- assert len(ds["dummy_step_3"]["train"]) == 75
diff --git a/tests/unit/pipeline/test_batch.py b/tests/unit/pipeline/test_batch.py
new file mode 100644
index 0000000000..ed246e491f
--- /dev/null
+++ b/tests/unit/pipeline/test_batch.py
@@ -0,0 +1,172 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.pipeline.batch import _Batch
+
+
+class TestBatch:
+ def test_get_data(self) -> None:
+ batch = _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[
+ [
+ {"a": 0},
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ {"a": 6},
+ ]
+ ],
+ )
+
+ batch.set_data(
+ [
+ [
+ {"a": 0},
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ {"a": 6},
+ ]
+ ]
+ )
+
+ old_hash = batch.data_hash
+
+ data = batch.get_data(5)
+ assert data == [{"a": 0}, {"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]
+ assert batch.data == [[{"a": 5}, {"a": 6}]]
+ assert batch.data_hash != old_hash
+
+ def test_set_data(self) -> None:
+ batch = _Batch(seq_no=0, step_name="step1", last_batch=False)
+ data = [[{"i": i} for i in range(5000)]]
+ batch.set_data(data)
+
+ assert batch.data == data
+ assert batch.size == 5000
+
+ def test_next_batch(self) -> None:
+ batch = _Batch(seq_no=0, step_name="step1", last_batch=False)
+ next_batch = batch.next_batch()
+
+ assert next_batch == _Batch(seq_no=1, step_name="step1", last_batch=False)
+
+ def test_accumulate(self) -> None:
+ batches = [
+ [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
+ ),
+ _Batch(
+ seq_no=1,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 4}, {"a": 5}, {"a": 6}]],
+ ),
+ ],
+ [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}]],
+ ),
+ _Batch(
+ seq_no=1,
+ step_name="step2",
+ last_batch=True,
+ data=[[{"b": 4}, {"b": 5}, {"b": 6}]],
+ ),
+ ],
+ ]
+
+ batch = _Batch.accumulate("step3", batches)
+
+ assert batch.seq_no == 0
+ assert batch.step_name == "step3"
+ assert batch.last_batch is True
+ assert batch.data == [
+ [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}],
+ [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}],
+ ]
+
+ def test_dump(self) -> None:
+ batch = _Batch(seq_no=0, step_name="step1", last_batch=False)
+ assert batch.dump() == {
+ "seq_no": 0,
+ "size": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [],
+ "data_hash": None,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {"module": "distilabel.pipeline.batch", "name": "_Batch"},
+ }
+
+ batch = _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
+ data_hash="hash",
+ accumulated=False,
+ created_from={"step0": [(0, 5), (1, 5)]},
+ batch_routed_to=["step2", "step3"],
+ )
+ assert batch.dump() == {
+ "seq_no": 0,
+ "size": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [[{"a": 1}, {"a": 2}, {"a": 3}]],
+ "data_hash": "hash",
+ "accumulated": False,
+ "created_from": {"step0": [(0, 5), (1, 5)]},
+ "batch_routed_to": ["step2", "step3"],
+ "type_info": {"module": "distilabel.pipeline.batch", "name": "_Batch"},
+ }
+
+ def test_from_dict(self) -> None:
+ batch = _Batch.from_dict(
+ {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [[{"a": 1}, {"a": 2}, {"a": 3}]],
+ "accumulated": False,
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ )
+
+ assert isinstance(batch, _Batch)
+ assert batch.seq_no == 0
+ assert batch.step_name == "step1"
+ assert batch.last_batch is False
+ assert batch.data == [[{"a": 1}, {"a": 2}, {"a": 3}]]
+ assert batch.accumulated is False
diff --git a/tests/unit/pipeline/test_batch_manager.py b/tests/unit/pipeline/test_batch_manager.py
new file mode 100644
index 0000000000..7b1cb1a8a6
--- /dev/null
+++ b/tests/unit/pipeline/test_batch_manager.py
@@ -0,0 +1,2214 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+from pathlib import Path
+from typing import Dict, List
+
+import pytest
+from distilabel.pipeline._dag import DAG
+from distilabel.pipeline.batch import _Batch
+from distilabel.pipeline.batch_manager import _BatchManager, _BatchManagerStep
+from distilabel.steps.base import GeneratorStep, GlobalStep, Step
+
+
+class TestBatchManagerStep:
+ def test_add_batch(self) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []}
+ )
+
+ batch = _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
+ )
+
+ batch_manager_step.add_batch(batch)
+
+ assert batch_manager_step.data["step1"] == [batch]
+ assert batch_manager_step.last_batch_received == []
+
+ def test_add_batch_with_prepend(self) -> None:
+ batch_1 = _Batch(
+ seq_no=1,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]],
+ )
+ batch_manager_step = _BatchManagerStep(
+ step_name="step2",
+ accumulate=False,
+ input_batch_size=10,
+ data={"step1": [batch_1]},
+ )
+
+ batch_0 = _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ batch_manager_step.add_batch(batch_0, prepend=True)
+
+ assert batch_manager_step.built_batches == [batch_0]
+ assert batch_manager_step.data["step1"] == [batch_1]
+ assert batch_manager_step.last_batch_received == []
+
+ def test_add_batch_last_batch(self) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []}
+ )
+
+ batch = _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
+ )
+
+ batch_manager_step.add_batch(batch)
+
+ assert batch_manager_step.data["step1"] == [batch]
+ assert batch_manager_step.last_batch_received == ["step1"]
+
+ def test_get_batch(self) -> None:
+ previously_built_batch = _Batch(
+ seq_no=0,
+ step_name="step3",
+ last_batch=False,
+ data=[
+ [
+ {"a": -1},
+ {"a": 0},
+ ],
+ [
+ {"b": -1},
+ {"b": 0},
+ ],
+ ],
+ )
+
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ input_batch_size=2,
+ seq_no=1,
+ data={
+ "step1": [
+ _Batch(
+ seq_no=1,
+ step_name="step1",
+ last_batch=False,
+ data=[
+ [
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ ]
+ ],
+ size=5,
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=1,
+ step_name="step2",
+ last_batch=False,
+ data=[
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ ]
+ ],
+ size=5,
+ )
+ ],
+ },
+ built_batches=[previously_built_batch],
+ )
+
+ batch = batch_manager_step.get_batch()
+
+ assert batch == previously_built_batch
+
+ batch = batch_manager_step.get_batch()
+
+ assert batch == _Batch(
+ step_name="step3",
+ seq_no=1,
+ last_batch=False,
+ data=[
+ [
+ {"a": 1},
+ {"a": 2},
+ ],
+ [
+ {"b": 1},
+ {"b": 2},
+ ],
+ ],
+ created_from={"step1": [(1, 5)], "step2": [(1, 5)]},
+ )
+
+ batch = batch_manager_step.get_batch()
+
+ assert batch == _Batch(
+ step_name="step3",
+ seq_no=2,
+ last_batch=False,
+ data=[
+ [
+ {"a": 3},
+ {"a": 4},
+ ],
+ [
+ {"b": 3},
+ {"b": 4},
+ ],
+ ],
+ created_from={"step1": [(1, 5)], "step2": [(1, 5)]},
+ )
+
+ def test_get_batches_accumulate(self) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=True,
+ data={
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[
+ [
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ ]
+ ],
+ size=5,
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=True,
+ data=[
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ ]
+ ],
+ size=6,
+ )
+ ],
+ },
+ last_batch_received=["step1", "step2"],
+ )
+
+ batch = batch_manager_step.get_batch()
+
+ assert batch == _Batch(
+ step_name="step3",
+ seq_no=0,
+ last_batch=True,
+ accumulated=True,
+ data=[
+ [
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ ],
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ ],
+ ],
+ created_from={"step1": [(0, 5)], "step2": [(0, 6)]},
+ )
+
+ def test_get_batches_not_enough_data(self) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ input_batch_size=2,
+ data={
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[
+ [
+ {"a": 1},
+ ]
+ ],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[
+ [
+ {"b": 1},
+ {"b": 2},
+ ]
+ ],
+ )
+ ],
+ },
+ )
+
+ assert batch_manager_step.get_batch() is None
+
+ def test_from_step(self, dummy_step_1: "Step") -> None:
+ batch_manager_step = _BatchManagerStep.from_step(
+ step=dummy_step_1, predecessors=["step1", "step2"]
+ )
+
+ assert batch_manager_step.step_name == "dummy_step_1"
+ assert batch_manager_step.accumulate is False
+ assert batch_manager_step.input_batch_size == 50
+ assert batch_manager_step.data == {"step1": [], "step2": []}
+ assert batch_manager_step.seq_no == 0
+ assert batch_manager_step.last_batch_received == []
+
+ def test_from_step_with_global_step(self, dummy_global_step: "GlobalStep") -> None:
+ batch_manager_step = _BatchManagerStep.from_step(
+ step=dummy_global_step, predecessors=["step1", "step2"]
+ )
+
+ assert batch_manager_step.step_name == "dummy_global_step"
+ assert batch_manager_step.accumulate is True
+ assert batch_manager_step.input_batch_size == 50
+ assert batch_manager_step.data == {"step1": [], "step2": []}
+ assert batch_manager_step.seq_no == 0
+ assert batch_manager_step.last_batch_received == []
+
+ def test_get_seq_no(self) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step2", accumulate=False, input_batch_size=5, data={"step1": []}
+ )
+
+ seq_no = batch_manager_step._get_seq_no()
+
+ assert seq_no == 0
+ assert batch_manager_step.seq_no == 1
+
+ def test_get_data(self) -> None:
+ batch_step_1 = _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]],
+ size=6,
+ batch_routed_to=["step1", "step2"],
+ )
+ batch_step_2 = _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ {"b": 7},
+ ]
+ ],
+ size=7,
+ batch_routed_to=["step1", "step2"],
+ )
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ input_batch_size=5,
+ data={
+ "step1": [batch_step_1],
+ "step2": [batch_step_2],
+ },
+ )
+
+ data, created_from, routed_to = batch_manager_step._get_data()
+ assert data == [
+ [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}],
+ [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}],
+ ]
+ assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]}
+ assert routed_to == ["step1", "step2"]
+
+ assert batch_manager_step.data == {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 6}]],
+ data_hash=batch_step_1.data_hash,
+ size=6,
+ batch_routed_to=["step1", "step2"],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"b": 6}, {"b": 7}]],
+ data_hash=batch_step_2.data_hash,
+ size=7,
+ batch_routed_to=["step1", "step2"],
+ )
+ ],
+ }
+
+ def test_get_data_accumulate(self) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=True,
+ data={
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[
+ [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]
+ ],
+ size=6,
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ {"b": 7},
+ ]
+ ],
+ size=7,
+ )
+ ],
+ },
+ )
+
+ data, created_from, routed_to = batch_manager_step._get_data()
+
+ assert data == [
+ [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}],
+ [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}],
+ ]
+ assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]}
+ assert routed_to == []
+
+ assert batch_manager_step.data == {"step1": [], "step2": []}
+
+ def test_get_data_convergence_step(self) -> None:
+ batch_a_0 = _Batch(
+ seq_no=0,
+ step_name="A",
+ last_batch=False,
+ data=[
+ [
+ {"generation": "Hello, I'm A 0"},
+ {"generation": "Hello, I'm A 0"},
+ {"generation": "Hello, I'm A 0"},
+ ]
+ ],
+ size=3,
+ created_from={"Z": [(0, 3)]},
+ )
+
+ batch_a_1 = _Batch(
+ seq_no=1,
+ step_name="A",
+ last_batch=False,
+ data=[
+ [
+ {"generation": "Hello, I'm A 1"},
+ {"generation": "Hello, I'm A 1"},
+ {"generation": "Hello, I'm A 1"},
+ ]
+ ],
+ size=3,
+ created_from={"Z": [(1, 3)]},
+ )
+
+ batch_b_0 = _Batch(
+ seq_no=0,
+ step_name="B",
+ last_batch=False,
+ data=[
+ [
+ {"generation": "Hello, I'm B 0"},
+ {"generation": "Hello, I'm B 0"},
+ {"generation": "Hello, I'm B 0"},
+ ]
+ ],
+ size=3,
+ created_from={"Z": [(0, 3)]},
+ )
+
+ batch_c_0 = _Batch(
+ seq_no=0,
+ step_name="C",
+ last_batch=False,
+ data=[
+ [
+ {"generation": "Hello, I'm C 0"},
+ {"generation": "Hello, I'm C 0"},
+ {"generation": "Hello, I'm C 0"},
+ ]
+ ],
+ size=3,
+ created_from={"Z": [(1, 3)]},
+ )
+
+ batch_manager_step = _BatchManagerStep(
+ step_name="D",
+ input_batch_size=3,
+ convergence_step=True,
+ accumulate=False,
+ data={"A": [batch_a_0, batch_a_1], "B": [batch_b_0], "C": [batch_c_0]},
+ )
+
+ data, created_from, routed_to = batch_manager_step._get_data()
+
+ assert data == [
+ [
+ {"generation": "Hello, I'm A 0"},
+ {"generation": "Hello, I'm A 0"},
+ {"generation": "Hello, I'm A 0"},
+ ],
+ [
+ {"generation": "Hello, I'm B 0"},
+ {"generation": "Hello, I'm B 0"},
+ {"generation": "Hello, I'm B 0"},
+ ],
+ ]
+ assert created_from == {"A": [(0, 3)], "B": [(0, 3)]}
+ assert routed_to == []
+ assert batch_manager_step.next_expected_created_from_batch_seq_no == 1
+
+ data, created_from, routed_to = batch_manager_step._get_data()
+
+ assert data == [
+ [
+ {"generation": "Hello, I'm A 1"},
+ {"generation": "Hello, I'm A 1"},
+ {"generation": "Hello, I'm A 1"},
+ ],
+ [
+ {"generation": "Hello, I'm C 0"},
+ {"generation": "Hello, I'm C 0"},
+ {"generation": "Hello, I'm C 0"},
+ ],
+ ]
+ assert created_from == {"A": [(1, 3)], "C": [(0, 3)]}
+ assert routed_to == []
+ assert batch_manager_step.next_expected_created_from_batch_seq_no == 2
+
+ @pytest.mark.parametrize(
+ "data, last_batch_received, expected",
+ [
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ ]
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]],
+ )
+ ],
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[
+ [
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ {"a": 6},
+ ]
+ ],
+ )
+ ]
+ },
+ ["step1"],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ ]
+ },
+ ["step1"],
+ True,
+ ),
+ ],
+ )
+ def test_last_batch(
+ self,
+ data: Dict[str, List[_Batch]],
+ last_batch_received: List[str],
+ expected: bool,
+ ) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step2",
+ accumulate=False,
+ input_batch_size=5,
+ data=data,
+ last_batch_received=last_batch_received,
+ )
+
+ assert batch_manager_step._last_batch() is expected
+
+ @pytest.mark.parametrize(
+ "data, last_batch_received, expected",
+ [
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ )
+ ],
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ )
+ ],
+ },
+ ["step1"],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ )
+ ],
+ },
+ ["step1", "step2"],
+ True,
+ ),
+ ],
+ )
+ def test_last_batch_accumulate(
+ self,
+ data: Dict[str, List[_Batch]],
+ last_batch_received: List[str],
+ expected: bool,
+ ) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=True,
+ data=data,
+ last_batch_received=last_batch_received,
+ )
+
+ assert batch_manager_step._last_batch() is expected
+
+ @pytest.mark.parametrize(
+ "data, last_batch_received, expected",
+ [
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
+ created_from={"step0": [(0, 3)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}]],
+ created_from={"step0": [(0, 3)]},
+ )
+ ],
+ },
+ [],
+ True,
+ ),
+ ],
+ )
+ def test_last_batch_convergence_step(
+ self,
+ data: Dict[str, List[_Batch]],
+ last_batch_received: List[str],
+ expected: bool,
+ ) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ data=data,
+ last_batch_received=last_batch_received,
+ input_batch_size=3,
+ convergence_step=True,
+ )
+
+ assert batch_manager_step._last_batch() is expected
+
+ @pytest.mark.parametrize(
+ "data, last_batch_received, expected",
+ [
+ (
+ {
+ "step1": [],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ )
+ ],
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ )
+ ],
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=True,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ )
+ ],
+ },
+ ["step1", "step2"],
+ True,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=True,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]],
+ )
+ ],
+ },
+ ["step1", "step2"],
+ True,
+ ),
+ ],
+ )
+ def test_ready_to_create_batch(
+ self,
+ data: Dict[str, List[_Batch]],
+ last_batch_received: List[str],
+ expected: bool,
+ ) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step2",
+ accumulate=False,
+ input_batch_size=5,
+ data=data,
+ last_batch_received=last_batch_received,
+ )
+
+ assert batch_manager_step._ready_to_create_batch() is expected
+
+ @pytest.mark.parametrize(
+ "data, last_batch_received, expected",
+ [
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=True,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ )
+ ],
+ },
+ ["step1", "step2"],
+ True,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ )
+ ],
+ },
+ ["step1"],
+ False,
+ ),
+ ],
+ )
+ def test_ready_to_create_batch_accumulate(
+ self,
+ data: Dict[str, List[_Batch]],
+ last_batch_received: List[str],
+ expected: bool,
+ ) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=True,
+ data=data,
+ last_batch_received=last_batch_received,
+ )
+
+ assert batch_manager_step._ready_to_create_batch() is expected
+
+ def test_dump(self) -> None:
+ batch_step_1 = _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]],
+ data_hash="hash0",
+ size=6,
+ )
+ batch_step_2 = _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=True,
+ data=[
+ [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}]
+ ],
+ data_hash="hash1",
+ size=7,
+ )
+ batch_step_3 = _Batch(
+ seq_no=0,
+ step_name="step3",
+ last_batch=True,
+ data=[[{"c": 1}, {"c": 2}, {"c": 3}, {"c": 4}, {"c": 5}]],
+ data_hash="hash2",
+ size=5,
+ )
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=True,
+ data={
+ "step1": [batch_step_1],
+ "step2": [batch_step_2],
+ },
+ built_batches=[batch_step_3],
+ )
+ assert batch_manager_step.dump() == {
+ "step_name": "step3",
+ "accumulate": True,
+ "convergence_step": False,
+ "convergence_step_batches_consumed": {},
+ "input_batch_size": None,
+ "data": {
+ "step1": [
+ {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": True,
+ "data": [
+ [
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ {"a": 6},
+ ]
+ ],
+ "data_hash": "hash0",
+ "size": 6,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ ],
+ "step2": [
+ {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": True,
+ "data": [
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ {"b": 7},
+ ]
+ ],
+ "data_hash": "hash1",
+ "size": 7,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ ],
+ },
+ "built_batches": [
+ {
+ "seq_no": 0,
+ "step_name": "step3",
+ "last_batch": True,
+ "data": [[{"c": 1}, {"c": 2}, {"c": 3}, {"c": 4}, {"c": 5}]],
+ "data_hash": "hash2",
+ "size": 5,
+ "accumulated": False,
+ "batch_routed_to": [],
+ "created_from": {},
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ ],
+ "seq_no": 0,
+ "last_batch_received": [],
+ "next_expected_created_from_batch_seq_no": 0,
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManagerStep",
+ },
+ }
+
+ @pytest.mark.parametrize(
+ "data, last_batch_received, expected",
+ [
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ "step2": [],
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ },
+ [],
+ True,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 4)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 4)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=True,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 4)]},
+ )
+ ],
+ },
+ ["step1", "step2"],
+ True,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 4)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ },
+ [],
+ False,
+ ),
+ ],
+ )
+ def test_ready_to_create_batch_convergence_step(
+ self,
+ data: Dict[str, List[_Batch]],
+ last_batch_received: List[str],
+ expected: bool,
+ ) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ input_batch_size=5,
+ data=data,
+ last_batch_received=last_batch_received,
+ convergence_step=True,
+ )
+
+ assert batch_manager_step._ready_to_create_batch() is expected
+
+ def test_from_dict(self) -> None:
+ batch_manager_step = _BatchManagerStep.from_dict(
+ {
+ "step_name": "step3",
+ "accumulate": True,
+ "convergence_step": False,
+ "convergence_step_batches_consumed": {0: {"Z": 1234}},
+ "input_batch_size": None,
+ "data": {
+ "step1": [
+ {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": True,
+ "data": [
+ [
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ {"a": 6},
+ ]
+ ],
+ "size": 6,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ }
+ ],
+ "step2": [
+ {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": True,
+ "data": [
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ {"b": 7},
+ ]
+ ],
+ "size": 7,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ }
+ ],
+ },
+ "seq_no": 0,
+ "last_batch_received": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManagerStep",
+ },
+ }
+ )
+
+ assert isinstance(batch_manager_step, _BatchManagerStep)
+ assert batch_manager_step.step_name == "step3"
+ assert batch_manager_step.accumulate is True
+ assert batch_manager_step.convergence_step is False
+ assert batch_manager_step.convergence_step_batches_consumed == {0: {"Z": 1234}}
+ assert batch_manager_step.input_batch_size is None
+ assert batch_manager_step.seq_no == 0
+ assert batch_manager_step.last_batch_received == []
+
+
+class TestBatchManager:
+ def test_add_batch(self) -> None:
+ batch_manager = _BatchManager(
+ steps={
+ "step3": _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ input_batch_size=5,
+ data={"step1": [], "step2": []},
+ )
+ },
+ last_batch_received={"step3": None},
+ last_batch_sent={"step3": None},
+ last_batch_flag_sent_to=[],
+ )
+
+ batch_from_step_1 = _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ batch_manager.add_batch(to_step="step3", batch=batch_from_step_1)
+
+ assert batch_manager._steps["step3"].data == {
+ "step1": [batch_from_step_1],
+ "step2": [],
+ }
+
+ def test_add_batch_with_prepend(self) -> None:
+ batch_1 = _Batch(
+ seq_no=1,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]],
+ )
+ batch_manager = _BatchManager(
+ steps={
+ "step3": _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ input_batch_size=5,
+ data={
+ "step1": [batch_1],
+ "step2": [],
+ },
+ )
+ },
+ last_batch_received={"step3": None},
+ last_batch_sent={"step3": None},
+ last_batch_flag_sent_to=[],
+ )
+ batch_0 = _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ )
+ batch_manager.add_batch(to_step="step3", batch=batch_0, prepend=True)
+ assert batch_manager._steps["step3"].built_batches == [batch_0]
+ assert batch_manager._steps["step3"].data == {
+ "step1": [batch_1],
+ "step2": [],
+ }
+
+ def test_from_dag(
+ self,
+ dummy_generator_step: "GeneratorStep",
+ dummy_step_1: "Step",
+ dummy_step_2: "Step",
+ dummy_global_step: "GlobalStep",
+ ) -> None:
+ dag = DAG()
+ dag.add_step(dummy_generator_step)
+ dag.add_step(dummy_step_1)
+ dag.add_step(dummy_step_2)
+ dag.add_step(dummy_global_step)
+ dag.add_edge("dummy_generator_step", "dummy_step_1")
+ dag.add_edge("dummy_generator_step", "dummy_global_step")
+ dag.add_edge("dummy_step_1", "dummy_step_2")
+
+ batch_manager = _BatchManager.from_dag(dag)
+
+ assert batch_manager._steps == {
+ "dummy_step_1": _BatchManagerStep(
+ step_name="dummy_step_1",
+ accumulate=False,
+ input_batch_size=50,
+ data={"dummy_generator_step": []},
+ ),
+ "dummy_global_step": _BatchManagerStep(
+ step_name="dummy_global_step",
+ accumulate=True,
+ input_batch_size=50,
+ data={"dummy_generator_step": []},
+ ),
+ "dummy_step_2": _BatchManagerStep(
+ step_name="dummy_step_2",
+ accumulate=False,
+ input_batch_size=50,
+ data={"dummy_step_1": []},
+ ),
+ }
+
+ def test_can_generate(self) -> None:
+ batch_manager = _BatchManager(
+ steps={},
+ last_batch_received={
+ "step_1": _Batch(seq_no=0, step_name="step_1", last_batch=False),
+ "step_2": _Batch(seq_no=0, step_name="step_2", last_batch=False),
+ "step_3": _Batch(seq_no=0, step_name="step_3", last_batch=False),
+ },
+ last_batch_sent={"step_1": None, "step_2": None, "step_3": None},
+ last_batch_flag_sent_to=[],
+ )
+
+ assert batch_manager.can_generate()
+
+ batch_1 = _Batch(seq_no=0, step_name="step_1", last_batch=True)
+ batch_2 = _Batch(seq_no=0, step_name="step_2", last_batch=True)
+ batch_3 = _Batch(seq_no=0, step_name="step_3", last_batch=True)
+
+ batch_manager = _BatchManager(
+ steps={},
+ last_batch_received={
+ "step_1": batch_1,
+ "step_2": batch_2,
+ "step_3": batch_3,
+ },
+ last_batch_sent={"step_1": batch_1, "step_2": batch_2, "step_3": batch_3},
+ last_batch_flag_sent_to=[],
+ )
+
+ assert not batch_manager.can_generate()
+
+ def test_dump(self) -> None:
+ built_batch = _Batch(
+ seq_no=0,
+ last_batch=False,
+ step_name="step3",
+ data=[[]],
+ data_hash="hash",
+ )
+
+ batch_manager = _BatchManager(
+ steps={
+ "step3": _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ input_batch_size=5,
+ data={"step1": [], "step2": []},
+ built_batches=[built_batch],
+ seq_no=1,
+ )
+ },
+ last_batch_received={
+ "step3": _Batch(
+ seq_no=0,
+ step_name="step3",
+ last_batch=False,
+ )
+ },
+ last_batch_sent={
+ "step3": _Batch(
+ seq_no=1,
+ step_name="step3",
+ last_batch=False,
+ )
+ },
+ last_batch_flag_sent_to=["step99"],
+ )
+ assert batch_manager.dump() == {
+ "steps": {
+ "step3": {
+ "step_name": "step3",
+ "accumulate": False,
+ "convergence_step": False,
+ "convergence_step_batches_consumed": {},
+ "input_batch_size": 5,
+ "data": {"step1": [], "step2": []},
+ "built_batches": [
+ {
+ "seq_no": 0,
+ "step_name": "step3",
+ "last_batch": False,
+ "data": [[]],
+ "data_hash": "hash",
+ "size": 0,
+ "accumulated": False,
+ "batch_routed_to": [],
+ "created_from": {},
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ ],
+ "seq_no": 1,
+ "last_batch_received": [],
+ "next_expected_created_from_batch_seq_no": 0,
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManagerStep",
+ },
+ },
+ },
+ "last_batch_received": {
+ "step3": {
+ "seq_no": 0,
+ "step_name": "step3",
+ "batch_routed_to": [],
+ "created_from": {},
+ "last_batch": False,
+ "data": [],
+ "data_hash": None,
+ "size": 0,
+ "accumulated": False,
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ },
+ "last_batch_sent": {
+ "step3": {
+ "seq_no": 1,
+ "step_name": "step3",
+ "batch_routed_to": [],
+ "created_from": {},
+ "last_batch": False,
+ "data": [],
+ "data_hash": None,
+ "size": 0,
+ "accumulated": False,
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ },
+ "last_batch_flag_sent_to": ["step99"],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManager",
+ },
+ }
+
+ def test_from_dict(self) -> None:
+ batch_manager = _BatchManager.from_dict(
+ {
+ "steps": {
+ "step1": {
+ "step_name": "step1",
+ "accumulate": True,
+ "convergence_step": False,
+ "convergence_step_batches_consumed": {0: {"Z": 1234}},
+ "input_batch_size": None,
+ "data": {
+ "step2": [
+ {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": True,
+ "data": [
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ {"b": 7},
+ ]
+ ],
+ "size": 7,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ }
+ ],
+ },
+ "seq_no": 0,
+ "last_batch_received": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManagerStep",
+ },
+ },
+ "step2": {
+ "step_name": "step2",
+ "accumulate": False,
+ "convergence_step": False,
+ "convergence_step_batches_consumed": {0: {"Z": 1234}},
+ "input_batch_size": 50,
+ "data": {
+ "step2": [
+ {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": True,
+ "data": [
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ {"b": 7},
+ ]
+ ],
+ "size": 7,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ }
+ ],
+ },
+ "seq_no": 0,
+ "last_batch_received": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManagerStep",
+ },
+ },
+ },
+ "last_batch_received": {
+ "step1": {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ },
+ "step2": {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ },
+ },
+ "last_batch_sent": {
+ "step1": {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ },
+ "step2": {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ },
+ },
+ "last_batch_flag_sent_to": ["step3"],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManager",
+ },
+ }
+ )
+
+ assert isinstance(batch_manager, _BatchManager)
+
+ assert len(batch_manager._steps) == 2
+ for step in batch_manager._steps.values():
+ assert isinstance(step, _BatchManagerStep)
+
+ assert len(batch_manager._last_batch_received) == 2
+ for step in batch_manager._last_batch_received.values():
+ assert isinstance(step, _Batch)
+
+ assert len(batch_manager._last_batch_sent) == 2
+ for step in batch_manager._last_batch_sent.values():
+ assert isinstance(step, _Batch)
+
+ assert batch_manager._last_batch_flag_sent_to == ["step3"]
+
+ def test_cache(self) -> None:
+ batch_manager = _BatchManager.from_dict(
+ {
+ "steps": {
+ "step1": {
+ "step_name": "step1",
+ "accumulate": True,
+ "convergence_step": False,
+ "convergence_step_batches_consumed": {"0": {"Z": 1234}},
+ "input_batch_size": None,
+ "data": {
+ "step2": [
+ {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": True,
+ "data": [
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ {"b": 7},
+ ]
+ ],
+ "data_hash": "1234",
+ "size": 7,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ }
+ ],
+ },
+ "built_batches": [
+ {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [
+ [
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ ]
+ ],
+ "data_hash": "1234",
+ "size": 5,
+ "accumulated": False,
+ "batch_routed_to": [],
+ "created_from": {},
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ ],
+ "seq_no": 0,
+ "last_batch_received": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManagerStep",
+ },
+ },
+ "step2": {
+ "step_name": "step2",
+ "accumulate": False,
+ "convergence_step": False,
+ "convergence_step_batches_consumed": {"0": {"Z": 1234}},
+ "input_batch_size": 50,
+ "data": {
+ "step2": [
+ {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": True,
+ "data": [
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ {"b": 7},
+ ]
+ ],
+ "data_hash": "1234",
+ "size": 7,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ }
+ ],
+ },
+ "built_batches": [
+ {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [
+ [
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ ]
+ ],
+ "data_hash": "1234",
+ "size": 5,
+ "accumulated": False,
+ "batch_routed_to": [],
+ "created_from": {},
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ ],
+ "seq_no": 0,
+ "last_batch_received": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManagerStep",
+ },
+ },
+ },
+ "last_batch_received": {
+ "step1": {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ },
+ "step2": {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ },
+ },
+ "last_batch_sent": {
+ "step1": {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ },
+ "step2": {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_Batch",
+ },
+ },
+ },
+ "last_batch_flag_sent_to": ["step3"],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManager",
+ },
+ }
+ )
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ batch_manager_path = Path(tmp_dir) / "batch_manager.json"
+ batch_manager.cache(batch_manager_path)
+
+ assert batch_manager_path.exists() and batch_manager_path.is_file()
+
+ for step_name, step in batch_manager._steps.items():
+ batch_manager_step_dir = (
+ Path(tmp_dir) / "batch_manager_steps" / step_name
+ )
+ assert (
+ batch_manager_step_dir.exists() and batch_manager_step_dir.is_dir()
+ )
+
+ batch_manager_step_path = (
+ batch_manager_step_dir / "batch_manager_step.json"
+ )
+ assert (
+ batch_manager_step_path.exists()
+ and batch_manager_step_path.is_file()
+ )
+
+ built_batches_dir = batch_manager_step_dir / "built_batches"
+ assert built_batches_dir.exists()
+
+ for batch in step.built_batches:
+ batch_path = (
+ built_batches_dir
+ / f"batch_{batch.seq_no}_{batch.data_hash}.json"
+ )
+ assert batch_path.exists() and batch_path.is_file()
+
+ for buffered_step_name in step.data:
+ buffered_step_dir = batch_manager_step_dir / buffered_step_name
+ assert buffered_step_dir.exists() and buffered_step_dir.is_dir()
+
+ for batch in step.data[buffered_step_name]:
+ batch_path = (
+ buffered_step_dir
+ / f"batch_{batch.seq_no}_{batch.data_hash}.json"
+ )
+ assert batch_path.exists() and batch_path.is_file()
+
+ def test_load_from_cache(self) -> None:
+ batch_manager = _BatchManager.from_dict(
+ {
+ "steps": {
+ "step1": {
+ "step_name": "step1",
+ "accumulate": True,
+ "convergence_step": False,
+ "convergence_step_batches_consumed": {"0": {"Z": 1234}},
+ "input_batch_size": None,
+ "data": {
+ "step2": [
+ {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": True,
+ "data": [
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ {"b": 7},
+ ]
+ ],
+ "data_hash": "1234",
+ "size": 7,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ ],
+ },
+ "built_batches": [
+ {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [
+ [
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ ]
+ ],
+ "data_hash": "1234",
+ "size": 5,
+ "accumulated": False,
+ "batch_routed_to": [],
+ "created_from": {},
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ ],
+ "seq_no": 0,
+ "last_batch_received": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManagerStep",
+ },
+ },
+ "step2": {
+ "step_name": "step2",
+ "accumulate": False,
+ "convergence_step": False,
+ "convergence_step_batches_consumed": {"0": {"Z": 1234}},
+ "input_batch_size": 50,
+ "data": {
+ "step2": [
+ {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": True,
+ "data": [
+ [
+ {"b": 1},
+ {"b": 2},
+ {"b": 3},
+ {"b": 4},
+ {"b": 5},
+ {"b": 6},
+ {"b": 7},
+ ]
+ ],
+ "data_hash": "1234",
+ "size": 7,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ ],
+ },
+ "built_batches": [
+ {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [
+ [
+ {"a": 1},
+ {"a": 2},
+ {"a": 3},
+ {"a": 4},
+ {"a": 5},
+ ]
+ ],
+ "data_hash": "1234",
+ "size": 5,
+ "accumulated": False,
+ "batch_routed_to": [],
+ "created_from": {},
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ }
+ ],
+ "seq_no": 0,
+ "last_batch_received": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManagerStep",
+ },
+ },
+ },
+ "last_batch_received": {
+ "step1": {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ },
+ "step2": {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ },
+ },
+ "last_batch_sent": {
+ "step1": {
+ "seq_no": 0,
+ "step_name": "step1",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ },
+ "step2": {
+ "seq_no": 0,
+ "step_name": "step2",
+ "last_batch": False,
+ "data": [],
+ "size": 0,
+ "accumulated": False,
+ "created_from": {},
+ "batch_routed_to": [],
+ "type_info": {
+ "module": "distilabel.pipeline.batch",
+ "name": "_Batch",
+ },
+ },
+ },
+ "last_batch_flag_sent_to": ["step3"],
+ "type_info": {
+ "module": "distilabel.pipeline.batch_manager",
+ "name": "_BatchManager",
+ },
+ }
+ )
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ batch_manager_path = Path(tmp_dir) / "batch_manager.json"
+ batch_manager.cache(batch_manager_path)
+ loaded_batch_manager = _BatchManager.load_from_cache(batch_manager_path)
+
+ assert batch_manager.dump() == loaded_batch_manager.dump()
diff --git a/tests/unit/pipeline/test_local.py b/tests/unit/pipeline/test_local.py
index 3c4a15b534..4797f8e66d 100644
--- a/tests/unit/pipeline/test_local.py
+++ b/tests/unit/pipeline/test_local.py
@@ -15,7 +15,8 @@
from typing import TYPE_CHECKING
from unittest import mock
-from distilabel.pipeline.base import _Batch, _BatchManager
+from distilabel.pipeline.batch import _Batch
+from distilabel.pipeline.batch_manager import _BatchManager
from distilabel.pipeline.local import Pipeline
from .utils import DummyGeneratorStep, DummyStep1, DummyStep2
@@ -58,17 +59,11 @@ def test_send_batch_to_step(self, dummy_generator_step: "GeneratorStep") -> None
)
pipeline._send_batch_to_step(batch=batch) # type: ignore
- batch_manager_mock.set_last_batch_sent.assert_called_once_with(batch)
- get_step_mock.assert_called_once_with(dummy_generator_step.name)
+ get_step_mock.assert_has_calls([mock.call(dummy_generator_step.name)])
input_queue.put.assert_called_once_with(batch)
@mock.patch("distilabel.pipeline.local._ProcessWrapper")
def test_create_processes(self, process_wrapper_mock: mock.MagicMock) -> None:
- pool = mock.MagicMock()
- manager = mock.MagicMock()
- queue = mock.MagicMock()
- shared_info = mock.MagicMock()
-
with Pipeline(name="unit-test-pipeline") as pipeline:
dummy_generator = DummyGeneratorStep(name="dummy_generator_step")
dummy_step_1 = DummyStep1(name="dummy_step_1")
@@ -77,51 +72,52 @@ def test_create_processes(self, process_wrapper_mock: mock.MagicMock) -> None:
dummy_generator.connect(dummy_step_1)
dummy_step_1.connect(dummy_step_2)
- pipeline._run_steps_in_loop(pool, manager, queue, shared_info)
+ pipeline._pool = mock.MagicMock()
+ pipeline._manager = mock.MagicMock()
+ pipeline._output_queue = mock.MagicMock()
+ pipeline._load_queue = mock.MagicMock()
+ pipeline._run_steps()
- assert manager.Queue.call_count == 3
+ assert pipeline._manager.Queue.call_count == 3
process_wrapper_mock.assert_has_calls(
[
mock.call(
step=dummy_generator,
input_queue=mock.ANY,
- output_queue=queue,
- shared_info=shared_info,
+ output_queue=pipeline._output_queue,
+ load_queue=pipeline._load_queue,
dry_run=False,
),
mock.call(
step=dummy_step_1,
input_queue=mock.ANY,
- output_queue=queue,
- shared_info=shared_info,
+ output_queue=pipeline._output_queue,
+ load_queue=pipeline._load_queue,
dry_run=False,
),
mock.call(
step=dummy_step_2,
input_queue=mock.ANY,
- output_queue=queue,
- shared_info=shared_info,
+ output_queue=pipeline._output_queue,
+ load_queue=pipeline._load_queue,
dry_run=False,
),
],
)
- pool.apply_async.assert_has_calls(
+ pipeline._pool.apply_async.assert_has_calls(
[
mock.call(
process_wrapper_mock.return_value.run,
- callback=pipeline._finished_callback,
error_callback=pipeline._error_callback,
),
mock.call(
process_wrapper_mock.return_value.run,
- callback=pipeline._finished_callback,
error_callback=pipeline._error_callback,
),
mock.call(
process_wrapper_mock.return_value.run,
- callback=pipeline._finished_callback,
error_callback=pipeline._error_callback,
),
]
diff --git a/tests/unit/pipeline/test_routing_batch_function.py b/tests/unit/pipeline/test_routing_batch_function.py
index 5e3f208c5b..6cc3090eb7 100644
--- a/tests/unit/pipeline/test_routing_batch_function.py
+++ b/tests/unit/pipeline/test_routing_batch_function.py
@@ -14,7 +14,7 @@
from typing import List
-from distilabel.pipeline.base import _Batch
+from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.local import Pipeline
from distilabel.pipeline.routing_batch_function import (
RoutingBatchFunction,
diff --git a/tests/unit/pipeline/test_write_buffer.py b/tests/unit/pipeline/test_write_buffer.py
new file mode 100644
index 0000000000..a7ae64c91e
--- /dev/null
+++ b/tests/unit/pipeline/test_write_buffer.py
@@ -0,0 +1,150 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+from pathlib import Path
+
+from distilabel.distiset import Distiset, create_distiset
+from distilabel.pipeline.local import Pipeline
+from distilabel.pipeline.write_buffer import _WriteBuffer
+
+from tests.unit.pipeline.utils import (
+ DummyGeneratorStep,
+ DummyStep1,
+ DummyStep2,
+ batch_gen,
+)
+
+
+class TestWriteBuffer:
+ def test_create(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ folder = Path(tmpdirname) / "data"
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1")
+ dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2")
+ dummy_step_1 = DummyStep1(name="dummy_step_1")
+ dummy_step_2 = DummyStep2(name="dummy_step_2")
+ dummy_step_3 = DummyStep2(name="dummy_step_3")
+
+ dummy_generator_1.connect(dummy_step_1)
+ dummy_generator_2.connect(dummy_step_2)
+ dummy_step_1.connect(dummy_step_2)
+ dummy_step_1.connect(dummy_step_3)
+
+ write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps)
+
+ assert write_buffer._buffers == {"dummy_step_2": [], "dummy_step_3": []}
+ assert write_buffer._buffers_dump_batch_size == {
+ "dummy_step_2": 50,
+ "dummy_step_3": 50,
+ }
+ assert write_buffer._buffer_last_schema == {}
+ assert write_buffer._buffers_last_file == {
+ "dummy_step_2": 1,
+ "dummy_step_3": 1,
+ }
+
+ def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ folder = Path(tmpdirname) / "data"
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator = DummyGeneratorStep(name="dummy_generator_step")
+ dummy_step_1 = DummyStep1(name="dummy_step_1")
+ dummy_step_2 = DummyStep2(name="dummy_step_2")
+
+ dummy_generator.connect(dummy_step_1)
+ dummy_step_1.connect(dummy_step_2)
+
+ write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps)
+
+ # Add one batch with 5 rows, shouldn't write anything 5 < 50
+ batch = batch_gen(dummy_step_2.name) # type: ignore
+ write_buffer.add_batch(batch)
+
+ # Add 45 more rows, should write now
+ for _ in range(9):
+ batch = batch_gen(dummy_step_2.name) # type: ignore
+ write_buffer.add_batch(batch)
+
+ assert Path(folder, "dummy_step_2", "00001.parquet").exists()
+
+ # Add 50 more rows, we should have a new file
+ for _ in range(10):
+ batch = batch_gen(dummy_step_2.name) # type: ignore
+ write_buffer.add_batch(batch)
+
+ assert Path(folder, "dummy_step_2", "00002.parquet").exists()
+
+ # Add more rows and close the write buffer, we should have a new file
+ for _ in range(5):
+ batch = batch_gen(dummy_step_2.name) # type: ignore
+ write_buffer.add_batch(batch)
+
+ write_buffer.close()
+
+ assert Path(folder, "dummy_step_2", "00003.parquet").exists()
+
+ ds = create_distiset(write_buffer._path)
+ assert isinstance(ds, Distiset)
+ assert len(ds.keys()) == 1
+ assert len(ds["default"]["train"]) == 125
+
+ def test_write_buffer_multiple_leaf_steps_and_create_dataset(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ folder = Path(tmpdirname) / "data"
+ with Pipeline(name="unit-test-pipeline") as pipeline:
+ dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1")
+ dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2")
+ dummy_step_1 = DummyStep1(name="dummy_step_1")
+ dummy_step_2 = DummyStep2(name="dummy_step_2")
+ dummy_step_3 = DummyStep2(name="dummy_step_3")
+
+ dummy_generator_1.connect(dummy_step_1)
+ dummy_generator_2.connect(dummy_step_2)
+ dummy_step_1.connect(dummy_step_2)
+ dummy_step_1.connect(dummy_step_3)
+
+ write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps)
+
+ for _ in range(10):
+ batch = batch_gen(dummy_step_2.name) # type: ignore
+ write_buffer.add_batch(batch)
+
+ assert Path(folder, "dummy_step_2", "00001.parquet").exists()
+
+ for _ in range(10):
+ batch = batch_gen(dummy_step_3.name) # type: ignore
+ write_buffer.add_batch(batch)
+
+ assert Path(folder, "dummy_step_3", "00001.parquet").exists()
+
+ for _ in range(5):
+ batch = batch_gen(dummy_step_2.name) # type: ignore
+ write_buffer.add_batch(batch)
+
+ for _ in range(5):
+ batch = batch_gen(dummy_step_3.name) # type: ignore
+ write_buffer.add_batch(batch)
+
+ write_buffer.close()
+
+ assert Path(folder, "dummy_step_2", "00002.parquet").exists()
+ assert Path(folder, "dummy_step_3", "00002.parquet").exists()
+
+ ds = create_distiset(write_buffer._path)
+ assert isinstance(ds, Distiset)
+ assert len(ds.keys()) == 2
+ assert len(ds["dummy_step_2"]["train"]) == 75
+ assert len(ds["dummy_step_3"]["train"]) == 75
diff --git a/tests/unit/pipeline/utils.py b/tests/unit/pipeline/utils.py
index 8d02340114..7f771271d0 100644
--- a/tests/unit/pipeline/utils.py
+++ b/tests/unit/pipeline/utils.py
@@ -14,7 +14,7 @@
from typing import List
-from distilabel.pipeline.base import _Batch
+from distilabel.pipeline.batch import _Batch
from distilabel.steps.base import GeneratorStep, GlobalStep, Step, StepInput
from distilabel.steps.typing import GeneratorStepOutput, StepOutput
diff --git a/tests/unit/steps/argilla/test_base.py b/tests/unit/steps/argilla/test_base.py
index c816c8bcac..dbb8773923 100644
--- a/tests/unit/steps/argilla/test_base.py
+++ b/tests/unit/steps/argilla/test_base.py
@@ -13,6 +13,7 @@
# limitations under the License.
import os
+import sys
from typing import TYPE_CHECKING, List
import pytest
@@ -83,7 +84,9 @@ def test_with_errors(self, caplog) -> None:
with pytest.raises(
TypeError,
- match="Can't instantiate abstract class Argilla with abstract methods inputs, process",
+ match="Can't instantiate abstract class Argilla with abstract methods inputs, process"
+ if sys.version_info < (3, 12)
+ else "Can't instantiate abstract class Argilla without an implementation for abstract methods 'inputs', 'process'",
):
Argilla(name="step", pipeline=Pipeline(name="unit-test-pipeline")) # type: ignore
diff --git a/tests/unit/steps/generators/sample_functions.jsonl b/tests/unit/steps/generators/sample_functions.jsonl
new file mode 100644
index 0000000000..700d21ad5b
--- /dev/null
+++ b/tests/unit/steps/generators/sample_functions.jsonl
@@ -0,0 +1,11 @@
+{"type": "function", "function": {"name": "code_interpreter", "description": "Execute the provided Python code string on the terminal using exec.\n\n The string should contain valid, executable and pure Python code in markdown syntax.\n Code should also import any required Python packages.\n\n Args:\n code_markdown (str): The Python code with markdown syntax to be executed.\n For example: ```python\n\n```\n\n Returns:\n dict | str: A dictionary containing variables declared and values returned by function calls,\n or an error message if an exception occurred.\n\n Note:\n Use this function with caution, as executing arbitrary code can pose security risks.", "parameters": {"type": "object", "properties": {"code_markdown": {"type": "string"}}, "required": ["code_markdown"]}}}
+{"type": "function", "function": {"name": "google_search_and_scrape", "description": "Performs a Google search for the given query, retrieves the top search result URLs,\nand scrapes the text content and table data from those pages in parallel.\n\nArgs:\n query (str): The search query.\nReturns:\n list: A list of dictionaries containing the URL, text content, and table data for each scraped page.", "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}}}
+{"type": "function", "function": {"name": "get_current_stock_price", "description": "Get the current stock price for a given symbol.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n float: The current stock price, or None if an error occurs.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}
+{"type": "function", "function": {"name": "get_company_news", "description": "Get company news and press releases for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing company news and press releases.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}
+{"type": "function", "function": {"name": "get_company_profile", "description": "Get company profile and overview for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing company profile and overview.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}
+{"type": "function", "function": {"name": "get_stock_fundamentals", "description": "Get fundamental data for a given stock symbol using yfinance API.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n dict: A dictionary containing fundamental data.\n Keys:\n - 'symbol': The stock symbol.\n - 'company_name': The long name of the company.\n - 'sector': The sector to which the company belongs.\n - 'industry': The industry to which the company belongs.\n - 'market_cap': The market capitalization of the company.\n - 'pe_ratio': The forward price-to-earnings ratio.\n - 'pb_ratio': The price-to-book ratio.\n - 'dividend_yield': The dividend yield.\n - 'eps': The trailing earnings per share.\n - 'beta': The beta value of the stock.\n - '52_week_high': The 52-week high price of the stock.\n - '52_week_low': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}
+{"type": "function", "function": {"name": "get_financial_statements", "description": "Get financial statements for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing financial statements (income statement, balance sheet, cash flow statement).", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}
+{"type": "function", "function": {"name": "get_key_financial_ratios", "description": "Get key financial ratios for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\ndict: Dictionary containing key financial ratios.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}
+{"type": "function", "function": {"name": "get_analyst_recommendations", "description": "Get analyst recommendations for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing analyst recommendations.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}
+{"type": "function", "function": {"name": "get_dividend_data", "description": "Get dividend data for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing dividend data.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}
+{"type": "function", "function": {"name": "get_technical_indicators", "description": "Get technical indicators for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing technical indicators.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}
diff --git a/tests/unit/steps/generators/test_data.py b/tests/unit/steps/generators/test_data.py
index 9817837e20..c35b9db86d 100644
--- a/tests/unit/steps/generators/test_data.py
+++ b/tests/unit/steps/generators/test_data.py
@@ -17,7 +17,7 @@
from pydantic import ValidationError
-class TestLoadDataFromDictsTask:
+class TestLoadDataFromDicts:
data = [{"instruction": "test"}] * 10
def test_init(self) -> None:
diff --git a/tests/unit/steps/generators/test_huggingface.py b/tests/unit/steps/generators/test_huggingface.py
index 34b44f4fc5..e72a70acb2 100644
--- a/tests/unit/steps/generators/test_huggingface.py
+++ b/tests/unit/steps/generators/test_huggingface.py
@@ -13,19 +13,27 @@
# limitations under the License.
import os
+import tempfile
+from pathlib import Path
from typing import Generator, Union
import pytest
from datasets import Dataset, IterableDataset
+from distilabel.distiset import Distiset
from distilabel.pipeline import Pipeline
-from distilabel.steps.generators.huggingface import LoadHubDataset
+from distilabel.steps.generators.huggingface import (
+ LoadDataFromDisk,
+ LoadDataFromFileSystem,
+ LoadDataFromHub,
+ LoadHubDataset,
+)
DISTILABEL_RUN_SLOW_TESTS = os.getenv("DISTILABEL_RUN_SLOW_TESTS", False)
@pytest.fixture(scope="module")
def dataset_loader() -> Generator[Union[Dataset, IterableDataset], None, None]:
- load_hub_dataset = LoadHubDataset(
+ load_hub_dataset = LoadDataFromHub(
name="load_dataset",
repo_id="distilabel-internal-testing/instruction-dataset-mini",
split="test",
@@ -39,12 +47,12 @@ def dataset_loader() -> Generator[Union[Dataset, IterableDataset], None, None]:
not DISTILABEL_RUN_SLOW_TESTS,
reason="These tests depend on internet connection, are slow and depend mainly on HF API, we don't need to test them often.",
)
-class TestLoadHubDataset:
+class TestLoadDataFromHub:
@pytest.mark.parametrize(
"streaming, ds_type", [(True, IterableDataset), (False, Dataset)]
)
def test_runtime_parameters(self, streaming: bool, ds_type) -> None:
- load_hub_dataset = LoadHubDataset(
+ load_hub_dataset = LoadDataFromHub(
name="load_dataset",
repo_id="distilabel-internal-testing/instruction-dataset-mini",
split="test",
@@ -60,6 +68,131 @@ def test_runtime_parameters(self, streaming: bool, ds_type) -> None:
assert isinstance(generator_step_output[1], bool)
assert len(generator_step_output[0]) == 2
- def test_dataset_outputs(self, dataset_loader: LoadHubDataset) -> None:
+ def test_dataset_outputs(self, dataset_loader: LoadDataFromHub) -> None:
# TODO: This test can be run with/without internet connection, we should emulate it here with a mock.
assert dataset_loader.outputs == ["prompt", "completion", "meta"]
+
+
+class TestLoadDataFromFileSystem:
+ @pytest.mark.parametrize("filetype", ["json", None])
+ @pytest.mark.parametrize("streaming", [True, False])
+ def test_read_from_jsonl(self, streaming: bool, filetype: Union[str, None]) -> None:
+ loader = LoadDataFromFileSystem(
+ filetype=filetype,
+ data_files=str(Path(__file__).parent / "sample_functions.jsonl"),
+ streaming=streaming,
+ )
+ loader.load()
+ generator_step_output = next(loader.process())
+ assert isinstance(generator_step_output, tuple)
+ assert isinstance(generator_step_output[1], bool)
+ assert len(generator_step_output[0]) == 11
+
+ @pytest.mark.parametrize("filetype", ["json", None])
+ def test_read_from_jsonl_with_folder(self, filetype: Union[str, None]) -> None:
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ filename = "sample_functions.jsonl"
+ sample_file = Path(__file__).parent / filename
+ for i in range(3):
+ Path(tmpdir).mkdir(parents=True, exist_ok=True)
+ (Path(tmpdir) / f"sample_functions_{i}.jsonl").write_text(
+ sample_file.read_text(), encoding="utf-8"
+ )
+
+ loader = LoadDataFromFileSystem(
+ filetype=filetype,
+ data_files=tmpdir,
+ )
+ loader.load()
+ generator_step_output = next(loader.process())
+ assert isinstance(generator_step_output, tuple)
+ assert isinstance(generator_step_output[1], bool)
+ assert len(generator_step_output[0]) == 33
+
+ @pytest.mark.parametrize("filetype", ["json", None])
+ def test_read_from_jsonl_with_nested_folder(
+ self, filetype: Union[str, None]
+ ) -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ filename = "sample_functions.jsonl"
+ sample_file = Path(__file__).parent / filename
+ for folder in ["train", "validation"]:
+ (Path(tmpdir) / folder).mkdir(parents=True, exist_ok=True)
+ (Path(tmpdir) / folder / filename).write_text(
+ sample_file.read_text(), encoding="utf-8"
+ )
+
+ loader = LoadDataFromFileSystem(
+ filetype=filetype,
+ data_files=tmpdir,
+ )
+ loader.load()
+ generator_step_output = next(loader.process())
+ assert isinstance(generator_step_output, tuple)
+ assert isinstance(generator_step_output[1], bool)
+ assert len(generator_step_output[0]) == 22
+
+ @pytest.mark.parametrize("load", [True, False])
+ def test_outputs(self, load: bool) -> None:
+ loader = LoadDataFromFileSystem(
+ filetype="json",
+ data_files=str(Path(__file__).parent / "sample_functions.jsonl"),
+ )
+ if load:
+ loader.load()
+ assert loader.outputs == ["type", "function"]
+ else:
+ with pytest.raises(ValueError):
+ loader.outputs # noqa: B018
+
+
+class TestLoadDataFromDisk:
+ def test_load_dataset_from_disk(self) -> None:
+ dataset = Dataset.from_dict({"a": [1, 2, 3]})
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dataset_path = str(Path(tmpdir) / "dataset_path")
+ dataset.save_to_disk(dataset_path)
+
+ loader = LoadDataFromDisk(dataset_path=dataset_path)
+ loader.load()
+ generator_step_output = next(loader.process())
+ assert isinstance(generator_step_output, tuple)
+ assert isinstance(generator_step_output[1], bool)
+ assert len(generator_step_output[0]) == 3
+
+ def test_load_distiset_from_disk(self) -> None:
+ distiset = Distiset(
+ {
+ "leaf_step_1": Dataset.from_dict({"a": [1, 2, 3]}),
+ "leaf_step_2": Dataset.from_dict(
+ {"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}
+ ),
+ }
+ )
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dataset_path = str(Path(tmpdir) / "dataset_path")
+ distiset.save_to_disk(dataset_path)
+
+ loader = LoadDataFromDisk(
+ dataset_path=dataset_path, is_distiset=True, config="leaf_step_1"
+ )
+ loader.load()
+ generator_step_output = next(loader.process())
+ assert isinstance(generator_step_output, tuple)
+ assert isinstance(generator_step_output[1], bool)
+ assert len(generator_step_output[0]) == 3
+
+
+def test_LoadHubDataset_deprecation_warning():
+ with pytest.deprecated_call():
+ LoadHubDataset(
+ repo_id="distilabel-internal-testing/instruction-dataset-mini",
+ split="test",
+ batch_size=2,
+ )
+ import distilabel
+ from packaging.version import Version
+
+ assert Version(distilabel.__version__) <= Version("1.3.0")
diff --git a/tests/unit/steps/tasks/conftest.py b/tests/unit/steps/tasks/conftest.py
deleted file mode 100644
index da1493c9ce..0000000000
--- a/tests/unit/steps/tasks/conftest.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright 2023-present, Argilla, Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import TYPE_CHECKING, List
-
-import pytest
-from distilabel.llms.base import LLM
-
-if TYPE_CHECKING:
- from distilabel.llms.typing import GenerateOutput
- from distilabel.steps.tasks.typing import ChatType
-
-
-@pytest.fixture
-def dummy_llm() -> LLM:
- class DummyLLM(LLM):
- def load(self) -> None:
- pass
-
- @property
- def model_name(self) -> str:
- return "test"
-
- def generate( # type: ignore
- self, inputs: List["ChatType"], num_generations: int = 1
- ) -> List["GenerateOutput"]:
- return [["output"] for _ in inputs]
-
- return DummyLLM()
-
-
-# Defined here too, so that the serde still works
-class DummyLLM(LLM):
- def load(self) -> None:
- pass
-
- @property
- def model_name(self) -> str:
- return "test"
-
- def generate( # type: ignore
- self, inputs: List["ChatType"], num_generations: int = 1
- ) -> List["GenerateOutput"]:
- return [["output"] for _ in inputs]
diff --git a/tests/unit/steps/tasks/evol_instruct/test_base.py b/tests/unit/steps/tasks/evol_instruct/test_base.py
index d3999684c1..9b679f6dfb 100644
--- a/tests/unit/steps/tasks/evol_instruct/test_base.py
+++ b/tests/unit/steps/tasks/evol_instruct/test_base.py
@@ -121,13 +121,12 @@ def test_serialization(self, dummy_llm: LLM) -> None:
task.load()
assert task.dump() == {
"name": "task",
- "add_raw_output": False,
+ "add_raw_output": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"input_batch_size": task.input_batch_size,
"llm": {
"generation_kwargs": {},
- "structured_output": None,
"type_info": {
"module": task.llm.__module__,
"name": task.llm.__class__.__name__,
@@ -163,6 +162,11 @@ def test_serialization(self, dummy_llm: LLM) -> None:
}
],
},
+ {
+ "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column",
+ "name": "add_raw_output",
+ "optional": True,
+ },
{
"name": "num_generations",
"optional": True,
diff --git a/tests/unit/steps/tasks/evol_instruct/test_generator.py b/tests/unit/steps/tasks/evol_instruct/test_generator.py
index fee2234083..13cf6e2783 100644
--- a/tests/unit/steps/tasks/evol_instruct/test_generator.py
+++ b/tests/unit/steps/tasks/evol_instruct/test_generator.py
@@ -117,13 +117,12 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "task",
"llm": {
"generation_kwargs": {},
- "structured_output": None,
"type_info": {
"module": task.llm.__class__.__module__,
"name": task.llm.__class__.__name__,
},
},
- "add_raw_output": False,
+ "add_raw_output": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"batch_size": task.batch_size,
@@ -158,6 +157,11 @@ def test_serialization(self, dummy_llm: LLM) -> None:
},
],
},
+ {
+ "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column",
+ "name": "add_raw_output",
+ "optional": True,
+ },
{
"name": "num_generations",
"optional": True,
diff --git a/tests/unit/steps/tasks/evol_quality/test_base.py b/tests/unit/steps/tasks/evol_quality/test_base.py
index 251e377d9e..fffacafa06 100644
--- a/tests/unit/steps/tasks/evol_quality/test_base.py
+++ b/tests/unit/steps/tasks/evol_quality/test_base.py
@@ -34,6 +34,18 @@ def test_with_errors(
EvolQuality(name="task", llm=dummy_llm, num_evolutions=2)
assert "Step 'task' hasn't received a pipeline" in caplog.text
+ def test_apply_random_mutation(self, dummy_llm: LLM) -> None:
+ pipeline = Pipeline(name="unit-test-pipeline")
+ task = EvolQuality(
+ name="task", llm=dummy_llm, num_evolutions=2, pipeline=pipeline
+ )
+ task.load()
+
+ mutated = task._apply_random_mutation("I'm an instruction", "I'm a response")
+
+ assert "I'm an instruction" in mutated
+ assert "I'm a response" in mutated
+
def test_process(self, dummy_llm: LLM) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
task = EvolQuality(
@@ -80,13 +92,12 @@ def test_serialization(self, dummy_llm: LLM) -> None:
task.load()
assert task.dump() == {
"name": "task",
- "add_raw_output": False,
+ "add_raw_output": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"input_batch_size": task.input_batch_size,
"llm": {
"generation_kwargs": {},
- "structured_output": None,
"type_info": {
"module": task.llm.__module__,
"name": task.llm.__class__.__name__,
@@ -112,9 +123,14 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "generation_kwargs",
"description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.",
"keys": [],
- }
+ },
],
},
+ {
+ "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column",
+ "name": "add_raw_output",
+ "optional": True,
+ },
{
"name": "num_generations",
"optional": True,
diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py
index 549155076b..e174f53716 100644
--- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py
+++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py
@@ -17,9 +17,10 @@
import pytest
from distilabel.llms.huggingface.transformers import TransformersLLM
from distilabel.steps.tasks.structured_outputs.outlines import (
- StructuredOutputType,
+ # StructuredOutputType,
model_to_schema,
)
+from distilabel.steps.tasks.typing import OutlinesStructuredOutputType
from pydantic import BaseModel
@@ -88,10 +89,6 @@ class DummyUserTest(BaseModel):
class TestOutlinesIntegration:
- # @pytest.mark.skipif(
- # not DISTILABEL_RUN_SLOW_TESTS,
- # reason="Slow tests, run locally when needed.",
- # )
@pytest.mark.parametrize(
"format, schema, prompt",
[
@@ -99,7 +96,7 @@ class TestOutlinesIntegration:
"json",
DummyUserTest,
"Create a user profile with the fields name, last_name and id",
- ), #
+ ),
(
"json",
model_to_schema(DummyUserTest),
@@ -117,7 +114,9 @@ def test_generation(
) -> None:
llm = TransformersLLM(
model="openaccess-ai-collective/tiny-mistral",
- structured_output=StructuredOutputType(format=format, schema=schema),
+ structured_output=OutlinesStructuredOutputType(
+ format=format, schema=schema
+ ),
)
llm.load()
@@ -154,7 +153,9 @@ def test_serialization(
) -> None:
llm = TransformersLLM(
model="openaccess-ai-collective/tiny-mistral",
- structured_output=StructuredOutputType(format=format, schema=schema),
+ structured_output=OutlinesStructuredOutputType(
+ format=format, schema=schema
+ ),
)
llm.load()
assert llm.dump() == dump
diff --git a/tests/unit/steps/tasks/structured_outputs/test_utils.py b/tests/unit/steps/tasks/structured_outputs/test_utils.py
new file mode 100644
index 0000000000..6238c8567f
--- /dev/null
+++ b/tests/unit/steps/tasks/structured_outputs/test_utils.py
@@ -0,0 +1,75 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+from typing import List
+
+from distilabel.steps.tasks.structured_outputs.utils import json_schema_to_model
+from pydantic import BaseModel, Field, StringConstraints, conint
+from typing_extensions import Annotated
+
+
+class Node(BaseModel):
+ id: int
+ label: str
+ color: str
+
+
+class Edge(BaseModel):
+ source: int
+ target: int
+ label: str
+ color: str = "black"
+
+
+class KnowledgeGraph(BaseModel):
+ nodes: List[Node] = Field(..., default_factory=list)
+ edges: List[Edge] = Field(..., default_factory=list)
+
+
+class Weapon(str, Enum):
+ sword = "sword"
+ axe = "axe"
+ mace = "mace"
+ spear = "spear"
+ bow = "bow"
+ crossbow = "crossbow"
+
+
+class Armor(str, Enum):
+ leather = "leather"
+ chainmail = "chainmail"
+ plate = "plate"
+ mithril = "mithril"
+
+
+class Character(BaseModel):
+ name: Annotated[str, StringConstraints(max_length=30)]
+ age: conint(gt=1, lt=3000)
+ armor: Armor
+ weapon: Weapon
+
+
+def test_json_schema_to_model():
+ assert type(json_schema_to_model(Node.model_json_schema())) == type(Node)
+
+
+def test_json_schema_to_model_with_enum():
+ assert type(json_schema_to_model(Character.model_json_schema())) == type(Character)
+
+
+def test_json_schema_to_model_nested():
+ assert type(json_schema_to_model(KnowledgeGraph.model_json_schema())) == type(
+ KnowledgeGraph
+ )
diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py
index ed1fd956cf..c5ff87e1e0 100644
--- a/tests/unit/steps/tasks/test_base.py
+++ b/tests/unit/steps/tasks/test_base.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import sys
from dataclasses import field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
@@ -21,7 +22,7 @@
from distilabel.steps.tasks.base import Task
from pydantic import ValidationError
-from tests.unit.steps.tasks.utils import DummyLLM
+from tests.unit.conftest import DummyLLM
if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType
@@ -77,7 +78,9 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
with pytest.raises(
TypeError,
- match="Can't instantiate abstract class Task with abstract methods format_input, format_output",
+ match="Can't instantiate abstract class Task with abstract methods format_input, format_output"
+ if sys.version_info < (3, 12)
+ else "Can't instantiate abstract class Task without an implementation for abstract methods 'format_input', 'format_output'",
):
Task(name="task", llm=DummyLLM()) # type: ignore
@@ -91,16 +94,19 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"instruction": "test",
"output": "output",
"model_name": "test",
+ "distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test",
"output": "output",
"model_name": "test",
+ "distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test",
"output": "output",
"model_name": "test",
+ "distilabel_metadata": {"raw_output_task": "output"},
},
],
),
@@ -111,6 +117,11 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"instruction": "test",
"output": ["output", "output", "output"],
"model_name": "test",
+ "distilabel_metadata": [
+ {"raw_output_task": "output"},
+ {"raw_output_task": "output"},
+ {"raw_output_task": "output"},
+ ],
},
],
),
@@ -145,7 +156,7 @@ def test_process_with_runtime_parameters(self) -> None:
assert task.llm.runtime_parameters_names == {
"runtime_parameter": False,
"runtime_parameter_optional": True,
- "generation_kwargs": {"kwargs": False},
+ "generation_kwargs": {},
}
# 2. Runtime parameters in init
@@ -160,7 +171,7 @@ def test_process_with_runtime_parameters(self) -> None:
assert task.llm.runtime_parameters_names == {
"runtime_parameter": False,
"runtime_parameter_optional": True,
- "generation_kwargs": {"kwargs": False},
+ "generation_kwargs": {},
}
# 3. Runtime parameters in init superseded by runtime parameters
@@ -176,7 +187,7 @@ def test_process_with_runtime_parameters(self) -> None:
assert task.llm.runtime_parameters_names == {
"runtime_parameter": False,
"runtime_parameter_optional": True,
- "generation_kwargs": {"kwargs": False},
+ "generation_kwargs": {},
}
def test_serialization(self) -> None:
@@ -185,15 +196,14 @@ def test_serialization(self) -> None:
task = DummyTask(name="task", llm=llm, pipeline=pipeline)
assert task.dump() == {
"name": "task",
- "add_raw_output": False,
+ "add_raw_output": True,
"input_mappings": {},
"output_mappings": {},
"input_batch_size": 50,
"llm": {
"generation_kwargs": {},
- "structured_output": None,
"type_info": {
- "module": "tests.unit.steps.tasks.utils",
+ "module": "tests.unit.conftest",
"name": "DummyLLM",
},
},
@@ -211,16 +221,16 @@ def test_serialization(self) -> None:
{
"description": "The kwargs to be propagated to either `generate` or "
"`agenerate` methods within each `LLM`.",
- "keys": [
- {
- "name": "kwargs",
- "optional": False,
- },
- ],
+ "keys": [],
"name": "generation_kwargs",
},
],
},
+ {
+ "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column",
+ "name": "add_raw_output",
+ "optional": True,
+ },
{
"name": "num_generations",
"description": "The number of generations to be produced per input.",
diff --git a/tests/unit/steps/tasks/test_complexity_scorer.py b/tests/unit/steps/tasks/test_complexity_scorer.py
index a47a16445d..ec0575d745 100644
--- a/tests/unit/steps/tasks/test_complexity_scorer.py
+++ b/tests/unit/steps/tasks/test_complexity_scorer.py
@@ -18,7 +18,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.complexity_scorer import ComplexityScorer
-from tests.unit.steps.tasks.utils import DummyLLM
+from tests.unit.conftest import DummyLLM
class TestComplexityScorer:
diff --git a/tests/unit/steps/tasks/test_genstruct.py b/tests/unit/steps/tasks/test_genstruct.py
index 8ecc9d2d58..12878b9f26 100644
--- a/tests/unit/steps/tasks/test_genstruct.py
+++ b/tests/unit/steps/tasks/test_genstruct.py
@@ -18,7 +18,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.genstruct import Genstruct
-from tests.unit.steps.tasks.utils import DummyLLM
+from tests.unit.conftest import DummyLLM
class TestGenstruct:
diff --git a/tests/unit/steps/tasks/test_improving_text_embeddings.py b/tests/unit/steps/tasks/test_improving_text_embeddings.py
new file mode 100644
index 0000000000..8ab9b2fd51
--- /dev/null
+++ b/tests/unit/steps/tasks/test_improving_text_embeddings.py
@@ -0,0 +1,406 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Any, List
+
+import pytest
+from distilabel.llms import LLM
+from distilabel.llms.typing import GenerateOutput
+from distilabel.pipeline.local import Pipeline
+from distilabel.steps.tasks.improving_text_embeddings import (
+ BitextRetrievalGenerator,
+ EmbeddingTaskGenerator,
+ GenerateLongTextMatchingData,
+ GenerateShortTextMatchingData,
+ GenerateTextClassificationData,
+ GenerateTextRetrievalData,
+ MonolingualTripletGenerator,
+)
+from distilabel.steps.tasks.typing import ChatType
+
+
+class MockLLM(LLM):
+ output: str
+
+ def load(self) -> None:
+ pass
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ def generate( # type: ignore
+ self, inputs: List[ChatType], num_generations: int = 1
+ ) -> List[GenerateOutput]:
+ return [[self.output] for _ in range(num_generations)]
+
+
+class TestEmbeddingTaskGenerator:
+ @pytest.mark.parametrize(
+ "category",
+ [
+ "text-retrieval",
+ "text-matching-short",
+ "text-matching-long",
+ "text-classification",
+ ],
+ )
+ @pytest.mark.parametrize("flatten_tasks", [True, False])
+ def test_process(self, category: str, flatten_tasks: bool) -> None:
+ task = EmbeddingTaskGenerator(
+ name="embedding_task_generator",
+ category=category, # type: ignore
+ flatten_tasks=flatten_tasks,
+ add_raw_output=False,
+ llm=MockLLM(output="[ 'A', 'B', 'C' ]"),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ assert task.outputs == ["tasks" if not flatten_tasks else "task", "model_name"]
+
+ result = (
+ ([{"tasks": ["A", "B", "C"], "model_name": "test"}], True)
+ if not flatten_tasks
+ else (
+ [
+ {"task": "A", "model_name": "test"},
+ {"task": "B", "model_name": "test"},
+ {"task": "C", "model_name": "test"},
+ ],
+ True,
+ )
+ )
+ assert next(task.process()) == result
+
+
+class TestBitextRetrievalGenerator:
+ @pytest.mark.parametrize(
+ "task_kwargs",
+ [
+ {
+ "source_language": "English",
+ "target_language": "French",
+ "unit": "sentence",
+ "difficulty": "elementary school",
+ "high_score": "4",
+ "low_score": "2.5",
+ }
+ ],
+ )
+ def test_prompt(self, task_kwargs: Any) -> None:
+ task = BitextRetrievalGenerator(
+ name="bitext_retrieval_generator",
+ **task_kwargs,
+ add_raw_output=False,
+ llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ assert all(
+ task.prompt[-1]["content"].__contains__(v) for _, v in task_kwargs.items()
+ )
+
+ def test_process(self) -> None:
+ task = BitextRetrievalGenerator(
+ name="bitext_retrieval_generator",
+ source_language="English",
+ target_language="French",
+ add_raw_output=False,
+ llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ assert task.outputs == ["S1", "S2", "S3", "model_name"]
+
+ assert next(task.process()) == (
+ [{"S1": "A", "S2": "B", "S3": "C", "model_name": "test"}],
+ True,
+ )
+
+ def test_reproducibility(self) -> None:
+ unique_prompts = set()
+ for _ in range(10):
+ task = BitextRetrievalGenerator(
+ name="bitext_retrieval_generator",
+ source_language="English",
+ target_language="French",
+ add_raw_output=False,
+ seed=42,
+ llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ unique_prompts.add(task.prompt[-1]["content"])
+
+ assert len(unique_prompts) == 1
+
+
+class TestMonolingualTripletGenerator:
+ @pytest.mark.parametrize(
+ "task_kwargs",
+ [
+ {
+ "language": "English",
+ "unit": "sentence",
+ "difficulty": "elementary school",
+ "high_score": "4",
+ "low_score": "2.5",
+ }
+ ],
+ )
+ def test_prompt(self, task_kwargs: Any) -> None:
+ task = MonolingualTripletGenerator(
+ name="monolingual_triplet_generator",
+ **task_kwargs,
+ add_raw_output=False,
+ llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ assert all(
+ task.prompt[-1]["content"].__contains__(v) for _, v in task_kwargs.items()
+ )
+
+ def test_process(self) -> None:
+ task = MonolingualTripletGenerator(
+ name="monolingual_triplet_generator",
+ language="English",
+ add_raw_output=False,
+ llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ assert task.outputs == ["S1", "S2", "S3", "model_name"]
+ assert next(task.process()) == (
+ [{"S1": "A", "S2": "B", "S3": "C", "model_name": "test"}],
+ True,
+ )
+
+ def test_reproducibility(self) -> None:
+ unique_prompts = set()
+ for _ in range(10):
+ task = MonolingualTripletGenerator(
+ name="monolingual_triplet_generator",
+ language="English",
+ add_raw_output=False,
+ seed=42,
+ llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ unique_prompts.add(task.prompt[-1]["content"])
+ assert len(unique_prompts) == 1
+
+
+class TestGenerateLongTextMatchingData:
+ def test_format_input(self) -> None:
+ task = GenerateLongTextMatchingData(
+ name="generate_long_text_matching_data",
+ language="English",
+ add_raw_output=False,
+ llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ assert task.format_input({"task": "A"})[-1]["content"].startswith(
+ "You have been assigned a text matching task: A"
+ )
+
+ def test_process(self) -> None:
+ task = GenerateLongTextMatchingData(
+ name="generate_long_text_matching_data",
+ language="English",
+ add_raw_output=False,
+ llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+
+ assert task.outputs == ["input", "positive_document", "model_name"]
+
+ assert next(task.process(inputs=[{"task": "A"}])) == [
+ {"task": "A", "input": "A", "positive_document": "B", "model_name": "test"}
+ ]
+
+
+class TestGenerateShortTextMatchingData:
+ def test_format_input(self) -> None:
+ task = GenerateShortTextMatchingData(
+ name="generate_short_text_matching_data",
+ language="English",
+ add_raw_output=False,
+ llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ assert task.format_input({"task": "A"})[-1]["content"].startswith(
+ "You have been assigned a text matching task: A"
+ )
+
+ def test_process(self) -> None:
+ task = GenerateShortTextMatchingData(
+ name="generate_short_text_matching_data",
+ language="English",
+ add_raw_output=False,
+ llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ assert task.outputs == ["input", "positive_document", "model_name"]
+ assert next(task.process(inputs=[{"task": "A"}])) == [
+ {"task": "A", "input": "A", "positive_document": "B", "model_name": "test"}
+ ]
+
+ def test_reproducibility(self) -> None:
+ unique_prompts = set()
+ for _ in range(10):
+ task = GenerateShortTextMatchingData(
+ name="generate_short_text_matching_data",
+ language="English",
+ add_raw_output=False,
+ seed=42,
+ llm=MockLLM(
+ output=json.dumps({"input": "A", "positive_document": "B"})
+ ),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ unique_prompts.add(task.format_input({"task": "A"})[-1]["content"])
+
+ assert len(unique_prompts) == 1
+
+
+class TestGenerateTextClassificationData:
+ def test_format_input(self) -> None:
+ task = GenerateTextClassificationData(
+ name="generate_text_classification_data",
+ language="English",
+ add_raw_output=False,
+ llm=MockLLM(
+ output=json.dumps(
+ {"input_text": "A", "label": "B", "misleading_label": "C"}
+ )
+ ),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ assert task.format_input({"task": "A"})[-1]["content"].startswith(
+ "You have been assigned a text classification task: A"
+ )
+
+ def test_process(self) -> None:
+ task = GenerateTextClassificationData(
+ name="generate_text_classification_data",
+ language="English",
+ add_raw_output=False,
+ llm=MockLLM(
+ output=json.dumps(
+ {"input_text": "A", "label": "B", "misleading_label": "C"}
+ )
+ ),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ assert task.outputs == ["input_text", "label", "misleading_label", "model_name"]
+ assert next(task.process(inputs=[{"task": "A"}])) == [
+ {
+ "task": "A",
+ "input_text": "A",
+ "label": "B",
+ "misleading_label": "C",
+ "model_name": "test",
+ }
+ ]
+
+ def test_reproducibility(self) -> None:
+ unique_prompts = set()
+ for _ in range(10):
+ task = GenerateTextClassificationData(
+ name="generate_text_classification_data",
+ language="English",
+ add_raw_output=False,
+ seed=42,
+ llm=MockLLM(
+ output=json.dumps(
+ {"input_text": "A", "label": "B", "misleading_label": "C"}
+ )
+ ),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ unique_prompts.add(task.format_input({"task": "A"})[-1]["content"])
+
+ assert len(unique_prompts) == 1
+
+
+class TestGenerateTextRetrievalData:
+ def test_format_input(self) -> None:
+ task = GenerateTextRetrievalData(
+ name="generate_text_retrieval_data",
+ language="English",
+ add_raw_output=False,
+ llm=MockLLM(
+ output=json.dumps(
+ {
+ "user_query": "A",
+ "positive_document": "B",
+ "hard_negative_document": "C",
+ }
+ )
+ ),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ assert task.format_input({"task": "A"})[-1]["content"].startswith(
+ "You have been assigned a retrieval task: A"
+ )
+
+ def test_process(self) -> None:
+ task = GenerateTextRetrievalData(
+ name="generate_text_retrieval_data",
+ language="English",
+ add_raw_output=False,
+ llm=MockLLM(
+ output=json.dumps(
+ {
+ "user_query": "A",
+ "positive_document": "B",
+ "hard_negative_document": "C",
+ }
+ )
+ ),
+ pipeline=Pipeline(name="unit-test-pipeline"),
+ )
+ task.load()
+ assert task.outputs == [
+ "user_query",
+ "positive_document",
+ "hard_negative_document",
+ "model_name",
+ ]
+ assert next(task.process(inputs=[{"task": "A"}])) == [
+ {
+ "task": "A",
+ "user_query": "A",
+ "positive_document": "B",
+ "hard_negative_document": "C",
+ "model_name": "test",
+ }
+ ]
diff --git a/tests/unit/steps/tasks/test_instruction_backtranslation.py b/tests/unit/steps/tasks/test_instruction_backtranslation.py
index 4c8a8df7fe..a6f2793285 100644
--- a/tests/unit/steps/tasks/test_instruction_backtranslation.py
+++ b/tests/unit/steps/tasks/test_instruction_backtranslation.py
@@ -86,5 +86,8 @@ def test_process(self) -> None:
"score": 1,
"reason": "This is the reason.",
"model_name": "instruction-backtranslation-model",
+ "distilabel_metadata": {
+ "raw_output_instruction-backtranslation": "This is the reason. Score: 1"
+ },
}
]
diff --git a/tests/unit/steps/tasks/test_prometheus_eval.py b/tests/unit/steps/tasks/test_prometheus_eval.py
index e5a4ad8590..31a437fdab 100644
--- a/tests/unit/steps/tasks/test_prometheus_eval.py
+++ b/tests/unit/steps/tasks/test_prometheus_eval.py
@@ -27,7 +27,7 @@
from jinja2 import Template
from pydantic import ValidationError
-from tests.unit.steps.tasks.utils import DummyLLM
+from tests.unit.conftest import DummyLLM
def load_template(template: str) -> Template:
diff --git a/tests/unit/steps/tasks/test_quality_scorer.py b/tests/unit/steps/tasks/test_quality_scorer.py
index 0a3db8261b..608631e9a2 100644
--- a/tests/unit/steps/tasks/test_quality_scorer.py
+++ b/tests/unit/steps/tasks/test_quality_scorer.py
@@ -18,7 +18,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.quality_scorer import QualityScorer
-from tests.unit.steps.tasks.utils import DummyLLM
+from tests.unit.conftest import DummyLLM
class TestQualityScorer:
diff --git a/tests/unit/steps/tasks/test_self_instruct.py b/tests/unit/steps/tasks/test_self_instruct.py
index 8525b88e6b..e3378e7e93 100644
--- a/tests/unit/steps/tasks/test_self_instruct.py
+++ b/tests/unit/steps/tasks/test_self_instruct.py
@@ -15,7 +15,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.self_instruct import SelfInstruct
-from tests.unit.steps.tasks.utils import DummyLLM
+from tests.unit.conftest import DummyLLM
class TestSelfInstruct:
diff --git a/tests/unit/steps/tasks/test_sentence_transformers.py b/tests/unit/steps/tasks/test_sentence_transformers.py
new file mode 100644
index 0000000000..2f81240755
--- /dev/null
+++ b/tests/unit/steps/tasks/test_sentence_transformers.py
@@ -0,0 +1,215 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict
+
+import pytest
+from distilabel.steps.tasks.sentence_transformers import (
+ CONTEXT_INTRO,
+ POSITIVE_NEGATIVE_SYSTEM_PROMPT,
+ POSITIVE_SYSTEM_PROMPT,
+ GenerateSentencePair,
+ GenerationAction,
+)
+
+from tests.unit.conftest import DummyLLM
+
+
+class TestGenerateSentencePair:
+ @pytest.mark.parametrize(
+ "action,triplet,system_prompt",
+ [
+ (
+ "paraphrase",
+ True,
+ POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
+ action_sentence="paraphrase", context=""
+ ),
+ ),
+ (
+ "paraphrase",
+ False,
+ POSITIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase", context=""),
+ ),
+ (
+ "semantically-similar",
+ True,
+ POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
+ action_sentence="be semantically similar to", context=""
+ ),
+ ),
+ (
+ "semantically-similar",
+ False,
+ POSITIVE_SYSTEM_PROMPT.format(
+ action_sentence="be semantically similar to", context=""
+ ),
+ ),
+ (
+ "query",
+ True,
+ POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
+ action_sentence="be a query for", context=""
+ ),
+ ),
+ (
+ "query",
+ False,
+ POSITIVE_SYSTEM_PROMPT.format(
+ action_sentence="be a query for", context=""
+ ),
+ ),
+ (
+ "answer",
+ True,
+ POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
+ action_sentence="be an answer for", context=""
+ ),
+ ),
+ (
+ "answer",
+ False,
+ POSITIVE_SYSTEM_PROMPT.format(
+ action_sentence="be an answer for", context=""
+ ),
+ ),
+ ],
+ )
+ def test_format_input(
+ self, action: GenerationAction, triplet: bool, system_prompt: str
+ ) -> None:
+ task = GenerateSentencePair(llm=DummyLLM(), action=action, triplet=triplet)
+ task.load()
+ content = "## Anchor\n\nThis is a unit test\n"
+ assert task.format_input({"anchor": "This is a unit test"}) == [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": content},
+ ]
+
+ @pytest.mark.parametrize(
+ "action,triplet,system_prompt",
+ [
+ (
+ "paraphrase",
+ True,
+ POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
+ action_sentence="paraphrase", context=CONTEXT_INTRO
+ ),
+ ),
+ (
+ "paraphrase",
+ False,
+ POSITIVE_SYSTEM_PROMPT.format(
+ action_sentence="paraphrase", context=CONTEXT_INTRO
+ ),
+ ),
+ (
+ "semantically-similar",
+ True,
+ POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
+ action_sentence="be semantically similar to", context=CONTEXT_INTRO
+ ),
+ ),
+ (
+ "semantically-similar",
+ False,
+ POSITIVE_SYSTEM_PROMPT.format(
+ action_sentence="be semantically similar to", context=CONTEXT_INTRO
+ ),
+ ),
+ (
+ "query",
+ True,
+ POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
+ action_sentence="be a query for", context=CONTEXT_INTRO
+ ),
+ ),
+ (
+ "query",
+ False,
+ POSITIVE_SYSTEM_PROMPT.format(
+ action_sentence="be a query for", context=CONTEXT_INTRO
+ ),
+ ),
+ (
+ "answer",
+ True,
+ POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
+ action_sentence="be an answer for", context=CONTEXT_INTRO
+ ),
+ ),
+ (
+ "answer",
+ False,
+ POSITIVE_SYSTEM_PROMPT.format(
+ action_sentence="be an answer for", context=CONTEXT_INTRO
+ ),
+ ),
+ ],
+ )
+ def test_format_input_with_context(
+ self, action: GenerationAction, triplet: bool, system_prompt: str
+ ) -> None:
+ context = "This is your context."
+ task = GenerateSentencePair(
+ llm=DummyLLM(),
+ action=action,
+ triplet=triplet,
+ context=context,
+ )
+ task.load()
+ content = f"## Context\n\n{context}\n\n## Anchor\n\nThis is a unit test\n"
+ # content = f"## Anchor\n\nThis is a unit test\n## Context\n\n{context}"
+ assert task.format_input({"anchor": "This is a unit test"}) == [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": content},
+ ]
+
+ @pytest.mark.parametrize(
+ "output,triplet,expected",
+ [
+ (
+ "## Positive\n\nThis is a paraphrase\n## Negative\n\nThis is not a paraphrase",
+ True,
+ {
+ "positive": "This is a paraphrase",
+ "negative": "This is not a paraphrase",
+ },
+ ),
+ (
+ "## Positive\n\nThis is a paraphrase",
+ True,
+ {"positive": "This is a paraphrase", "negative": None},
+ ),
+ (
+ "## Positive\n\nThis is a paraphrase",
+ False,
+ {"positive": "This is a paraphrase"},
+ ),
+ (
+ "random",
+ False,
+ {"positive": None},
+ ),
+ ],
+ )
+ def test_format_output(
+ self, output: str, triplet: bool, expected: Dict[str, Any]
+ ) -> None:
+ task = GenerateSentencePair(
+ llm=DummyLLM(), action="paraphrase", triplet=triplet
+ )
+ task.load()
+
+ assert task.format_output(output) == expected
diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py
new file mode 100644
index 0000000000..e2c230ef7e
--- /dev/null
+++ b/tests/unit/steps/tasks/test_structured_generation.py
@@ -0,0 +1,125 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Any, List
+
+from distilabel.llms.base import LLM
+from distilabel.llms.typing import GenerateOutput
+from distilabel.pipeline.local import Pipeline
+from distilabel.steps.tasks.structured_generation import StructuredGeneration
+from distilabel.steps.tasks.typing import StructuredInput
+from typing_extensions import override
+
+
+class DummyStructuredLLM(LLM):
+ def load(self) -> None:
+ pass
+
+ @property
+ def model_name(self) -> str:
+ return "test"
+
+ @override
+ def generate( # type: ignore
+ self, inputs: List["StructuredInput"], num_generations: int = 1, **kwargs: Any
+ ) -> List["GenerateOutput"]:
+ return [
+ [json.dumps({"test": "output"}) for _ in range(num_generations)]
+ for _ in inputs
+ ]
+
+
+class TestStructuredGeneration:
+ def test_format_input(self) -> None:
+ pipeline = Pipeline(name="unit-test-pipeline")
+ llm = DummyStructuredLLM()
+ task = StructuredGeneration(name="task", llm=llm, pipeline=pipeline)
+
+ # 1. Including the `grammar` field within the input
+ assert task.format_input(
+ {
+ "instruction": "test",
+ "system_prompt": "test",
+ "structured_output": {"format": "regex", "schema": r"[a-zA-Z]+"},
+ }
+ ) == (
+ [{"role": "user", "content": "test"}],
+ {"format": "regex", "schema": r"[a-zA-Z]+"},
+ )
+
+ # 2. Not including the `grammar` field within the input
+ assert task.format_input({"instruction": "test", "system_prompt": "test"}) == (
+ [{"role": "user", "content": "test"}],
+ None,
+ )
+
+ def test_format_input_with_system_prompt(self) -> None:
+ pipeline = Pipeline(name="unit-test-pipeline")
+ llm = DummyStructuredLLM()
+ task = StructuredGeneration(
+ name="task",
+ llm=llm,
+ pipeline=pipeline,
+ use_system_prompt=True,
+ )
+
+ assert task.format_input({"instruction": "test", "system_prompt": "test"}) == (
+ [
+ {"role": "system", "content": "test"},
+ {"role": "user", "content": "test"},
+ ],
+ None,
+ )
+
+ def test_process(self) -> None:
+ pipeline = Pipeline(name="unit-test-pipeline")
+ llm = DummyStructuredLLM()
+ task = StructuredGeneration(name="task", llm=llm, pipeline=pipeline)
+ assert next(
+ task.process(
+ [
+ {
+ "instruction": "test",
+ "structured_output": {
+ "format": "json",
+ "schema": {
+ "properties": {
+ "test": {"title": "Test", "type": "string"}
+ },
+ "required": ["test"],
+ "title": "Test",
+ "type": "object",
+ },
+ },
+ }
+ ]
+ )
+ ) == [
+ {
+ "instruction": "test",
+ "structured_output": {
+ "format": "json",
+ "schema": {
+ "properties": {"test": {"title": "Test", "type": "string"}},
+ "required": ["test"],
+ "title": "Test",
+ "type": "object",
+ },
+ },
+ "generation": '{"test": "output"}',
+ "model_name": "test",
+ "distilabel_metadata": {"raw_output_task": '{"test": "output"}'},
+ }
+ ]
diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py
index ecff0e1d90..c98adb00e5 100644
--- a/tests/unit/steps/tasks/test_text_generation.py
+++ b/tests/unit/steps/tasks/test_text_generation.py
@@ -16,7 +16,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration
-from tests.unit.steps.tasks.utils import DummyLLM
+from tests.unit.conftest import DummyLLM
class TestTextGeneration:
@@ -53,6 +53,12 @@ def test_format_input_errors(self) -> None:
name="task", llm=llm, pipeline=pipeline, use_system_prompt=True
)
+ with pytest.raises(
+ ValueError,
+ match=r"Providing \`instruction\` formatted as an OpenAI chat / conversation is deprecated",
+ ):
+ task.format_input({"instruction": [{"role": "user", "content": "test"}]})
+
with pytest.raises(
ValueError, match=r"Input \`instruction\` must be a string. Got: 1."
):
@@ -76,26 +82,12 @@ def test_process(self) -> None:
"instruction": "test",
"generation": "output",
"model_name": "test",
+ "distilabel_metadata": {
+ "raw_output_task": "output",
+ },
}
]
- def test_deprecation_warning(self) -> None:
- pipeline = Pipeline(name="unit-test-pipeline")
- llm = DummyLLM()
- task = TextGeneration(name="task", llm=llm, pipeline=pipeline)
-
- with pytest.warns(
- DeprecationWarning,
- match=r"Providing \`instruction\` formatted as an OpenAI chat \/ conversation is about to be deprecated in \`distilabel v1.2.0\`",
- ):
- task.format_input(
- {
- "instruction": [
- {"role": "user", "content": "Tell me a joke."},
- ]
- }
- )
-
class TestChatGeneration:
def test_format_input(self) -> None:
@@ -150,5 +142,6 @@ def test_process(self) -> None:
"messages": [{"role": "user", "content": "Tell me a joke."}],
"generation": "output",
"model_name": "test",
+ "distilabel_metadata": {"raw_output_task": "output"},
}
]
diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py
index 69e9326570..fa72ff9442 100644
--- a/tests/unit/steps/tasks/test_ultrafeedback.py
+++ b/tests/unit/steps/tasks/test_ultrafeedback.py
@@ -63,6 +63,9 @@ def test_process_with_simple_aspect(self) -> None:
"ratings": [1, 2],
"rationales": ["text", "text"],
"model_name": "ultrafeedback-model",
+ "distilabel_metadata": {
+ "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text"
+ },
}
]
@@ -89,5 +92,8 @@ def test_process_with_complex_aspect(self) -> None:
"ratings": [1, 2],
"rationales-for-ratings": ["text", "text"],
"model_name": "ultrafeedback-model",
+ "distilabel_metadata": {
+ "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text"
+ },
}
]
diff --git a/tests/unit/steps/test_base.py b/tests/unit/steps/test_base.py
index 543e5a43d2..6b469bf324 100644
--- a/tests/unit/steps/test_base.py
+++ b/tests/unit/steps/test_base.py
@@ -310,7 +310,7 @@ def test_step_from_dict(self) -> None:
**{
"name": "dummy",
TYPE_INFO_KEY: {
- "module": "tests.unit.pipeline.step.test_base",
+ "module": "tests.unit.steps.test_base",
"name": "DummyStep",
},
}
@@ -327,7 +327,7 @@ def test_step_from_dict_without_pipeline_context(
**{
"name": "dummy",
TYPE_INFO_KEY: {
- "module": "tests.pipeline.step.test_base",
+ "module": "tests.unit.steps.test_base",
"name": "DummyStep",
},
}
diff --git a/tests/unit/test_distiset.py b/tests/unit/test_distiset.py
index 18de3f8769..07e6549d7b 100644
--- a/tests/unit/test_distiset.py
+++ b/tests/unit/test_distiset.py
@@ -12,9 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
+import re
+import tempfile
+from pathlib import Path
+from typing import Any, Dict, Optional
+
import pytest
+import yaml
from datasets import Dataset, DatasetDict
from distilabel.distiset import Distiset
+from upath import UPath
@pytest.fixture(scope="function")
@@ -27,6 +35,24 @@ def distiset():
)
+def make_fake_file(filename: Path) -> None:
+ if not filename.parent.exists():
+ filename.parent.mkdir(parents=True)
+ filename.touch()
+
+
+def add_config_to_distiset(distiset: Distiset, folder: Path) -> Distiset:
+ from distilabel.distiset import DISTISET_CONFIG_FOLDER
+
+ pipeline_yaml = folder / DISTISET_CONFIG_FOLDER / "pipeline.yaml"
+ pipeline_log = folder / DISTISET_CONFIG_FOLDER / "pipeline.log"
+ make_fake_file(pipeline_yaml)
+ make_fake_file(pipeline_log)
+ distiset.pipeline_path = pipeline_yaml
+ distiset.pipeline_log_path = pipeline_log
+ return distiset
+
+
class TestDistiset:
def test_train_test_split(self, distiset: Distiset) -> None:
assert isinstance(distiset["leaf_step_1"], Dataset)
@@ -34,3 +60,111 @@ def test_train_test_split(self, distiset: Distiset) -> None:
assert isinstance(ds, Distiset)
assert len(ds) == 2
assert isinstance(ds["leaf_step_1"], DatasetDict)
+
+ @pytest.mark.parametrize("storage_options", [None, {"test": "option"}])
+ @pytest.mark.parametrize("with_config", [False, True])
+ def test_save_to_disk(
+ self,
+ distiset: Distiset,
+ with_config: bool,
+ storage_options: Optional[Dict[str, Any]],
+ ) -> None:
+ full_distiset = copy.deepcopy(distiset)
+ # Distiset with Distiset
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ folder = Path(tmpdirname) / "distiset_folder"
+ if with_config:
+ full_distiset = add_config_to_distiset(full_distiset, folder)
+
+ full_distiset.save_to_disk(
+ folder,
+ save_card=with_config,
+ save_pipeline_config=with_config,
+ save_pipeline_log=with_config,
+ storage_options=storage_options,
+ )
+ assert folder.is_dir()
+ assert len(list(folder.iterdir())) == 3
+
+ full_distiset = copy.deepcopy(distiset)
+ # Distiset with DatasetDict
+ distiset_with_dict = full_distiset.train_test_split(0.8)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ folder = Path(tmpdirname) / "distiset_folder"
+ if with_config:
+ distiset_with_dict = add_config_to_distiset(distiset_with_dict, folder)
+
+ distiset_with_dict.save_to_disk(
+ folder,
+ save_card=with_config,
+ save_pipeline_config=with_config,
+ save_pipeline_log=with_config,
+ )
+
+ assert folder.is_dir()
+ assert len(list(folder.iterdir())) == 3
+
+ @pytest.mark.parametrize("pathlib_implementation", [Path, UPath])
+ @pytest.mark.parametrize("storage_options", [None, {"project": "experiments"}])
+ @pytest.mark.parametrize("with_config", [False, True])
+ def test_load_from_disk(
+ self,
+ distiset: Distiset,
+ with_config: bool,
+ storage_options: Optional[Dict[str, Any]],
+ pathlib_implementation: type,
+ ) -> None:
+ full_distiset = copy.deepcopy(distiset)
+ # Distiset with Distiset
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ # This way we can test also we work with UPath, using FilePath protocol, as it should
+ # do the same as S3Path, GCSPath, etc.
+ folder = pathlib_implementation(tmpdirname) / "distiset_folder"
+ if with_config:
+ full_distiset = add_config_to_distiset(full_distiset, folder)
+ full_distiset.save_to_disk(
+ folder,
+ save_card=with_config,
+ save_pipeline_config=with_config,
+ save_pipeline_log=with_config,
+ storage_options=storage_options,
+ )
+ ds = Distiset.load_from_disk(
+ folder,
+ storage_options=storage_options,
+ )
+ assert isinstance(ds, Distiset)
+ assert isinstance(ds["leaf_step_1"], Dataset)
+
+ if with_config:
+ assert ds.pipeline_path.exists()
+ assert ds.log_filename_path.exists()
+
+ full_distiset = copy.deepcopy(distiset)
+ # Distiset with DatasetDict
+ distiset_with_dict = full_distiset.train_test_split(0.8)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ folder = pathlib_implementation(tmpdirname) / "distiset_folder"
+ if with_config:
+ distiset_with_dict = add_config_to_distiset(distiset_with_dict, folder)
+
+ distiset_with_dict.save_to_disk(folder)
+ ds = Distiset.load_from_disk(folder, storage_options=storage_options)
+
+ assert folder.is_dir()
+ assert isinstance(ds["leaf_step_1"], DatasetDict)
+
+ if with_config:
+ assert ds.pipeline_path.exists()
+ assert ds.log_filename_path.exists()
+
+ def test_dataset_card(self, distiset: Distiset) -> None:
+ # Test the the metadata we generate by default without extracting the already generated content from the HF hub.
+ # We parse the content and check it's the same as the one we generate.
+ distiset_card = distiset._get_card("repo_name_or_path")
+ metadata = re.findall(r"---\n(.*?)\n---", str(distiset_card), re.DOTALL)[0]
+ metadata = yaml.safe_load(metadata)
+ assert metadata == {
+ "size_categories": "n<1K",
+ "tags": ["synthetic", "distilabel", "rlaif"],
+ }
diff --git a/tests/unit/test_imports.py b/tests/unit/test_imports.py
index cfccb1585f..e20e186c8e 100644
--- a/tests/unit/test_imports.py
+++ b/tests/unit/test_imports.py
@@ -51,6 +51,8 @@ def test_imports() -> None:
GeneratorStepOutput,
KeepColumns,
LoadDataFromDicts,
+ LoadDataFromHub,
+ LoadDataFromDisk,
LoadHubDataset,
PushToHub,
Step,
@@ -72,11 +74,19 @@ def test_imports() -> None:
EvolInstructGenerator,
GenerateEmbeddings,
Genstruct,
+ BitextRetrievalGenerator,
+ EmbeddingTaskGenerator,
+ GenerateLongTextMatchingData,
+ GenerateShortTextMatchingData,
+ GenerateTextClassificationData,
+ GenerateTextRetrievalData,
+ MonolingualTripletGenerator,
InstructionBacktranslation,
PairRM,
PrometheusEval,
QualityScorer,
SelfInstruct,
+ StructuredGeneration,
TextGeneration,
UltraFeedback,
)
diff --git a/tests/unit/utils/test_serialization.py b/tests/unit/utils/test_serialization.py
new file mode 100644
index 0000000000..153e2a8692
--- /dev/null
+++ b/tests/unit/utils/test_serialization.py
@@ -0,0 +1,37 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.utils.serialization import _extra_serializable_fields, _Serializable
+from pydantic import BaseModel
+
+
+def test_extra_serializable_fields() -> None:
+ class DummyAttribute(BaseModel, _Serializable):
+ pass
+
+ class Dummy(BaseModel, _Serializable):
+ attr: DummyAttribute
+
+ dummy = Dummy(attr=DummyAttribute())
+
+ assert _extra_serializable_fields(dummy) == [
+ {
+ "attr": {
+ "type_info": {
+ "module": "tests.unit.utils.test_serialization",
+ "name": "DummyAttribute",
+ }
+ }
+ }
+ ]