diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py index 12e39eb08e..b1ad50f5e1 100644 --- a/src/distilabel/steps/tasks/sentence_transformers.py +++ b/src/distilabel/steps/tasks/sentence_transformers.py @@ -43,31 +43,36 @@ } POSITIVE_SYSTEM_PROMPT: str = ( - "Your task is to generate a positive sentence given an anchor sentence. The positive" + "Your task is to generate a positive sentence given an anchor sentence.{context} The positive" " sentence has to {action_sentence} the anchor sentence. You must output only one new" " section: `## Positive`." ) POSITIVE_NEGATIVE_SYSTEM_PROMPT: str = ( - "Your task is to generate a positive and a negative sentence given an anchor sentence." + "Your task is to generate a positive and a negative sentence given an anchor sentence.{context}" " The positive sentence has to {action_sentence} the anchor sentence, while the negative" " sentence can use similar words but must not be related to the anchor sentence. You" " must output only two new sections: `## Positive` and `## Negative`." ) +CONTEXT_INTRO: Final[str] = " Take into account the context given." + class GenerateSentencePair(Task): """Generate a positive and negative (optionally) sentences given an anchor sentence. `GenerateSentencePair` is a pre-defined task that given an anchor sentence generates a positive sentence related to the anchor and optionally a negative sentence unrelated - to the anchor. This task is useful to generate training datasets for training embeddings + to the anchor. Optionally, you can give a context to guide the LLM towards more specific + behavior. This task is useful to generate training datasets for training embeddings models. Attributes: triplet: a flag to indicate if the task should generate a triplet of sentences (anchor, positive, negative). Defaults to `False`. action: the action to perform to generate the positive sentence. + context: the context to use for the generation. Can be helpful to guide the LLM + towards more specific context. Not used by default. Input columns: - anchor (`str`): The anchor sentence to generate the positive and negative sentences. @@ -165,10 +170,33 @@ class GenerateSentencePair(Task): result = generate_sentence_pair.process([{"anchor": "What Game of Thrones villain would be the most likely to give you mercy?"}]) ``` + + Generating queries with context (**applies to every action**): + + ```python + from distilabel.steps.tasks import GenerateSentencePair + from distilabel.llms import InferenceEndpointsLLM + + generate_sentence_pair = GenerateSentencePair( + triplet=True, # `False` to generate only positive + action="query", + context="Argilla is an open-source data curation platform for LLMs.", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + ), + input_batch_size=10, + ) + + generate_sentence_pair.load() + + result = generate_sentence_pair.process([{"anchor": "I want to generate queries for my LLM."}]) + ``` """ triplet: bool = False action: GenerationAction + context: str = "" def load(self) -> None: """Loads the Jinja2 template.""" @@ -203,11 +231,20 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType": action_sentence = GENERATION_ACTION_SENTENCES[self.action] system_prompt = ( POSITIVE_NEGATIVE_SYSTEM_PROMPT if self.triplet else POSITIVE_SYSTEM_PROMPT - ).format(action_sentence=action_sentence) + ).format( + action_sentence=action_sentence, + context=CONTEXT_INTRO if self.context else "", + ) return [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": self._template.render(anchor=input["anchor"])}, + { + "role": "user", + "content": self._template.render( + anchor=input["anchor"], + context=self.context if self.context else None, + ), + }, ] @property diff --git a/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 b/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 index 82594f18a8..cac188e101 100644 --- a/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 +++ b/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 @@ -1,3 +1,10 @@ +{% if context is not none -%} +## Context + +{{ context }} + +{% endif -%} + ## Anchor {{ anchor }} diff --git a/tests/unit/steps/tasks/test_sentence_transformers.py b/tests/unit/steps/tasks/test_sentence_transformers.py index 3e50e7e3f1..e63b14bc83 100644 --- a/tests/unit/steps/tasks/test_sentence_transformers.py +++ b/tests/unit/steps/tasks/test_sentence_transformers.py @@ -16,6 +16,7 @@ import pytest from distilabel.steps.tasks.sentence_transformers import ( + CONTEXT_INTRO, POSITIVE_NEGATIVE_SYSTEM_PROMPT, POSITIVE_SYSTEM_PROMPT, GenerateSentencePair, @@ -32,50 +33,56 @@ class TestGenerateSentencePair: ( "paraphrase", True, - POSITIVE_NEGATIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase"), + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="paraphrase", context="" + ), ), ( "paraphrase", False, - POSITIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase"), + POSITIVE_SYSTEM_PROMPT.format(action_sentence="paraphrase", context=""), ), ( "semantically-similar", True, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="be semantically similar to" + action_sentence="be semantically similar to", context="" ), ), ( "semantically-similar", False, POSITIVE_SYSTEM_PROMPT.format( - action_sentence="be semantically similar to" + action_sentence="be semantically similar to", context="" ), ), ( "query", True, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="be a query for" + action_sentence="be a query for", context="" ), ), ( "query", False, - POSITIVE_SYSTEM_PROMPT.format(action_sentence="be a query for"), + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be a query for", context="" + ), ), ( "answer", True, POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( - action_sentence="be an answer for" + action_sentence="be an answer for", context="" ), ), ( "answer", False, - POSITIVE_SYSTEM_PROMPT.format(action_sentence="be an answer for"), + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be an answer for", context="" + ), ), ], ) @@ -84,10 +91,89 @@ def test_format_input( ) -> None: task = GenerateSentencePair(llm=DummyLLM(), action=action, triplet=triplet) task.load() + content = "## Anchor\n\nThis is a unit test\n" + assert task.format_input({"anchor": "This is a unit test"}) == [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": content}, + ] + @pytest.mark.parametrize( + "action,triplet,system_prompt", + [ + ( + "paraphrase", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="paraphrase", context=CONTEXT_INTRO + ), + ), + ( + "paraphrase", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="paraphrase", context=CONTEXT_INTRO + ), + ), + ( + "semantically-similar", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be semantically similar to", context=CONTEXT_INTRO + ), + ), + ( + "semantically-similar", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be semantically similar to", context=CONTEXT_INTRO + ), + ), + ( + "query", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be a query for", context=CONTEXT_INTRO + ), + ), + ( + "query", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be a query for", context=CONTEXT_INTRO + ), + ), + ( + "answer", + True, + POSITIVE_NEGATIVE_SYSTEM_PROMPT.format( + action_sentence="be an answer for", context=CONTEXT_INTRO + ), + ), + ( + "answer", + False, + POSITIVE_SYSTEM_PROMPT.format( + action_sentence="be an answer for", context=CONTEXT_INTRO + ), + ), + ], + ) + def test_format_input_with_context( + self, action: GenerationAction, triplet: bool, system_prompt: str + ) -> None: + context = "This is your context." + task = GenerateSentencePair( + llm=DummyLLM(), + action=action, + triplet=triplet, + context=context, + ) + task.load() + content = f"## Context\n\n{context}\n\n## Anchor\n\nThis is a unit test\n" + # content = f"## Anchor\n\nThis is a unit test\n## Context\n\n{context}" assert task.format_input({"anchor": "This is a unit test"}) == [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": "## Anchor\n\nThis is a unit test\n"}, + {"role": "user", "content": content}, ] @pytest.mark.parametrize(