Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TextGenerationWithCotReflection task #1031

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/distilabel/steps/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@
from distilabel.steps.tasks.sentence_transformers import GenerateSentencePair
from distilabel.steps.tasks.structured_generation import StructuredGeneration
from distilabel.steps.tasks.text_classification import TextClassification
from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration
from distilabel.steps.tasks.text_generation import (
ChatGeneration,
TextGeneration,
TextGenerationWithCotReflection,
)
from distilabel.steps.tasks.typing import ChatItem, ChatType
from distilabel.steps.tasks.ultrafeedback import UltraFeedback
from distilabel.steps.tasks.urial import URIAL
Expand Down Expand Up @@ -90,6 +94,7 @@
"TextClassification",
"ChatGeneration",
"TextGeneration",
"TextGenerationWithCotReflection",
"ChatItem",
"ChatType",
"CLAIR",
Expand Down
131 changes: 131 additions & 0 deletions src/distilabel/steps/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,137 @@ def format_output(
return {"generation": output}


# Slighty modified version: https://github.com/codelion/optillm/blob/main/optillm/cot_reflection.py
COT_REFLECTION_SYSTEM_PROMPT = """
You are an AI assistant that uses a Chain of Thought (CoT) approach with reflection to answer queries. Follow these steps:

1. Think through the problem step by step within the <thinking> tags.
2. Reflect on your thinking to check for any errors or improvements within the <reflection> tags.
3. Make any necessary adjustments based on your reflection.
4. Provide your final, concise answer within the <output> tags.

Important: The <thinking> and <reflection> sections are for your internal reasoning process only.
Do not include any part of the final answer in these sections.
You can only create one <thinking> and the <reflection> blocks must be contained within it.
The actual response to the query must be entirely contained within the <output> tags.
You must always include </output> at the end of the generation.

Use the following format for your response:

```
<thinking>
[Your step-by-step reasoning goes here. This is your internal thought process, not the final answer.]
<reflection>
[Your reflection on your reasoning, checking for errors or improvements]
</reflection>
[Any adjustments to your thinking based on your reflection]
</thinking>
<output>
[Your final, concise answer to the query. This is the only part that will be shown to the user.]
</output>
```
""".lstrip()

# Sometimes `LLM`s doesn't generate the `</output>` that's why it's optional
COT_REFLECTION_OUTPUT_REGEX = re.compile(
r"<thinking>([\s\S]*?)</thinking>\s*<output>([\s\S]*?)(?:</output>)?"
)


class TextGenerationWithCotReflection(Task):
"""Text generation with an `LLM` using Chain of Thought (CoT) reflection.

`TextGenerationWithCotReflection` is a `Task` that allows generating a response for
a given instruction using a Chain of Thought (CoT) approach with reflection. The `LLM`
will first think through the problem step by step, reflect on the thinking process, make
any necessary adjustments based on the reflection, and provide a final, concise answer.
This method usually helps in generating more accurate and thoughtful responses at the
cost of generating more tokens and being slower.

Attributes:
system_prompt: The system prompt to use in the generation and that will be appended
to the CoT Reflection system prompt. If not provided, then it will check if
the input row has a column named `system_prompt` and use it. If not, then no
system prompt will be used. Defaults to `None`.

Input columns:
- instruction (`str`): The instruction to generate the response.
- system_prompt (`str`, optional): The system prompt to use in the generation and
that will be appended to the CoT Reflection system prompt. Defaults to `None`.

Output columns:
- thinking (`str`): The step-by-step reasoning process.

Categories:
- text-generation

Examples:
Generate text from an instruction:

```python
from distilabel.llms import InferenceEndpointsLLM
from distilabel.steps.tasks import TextGenerationWithCotReflection

task = TextGenerationWithCotReflection(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
generation_kwargs={"temperature": 0.8, "max_new_tokens": 2048},
),
use_cache=False,
)

task.load()


result = next(
task.process_applying_mappings(
[
{
"instruction": "If all cats have whiskers, and Fluffy is a cat, but Fluffy doesn't have whiskers, what can we conclude about this situation?"
}
]
)
)
# {
# "instruction": "If all cats have whiskers, and Fluffy is a cat, but Fluffy doesn't have whiskers, what can we conclude about this situation?",
# "thinking": "Let's break down the information provided: \n- All cats have whiskers.\n- Fluffy is a cat.\n- Fluffy doesn't have whiskers...",
# "output": 'We can conclude that either the general rule "all cats have whiskers" is incorrect, ...',
# }
```
"""

system_prompt: Union[str, None] = None

@property
def inputs(self) -> "StepColumns":
return {"instruction": True, "system_prompt": False}

def format_input(self, input: Dict[str, Any]) -> "ChatType":
system_prompt = COT_REFLECTION_SYSTEM_PROMPT
if additional_system_prompt := input.get("system_prompt", self.system_prompt):
system_prompt = f"{additional_system_prompt}\n\n{system_prompt}"
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": input["instruction"]},
]

