Skip to content

Commit

Permalink
Add task decorator (#1028)
Browse files Browse the repository at this point in the history
* Add `task` decorator

* Add `task` decorator unit tests

* Add `task` decorator docs

* Fix typing
  • Loading branch information
gabrielmbmb authored Oct 10, 2024
1 parent 95a418c commit 0666dd4
Show file tree
Hide file tree
Showing 6 changed files with 478 additions and 29 deletions.
76 changes: 52 additions & 24 deletions docs/sections/how_to_guides/basic/task/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,35 +211,63 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe

- `format_output`: is a method that receives the output from the [`LLM`][distilabel.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output.

```python
from typing import Any, Dict, List, Union, TYPE_CHECKING
=== "Inherit from `Task`"

When using the `Task` class inheritance method for creating a custom task, we can also optionally override the `Task.process` method to define a more complex processing logic involving an `LLM`, as the default one just calls the `LLM.generate` method once previously formatting the input and subsequently formatting the output. For example, [EvolInstruct][distilabel.steps.tasks.EvolInstruct] task overrides this method to call the `LLM.generate` multiple times (one for each evolution).

from distilabel.steps.tasks.base import Task
```python
from typing import Any, Dict, List, Union, TYPE_CHECKING

if TYPE_CHECKING:
from distilabel.steps.typing import StepColumns
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks import Task

if TYPE_CHECKING:
from distilabel.steps.typing import StepColumns
from distilabel.steps.tasks.typing import ChatType

class MyCustomTask(Task):
@property
def inputs(self) -> "StepColumns":
return ["input_field"]

def format_input(self, input: Dict[str, Any]) -> "ChatType":
return [
{
"role": "user",
"content": input["input_field"],
},
]
class MyCustomTask(Task):
@property
def inputs(self) -> "StepColumns":
return ["input_field"]

@property
def outputs(self) -> "StepColumns":
return ["output_field", "model_name"]
def format_input(self, input: Dict[str, Any]) -> "ChatType":
return [
{
"role": "user",
"content": input["input_field"],
},
]

def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
@property
def outputs(self) -> "StepColumns":
return ["output_field", "model_name"]

def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
return {"output_field": output}
```

=== "Using the `@task` decorator"

If your task just needs a system prompt, a user message template and a way to format the output given by the `LLM`, then you can use the `@task` decorator to avoid writing too much boilerplate code.

```python
from typing import Any, Dict, Union
from distilabel.steps.tasks import task


@task(inputs=["input_field"], outputs=["output_field"])
def MyCustomTask(output: Union[str, None], input: Union[Dict[str, Any], None] = None) -> Dict[str, Any]:
"""
---
system_prompt: |
My custom system prompt

user_message_template: |
My custom user message template: {input_field}
---
"""
# Format the `LLM` output here
return {"output_field": output}
```
```
5 changes: 2 additions & 3 deletions src/distilabel/steps/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
TYPE_CHECKING,
Any,
Callable,
List,
Literal,
Type,
Union,
Expand Down Expand Up @@ -175,10 +174,10 @@ def decorator(func: ProcessingFunc) -> Type["_Step"]:
**runtime_parameters, # type: ignore
)

def inputs_property(self) -> List[str]:
def inputs_property(self) -> "StepColumns":
return inputs

def outputs_property(self) -> List[str]:
def outputs_property(self) -> "StepColumns":
return outputs

def process(
Expand Down
2 changes: 2 additions & 0 deletions src/distilabel/steps/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from distilabel.steps.tasks.base import GeneratorTask, Task
from distilabel.steps.tasks.clair import CLAIR
from distilabel.steps.tasks.complexity_scorer import ComplexityScorer
from distilabel.steps.tasks.decorator import task
from distilabel.steps.tasks.evol_instruct.base import EvolInstruct
from distilabel.steps.tasks.evol_instruct.evol_complexity.base import EvolComplexity
from distilabel.steps.tasks.evol_instruct.evol_complexity.generator import (
Expand Down Expand Up @@ -62,6 +63,7 @@
"APIGenGenerator",
"APIGenSemanticChecker",
"ComplexityScorer",
"task",
"EvolInstruct",
"EvolComplexity",
"EvolComplexityGenerator",
Expand Down
220 changes: 220 additions & 0 deletions src/distilabel/steps/tasks/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import re
from typing import TYPE_CHECKING, Any, Callable, Dict, Final, List, Tuple, Type, Union

import yaml

from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.typing import FormattedInput

if TYPE_CHECKING:
from distilabel.steps.typing import StepColumns


TaskFormattingOutputFunc = Callable[..., Dict[str, Any]]


def task(
inputs: Union["StepColumns", None] = None,
outputs: Union["StepColumns", None] = None,
) -> Callable[..., Type["Task"]]:
"""Creates a `Task` from a formatting output function.
Args:
inputs: a list containing the name of the inputs columns/keys or a dictionary
where the keys are the columns and the values are booleans indicating whether
the column is required or not, that are required by the step. If not provided
the default will be an empty list `[]` and it will be assumed that the step
doesn't need any specific columns. Defaults to `None`.
outputs: a list containing the name of the outputs columns/keys or a dictionary
where the keys are the columns and the values are booleans indicating whether
the column will be generated or not. If not provided the default will be an
empty list `[]` and it will be assumed that the step doesn't need any specific
columns. Defaults to `None`.
"""

inputs = inputs or []
outputs = outputs or []

def decorator(func: TaskFormattingOutputFunc) -> Type["Task"]:
doc = inspect.getdoc(func)
if doc is None:
raise DistilabelUserError(
"When using the `task` decorator, including a docstring in the formatting"
" function is mandatory. The docstring must follow the format described"
" in the documentation.",
page="",
)

system_prompt, user_message_template = _parse_docstring(doc)
_validate_templates(inputs, system_prompt, user_message_template)

def inputs_property(self) -> "StepColumns":
return inputs

def outputs_property(self) -> "StepColumns":
return outputs

def format_input(self, input: Dict[str, Any]) -> "FormattedInput":
return [
{"role": "system", "content": system_prompt.format(**input)},
{"role": "user", "content": user_message_template.format(**input)},
]

def format_output(
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
return func(output, input)

return type(
func.__name__,
(Task,),
{
"inputs": property(inputs_property),
"outputs": property(outputs_property),
"__module__": func.__module__,
"format_input": format_input,
"format_output": format_output,
},
)

return decorator


_SYSTEM_PROMPT_YAML_KEY: Final[str] = "system_prompt"
_USER_MESSAGE_TEMPLATE_YAML_KEY: Final[str] = "user_message_template"
_DOCSTRING_FORMATTING_FUNCTION_ERROR: Final[str] = (
"Formatting function decorated with `task` doesn't follow the expected format. Please,"
" check the documentation and update the function to include a docstring with the expected"
" format."
)


def _parse_docstring(docstring: str) -> Tuple[str, str]:
"""Parses the docstring of the formatting function that was built using the `task`
decorator.
Args:
docstring: the docstring of the formatting function.
Returns:
A tuple containing the system prompt and the user message template.
Raises:
DistilabelUserError: if the docstring doesn't follow the expected format or if
the expected keys are missing.
"""
parts = docstring.split("---")

if len(parts) != 3:
raise DistilabelUserError(
_DOCSTRING_FORMATTING_FUNCTION_ERROR,
page="",
)

yaml_content = parts[1]

try:
parsed_yaml = yaml.safe_load(yaml_content)
if not isinstance(parsed_yaml, dict):
raise DistilabelUserError(
_DOCSTRING_FORMATTING_FUNCTION_ERROR,
page="",
)

system_prompt = parsed_yaml.get(_SYSTEM_PROMPT_YAML_KEY)
user_template = parsed_yaml.get(_USER_MESSAGE_TEMPLATE_YAML_KEY)
if system_prompt is None or user_template is None:
raise DistilabelUserError(
"The formatting function decorated with `task` must include both the `system_prompt`"
" and `user_message_template` keys in the docstring. Please, check the documentation"
" and update the docstring of the formatting function to include the expected"
" keys.",
page="",
)

return system_prompt.strip(), user_template.strip()

except yaml.YAMLError as e:
raise DistilabelUserError(_DOCSTRING_FORMATTING_FUNCTION_ERROR, page="") from e


TEMPLATE_PLACEHOLDERS_REGEX = re.compile(r"\{(\w+)\}")


def _validate_templates(
inputs: "StepColumns", system_prompt: str, user_message_template: str
) -> None:
"""Validates the system prompt and user message template to ensure that they only
contain the allowed placeholders i.e. the columns/keys that are provided as inputs.
Args:
inputs: the list of inputs columns/keys.
system_prompt: the system prompt.
user_message_template: the user message template.
Raises:
DistilabelUserError: if the system prompt or the user message template contain
invalid placeholders.
"""
list_inputs = list(inputs.keys()) if isinstance(inputs, dict) else inputs

valid_system_prompt, invalid_system_prompt_placeholders = _validate_template(
system_prompt, list_inputs
)
if not valid_system_prompt:
raise DistilabelUserError(
f"The formatting function decorated with `task` includes invalid placeholders"
f" in the extracted `system_prompt` from the function docstring. Valid placeholders"
f" are: {list_inputs}, but the following placeholders were found: {invalid_system_prompt_placeholders}."
f" Please, update the `system_prompt` to only include the valid placeholders.",
page="",
)

valid_user_message_template, invalid_user_message_template_placeholders = (
_validate_template(user_message_template, list_inputs)
)
if not valid_user_message_template:
raise DistilabelUserError(
f"The formatting function decorated with `task` includes invalid placeholders"
f" in the extracted `user_message_template` from the function docstring. Valid"
f" placeholders are: {list_inputs}, but the following placeholders were found:"
f" {invalid_user_message_template_placeholders}. Please, update the `system_prompt`"
" to only include the valid placeholders.",
page="",
)


def _validate_template(
template: str, allowed_placeholders: List[str]
) -> Tuple[bool, set[str]]:
"""Validates that the template only contains the allowed placeholders.
Args:
template: the template to validate.
allowed_placeholders: the list of allowed placeholders.
Returns:
A tuple containing a boolean indicating if the template is valid and a set
with the invalid placeholders.
"""
placeholders = set(TEMPLATE_PLACEHOLDERS_REGEX.findall(template))
allowed_placeholders_set = set(allowed_placeholders)
are_valid = placeholders.issubset(allowed_placeholders_set)
invalid_placeholders = placeholders - allowed_placeholders_set
return are_valid, invalid_placeholders
4 changes: 2 additions & 2 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def model_name(self) -> str:

def generate( # type: ignore
self, inputs: "FormattedInput", num_generations: int = 1
) -> "GenerateOutput":
return ["output" for _ in range(num_generations)]
) -> List["GenerateOutput"]:
return [["output" for _ in range(num_generations)]]


class DummyMagpieLLM(LLM, MagpieChatTemplateMixin):
Expand Down
Loading

0 comments on commit 0666dd4

Please sign in to comment.