Skip to content

Commit

Permalink
Fix tests from merge responses and group generations
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Oct 23, 2024
1 parent f108670 commit c8063a4
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 326 deletions.
42 changes: 26 additions & 16 deletions src/distilabel/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import time
from abc import ABC, abstractmethod
from functools import cached_property
from itertools import islice
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
Expand Down Expand Up @@ -459,21 +460,15 @@ async def _agenerate(
for input in inputs
]
result = await asyncio.gather(*tasks)
print("\n_agenerate\n\n", result)
print("\n_agenerate MERGED\n\n", merge_responses(result))
print(
"CORRECT merge_response, ITS GROUPING num_generations MIXED WITH THE INPUTS PASSED"
)
# TODO: Update this,
return merge_responses(result)
return result

tasks = [
asyncio.create_task(self.agenerate(input=input, **kwargs))
for input in inputs
for _ in range(num_generations)
]
outputs = await asyncio.gather(*tasks)
return merge_responses(outputs)
return merge_responses(outputs, n=num_generations)

def generate(
self,
Expand Down Expand Up @@ -595,26 +590,41 @@ def _prepare_kwargs(
return arguments


def merge_responses(responses: List[Dict[str, Any]]) -> List["GenerateOutput"]:
def merge_responses(
responses: List[Dict[str, Any]], n: int = 1
) -> List[Dict[str, Any]]:
"""Helper function to group the responses from `LLM.agenerate` method according
to the number of generations requested.
Args:
responses: the responses from the `LLM.agenerate` method.
n: number of responses to group together. Defaults to 1.
Returns:
Merges the texts and statistics of the responses into a single response.
List of merged responses, where each merged response contains n generations
and their corresponding statistics.
"""
if not responses:
return []

first = responses[0]
return [
{
"generations": sum((r["generations"] for r in responses), []),
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield list(islice(lst, i, i + n))

# Split responses into groups of size n
grouped_responses = list(chunks(responses, n))

result = []
for group in grouped_responses:
first = group[0]
merged = {
"generations": sum((r["generations"] for r in group), []),
"statistics": {
key: sum((r["statistics"][key] for r in responses), [])
key: sum((r["statistics"][key] for r in group), [])
for key in first["statistics"]
},
}
]
result.append(merged)

return result
17 changes: 2 additions & 15 deletions src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,32 +170,23 @@ def _format_outputs(
A list containing a dictionary with the outputs of the task for each input.
"""
inputs = [None] if input is None else [input]
print("INPUTS", inputs)
formatted_outputs = []
repeate_inputs = len(outputs.get("generations"))
outputs = normalize_statistics(outputs)

for (output, stats), input in zip(
iterate_generations_with_stats(outputs), inputs * repeate_inputs
): # type: ignore
# for output, input in zip(outputs, inputs * len(outputs)): # type: ignore
try:
# Extract the generations, and move the statistics to the distilabel_metadata,
# to keep everything clean
# TODO: THIS WOULD FAIL IF THE LLM DOESN'T RETURN generations,
# WE HAVE TO REMOVE THE STATISTICS AND PASS EVERYTHING ELSE
print("OUTPUT", output)
print("STATS", stats)
print("INPUT", input)
# output_generations: "LLMOutput" = output.get("generations", [])
formatted_output = self.format_output(output, input)
formatted_output = self._create_metadata(
formatted_output,
output,
input,
add_raw_output=self.add_raw_output, # type: ignore
add_raw_input=self.add_raw_input, # type: ignore
# statistics=output.get("statistics"),
statistics=stats,
)
formatted_outputs.append(formatted_output)
Expand Down Expand Up @@ -224,7 +215,6 @@ def _output_on_failure(
)
return outputs

# TODO: Rename to _create_metadata
def _create_metadata(
self,
output: Dict[str, Any],
Expand Down Expand Up @@ -447,8 +437,6 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
)

task_outputs = []
print("INPUTS", inputs)
print("OUTPUTS", outputs)
for input, input_outputs in zip(inputs, outputs):
formatted_outputs = self._format_outputs(input_outputs, input)

Expand All @@ -461,7 +449,6 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore

# Create a row per generation
for formatted_output in formatted_outputs:
print("FORMATED", formatted_output)
task_outputs.append(
{**input, **formatted_output, "model_name": self.llm.model_name}
)
Expand Down Expand Up @@ -516,8 +503,8 @@ def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput":
}
```
"""
if not (statistics := output.get("statistics")):
print(statistics)
statistics = output.get("statistics")
if not statistics:
return output
gen_length = len(output["generations"])

Expand Down
12 changes: 4 additions & 8 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union

import pytest
from pydantic import PrivateAttr

from distilabel.llms.base import LLM, AsyncLLM
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
Expand All @@ -28,9 +29,11 @@
# Defined here too, so that the serde still works
class DummyAsyncLLM(AsyncLLM):
structured_output: Any = None
n_generations_supported: bool = True # To work as OpenAI or an LLM that doesn't allow num_generations out of the box
_num_generations_param_supported: bool = PrivateAttr(default=True)

def load(self) -> None:
pass
self._num_generations_param_supported = self.n_generations_supported

@property
def model_name(self) -> str:
Expand All @@ -39,13 +42,6 @@ def model_name(self) -> str:
async def agenerate( # type: ignore
self, input: "FormattedInput", num_generations: int = 1
) -> "GenerateOutput":
# return {
# "generations": ["output"],
# "statistics": {
# "input_tokens": [12],
# "output_tokens": [12],
# },
# }
return {
"generations": ["output" for i in range(num_generations)],
"statistics": {
Expand Down
Loading

0 comments on commit c8063a4

Please sign in to comment.