From 0e8c75242d91fe6e1f482c18e1d08a7639bbed82 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 12 Jun 2024 10:09:39 +0200 Subject: [PATCH] Implement "Improving Text Embeddings with LLMs" (#683) * Set `input` as optional in `format_output` * Implement "Improving Text Embeddings with LLMs" (WIP) * Implement "Improving Text Embeddings with LLMs" (WIP) * Add `model_name` at the end of each batch * Move `text_embeddings.py` to `improving_text_embeddings.py` * Fix `re.sub` to also capture `\t` and `\r` * Add `MonolingualTripletGenerator` and `BitextRetrievalGenerator` * Move all `templates` from `str` to `jinja2` files * Update class naming and imports * Add some docstrings and fix `jinja2` file paths * Fix `prompt` accross tasks * Add missing docstrings * Fix `process` method in `EmbeddingTaskGenerator` * Add unit tests for `...Generator` tasks * Add remaining unit tests * Remove duplicated imports in `distilabel.steps.tasks` * Add examples in docstrings and add notes --- src/distilabel/steps/tasks/__init__.py | 16 + src/distilabel/steps/tasks/base.py | 19 +- .../steps/tasks/improving_text_embeddings.py | 941 ++++++++++++++++++ .../bitext-retrieval.jinja2 | 13 + .../brainstorming/text-classification.jinja2 | 6 + .../brainstorming/text-matching-long.jinja2 | 7 + .../brainstorming/text-matching-short.jinja2 | 8 + .../brainstorming/text-retrieval.jinja2 | 11 + .../long-text-matching.jinja2 | 12 + .../monolingual-triplet.jinja2 | 10 + .../short-text-matching.jinja2 | 12 + .../text-classification.jinja2 | 15 + .../text-retrieval.jinja2 | 17 + .../tasks/test_improving_text_embeddings.py | 406 ++++++++ tests/unit/test_imports.py | 7 + 15 files changed, 1494 insertions(+), 6 deletions(-) create mode 100644 src/distilabel/steps/tasks/improving_text_embeddings.py create mode 100644 src/distilabel/steps/tasks/templates/improving_text_embeddings/bitext-retrieval.jinja2 create mode 100644 src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-classification.jinja2 create mode 100644 src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-long.jinja2 create mode 100644 src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-short.jinja2 create mode 100644 src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-retrieval.jinja2 create mode 100644 src/distilabel/steps/tasks/templates/improving_text_embeddings/long-text-matching.jinja2 create mode 100644 src/distilabel/steps/tasks/templates/improving_text_embeddings/monolingual-triplet.jinja2 create mode 100644 src/distilabel/steps/tasks/templates/improving_text_embeddings/short-text-matching.jinja2 create mode 100644 src/distilabel/steps/tasks/templates/improving_text_embeddings/text-classification.jinja2 create mode 100644 src/distilabel/steps/tasks/templates/improving_text_embeddings/text-retrieval.jinja2 create mode 100644 tests/unit/steps/tasks/test_improving_text_embeddings.py diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 72e2216e71..b2456d7824 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -23,6 +23,15 @@ from distilabel.steps.tasks.evol_quality.base import EvolQuality from distilabel.steps.tasks.generate_embeddings import GenerateEmbeddings from distilabel.steps.tasks.genstruct import Genstruct +from distilabel.steps.tasks.improving_text_embeddings import ( + BitextRetrievalGenerator, + EmbeddingTaskGenerator, + GenerateLongTextMatchingData, + GenerateShortTextMatchingData, + GenerateTextClassificationData, + GenerateTextRetrievalData, + MonolingualTripletGenerator, +) from distilabel.steps.tasks.instruction_backtranslation import ( InstructionBacktranslation, ) @@ -47,6 +56,13 @@ "EvolQuality", "GenerateEmbeddings", "Genstruct", + "BitextRetrievalGenerator", + "EmbeddingTaskGenerator", + "GenerateLongTextMatchingData", + "GenerateShortTextMatchingData", + "GenerateTextClassificationData", + "GenerateTextRetrievalData", + "MonolingualTripletGenerator", "InstructionBacktranslation", "PairRM", "PrometheusEval", diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index 2c19d8c8a4..fda1e1e248 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -70,7 +70,9 @@ def load(self) -> None: @abstractmethod def format_output( - self, output: Union[str, None], input: Dict[str, Any] + self, + output: Union[str, None], + input: Union[Dict[str, Any], None] = None, ) -> Dict[str, Any]: """Abstract method to format the outputs of the task. It needs to receive an output as a string, and generates a Python dictionary with the outputs of the task. In @@ -80,7 +82,9 @@ def format_output( pass def _format_outputs( - self, outputs: "GenerateOutput", inputs: List[Dict[str, Any]] + self, + outputs: "GenerateOutput", + inputs: Union[List[Dict[str, Any]], None] = None, ) -> List[Dict[str, Any]]: """Formats the outputs of the task using the `format_output` method. If the output is `None` (i.e. the LLM failed to generate a response), then the outputs will be @@ -93,8 +97,11 @@ def _format_outputs( Returns: A list containing a dictionary with the outputs of the task for each input. """ + if inputs is None: + inputs = [None] # type: ignore + formatted_outputs = [] - for output, input in zip(outputs, inputs * len(outputs)): + for output, input in zip(outputs, inputs * len(outputs)): # type: ignore try: formatted_output = self.format_output(output, input) formatted_output = self._maybe_add_raw_output( @@ -109,7 +116,7 @@ def _format_outputs( return formatted_outputs def _output_on_failure( - self, output: Union[str, None], input: Dict[str, Any] + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None ) -> Dict[str, Any]: """In case of failure to format the output, this method will return a dictionary including a new field `distilabel_meta` with the raw output of the LLM. @@ -189,14 +196,14 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore if self.group_generations: combined = combine_dicts(*formatted_outputs) task_outputs.append( - {**input, "model_name": self.llm.model_name, **combined} + {**input, **combined, "model_name": self.llm.model_name} ) continue # Create a row per generation for formatted_output in formatted_outputs: task_outputs.append( - {**input, "model_name": self.llm.model_name, **formatted_output} + {**input, **formatted_output, "model_name": self.llm.model_name} ) yield task_outputs diff --git a/src/distilabel/steps/tasks/improving_text_embeddings.py b/src/distilabel/steps/tasks/improving_text_embeddings.py new file mode 100644 index 0000000000..0e91354274 --- /dev/null +++ b/src/distilabel/steps/tasks/improving_text_embeddings.py @@ -0,0 +1,941 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import re +import sys +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Literal, Optional, Union + +if sys.version_info < (3, 9): + import importlib_resources +else: + import importlib.resources as importlib_resources + +from jinja2 import Template +from pydantic import Field, PrivateAttr +from typing_extensions import override + +from distilabel.steps.tasks.base import GeneratorTask, Task +from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.typing import GeneratorStepOutput + + +# BASE CLASSES +class _JSONFormatter(ABC): + """Abstract class that sets the `outputs` property and `format_output` method, assuming + that the output is a JSON string with the keys specified in the `keys` property. So on, + this class is intended to be used whenever we get a JSON string as the `LLM` output with + a set of `keys` we know are there. + + Note: + At the moment this abstract class is only intended to be used for the tasks defined + below based on the output generated by those. Also note that this is not a replacement + for neither the `StructuredGeneration` task nor for the `structured_output` argument + of an `LLM` subclass. + """ + + @property + @abstractmethod + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + ... + + @property + def outputs(self) -> List[str]: + """Contains the output columns produced by the `process` method of the task. In this + case, it consists of the `keys` (i.e. the JSON keys) and the `model_name`. + """ + return self.keys + ["model_name"] + + def format_output( + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None + ) -> Dict[str, Any]: + """Method to parse the JSON output into a Python dictionary based on the `keys` property. + + Args: + output: The JSON string output produced by the `LLM`. + input: The input dictionary that was used to generate the output. + + Returns: + A Python dictionary with the parsed output based on the `keys` property. + """ + if output is None: + return {key: None for key in self.keys} + + def escape_backslashes_in_values(s): + # Regular expression to match the key-value pairs in the dictionary + pattern = re.compile(r'(".*?":\s*")(.*?)(",?)', re.DOTALL) + + def replace_backslashes(match): + return ( + match.group(1) + + re.sub( + r"(? None: + """Loads the Jinja2 template and sets the random seed.""" + super().load() + + random.seed(self.seed) + + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "improving_text_embeddings" + / f"{self._template_name}.jinja2" # type: ignore + ) + + self._template = Template(open(_path).read()) + + @property + def inputs(self) -> List[str]: + """Contains the input columns expected by the `process` method of the task. In this + case, it consists of the `task`; ideally produced in a previous task which should be + preferrably `EmbeddingTaskGenerator` (as per the original implementation).""" + return ["task"] + + +class _EmbeddingDataGenerator(_JSONFormatter, GeneratorTask, ABC): + """Base class for the subtasks related to embedding data generation as presented in the + paper "Improving Text Embeddings with Large Language Models" that generate data without + an input i.e. `GeneratorStep` or `GeneratorTask`. This class includes a pre-defined `load` + method to load a Jinja2 template based on the `_template_name` private attribute (to be set + in each of the subclasses), assuming that the `prompt` property only expects the `task`, while + keeping the `format_input` as an abstract method to be implemented in the subclasses. + + Attributes: + seed: The random seed to be set in case there's any sampling within the `format_input` method. + _template: The Jinja2 template to be rendered within the `format_input` method with the + provided arguments. + _template_name: The name of the Jinja2 template file within the + `distilabel/steps/tasks/templates/improving_text_embeddings` directory. + """ + + seed: int = 42 + + _template: Union[Template, None] = PrivateAttr(...) + _template_name: str = PrivateAttr(...) + + def load(self) -> None: + """Loads the Jinja2 template and sets the random seed.""" + super().load() + + random.seed(self.seed) + + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "improving_text_embeddings" + / f"{self._template_name}.jinja2" # type: ignore + ) + + self._template = Template(open(_path).read()) + + @property + @abstractmethod + def prompt(self) -> ChatType: + """The prompt to be used for the generation step, ideally rendering the `_template`.""" + ... + + @override + def process(self, offset: int = 0) -> GeneratorStepOutput: # type: ignore + """Method to run the `LLM` generation with the `prompt`, as well as formatting the + outputs accordingly for the task i.e. via the `_JSONFormatter` inheritance. So on, the + `LLM` ideally will be prompted to produce JSON content and then the `format_output` + method will parse it into a Python dictionary based on the `keys` property. + + Args: + offset: The offset to start the generation from. Defaults to 0. + + Yields: + The output rows and a boolean indicating if it's the last batch or not. + """ + formatted_inputs = [self.prompt] + outputs = self.llm.generate( + inputs=formatted_inputs, + num_generations=self.num_generations, + **self.llm.generation_kwargs, # type: ignore + ) + + task_outputs = [] + for input_outputs in outputs: + formatted_outputs = self._format_outputs(input_outputs) # type: ignore + for formatted_output in formatted_outputs: + task_outputs.append( + { + **formatted_output, + "model_name": self.llm.model_name, + } + ) + yield task_outputs, True + + +# IMPLEMENTED TASKS +class EmbeddingTaskGenerator(GeneratorTask): + """Generate task descriptions for embedding-related tasks using an `LLM`. + + `EmbeddingTaskGenerator` is a `GeneratorTask` that doesn't receieve any input besides the + provided attributes that generates task descriptions for embedding-related tasks using a + pre-defined prompt based on the `category` attribute. The `category` attribute should be + one of the following: + + - `text-retrieval`: Generate task descriptions for text retrieval tasks. + - `text-matching-short`: Generate task descriptions for short text matching tasks. + - `text-matching-long`: Generate task descriptions for long text matching tasks. + - `text-classification`: Generate task descriptions for text classification tasks. + + Attributes: + category: The category of the task to be generated, which can either be `text-retrieval`, + `text-matching-short`, `text-matching-long`, or `text-classification`. + flatten_tasks: Whether to flatten the tasks i.e. since a list of tasks is generated by the + `LLM`, this attribute indicates whether to flatten the list or not. Defaults to `False`, + meaning that running this task with `num_generations=1` will return a `distilabel.Distiset` + with one row only containing a list with around 20 tasks; otherwise, if set to `True`, it + will return a `distilabel.Distiset` with around 20 rows, each containing one task. + + References: + - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368) + + Examples: + + Generate embedding tasks for text retrieval: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import EmbeddingTaskGenerator + + with Pipeline("my-pipeline") as pipeline: + task = EmbeddingTaskGenerator( + category="text-retrieval", + flatten_tasks=True, + llm=..., # LLM instance + ) + + ... + + task >> ... + ``` + """ + + category: Literal[ + "text-retrieval", + "text-matching-short", + "text-matching-long", + "text-classification", + ] + flatten_tasks: bool = False + + _template: Union[Template, None] = PrivateAttr(...) + + def load(self) -> None: + """Loads the Jinja2 template.""" + super().load() + + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "improving_text_embeddings" + / "brainstorming" + / f"{self.category}.jinja2" + ) + + self._template = Template(open(_path).read()) + + @property + def prompt(self) -> ChatType: # type: ignore + """The prompt to be used in the `process` method, rendering the `_template` with the + provided args / attributes. + """ + return [{"role": "user", "content": self._template.render().strip()}] # type: ignore + + @override + def process(self, offset: int = 0) -> GeneratorStepOutput: # type: ignore + """Method to run the `LLM` generation with the `prompt`, as well as formatting the + outputs accordingly for the task i.e. via the `_JSONFormatter` inheritance. So on, the + `LLM` ideally will be prompted to produce JSON content and then the `format_output` + method will parse it into a Python dictionary based on the `keys` property. + + Args: + offset: The offset to start the generation from. Defaults to 0. + + Yields: + The output rows and a boolean indicating if it's the last batch or not. + """ + formatted_inputs = [self.prompt] + outputs = self.llm.generate( + inputs=formatted_inputs, + num_generations=self.num_generations, + **self.llm.generation_kwargs, # type: ignore + ) + + task_outputs = [] + for input_outputs in outputs: + formatted_outputs = self._format_outputs(input_outputs) # type: ignore + for formatted_output in formatted_outputs: + if isinstance(formatted_output["tasks"], list) and self.flatten_tasks: + tasks = formatted_output.pop("tasks") + task_outputs.extend( + [ + { + "task": task, + **formatted_output, + "model_name": self.llm.model_name, + } + for task in tasks + ] + ) + else: + if self.flatten_tasks: + formatted_output["task"] = formatted_output.pop("tasks") + task_outputs.append( + {**formatted_output, "model_name": self.llm.model_name} + ) + yield task_outputs, True + + @property + def outputs(self) -> List[str]: + """Contains the output columns produced by the `process` method of the task. In this + case, it consists of the `tasks` or `task` (depending on the `flatten_tasks` attribute) + and the `model_name`. + """ + return ["tasks" if not self.flatten_tasks else "task", "model_name"] + + def format_output( + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None + ) -> Dict[str, Any]: + """Method to parse the JSON output into a Python dictionary based on the `keys` property. + + Args: + output: The JSON string output produced by the `LLM`. + input: The input dictionary that was used to generate the output. + + Returns: + A Python dictionary with the parsed output based on the `keys` property. + """ + try: + if output is not None: + output = eval(output) + except Exception: + pass + return {"tasks": output} + + +class GenerateTextRetrievalData(_EmbeddingDataGeneration): + """Generate text retrieval data with an `LLM` to later on train an embedding model. + + `GenerateTextRetrievalData` is a `Task` that generates text retrieval data with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Note: + Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True` + with the `category="text-retrieval"`; so that the `LLM` generates a list of tasks that + are flattened so that each row contains a single task for the text-retrieval category. + + Attributes: + language: The language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + query_type: The type of query to be generated, which can be `extremely long-tail`, `long-tail`, + or `common`. Defaults to `None`, meaning that it will be randomly sampled. + query_length: The length of the query to be generated, which can be `less than 5 words`, `5 to 15 words`, + or `at least 10 words`. Defaults to `None`, meaning that it will be randomly sampled. + difficulty: The difficulty of the query to be generated, which can be `high school`, `college`, or `PhD`. + Defaults to `None`, meaning that it will be randomly sampled. + clarity: The clarity of the query to be generated, which can be `clear`, `understandable with some effort`, + or `ambiguous`. Defaults to `None`, meaning that it will be randomly sampled. + num_words: The number of words in the query to be generated, which can be `50`, `100`, `200`, `300`, `400`, or `500`. + Defaults to `None`, meaning that it will be randomly sampled. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + + References: + - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368) + + Examples: + + Generate synthetic text retrieval data for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateTextRetrievalData + + with Pipeline("my-pipeline") as pipeline: + task = EmbeddingTaskGenerator( + category="text-retrieval", + flatten_tasks=True, + llm=..., # LLM instance + ) + + generate = GenerateTextRetrievalData( + language="English", + query_type="common", + query_length="5 to 15 words", + difficulty="high school", + clarity="clear", + num_words=100, + llm=..., # LLM instance + ) + + task >> generate + ``` + """ + + language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + query_type: Optional[Literal["extremely long-tail", "long-tail", "common"]] = None + query_length: Optional[ + Literal["less than 5 words", "5 to 15 words", "at least 10 words"] + ] = None + difficulty: Optional[Literal["high school", "college", "PhD"]] = None + clarity: Optional[ + Literal["clear", "understandable with some effort", "ambiguous"] + ] = None + num_words: Optional[Literal[50, 100, 200, 300, 400, 500]] = None + + _template_name: str = PrivateAttr(default="text-retrieval") + + def format_input(self, input: Dict[str, Any]) -> ChatType: + """Method to format the input based on the `task` and the provided attributes, or just + randomly sampling those if not provided. This method will render the `_template` with + the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that + there's only one turn, being from the user with the content being the rendered `_template`. + + Args: + input: The input dictionary containing the `task` to be used in the `_template`. + + Returns: + A list with a single chat containing the user's message with the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + task=input["task"], + language=self.language, + query_type=self.query_type + or random.choice(["extremely long-tail", "long-tail", "common"]), + query_length=self.query_length + or random.choice( + ["less than 5 words", "5 to 15 words", "at least 10 words"] + ), + difficulty=self.difficulty + or random.choice(["high school", "college", "PhD"]), + clarity=self.clarity + or random.choice( + ["clear", "understandable with some effort", "ambiguous"] + ), + num_words=self.num_words + or random.choice([50, 100, 200, 300, 400, 500]), + ).strip(), + } + ] + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return [ + "user_query", + "positive_document", + "hard_negative_document", + ] + + +class GenerateShortTextMatchingData(_EmbeddingDataGeneration): + """Generate short text matching data with an `LLM` to later on train an embedding model. + + `GenerateShortTextMatchingData` is a `Task` that generates short text matching data with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Note: + Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True` + with the `category="text-matching-short"`; so that the `LLM` generates a list of tasks that + are flattened so that each row contains a single task for the text-matching-short category. + + Attributes: + language: The language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + Note that in this task the `seed` has no effect since there are no sampling params. + + References: + - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368) + + Examples: + + Generate synthetic short text matching data for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateShortTextMatchingData + + with Pipeline("my-pipeline") as pipeline: + task = EmbeddingTaskGenerator( + category="text-matching-short", + flatten_tasks=True, + llm=..., # LLM instance + ) + + generate = GenerateShortTextMatchingData( + language="English", + llm=..., # LLM instance + ) + + task >> generate + ``` + """ + + language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + _template_name: str = PrivateAttr(default="short-text-matching") + + def format_input(self, input: Dict[str, Any]) -> ChatType: + """Method to format the input based on the `task` and the provided attributes, or just + randomly sampling those if not provided. This method will render the `_template` with + the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that + there's only one turn, being from the user with the content being the rendered `_template`. + + Args: + input: The input dictionary containing the `task` to be used in the `_template`. + + Returns: + A list with a single chat containing the user's message with the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + task=input["task"], + language=self.language, + ).strip(), + } + ] + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return ["input", "positive_document"] + + +class GenerateLongTextMatchingData(_EmbeddingDataGeneration): + """Generate long text matching data with an `LLM` to later on train an embedding model. + + `GenerateLongTextMatchingData` is a `Task` that generates long text matching data with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Note: + Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True` + with the `category="text-matching-long"`; so that the `LLM` generates a list of tasks that + are flattened so that each row contains a single task for the text-matching-long category. + + Attributes: + language: The language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + Note that in this task the `seed` has no effect since there are no sampling params. + + References: + - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368) + + Examples: + + Generate synthetic long text matching data for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateLongTextMatchingData + + with Pipeline("my-pipeline") as pipeline: + task = EmbeddingTaskGenerator( + category="text-matching-long", + flatten_tasks=True, + llm=..., # LLM instance + ) + + generate = GenerateLongTextMatchingData( + language="English", + llm=..., # LLM instance + ) + + task >> generate + ``` + """ + + language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + _template_name: str = PrivateAttr(default="long-text-matching") + + def format_input(self, input: Dict[str, Any]) -> ChatType: + """Method to format the input based on the `task` and the provided attributes, or just + randomly sampling those if not provided. This method will render the `_template` with + the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that + there's only one turn, being from the user with the content being the rendered `_template`. + + Args: + input: The input dictionary containing the `task` to be used in the `_template`. + + Returns: + A list with a single chat containing the user's message with the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + task=input["task"], + language=self.language, + ).strip(), + } + ] + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return ["input", "positive_document"] + + +class GenerateTextClassificationData(_EmbeddingDataGeneration): + """Generate text classification data with an `LLM` to later on train an embedding model. + + `GenerateTextClassificationData` is a `Task` that generates text classification data with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Note: + Ideally this task should be used with `EmbeddingTaskGenerator` with `flatten_tasks=True` + with the `category="text-classification"`; so that the `LLM` generates a list of tasks that + are flattened so that each row contains a single task for the text-classification category. + + Attributes: + language: The language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + difficulty: The difficulty of the query to be generated, which can be `high school`, `college`, or `PhD`. + Defaults to `None`, meaning that it will be randomly sampled. + clarity: The clarity of the query to be generated, which can be `clear`, `understandable with some effort`, + or `ambiguous`. Defaults to `None`, meaning that it will be randomly sampled. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + + References: + - [Improving Text Embeddings with Large Language Models](https://arxiv.org/abs/2401.00368) + + Examples: + + Generate synthetic text classification data for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import EmbeddingTaskGenerator, GenerateTextClassificationData + + with Pipeline("my-pipeline") as pipeline: + task = EmbeddingTaskGenerator( + category="text-classification", + flatten_tasks=True, + llm=..., # LLM instance + ) + + generate = GenerateTextClassificationData( + language="English", + difficulty="high school", + clarity="clear", + llm=..., # LLM instance + ) + + task >> generate + ``` + """ + + language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + difficulty: Optional[Literal["high school", "college", "PhD"]] = None + clarity: Optional[ + Literal["clear", "understandable with some effort", "ambiguous"] + ] = None + + _template_name: str = PrivateAttr(default="text-classification") + + def format_input(self, input: Dict[str, Any]) -> ChatType: + """Method to format the input based on the `task` and the provided attributes, or just + randomly sampling those if not provided. This method will render the `_template` with + the provided arguments and return an OpenAI formatted chat i.e. a `ChatType`, assuming that + there's only one turn, being from the user with the content being the rendered `_template`. + + Args: + input: The input dictionary containing the `task` to be used in the `_template`. + + Returns: + A list with a single chat containing the user's message with the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + task=input["task"], + language=self.language, + difficulty=self.difficulty + or random.choice(["high school", "college", "PhD"]), + clarity=self.clarity + or random.choice( + ["clear", "understandable with some effort", "ambiguous"] + ), + ).strip(), + } + ] + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return ["input_text", "label", "misleading_label"] + + +class MonolingualTripletGenerator(_EmbeddingDataGenerator): + """Generate monolingual triplets with an `LLM` to later on train an embedding model. + + `MonolingualTripletGenerator` is a `GeneratorTask` that generates monolingual triplets with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Attributes: + language: The language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + unit: The unit of the data to be generated, which can be `sentence`, `phrase`, or `passage`. + Defaults to `None`, meaning that it will be randomly sampled. + difficulty: The difficulty of the query to be generated, which can be `elementary school`, `high school`, or `college`. + Defaults to `None`, meaning that it will be randomly sampled. + high_score: The high score of the query to be generated, which can be `4`, `4.5`, or `5`. + Defaults to `None`, meaning that it will be randomly sampled. + low_score: The low score of the query to be generated, which can be `2.5`, `3`, or `3.5`. + Defaults to `None`, meaning that it will be randomly sampled. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + + Examples: + + Generate monolingual triplets for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import MonolingualTripletGenerator + + with Pipeline("my-pipeline") as pipeline: + task = MonolingualTripletGenerator( + language="English", + unit="sentence", + difficulty="elementary school", + high_score="4", + low_score="2.5", + llm=..., + ) + + ... + + task >> ... + ``` + """ + + language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + unit: Optional[Literal["sentence", "phrase", "passage"]] = None + difficulty: Optional[Literal["elementary school", "high school", "college"]] = None + high_score: Optional[Literal["4", "4.5", "5"]] = None + low_score: Optional[Literal["2.5", "3", "3.5"]] = None + + _template_name: str = PrivateAttr(default="monolingual-triplet") + + @property + def prompt(self) -> ChatType: + """Contains the `prompt` to be used in the `process` method, rendering the `_template`; and + formatted as an OpenAI formatted chat i.e. a `ChatType`, assuming that there's only one turn, + being from the user with the content being the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + language=self.language, + unit=self.unit or random.choice(["sentence", "phrase", "passage"]), + difficulty=self.difficulty + or random.choice(["elementary school", "high school", "college"]), + high_score=self.high_score or random.choice(["4", "4.5", "5"]), + low_score=self.low_score or random.choice(["2.5", "3", "3.5"]), + ).strip(), + } + ] # type: ignore + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return ["S1", "S2", "S3"] + + +class BitextRetrievalGenerator(_EmbeddingDataGenerator): + """Generate bitext retrieval data with an `LLM` to later on train an embedding model. + + `BitextRetrievalGenerator` is a `GeneratorTask` that generates bitext retrieval data with an + `LLM` to later on train an embedding model. The task is based on the paper "Improving + Text Embeddings with Large Language Models" and the data is generated based on the + provided attributes, or randomly sampled if not provided. + + Attributes: + source_language: The source language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + target_language: The target language of the data to be generated, which can be any of the languages + retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf. + unit: The unit of the data to be generated, which can be `sentence`, `phrase`, or `passage`. + Defaults to `None`, meaning that it will be randomly sampled. + difficulty: The difficulty of the query to be generated, which can be `elementary school`, `high school`, or `college`. + Defaults to `None`, meaning that it will be randomly sampled. + high_score: The high score of the query to be generated, which can be `4`, `4.5`, or `5`. + Defaults to `None`, meaning that it will be randomly sampled. + low_score: The low score of the query to be generated, which can be `2.5`, `3`, or `3.5`. + Defaults to `None`, meaning that it will be randomly sampled. + seed: The random seed to be set in case there's any sampling within the `format_input` method. + + Examples: + + Generate bitext retrieval data for training embedding models: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks import BitextRetrievalGenerator + + with Pipeline("my-pipeline") as pipeline: + task = BitextRetrievalGenerator( + source_language="English", + target_language="Spanish", + unit="sentence", + difficulty="elementary school", + high_score="4", + low_score="2.5", + llm=..., + ) + + ... + + task >> ... + ``` + """ + + source_language: str = Field( + default="English", + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + target_language: str = Field( + default=..., + description="The languages are retrieved from the list of XLM-R in the Appendix A of https://aclanthology.org/2020.acl-main.747.pdf", + ) + + unit: Optional[Literal["sentence", "phrase", "passage"]] = None + difficulty: Optional[Literal["elementary school", "high school", "college"]] = None + high_score: Optional[Literal["4", "4.5", "5"]] = None + low_score: Optional[Literal["2.5", "3", "3.5"]] = None + + _template_name: str = PrivateAttr(default="bitext-retrieval") + + @property + def prompt(self) -> ChatType: + """Contains the `prompt` to be used in the `process` method, rendering the `_template`; and + formatted as an OpenAI formatted chat i.e. a `ChatType`, assuming that there's only one turn, + being from the user with the content being the rendered `_template`. + """ + return [ + { + "role": "user", + "content": self._template.render( # type: ignore + source_language=self.source_language, + target_language=self.target_language, + unit=self.unit or random.choice(["sentence", "phrase", "passage"]), + difficulty=self.difficulty + or random.choice(["elementary school", "high school", "college"]), + high_score=self.high_score or random.choice(["4", "4.5", "5"]), + low_score=self.low_score or random.choice(["2.5", "3", "3.5"]), + ).strip(), + } + ] # type: ignore + + @property + def keys(self) -> List[str]: + """Contains the `keys` that will be parsed from the `LLM` output into a Python dict.""" + return ["S1", "S2", "S3"] diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/bitext-retrieval.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/bitext-retrieval.jinja2 new file mode 100644 index 0000000000..1cf238015f --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/bitext-retrieval.jinja2 @@ -0,0 +1,13 @@ +Write a {{ unit }} triple with one {{ unit }} in {{ source_language }} and two {{ unit }}s in {{ target_language }} with varying translation qualities in JSON format. + +The triple is denotes as ("S1", "S2", "S3"). The translation quality score ranges from 1 to 5, with higher scores are better. + +Please adhere to the following guidelines: + - The values of "S1" is a string in {{ source_language }}, the value of "S2" and "S3" are strings in {{ target_language }}. + - There should be some word overlaps between "S2" and "S3". + - The translation quality score of "S2" with respect to "S1" should be {{ high_score }}. + - The translation quality score of "S3" with respect to "S1" should be {{ low_score }}. + - "S3" should be grammatical and fluent, but contain some keyword or number translation errors, or miss some information, or contain some redundant information. + - "S1" requires {{ difficulty }} level education to understand and should be diverse in terms of topic and length. + +Your output must always be a JSON object only with three keys "S1", "S2" and "S3", do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-classification.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-classification.jinja2 new file mode 100644 index 0000000000..3501b9332d --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-classification.jinja2 @@ -0,0 +1,6 @@ +Brainstorm a list of potentially useful text classification tasks. + +Please adhere to the following guidelines: + - Tasks should cover a diverse range of domains and task types. + +Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct text classification task in one sentence. Do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-long.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-long.jinja2 new file mode 100644 index 0000000000..0090ef2af4 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-long.jinja2 @@ -0,0 +1,7 @@ +Brainstorm a list of text matching tasks where the queries are long documents. + +Here are a few examples: + - Given a document that supports a debatable argument, find another document that contains opposite arguments. + - Provided a lengthy business proposal, retrieve competitive business strategies in the same industry. + +Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct task in one sentence. Do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-short.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-short.jinja2 new file mode 100644 index 0000000000..cf42fddae5 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-matching-short.jinja2 @@ -0,0 +1,8 @@ +Brainstorm a list of text matching tasks where both the queries and the groundtruth documents are very short (one or two sentences, even a short phrase). + +Here are a few examples: + - Given a scientific paper title, retrieve the title of papers that cite the given paper. + - Match a word with its definition. + - Provided a notable person's name, identify their occupation or achievement. + +Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct task in one sentence. Do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-retrieval.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-retrieval.jinja2 new file mode 100644 index 0000000000..464ed0e763 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/brainstorming/text-retrieval.jinja2 @@ -0,0 +1,11 @@ +Brainstorm a list of potentially useful text retrieval tasks. + +Here are a few examples for your reference: + - Provided a scientific claim as query, retrieve documents that help verify or refute the claim. + - Search for documents that answers a FAQ-style query on children's nutrition. + +Please adhere to the following guidelines: + - Specify what the query is, and what the desired documents are. + - Each retrieval task should cover a wide range of queries, and should not be too specific. + +Your output should always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct retrieval task in one sentence. Do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/long-text-matching.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/long-text-matching.jinja2 new file mode 100644 index 0000000000..cd8bf1922a --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/long-text-matching.jinja2 @@ -0,0 +1,12 @@ +You have been assigned a text matching task: {{ task }} + +Your mission is to write one example for this task in JSON format. The JSON object must contain the following keys: + - "input": a string, a random input specified by the task. + - "positive_document": a string, a relevant document for the "input" according to the task. + +Please adhere to the following guidelines: + - The values of all fields should be in {{ language }}. + - Both the "input" and "positive_document" should be long documents (at least 300 words), avoid substantial word overlaps, otherwise the task would be too easy. + - The "input" and "positive_document" should be independent of each other. + +Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/monolingual-triplet.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/monolingual-triplet.jinja2 new file mode 100644 index 0000000000..585d618620 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/monolingual-triplet.jinja2 @@ -0,0 +1,10 @@ +Write a {{ unit }} triple with varying semantic similarity scores in JSON format. The semantic similarity score ranges from 1 to 5, with 1 denotes least similar and 5 denotes most similar. + +Please adhere to the following guidelines: + - The keys in JSON are "S1", "S2", and "S3", the values are all strings in {{ language }}, do not add any other keys. + - There should be some word overlaps between all three {{ unit }}s. + - The similarity score between S1 and S2 should be {{ high_score }}. + - The similarity score between S1 and S3 should be {{ low_score }}. + - The {{ unit }}s require {{ difficulty }} level education to understand and should be diverse in terms of topic and length. + +Your output must always be a JSON object only with three keys "S1", "S2" and "S3", do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/short-text-matching.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/short-text-matching.jinja2 new file mode 100644 index 0000000000..90b08f9e57 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/short-text-matching.jinja2 @@ -0,0 +1,12 @@ +You have been assigned a text matching task: {{ task }} + +Your mission is to write one example for this task in JSON format. The JSON object must contain the following keys: + - "input": a string, a random input specified by the task. + - "positive_document": a string, a relevant document for the "input" according to the task. + +Please adhere to the following guidelines: + - The values of all fields should be in {{ language }}. + - Both the "input" and "positive_document" should be very short (a sentence or a phrase), avoid substantial word overlaps, otherwise the task would be too easy. + - The "input" and "positive_document" should be independent of each other. + +Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-classification.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-classification.jinja2 new file mode 100644 index 0000000000..74a184bc56 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-classification.jinja2 @@ -0,0 +1,15 @@ +You have been assigned a text classification task: {{ task }} + +Your mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys: + - "input_text": a string, the input text specified by the classification task. + - "label": a string, the correct label of the input text. + - "misleading_label": a string, an incorrect label that is related to the task. + +Please adhere to the following guidelines: + - The "input_text" should be diverse in expression. + - The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the "input_text". + - The values for all fields should be in {{ language }}. + - Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make the task too easy. + - The "input_text" is {{ clarity }} and requires {{ difficulty }} level education to comprehend. + +Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative! diff --git a/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-retrieval.jinja2 b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-retrieval.jinja2 new file mode 100644 index 0000000000..c76ac8a698 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/improving_text_embeddings/text-retrieval.jinja2 @@ -0,0 +1,17 @@ +You have been assigned a retrieval task: {{ task }} + +Your mission is to write one text retrieval example for this task in JSON format. The JSON object must contain the following keys: + - "user_query": a string, a random user search query specified by the retrieval task. + - "positive_document": a string, a relevant document for the user query. + - "hard_negative_document": a string, a hard negative document that only appears relevant to the query. + +Please adhere to the following guidelines: + - The "user_query" should be {{ query_type }}, {{ query_length }}, {{ clarity }}, and diverse in topic. + - All documents must be created independent of the query. Avoid copying the query verbatim. It's acceptable if some parts of the "positive_document" are not topically related to the query. + - All documents should be at least {{ num_words}} words long. + - The "hard_negative_document" contains some useful information, but it should be less useful or comprehensive compared to the "positive_document". + - Both the query and documents should be in {{ language }}. + - Do not provide any explanation in any document on why it is relevant or not relevant to the query. + - Both the query and documents require {{ difficulty }} level education to understand. + +Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative! diff --git a/tests/unit/steps/tasks/test_improving_text_embeddings.py b/tests/unit/steps/tasks/test_improving_text_embeddings.py new file mode 100644 index 0000000000..8ab9b2fd51 --- /dev/null +++ b/tests/unit/steps/tasks/test_improving_text_embeddings.py @@ -0,0 +1,406 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, List + +import pytest +from distilabel.llms import LLM +from distilabel.llms.typing import GenerateOutput +from distilabel.pipeline.local import Pipeline +from distilabel.steps.tasks.improving_text_embeddings import ( + BitextRetrievalGenerator, + EmbeddingTaskGenerator, + GenerateLongTextMatchingData, + GenerateShortTextMatchingData, + GenerateTextClassificationData, + GenerateTextRetrievalData, + MonolingualTripletGenerator, +) +from distilabel.steps.tasks.typing import ChatType + + +class MockLLM(LLM): + output: str + + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + def generate( # type: ignore + self, inputs: List[ChatType], num_generations: int = 1 + ) -> List[GenerateOutput]: + return [[self.output] for _ in range(num_generations)] + + +class TestEmbeddingTaskGenerator: + @pytest.mark.parametrize( + "category", + [ + "text-retrieval", + "text-matching-short", + "text-matching-long", + "text-classification", + ], + ) + @pytest.mark.parametrize("flatten_tasks", [True, False]) + def test_process(self, category: str, flatten_tasks: bool) -> None: + task = EmbeddingTaskGenerator( + name="embedding_task_generator", + category=category, # type: ignore + flatten_tasks=flatten_tasks, + add_raw_output=False, + llm=MockLLM(output="[ 'A', 'B', 'C' ]"), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + assert task.outputs == ["tasks" if not flatten_tasks else "task", "model_name"] + + result = ( + ([{"tasks": ["A", "B", "C"], "model_name": "test"}], True) + if not flatten_tasks + else ( + [ + {"task": "A", "model_name": "test"}, + {"task": "B", "model_name": "test"}, + {"task": "C", "model_name": "test"}, + ], + True, + ) + ) + assert next(task.process()) == result + + +class TestBitextRetrievalGenerator: + @pytest.mark.parametrize( + "task_kwargs", + [ + { + "source_language": "English", + "target_language": "French", + "unit": "sentence", + "difficulty": "elementary school", + "high_score": "4", + "low_score": "2.5", + } + ], + ) + def test_prompt(self, task_kwargs: Any) -> None: + task = BitextRetrievalGenerator( + name="bitext_retrieval_generator", + **task_kwargs, + add_raw_output=False, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + assert all( + task.prompt[-1]["content"].__contains__(v) for _, v in task_kwargs.items() + ) + + def test_process(self) -> None: + task = BitextRetrievalGenerator( + name="bitext_retrieval_generator", + source_language="English", + target_language="French", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + assert task.outputs == ["S1", "S2", "S3", "model_name"] + + assert next(task.process()) == ( + [{"S1": "A", "S2": "B", "S3": "C", "model_name": "test"}], + True, + ) + + def test_reproducibility(self) -> None: + unique_prompts = set() + for _ in range(10): + task = BitextRetrievalGenerator( + name="bitext_retrieval_generator", + source_language="English", + target_language="French", + add_raw_output=False, + seed=42, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + unique_prompts.add(task.prompt[-1]["content"]) + + assert len(unique_prompts) == 1 + + +class TestMonolingualTripletGenerator: + @pytest.mark.parametrize( + "task_kwargs", + [ + { + "language": "English", + "unit": "sentence", + "difficulty": "elementary school", + "high_score": "4", + "low_score": "2.5", + } + ], + ) + def test_prompt(self, task_kwargs: Any) -> None: + task = MonolingualTripletGenerator( + name="monolingual_triplet_generator", + **task_kwargs, + add_raw_output=False, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert all( + task.prompt[-1]["content"].__contains__(v) for _, v in task_kwargs.items() + ) + + def test_process(self) -> None: + task = MonolingualTripletGenerator( + name="monolingual_triplet_generator", + language="English", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.outputs == ["S1", "S2", "S3", "model_name"] + assert next(task.process()) == ( + [{"S1": "A", "S2": "B", "S3": "C", "model_name": "test"}], + True, + ) + + def test_reproducibility(self) -> None: + unique_prompts = set() + for _ in range(10): + task = MonolingualTripletGenerator( + name="monolingual_triplet_generator", + language="English", + add_raw_output=False, + seed=42, + llm=MockLLM(output=json.dumps({"S1": "A", "S2": "B", "S3": "C"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + unique_prompts.add(task.prompt[-1]["content"]) + assert len(unique_prompts) == 1 + + +class TestGenerateLongTextMatchingData: + def test_format_input(self) -> None: + task = GenerateLongTextMatchingData( + name="generate_long_text_matching_data", + language="English", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + assert task.format_input({"task": "A"})[-1]["content"].startswith( + "You have been assigned a text matching task: A" + ) + + def test_process(self) -> None: + task = GenerateLongTextMatchingData( + name="generate_long_text_matching_data", + language="English", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + + assert task.outputs == ["input", "positive_document", "model_name"] + + assert next(task.process(inputs=[{"task": "A"}])) == [ + {"task": "A", "input": "A", "positive_document": "B", "model_name": "test"} + ] + + +class TestGenerateShortTextMatchingData: + def test_format_input(self) -> None: + task = GenerateShortTextMatchingData( + name="generate_short_text_matching_data", + language="English", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.format_input({"task": "A"})[-1]["content"].startswith( + "You have been assigned a text matching task: A" + ) + + def test_process(self) -> None: + task = GenerateShortTextMatchingData( + name="generate_short_text_matching_data", + language="English", + add_raw_output=False, + llm=MockLLM(output=json.dumps({"input": "A", "positive_document": "B"})), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.outputs == ["input", "positive_document", "model_name"] + assert next(task.process(inputs=[{"task": "A"}])) == [ + {"task": "A", "input": "A", "positive_document": "B", "model_name": "test"} + ] + + def test_reproducibility(self) -> None: + unique_prompts = set() + for _ in range(10): + task = GenerateShortTextMatchingData( + name="generate_short_text_matching_data", + language="English", + add_raw_output=False, + seed=42, + llm=MockLLM( + output=json.dumps({"input": "A", "positive_document": "B"}) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + unique_prompts.add(task.format_input({"task": "A"})[-1]["content"]) + + assert len(unique_prompts) == 1 + + +class TestGenerateTextClassificationData: + def test_format_input(self) -> None: + task = GenerateTextClassificationData( + name="generate_text_classification_data", + language="English", + add_raw_output=False, + llm=MockLLM( + output=json.dumps( + {"input_text": "A", "label": "B", "misleading_label": "C"} + ) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.format_input({"task": "A"})[-1]["content"].startswith( + "You have been assigned a text classification task: A" + ) + + def test_process(self) -> None: + task = GenerateTextClassificationData( + name="generate_text_classification_data", + language="English", + add_raw_output=False, + llm=MockLLM( + output=json.dumps( + {"input_text": "A", "label": "B", "misleading_label": "C"} + ) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.outputs == ["input_text", "label", "misleading_label", "model_name"] + assert next(task.process(inputs=[{"task": "A"}])) == [ + { + "task": "A", + "input_text": "A", + "label": "B", + "misleading_label": "C", + "model_name": "test", + } + ] + + def test_reproducibility(self) -> None: + unique_prompts = set() + for _ in range(10): + task = GenerateTextClassificationData( + name="generate_text_classification_data", + language="English", + add_raw_output=False, + seed=42, + llm=MockLLM( + output=json.dumps( + {"input_text": "A", "label": "B", "misleading_label": "C"} + ) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + unique_prompts.add(task.format_input({"task": "A"})[-1]["content"]) + + assert len(unique_prompts) == 1 + + +class TestGenerateTextRetrievalData: + def test_format_input(self) -> None: + task = GenerateTextRetrievalData( + name="generate_text_retrieval_data", + language="English", + add_raw_output=False, + llm=MockLLM( + output=json.dumps( + { + "user_query": "A", + "positive_document": "B", + "hard_negative_document": "C", + } + ) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.format_input({"task": "A"})[-1]["content"].startswith( + "You have been assigned a retrieval task: A" + ) + + def test_process(self) -> None: + task = GenerateTextRetrievalData( + name="generate_text_retrieval_data", + language="English", + add_raw_output=False, + llm=MockLLM( + output=json.dumps( + { + "user_query": "A", + "positive_document": "B", + "hard_negative_document": "C", + } + ) + ), + pipeline=Pipeline(name="unit-test-pipeline"), + ) + task.load() + assert task.outputs == [ + "user_query", + "positive_document", + "hard_negative_document", + "model_name", + ] + assert next(task.process(inputs=[{"task": "A"}])) == [ + { + "task": "A", + "user_query": "A", + "positive_document": "B", + "hard_negative_document": "C", + "model_name": "test", + } + ] diff --git a/tests/unit/test_imports.py b/tests/unit/test_imports.py index 07eb3e5b63..e20e186c8e 100644 --- a/tests/unit/test_imports.py +++ b/tests/unit/test_imports.py @@ -74,6 +74,13 @@ def test_imports() -> None: EvolInstructGenerator, GenerateEmbeddings, Genstruct, + BitextRetrievalGenerator, + EmbeddingTaskGenerator, + GenerateLongTextMatchingData, + GenerateShortTextMatchingData, + GenerateTextClassificationData, + GenerateTextRetrievalData, + MonolingualTripletGenerator, InstructionBacktranslation, PairRM, PrometheusEval,