Skip to content

Commit

Permalink
Implement "Improving Text Embeddings with LLMs" (#683)
Browse files Browse the repository at this point in the history
* Set `input` as optional in `format_output`

* Implement "Improving Text Embeddings with LLMs" (WIP)

* Implement "Improving Text Embeddings with LLMs" (WIP)

* Add `model_name` at the end of each batch

* Move `text_embeddings.py` to `improving_text_embeddings.py`

* Fix `re.sub` to also capture `\t` and `\r`

* Add `MonolingualTripletGenerator` and `BitextRetrievalGenerator`

* Move all `templates` from `str` to `jinja2` files

* Update class naming and imports

* Add some docstrings and fix `jinja2` file paths

* Fix `prompt` accross tasks

* Add missing docstrings

* Fix `process` method in `EmbeddingTaskGenerator`

* Add unit tests for `...Generator` tasks

* Add remaining unit tests

* Remove duplicated imports in `distilabel.steps.tasks`

* Add examples in docstrings and add notes
  • Loading branch information
alvarobartt authored Jun 12, 2024
1 parent a0d7e93 commit 0e8c752
Show file tree
Hide file tree
Showing 15 changed files with 1,494 additions and 6 deletions.
16 changes: 16 additions & 0 deletions src/distilabel/steps/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
from distilabel.steps.tasks.evol_quality.base import EvolQuality
from distilabel.steps.tasks.generate_embeddings import GenerateEmbeddings
from distilabel.steps.tasks.genstruct import Genstruct
from distilabel.steps.tasks.improving_text_embeddings import (
BitextRetrievalGenerator,
EmbeddingTaskGenerator,
GenerateLongTextMatchingData,
GenerateShortTextMatchingData,
GenerateTextClassificationData,
GenerateTextRetrievalData,
MonolingualTripletGenerator,
)
from distilabel.steps.tasks.instruction_backtranslation import (
InstructionBacktranslation,
)
Expand All @@ -47,6 +56,13 @@
"EvolQuality",
"GenerateEmbeddings",
"Genstruct",
"BitextRetrievalGenerator",
"EmbeddingTaskGenerator",
"GenerateLongTextMatchingData",
"GenerateShortTextMatchingData",
"GenerateTextClassificationData",
"GenerateTextRetrievalData",
"MonolingualTripletGenerator",
"InstructionBacktranslation",
"PairRM",
"PrometheusEval",
Expand Down
19 changes: 13 additions & 6 deletions src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def load(self) -> None:

@abstractmethod
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
self,
output: Union[str, None],
input: Union[Dict[str, Any], None] = None,
) -> Dict[str, Any]:
"""Abstract method to format the outputs of the task. It needs to receive an output
as a string, and generates a Python dictionary with the outputs of the task. In
Expand All @@ -80,7 +82,9 @@ def format_output(
pass

def _format_outputs(
self, outputs: "GenerateOutput", inputs: List[Dict[str, Any]]
self,
outputs: "GenerateOutput",
inputs: Union[List[Dict[str, Any]], None] = None,
) -> List[Dict[str, Any]]:
"""Formats the outputs of the task using the `format_output` method. If the output
is `None` (i.e. the LLM failed to generate a response), then the outputs will be
Expand All @@ -93,8 +97,11 @@ def _format_outputs(
Returns:
A list containing a dictionary with the outputs of the task for each input.
"""
if inputs is None:
inputs = [None] # type: ignore

formatted_outputs = []
for output, input in zip(outputs, inputs * len(outputs)):
for output, input in zip(outputs, inputs * len(outputs)): # type: ignore
try:
formatted_output = self.format_output(output, input)
formatted_output = self._maybe_add_raw_output(
Expand All @@ -109,7 +116,7 @@ def _format_outputs(
return formatted_outputs

def _output_on_failure(
self, output: Union[str, None], input: Dict[str, Any]
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
"""In case of failure to format the output, this method will return a dictionary including
a new field `distilabel_meta` with the raw output of the LLM.
Expand Down Expand Up @@ -189,14 +196,14 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
if self.group_generations:
combined = combine_dicts(*formatted_outputs)
task_outputs.append(
{**input, "model_name": self.llm.model_name, **combined}
{**input, **combined, "model_name": self.llm.model_name}
)
continue

# Create a row per generation
for formatted_output in formatted_outputs:
task_outputs.append(
{**input, "model_name": self.llm.model_name, **formatted_output}
{**input, **formatted_output, "model_name": self.llm.model_name}
)

yield task_outputs
Expand Down
Loading

0 comments on commit 0e8c752

Please sign in to comment.