Skip to content

Commit

Permalink
Update generate_prompt in Task subclasses to always return `Promp…
Browse files Browse the repository at this point in the history
…t` (#199)

* Remove LLM-specific `Task` implementations

* Fix `generate_prompt` to always return `Prompt` instead of `Any`

* Update `examples/*.py`

* Fix `Pipeline` docstrings

Including both fixes within the examples themselves, and also w.r.t. the recent removal of the `Llama2TextGenerationTask`

* Update `docs/` upon `{Llama2,OpenAI}TextGenerationTask` removal

* Remove extra line-break in `formatted_prompt` of `UltraCMTask.generate_prompt`

Since the default prompt format will be applied within `Prompt` if not format is specified, and it implies joining both the `system_prompt` and `formatted_prompt` with a line-break, then the leading line-break is not needed within the `formatted_prompt`

* Update `task` arg type-hint to `Task`
  • Loading branch information
alvarobartt authored Dec 27, 2023
1 parent fe16a4b commit 0407bd0
Show file tree
Hide file tree
Showing 16 changed files with 52 additions and 229 deletions.
5 changes: 3 additions & 2 deletions docs/snippets/technical-reference/llm/openai_generate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os

from distilabel.llm import OpenAILLM
from distilabel.tasks import OpenAITextGenerationTask
from distilabel.tasks import TextGenerationTask

openaillm = OpenAILLM(
model="gpt-3.5-turbo",
task=OpenAITextGenerationTask(),
task=TextGenerationTask(),
prompt_format="openai",
max_new_tokens=256,
openai_api_key=os.environ.get("OPENAI_API_KEY"),
temperature=0.3,
Expand Down

This file was deleted.

This file was deleted.

20 changes: 0 additions & 20 deletions docs/technical-reference/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,6 @@ This is the base class for *text generation*, and includes the following fields

For the API reference visit [TextGenerationTask][distilabel.tasks.text_generation.base.TextGenerationTask].

### Llama2TextGenerationTask

This class inherits from the `TextGenerationTask` and it's specially prepared to deal with prompts in the form of the *Llama2* model, so it should be the go to task for `LLMs` intented for text generation that were trained using this prompt format. The specific prompt formats can be found in the source code of the [Prompt][distilabel.tasks.prompt.Prompt] class.

```python
--8<-- "docs/snippets/technical-reference/tasks/generic_llama2_textgeneration.py"
```

For the API reference visit [Llama2TextGenerationTask][distilabel.tasks.text_generation.llama.Llama2TextGenerationTask].

### OpenAITextGenerationTask

The OpenAI task for text generation is similar to the `Llama2TextGenerationTask`, but with the specific prompt format expected by the *chat completion* task from OpenAI.

```python
--8<-- "docs/snippets/technical-reference/tasks/generic_openai_textgeneration.py"
```

For the API reference visit [OpenAITextGenerationTask][distilabel.tasks.text_generation.openai.OpenAITextGenerationTask].

### SelfInstructTask

The task specially designed to build the prompts following the Self-Instruct paper: [SELF-INSTRUCT: Aligning Language Models
Expand Down
9 changes: 5 additions & 4 deletions examples/inference-endpoints-llm-custom-task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
from typing import Dict

from distilabel.llm import InferenceEndpointsLLM
from distilabel.tasks import Llama2TextGenerationTask, Prompt
from distilabel.tasks import Prompt, TextGenerationTask


class Llama2QuestionAnsweringTask(Llama2TextGenerationTask):
def generate_prompt(self, question: str) -> str:
class Llama2QuestionAnsweringTask(TextGenerationTask):
def generate_prompt(self, question: str) -> Prompt:
return Prompt(
system_prompt=self.system_prompt,
formatted_prompt=question,
).format_as("llama2") # type: ignore
)

def parse_output(self, output: str) -> Dict[str, str]:
return {"answer": output.strip()}
Expand All @@ -47,6 +47,7 @@ def output_args_names(self) -> list[str]:
endpoint_namespace=os.getenv("HF_NAMESPACE"), # type: ignore
token=os.getenv("HF_TOKEN", None),
task=Llama2QuestionAnsweringTask(),
prompt_format="llama2",
)
print(llm.generate([{"question": "What's the capital of Spain?"}]))
# Output: [
Expand Down
2 changes: 1 addition & 1 deletion examples/pipeline-accelerate-and-openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

def get_current_device() -> int:
"""Get the current device. For GPU we return the local process index to enable multiple GPU training."""
return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" # type: ignore


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions examples/pipeline-fn-ultrafeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from datasets import load_dataset
from distilabel.llm import InferenceEndpointsLLM
from distilabel.pipeline import pipeline
from distilabel.tasks import Llama2TextGenerationTask
from distilabel.tasks import TextGenerationTask

if __name__ == "__main__":
dataset = (
Expand All @@ -33,7 +33,8 @@
generator=InferenceEndpointsLLM(
endpoint_name=os.getenv("HF_INFERENCE_ENDPOINT_NAME"), # type: ignore
endpoint_namespace=os.getenv("HF_NAMESPACE", None),
task=Llama2TextGenerationTask(),
task=TextGenerationTask(),
prompt_format="llama2",
max_new_tokens=256,
num_threads=2,
temperature=0.3,
Expand Down
2 changes: 1 addition & 1 deletion examples/pipeline-pool-llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def load_openai(task):
)

dataset = pipeline.generate(
dataset=dataset,
dataset=dataset, # type: ignore
num_generations=3,
batch_size=5,
)
10 changes: 7 additions & 3 deletions examples/pipeline-preference-dataset-llmpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def load_neural_chat(task: Task) -> LLM:
)


def load_gpt_4(task: UltraFeedbackTask) -> LLM:
def load_gpt_4(task: Task) -> LLM:
from distilabel.llm import OpenAILLM

return OpenAILLM(
Expand All @@ -108,7 +108,8 @@ def load_gpt_4(task: UltraFeedbackTask) -> LLM:
]
),
labeller=ProcessLLM(
task=UltraFeedbackTask.for_instruction_following(), load_llm_fn=load_gpt_4
task=UltraFeedbackTask.for_instruction_following(),
load_llm_fn=load_gpt_4,
),
)