@property
def outputs(self) -> "StepColumns":
return ["thinking", "output"]

def format_output(
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
if output is None:
return {"thinking": None, "output": None}

match = COT_REFLECTION_OUTPUT_REGEX.search(output)
if match is None:
return {"thinking": None, "output": None}

return {"thinking": match.group(1).strip(), "output": match.group(2).strip()}


class ChatGeneration(Task):
"""Generates text based on a conversation.

Expand Down
73 changes: 72 additions & 1 deletion tests/unit/steps/tasks/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@

from distilabel.errors import DistilabelUserError
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration
from distilabel.steps.tasks.text_generation import (
COT_REFLECTION_SYSTEM_PROMPT,
ChatGeneration,
TextGeneration,
TextGenerationWithCotReflection,
)
from tests.unit.conftest import DummyAsyncLLM


Expand Down Expand Up @@ -175,6 +180,72 @@ def test_format_input_custom_columns_expected_errors(
task.load()


class TestTextGenerationWithCotReflection:
def test_format_input(self) -> None:
llm = DummyAsyncLLM()
task = TextGenerationWithCotReflection(name="task", llm=llm)
task.load()

assert task.format_input({"instruction": "test"}) == [
{"role": "system", "content": COT_REFLECTION_SYSTEM_PROMPT},
{"role": "user", "content": "test"},
]

def test_format_input_with_system_prompt(self) -> None:
llm = DummyAsyncLLM()
task = TextGenerationWithCotReflection(
name="task", llm=llm, system_prompt="test"
)
task.load()

assert task.format_input({"instruction": "test"}) == [
{"role": "system", "content": "test\n\n" + COT_REFLECTION_SYSTEM_PROMPT},
{"role": "user", "content": "test"},
]

def test_format_input_with_row_system_prompt(self) -> None:
llm = DummyAsyncLLM()
task = TextGenerationWithCotReflection(name="task", llm=llm)
task.load()

assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [
{"role": "system", "content": "test\n\n" + COT_REFLECTION_SYSTEM_PROMPT},
{"role": "user", "content": "test"},
]

def test_format_input_with_row_system_prompt_and_system_prompt(self) -> None:
llm = DummyAsyncLLM()
task = TextGenerationWithCotReflection(
name="task", llm=llm, system_prompt="i won't be used"
)
task.load()

assert task.format_input({"instruction": "test", "system_prompt": "test"}) == [
{"role": "system", "content": "test\n\n" + COT_REFLECTION_SYSTEM_PROMPT},
{"role": "user", "content": "test"},
]

def test_format_ouptut(self) -> None:
llm = DummyAsyncLLM()
task = TextGenerationWithCotReflection(
name="task", llm=llm, system_prompt="i won't be used"
)
task.load()

assert task.format_output(None) == {"thinking": None, "output": None}
assert task.format_output("i'm not following the output format") == {
"thinking": None,
"output": None,
}

assert task.format_output(
output="<thinking>\ni'm thinking\n<reflection>\nI'm having a reflection\n</reflection>\n</thinking>\n<output>\ni'm the output\n</output>"
) == {
"thinking": "i'm thinking\n<reflection>\nI'm having a reflection\n</reflection>",
"output": "i'm the output",
}


class TestChatGeneration:
def test_format_input(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
Expand Down
Loading