diff --git a/docs/sections/learn/tutorial/task/index.md b/docs/sections/learn/tutorial/task/index.md index 6f6f1259bd..2e322895e7 100644 --- a/docs/sections/learn/tutorial/task/index.md +++ b/docs/sections/learn/tutorial/task/index.md @@ -8,6 +8,7 @@ The subclasses of [`Task`][distilabel.steps.tasks.Task] are intended to be used For example, the most basic task is the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task, which generates text based on a given instruction, and it can be used standalone as well as within a [`Pipeline`][distilabel.pipeline.Pipeline]. +```python ```python from distilabel.steps.tasks import TextGeneration @@ -18,12 +19,23 @@ task = TextGeneration( task.load() next(task.process([{"instruction": "What's the capital of Spain?"}])) -# [{'instruction': "What's the capital of Spain?", "generation": "The capital of Spain is Madrid.", "model_name": "gpt-4"}] +# [ +# { +# "instruction": "What's the capital of Spain?", +# "generation": "The capital of Spain is Madrid.", +# "model_name": "gpt-4", +# "distilabel_metadata": { +# "raw_output_text-generation": "The capital of Spain is Madrid" +# } +# } +# ] ``` !!! NOTE The `load` method needs to be called ALWAYS if using the tasks as standalone, otherwise, if the [`Pipeline`][distilabel.pipeline.Pipeline] context manager is used, there's no need to call that method, since it will be automatically called on `Pipeline.run`; but in any other case the method `load` needs to be called from the parent class e.g. a [`Task`][distilabel.steps.tasks.Task] with an [`LLM`][distilabel.llms.LLM] will need to call `Task.load` to load both the task and the LLM. +As we can see in the comment of the code snippet above, the task has enriched the input dictionaries adding the `generation`, the `model_name` that was used to generate, and finally the `distilabel_metadata` dictionary that contains the raw output (without post-processing) from the LLM. In this case, the `TextGeneration` task does no post-processing, so the `generation` and the raw output is the same, but some other tasks do post-processing, which in some situations it can fail. That's why is useful to have the raw output available in the `distilabel_metadata` dictionary. If this default behaviour is not desired, then all the `Task`s has a `add_raw_output` attribute that we can set to `False` when creating the instance of the task or at run time. + ## Defining custom Tasks In order to define custom tasks, we need to inherit from the [`Task`][distilabel.steps.tasks.Task] class and implement the `format_input` and `format_output` methods, as well as setting the properties `inputs` and `outputs`, as for [`Step`][distilabel.steps.Step] subclasses. diff --git a/src/distilabel/mixins/runtime_parameters.py b/src/distilabel/mixins/runtime_parameters.py index 5959244803..9a6fec512b 100644 --- a/src/distilabel/mixins/runtime_parameters.py +++ b/src/distilabel/mixins/runtime_parameters.py @@ -115,13 +115,13 @@ def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None: name, runtime_parameters_names, cutoff=0.5 ) msg = ( - f"⚠️ Runtime parameter '{name}' unknown in step '{self.name}'." + f"⚠️ Runtime parameter '{name}' unknown in step '{self.name}'." # type: ignore ) if closest: msg += f" Did you mean any of: {closest}" else: msg += f" Available runtime parameters for the step: {runtime_parameters_names}." - self.pipeline._logger.warning(msg) + self.pipeline._logger.warning(msg) # type: ignore continue attr = getattr(self, name) diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index bc77a53536..9150d8c3d0 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -298,11 +298,18 @@ def run( The `Distiset` created by the pipeline. """ - setup_logging(**self._logging_parameters) - - # Set the runtime parameters that will be used during the pipeline execution + # Set the runtime parameters that will be used during the pipeline execution. + # They are used to generate the signature of the pipeline that is used to hit the + # cache when the pipeline is run, so it's important to do it first. self._set_runtime_parameters(parameters or {}) + setup_logging( + **{ + **self._logging_parameters, + "filename": str(self._cache_location["log_file"]), + } + ) + # Validate the pipeline DAG to check that all the steps are chainable, there are # no missing runtime parameters, batch sizes are correct, etc. self.dag.validate() diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index f785f9eb6a..72e2216e71 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -30,17 +30,15 @@ from distilabel.steps.tasks.prometheus_eval import PrometheusEval from distilabel.steps.tasks.quality_scorer import QualityScorer from distilabel.steps.tasks.self_instruct import SelfInstruct +from distilabel.steps.tasks.sentence_transformers import GenerateSentencePair from distilabel.steps.tasks.structured_generation import StructuredGeneration from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback __all__ = [ - "Task", "GeneratorTask", - "ChatGeneration", - "ChatItem", - "ChatType", + "Task", "ComplexityScorer", "EvolInstruct", "EvolComplexity", @@ -54,7 +52,11 @@ "PrometheusEval", "QualityScorer", "SelfInstruct", + "GenerateSentencePair", "StructuredGeneration", + "ChatGeneration", "TextGeneration", + "ChatItem", + "ChatType", "UltraFeedback", ] diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index 7e2cfc2520..2c19d8c8a4 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -52,7 +52,13 @@ class _Task(_Step, ABC): llm: LLM group_generations: bool = False - add_raw_output: bool = False + add_raw_output: RuntimeParameter[bool] = Field( + default=True, + description=( + "Whether to include the raw output of the LLM in the key `raw_output_`" + " of the `distilabel_metadata` dictionary output column" + ), + ) num_generations: RuntimeParameter[int] = Field( default=1, description="The number of generations to be produced per input." ) diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py new file mode 100644 index 0000000000..12e39eb08e --- /dev/null +++ b/src/distilabel/steps/tasks/sentence_transformers.py @@ -0,0 +1,254 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys +from typing import TYPE_CHECKING, Any, Dict, Final, List, Literal, Optional, Union + +from jinja2 import Template + +from distilabel.steps.tasks.base import Task + +if sys.version_info < (3, 9): + import importlib_resources +else: + import importlib.resources as importlib_resources + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType + +GenerationAction = Literal["paraphrase", "semantically-similar", "query", "answer"] + +POSITIVE_NEGATIVE_PAIR_REGEX = re.compile( + r"## Positive\s+(.*?)(?:\s+## Negative\s+(.*?))?\s*$", + re.DOTALL, +) + +GENERATION_ACTION_SENTENCES: Final[Dict[GenerationAction, str]] = { + "paraphrase": "paraphrase", + "semantically-similar": "be semantically similar to", + "query": "be a query for", + "answer": "be an answer for", +} + +POSITIVE_SYSTEM_PROMPT: str = ( + "Your task is to generate a positive sentence given an anchor sentence. The positive" + " sentence has to {action_sentence} the anchor sentence. You must output only one new" + " section: `## Positive`." +) + +POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = ( + "Your task is to generate a positive and a negative sentence given an anchor sentence." + " The positive sentence has to {action_sentence} the anchor sentence, while the negative" + " sentence can use similar words but must not be related to the anchor sentence. You" + " must output only two new sections: `## Positive` and `## Negative`." +) + + +class GenerateSentencePair(Task): + """Generate a positive and negative (optionally) sentences given an anchor sentence. + + `GenerateSentencePair` is a pre-defined task that given an anchor sentence generates + a positive sentence related to the anchor and optionally a negative sentence unrelated + to the anchor. This task is useful to generate training datasets for training embeddings + models. + + Attributes: + triplet: a flag to indicate if the task should generate a triplet of sentences + (anchor, positive, negative). Defaults to `False`. + action: the action to perform to generate the positive sentence. + + Input columns: + - anchor (`str`): The anchor sentence to generate the positive and negative sentences. + + Output columns: + - positive (`str`): The positive sentence related to the `anchor`. + - negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`. + - model_name (`str`): The name of the model that was used to generate the sentences. + + Categories: + - embedding + + Examples: + + Paraphrasing: + + ```python + from distilabel.steps.tasks import GenerateSentencePair + from distilabel.llms import InferenceEndpointsLLM + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="paraphrase", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}]) + ``` + + Generating semantically similar sentences: + + ```python + from distilabel.llms import InferenceEndpointsLLM + from distilabel.steps.tasks import GenerateSentencePair + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="semantically-similar", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "How does 3D printing work?"}]) + ``` + + Generating queries: + + ```python + from distilabel.steps.tasks import GenerateSentencePair + from distilabel.llms import InferenceEndpointsLLM + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="query", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "Argilla is an open-source data curation platform for LLMs. Using Argilla, ..."}]) + ``` + + Generating answers: + + ```python + from distilabel.steps.tasks import GenerateSentencePair + from distilabel.llms import InferenceEndpointsLLM + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="answer", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}]) + ``` + """ + + triplet: bool = False + action: GenerationAction + + def load(self) -> None: + """Loads the Jinja2 template.""" + super().load() + + _path = str( + importlib_resources.files("distilabel") + / "steps" + / "tasks" + / "templates" + / "generate-sentence-pair.jinja2" + ) + + self._template = Template(open(_path).read()) + + @property + def inputs(self) -> List[str]: + """The inputs for the task is the `anchor` sentence.""" + return ["anchor"] + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + """The inputs are formatted as a `ChatType`, with a system prompt describing the + task of generating a positive and negative sentences for the anchor sentence. The + anchor is provided as the first user interaction in the conversation. + + Args: + input: The input containing the `anchor` sentence. + + Returns: + A list of dictionaries containing the system and user interactions. + """ + action_sentence = GENERATION_ACTION_SENTENCES[self.action] + system_prompt = ( + POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT + ).format(action_sentence=action_sentence) + + return [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": self._template.render(anchor=input["anchor"])}, + ] + + @property + def outputs(self) -> List[str]: + """The outputs for the task are the `positive` and `negative` sentences, as well + as the `model_name` used to generate the sentences.""" + columns = ["positive", "negative"] if self.triplet else ["positive"] + columns += ["model_name"] + return columns + + def format_output( + self, output: Union[str, None], input: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Formats the output of the LLM, to extract the `positive` and `negative` sentences + generated. If the output is `None` or the regex doesn't match, then the outputs + will be set to `None` as well. + + Args: + output: The output of the LLM. + input: The input used to generate the output. + + Returns: + The formatted output containing the `positive` and `negative` sentences. + """ + if output is None: + return {"positive": None, "negative": None} + + match = POSITIVE_NEGATIVE_PAIR_REGEX.match(output) + if match is None: + formatted_output = {"positive": None} + if self.triplet: + formatted_output["negative"] = None + return formatted_output + + groups = match.groups() + if self.triplet: + return { + "positive": groups[0].strip(), + "negative": groups[1].strip() + if len(groups) > 1 and groups[1] is not None + else None, + } + + return {"positive": groups[0].strip()} diff --git a/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 b/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 new file mode 100644 index 0000000000..82594f18a8 --- /dev/null +++ b/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 @@ -0,0 +1,4 @@ +## Anchor + +{{ anchor }} + diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2 b/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2 index 486fc4b299..4be5fdc1ab 100644 --- a/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2 +++ b/src/distilabel/utils/mkdocs/templates/components-gallery/step-detail.jinja2 @@ -56,7 +56,7 @@ {% for example_title, code in step.docstring.examples.items() %} #### {{ example_title }} ```python -{{ code | e }} +{{ code | replace("\n", "\n") }} ``` {% endfor %} {% endif %} diff --git a/tests/unit/steps/tasks/evol_instruct/test_base.py b/tests/unit/steps/tasks/evol_instruct/test_base.py index d3999684c1..cea6fd75ca 100644 --- a/tests/unit/steps/tasks/evol_instruct/test_base.py +++ b/tests/unit/steps/tasks/evol_instruct/test_base.py @@ -121,7 +121,7 @@ def test_serialization(self, dummy_llm: LLM) -> None: task.load() assert task.dump() == { "name": "task", - "add_raw_output": False, + "add_raw_output": True, "input_mappings": task.input_mappings, "output_mappings": task.output_mappings, "input_batch_size": task.input_batch_size, @@ -163,6 +163,11 @@ def test_serialization(self, dummy_llm: LLM) -> None: } ], }, + { + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + "name": "add_raw_output", + "optional": True, + }, { "name": "num_generations", "optional": True, diff --git a/tests/unit/steps/tasks/evol_instruct/test_generator.py b/tests/unit/steps/tasks/evol_instruct/test_generator.py index fee2234083..bdc09c9162 100644 --- a/tests/unit/steps/tasks/evol_instruct/test_generator.py +++ b/tests/unit/steps/tasks/evol_instruct/test_generator.py @@ -123,7 +123,7 @@ def test_serialization(self, dummy_llm: LLM) -> None: "name": task.llm.__class__.__name__, }, }, - "add_raw_output": False, + "add_raw_output": True, "input_mappings": task.input_mappings, "output_mappings": task.output_mappings, "batch_size": task.batch_size, @@ -158,6 +158,11 @@ def test_serialization(self, dummy_llm: LLM) -> None: }, ], }, + { + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + "name": "add_raw_output", + "optional": True, + }, { "name": "num_generations", "optional": True, diff --git a/tests/unit/steps/tasks/evol_quality/test_base.py b/tests/unit/steps/tasks/evol_quality/test_base.py index 251e377d9e..d4445a4613 100644 --- a/tests/unit/steps/tasks/evol_quality/test_base.py +++ b/tests/unit/steps/tasks/evol_quality/test_base.py @@ -80,7 +80,7 @@ def test_serialization(self, dummy_llm: LLM) -> None: task.load() assert task.dump() == { "name": "task", - "add_raw_output": False, + "add_raw_output": True, "input_mappings": task.input_mappings, "output_mappings": task.output_mappings, "input_batch_size": task.input_batch_size, @@ -112,9 +112,14 @@ def test_serialization(self, dummy_llm: LLM) -> None: "name": "generation_kwargs", "description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.", "keys": [], - } + }, ], }, + { + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + "name": "add_raw_output", + "optional": True, + }, { "name": "num_generations", "optional": True, diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index 0cccbd5c9d..4a9f566c6a 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -94,16 +94,19 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "instruction": "test", "output": "output", "model_name": "test", + "distilabel_metadata": {"raw_output_task": "output"}, }, { "instruction": "test", "output": "output", "model_name": "test", + "distilabel_metadata": {"raw_output_task": "output"}, }, { "instruction": "test", "output": "output", "model_name": "test", + "distilabel_metadata": {"raw_output_task": "output"}, }, ], ), @@ -114,6 +117,11 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "instruction": "test", "output": ["output", "output", "output"], "model_name": "test", + "distilabel_metadata": [ + {"raw_output_task": "output"}, + {"raw_output_task": "output"}, + {"raw_output_task": "output"}, + ], }, ], ), @@ -188,7 +196,7 @@ def test_serialization(self) -> None: task = DummyTask(name="task", llm=llm, pipeline=pipeline) assert task.dump() == { "name": "task", - "add_raw_output": False, + "add_raw_output": True, "input_mappings": {}, "output_mappings": {}, "input_batch_size": 50, @@ -224,6 +232,11 @@ def test_serialization(self) -> None: }, ], }, + { + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + "name": "add_raw_output", + "optional": True, + }, { "name": "num_generations", "description": "The number of generations to be produced per input.", diff --git a/tests/unit/steps/tasks/test_instruction_backtranslation.py b/tests/unit/steps/tasks/test_instruction_backtranslation.py index 4c8a8df7fe..a6f2793285 100644 --- a/tests/unit/steps/tasks/test_instruction_backtranslation.py +++ b/tests/unit/steps/tasks/test_instruction_backtranslation.py @@ -86,5 +86,8 @@ def test_process(self) -> None: "score": 1, "reason": "This is the reason.", "model_name": "instruction-backtranslation-model", + "distilabel_metadata": { + "raw_output_instruction-backtranslation": "This is the reason. Score: 1" + }, } ] diff --git a/tests/unit/steps/tasks/test_sentence_transformers.py b/tests/unit/steps/tasks/test_sentence_transformers.py new file mode 100644 index 0000000000..3e50e7e3f1 --- /dev/null +++ b/tests/unit/steps/tasks/test_sentence_transformers.py @@ -0,0 +1,129 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict + +import pytest +from distilabel.steps.tasks.sentence_transformers import ( + POSITIVE_NEGATIVE_SYSTEM_PROMPT, + POSITIVE_SYSTEM_PROMPT, + GenerateSentencePair, + GenerationAction, +) + +from tests.unit.steps.tasks.utils import DummyLLM + + +class TestGenerateSentencePair: + @pytest.mark.parametrize( + "action,triplet,system_prompt", + [ + ( + "paraphrase", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase"), + ), + ( + "paraphrase", + False, + POSITIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase"), + ), + ( + "semantically-similar", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be semantically similar to" + ), + ), + ( + "semantically-similar", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be semantically similar to" + ), + ), + ( + "query", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be a query for" + ), + ), + ( + "query", + False, + POSITIVE_SYSTEM_PROMPT.format(action_sentence="be a query for"), + ), + ( + "answer", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be an answer for" + ), + ), + ( + "answer", + False, + POSITIVE_SYSTEM_PROMPT.format(action_sentence="be an answer for"), + ), + ], + ) + def test_format_input( + self, action: GenerationAction, triplet: bool, system_prompt: str + ) -> None: + task = GenerateSentencePair(llm=DummyLLM(), action=action, triplet=triplet) + task.load() + + assert task.format_input({"anchor": "This is a unit test"}) == [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "## Anchor\n\nThis is a unit test\n"}, + ] + + @pytest.mark.parametrize( + "output,triplet,expected", + [ + ( + "## Positive\n\nThis is a paraphrase\n## Negative\n\nThis is not a paraphrase", + True, + { + "positive": "This is a paraphrase", + "negative": "This is not a paraphrase", + }, + ), + ( + "## Positive\n\nThis is a paraphrase", + True, + {"positive": "This is a paraphrase", "negative": None}, + ), + ( + "## Positive\n\nThis is a paraphrase", + False, + {"positive": "This is a paraphrase"}, + ), + ( + "random", + False, + {"positive": None}, + ), + ], + ) + def test_format_output( + self, output: str, triplet: bool, expected: Dict[str, Any] + ) -> None: + task = GenerateSentencePair( + llm=DummyLLM(), action="paraphrase", triplet=triplet + ) + task.load() + + assert task.format_output(output) == expected diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py index c4766aaa57..febc7f698f 100644 --- a/tests/unit/steps/tasks/test_structured_generation.py +++ b/tests/unit/steps/tasks/test_structured_generation.py @@ -120,5 +120,6 @@ def test_process(self) -> None: }, "generation": '{"test": "output"}', "model_name": "test", + "distilabel_metadata": {"raw_output_task": '{"test": "output"}'}, } ] diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index d07ba464a3..545cf6a7b8 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -82,6 +82,9 @@ def test_process(self) -> None: "instruction": "test", "generation": "output", "model_name": "test", + "distilabel_metadata": { + "raw_output_task": "output", + }, } ] @@ -139,5 +142,6 @@ def test_process(self) -> None: "messages": [{"role": "user", "content": "Tell me a joke."}], "generation": "output", "model_name": "test", + "distilabel_metadata": {"raw_output_task": "output"}, } ] diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py index 69e9326570..fa72ff9442 100644 --- a/tests/unit/steps/tasks/test_ultrafeedback.py +++ b/tests/unit/steps/tasks/test_ultrafeedback.py @@ -63,6 +63,9 @@ def test_process_with_simple_aspect(self) -> None: "ratings": [1, 2], "rationales": ["text", "text"], "model_name": "ultrafeedback-model", + "distilabel_metadata": { + "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" + }, } ] @@ -89,5 +92,8 @@ def test_process_with_complex_aspect(self) -> None: "ratings": [1, 2], "rationales-for-ratings": ["text", "text"], "model_name": "ultrafeedback-model", + "distilabel_metadata": { + "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" + }, } ]