diff --git a/docs/snippets/technical-reference/llm/openai_generate.py b/docs/snippets/technical-reference/llm/openai_generate.py index 7b41df7781..77fbc995e2 100644 --- a/docs/snippets/technical-reference/llm/openai_generate.py +++ b/docs/snippets/technical-reference/llm/openai_generate.py @@ -1,11 +1,12 @@ import os from distilabel.llm import OpenAILLM -from distilabel.tasks import OpenAITextGenerationTask +from distilabel.tasks import TextGenerationTask openaillm = OpenAILLM( model="gpt-3.5-turbo", - task=OpenAITextGenerationTask(), + task=TextGenerationTask(), + prompt_format="openai", max_new_tokens=256, openai_api_key=os.environ.get("OPENAI_API_KEY"), temperature=0.3, diff --git a/docs/snippets/technical-reference/tasks/generic_llama2_textgeneration.py b/docs/snippets/technical-reference/tasks/generic_llama2_textgeneration.py deleted file mode 100644 index 49ab3ca205..0000000000 --- a/docs/snippets/technical-reference/tasks/generic_llama2_textgeneration.py +++ /dev/null @@ -1,9 +0,0 @@ -from distilabel.llm import TransformersLLM -from distilabel.tasks import Llama2TextGenerationTask - -# This snippet uses `TransformersLLM`, but is the same for every other `LLM`. -generator = TransformersLLM( - model=..., - tokenizer=..., - task=Llama2TextGenerationTask(), -) diff --git a/docs/snippets/technical-reference/tasks/generic_openai_textgeneration.py b/docs/snippets/technical-reference/tasks/generic_openai_textgeneration.py deleted file mode 100644 index 7f10affc03..0000000000 --- a/docs/snippets/technical-reference/tasks/generic_openai_textgeneration.py +++ /dev/null @@ -1,8 +0,0 @@ -import os - -from distilabel.llm import OpenAILLM -from distilabel.tasks import OpenAITextGenerationTask - -generator = OpenAILLM( - task=OpenAITextGenerationTask(), openai_api_key=os.getenv("OPENAI_API_KEY") -) diff --git a/docs/technical-reference/tasks.md b/docs/technical-reference/tasks.md index 794365f8f0..1105f20957 100644 --- a/docs/technical-reference/tasks.md +++ b/docs/technical-reference/tasks.md @@ -38,26 +38,6 @@ This is the base class for *text generation*, and includes the following fields For the API reference visit [TextGenerationTask][distilabel.tasks.text_generation.base.TextGenerationTask]. -### Llama2TextGenerationTask - -This class inherits from the `TextGenerationTask` and it's specially prepared to deal with prompts in the form of the *Llama2* model, so it should be the go to task for `LLMs` intented for text generation that were trained using this prompt format. The specific prompt formats can be found in the source code of the [Prompt][distilabel.tasks.prompt.Prompt] class. - -```python ---8<-- "docs/snippets/technical-reference/tasks/generic_llama2_textgeneration.py" -``` - -For the API reference visit [Llama2TextGenerationTask][distilabel.tasks.text_generation.llama.Llama2TextGenerationTask]. - -### OpenAITextGenerationTask - -The OpenAI task for text generation is similar to the `Llama2TextGenerationTask`, but with the specific prompt format expected by the *chat completion* task from OpenAI. - -```python ---8<-- "docs/snippets/technical-reference/tasks/generic_openai_textgeneration.py" -``` - -For the API reference visit [OpenAITextGenerationTask][distilabel.tasks.text_generation.openai.OpenAITextGenerationTask]. - ### SelfInstructTask The task specially designed to build the prompts following the Self-Instruct paper: [SELF-INSTRUCT: Aligning Language Models diff --git a/examples/inference-endpoints-llm-custom-task.py b/examples/inference-endpoints-llm-custom-task.py index 5099bdc81d..d215ae831f 100644 --- a/examples/inference-endpoints-llm-custom-task.py +++ b/examples/inference-endpoints-llm-custom-task.py @@ -16,15 +16,15 @@ from typing import Dict from distilabel.llm import InferenceEndpointsLLM -from distilabel.tasks import Llama2TextGenerationTask, Prompt +from distilabel.tasks import Prompt, TextGenerationTask -class Llama2QuestionAnsweringTask(Llama2TextGenerationTask): - def generate_prompt(self, question: str) -> str: +class Llama2QuestionAnsweringTask(TextGenerationTask): + def generate_prompt(self, question: str) -> Prompt: return Prompt( system_prompt=self.system_prompt, formatted_prompt=question, - ).format_as("llama2") # type: ignore + ) def parse_output(self, output: str) -> Dict[str, str]: return {"answer": output.strip()} @@ -47,6 +47,7 @@ def output_args_names(self) -> list[str]: endpoint_namespace=os.getenv("HF_NAMESPACE"), # type: ignore token=os.getenv("HF_TOKEN", None), task=Llama2QuestionAnsweringTask(), + prompt_format="llama2", ) print(llm.generate([{"question": "What's the capital of Spain?"}])) # Output: [ diff --git a/examples/pipeline-accelerate-and-openai.py b/examples/pipeline-accelerate-and-openai.py index 43cfd4c3ea..f4d5a077eb 100644 --- a/examples/pipeline-accelerate-and-openai.py +++ b/examples/pipeline-accelerate-and-openai.py @@ -31,7 +31,7 @@ def get_current_device() -> int: """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" - return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" + return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" # type: ignore if __name__ == "__main__": diff --git a/examples/pipeline-fn-ultrafeedback.py b/examples/pipeline-fn-ultrafeedback.py index 24e614c383..257a46909d 100644 --- a/examples/pipeline-fn-ultrafeedback.py +++ b/examples/pipeline-fn-ultrafeedback.py @@ -18,7 +18,7 @@ from datasets import load_dataset from distilabel.llm import InferenceEndpointsLLM from distilabel.pipeline import pipeline -from distilabel.tasks import Llama2TextGenerationTask +from distilabel.tasks import TextGenerationTask if __name__ == "__main__": dataset = ( @@ -33,7 +33,8 @@ generator=InferenceEndpointsLLM( endpoint_name=os.getenv("HF_INFERENCE_ENDPOINT_NAME"), # type: ignore endpoint_namespace=os.getenv("HF_NAMESPACE", None), - task=Llama2TextGenerationTask(), + task=TextGenerationTask(), + prompt_format="llama2", max_new_tokens=256, num_threads=2, temperature=0.3, diff --git a/examples/pipeline-pool-llm.py b/examples/pipeline-pool-llm.py index 868acf0cc8..56b2389a80 100644 --- a/examples/pipeline-pool-llm.py +++ b/examples/pipeline-pool-llm.py @@ -64,7 +64,7 @@ def load_openai(task): ) dataset = pipeline.generate( - dataset=dataset, + dataset=dataset, # type: ignore num_generations=3, batch_size=5, ) diff --git a/examples/pipeline-preference-dataset-llmpool.py b/examples/pipeline-preference-dataset-llmpool.py index d069bed2d3..a0904c7aa9 100644 --- a/examples/pipeline-preference-dataset-llmpool.py +++ b/examples/pipeline-preference-dataset-llmpool.py @@ -86,7 +86,7 @@ def load_neural_chat(task: Task) -> LLM: ) -def load_gpt_4(task: UltraFeedbackTask) -> LLM: +def load_gpt_4(task: Task) -> LLM: from distilabel.llm import OpenAILLM return OpenAILLM( @@ -108,7 +108,8 @@ def load_gpt_4(task: UltraFeedbackTask) -> LLM: ] ), labeller=ProcessLLM( - task=UltraFeedbackTask.for_instruction_following(), load_llm_fn=load_gpt_4 + task=UltraFeedbackTask.for_instruction_following(), + load_llm_fn=load_gpt_4, ), ) @@ -119,7 +120,10 @@ def load_gpt_4(task: UltraFeedbackTask) -> LLM: ) dataset = pipeline.generate( - dataset=dataset, num_generations=2, batch_size=10, display_progress_bar=True + dataset=dataset, # type: ignore + num_generations=2, + batch_size=10, + display_progress_bar=True, # type: ignore ) rg_argilla = dataset.to_argilla() diff --git a/src/distilabel/pipeline.py b/src/distilabel/pipeline.py index 4adb7e9b60..f92363f25f 100644 --- a/src/distilabel/pipeline.py +++ b/src/distilabel/pipeline.py @@ -66,16 +66,15 @@ def __init__( ValueError: if no LLM is provided. Examples: - >>> from distilabel.llm.huggingface import TransformersLLM - >>> from distilabel.llm.openai_ import OpenAILLM - >>> from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask - >>> from distilabel.tasks.text_generation.llama import Llama2TextGenerationTask + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + >>> from distilabel.llm import OpenAILLM, TransformersLLM + >>> from distilabel.tasks import TextGenerationTask, UltraFeedbackTask >>> from distilabel.pipeline import Pipeline - >>> generator = TransformersLLM( - ... model="meta-llama/Llama-2-7b-chat-hf", - ... tokenizer="meta-llama/Llama-2-7b-chat-hf", - ... task=Llama2TextGenerationTask(), + ... model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf"), + ... tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf"), + ... task=TextGenerationTask(), + ... prompt_format="llama2", ... ) >>> labeller = OpenAILLM( ... model="gpt-3.5-turbo", @@ -532,49 +531,8 @@ def _generate( # noqa: C901 display_progress_bar: bool = False, ) -> CustomDataset: """Generates the outputs for the given dataset using the LLMs provided to the - `Pipeline`. - - Args: - dataset (Dataset): the dataset to be used for generation. - num_generations (int, optional): the number of generations to be performed - for each input. Defaults to `1`. - batch_size (int, optional): the batch size to be used for generation. Defaults - to `1`. - shuffle_before_labelling (bool, optional): whether to shuffle the generations - before labelling or not. This is useful to avoid the labelling LLM to be - biased by the order of the generations. Defaults to `True`. - enable_checkpoints (bool, optional): whether to enable checkpoints or not. - Defaults to `True`. - display_progress_bar (bool, optional): whether to display the progress bar - or not. Defaults to `False`. + `Pipeline`.""" - Returns: - CustomDataset: the final dataset. - - Raises: - RuntimeError: if the `Pipeline` fails during the generation or labelling steps. - UserWarning: if the `Pipeline` fails during the generation or labelling steps - and `enable_checkpoints` is set to `False`. - - Examples: - >>> from distilabel.llm.huggingface import TransformersLLM - >>> from distilabel.llm.openai_ import OpenAILLM - >>> from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask - >>> from distilabel.tasks.text_generation.llama import Llama2TextGenerationTask - >>> from distilabel.pipeline import Pipeline - - >>> generator = TransformersLLM( - ... model="meta-llama/Llama-2-7b-chat-hf", - ... tokenizer="meta-llama/Llama-2-7b-chat-hf", - ... task=Llama2TextGenerationTask(), - ... ) - >>> labeller = OpenAILLM( - ... model="gpt-3.5-turbo", - ... task=UltraFeedbackTask.for_text_quality(), - ... ) - >>> pipeline = Pipeline(generator=generator, labeller=labeller) - >>> dataset = pipeline.generate(dataset=..., num_generations=1, batch_size=1) - """ if ( self.labeller is not None and self.generator is not None @@ -739,16 +697,15 @@ def generate( `enable_checkpoints` is set to `False`. Examples: - >>> from distilabel.llm.huggingface import TransformersLLM - >>> from distilabel.llm.openai_ import OpenAILLM - >>> from distilabel.tasks.preference.ultrafeedback import UltraFeedbackTask - >>> from distilabel.tasks.text_generation.llama import Llama2TextGenerationTask + >>> from transformers import AutoModelForCaualLM, AutoTokenizer + >>> from distilabel.llm import OpenAILLM, TransformersLLM + >>> from distilabel.tasks import TextGenerationTask, UltraFeedbackTask >>> from distilabel.pipeline import Pipeline - >>> generator = TransformersLLM( - ... model="meta-llama/Llama-2-7b-chat-hf", - ... tokenizer="meta-llama/Llama-2-7b-chat-hf", - ... task=Llama2TextGenerationTask(), + ... model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf"), + ... tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf"), + ... task=TextGenerationTask(), + ... prompt_format="llama2", ... ) >>> labeller = OpenAILLM( ... model="gpt-3.5-turbo", @@ -808,20 +765,22 @@ def pipeline( Pipeline: the `Pipeline` instance. Examples: - >>> from distilabel.llm.huggingface import TransformersLLM - >>> from distilabel.tasks.text_generation.llama import Llama2TextGenerationTask + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + >>> from distilabel.llm import TransformersLLM + >>> from distilabel.tasks import TextGenerationTask >>> from distilabel.pipeline import pipeline - >>> generator = TransformersLLM( - ... model="meta-llama/Llama-2-7b-chat-hf", - ... tokenizer="meta-llama/Llama-2-7b-chat-hf", - ... task=Llama2TextGenerationTask(), + ... model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf"), + ... tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf"), + ... task=TextGenerationTask(), + ... prompt_format="llama2", ... ) >>> pipeline = pipeline( ... task="preference", ... subtask="text-quality", ... generator=generator, ... ) + >>> dataset = pipeline.generate(dataset=..., num_generations=1, batch_size=1) """ if task == "preference": if labeller is None: diff --git a/src/distilabel/tasks/__init__.py b/src/distilabel/tasks/__init__.py index 8723e2c422..d5942793d2 100644 --- a/src/distilabel/tasks/__init__.py +++ b/src/distilabel/tasks/__init__.py @@ -21,8 +21,6 @@ from distilabel.tasks.preference.ultrajudge import UltraJudgeTask from distilabel.tasks.prompt import Prompt from distilabel.tasks.text_generation.base import TextGenerationTask -from distilabel.tasks.text_generation.llama import Llama2TextGenerationTask -from distilabel.tasks.text_generation.openai import OpenAITextGenerationTask from distilabel.tasks.text_generation.self_instruct import SelfInstructTask __all__ = [ @@ -35,7 +33,5 @@ "UltraJudgeTask", "Prompt", "TextGenerationTask", - "OpenAITextGenerationTask", - "Llama2TextGenerationTask", "SelfInstructTask", ] diff --git a/src/distilabel/tasks/base.py b/src/distilabel/tasks/base.py index 48f1c7f9d1..02b6f4f11c 100644 --- a/src/distilabel/tasks/base.py +++ b/src/distilabel/tasks/base.py @@ -71,7 +71,7 @@ def template(self) -> "Template": return Template(open(self.__jinja2_template__).read()) @abstractmethod - def generate_prompt(self, **kwargs: Any) -> Union[Prompt, Any]: + def generate_prompt(self, **kwargs: Any) -> Prompt: pass @abstractmethod diff --git a/src/distilabel/tasks/critique/prometheus.py b/src/distilabel/tasks/critique/prometheus.py index 980da395e2..95a4910d22 100644 --- a/src/distilabel/tasks/critique/prometheus.py +++ b/src/distilabel/tasks/critique/prometheus.py @@ -38,7 +38,7 @@ def input_args_names(self) -> List[str]: def generate_prompt( self, input: str, generations: str, ref_completion: str, **_: Any - ) -> str: + ) -> Prompt: render_kwargs = { "instruction": input, "completion": generations, @@ -49,7 +49,7 @@ def generate_prompt( return Prompt( system_prompt=self.system_prompt, formatted_prompt=self.template.render(**render_kwargs), - ).format_as(format="llama2") # type: ignore + ) def parse_output(self, output: str) -> CritiqueTaskOutput: # type: ignore """Parses the output of the model into the desired format.""" diff --git a/src/distilabel/tasks/critique/ultracm.py b/src/distilabel/tasks/critique/ultracm.py index 4e12e30c81..098d51e888 100644 --- a/src/distilabel/tasks/critique/ultracm.py +++ b/src/distilabel/tasks/critique/ultracm.py @@ -18,6 +18,7 @@ from distilabel.tasks.base import get_template from distilabel.tasks.critique.base import CritiqueTask, CritiqueTaskOutput +from distilabel.tasks.prompt import Prompt _ULTRACM_TEMPLATE = get_template("ultracm.jinja2") @@ -32,12 +33,15 @@ class UltraCMTask(CritiqueTask): " the user's questions." ) - def generate_prompt(self, input: str, generations: str, **_: Any) -> str: + def generate_prompt(self, input: str, generations: str, **_: Any) -> Prompt: render_kwargs = { "instruction": input, "completion": generations, } - return f"{self.system_prompt}\nUser: {self.template.render(**render_kwargs)}\nAssistant: ### Feedback\nOverall Score: " + return Prompt( + system_prompt=self.system_prompt, + formatted_prompt=f"User: {self.template.render(**render_kwargs)}\nAssistant: ### Feedback\nOverall Score: ", + ) def parse_output(self, output: str) -> CritiqueTaskOutput: # type: ignore """Parses the output of the model into the desired format.""" diff --git a/src/distilabel/tasks/text_generation/llama.py b/src/distilabel/tasks/text_generation/llama.py deleted file mode 100644 index 5ffe1d1953..0000000000 --- a/src/distilabel/tasks/text_generation/llama.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2023-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -from distilabel.tasks.prompt import Prompt -from distilabel.tasks.text_generation.base import TextGenerationTask - - -class Llama2TextGenerationTask(TextGenerationTask): - """A `TextGenerationTask` for the Llama2 model. - - Args: - system_prompt (str, optional): the system prompt to be used. Defaults to `None`. - principles (Dict[str, List[str]], optional): the principles to be used for the system prompt. - Defaults to `None`. - principles_distribution (Union[Dict[str, float], Literal["balanced"], None], optional): the - distribution of principles to be used for the system prompt. Defaults to `None`. - """ - - def generate_prompt(self, input: str, **_: Any) -> str: - """Generates a prompt for the Llama2 model. - - Args: - input (str): the input to be used for the prompt. - - Returns: - str: the generated prompt. - - Examples: - >>> from distilabel.tasks.text_generation import Llama2TextGenerationTask - >>> task = Llama2TextGenerationTask(system_prompt="You are a helpful assistant.") - >>> task.generate_prompt("What are the first 5 Fibonacci numbers?") - '[INST] <>\nYou are a helpful assistant.<>\n\nWhat are the first 5 Fibonacci numbers? [/INST]' - """ - return Prompt( - system_prompt=self.system_prompt, - formatted_prompt=input, - ).format_as("llama2") # type: ignore diff --git a/src/distilabel/tasks/text_generation/openai.py b/src/distilabel/tasks/text_generation/openai.py deleted file mode 100644 index 4e51441c66..0000000000 --- a/src/distilabel/tasks/text_generation/openai.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2023-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, Any, List - -from distilabel.tasks.prompt import Prompt -from distilabel.tasks.text_generation.base import TextGenerationTask - -if TYPE_CHECKING: - from distilabel.tasks.prompt import ChatCompletion - - -class OpenAITextGenerationTask(TextGenerationTask): - """A `TextGenerationTask` for any chat-completion OpenAI model. - - Args: - system_prompt (str, optional): the system prompt to be used. Defaults to `None`. - principles (Dict[str, List[str]], optional): the principles to be used for the system prompt. - Defaults to `None`. - principles_distribution (Union[Dict[str, float], Literal["balanced"], None], optional): the - distribution of principles to be used for the system prompt. Defaults to `None`. - """ - - def generate_prompt(self, input: str, **_: Any) -> List["ChatCompletion"]: - """Generates a prompt for any chat-completion OpenAI model. - - Args: - input (str): the input to be used for the prompt. - - Returns: - List[ChatCompletion]: the generated prompt. - - Examples: - >>> from distilabel.tasks.text_generation import OpenAITextGenerationTask - >>> task = OpenAITextGenerationTask(system_prompt="You are a helpful assistant.") - >>> task.generate_prompt("What are the first 5 Fibonacci numbers?") - [ - {'role': 'system', 'content': 'You are a helpful assistant.'}, - {'role': 'user', 'content': 'What are the first 5 Fibonacci numbers?'}, - ] - """ - return Prompt( - system_prompt=self.system_prompt, - formatted_prompt=input, - ).format_as("openai") # type: ignore