Expand All @@ -119,7 +120,10 @@ def load_gpt_4(task: UltraFeedbackTask) -> LLM:
)

dataset = pipeline.generate(
dataset=dataset, num_generations=2, batch_size=10, display_progress_bar=True
dataset=dataset, # type: ignore
num_generations=2,
batch_size=10,
display_progress_bar=True, # type: ignore
)

rg_argilla = dataset.to_argilla()
Expand Down
87 changes: 23 additions & 64 deletions src/distilabel/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,15 @@ def __init__(
ValueError: if no LLM is provided.
Examples:
>>> from distilabel.llm.huggingface import TransformersLLM
>>> from distilabel.llm.openai_ import OpenAILLM
>>> from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask
>>> from distilabel.tasks.text_generation.llama import Llama2TextGenerationTask
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> from distilabel.llm import OpenAILLM, TransformersLLM
>>> from distilabel.tasks import TextGenerationTask, UltraFeedbackTask
>>> from distilabel.pipeline import Pipeline
>>> generator = TransformersLLM(
... model="meta-llama/Llama-2-7b-chat-hf",
... tokenizer="meta-llama/Llama-2-7b-chat-hf",
... task=Llama2TextGenerationTask(),
... model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf"),
... tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf"),
... task=TextGenerationTask(),
... prompt_format="llama2",
... )
>>> labeller = OpenAILLM(
... model="gpt-3.5-turbo",
Expand Down Expand Up @@ -532,49 +531,8 @@ def _generate( # noqa: C901
display_progress_bar: bool = False,
) -> CustomDataset:
"""Generates the outputs for the given dataset using the LLMs provided to the
`Pipeline`.
Args:
dataset (Dataset): the dataset to be used for generation.
num_generations (int, optional): the number of generations to be performed
for each input. Defaults to `1`.
batch_size (int, optional): the batch size to be used for generation. Defaults
to `1`.
shuffle_before_labelling (bool, optional): whether to shuffle the generations
before labelling or not. This is useful to avoid the labelling LLM to be
biased by the order of the generations. Defaults to `True`.
enable_checkpoints (bool, optional): whether to enable checkpoints or not.
Defaults to `True`.
display_progress_bar (bool, optional): whether to display the progress bar
or not. Defaults to `False`.
`Pipeline`."""

Returns:
CustomDataset: the final dataset.
Raises:
RuntimeError: if the `Pipeline` fails during the generation or labelling steps.
UserWarning: if the `Pipeline` fails during the generation or labelling steps
and `enable_checkpoints` is set to `False`.
Examples:
>>> from distilabel.llm.huggingface import TransformersLLM
>>> from distilabel.llm.openai_ import OpenAILLM
>>> from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask
>>> from distilabel.tasks.text_generation.llama import Llama2TextGenerationTask
>>> from distilabel.pipeline import Pipeline
>>> generator = TransformersLLM(
... model="meta-llama/Llama-2-7b-chat-hf",
... tokenizer="meta-llama/Llama-2-7b-chat-hf",
... task=Llama2TextGenerationTask(),
... )
>>> labeller = OpenAILLM(
... model="gpt-3.5-turbo",
... task=UltraFeedbackTask.for_text_quality(),
... )
>>> pipeline = Pipeline(generator=generator, labeller=labeller)
>>> dataset = pipeline.generate(dataset=..., num_generations=1, batch_size=1)
"""
if (
self.labeller is not None
and self.generator is not None
Expand Down Expand Up @@ -739,16 +697,15 @@ def generate(
`enable_checkpoints` is set to `False`.
Examples:
>>> from distilabel.llm.huggingface import TransformersLLM
>>> from distilabel.llm.openai_ import OpenAILLM
>>> from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask
>>> from distilabel.tasks.text_generation.llama import Llama2TextGenerationTask
>>> from transformers import AutoModelForCaualLM, AutoTokenizer
>>> from distilabel.llm import OpenAILLM, TransformersLLM
>>> from distilabel.tasks import TextGenerationTask, UltraFeedbackTask
>>> from distilabel.pipeline import Pipeline
>>> generator = TransformersLLM(
... model="meta-llama/Llama-2-7b-chat-hf",
... tokenizer="meta-llama/Llama-2-7b-chat-hf",
... task=Llama2TextGenerationTask(),
... model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf"),
... tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf"),
... task=TextGenerationTask(),
... prompt_format="llama2",
... )
>>> labeller = OpenAILLM(
... model="gpt-3.5-turbo",
Expand Down Expand Up @@ -808,20 +765,22 @@ def pipeline(
Pipeline: the `Pipeline` instance.
Examples:
>>> from distilabel.llm.huggingface import TransformersLLM
>>> from distilabel.tasks.text_generation.llama import Llama2TextGenerationTask
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> from distilabel.llm import TransformersLLM
>>> from distilabel.tasks import TextGenerationTask
>>> from distilabel.pipeline import pipeline
>>> generator = TransformersLLM(
... model="meta-llama/Llama-2-7b-chat-hf",
... tokenizer="meta-llama/Llama-2-7b-chat-hf",
... task=Llama2TextGenerationTask(),
... model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf"),
... tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf"),
... task=TextGenerationTask(),
... prompt_format="llama2",
... )
>>> pipeline = pipeline(
... task="preference",
... subtask="text-quality",
... generator=generator,
... )
>>> dataset = pipeline.generate(dataset=..., num_generations=1, batch_size=1)
"""
if task == "preference":
if labeller is None:
Expand Down
4 changes: 0 additions & 4 deletions src/distilabel/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from distilabel.tasks.preference.ultrajudge import UltraJudgeTask
from distilabel.tasks.prompt import Prompt
from distilabel.tasks.text_generation.base import TextGenerationTask
from distilabel.tasks.text_generation.llama import Llama2TextGenerationTask
from distilabel.tasks.text_generation.openai import OpenAITextGenerationTask
from distilabel.tasks.text_generation.self_instruct import SelfInstructTask

__all__ = [
Expand All @@ -35,7 +33,5 @@
"UltraJudgeTask",
"Prompt",
"TextGenerationTask",
"OpenAITextGenerationTask",
"Llama2TextGenerationTask",
"SelfInstructTask",
]
2 changes: 1 addition & 1 deletion src/distilabel/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def template(self) -> "Template":
return Template(open(self.__jinja2_template__).read())

@abstractmethod
def generate_prompt(self, **kwargs: Any) -> Union[Prompt, Any]:
def generate_prompt(self, **kwargs: Any) -> Prompt:
pass

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/tasks/critique/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def input_args_names(self) -> List[str]:

def generate_prompt(
self, input: str, generations: str, ref_completion: str, **_: Any
) -> str:
) -> Prompt:
render_kwargs = {
"instruction": input,
"completion": generations,
Expand All @@ -49,7 +49,7 @@ def generate_prompt(
return Prompt(
system_prompt=self.system_prompt,
formatted_prompt=self.template.render(**render_kwargs),
).format_as(format="llama2") # type: ignore
)

def parse_output(self, output: str) -> CritiqueTaskOutput: # type: ignore
"""Parses the output of the model into the desired format."""
Expand Down
8 changes: 6 additions & 2 deletions src/distilabel/tasks/critique/ultracm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from distilabel.tasks.base import get_template
from distilabel.tasks.critique.base import CritiqueTask, CritiqueTaskOutput
from distilabel.tasks.prompt import Prompt

_ULTRACM_TEMPLATE = get_template("ultracm.jinja2")

Expand All @@ -32,12 +33,15 @@ class UltraCMTask(CritiqueTask):
" the user's questions.</s>"
)

def generate_prompt(self, input: str, generations: str, **_: Any) -> str:
def generate_prompt(self, input: str, generations: str, **_: Any) -> Prompt:
render_kwargs = {
"instruction": input,
"completion": generations,
}
return f"{self.system_prompt}\nUser: {self.template.render(**render_kwargs)}</s>\nAssistant: ### Feedback\nOverall Score: "
return Prompt(
system_prompt=self.system_prompt,
formatted_prompt=f"User: {self.template.render(**render_kwargs)}</s>\nAssistant: ### Feedback\nOverall Score: ",
)

def parse_output(self, output: str) -> CritiqueTaskOutput: # type: ignore
"""Parses the output of the model into the desired format."""
Expand Down
Loading

0 comments on commit 0407bd0

Please sign in to comment.