Skip to content

Commit

Permalink
Checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Oct 22, 2024
1 parent 9746d75 commit f108670
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 252 deletions.
6 changes: 6 additions & 0 deletions src/distilabel/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,12 @@ async def _agenerate(
for input in inputs
]
result = await asyncio.gather(*tasks)
print("\n_agenerate\n\n", result)
print("\n_agenerate MERGED\n\n", merge_responses(result))
print(
"CORRECT merge_response, ITS GROUPING num_generations MIXED WITH THE INPUTS PASSED"
)
# TODO: Update this,
return merge_responses(result)

tasks = [
Expand Down
94 changes: 84 additions & 10 deletions src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from distilabel.utils.dicts import group_dicts

if TYPE_CHECKING:
from distilabel.llms.typing import GenerateOutput, LLMOutput, LLMStatistics
from distilabel.llms.typing import GenerateOutput, LLMStatistics
from distilabel.steps.tasks.typing import ChatType, FormattedInput
from distilabel.steps.typing import StepOutput

Expand Down Expand Up @@ -170,30 +170,40 @@ def _format_outputs(
A list containing a dictionary with the outputs of the task for each input.
"""
inputs = [None] if input is None else [input]

print("INPUTS", inputs)
formatted_outputs = []
for output, input in zip(outputs, inputs * len(outputs)): # type: ignore
repeate_inputs = len(outputs.get("generations"))
outputs = normalize_statistics(outputs)

for (output, stats), input in zip(
iterate_generations_with_stats(outputs), inputs * repeate_inputs
): # type: ignore
# for output, input in zip(outputs, inputs * len(outputs)): # type: ignore
try:
# Extract the generations, and move the statistics to the distilabel_metadata,
# to keep everything clean
output_generations: "LLMOutput" = output.get("generations", [])
formatted_output = self.format_output(output_generations, input)
# TODO: THIS WOULD FAIL IF THE LLM DOESN'T RETURN generations,
# WE HAVE TO REMOVE THE STATISTICS AND PASS EVERYTHING ELSE
print("OUTPUT", output)
print("STATS", stats)
print("INPUT", input)
# output_generations: "LLMOutput" = output.get("generations", [])
formatted_output = self.format_output(output, input)
formatted_output = self._create_metadata(
formatted_output,
output_generations,
output,
input,
add_raw_output=self.add_raw_output, # type: ignore
add_raw_input=self.add_raw_input, # type: ignore
statistics=output.get("statistics"),
# statistics=output.get("statistics"),
statistics=stats,
)
formatted_outputs.append(formatted_output)
except Exception as e:
self._logger.warning( # type: ignore
f"Task '{self.name}' failed to format output: {e}. Saving raw response." # type: ignore
)
formatted_outputs.append(
self._output_on_failure(output.get("generations", []), input)
)
formatted_outputs.append(self._output_on_failure(output, input))
return formatted_outputs

def _output_on_failure(
Expand Down Expand Up @@ -437,6 +447,8 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
)

task_outputs = []
print("INPUTS", inputs)
print("OUTPUTS", outputs)
for input, input_outputs in zip(inputs, outputs):
formatted_outputs = self._format_outputs(input_outputs, input)

Expand All @@ -449,6 +461,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore

# Create a row per generation
for formatted_output in formatted_outputs:
print("FORMATED", formatted_output)
task_outputs.append(
{**input, **formatted_output, "model_name": self.llm.model_name}
)
Expand Down Expand Up @@ -477,3 +490,64 @@ class GlobalTask(_Task, GlobalStep):
"""

pass


def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput":
"""Transforms the GenerateOutput statistics to have the same length as the generations.
Args:
data: A generate output that possibly has different lengths of statistics
vs generations (due to num_generations=3 returning 3 generations, but
for example the tokens are only counted once).
Returns:
Normalized statistics according to the generations length.
Examples:
```python
data = {
"generations": ["text1", "text2", "text3", "text4"],
"statistics": {"input_tokens": [1], "output_tokens": [1, 2, 3]}
}
normalize_statistics(data)
data = {
"generations": ["text1", "text2", "text3"],
"statistics": {"input_tokens": [1, 1, 1], "output_tokens": [1, 2, 3]}
}
```
"""
if not (statistics := output.get("statistics")):
print(statistics)
return output
gen_length = len(output["generations"])

for stat_key, stat_values in output["statistics"].items():
current_length = len(stat_values)

if current_length < gen_length:
# Calculate how many times to repeat the tokens
repeats = gen_length // current_length
remainder = gen_length % current_length

# Create new list with repeated values
new_values = stat_values * repeats + stat_values[:remainder]
output["statistics"][stat_key] = new_values

return output


def iterate_generations_with_stats(output: "GenerateOutput") -> "GenerateOutput":
"""Helper function to iterate together generations and statistics while
processing them inside _format_outputs.
Args:
output: Output from the LLM.generate_outputs method.
Yields:
Iterator of generation and statistics paired.
"""
for i, generation in enumerate(output["generations"]):
# Create a new dictionary with the statistics for this index
stats = {key: values[i] for key, values in output["statistics"].items()}

yield generation, stats
40 changes: 29 additions & 11 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,20 @@ def model_name(self) -> str:
async def agenerate( # type: ignore
self, input: "FormattedInput", num_generations: int = 1
) -> "GenerateOutput":
# return ["output" for _ in range(num_generations)]
return [
{"generations": "output", "statistics": {"test": "test"}}
for _ in range(num_generations)
]
# return {
# "generations": ["output"],
# "statistics": {
# "input_tokens": [12],
# "output_tokens": [12],
# },
# }
return {
"generations": ["output" for i in range(num_generations)],
"statistics": {
"input_tokens": [12] * num_generations,
"output_tokens": [12] * num_generations,
},
}


class DummyLLM(LLM):
Expand All @@ -60,11 +69,14 @@ def generate( # type: ignore
self, inputs: "FormattedInput", num_generations: int = 1
) -> List["GenerateOutput"]:
return [
[
{"generations": "output", "statistics": {"test": "test"}}
for _ in range(num_generations)
]
]
{
"generations": [f"output {i}" for i in range(num_generations)],
"statistics": {
"input_tokens": [12] * num_generations,
"output_tokens": [12] * num_generations,
},
}
] * len(inputs)


class DummyMagpieLLM(LLM, MagpieChatTemplateMixin):
Expand All @@ -80,7 +92,13 @@ def generate(
) -> List["GenerateOutput"]:
return [
[
{"generations": "output", "statistics": {"test": "test"}}
{
"generations": ["output"] * num_generations,
"statistics": {
"input_tokens": [12] * num_generations,
"output_tokens": [12] * num_generations,
},
}
for _ in range(num_generations)
]
for _ in range(len(inputs))
Expand Down
Loading

0 comments on commit f108670

Please sign in to comment.