Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context to guide the generate sentence pair task if informed #706

Merged
merged 3 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions src/distilabel/steps/tasks/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,31 +43,36 @@
}

POSITIVE_SYSTEM_PROMPT: str = (
"Your task is to generate a positive sentence given an anchor sentence. The positive"
"Your task is to generate a positive sentence given an anchor sentence.{context} The positive"
" sentence has to {action_sentence} the anchor sentence. You must output only one new"
" section: `## Positive`."
)

POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = (
"Your task is to generate a positive and a negative sentence given an anchor sentence."
"Your task is to generate a positive and a negative sentence given an anchor sentence.{context}"
" The positive sentence has to {action_sentence} the anchor sentence, while the negative"
" sentence can use similar words but must not be related to the anchor sentence. You"
" must output only two new sections: `## Positive` and `## Negative`."
)

CONTEXT_INTRO: Final[str] = " Take into account the context given."


class GenerateSentencePair(Task):
"""Generate a positive and negative (optionally) sentences given an anchor sentence.

`GenerateSentencePair` is a pre-defined task that given an anchor sentence generates
a positive sentence related to the anchor and optionally a negative sentence unrelated
to the anchor. This task is useful to generate training datasets for training embeddings
to the anchor. Optionally, you can give a context to guide the LLM towards more specific
behavior. This task is useful to generate training datasets for training embeddings
models.

Attributes:
triplet: a flag to indicate if the task should generate a triplet of sentences
(anchor, positive, negative). Defaults to `False`.
action: the action to perform to generate the positive sentence.
context: the context to use for the generation. Can be helpful to guide the LLM
towards more specific context. Not used by default.

Input columns:
- anchor (`str`): The anchor sentence to generate the positive and negative sentences.
Expand Down Expand Up @@ -165,10 +170,33 @@ class GenerateSentencePair(Task):

result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}])
```

Generating queries with context (**applies to every action**):

```python
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.llms import InferenceEndpointsLLM

generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
action="query",
context="Argilla is an open-source data curation platform for LLMs.",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
input_batch_size=10,
)

generate_sentence_pair.load()

result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}])
```
"""

triplet: bool = False
action: GenerationAction
context: str = ""

def load(self) -> None:
"""Loads the Jinja2 template."""
Expand Down Expand Up @@ -203,11 +231,20 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
action_sentence = GENERATION_ACTION_SENTENCES[self.action]
system_prompt = (
POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT
).format(action_sentence=action_sentence)
).format(
action_sentence=action_sentence,
context=CONTEXT_INTRO if self.context else "",
)

return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": self._template.render(anchor=input["anchor"])},
{
"role": "user",
"content": self._template.render(
anchor=input["anchor"],
context=self.context if self.context else None,
),
},
]

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
{% if context is not none -%}
## Context

{{ context }}

{% endif -%}

## Anchor

{{ anchor }}
Expand Down
104 changes: 95 additions & 9 deletions tests/unit/steps/tasks/test_sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
from distilabel.steps.tasks.sentence_transformers import (
CONTEXT_INTRO,
POSITIVE_NEGATIVE_SYSTEM_PROMPT,
POSITIVE_SYSTEM_PROMPT,
GenerateSentencePair,
Expand All @@ -32,50 +33,56 @@ class TestGenerateSentencePair:
(
"paraphrase",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase"),
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="paraphrase", context=""
),
),
(
"paraphrase",
False,
POSITIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase"),
POSITIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase", context=""),
),
(
"semantically-similar",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be semantically similar to"
action_sentence="be semantically similar to", context=""
),
),
(
"semantically-similar",
False,
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be semantically similar to"
action_sentence="be semantically similar to", context=""
),
),
(
"query",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be a query for"
action_sentence="be a query for", context=""
),
),
(
"query",
False,
POSITIVE_SYSTEM_PROMPT.format(action_sentence="be a query for"),
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be a query for", context=""
),
),
(
"answer",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be an answer for"
action_sentence="be an answer for", context=""
),
),
(
"answer",
False,
POSITIVE_SYSTEM_PROMPT.format(action_sentence="be an answer for"),
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be an answer for", context=""
),
),
],
)
Expand All @@ -84,10 +91,89 @@ def test_format_input(
) -> None:
task = GenerateSentencePair(llm=DummyLLM(), action=action, triplet=triplet)
task.load()
content = "## Anchor\n\nThis is a unit test\n"
assert task.format_input({"anchor": "This is a unit test"}) == [
{"role": "system", "content": system_prompt},
{"role": "user", "content": content},
]

@pytest.mark.parametrize(
"action,triplet,system_prompt",
[
(
"paraphrase",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="paraphrase", context=CONTEXT_INTRO
),
),
(
"paraphrase",
False,
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="paraphrase", context=CONTEXT_INTRO
),
),
(
"semantically-similar",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be semantically similar to", context=CONTEXT_INTRO
),
),
(
"semantically-similar",
False,
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be semantically similar to", context=CONTEXT_INTRO
),
),
(
"query",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be a query for", context=CONTEXT_INTRO
),
),
(
"query",
False,
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be a query for", context=CONTEXT_INTRO
),
),
(
"answer",
True,
POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(
action_sentence="be an answer for", context=CONTEXT_INTRO
),
),
(
"answer",
False,
POSITIVE_SYSTEM_PROMPT.format(
action_sentence="be an answer for", context=CONTEXT_INTRO
),
),
],
)
def test_format_input_with_context(
self, action: GenerationAction, triplet: bool, system_prompt: str
) -> None:
context = "This is your context."
task = GenerateSentencePair(
llm=DummyLLM(),
action=action,
triplet=triplet,
context=context,
)
task.load()
content = f"## Context\n\n{context}\n\n## Anchor\n\nThis is a unit test\n"
# content = f"## Anchor\n\nThis is a unit test\n## Context\n\n{context}"
assert task.format_input({"anchor": "This is a unit test"}) == [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "## Anchor\n\nThis is a unit test\n"},
{"role": "user", "content": content},
]

@pytest.mark.parametrize(
Expand Down
Loading