Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GenerateSentencePair task #689

Merged
merged 16 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion docs/sections/learn/tutorial/task/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The subclasses of [`Task`][distilabel.steps.tasks.Task] are intended to be used

For example, the most basic task is the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task, which generates text based on a given instruction, and it can be used standalone as well as within a [`Pipeline`][distilabel.pipeline.Pipeline].

```python
```python
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
from distilabel.steps.tasks import TextGeneration

Expand All @@ -18,12 +19,23 @@ task = TextGeneration(
task.load()

next(task.process([{"instruction": "What's the capital of Spain?"}]))
# [{'instruction': "What's the capital of Spain?", "generation": "The capital of Spain is Madrid.", "model_name": "gpt-4"}]
# [
# {
# "instruction": "What's the capital of Spain?",
# "generation": "The capital of Spain is Madrid.",
# "model_name": "gpt-4",
# "distilabel_metadata": {
# "raw_output_text-generation": "The capital of Spain is Madrid"
# }
# }
# ]
```

!!! NOTE
The `load` method needs to be called ALWAYS if using the tasks as standalone, otherwise, if the [`Pipeline`][distilabel.pipeline.Pipeline] context manager is used, there's no need to call that method, since it will be automatically called on `Pipeline.run`; but in any other case the method `load` needs to be called from the parent class e.g. a [`Task`][distilabel.steps.tasks.Task] with an [`LLM`][distilabel.llms.LLM] will need to call `Task.load` to load both the task and the LLM.

As we can see in the comment of the code snippet above, the task has enriched the input dictionaries adding the `generation`, the `model_name` that was used to generate, and finally the `distilabel_metadata` dictionary that contains the raw output (without post-processing) from the LLM. In this case, the `TextGeneration` task does no post-processing, so the `generation` and the raw output is the same, but some other tasks do post-processing, which in some situations it can fail. That's why is useful to have the raw output available in the `distilabel_metadata` dictionary. If this default behaviour is not desired, then all the `Task`s has a `add_raw_output` attribute that we can set to `False` when creating the instance of the task or at run time.

## Defining custom Tasks

In order to define custom tasks, we need to inherit from the [`Task`][distilabel.steps.tasks.Task] class and implement the `format_input` and `format_output` methods, as well as setting the properties `inputs` and `outputs`, as for [`Step`][distilabel.steps.Step] subclasses.
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/mixins/runtime_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None:
name, runtime_parameters_names, cutoff=0.5
)
msg = (
f"⚠️ Runtime parameter '{name}' unknown in step '{self.name}'."
f"⚠️ Runtime parameter '{name}' unknown in step '{self.name}'." # type: ignore
)
if closest:
msg += f" Did you mean any of: {closest}"
else:
msg += f" Available runtime parameters for the step: {runtime_parameters_names}."
self.pipeline._logger.warning(msg)
self.pipeline._logger.warning(msg) # type: ignore
continue

attr = getattr(self, name)
Expand Down
13 changes: 10 additions & 3 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,18 @@ def run(
The `Distiset` created by the pipeline.
"""

setup_logging(**self._logging_parameters)

# Set the runtime parameters that will be used during the pipeline execution
# Set the runtime parameters that will be used during the pipeline execution.
# They are used to generate the signature of the pipeline that is used to hit the
# cache when the pipeline is run, so it's important to do it first.
self._set_runtime_parameters(parameters or {})

setup_logging(
**{
**self._logging_parameters,
"filename": str(self._cache_location["log_file"]),
}
)

# Validate the pipeline DAG to check that all the steps are chainable, there are
# no missing runtime parameters, batch sizes are correct, etc.
self.dag.validate()
Expand Down
10 changes: 6 additions & 4 deletions src/distilabel/steps/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,15 @@
from distilabel.steps.tasks.prometheus_eval import PrometheusEval
from distilabel.steps.tasks.quality_scorer import QualityScorer
from distilabel.steps.tasks.self_instruct import SelfInstruct
from distilabel.steps.tasks.sentence_transformers import GenerateSentencePair
from distilabel.steps.tasks.structured_generation import StructuredGeneration
from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration
from distilabel.steps.tasks.typing import ChatItem, ChatType
from distilabel.steps.tasks.ultrafeedback import UltraFeedback

__all__ = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import order here was alphabetical, what's the rationale behind this change? Maybe we should change this in other places too to make sure we're aligned?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale was to have the imports ordered in __all__ by the order in which the were imported, which is more common than having them alphabetically.

"Task",
"GeneratorTask",
"ChatGeneration",
"ChatItem",
"ChatType",
"Task",
"ComplexityScorer",
"EvolInstruct",
"EvolComplexity",
Expand All @@ -54,7 +52,11 @@
"PrometheusEval",
"QualityScorer",
"SelfInstruct",
"GenerateSentencePair",
"StructuredGeneration",
"ChatGeneration",
"TextGeneration",
"ChatItem",
"ChatType",
"UltraFeedback",
]
8 changes: 7 additions & 1 deletion src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ class _Task(_Step, ABC):
llm: LLM

group_generations: bool = False
add_raw_output: bool = False
add_raw_output: RuntimeParameter[bool] = Field(
default=True,
description=(
"Whether to include the raw output of the LLM in the key `raw_output_<TASK_NAME>`"
" of the `distilabel_metadata` dictionary output column"
),
)
num_generations: RuntimeParameter[int] = Field(
default=1, description="The number of generations to be produced per input."
)
Expand Down
254 changes: 254 additions & 0 deletions src/distilabel/steps/tasks/sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import sys
from typing import TYPE_CHECKING, Any, Dict, Final, List, Literal, Optional, Union

from jinja2 import Template

from distilabel.steps.tasks.base import Task

if sys.version_info < (3, 9):
import importlib_resources
else:
import importlib.resources as importlib_resources

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType

GenerationAction = Literal["paraphrase", "semantically-similar", "query", "answer"]

POSITIVE_NEGATIVE_PAIR_REGEX = re.compile(
r"## Positive\s+(.*?)(?:\s+## Negative\s+(.*?))?\s*$",
re.DOTALL,
)

GENERATION_ACTION_SENTENCES: Final[Dict[GenerationAction, str]] = {
"paraphrase": "paraphrase",
"semantically-similar": "be semantically similar to",
"query": "be a query for",
"answer": "be an answer for",
}

POSITIVE_SYSTEM_PROMPT: str = (
"Your task is to generate a positive sentence given an anchor sentence. The positive"
" sentence has to {action_sentence} the anchor sentence. You must output only one new"
" section: `## Positive`."
)

POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = (
"Your task is to generate a positive and a negative sentence given an anchor sentence."
" The positive sentence has to {action_sentence} the anchor sentence, while the negative"
" sentence can use similar words but must not be related to the anchor sentence. You"
" must output only two new sections: `## Positive` and `## Negative`."
)


class GenerateSentencePair(Task):
"""Generate a positive and negative (optionally) sentences given an anchor sentence.

`GenerateSentencePair` is a pre-defined task that given an anchor sentence generates
a positive sentence related to the anchor and optionally a negative sentence unrelated
to the anchor. This task is useful to generate training datasets for training embeddings
models.

Attributes:
triplet: a flag to indicate if the task should generate a triplet of sentences
(anchor, positive, negative). Defaults to `False`.
action: the action to perform to generate the positive sentence.

Input columns:
- anchor (`str`): The anchor sentence to generate the positive and negative sentences.

Output columns:
- positive (`str`): The positive sentence related to the `anchor`.
- negative (`str`): The negative sentence unrelated to the `anchor` if `triplet=True`.
- model_name (`str`): The name of the model that was used to generate the sentences.

Categories:
- embedding

Examples:

Paraphrasing:

```python
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import InferenceEndpointsLLM

generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="paraphrase",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
input_batch_size=10,
)

generate_sentence_pair.load()

result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}])
```

Generating semantically similar sentences:

```python
from distilabel.llms import InferenceEndpointsLLM
from distilabel.steps.tasks import GenerateSentencePair

generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="semantically-similar",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
input_batch_size=10,
)

generate_sentence_pair.load()

result = generate_sentence_pair.process([{"anchor": "How does 3D printing work?"}])
```

Generating queries:

```python
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import InferenceEndpointsLLM

generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="query",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
input_batch_size=10,
)

generate_sentence_pair.load()

result = generate_sentence_pair.process([{"anchor": "Argilla is an open-source data curation platform for LLMs. Using Argilla, ..."}])
```

Generating answers:

```python
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import InferenceEndpointsLLM

generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="answer",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
input_batch_size=10,
)

generate_sentence_pair.load()

result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}])
```
"""

triplet: bool = False
action: GenerationAction

def load(self) -> None:
"""Loads the Jinja2 template."""
super().load()

_path = str(
importlib_resources.files("distilabel")
/ "steps"
/ "tasks"
/ "templates"
/ "generate-sentence-pair.jinja2"
)

self._template = Template(open(_path).read())

@property
def inputs(self) -> List[str]:
"""The inputs for the task is the `anchor` sentence."""
return ["anchor"]

def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The inputs are formatted as a `ChatType`, with a system prompt describing the
task of generating a positive and negative sentences for the anchor sentence. The
anchor is provided as the first user interaction in the conversation.

Args:
input: The input containing the `anchor` sentence.

Returns:
A list of dictionaries containing the system and user interactions.
"""
action_sentence = GENERATION_ACTION_SENTENCES[self.action]
system_prompt = (
POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT
).format(action_sentence=action_sentence)

return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": self._template.render(anchor=input["anchor"])},
]

@property
def outputs(self) -> List[str]:
"""The outputs for the task are the `positive` and `negative` sentences, as well
as the `model_name` used to generate the sentences."""
columns = ["positive", "negative"] if self.triplet else ["positive"]
columns += ["model_name"]
return columns

def format_output(
self, output: Union[str, None], input: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Formats the output of the LLM, to extract the `positive` and `negative` sentences
generated. If the output is `None` or the regex doesn't match, then the outputs
will be set to `None` as well.

Args:
output: The output of the LLM.
input: The input used to generate the output.

Returns:
The formatted output containing the `positive` and `negative` sentences.
"""
if output is None:
return {"positive": None, "negative": None}

match = POSITIVE_NEGATIVE_PAIR_REGEX.match(output)
if match is None:
formatted_output = {"positive": None}
if self.triplet:
formatted_output["negative"] = None
return formatted_output

groups = match.groups()
if self.triplet:
return {
"positive": groups[0].strip(),
"negative": groups[1].strip()
if len(groups) > 1 and groups[1] is not None
else None,
}

return {"positive": groups[0].strip()}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
## Anchor

{{ anchor }}

Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
{% for example_title, code in step.docstring.examples.items() %}
#### {{ example_title }}
```python
{{ code | e }}
{{ code | replace("\n", "\n") }}
```
{% endfor %}
{% endif %}
Expand Down
Loading
Loading