Skip to content

Commit

Permalink
Initial version of the generator
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Nov 6, 2024
1 parent 80a76b3 commit 6d63567
Showing 1 changed file with 19 additions and 43 deletions.
62 changes: 19 additions & 43 deletions src/distilabel/steps/tasks/math_shepherd/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@
# limitations under the License.

import json
import re
from typing import TYPE_CHECKING, Any, Dict, Final, Optional, Union

from jinja2 import Template
from pydantic import PositiveInt

from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.math_shepherd.utils import split_solution_steps

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.typing import StepColumns


SYSTEM_PROMPT = """\
You are a math tutor that helps students solve math problems by breaking them down into clear, logical steps{% if include_errors %}, but trick them by including unannounced errors in the reasoning without explaining them{% endif %}. Follow these guidelines:
You are a math tutor that helps students solve math problems by breaking them down into clear, logical steps. Follow these guidelines:
# For each step:
- Clearly explain the reasoning
Expand Down Expand Up @@ -97,53 +97,22 @@
Step 4: Therefore, the answer is $\boxed{59}$. The answer is: 59
"""


TEMPLATE = """\
This is your instruction, don't include it in your answer:
{{ instruction }}{% if example_solutions %}{{ example_solutions }}{% endif %}"""


def split_solution_steps(text):
"""
Split a step-by-step solution text into individual components.
Returns a list of steps and the final answer.
"""
# Pattern to match:
# 1. Steps starting with "Step N:" and capturing all content until the next step or answer
# 2. The final answer starting with "The answer is:"
pattern = r"Step \d+:.*?(?=Step \d+:|The answer is:|$)|The answer is:.*"

# Find all matches, strip whitespace
matches = [match.strip() for match in re.findall(pattern, text, re.DOTALL)]

return matches
TEMPLATE = """{{ instruction }}
{% if M %}Generate {{ M }} example solutions to the same problem, separated by a single `---`{% endif %}"""


class MathShepherdGenerator(Task):
system_prompt: Optional[str] = SYSTEM_PROMPT
extra_rules: Optional[str] = RULES_GSM8K
few_shots: Optional[str] = FEW_SHOTS_GSM8K
include_errors: bool = True
N: PositiveInt = 1
M: Optional[PositiveInt] = None

def load(self) -> None:
super().load()
errors = ""
self._example_solutions = ""
if self.include_errors:
errors = "\n\nInclude errors to help students learn from their mistakes in any of the steps, including the final answer."
self._example_solutions = (
f"\n\nGenerate {self.N} example solution"
if self.N == 1
else f"\n\nGenerate {self.N} example solutions to the same problem, separated by a single `---`"
)

if self.system_prompt is not None:
self.system_prompt = Template(self.system_prompt).render(
extra_rules=self.extra_rules or "",
few_shots=self.few_shots or "",
errors=errors,
include_errors=self.include_errors,
)
self._template = Template(TEMPLATE)

Expand All @@ -153,15 +122,19 @@ def inputs(self) -> "StepColumns":

@property
def outputs(self) -> "StepColumns":
return ["steps", "model_name"]
return {
"solutions": False,
"golden_solution": False,
"model_name": True,
}

def format_input(self, input: Dict[str, Any]) -> "ChatType":
messages = [
{
"role": "user",
"content": self._template.render(
instruction=input["instruction"],
example_solutions=self._example_solutions,
M=self.M,
),
}
]
Expand All @@ -172,12 +145,15 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
def format_output(
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
output_name = "solutions" if self.M else "golden_solution"
if output is None:
input.update(**{"steps": None})
input.update(**{output_name: None})
return input
if self.include_errors:
examples = [split_solution_steps(o) for o in output.split("---")]

if self.M:
solutions = [split_solution_steps(o) for o in output.split("---")]
else:
examples = [split_solution_steps(output)]
input.update(**{"steps": json.dumps(examples)})
solutions = split_solution_steps(output)

input.update(**{output_name: json.dumps(solutions)})
return input

0 comments on commit 6d63567

Please sign in to comment.