diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 98974b00d..3e0d1479f 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -50,7 +50,11 @@ from distilabel.steps.tasks.sentence_transformers import GenerateSentencePair from distilabel.steps.tasks.structured_generation import StructuredGeneration from distilabel.steps.tasks.text_classification import TextClassification -from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration +from distilabel.steps.tasks.text_generation import ( + ChatGeneration, + TextGeneration, + TextGenerationWithCotReflection, +) from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback from distilabel.steps.tasks.urial import URIAL @@ -90,6 +94,7 @@ "TextClassification", "ChatGeneration", "TextGeneration", + "TextGenerationWithCotReflection", "ChatItem", "ChatType", "CLAIR", diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index a8b2048e5..5083e5c78 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -299,6 +299,137 @@ def format_output( return {"generation": output} +# Slighty modified version: https://github.com/codelion/optillm/blob/main/optillm/cot_reflection.py +COT_REFLECTION_SYSTEM_PROMPT = """ +You are an AI assistant that uses a Chain of Thought (CoT) approach with reflection to answer queries. Follow these steps: + +1. Think through the problem step by step within the tags. +2. Reflect on your thinking to check for any errors or improvements within the tags. +3. Make any necessary adjustments based on your reflection. +4. Provide your final, concise answer within the tags. + +Important: The and sections are for your internal reasoning process only. +Do not include any part of the final answer in these sections. +You can only create one and the blocks must be contained within it. +The actual response to the query must be entirely contained within the tags. +You must always include at the end of the generation. + +Use the following format for your response: + +``` + +[Your step-by-step reasoning goes here. This is your internal thought process, not the final answer.] + +[Your reflection on your reasoning, checking for errors or improvements] + +[Any adjustments to your thinking based on your reflection] + + +[Your final, concise answer to the query. This is the only part that will be shown to the user.] + +``` +""".lstrip() + +# Sometimes `LLM`s doesn't generate the `` that's why it's optional +COT_REFLECTION_OUTPUT_REGEX = re.compile( + r"([\s\S]*?)\s*([\s\S]*?)(?:)?" +) + + +class TextGenerationWithCotReflection(Task): + """Text generation with an `LLM` using Chain of Thought (CoT) reflection. + + `TextGenerationWithCotReflection` is a `Task` that allows generating a response for + a given instruction using a Chain of Thought (CoT) approach with reflection. The `LLM` + will first think through the problem step by step, reflect on the thinking process, make + any necessary adjustments based on the reflection, and provide a final, concise answer. + This method usually helps in generating more accurate and thoughtful responses at the + cost of generating more tokens and being slower. + + Attributes: + system_prompt: The system prompt to use in the generation and that will be appended + to the CoT Reflection system prompt. If not provided, then it will check if + the input row has a column named `system_prompt` and use it. If not, then no + system prompt will be used. Defaults to `None`. + + Input columns: + - instruction (`str`): The instruction to generate the response. + - system_prompt (`str`, optional): The system prompt to use in the generation and + that will be appended to the CoT Reflection system prompt. Defaults to `None`. + + Output columns: + - thinking (`str`): The step-by-step reasoning process. + + Categories: + - text-generation + + Examples: + Generate text from an instruction: + + ```python + from distilabel.llms import InferenceEndpointsLLM + from distilabel.steps.tasks import TextGenerationWithCotReflection + + task = TextGenerationWithCotReflection( + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + generation_kwargs={"temperature": 0.8, "max_new_tokens": 2048}, + ), + use_cache=False, + ) + + task.load() + + + result = next( + task.process_applying_mappings( + [ + { + "instruction": "If all cats have whiskers, and Fluffy is a cat, but Fluffy doesn't have whiskers, what can we conclude about this situation?" + } + ] + ) + ) + # { + # "instruction": "If all cats have whiskers, and Fluffy is a cat, but Fluffy doesn't have whiskers, what can we conclude about this situation?", + # "thinking": "Let's break down the information provided: \n- All cats have whiskers.\n- Fluffy is a cat.\n- Fluffy doesn't have whiskers...", + # "output": 'We can conclude that either the general rule "all cats have whiskers" is incorrect, ...', + # } + ``` + """ + + system_prompt: Union[str, None] = None + + @property + def inputs(self) -> "StepColumns": + return {"instruction": True, "system_prompt": False} + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + system_prompt = COT_REFLECTION_SYSTEM_PROMPT + if additional_system_prompt := input.get("system_prompt", self.system_prompt): + system_prompt = f"{additional_system_prompt}\n\n{system_prompt}" + return [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": input["instruction"]}, + ] + + @property + def outputs(self) -> "StepColumns": + return ["thinking", "output"] + + def format_output( + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None + ) -> Dict[str, Any]: + if output is None: + return {"thinking": None, "output": None} + + match = COT_REFLECTION_OUTPUT_REGEX.search(output) + if match is None: + return {"thinking": None, "output": None} + + return {"thinking": match.group(1).strip(), "output": match.group(2).strip()} + + class ChatGeneration(Task): """Generates text based on a conversation. diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index 2a6abefb2..889c41a38 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -18,7 +18,12 @@ from distilabel.errors import DistilabelUserError from distilabel.pipeline.local import Pipeline -from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration +from distilabel.steps.tasks.text_generation import ( + COT_REFLECTION_SYSTEM_PROMPT, + ChatGeneration, + TextGeneration, + TextGenerationWithCotReflection, +) from tests.unit.conftest import DummyAsyncLLM @@ -175,6 +180,72 @@ def test_format_input_custom_columns_expected_errors( task.load() +class TestTextGenerationWithCotReflection: + def test_format_input(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithCotReflection(name="task", llm=llm) + task.load() + + assert task.format_input({"instruction": "test"}) == [ + {"role": "system", "content": COT_REFLECTION_SYSTEM_PROMPT}, + {"role": "user", "content": "test"}, + ] + + def test_format_input_with_system_prompt(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithCotReflection( + name="task", llm=llm, system_prompt="test" + ) + task.load() + + assert task.format_input({"instruction": "test"}) == [ + {"role": "system", "content": "test\n\n" + COT_REFLECTION_SYSTEM_PROMPT}, + {"role": "user", "content": "test"}, + ] + + def test_format_input_with_row_system_prompt(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithCotReflection(name="task", llm=llm) + task.load() + + assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [ + {"role": "system", "content": "test\n\n" + COT_REFLECTION_SYSTEM_PROMPT}, + {"role": "user", "content": "test"}, + ] + + def test_format_input_with_row_system_prompt_and_system_prompt(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithCotReflection( + name="task", llm=llm, system_prompt="i won't be used" + ) + task.load() + + assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [ + {"role": "system", "content": "test\n\n" + COT_REFLECTION_SYSTEM_PROMPT}, + {"role": "user", "content": "test"}, + ] + + def test_format_ouptut(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithCotReflection( + name="task", llm=llm, system_prompt="i won't be used" + ) + task.load() + + assert task.format_output(None) == {"thinking": None, "output": None} + assert task.format_output("i'm not following the output format") == { + "thinking": None, + "output": None, + } + + assert task.format_output( + output="\ni'm thinking\n\nI'm having a reflection\n\n\n\ni'm the output\n" + ) == { + "thinking": "i'm thinking\n\nI'm having a reflection\n", + "output": "i'm the output", + } + + class TestChatGeneration: def test_format_input(self) -> None: pipeline = Pipeline(name="unit-test-pipeline")