From f108670a9707a343184340ead79232e263d71e58 Mon Sep 17 00:00:00 2001 From: plaguss Date: Tue, 22 Oct 2024 11:31:06 +0200 Subject: [PATCH] Checkpoint --- src/distilabel/llms/base.py | 6 + src/distilabel/steps/tasks/base.py | 94 ++++++- tests/unit/conftest.py | 40 ++- tests/unit/steps/tasks/test_base.py | 418 +++++++++++++--------------- 4 files changed, 306 insertions(+), 252 deletions(-) diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index cabfd3706..dcdc346a6 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -459,6 +459,12 @@ async def _agenerate( for input in inputs ] result = await asyncio.gather(*tasks) + print("\n_agenerate\n\n", result) + print("\n_agenerate MERGED\n\n", merge_responses(result)) + print( + "CORRECT merge_response, ITS GROUPING num_generations MIXED WITH THE INPUTS PASSED" + ) + # TODO: Update this, return merge_responses(result) tasks = [ diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index d27a3b80f..f8f33df14 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -34,7 +34,7 @@ from distilabel.utils.dicts import group_dicts if TYPE_CHECKING: - from distilabel.llms.typing import GenerateOutput, LLMOutput, LLMStatistics + from distilabel.llms.typing import GenerateOutput, LLMStatistics from distilabel.steps.tasks.typing import ChatType, FormattedInput from distilabel.steps.typing import StepOutput @@ -170,30 +170,40 @@ def _format_outputs( A list containing a dictionary with the outputs of the task for each input. """ inputs = [None] if input is None else [input] - + print("INPUTS", inputs) formatted_outputs = [] - for output, input in zip(outputs, inputs * len(outputs)): # type: ignore + repeate_inputs = len(outputs.get("generations")) + outputs = normalize_statistics(outputs) + + for (output, stats), input in zip( + iterate_generations_with_stats(outputs), inputs * repeate_inputs + ): # type: ignore + # for output, input in zip(outputs, inputs * len(outputs)): # type: ignore try: # Extract the generations, and move the statistics to the distilabel_metadata, # to keep everything clean - output_generations: "LLMOutput" = output.get("generations", []) - formatted_output = self.format_output(output_generations, input) + # TODO: THIS WOULD FAIL IF THE LLM DOESN'T RETURN generations, + # WE HAVE TO REMOVE THE STATISTICS AND PASS EVERYTHING ELSE + print("OUTPUT", output) + print("STATS", stats) + print("INPUT", input) + # output_generations: "LLMOutput" = output.get("generations", []) + formatted_output = self.format_output(output, input) formatted_output = self._create_metadata( formatted_output, - output_generations, + output, input, add_raw_output=self.add_raw_output, # type: ignore add_raw_input=self.add_raw_input, # type: ignore - statistics=output.get("statistics"), + # statistics=output.get("statistics"), + statistics=stats, ) formatted_outputs.append(formatted_output) except Exception as e: self._logger.warning( # type: ignore f"Task '{self.name}' failed to format output: {e}. Saving raw response." # type: ignore ) - formatted_outputs.append( - self._output_on_failure(output.get("generations", []), input) - ) + formatted_outputs.append(self._output_on_failure(output, input)) return formatted_outputs def _output_on_failure( @@ -437,6 +447,8 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore ) task_outputs = [] + print("INPUTS", inputs) + print("OUTPUTS", outputs) for input, input_outputs in zip(inputs, outputs): formatted_outputs = self._format_outputs(input_outputs, input) @@ -449,6 +461,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore # Create a row per generation for formatted_output in formatted_outputs: + print("FORMATED", formatted_output) task_outputs.append( {**input, **formatted_output, "model_name": self.llm.model_name} ) @@ -477,3 +490,64 @@ class GlobalTask(_Task, GlobalStep): """ pass + + +def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput": + """Transforms the GenerateOutput statistics to have the same length as the generations. + + Args: + data: A generate output that possibly has different lengths of statistics + vs generations (due to num_generations=3 returning 3 generations, but + for example the tokens are only counted once). + + Returns: + Normalized statistics according to the generations length. + + Examples: + ```python + data = { + "generations": ["text1", "text2", "text3", "text4"], + "statistics": {"input_tokens": [1], "output_tokens": [1, 2, 3]} + } + normalize_statistics(data) + data = { + "generations": ["text1", "text2", "text3"], + "statistics": {"input_tokens": [1, 1, 1], "output_tokens": [1, 2, 3]} + } + ``` + """ + if not (statistics := output.get("statistics")): + print(statistics) + return output + gen_length = len(output["generations"]) + + for stat_key, stat_values in output["statistics"].items(): + current_length = len(stat_values) + + if current_length < gen_length: + # Calculate how many times to repeat the tokens + repeats = gen_length // current_length + remainder = gen_length % current_length + + # Create new list with repeated values + new_values = stat_values * repeats + stat_values[:remainder] + output["statistics"][stat_key] = new_values + + return output + + +def iterate_generations_with_stats(output: "GenerateOutput") -> "GenerateOutput": + """Helper function to iterate together generations and statistics while + processing them inside _format_outputs. + + Args: + output: Output from the LLM.generate_outputs method. + + Yields: + Iterator of generation and statistics paired. + """ + for i, generation in enumerate(output["generations"]): + # Create a new dictionary with the statistics for this index + stats = {key: values[i] for key, values in output["statistics"].items()} + + yield generation, stats diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 905b0f723..2127e26a5 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -39,11 +39,20 @@ def model_name(self) -> str: async def agenerate( # type: ignore self, input: "FormattedInput", num_generations: int = 1 ) -> "GenerateOutput": - # return ["output" for _ in range(num_generations)] - return [ - {"generations": "output", "statistics": {"test": "test"}} - for _ in range(num_generations) - ] + # return { + # "generations": ["output"], + # "statistics": { + # "input_tokens": [12], + # "output_tokens": [12], + # }, + # } + return { + "generations": ["output" for i in range(num_generations)], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } class DummyLLM(LLM): @@ -60,11 +69,14 @@ def generate( # type: ignore self, inputs: "FormattedInput", num_generations: int = 1 ) -> List["GenerateOutput"]: return [ - [ - {"generations": "output", "statistics": {"test": "test"}} - for _ in range(num_generations) - ] - ] + { + "generations": [f"output {i}" for i in range(num_generations)], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) class DummyMagpieLLM(LLM, MagpieChatTemplateMixin): @@ -80,7 +92,13 @@ def generate( ) -> List["GenerateOutput"]: return [ [ - {"generations": "output", "statistics": {"test": "test"}} + { + "generations": ["output"] * num_generations, + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } for _ in range(num_generations) ] for _ in range(len(inputs)) diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index ef400acfc..87ce7198d 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -109,7 +109,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, { @@ -124,22 +124,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], - "statistics": {"test": "test"}, - }, - }, - { - "instruction": "test_0", - "additional_info": "additional_info_0", - "output": "output", - "info_from_input": "additional_info_0", - "model_name": "test", - "distilabel_metadata": { - "raw_output_task": "output", - "raw_input_task": [ - {"content": "", "role": "system"}, - {"content": "test_0", "role": "user"}, - ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, { @@ -154,7 +139,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, { @@ -169,37 +154,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], - "statistics": {"test": "test"}, - }, - }, - { - "instruction": "test_1", - "additional_info": "additional_info_1", - "output": "output", - "info_from_input": "additional_info_1", - "model_name": "test", - "distilabel_metadata": { - "raw_output_task": "output", - "raw_input_task": [ - {"content": "", "role": "system"}, - {"content": "test_1", "role": "user"}, - ], - "statistics": {"test": "test"}, - }, - }, - { - "instruction": "test_2", - "additional_info": "additional_info_2", - "output": "output", - "info_from_input": "additional_info_2", - "model_name": "test", - "distilabel_metadata": { - "raw_output_task": "output", - "raw_input_task": [ - {"content": "", "role": "system"}, - {"content": "test_2", "role": "user"}, - ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, { @@ -214,7 +169,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, { @@ -229,186 +184,186 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, ], ), - ( - [ - {"instruction": "test_0", "additional_info": "additional_info_0"}, - {"instruction": "test_1", "additional_info": "additional_info_1"}, - {"instruction": "test_2", "additional_info": "additional_info_2"}, - ], - True, - [ - { - "instruction": "test_0", - "additional_info": "additional_info_0", - "output": ["output", "output", "output"], - "info_from_input": [ - "additional_info_0", - "additional_info_0", - "additional_info_0", - ], - "model_name": "test", - "distilabel_metadata": [ - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_0", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_0", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_0", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - ], - }, - { - "instruction": "test_1", - "additional_info": "additional_info_1", - "output": ["output", "output", "output"], - "info_from_input": [ - "additional_info_1", - "additional_info_1", - "additional_info_1", - ], - "model_name": "test", - "distilabel_metadata": [ - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_1", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_1", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_1", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - ], - }, - { - "instruction": "test_2", - "additional_info": "additional_info_2", - "output": ["output", "output", "output"], - "info_from_input": [ - "additional_info_2", - "additional_info_2", - "additional_info_2", - ], - "model_name": "test", - "distilabel_metadata": [ - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_2", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_2", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_2", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - ], - }, - ], - ), + # ( + # [ + # {"instruction": "test_0", "additional_info": "additional_info_0"}, + # {"instruction": "test_1", "additional_info": "additional_info_1"}, + # {"instruction": "test_2", "additional_info": "additional_info_2"}, + # ], + # True, + # [ + # { + # "instruction": "test_0", + # "additional_info": "additional_info_0", + # "output": ["output", "output", "output"], + # "info_from_input": [ + # "additional_info_0", + # "additional_info_0", + # "additional_info_0", + # ], + # "model_name": "test", + # "distilabel_metadata": [ + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_0", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_0", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_0", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # ], + # }, + # { + # "instruction": "test_1", + # "additional_info": "additional_info_1", + # "output": ["output", "output", "output"], + # "info_from_input": [ + # "additional_info_1", + # "additional_info_1", + # "additional_info_1", + # ], + # "model_name": "test", + # "distilabel_metadata": [ + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_1", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_1", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_1", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # ], + # }, + # { + # "instruction": "test_2", + # "additional_info": "additional_info_2", + # "output": ["output", "output", "output"], + # "info_from_input": [ + # "additional_info_2", + # "additional_info_2", + # "additional_info_2", + # ], + # "model_name": "test", + # "distilabel_metadata": [ + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_2", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_2", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_2", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # ], + # }, + # ], + # ), ], ) def test_process( @@ -424,7 +379,7 @@ def test_process( llm=llm, pipeline=pipeline, group_generations=group_generations, - num_generations=3, + num_generations=2, ) task.load() result = next(task.process(input)) @@ -436,7 +391,7 @@ def test_process_overriding_inputs(self) -> None: name="task", llm=llm, group_generations=False, - num_generations=3, + num_generations=2, input_mappings={"instruction": "instruction_2"}, ) task.load() @@ -452,6 +407,7 @@ def test_process_overriding_inputs(self) -> None: ] ) ) + print("REUSLT", result) assert result == [ { @@ -468,7 +424,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -490,7 +446,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -512,7 +468,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping",