From 53019bc3b974d956ec9b1ab8422a5c160306fca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 3 Jun 2024 19:15:44 +0200 Subject: [PATCH] Update task to use system prompt --- .../steps/tasks/sentence_transformers.py | 30 ++++++++++++++----- .../templates/generate-sentence-pair.jinja2 | 8 ----- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py index 14f6ad0db4..63157d6d47 100644 --- a/src/distilabel/steps/tasks/sentence_transformers.py +++ b/src/distilabel/steps/tasks/sentence_transformers.py @@ -31,11 +31,10 @@ GenerationAction = Literal["paraphrase", "semantically-similar", "query"] POSITIVE_NEGATIVE_PAIR_REGEX = re.compile( - r"## Positive\s+(.*?)\s+(?:## Negative\s+(.*?)\s*)?$", + r"## Positive\s+(.*?)(?:\s+## Negative\s+(.*?))?\s*$", re.DOTALL, ) - GENERATION_ACTION_SENTENCES: Final[Dict[GenerationAction, str]] = { "paraphrase": "paraphrase", "semantically-similar": "be semantically similar to", @@ -89,14 +88,27 @@ def inputs(self) -> List[str]: return ["anchor"] def format_input(self, input: Dict[str, Any]) -> "ChatType": + action_sentence = GENERATION_ACTION_SENTENCES[self.action] return [ { - "role": "user", - "content": self._template.render( - anchor=input["anchor"], - action=GENERATION_ACTION_SENTENCES[self.action], + "role": "system", + "content": ( + "Your task is to generate a positive and a negative sentence given" + f" an anchor sentence. The positive sentence has to {action_sentence}" + " the anchor sentence, while the negative sentence has to do the opposite." + " You must output only two new sections: `## Positive` and `## Negative`." + ) + if self.triplet + else ( + "Your task is to generate a positive sentence given an anchor sentence." + f" The positive sentence has to {action_sentence} the anchor sentence." + " You must output only one new section: `## Positive`." ), - } + }, + { + "role": "user", + "content": self._template.render(anchor=input["anchor"]), + }, ] @property @@ -120,7 +132,9 @@ def format_output( if self.triplet: return { "positive": groups[0].strip(), - "negative": groups[1].strip() if len(groups) > 1 else None, + "negative": groups[1].strip() + if len(groups) > 1 and groups[1] is not None + else None, } return {"positive": groups[0].strip()} diff --git a/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 b/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 index 6b745949f6..82594f18a8 100644 --- a/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 +++ b/src/distilabel/steps/tasks/templates/generate-sentence-pair.jinja2 @@ -1,12 +1,4 @@ -# Task Description -{% if triplet %} -Your task is to generate a positive and a negative sentence given an anchor sentence. The positive sentence has to {{ action }} the anchor sentence, while the negative sentence is the contrary. You must output only two new sections: `## Positive` and `## Negative`. -{% else %} -Your task is to generate a positive sentence given an anchor sentence. The positive sentence has to {{ action }} the anchor sentence. You must output only one new section: `## Positive`. -{% endif %} - ## Anchor {{ anchor }} -## Positive