From e1b4c8b766ef8767f49fbb0ab5bbd8ded1190c28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 9 Jul 2024 10:12:53 +0200 Subject: [PATCH 01/30] Move `CudaDevicePlacementMixin` to new module --- src/distilabel/llms/__init__.py | 2 +- src/distilabel/llms/huggingface/transformers.py | 2 +- src/distilabel/llms/mixins/__init__.py | 14 ++++++++++++++ .../{mixins.py => mixins/cuda_device_placement.py} | 0 src/distilabel/llms/vllm.py | 7 +++---- src/distilabel/pipeline/local.py | 2 +- src/distilabel/steps/tasks/typing.py | 2 +- tests/unit/llms/mixins/__init__.py | 14 ++++++++++++++ .../test_cuda_device_placement.py} | 2 +- 9 files changed, 36 insertions(+), 9 deletions(-) create mode 100644 src/distilabel/llms/mixins/__init__.py rename src/distilabel/llms/{mixins.py => mixins/cuda_device_placement.py} (100%) create mode 100644 tests/unit/llms/mixins/__init__.py rename tests/unit/llms/{test_mixins.py => mixins/test_cuda_device_placement.py} (97%) diff --git a/src/distilabel/llms/__init__.py b/src/distilabel/llms/__init__.py index 3e50ddefaa..5f6ab9abbf 100644 --- a/src/distilabel/llms/__init__.py +++ b/src/distilabel/llms/__init__.py @@ -22,7 +22,7 @@ from distilabel.llms.litellm import LiteLLM from distilabel.llms.llamacpp import LlamaCppLLM from distilabel.llms.mistral import MistralLLM -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.moa import MixtureOfAgentsLLM from distilabel.llms.ollama import OllamaLLM from distilabel.llms.openai import OpenAILLM diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 6e7736d006..08029ee978 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -19,7 +19,7 @@ from distilabel.llms.base import LLM from distilabel.llms.chat_templates import CHATML_TEMPLATE -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput diff --git a/src/distilabel/llms/mixins/__init__.py b/src/distilabel/llms/mixins/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/src/distilabel/llms/mixins/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/src/distilabel/llms/mixins.py b/src/distilabel/llms/mixins/cuda_device_placement.py similarity index 100% rename from src/distilabel/llms/mixins.py rename to src/distilabel/llms/mixins/cuda_device_placement.py diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index ee124c7dfa..42c6b5aeea 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -30,7 +30,7 @@ from distilabel.llms.base import LLM from distilabel.llms.chat_templates import CHATML_TEMPLATE -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType @@ -304,14 +304,13 @@ def generate( # type: ignore if extra_sampling_params is None: extra_sampling_params = {} structured_output = None - needs_sorting = False if isinstance(inputs[0], tuple): prepared_batches, sorted_indices = self._prepare_batches(inputs) - needs_sorting = True else: # Simulate a batch without the structured output content prepared_batches = [([self.prepare_input(input) for input in inputs], None)] + sorted_indices = None # In case we have a single structured output for the dataset, we can logits_processors = None @@ -348,7 +347,7 @@ def generate( # type: ignore # If logits_processor is set, we need to sort the outputs back to the original order # (would be needed only if we have multiple structured outputs in the dataset) - if needs_sorting: + if sorted_indices is not None: batched_outputs = _sort_batches( batched_outputs, sorted_indices, num_generations=num_generations ) diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index 874757a6de..0e89436c5b 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -21,7 +21,7 @@ import tblib from distilabel.distiset import create_distiset -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.pipeline.base import ( BasePipeline, ) diff --git a/src/distilabel/steps/tasks/typing.py b/src/distilabel/steps/tasks/typing.py index 4f92cdc057..ae9fd9519e 100644 --- a/src/distilabel/steps/tasks/typing.py +++ b/src/distilabel/steps/tasks/typing.py @@ -38,7 +38,7 @@ class OutlinesStructuredOutputType(TypedDict, total=False): as obtained from `model_to_schema(BaseModel)`, if "regex", it should be a regex pattern as a string. """ - whitespace_pattern: Optional[Union[str, List[str]]] = None + whitespace_pattern: Optional[Union[str, List[str]]] """If "json" corresponds to a string or a list of strings with a pattern (doesn't impact string literals). For example, to allow only a single space or newline with diff --git a/tests/unit/llms/mixins/__init__.py b/tests/unit/llms/mixins/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/llms/mixins/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/llms/test_mixins.py b/tests/unit/llms/mixins/test_cuda_device_placement.py similarity index 97% rename from tests/unit/llms/test_mixins.py rename to tests/unit/llms/mixins/test_cuda_device_placement.py index c0c7b10671..80690bbf41 100644 --- a/tests/unit/llms/test_mixins.py +++ b/tests/unit/llms/mixins/test_cuda_device_placement.py @@ -19,7 +19,7 @@ import pytest from distilabel.llms.base import LLM -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin if TYPE_CHECKING: from distilabel.steps.tasks.typing import ChatType From d557542e30c9b0a61cc2de69c065b150cfa8a0b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 9 Jul 2024 17:32:28 +0200 Subject: [PATCH 02/30] Initial work for implementing Magpie --- .../llms/huggingface/transformers.py | 10 ++- src/distilabel/llms/mixins/magpie.py | 86 +++++++++++++++++++ src/distilabel/steps/tasks/__init__.py | 2 + .../steps/tasks/evol_instruct/generator.py | 1 + src/distilabel/steps/tasks/magpie/__init__.py | 14 +++ src/distilabel/steps/tasks/magpie/base.py | 86 +++++++++++++++++++ .../steps/tasks/magpie/generator.py | 31 +++++++ 7 files changed, 229 insertions(+), 1 deletion(-) create mode 100644 src/distilabel/llms/mixins/magpie.py create mode 100644 src/distilabel/steps/tasks/magpie/__init__.py create mode 100644 src/distilabel/steps/tasks/magpie/base.py create mode 100644 src/distilabel/steps/tasks/magpie/generator.py diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 08029ee978..200992845e 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -20,6 +20,7 @@ from distilabel.llms.base import LLM from distilabel.llms.chat_templates import CHATML_TEMPLATE from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput @@ -32,7 +33,7 @@ from distilabel.llms.typing import HiddenState -class TransformersLLM(LLM, CudaDevicePlacementMixin): +class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): """Hugging Face `transformers` library LLM implementation using the text generation pipeline. @@ -106,6 +107,8 @@ class TransformersLLM(LLM, CudaDevicePlacementMixin): def load(self) -> None: """Loads the model and tokenizer and creates the text generation pipeline. In addition, it will configure the tokenizer chat template.""" + MagpieChatTemplateMixin.load(self) + if self.device == "cuda": CudaDevicePlacementMixin.load(self) @@ -160,6 +163,9 @@ def prepare_input(self, input: "StandardInput") -> str: """Prepares the input by applying the chat template to the input, which is formatted as an OpenAI conversation, and adding the generation prompt. """ + if (prepared_input := super().prepare_input(input=input)) is not None: + return prepared_input + return self._pipeline.tokenizer.apply_chat_template( # type: ignore input, # type: ignore tokenize=False, @@ -176,6 +182,7 @@ def generate( # type: ignore repetition_penalty: float = 1.1, top_p: float = 1.0, top_k: int = 0, + stop_sequence: Union[str, List[str], None] = None, do_sample: bool = True, ) -> List[GenerateOutput]: """Generates `num_generations` responses for each input using the text generation @@ -206,6 +213,7 @@ def generate( # type: ignore repetition_penalty=repetition_penalty, top_p=top_p, top_k=top_k, + stop_sequence=stop_sequence, do_sample=do_sample, num_return_sequences=num_generations, prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn, diff --git a/src/distilabel/llms/mixins/magpie.py b/src/distilabel/llms/mixins/magpie.py new file mode 100644 index 0000000000..f8f456eab3 --- /dev/null +++ b/src/distilabel/llms/mixins/magpie.py @@ -0,0 +1,86 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import TYPE_CHECKING, Dict, Final, Literal, Union + +import jinja2 +from pydantic import BaseModel, PrivateAttr +from typing_extensions import TypedDict + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import StandardInput + + +MagpieAvailableTemplates = Literal[ + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-70B-Instruct", +] + + +class MagpieChatTemplate(TypedDict): + chat_template: str + generate_instruction: str + generate_instruction_with_system_prompt: str + + +MAGPIE_TEMPLATES: Final[Dict["MagpieAvailableTemplates", "MagpieChatTemplate"]] = { + "meta-llama/Meta-Llama-3-8B-Instruct": { + "chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "generate_instruction": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", + "generate_instruction_with_system_prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", + }, + "meta-llama/Meta-Llama-3-70B-Instruct": { + "chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "generate_instruction": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", + "generate_instruction_with_system_prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", + }, +} + + +class MagpieChatTemplateMixin(BaseModel): + model: str + use_magpie_template: bool = False + template: Union[MagpieChatTemplate, None] = None + + _chat_template: jinja2.Template = PrivateAttr(default=None) + + def load(self) -> None: + if not self.use_magpie_template: + return + + if self.template is None: + self.template = MAGPIE_TEMPLATES[ + "meta-llama/Meta-Llama-3-8B-Instruct" + ].copy() + self._chat_template = jinja2.Template(self.template["chat_template"]) + + def prepare_input(self, input: "StandardInput") -> Union[str, None]: + if not self.use_magpie_template: + return None + + assert self.template + + if len(input) == 0: + return self.template["generate_instruction"] + + if len(input) == 1 and input[0]["role"] == "system": + template = copy.copy( + self.template["generate_instruction_with_system_prompt"] + ) + return template.format(system_prompt=input[0]["content"]) + + # TODO: case there are messages + + return None diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index b2456d7824..2534b5d6ba 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -35,6 +35,7 @@ from distilabel.steps.tasks.instruction_backtranslation import ( InstructionBacktranslation, ) +from distilabel.steps.tasks.magpie.base import Magpie from distilabel.steps.tasks.pair_rm import PairRM from distilabel.steps.tasks.prometheus_eval import PrometheusEval from distilabel.steps.tasks.quality_scorer import QualityScorer @@ -64,6 +65,7 @@ "GenerateTextRetrievalData", "MonolingualTripletGenerator", "InstructionBacktranslation", + "Magpie", "PairRM", "PrometheusEval", "QualityScorer", diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py index bc8c5d2eff..0ab8b9ab9b 100644 --- a/src/distilabel/steps/tasks/evol_instruct/generator.py +++ b/src/distilabel/steps/tasks/evol_instruct/generator.py @@ -280,6 +280,7 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore instructions = [] mutation_no = 0 + # TODO: update to take into account `offset` iter_no = 0 while len(instructions) < self.num_instructions: prompts = self._apply_random_mutation(iter_no=iter_no) diff --git a/src/distilabel/steps/tasks/magpie/__init__.py b/src/distilabel/steps/tasks/magpie/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/src/distilabel/steps/tasks/magpie/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py new file mode 100644 index 0000000000..947e85b307 --- /dev/null +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -0,0 +1,86 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from pydantic import Field, PositiveInt + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.base import StepInput +from distilabel.steps.tasks.base import Task + +if TYPE_CHECKING: + from distilabel.llms.typing import GenerateOutput + from distilabel.steps.tasks.typing import ChatType, FormattedInput + from distilabel.steps.typing import StepOutput + + +class Magpie(Task): + n_turns: Optional[RuntimeParameter[PositiveInt]] = Field( + default=None, + description="If provided, then the number of turns to generate for the conversation.", + ) + + @property + def inputs(self) -> List[str]: + return [] + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + return [] + + @property + def outputs(self) -> List[str]: + if self.n_turns is None: + return ["instruction"] + + return ["conversation"] + + def format_output( + self, + output: Union[str, None], + input: Union[Dict[str, Any], None] = None, + ) -> Dict[str, Any]: + return {} + + def _prepare_inputs_for_instruction_generation( + self, inputs: List[Dict[str, Any]] + ) -> List["FormattedInput"]: + return [ + [{"role": "system", "content": input["system_prompt"]}] + if "system_prompt" in input + else [] + for input in inputs + ] + + def _format_instruction_generation_output(self, outputs: List["GenerateOutput"]): + instructions = [] + for output in outputs: + if output[0] is None: + instructions.append({"instruction": None}) + else: + parts = output[0].split("\n") + instructions.append({"instruction": parts[0]}) + return instructions + + def process(self, inputs: StepInput) -> "StepOutput": + inputs_for_instruction_generation = ( + self._prepare_inputs_for_instruction_generation(inputs=inputs) + ) + outputs = self.llm.generate( + inputs=inputs_for_instruction_generation, num_generations=1 + ) + instructions = self._format_instruction_generation_output(outputs=outputs) + + if self.n_turns is None: + yield instructions diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py new file mode 100644 index 0000000000..bf897138bb --- /dev/null +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -0,0 +1,31 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import TYPE_CHECKING +from pydantic import Field + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.base import GeneratorTask + +if TYPE_CHECKING: + from distilabel.steps.typing import GeneratorStepOutput + + +class MagpieInstructionGenerator(GeneratorTask): + num_instructions: RuntimeParameter[int] = Field( + default=None, description="The number of instructions to generate." + ) + + def process(self, offset: int = 0) -> "GeneratorStepOutput": + pass From 755f7ecd84bdd0de5201b30cf7a1ef634677eb5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 9 Jul 2024 19:49:03 +0200 Subject: [PATCH 03/30] Simplify magpie implementation --- .../llms/huggingface/transformers.py | 10 +-- src/distilabel/llms/mixins/magpie.py | 74 +++-------------- src/distilabel/steps/tasks/magpie/base.py | 80 ++++++++++++------- 3 files changed, 66 insertions(+), 98 deletions(-) diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 200992845e..bde07342a4 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -107,8 +107,6 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): def load(self) -> None: """Loads the model and tokenizer and creates the text generation pipeline. In addition, it will configure the tokenizer chat template.""" - MagpieChatTemplateMixin.load(self) - if self.device == "cuda": CudaDevicePlacementMixin.load(self) @@ -163,14 +161,12 @@ def prepare_input(self, input: "StandardInput") -> str: """Prepares the input by applying the chat template to the input, which is formatted as an OpenAI conversation, and adding the generation prompt. """ - if (prepared_input := super().prepare_input(input=input)) is not None: - return prepared_input - - return self._pipeline.tokenizer.apply_chat_template( # type: ignore + prompt = self._pipeline.tokenizer.apply_chat_template( # type: ignore input, # type: ignore tokenize=False, add_generation_prompt=True, ) + return super().apply_pre_query_template(prompt, input) @validate_call def generate( # type: ignore @@ -182,7 +178,6 @@ def generate( # type: ignore repetition_penalty: float = 1.1, top_p: float = 1.0, top_k: int = 0, - stop_sequence: Union[str, List[str], None] = None, do_sample: bool = True, ) -> List[GenerateOutput]: """Generates `num_generations` responses for each input using the text generation @@ -213,7 +208,6 @@ def generate( # type: ignore repetition_penalty=repetition_penalty, top_p=top_p, top_k=top_k, - stop_sequence=stop_sequence, do_sample=do_sample, num_return_sequences=num_generations, prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn, diff --git a/src/distilabel/llms/mixins/magpie.py b/src/distilabel/llms/mixins/magpie.py index f8f456eab3..94fb7655fd 100644 --- a/src/distilabel/llms/mixins/magpie.py +++ b/src/distilabel/llms/mixins/magpie.py @@ -12,75 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -from typing import TYPE_CHECKING, Dict, Final, Literal, Union +from typing import TYPE_CHECKING -import jinja2 -from pydantic import BaseModel, PrivateAttr -from typing_extensions import TypedDict +from pydantic import BaseModel if TYPE_CHECKING: from distilabel.steps.tasks.typing import StandardInput -MagpieAvailableTemplates = Literal[ - "meta-llama/Meta-Llama-3-8B-Instruct", - "meta-llama/Meta-Llama-3-70B-Instruct", -] - - -class MagpieChatTemplate(TypedDict): - chat_template: str - generate_instruction: str - generate_instruction_with_system_prompt: str - - -MAGPIE_TEMPLATES: Final[Dict["MagpieAvailableTemplates", "MagpieChatTemplate"]] = { - "meta-llama/Meta-Llama-3-8B-Instruct": { - "chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", - "generate_instruction": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", - "generate_instruction_with_system_prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", - }, - "meta-llama/Meta-Llama-3-70B-Instruct": { - "chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", - "generate_instruction": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", - "generate_instruction_with_system_prompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", - }, -} - - class MagpieChatTemplateMixin(BaseModel): model: str use_magpie_template: bool = False - template: Union[MagpieChatTemplate, None] = None - - _chat_template: jinja2.Template = PrivateAttr(default=None) - - def load(self) -> None: - if not self.use_magpie_template: - return - - if self.template is None: - self.template = MAGPIE_TEMPLATES[ - "meta-llama/Meta-Llama-3-8B-Instruct" - ].copy() - self._chat_template = jinja2.Template(self.template["chat_template"]) - - def prepare_input(self, input: "StandardInput") -> Union[str, None]: - if not self.use_magpie_template: - return None - - assert self.template - - if len(input) == 0: - return self.template["generate_instruction"] - - if len(input) == 1 and input[0]["role"] == "system": - template = copy.copy( - self.template["generate_instruction_with_system_prompt"] - ) - return template.format(system_prompt=input[0]["content"]) - - # TODO: case there are messages - - return None + # TODO: harcoded to llama 3 + pre_query_template: str = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + ) + + def apply_pre_query_template(self, prompt: str, input: "StandardInput") -> str: + if not self.use_magpie_template or input[-1]["role"] == "assistant": + return prompt + return prompt + self.pre_query_template diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index 947e85b307..68c7cd6c2d 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -12,26 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union from pydantic import Field, PositiveInt +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.base import StepInput from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.llms.typing import GenerateOutput from distilabel.steps.tasks.typing import ChatType, FormattedInput from distilabel.steps.typing import StepOutput class Magpie(Task): - n_turns: Optional[RuntimeParameter[PositiveInt]] = Field( - default=None, - description="If provided, then the number of turns to generate for the conversation.", + n_turns: RuntimeParameter[PositiveInt] = Field( + default=1, + description="The number of turns to generate for the conversation.", ) + def model_post_init(self, _: Any) -> None: + """Checks that the provided `LLM` uses the `MagpieChatTemplateMixin`.""" + super().load() + + if not isinstance(self.llm, MagpieChatTemplateMixin): + raise ValueError( + f"`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`." + f"`{self.llm.__class__.__name__}` doesn't use the aforementioned mixin." + ) + + self.llm.use_magpie_template = True + @property def inputs(self) -> List[str]: return [] @@ -41,9 +53,6 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType": @property def outputs(self) -> List[str]: - if self.n_turns is None: - return ["instruction"] - return ["conversation"] def format_output( @@ -63,24 +72,41 @@ def _prepare_inputs_for_instruction_generation( for input in inputs ] - def _format_instruction_generation_output(self, outputs: List["GenerateOutput"]): - instructions = [] - for output in outputs: - if output[0] is None: - instructions.append({"instruction": None}) - else: - parts = output[0].split("\n") - instructions.append({"instruction": parts[0]}) - return instructions + def _append_messages_to_conversations( + self, role: str, messages: List[str], conversations: List["ChatType"] + ) -> List["ChatType"]: + for instruction, conversation in zip(messages, conversations): + conversation.append({"role": role, "content": instruction}) + return conversations def process(self, inputs: StepInput) -> "StepOutput": - inputs_for_instruction_generation = ( - self._prepare_inputs_for_instruction_generation(inputs=inputs) - ) - outputs = self.llm.generate( - inputs=inputs_for_instruction_generation, num_generations=1 - ) - instructions = self._format_instruction_generation_output(outputs=outputs) - - if self.n_turns is None: - yield instructions + conversations = self._prepare_inputs_for_instruction_generation(inputs=inputs) + + for _ in range(self.n_turns): # type: ignore + # Generate instruction or user message + outputs = self.llm.generate( + inputs=conversations, + num_generations=1, + **self.llm.generation_kwargs, # type: ignore + ) + + conversations = self._append_messages_to_conversations( + role="user", + messages=[output[0] for output in outputs], + conversations=conversations, # type: ignore + ) + + # Generate response + outputs = self.llm.generate( + inputs=conversations, + num_generations=1, + **self.llm.generation_kwargs, # type: ignore + ) + + conversations = self._append_messages_to_conversations( + role="assistant", + messages=[output[0] for output in outputs], + conversations=conversations, # type: ignore + ) + + yield [{"conversation": conversation} for conversation in conversations] From 1719d15243496381771ca73fcfdd8af77f551ac1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 10 Jul 2024 16:54:45 +0200 Subject: [PATCH 04/30] Remove `use_open_ai` and add `MagpieChatTemplateMixin` to `InferenceEndpointsLLM` --- .../llms/huggingface/inference_endpoints.py | 186 +++++++----------- 1 file changed, 70 insertions(+), 116 deletions(-) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 6d4d3d1a5e..05169cbeac 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -14,8 +14,9 @@ import os import random +import sys import warnings -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from pydantic import ( Field, @@ -28,6 +29,7 @@ from typing_extensions import override from distilabel.llms.base import AsyncLLM +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import ( @@ -42,15 +44,13 @@ if TYPE_CHECKING: from huggingface_hub import AsyncInferenceClient - from openai import AsyncOpenAI from transformers import PreTrainedTokenizer -class InferenceEndpointsLLM(AsyncLLM): +class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin): """InferenceEndpoints LLM implementation running the async API client. - This LLM will internally use `huggingface_hub.AsyncInferenceClient` or `openai.AsyncOpenAI` - depending on the `use_openai_client` attribute. + This LLM will internally use `huggingface_hub.AsyncInferenceClient`. Attributes: model_id: the model ID to use for the LLM as available in the Hugging Face Hub, which @@ -63,7 +63,6 @@ class InferenceEndpointsLLM(AsyncLLM): tokenizer_id: the tokenizer ID to use for the LLM as available in the Hugging Face Hub. Defaults to `None`, but defining one is recommended to properly format the prompt. model_display_name: the model display name to use for the LLM. Defaults to `None`. - use_openai_client: whether to use the OpenAI client instead of the Hugging Face client. Icon: `:hugging:` @@ -137,7 +136,6 @@ class InferenceEndpointsLLM(AsyncLLM): tokenizer_id: Optional[str] = None model_display_name: Optional[str] = None - use_openai_client: bool = False structured_output: Optional[RuntimeParameter[StructuredOutputType]] = Field( default=None, @@ -149,7 +147,7 @@ class InferenceEndpointsLLM(AsyncLLM): _model_name: Optional[str] = PrivateAttr(default=None) _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None) _api_key_env_var: str = PrivateAttr(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME) - _aclient: Optional[Union["AsyncInferenceClient", "AsyncOpenAI"]] = PrivateAttr(...) + _aclient: Optional["AsyncInferenceClient"] = PrivateAttr(...) @model_validator(mode="after") # type: ignore def only_one_of_model_id_endpoint_name_or_base_url_provided( @@ -182,13 +180,10 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( ) def load(self) -> None: # noqa: C901 - """Loads the either the `AsyncInferenceClient` or the `AsyncOpenAI` client to benefit - from async requests, running the Hugging Face Inference Endpoint underneath via the - `/v1/chat/completions` endpoint, exposed for the models running on TGI using the - `text-generation` task. + """Loads the `AsyncInferenceClient` client to connect to the Hugging Face Inference + Endpoint. Raises: - ImportError: if the `openai` Python client is not installed. ImportError: if the `huggingface-hub` Python client is not installed. ValueError: if the model is not currently deployed or is not running the TGI framework. ImportError: if the `transformers` Python client is not installed. @@ -234,31 +229,16 @@ def load(self) -> None: # noqa: C901 ) if client.status in ["paused", "scaledToZero"]: client.resume().wait(timeout=300) - elif client.status in ["initializing"]: + elif client.status == "initializing": client.wait(timeout=300) self.base_url = client.url self._model_name = client.repository - if self.use_openai_client: - try: - from openai import AsyncOpenAI - except ImportError as ie: - raise ImportError( - "OpenAI Python client is not installed. Please install it using" - " `pip install openai`." - ) from ie - - self._aclient = AsyncOpenAI( - base_url=self.base_url, - api_key=self.api_key.get_secret_value(), - max_retries=6, - ) - else: - self._aclient = AsyncInferenceClient( - model=self.base_url, - token=self.api_key.get_secret_value(), - ) + self._aclient = AsyncInferenceClient( + model=self.base_url, + token=self.api_key.get_secret_value(), + ) if self.tokenizer_id: try: @@ -283,43 +263,65 @@ def model_name(self) -> Union[str, None]: # type: ignore or self.base_url ) - async def _openai_agenerate( - self, - input: "StandardInput", - max_new_tokens: int = 128, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - temperature: float = 1.0, - top_p: Optional[float] = None, - stop: Optional[Union[str, List[str]]] = None, - ) -> GenerateOutput: - """Generates completions for the given input using the OpenAI async client.""" - completion = await self._aclient.chat.completions.create( # type: ignore - messages=input, # type: ignore - model="tgi", - max_tokens=max_new_tokens, - n=1, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - temperature=temperature, - top_p=top_p, - stop=stop, - timeout=50, + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + prompt: str = self._tokenizer.apply_chat_template( # type: ignore + conversation=input, # type: ignore + tokenize=False, + add_generation_prompt=True, ) - if completion.choices[0].message.content is None: - self._logger.warning( # type: ignore - f"⚠️ Received no response using OpenAI client (model: '{self.model_name}')." - f" Finish reason was: {completion.choices[0].finish_reason}" - ) - return [completion.choices[0].message.content] + return super().apply_magpie_pre_query_template(prompt, input) + + def get_structured_output( + self, input: FormattedInput + ) -> Union[Dict[str, Any], None]: + """Gets the structured output (if any) for the given input. + + Args: + input: a single input in chat format to generate responses for. + + Returns: + The structured output that will be passed as `grammer` to the inference endpoint + or `None` if not required. + """ + structured_output = None + + # Specific structured output per input + if isinstance(input, tuple): + input, structured_output = input + structured_output = { + "type": structured_output["format"], + "value": structured_output["schema"], + } + + # Same structured output for all the inputs + if structured_output is None and self.structured_output is not None: + try: + structured_output = { + "type": self.structured_output["format"], + "value": self.structured_output["schema"], + } + except KeyError as e: + raise ValueError( + "To use the structured output you have to inform the `format` and `schema` in " + "the `structured_output` attribute." + ) from e + + return structured_output @validate_call async def agenerate( # type: ignore self, input: FormattedInput, max_new_tokens: int = 128, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, repetition_penalty: Optional[float] = None, temperature: float = 1.0, do_sample: bool = False, @@ -331,21 +333,16 @@ async def agenerate( # type: ignore seed: Optional[int] = None, watermark: bool = False, ) -> GenerateOutput: - """Generates completions for the given input using the OpenAI async client. + """Generates completions for the given input using the async client. Args: input: a single input in chat format to generate responses for. max_new_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`. - frequency_penalty: the repetition penalty to use for the generation. Defaults - to `0.0`. Only applies if `use_openai_client=True`. - presence_penalty: the presence penalty to use for the generation. Defaults to - `0.0`. Only applies if `use_openai_client=True`. repetition_penalty: the repetition penalty to use for the generation. Defaults - to `None`. Only applies if `use_openai_client=False`. + to `None`. temperature: the temperature to use for the generation. Defaults to `1.0`. do_sample: whether to use sampling for the generation. Defaults to `False`. - Only applies if `use_openai_client=False`. top_k: the top-k value to use for the generation. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid values in TGI. top_p: the top-p value to use for the generation. Defaults to `1.0`. @@ -373,55 +370,12 @@ async def agenerate( # type: ignore ) stop_sequences = stop_sequences[:4] - structured_output = None - if isinstance(input, tuple): - input, structured_output = input - structured_output = { - "type": structured_output["format"], - "value": structured_output["schema"], - } - - # NOTE: `self.structured_output` applies to all the generations, while `structured_output` i.e. the - # value included within the tuple provided as `input` to this method, is intended to be different per - # each input, so those should not be used together. Meaning that it should be either provided at attribute - # level i.e. self, or via a column within each input i.e. row. - if structured_output is None and self.structured_output is not None: - try: - structured_output = { - "type": self.structured_output["format"], - "value": self.structured_output["schema"], - } - except KeyError as e: - raise ValueError( - "To use the structured output you have to inform the `format` and `schema` in " - "the `structured_output` attribute." - ) from e - - if self.use_openai_client: - return await self._openai_agenerate( - input=input, - max_new_tokens=max_new_tokens, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - temperature=temperature, - top_p=top_p, - stop=stop_sequences, - ) - - if self._tokenizer is not None: - prompt = self._tokenizer.apply_chat_template( # type: ignore - conversation=input, # type: ignore - tokenize=False, - add_generation_prompt=True, - ) - else: - # TODO: should we apply a default chat template here instead? e.g. ChatML - prompt = "\n".join([message["content"] for message in input]) + structured_output = self.get_structured_output(input) completion = None try: completion = await self._aclient.text_generation( # type: ignore - prompt=prompt, # type: ignore + prompt=self.prepare_input(input), # type: ignore max_new_tokens=max_new_tokens, do_sample=do_sample, typical_p=typical_p, @@ -435,7 +389,7 @@ async def agenerate( # type: ignore grammar=structured_output, # type: ignore # NOTE: here to ensure that the cache is not used and a different response is # generated every time - seed=seed or random.randint(0, 2147483647), + seed=seed or random.randint(0, sys.maxsize), ) except Exception as e: self._logger.warning( # type: ignore From 775ca4ed7e3a6bd1c21d8ffc678cc3a59e2699b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 10 Jul 2024 16:55:10 +0200 Subject: [PATCH 05/30] Add `MagpieChatTemplateMixin` to `vLLM` --- src/distilabel/llms/vllm.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index 42c6b5aeea..e5ec27816d 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -31,6 +31,7 @@ from distilabel.llms.base import LLM from distilabel.llms.chat_templates import CHATML_TEMPLATE from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType @@ -39,11 +40,13 @@ from transformers import PreTrainedTokenizer from vllm import LLM as _vLLM + from distilabel.steps.tasks.typing import StandardInput + SamplingParams = None -class vLLM(LLM, CudaDevicePlacementMixin): +class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): """`vLLM` library LLM implementation. Attributes: @@ -213,15 +216,22 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model - def prepare_input(self, input: "FormattedInput") -> str: - """Prepares the input by applying the chat template to the input, which is formatted - as an OpenAI conversation, and adding the generation prompt. + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. """ - return self._tokenizer.apply_chat_template( # type: ignore + prompt = self._tokenizer.apply_chat_template( # type: ignore input, # type: ignore tokenize=False, add_generation_prompt=True, # type: ignore ) + return super().apply_magpie_pre_query_template(prompt, input) def _prepare_batches( self, inputs: List[FormattedInput] From 9ff6eebfde4bc904f6e25d8ba7a71ea868aa9ab7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 10 Jul 2024 18:26:09 +0200 Subject: [PATCH 06/30] Add `MagpieGenerator` task --- .../llms/huggingface/transformers.py | 15 +- src/distilabel/llms/mixins/magpie.py | 78 +++++-- src/distilabel/steps/tasks/__init__.py | 2 + src/distilabel/steps/tasks/magpie/base.py | 204 ++++++++++++++---- .../steps/tasks/magpie/generator.py | 76 ++++++- 5 files changed, 313 insertions(+), 62 deletions(-) diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index bde07342a4..3b4baa630f 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -158,15 +158,21 @@ def model_name(self) -> str: return self.model def prepare_input(self, input: "StandardInput") -> str: - """Prepares the input by applying the chat template to the input, which is formatted - as an OpenAI conversation, and adding the generation prompt. + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. """ - prompt = self._pipeline.tokenizer.apply_chat_template( # type: ignore + prompt: str = self._pipeline.tokenizer.apply_chat_template( # type: ignore input, # type: ignore tokenize=False, add_generation_prompt=True, ) - return super().apply_pre_query_template(prompt, input) + return super().apply_magpie_pre_query_template(prompt, input) @validate_call def generate( # type: ignore @@ -211,6 +217,7 @@ def generate( # type: ignore do_sample=do_sample, num_return_sequences=num_generations, prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn, + pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore ) return [ [generation["generated_text"] for generation in output] diff --git a/src/distilabel/llms/mixins/magpie.py b/src/distilabel/llms/mixins/magpie.py index 94fb7655fd..2020774432 100644 --- a/src/distilabel/llms/mixins/magpie.py +++ b/src/distilabel/llms/mixins/magpie.py @@ -12,23 +12,77 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Literal, Union -from pydantic import BaseModel +from pydantic import BaseModel, field_validator, model_validator +from typing_extensions import Self if TYPE_CHECKING: from distilabel.steps.tasks.typing import StandardInput +MagpieAvailablePreQueryTemplates = Literal["llama3", "qwen2"] +"""The available predefined pre-query templates.""" -class MagpieChatTemplateMixin(BaseModel): - model: str - use_magpie_template: bool = False - # TODO: harcoded to llama 3 - pre_query_template: str = ( - "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" - ) +MAGPIE_PRE_QUERY_TEMPLATES: Dict[MagpieAvailablePreQueryTemplates, str] = { + "llama3": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", + "qwen2": "<|im_start|>user\n", +} - def apply_pre_query_template(self, prompt: str, input: "StandardInput") -> str: - if not self.use_magpie_template or input[-1]["role"] == "assistant": + +class MagpieChatTemplateMixin(BaseModel, validate_assignment=True): + """A simple mixin that adds the required logic to apply the pre-query template that + allows to an instruct fine-tuned LLM to generate user instructions as described in + the paper 'Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing'. + + This mixin is meant to be used in combination with the [Magpie][distilabel.steps.tasks.magpie.base.Magpie] + task. + + Attributes: + use_magpie_template: a flag used to enable/disable applying the pre-query template. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2", ... + or a pre-query template. + + References: + - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) + """ + + use_magpie_template: bool = True + magpie_pre_query_template: Union[MagpieAvailablePreQueryTemplates, str, None] = None + + @field_validator("magpie_pre_query_template") + @classmethod + def magpie_pre_query_template_validator(cls, value: str) -> str: + """Resolves the pre-query template alias if it exists, otherwise, returns the + value with no modification.""" + if value in MAGPIE_PRE_QUERY_TEMPLATES: + return MAGPIE_PRE_QUERY_TEMPLATES[value] + return value + + @model_validator(mode="after") + def use_magpie_template_validation(self) -> Self: + """Checks that there is a pre-query template set if Magpie is going to be used.""" + if self.use_magpie_template and self.magpie_pre_query_template is None: + raise ValueError( + f"Cannot set `use_magpie_template=True` if `magpie_pre_query_template` is" + f" `None`. To use Magpie with `{self.__class__.__name__}` you need to set" + f" the `magpie_pre_query_template` attribute." + ) + return self + + def apply_magpie_pre_query_template( + self, prompt: str, input: "StandardInput" + ) -> str: + """Applies the pre-query template to the prompt if Magpie is going to be used. + + Args: + prompt: the prompt to which the pre-query template has to be applied. + input: the list with the chat items that were used to generate the prompt. + + Returns: + The prompt with the pre-query template applied if needed. + """ + if not self.use_magpie_template or input[-1]["role"] == "user": return prompt - return prompt + self.pre_query_template + return prompt + self.magpie_pre_query_template # type: ignore diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 2534b5d6ba..0b3a69596b 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -36,6 +36,7 @@ InstructionBacktranslation, ) from distilabel.steps.tasks.magpie.base import Magpie +from distilabel.steps.tasks.magpie.generator import MagpieGenerator from distilabel.steps.tasks.pair_rm import PairRM from distilabel.steps.tasks.prometheus_eval import PrometheusEval from distilabel.steps.tasks.quality_scorer import QualityScorer @@ -66,6 +67,7 @@ "MonolingualTripletGenerator", "InstructionBacktranslation", "Magpie", + "MagpieGenerator", "PairRM", "PrometheusEval", "QualityScorer", diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index 68c7cd6c2d..6067375c90 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from pydantic import Field, PositiveInt +from distilabel.llms.base import LLM from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin -from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.mixins.runtime_parameters import ( + RuntimeParameter, + RuntimeParametersMixin, +) from distilabel.steps.base import StepInput from distilabel.steps.tasks.base import Task @@ -25,61 +29,92 @@ from distilabel.steps.tasks.typing import ChatType, FormattedInput from distilabel.steps.typing import StepOutput +MAGPIE_MULTI_TURN_SYSTEM_PROMPT = ( + "You are a helpful Al assistant. The user will engage in a multi−round conversation" + " with you,asking initial questions and following up with additional related questions." + " Your goal is to provide thorough, relevant and insightful responses to help the user" + " with their queries." +) -class Magpie(Task): - n_turns: RuntimeParameter[PositiveInt] = Field( - default=1, - description="The number of turns to generate for the conversation.", - ) - def model_post_init(self, _: Any) -> None: - """Checks that the provided `LLM` uses the `MagpieChatTemplateMixin`.""" - super().load() +class MagpieBase(RuntimeParametersMixin): + """Base class defining the generation logic of Magpie method. - if not isinstance(self.llm, MagpieChatTemplateMixin): - raise ValueError( - f"`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`." - f"`{self.llm.__class__.__name__}` doesn't use the aforementioned mixin." - ) + References: + - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) + """ - self.llm.use_magpie_template = True + llm: LLM - @property - def inputs(self) -> List[str]: - return [] - - def format_input(self, input: Dict[str, Any]) -> "ChatType": - return [] - - @property - def outputs(self) -> List[str]: - return ["conversation"] - - def format_output( - self, - output: Union[str, None], - input: Union[Dict[str, Any], None] = None, - ) -> Dict[str, Any]: - return {} + n_turns: RuntimeParameter[PositiveInt] = Field( + default=1, + description="The number of turns to generate for the conversation.", + ) + system_prompt: Optional[RuntimeParameter[str]] = Field( + default=None, + description="An optional system prompt that can be used to steer the LLM to generate" + " content of certain topic, guide the style, etc.", + ) def _prepare_inputs_for_instruction_generation( self, inputs: List[Dict[str, Any]] ) -> List["FormattedInput"]: - return [ - [{"role": "system", "content": input["system_prompt"]}] - if "system_prompt" in input - else [] - for input in inputs - ] + """Prepares the inputs adding the system (if required) prompt provided in each row, + or if the conversations to generate have more than one turn, then adding the system + prompt for multi-turn conversation from the paper. + + Args: + inputs: the inputs to prepare. + + Returns: + The prepared inputs. + """ + prepared_inputs = [] + for input in inputs: + conversation = [] + if "system_prompt" in input: + conversation.append( + {"role": "system", "content": input["system_prompt"]} + ) + elif self.system_prompt is not None: + conversation.append({"role": "system", "content": self.system_prompt}) + elif self.n_turns > 1: # type: ignore + conversation.append( + {"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT} + ) + + prepared_inputs.append(conversation) + + return prepared_inputs def _append_messages_to_conversations( self, role: str, messages: List[str], conversations: List["ChatType"] ) -> List["ChatType"]: + """Appends the outputs generated by the LLM with the specified role to the conversations. + + Args: + role: the role to assign to the message to be appended. + messages: the list of messages generated by the LLM for each conversation. + conversations: the list of conversations to which the messages will be appended. + + Returns: + The updated conversations. + """ for instruction, conversation in zip(messages, conversations): conversation.append({"role": role, "content": instruction}) return conversations - def process(self, inputs: StepInput) -> "StepOutput": + def _generate_with_pre_query_template( + self, inputs: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Generate a list of instructions or conversations of the specified number of turns. + + Args: + inputs: a list of dictionaries that can contain a `system_prompt` key. + + Returns: + The list of generated conversations. + """ conversations = self._prepare_inputs_for_instruction_generation(inputs=inputs) for _ in range(self.n_turns): # type: ignore @@ -96,6 +131,8 @@ def process(self, inputs: StepInput) -> "StepOutput": conversations=conversations, # type: ignore ) + # TODO: handle potential previous `None`s + # Generate response outputs = self.llm.generate( inputs=conversations, @@ -109,4 +146,89 @@ def process(self, inputs: StepInput) -> "StepOutput": conversations=conversations, # type: ignore ) - yield [{"conversation": conversation} for conversation in conversations] + return [{"conversation": conversation} for conversation in conversations] + + +class Magpie(Task, MagpieBase): + """Generates conversations using an instruct fine-tuned LLM. + + Magpie is a neat method that allows generating user instructions with no seed data + or specific system prompt thanks to the autoregressive capabilities of the instruct + fine-tuned LLMs. As they were fine-tuned using a chat template composed by a user message + and a desired assistant output, the instruct fine-tuned LLM learns that after the pre-query + or pre-instruct tokens comes an instruction. If these pre-query tokens are sent to the + LLM without any user message, then the LLM will continue generating tokens as it was + the user. This trick allows "extracting" instructions from the instruct fine-tuned LLM. + After this instruct is generated, it can be sent again to the LLM to generate this time + an assistant response. This process can be repeated N times allowing to build a multi-turn + conversation. + + This method was described in the paper 'Magpie: Alignment Data Synthesis from + Scratch by Prompting Aligned LLMs with Nothing'. + + Runtime parameters: + - `n_turns`: the number of turns that the generated conversation will have. + - `system_prompt`: an optional system prompt that can be used to steer the LLM to + generate content of certain topic, guide the style, etc. If the provided inputs + contains a `system_prompt` column, then this runtime parameter will be ignored + and the one from the column will be used. Defaults to `None`. + + Input columns: + - system_prompt (`str`, optional): an optional system prompt that can be provided + to guide the generation of the instruct LLM and steer it to generate instructions + of certain topic. + + Outputs columns: + - conversation (`ChatType`): the generated conversation which is a list of chat + items with a role and a message. + + Categories: + - instruction-generation + - mt-generation + + References: + - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) + """ + + def model_post_init(self, __context: Any) -> None: + """Checks that the provided `LLM` uses the `MagpieChatTemplateMixin`.""" + super().model_post_init(__context) + + if not isinstance(self.llm, MagpieChatTemplateMixin): + raise ValueError( + f"`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`." + f"`{self.llm.__class__.__name__}` doesn't use the aforementioned mixin." + ) + + self.llm.use_magpie_template = True + + @property + def inputs(self) -> List[str]: + return [] + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + """Does nothing.""" + return [] + + @property + def outputs(self) -> List[str]: + return ["conversation"] + + def format_output( + self, + output: Union[str, None], + input: Union[Dict[str, Any], None] = None, + ) -> Dict[str, Any]: + """Does nothing.""" + return {} + + def process(self, inputs: StepInput) -> "StepOutput": + """Generate a list of instructions or conversations of the specified number of turns. + + Args: + inputs: a list of dictionaries that can contain a `system_prompt` key. + + Yields: + The list of generated conversations. + """ + yield self._generate_with_pre_query_template(inputs) diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py index bf897138bb..369f69ba66 100644 --- a/src/distilabel/steps/tasks/magpie/generator.py +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -12,20 +12,86 @@ # See the License for the specific language governing permissions and # limitations under the License. -import TYPE_CHECKING +from typing import TYPE_CHECKING, Any + from pydantic import Field +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.base import GeneratorTask +from distilabel.steps.tasks.magpie.base import MagpieBase if TYPE_CHECKING: from distilabel.steps.typing import GeneratorStepOutput -class MagpieInstructionGenerator(GeneratorTask): - num_instructions: RuntimeParameter[int] = Field( - default=None, description="The number of instructions to generate." +class MagpieGenerator(GeneratorTask, MagpieBase): + """Generator task the generates instructions or conversations using Magpie. + + Magpie is a neat method that allows generating user instructions with no seed data + or specific system prompt thanks to the autoregressive capabilities of the instruct + fine-tuned LLMs. As they were fine-tuned using a chat template composed by a user message + and a desired assistant output, the instruct fine-tuned LLM learns that after the pre-query + or pre-instruct tokens comes an instruction. If these pre-query tokens are sent to the + LLM without any user message, then the LLM will continue generating tokens as it was + the user. This trick allows "extracting" instructions from the instruct fine-tuned LLM. + After this instruct is generated, it can be sent again to the LLM to generate this time + an assistant response. This process can be repeated N times allowing to build a multi-turn + conversation. + + This method was described in the paper 'Magpie: Alignment Data Synthesis from + Scratch by Prompting Aligned LLMs with Nothing'. + + Runtime parameters: + - `n_turns`: the number of turns that the generated conversation will have. + - `system_prompt`: an optional system prompt that can be used to steer the LLM to + generate content of certain topic, guide the style, etc. Defaults to `None`. + - `num_rows`: the number of rows to be generated. + + Outputs columns: + - conversation (`ChatType`): the generated conversation which is a list of chat + items with a role and a message. + + Categories: + - instruction-generation + - mt-generation + - generator + + References: + - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) + """ + + # TODO: move this to `GeneratorTask` + num_rows: RuntimeParameter[int] = Field( + default=None, description="The number of rows to generate." ) + def model_post_init(self, __context: Any) -> None: + """Checks that the provided `LLM` uses the `MagpieChatTemplateMixin`.""" + super().model_post_init(__context) + + if not isinstance(self.llm, MagpieChatTemplateMixin): + raise ValueError( + f"`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`." + f"`{self.llm.__class__.__name__}` doesn't use the aforementioned mixin." + ) + + self.llm.use_magpie_template = True + def process(self, offset: int = 0) -> "GeneratorStepOutput": - pass + """Generates the desired number of instructions or conversations using Magpie. + + Args: + offset: The offset to start the generation from. Defaults to `0`. + + Yields: + The generated instructions or conversations. + """ + generated = offset + + while generated <= self.num_rows: # type: ignore + conversations = self._generate_with_pre_query_template( + inputs=[{} for _ in range(self.batch_size)] # type: ignore + ) + generated += self.batch_size # type: ignore + yield (conversations, generated == self.num_generations) From dadac54ffe08bd23f42c020f7ff5dff231fb37e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 10 Jul 2024 20:07:44 +0200 Subject: [PATCH 07/30] Fix unit tests --- .../llms/huggingface/inference_endpoints.py | 22 ++- src/distilabel/llms/mixins/magpie.py | 2 +- src/distilabel/pipeline/step_wrapper.py | 2 +- .../huggingface/test_inference_endpoints.py | 128 ++++++------------ .../tasks/structured_outputs/test_outlines.py | 11 +- 5 files changed, 70 insertions(+), 95 deletions(-) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 05169cbeac..2ef5b7d3b1 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -159,9 +159,19 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( if self.base_url and (self.model_id or self.endpoint_name): self._logger.warning( # type: ignore - f"Since the `base_url={self.base_url}` is available and either one of `model_id` or `endpoint_name`" - " is also provided, the `base_url` will either be ignored or overwritten with the one generated" - " from either of those args, for serverless or dedicated inference endpoints, respectively." + f"Since the `base_url={self.base_url}` is available and either one of `model_id`" + " or `endpoint_name` is also provided, the `base_url` will either be ignored" + " or overwritten with the one generated from either of those args, for serverless" + " or dedicated inference endpoints, respectively." + ) + + if self.model_id and self.tokenizer_id is None: + self.tokenizer_id = self.model_id + + if self.use_magpie_template and self.tokenizer_id is None: + raise ValueError( + "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please," + " set a `tokenizer_id` and try again." ) if self.base_url and not (self.model_id or self.endpoint_name): @@ -174,9 +184,9 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( return self raise ValidationError( - "Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is provided too," - " it will be overwritten instead. Found `model_id`={self.model_id}, `endpoint_name`={self.endpoint_name}," - f" and `base_url`={self.base_url}." + f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is" + f" provided too, it will be overwritten instead. Found `model_id`={self.model_id}," + f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}." ) def load(self) -> None: # noqa: C901 diff --git a/src/distilabel/llms/mixins/magpie.py b/src/distilabel/llms/mixins/magpie.py index 2020774432..7c05ff6214 100644 --- a/src/distilabel/llms/mixins/magpie.py +++ b/src/distilabel/llms/mixins/magpie.py @@ -48,7 +48,7 @@ class MagpieChatTemplateMixin(BaseModel, validate_assignment=True): - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) """ - use_magpie_template: bool = True + use_magpie_template: bool = False magpie_pre_query_template: Union[MagpieAvailablePreQueryTemplates, str, None] = None @field_validator("magpie_pre_query_template") diff --git a/src/distilabel/pipeline/step_wrapper.py b/src/distilabel/pipeline/step_wrapper.py index 29c2e3e11c..3befd5187d 100644 --- a/src/distilabel/pipeline/step_wrapper.py +++ b/src/distilabel/pipeline/step_wrapper.py @@ -16,7 +16,7 @@ from queue import Queue from typing import Any, Dict, List, Optional, Union, cast -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.pipeline.batch import _Batch from distilabel.pipeline.constants import LAST_BATCH_SENT_FLAG from distilabel.pipeline.typing import StepLoadStatus diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index ecc5d97596..cba499e463 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -14,7 +14,7 @@ import random from unittest import mock -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, patch import nest_asyncio import pytest @@ -22,11 +22,30 @@ @patch("huggingface_hub.AsyncInferenceClient") -@patch("openai.AsyncOpenAI") class TestInferenceEndpointsLLM: - def test_load_no_api_key( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + def test_no_tokenizer_magpie_raise_value_error( + self, mock_inference_client: MagicMock ) -> None: + with pytest.raises( + ValueError, + match="`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`", + ): + InferenceEndpointsLLM( + base_url="http://localhost:8000", + use_magpie_template=True, + magpie_pre_query_template="llama3", + ) + + def test_tokenizer_id_set_if_model_id( + self, mock_inference_client: MagicMock + ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral" + ) + + assert llm.tokenizer_id == llm.model_id + + def test_load_no_api_key(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral" ) @@ -40,12 +59,8 @@ def test_load_no_api_key( ): llm.load() - def test_load_with_cached_token( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock - ) -> None: - llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral" - ) + def test_load_with_cached_token(self, mock_inference_client: MagicMock) -> None: + llm = InferenceEndpointsLLM(base_url="http://localhost:8000") # Mock `huggingface_hub.constants.HF_TOKEN_PATH` to exist with ( @@ -58,7 +73,7 @@ def test_load_with_cached_token( llm.load() def test_serverless_inference_endpoints_llm( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral" @@ -68,7 +83,7 @@ def test_serverless_inference_endpoints_llm( assert llm.model_name == "distilabel-internal-testing/tiny-random-mistral" def test_dedicated_inference_endpoints_llm( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( endpoint_name="tiny-random-mistral", @@ -79,11 +94,12 @@ def test_dedicated_inference_endpoints_llm( assert llm.model_name == "tiny-random-mistral" def test_dedicated_inference_endpoints_llm_via_url( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( base_url="https://api-inference.huggingface.co/models/distilabel-internal-testing/tiny-random-mistral" ) + llm.load() assert isinstance(llm, InferenceEndpointsLLM) assert ( @@ -93,12 +109,12 @@ def test_dedicated_inference_endpoints_llm_via_url( @pytest.mark.asyncio async def test_agenerate_via_inference_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral" ) - llm._aclient = mock_inference_client + llm.load() llm._aclient.text_generation = AsyncMock( return_value=" Aenean hendrerit aliquam velit. ..." @@ -113,39 +129,14 @@ async def test_agenerate_via_inference_client( ] ) == [" Aenean hendrerit aliquam velit. ..."] - @pytest.mark.asyncio - async def test_agenerate_via_openai_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock - ) -> None: - llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral", - use_openai_client=True, - ) - llm._aclient = mock_openai_client - - mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] - ) - llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) - - assert await llm.agenerate( - input=[ - {"role": "system", "content": ""}, - { - "role": "user", - "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - }, - ] - ) == [" Aenean hendrerit aliquam velit. ..."] - @pytest.mark.asyncio async def test_generate_via_inference_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral" + model_id="distilabel-internal-testing/tiny-random-mistral", ) - llm._aclient = mock_inference_client + llm.load() llm._aclient.text_generation = AsyncMock( return_value=" Aenean hendrerit aliquam velit. ..." @@ -165,45 +156,15 @@ async def test_generate_via_inference_client( ] ) == [(" Aenean hendrerit aliquam velit. ...",)] - @pytest.mark.asyncio - async def test_generate_via_openai_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock - ) -> None: - llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral", - use_openai_client=True, - ) - llm._aclient = mock_openai_client - - mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] - ) - llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) - - ... - nest_asyncio.apply() - - assert llm.generate( - inputs=[ - [ - {"role": "system", "content": ""}, - { - "role": "user", - "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - }, - ] - ] - ) == [(" Aenean hendrerit aliquam velit. ...",)] - @pytest.mark.asyncio async def test_agenerate_with_structured_output( - self, mock_inference_client: MagicMock, _: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"}, ) - llm._aclient = mock_inference_client + llm.load() llm._aclient.text_generation = AsyncMock( return_value=" Aenean hendrerit aliquam velit. ..." @@ -223,7 +184,7 @@ async def test_agenerate_with_structured_output( ) == [" Aenean hendrerit aliquam velit. ..."] kwargs = { - "prompt": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + "prompt": "[INST] Lorem ipsum dolor sit amet, consectetur adipiscing elit. [/INST]", "max_new_tokens": 128, "do_sample": False, "typical_p": None, @@ -235,15 +196,11 @@ async def test_agenerate_with_structured_output( "return_full_text": False, "watermark": False, "grammar": {"type": "regex", "value": "\\b[A-Z][a-z]*\\b"}, - "seed": 478163327, # pre-computed random value with `random.seed(42)` + "seed": 2053695854357871005, # pre-computed random value with `random.seed(42)` } - mock_inference_client.text_generation.assert_called_with(**kwargs) + llm._aclient.text_generation.assert_called_with(**kwargs) - def test_serialization( - self, - mock_inference_client: MagicMock, - mock_openai_client: MagicMock, - ) -> None: + def test_serialization(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", ) @@ -253,11 +210,12 @@ def test_serialization( "endpoint_name": None, "endpoint_namespace": None, "base_url": None, - "tokenizer_id": None, + "tokenizer_id": "distilabel-internal-testing/tiny-random-mistral", "generation_kwargs": {}, + "magpie_pre_query_template": None, "structured_output": None, "model_display_name": None, - "use_openai_client": False, + "use_magpie_template": False, "type_info": { "module": "distilabel.llms.huggingface.inference_endpoints", "name": "InferenceEndpointsLLM", diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index e174f53716..0e488eea13 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Type, Union +from typing import Any, Dict, Literal, Type, Union import pytest from distilabel.llms.huggingface.transformers import TransformersLLM @@ -33,6 +33,7 @@ class DummyUserTest(BaseModel): DUMP_JSON = { "cuda_devices": "auto", "generation_kwargs": {}, + "magpie_pre_query_template": None, "structured_output": { "format": "json", "schema": { @@ -57,6 +58,7 @@ class DummyUserTest(BaseModel): "device": None, "device_map": None, "token": None, + "use_magpie_template": False, "type_info": { "module": "distilabel.llms.huggingface.transformers", "name": "TransformersLLM", @@ -66,6 +68,7 @@ class DummyUserTest(BaseModel): DUMP_REGEX = { "cuda_devices": "auto", "generation_kwargs": {}, + "magpie_pre_query_template": None, "structured_output": { "format": "regex", "schema": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", @@ -81,6 +84,7 @@ class DummyUserTest(BaseModel): "device": None, "device_map": None, "token": None, + "use_magpie_template": False, "type_info": { "module": "distilabel.llms.huggingface.transformers", "name": "TransformersLLM", @@ -149,7 +153,10 @@ def test_generation( ], ) def test_serialization( - self, format: str, schema: Union[str, Type[BaseModel]], dump: Dict[str, Any] + self, + format: Literal["json", "regex"], + schema: Union[str, Type[BaseModel]], + dump: Dict[str, Any], ) -> None: llm = TransformersLLM( model="openaccess-ai-collective/tiny-mistral", From 844ec57cc069c49a608433d9fc36adf7491b54c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 10 Jul 2024 20:13:12 +0200 Subject: [PATCH 08/30] Fix docstrings --- src/distilabel/steps/tasks/magpie/base.py | 6 ++---- src/distilabel/steps/tasks/magpie/generator.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index 6067375c90..fa9b8fa204 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -161,9 +161,7 @@ class Magpie(Task, MagpieBase): the user. This trick allows "extracting" instructions from the instruct fine-tuned LLM. After this instruct is generated, it can be sent again to the LLM to generate this time an assistant response. This process can be repeated N times allowing to build a multi-turn - conversation. - - This method was described in the paper 'Magpie: Alignment Data Synthesis from + conversation. This method was described in the paper 'Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing'. Runtime parameters: @@ -178,7 +176,7 @@ class Magpie(Task, MagpieBase): to guide the generation of the instruct LLM and steer it to generate instructions of certain topic. - Outputs columns: + Output columns: - conversation (`ChatType`): the generated conversation which is a list of chat items with a role and a message. diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py index 369f69ba66..a89a4abf62 100644 --- a/src/distilabel/steps/tasks/magpie/generator.py +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -37,9 +37,7 @@ class MagpieGenerator(GeneratorTask, MagpieBase): the user. This trick allows "extracting" instructions from the instruct fine-tuned LLM. After this instruct is generated, it can be sent again to the LLM to generate this time an assistant response. This process can be repeated N times allowing to build a multi-turn - conversation. - - This method was described in the paper 'Magpie: Alignment Data Synthesis from + conversation. This method was described in the paper 'Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing'. Runtime parameters: @@ -48,7 +46,7 @@ class MagpieGenerator(GeneratorTask, MagpieBase): generate content of certain topic, guide the style, etc. Defaults to `None`. - `num_rows`: the number of rows to be generated. - Outputs columns: + Output columns: - conversation (`ChatType`): the generated conversation which is a list of chat items with a role and a message. From 04ecc3a7b1939be2cbb2017960454761fc4587dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 10 Jul 2024 20:20:56 +0200 Subject: [PATCH 09/30] Mock `HF_TOKEN` environment variable --- .../unit/llms/huggingface/test_inference_endpoints.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index cba499e463..d1bb67cff5 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import random +from typing import Generator from unittest import mock from unittest.mock import AsyncMock, MagicMock, patch @@ -21,6 +23,12 @@ from distilabel.llms.huggingface.inference_endpoints import InferenceEndpointsLLM +@pytest.fixture(autouse=True) +def mock_hf_token_env_variable() -> Generator[None, None, None]: + with patch.dict(os.environ, {"HF_TOKEN": "hf_token"}): + yield + + @patch("huggingface_hub.AsyncInferenceClient") class TestInferenceEndpointsLLM: def test_no_tokenizer_magpie_raise_value_error( @@ -46,6 +54,8 @@ def test_tokenizer_id_set_if_model_id( assert llm.tokenizer_id == llm.model_id def test_load_no_api_key(self, mock_inference_client: MagicMock) -> None: + del os.environ["HF_TOKEN"] + llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral" ) From a15752a866ef3ae27a74f71a5c5a7c65991fbe68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 09:47:19 +0200 Subject: [PATCH 10/30] Fix list index out of range --- .../llms/huggingface/inference_endpoints.py | 12 ++++++++---- src/distilabel/llms/huggingface/transformers.py | 12 ++++++++---- src/distilabel/llms/mixins/magpie.py | 2 +- src/distilabel/llms/vllm.py | 12 ++++++++---- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 2ef5b7d3b1..e2fb238599 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -283,10 +283,14 @@ def prepare_input(self, input: "StandardInput") -> str: Returns: The prompt to send to the LLM. """ - prompt: str = self._tokenizer.apply_chat_template( # type: ignore - conversation=input, # type: ignore - tokenize=False, - add_generation_prompt=True, + prompt: str = ( + self._tokenizer.apply_chat_template( # type: ignore + conversation=input, # type: ignore + tokenize=False, + add_generation_prompt=True, + ) + if input + else "" ) return super().apply_magpie_pre_query_template(prompt, input) diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 3b4baa630f..5b702adbb8 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -167,10 +167,14 @@ def prepare_input(self, input: "StandardInput") -> str: Returns: The prompt to send to the LLM. """ - prompt: str = self._pipeline.tokenizer.apply_chat_template( # type: ignore - input, # type: ignore - tokenize=False, - add_generation_prompt=True, + prompt: str = ( + self._pipeline.tokenizer.apply_chat_template( # type: ignore + input, # type: ignore + tokenize=False, + add_generation_prompt=True, + ) + if input + else "" ) return super().apply_magpie_pre_query_template(prompt, input) diff --git a/src/distilabel/llms/mixins/magpie.py b/src/distilabel/llms/mixins/magpie.py index 7c05ff6214..ede1d1976a 100644 --- a/src/distilabel/llms/mixins/magpie.py +++ b/src/distilabel/llms/mixins/magpie.py @@ -83,6 +83,6 @@ def apply_magpie_pre_query_template( Returns: The prompt with the pre-query template applied if needed. """ - if not self.use_magpie_template or input[-1]["role"] == "user": + if not self.use_magpie_template or (input and input[-1]["role"] == "user"): return prompt return prompt + self.magpie_pre_query_template # type: ignore diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index e5ec27816d..130b3f76ec 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -226,10 +226,14 @@ def prepare_input(self, input: "StandardInput") -> str: Returns: The prompt to send to the LLM. """ - prompt = self._tokenizer.apply_chat_template( # type: ignore - input, # type: ignore - tokenize=False, - add_generation_prompt=True, # type: ignore + prompt: str = ( + self._tokenizer.apply_chat_template( # type: ignore + input, # type: ignore + tokenize=False, + add_generation_prompt=True, # type: ignore + ) + if input + else "" ) return super().apply_magpie_pre_query_template(prompt, input) From b46cd3344b03a6d801f5f6c27fb4ee32739b4834 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 10:28:39 +0200 Subject: [PATCH 11/30] Fix `MagpieGenerator` last batch --- .../steps/tasks/magpie/generator.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py index a89a4abf62..d2c507bac1 100644 --- a/src/distilabel/steps/tasks/magpie/generator.py +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, List, Union from pydantic import Field @@ -76,6 +76,18 @@ def model_post_init(self, __context: Any) -> None: self.llm.use_magpie_template = True + @property + def outputs(self) -> List[str]: + return ["conversation"] + + def format_output( + self, + output: Union[str, None], + input: Union[Dict[str, Any], None] = None, + ) -> Dict[str, Any]: + """Does nothing.""" + return {} + def process(self, offset: int = 0) -> "GeneratorStepOutput": """Generates the desired number of instructions or conversations using Magpie. @@ -88,8 +100,11 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": generated = offset while generated <= self.num_rows: # type: ignore + rows_to_generate = ( + self.num_rows if self.num_rows < self.batch_size else self.batch_size # type: ignore + ) conversations = self._generate_with_pre_query_template( - inputs=[{} for _ in range(self.batch_size)] # type: ignore + inputs=[{} for _ in range(rows_to_generate)] # type: ignore ) - generated += self.batch_size # type: ignore - yield (conversations, generated == self.num_generations) + generated += rows_to_generate # type: ignore + yield (conversations, generated == self.num_rows) From 75bd827a53e17d8c616441ef430c691c61247f52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 10:48:22 +0200 Subject: [PATCH 12/30] Add `only_instruction` attribute --- src/distilabel/steps/tasks/magpie/base.py | 65 ++++++++++++++----- .../steps/tasks/magpie/generator.py | 19 ++++-- 2 files changed, 64 insertions(+), 20 deletions(-) diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index fa9b8fa204..f77addd7d1 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -50,12 +50,24 @@ class MagpieBase(RuntimeParametersMixin): default=1, description="The number of turns to generate for the conversation.", ) + only_instruction: RuntimeParameter[bool] = Field( + default=False, + description="Whether to generate only the instruction. If this argument" + " is `True`, then `n_turns` will be ignored.", + ) system_prompt: Optional[RuntimeParameter[str]] = Field( default=None, description="An optional system prompt that can be used to steer the LLM to generate" " content of certain topic, guide the style, etc.", ) + @property + def outputs(self) -> List[str]: + """Either a multi-turn conversation or the instruction generated.""" + if self.only_instruction: + return ["instruction"] + return ["conversation"] + def _prepare_inputs_for_instruction_generation( self, inputs: List[Dict[str, Any]] ) -> List["FormattedInput"]: @@ -104,18 +116,10 @@ def _append_messages_to_conversations( conversation.append({"role": role, "content": instruction}) return conversations - def _generate_with_pre_query_template( + def _generate_multi_turn_conversation( self, inputs: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: - """Generate a list of instructions or conversations of the specified number of turns. - - Args: - inputs: a list of dictionaries that can contain a `system_prompt` key. - - Returns: - The list of generated conversations. - """ - conversations = self._prepare_inputs_for_instruction_generation(inputs=inputs) + conversations = self._prepare_inputs_for_instruction_generation(inputs) for _ in range(self.n_turns): # type: ignore # Generate instruction or user message @@ -148,6 +152,29 @@ def _generate_with_pre_query_template( return [{"conversation": conversation} for conversation in conversations] + def _generate_with_pre_query_template( + self, inputs: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Generate a list of instructions or conversations of the specified number of turns. + + Args: + inputs: a list of dictionaries that can contain a `system_prompt` key. + + Returns: + The list of generated conversations. + """ + + if self.only_instruction: + prepared_inputs = self._prepare_inputs_for_instruction_generation(inputs) + outputs = self.llm.generate( + inputs=prepared_inputs, + num_generations=1, + **self.llm.generation_kwargs, # type: ignore + ) + return [{"instruction": output[0]} for output in outputs] + + return self._generate_multi_turn_conversation(inputs) + class Magpie(Task, MagpieBase): """Generates conversations using an instruct fine-tuned LLM. @@ -164,8 +191,19 @@ class Magpie(Task, MagpieBase): conversation. This method was described in the paper 'Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing'. + Attributes: + n_turns: the number of turns that the generated conversation will have. + only_instruction: whether to generate only the instruction. If this argument is + `True`, then `n_turns` will be ignored. Defaults to `False`. + system_prompt: an optional system prompt that can be used to steer the LLM to generate + content of certain topic, guide the style, etc. If the provided inputs contains + a `system_prompt` column, then this runtime parameter will be ignored and the + one from the column will be used. Defaults to `None`. + Runtime parameters: - `n_turns`: the number of turns that the generated conversation will have. + only_instruction: whether to generate only the instruction. If this argument is + `True`, then `n_turns` will be ignored. Defaults to `False`. - `system_prompt`: an optional system prompt that can be used to steer the LLM to generate content of certain topic, guide the style, etc. If the provided inputs contains a `system_prompt` column, then this runtime parameter will be ignored @@ -178,7 +216,8 @@ class Magpie(Task, MagpieBase): Output columns: - conversation (`ChatType`): the generated conversation which is a list of chat - items with a role and a message. + items with a role and a message. Only if `only_instructions=False`. + - instruction (`str`): the generated instructions if `only_instruction=True`. Categories: - instruction-generation @@ -208,10 +247,6 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType": """Does nothing.""" return [] - @property - def outputs(self) -> List[str]: - return ["conversation"] - def format_output( self, output: Union[str, None], diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py index d2c507bac1..e8178e6269 100644 --- a/src/distilabel/steps/tasks/magpie/generator.py +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, Union from pydantic import Field @@ -40,8 +40,20 @@ class MagpieGenerator(GeneratorTask, MagpieBase): conversation. This method was described in the paper 'Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing'. + Attributes: + n_turns: the number of turns that the generated conversation will have. + only_instruction: whether to generate only the instruction. If this argument is + `True`, then `n_turns` will be ignored. Defaults to `False`. + system_prompt: an optional system prompt that can be used to steer the LLM to generate + content of certain topic, guide the style, etc. If the provided inputs contains + a `system_prompt` column, then this runtime parameter will be ignored and the + one from the column will be used. Defaults to `None`. + num_rows: the number of rows to be generated. + Runtime parameters: - `n_turns`: the number of turns that the generated conversation will have. + - `only_instruction`: whether to generate only the instruction. If this argument + is `True`, then `n_turns` will be ignored. Defaults to `False`. - `system_prompt`: an optional system prompt that can be used to steer the LLM to generate content of certain topic, guide the style, etc. Defaults to `None`. - `num_rows`: the number of rows to be generated. @@ -49,6 +61,7 @@ class MagpieGenerator(GeneratorTask, MagpieBase): Output columns: - conversation (`ChatType`): the generated conversation which is a list of chat items with a role and a message. + - instruction (`str`): the generated instructions if `only_instruction=True`. Categories: - instruction-generation @@ -76,10 +89,6 @@ def model_post_init(self, __context: Any) -> None: self.llm.use_magpie_template = True - @property - def outputs(self) -> List[str]: - return ["conversation"] - def format_output( self, output: Union[str, None], From a86f6406b2cc48967962172b1cbcd0028d9a7a1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 10:54:20 +0200 Subject: [PATCH 13/30] Update categories --- src/distilabel/steps/tasks/magpie/base.py | 4 ++-- src/distilabel/steps/tasks/magpie/generator.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index f77addd7d1..7b99f9427c 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -220,8 +220,8 @@ class Magpie(Task, MagpieBase): - instruction (`str`): the generated instructions if `only_instruction=True`. Categories: - - instruction-generation - - mt-generation + - text-generation + - instruction References: - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py index e8178e6269..ced902f9ca 100644 --- a/src/distilabel/steps/tasks/magpie/generator.py +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -64,8 +64,8 @@ class MagpieGenerator(GeneratorTask, MagpieBase): - instruction (`str`): the generated instructions if `only_instruction=True`. Categories: - - instruction-generation - - mt-generation + - text-generation + - instruction - generator References: From 463f622ebbb3fcfdb5594ba38704646e6dbc2e6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 11:04:50 +0200 Subject: [PATCH 14/30] testing --- src/distilabel/steps/tasks/magpie/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index 7b99f9427c..da6ca602e3 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -191,6 +191,8 @@ class Magpie(Task, MagpieBase): conversation. This method was described in the paper 'Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing'. + ![magpie](https://magpie-align.github.io/magpie_logo.png) + Attributes: n_turns: the number of turns that the generated conversation will have. only_instruction: whether to generate only the instruction. If this argument is From 953b933e0315fbe146379499b64454d1305c3db2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 11:23:20 +0200 Subject: [PATCH 15/30] Worth trying --- src/distilabel/steps/tasks/magpie/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index da6ca602e3..7b99f9427c 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -191,8 +191,6 @@ class Magpie(Task, MagpieBase): conversation. This method was described in the paper 'Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing'. - ![magpie](https://magpie-align.github.io/magpie_logo.png) - Attributes: n_turns: the number of turns that the generated conversation will have. only_instruction: whether to generate only the instruction. If this argument is From 53ff036dad9df2fd717ee204c950ea497ae20f7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 13:12:05 +0200 Subject: [PATCH 16/30] Add examples --- src/distilabel/steps/tasks/magpie/base.py | 110 ++++++++++++++++- .../steps/tasks/magpie/generator.py | 114 ++++++++++++++++++ 2 files changed, 223 insertions(+), 1 deletion(-) diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index 7b99f9427c..3592cb735f 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -121,7 +121,7 @@ def _generate_multi_turn_conversation( ) -> List[Dict[str, Any]]: conversations = self._prepare_inputs_for_instruction_generation(inputs) - for _ in range(self.n_turns): # type: ignore + for _ in range(self.n_turns - 1): # type: ignore # Generate instruction or user message outputs = self.llm.generate( inputs=conversations, @@ -225,6 +225,114 @@ class Magpie(Task, MagpieBase): References: - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) + + Examples: + + Generating instructions with Llama 3 8B Instruct and TransformersLLM: + + ```python + from distilabel.llms import TransformersLLM + from distilabel.steps.tasks import Magpie + + magpie = Magpie( + llm=TransformersLLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + magpie_pre_query_template="llama3", + generation_kwargs={ + "temperature": 1.0, + "max_new_tokens": 64, + }, + device="mps", + ), + only_instruction=True, + ) + + magpie.load() + + result = next( + magpie.process( + inputs=[ + { + "system_prompt": "You're a math expert AI assistant that helps students of secondary school to solve calculus problems." + }, + { + "system_prompt": "You're an expert florist AI assistant that helps user to erradicate pests in their crops." + }, + ] + ) + ) + # [ + # {'instruction': "That's me! I'd love some help with solving calculus problems! What kind of calculation are you most effective at? Linear Algebra, derivatives, integrals, optimization?"}, + # {'instruction': 'I was wondering if there are certain flowers and plants that can be used for pest control?'} + # ] + ``` + + Generating conversations with Llama 3 8B Instruct and TransformersLLM: + + ```python + from distilabel.llms import TransformersLLM + from distilabel.steps.tasks import Magpie + + magpie = Magpie( + llm=TransformersLLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + magpie_pre_query_template="llama3", + generation_kwargs={ + "temperature": 1.0, + "max_new_tokens": 256, + }, + device="mps", + ), + n_turns=2, + ) + + magpie.load() + + result = next( + magpie.process( + inputs=[ + { + "system_prompt": "You're a math expert AI assistant that helps students of secondary school to solve calculus problems." + }, + { + "system_prompt": "You're an expert florist AI assistant that helps user to erradicate pests in their crops." + }, + ] + ) + ) + # [ + # { + # 'conversation': [ + # {'role': 'system', 'content': "You're a math expert AI assistant that helps students of secondary school to solve calculus problems."}, + # { + # 'role': 'user', + # 'content': 'I\'m having trouble solving the limits of functions in calculus. Could you explain how to work with them? Limits of functions are denoted by lim x→a f(x) or lim x→a [f(x)]. It is read as "the limit as x approaches a of f + # of x".' + # }, + # { + # 'role': 'assistant', + # 'content': 'Limits are indeed a fundamental concept in calculus, and understanding them can be a bit tricky at first, but don\'t worry, I\'m here to help! The notation lim x→a f(x) indeed means "the limit as x approaches a of f of + # x". What it\'s asking us to do is find the' + # } + # ] + # }, + # { + # 'conversation': [ + # {'role': 'system', 'content': "You're an expert florist AI assistant that helps user to erradicate pests in their crops."}, + # { + # 'role': 'user', + # 'content': "As a flower shop owner, I'm noticing some unusual worm-like creatures causing damage to my roses and other flowers. Can you help me identify what the problem is? Based on your expertise as a florist AI assistant, I think it + # might be pests or diseases, but I'm not sure which." + # }, + # { + # 'role': 'assistant', + # 'content': "I'd be delighted to help you investigate the issue! Since you've noticed worm-like creatures damaging your roses and other flowers, I'll take a closer look at the possibilities. Here are a few potential culprits: 1. + # **Aphids**: These small, soft-bodied insects can secrete a sticky substance called" + # } + # ] + # } + # ] + ``` """ def model_post_init(self, __context: Any) -> None: diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py index ced902f9ca..78a4e8d820 100644 --- a/src/distilabel/steps/tasks/magpie/generator.py +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -70,6 +70,120 @@ class MagpieGenerator(GeneratorTask, MagpieBase): References: - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) + + Examples: + + Generating instructions with Llama 3 8B Instruct and TransformersLLM: + + ```python + from distilabel.llms import TransformersLLM + from distilabel.steps.tasks import MagpieGenerator + + generator = MagpieGenerator( + llm=TransformersLLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + magpie_pre_query_template="llama3", + generation_kwargs={ + "temperature": 1.0, + "max_new_tokens": 256, + }, + device="mps", + ), + only_instruction=True, + num_rows=5, + ) + + generator.load() + + result = next(generator.process()) + # ( + # [ + # {"instruction": "I've just bought a new phone and I're excited to start using it."}, + # {"instruction": "What are the most common types of companies that use digital signage?"} + # ], + # True + # ) + ``` + + Generating a conversation with Llama 3 8B Instruct and TransformersLLM: + + ```python + from distilabel.llms import TransformersLLM + from distilabel.steps.tasks import MagpieGenerator + + generator = MagpieGenerator( + llm=TransformersLLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + magpie_pre_query_template="llama3", + generation_kwargs={ + "temperature": 1.0, + "max_new_tokens": 64, + }, + device="mps", + ), + n_turns=3, + num_rows=5, + ) + + generator.load() + + result = next(generator.process()) + # ( + # [ + # { + # 'conversation': [ + # { + # 'role': 'system', + # 'content': 'You are a helpful Al assistant. The user will engage in a multi−round conversation with you,asking initial questions and following up with additional related questions. Your goal is to provide thorough, relevant and + # insightful responses to help the user with their queries.' + # }, + # {'role': 'user', 'content': "I'm considering starting a social media campaign for my small business and I're not sure where to start. Can you help?"}, + # { + # 'role': 'assistant', + # 'content': "Exciting endeavor! Creating a social media campaign can be a great way to increase brand awareness, drive website traffic, and ultimately boost sales. I'd be happy to guide you through the process. To get started, + # let's break down the basics. First, we need to identify your goals and target audience. What do" + # }, + # { + # 'role': 'user', + # 'content': "Before I start a social media campaign, what kind of costs ammol should I expect to pay? There are several factors that contribute to the total cost of running a social media campaign. Let me outline some of the main + # expenses you might encounter: 1. Time: As the business owner, you'll likely spend time creating" + # }, + # { + # 'role': 'assistant', + # 'content': 'Time is indeed one of the biggest investments when it comes to running a social media campaign! Besides time, you may also incur costs associated with: 2. Content creation: You might need to hire freelancers or + # agencies to create high-quality content (images, videos, captions) for your social media platforms. 3. Advertising' + # } + # ] + # }, + # { + # 'conversation': [ + # { + # 'role': 'system', + # 'content': 'You are a helpful Al assistant. The user will engage in a multi−round conversation with you,asking initial questions and following up with additional related questions. Your goal is to provide thorough, relevant and + # insightful responses to help the user with their queries.' + # }, + # {'role': 'user', 'content': "I am thinking of buying a new laptop or computer. What are some important factors I should consider when making your decision? I'll make sure to let you know if any other favorites or needs come up!"}, + # { + # 'role': 'assistant', + # 'content': 'Exciting times ahead! When considering a new laptop or computer, there are several key factors to think about to ensure you find the right one for your needs. Here are some crucial ones to get you started: 1. + # **Purpose**: How will you use your laptop or computer? For work, gaming, video editing,' + # }, + # { + # 'role': 'user', + # 'content': 'Let me stop you there. Let\'s explore this "purpose" factor that you mentioned earlier. Can you elaborate more on what type of devices would be suitable for different purposes? For example, if I\'re primarily using my + # laptop for general usage like browsing, email, and word processing, would a budget-friendly laptop be sufficient' + # }, + # { + # 'role': 'assistant', + # 'content': "Understanding your purpose can greatly impact the type of device you'll need. **General Usage (Browsing, Email, Word Processing)**: For casual users who mainly use their laptop for daily tasks, a budget-friendly + # option can be sufficient. Look for laptops with: * Intel Core i3 or i5 processor* " + # } + # ] + # } + # ], + # True + # ) + ``` """ # TODO: move this to `GeneratorTask` From ba85907f39a949c56f00446f2794907d92967507 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 15:24:52 +0200 Subject: [PATCH 17/30] Add magpie unit tests --- src/distilabel/steps/tasks/magpie/base.py | 16 +- .../steps/tasks/magpie/generator.py | 9 +- tests/unit/conftest.py | 21 +- tests/unit/llms/mixins/test_magpie.py | 60 ++++ tests/unit/steps/tasks/magpie/__init__.py | 14 + tests/unit/steps/tasks/magpie/test_base.py | 269 ++++++++++++++++++ .../unit/steps/tasks/magpie/test_generator.py | 154 ++++++++++ 7 files changed, 532 insertions(+), 11 deletions(-) create mode 100644 tests/unit/llms/mixins/test_magpie.py create mode 100644 tests/unit/steps/tasks/magpie/__init__.py create mode 100644 tests/unit/steps/tasks/magpie/test_base.py create mode 100644 tests/unit/steps/tasks/magpie/test_generator.py diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index 3592cb735f..7a974177fb 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -61,13 +61,6 @@ class MagpieBase(RuntimeParametersMixin): " content of certain topic, guide the style, etc.", ) - @property - def outputs(self) -> List[str]: - """Either a multi-turn conversation or the instruction generated.""" - if self.only_instruction: - return ["instruction"] - return ["conversation"] - def _prepare_inputs_for_instruction_generation( self, inputs: List[Dict[str, Any]] ) -> List["FormattedInput"]: @@ -121,7 +114,7 @@ def _generate_multi_turn_conversation( ) -> List[Dict[str, Any]]: conversations = self._prepare_inputs_for_instruction_generation(inputs) - for _ in range(self.n_turns - 1): # type: ignore + for _ in range(self.n_turns): # type: ignore # Generate instruction or user message outputs = self.llm.generate( inputs=conversations, @@ -355,6 +348,13 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType": """Does nothing.""" return [] + @property + def outputs(self) -> List[str]: + """Either a multi-turn conversation or the instruction generated.""" + if self.only_instruction: + return ["instruction"] + return ["conversation"] + def format_output( self, output: Union[str, None], diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py index 78a4e8d820..8d9dca96e5 100644 --- a/src/distilabel/steps/tasks/magpie/generator.py +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union from pydantic import Field @@ -211,6 +211,13 @@ def format_output( """Does nothing.""" return {} + @property + def outputs(self) -> List[str]: + """Either a multi-turn conversation or the instruction generated.""" + if self.only_instruction: + return ["instruction"] + return ["conversation"] + def process(self, offset: int = 0) -> "GeneratorStepOutput": """Generates the desired number of instructions or conversations using Magpie. diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index bbe6ca1ed4..adcd690276 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, List import pytest -from distilabel.llms.base import AsyncLLM +from distilabel.llms.base import LLM, AsyncLLM +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin if TYPE_CHECKING: from distilabel.llms.typing import GenerateOutput @@ -37,6 +38,22 @@ async def agenerate( return ["output" for _ in range(num_generations)] +class DummyMagpieLLM(LLM, MagpieChatTemplateMixin): + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + def generate( + self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any + ) -> List["GenerateOutput"]: + return [ + ["Hello Magpie" for _ in range(num_generations)] for _ in range(len(inputs)) + ] + + @pytest.fixture def dummy_llm() -> AsyncLLM: return DummyLLM() diff --git a/tests/unit/llms/mixins/test_magpie.py b/tests/unit/llms/mixins/test_magpie.py new file mode 100644 index 0000000000..bc7503fb2c --- /dev/null +++ b/tests/unit/llms/mixins/test_magpie.py @@ -0,0 +1,60 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from distilabel.llms.mixins.magpie import MAGPIE_PRE_QUERY_TEMPLATES + +from tests.unit.conftest import DummyMagpieLLM + + +class TestMagpieChatTemplateMixin: + def test_magpie_pre_query_template_set(self) -> None: + with pytest.raises( + ValueError, + match="Cannot set `use_magpie_template=True` if `magpie_pre_query_template` is `None`", + ): + DummyMagpieLLM(use_magpie_template=True) + + def test_magpie_pre_query_template_alias_resolved(self) -> None: + llm = DummyMagpieLLM(magpie_pre_query_template="llama3") + assert llm.magpie_pre_query_template == MAGPIE_PRE_QUERY_TEMPLATES["llama3"] + + def test_apply_magpie_pre_query_template(self) -> None: + llm = DummyMagpieLLM(magpie_pre_query_template="") + + assert ( + llm.apply_magpie_pre_query_template( + prompt="Hello hello", input=[] + ) + == "Hello hello" + ) + + llm = DummyMagpieLLM( + use_magpie_template=True, magpie_pre_query_template="" + ) + + assert ( + llm.apply_magpie_pre_query_template( + prompt="Hello hello", input=[] + ) + == "Hello hello" + ) + + assert ( + llm.apply_magpie_pre_query_template( + prompt="Hello helloHey", + input=[{"role": "user", "content": "Hey"}], + ) + == "Hello helloHey" + ) diff --git a/tests/unit/steps/tasks/magpie/__init__.py b/tests/unit/steps/tasks/magpie/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/steps/tasks/magpie/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/steps/tasks/magpie/test_base.py b/tests/unit/steps/tasks/magpie/test_base.py new file mode 100644 index 0000000000..77ed178f4c --- /dev/null +++ b/tests/unit/steps/tasks/magpie/test_base.py @@ -0,0 +1,269 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from distilabel.llms.openai import OpenAILLM +from distilabel.steps.tasks.magpie.base import MAGPIE_MULTI_TURN_SYSTEM_PROMPT, Magpie + +from tests.unit.conftest import DummyMagpieLLM + + +class TestMagpie: + def test_raise_value_error_llm_no_magpie_mixin(self) -> None: + with pytest.raises( + ValueError, + match="`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`", + ): + Magpie(llm=OpenAILLM(model="gpt-4", api_key="fake")) # type: ignore + + def test_outputs(self) -> None: + task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3")) + + assert task.outputs == ["conversation"] + + task = Magpie( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + only_instruction=True, + ) + + assert task.outputs == ["instruction"] + + def test_process(self) -> None: + task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=1) + + task.load() + + assert next(task.process(inputs=[{}, {}, {}])) == [ + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + ] + + def test_process_with_n_turns(self) -> None: + task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=2) + + task.load() + + assert next(task.process(inputs=[{}, {}, {}])) == [ + { + "conversation": [ + {"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + ] + + def test_process_with_system_prompt_per_row(self) -> None: + task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=2) + + task.load() + + assert next( + task.process( + inputs=[ + {"system_prompt": "You're a math expert assistant."}, + {"system_prompt": "You're a florist expert assistant."}, + {"system_prompt": "You're a plumber expert assistant."}, + ] + ) + ) == [ + { + "conversation": [ + {"role": "system", "content": "You're a math expert assistant."}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "system", "content": "You're a florist expert assistant."}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "system", "content": "You're a plumber expert assistant."}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + ] + + def test_process_only_instruction(self) -> None: + task = Magpie( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + only_instruction=True, + ) + + task.load() + + assert next(task.process(inputs=[{}, {}, {}])) == [ + {"instruction": "Hello Magpie"}, + {"instruction": "Hello Magpie"}, + {"instruction": "Hello Magpie"}, + ] + + def test_serialization(self) -> None: + task = Magpie( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + only_instruction=True, + ) + + assert task.dump() == { + "llm": { + "use_magpie_template": True, + "magpie_pre_query_template": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", + "generation_kwargs": {}, + "type_info": { + "module": "tests.unit.conftest", + "name": "DummyMagpieLLM", + }, + }, + "n_turns": 1, + "only_instruction": True, + "system_prompt": None, + "name": "magpie_0", + "resources": { + "replicas": 1, + "cpus": None, + "gpus": None, + "memory": None, + "resources": None, + }, + "input_mappings": {}, + "output_mappings": {}, + "input_batch_size": 50, + "group_generations": False, + "add_raw_output": True, + "num_generations": 1, + "runtime_parameters_info": [ + { + "name": "llm", + "runtime_parameters_info": [ + { + "name": "generation_kwargs", + "description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.", + "keys": [{"name": "kwargs", "optional": False}], + } + ], + }, + { + "name": "n_turns", + "optional": True, + "description": "The number of turns to generate for the conversation.", + }, + { + "name": "only_instruction", + "optional": True, + "description": "Whether to generate only the instruction. If this argument is `True`, then `n_turns` will be ignored.", + }, + { + "name": "system_prompt", + "optional": True, + "description": "An optional system prompt that can be used to steer the LLM to generate content of certain topic, guide the style, etc.", + }, + { + "name": "resources", + "runtime_parameters_info": [ + { + "name": "replicas", + "optional": True, + "description": "The number of replicas for the step.", + }, + { + "name": "cpus", + "optional": True, + "description": "The number of CPUs assigned to each step replica.", + }, + { + "name": "gpus", + "optional": True, + "description": "The number of GPUs assigned to each step replica.", + }, + { + "name": "memory", + "optional": True, + "description": "The memory in bytes required for each step replica.", + }, + { + "name": "resources", + "optional": True, + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + }, + ], + }, + { + "name": "input_batch_size", + "optional": True, + "description": "The number of rows that will contain the batches processed by the step.", + }, + { + "name": "add_raw_output", + "optional": True, + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + }, + { + "name": "num_generations", + "optional": True, + "description": "The number of generations to be produced per input.", + }, + ], + "type_info": { + "module": "distilabel.steps.tasks.magpie.base", + "name": "Magpie", + }, + } diff --git a/tests/unit/steps/tasks/magpie/test_generator.py b/tests/unit/steps/tasks/magpie/test_generator.py new file mode 100644 index 0000000000..7ebb815e0d --- /dev/null +++ b/tests/unit/steps/tasks/magpie/test_generator.py @@ -0,0 +1,154 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from distilabel.llms.openai import OpenAILLM +from distilabel.steps.tasks.magpie.generator import MagpieGenerator + +from tests.unit.conftest import DummyMagpieLLM + + +class TestMagpieGenerator: + def test_raise_value_error_llm_no_magpie_mixin(self) -> None: + with pytest.raises( + ValueError, + match="`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`", + ): + MagpieGenerator(llm=OpenAILLM(model="gpt-4", api_key="fake")) # type: ignore + + def test_outputs(self) -> None: + task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3")) + + assert task.outputs == ["conversation"] + + task = MagpieGenerator( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + only_instruction=True, + ) + + assert task.outputs == ["instruction"] + + def test_serialization(self) -> None: + task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3")) + + assert task.dump() == { + "llm": { + "use_magpie_template": True, + "magpie_pre_query_template": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", + "generation_kwargs": {}, + "type_info": { + "module": "tests.unit.conftest", + "name": "DummyMagpieLLM", + }, + }, + "n_turns": 1, + "only_instruction": False, + "system_prompt": None, + "name": "magpie_generator_0", + "resources": { + "replicas": 1, + "cpus": None, + "gpus": None, + "memory": None, + "resources": None, + }, + "input_mappings": {}, + "output_mappings": {}, + "batch_size": 50, + "group_generations": False, + "add_raw_output": True, + "num_generations": 1, + "num_rows": None, + "runtime_parameters_info": [ + { + "name": "llm", + "runtime_parameters_info": [ + { + "name": "generation_kwargs", + "description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.", + "keys": [{"name": "kwargs", "optional": False}], + } + ], + }, + { + "name": "n_turns", + "optional": True, + "description": "The number of turns to generate for the conversation.", + }, + { + "name": "only_instruction", + "optional": True, + "description": "Whether to generate only the instruction. If this argument is `True`, then `n_turns` will be ignored.", + }, + { + "name": "system_prompt", + "optional": True, + "description": "An optional system prompt that can be used to steer the LLM to generate content of certain topic, guide the style, etc.", + }, + { + "name": "resources", + "runtime_parameters_info": [ + { + "name": "replicas", + "optional": True, + "description": "The number of replicas for the step.", + }, + { + "name": "cpus", + "optional": True, + "description": "The number of CPUs assigned to each step replica.", + }, + { + "name": "gpus", + "optional": True, + "description": "The number of GPUs assigned to each step replica.", + }, + { + "name": "memory", + "optional": True, + "description": "The memory in bytes required for each step replica.", + }, + { + "name": "resources", + "optional": True, + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + }, + ], + }, + { + "name": "batch_size", + "optional": True, + "description": "The number of rows that will contain the batches generated by the step.", + }, + { + "name": "add_raw_output", + "optional": True, + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + }, + { + "name": "num_generations", + "optional": True, + "description": "The number of generations to be produced per input.", + }, + { + "name": "num_rows", + "optional": False, + "description": "The number of rows to generate.", + }, + ], + "type_info": { + "module": "distilabel.steps.tasks.magpie.generator", + "name": "MagpieGenerator", + }, + } From b2e88050ea8f2cdbf436439337e45358403fb814 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 15:26:21 +0200 Subject: [PATCH 18/30] Fix docstring --- src/distilabel/steps/tasks/magpie/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index 7a974177fb..d5db459870 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -195,7 +195,7 @@ class Magpie(Task, MagpieBase): Runtime parameters: - `n_turns`: the number of turns that the generated conversation will have. - only_instruction: whether to generate only the instruction. If this argument is + - `only_instruction`: whether to generate only the instruction. If this argument is `True`, then `n_turns` will be ignored. Defaults to `False`. - `system_prompt`: an optional system prompt that can be used to steer the LLM to generate content of certain topic, guide the style, etc. If the provided inputs From e52ae3f261b40d7ba44a030fec27626cc87296d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 15:29:02 +0200 Subject: [PATCH 19/30] Update docstrings --- src/distilabel/llms/huggingface/inference_endpoints.py | 6 ++++++ src/distilabel/llms/huggingface/transformers.py | 6 ++++++ src/distilabel/llms/mixins/magpie.py | 7 ++++--- src/distilabel/llms/vllm.py | 6 ++++++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index e2fb238599..ccd4e87c06 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -63,6 +63,12 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin): tokenizer_id: the tokenizer ID to use for the LLM as available in the Hugging Face Hub. Defaults to `None`, but defining one is recommended to properly format the prompt. model_display_name: the model display name to use for the LLM. Defaults to `None`. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. Icon: `:hugging:` diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 5b702adbb8..86754e8ef1 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -65,6 +65,12 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): local configuration will be used. Defaults to `None`. structured_output: a dictionary containing the structured output configuration or if more fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. Icon: `:hugging:` diff --git a/src/distilabel/llms/mixins/magpie.py b/src/distilabel/llms/mixins/magpie.py index ede1d1976a..8efa3add58 100644 --- a/src/distilabel/llms/mixins/magpie.py +++ b/src/distilabel/llms/mixins/magpie.py @@ -38,11 +38,12 @@ class MagpieChatTemplateMixin(BaseModel, validate_assignment=True): task. Attributes: - use_magpie_template: a flag used to enable/disable applying the pre-query template. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. magpie_pre_query_template: the pre-query template to be applied to the prompt or sent to the LLM to generate an instruction or a follow up user message. Valid - values are "llama3", "qwen2", ... - or a pre-query template. + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. References: - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index 130b3f76ec..d8e6100a13 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -78,6 +78,12 @@ class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): _tokenizer: the tokenizer instance used to format the prompt before passing it to the `LLM`. This attribute is meant to be used internally and should not be accessed directly. It will be set in the `load` method. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. References: - https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py From 5736a257f8780e99f59946520bf7ecfd290bf953 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 15:29:47 +0200 Subject: [PATCH 20/30] Apply suggestions from code review Co-authored-by: Agus --- src/distilabel/steps/tasks/magpie/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index d5db459870..9ecbfe2f59 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -31,7 +31,7 @@ MAGPIE_MULTI_TURN_SYSTEM_PROMPT = ( "You are a helpful Al assistant. The user will engage in a multi−round conversation" - " with you,asking initial questions and following up with additional related questions." + " with you, asking initial questions and following up with additional related questions." " Your goal is to provide thorough, relevant and insightful responses to help the user" " with their queries." ) @@ -177,7 +177,7 @@ class Magpie(Task, MagpieBase): fine-tuned LLMs. As they were fine-tuned using a chat template composed by a user message and a desired assistant output, the instruct fine-tuned LLM learns that after the pre-query or pre-instruct tokens comes an instruction. If these pre-query tokens are sent to the - LLM without any user message, then the LLM will continue generating tokens as it was + LLM without any user message, then the LLM will continue generating tokens as if it was the user. This trick allows "extracting" instructions from the instruct fine-tuned LLM. After this instruct is generated, it can be sent again to the LLM to generate this time an assistant response. This process can be repeated N times allowing to build a multi-turn From 32b1725f0145eb639664526dc2cefde4d9a3b3e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 11 Jul 2024 16:16:01 +0200 Subject: [PATCH 21/30] Update to `huggingface_hub >= 0.22.0` --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 263f1aaac2..7593c1c6f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ anthropic = ["anthropic >= 0.20.0"] argilla = ["argilla >= 1.29.0"] cohere = ["cohere >= 5.2.0"] groq = ["groq >= 0.4.1"] -hf-inference-endpoints = ["huggingface_hub >= 0.19.0"] +hf-inference-endpoints = ["huggingface_hub >= 0.22.0"] hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"] instructor = ["instructor >= 1.2.3"] litellm = ["litellm >= 1.30.0"] From e8991335457feebf9ba1b6188bc051f5783b2bdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 12 Jul 2024 12:12:43 +0200 Subject: [PATCH 22/30] Add generation with `chat_completion` --- .../llms/huggingface/inference_endpoints.py | 181 ++++++++++++++---- 1 file changed, 149 insertions(+), 32 deletions(-) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index ccd4e87c06..a71d353599 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -26,7 +26,7 @@ model_validator, validate_call, ) -from typing_extensions import override +from typing_extensions import Annotated, override from distilabel.llms.base import AsyncLLM from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin @@ -300,7 +300,7 @@ def prepare_input(self, input: "StandardInput") -> str: ) return super().apply_magpie_pre_query_template(prompt, input) - def get_structured_output( + def _get_structured_output( self, input: FormattedInput ) -> Union[Dict[str, Any], None]: """Gets the structured output (if any) for the given input. @@ -318,16 +318,16 @@ def get_structured_output( if isinstance(input, tuple): input, structured_output = input structured_output = { - "type": structured_output["format"], - "value": structured_output["schema"], + "type": structured_output["format"], # type: ignore + "value": structured_output["schema"], # type: ignore } # Same structured output for all the inputs if structured_output is None and self.structured_output is not None: try: structured_output = { - "type": self.structured_output["format"], - "value": self.structured_output["schema"], + "type": self.structured_output["format"], # type: ignore + "value": self.structured_output["schema"], # type: ignore } except KeyError as e: raise ValueError( @@ -337,12 +337,125 @@ def get_structured_output( return structured_output + async def _generate_with_text_generation( + self, + input: FormattedInput, + max_new_tokens: int = 128, + repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + temperature: float = 1.0, + do_sample: bool = False, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + stop_sequences: Union[List[str], None] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + watermark: bool = False, + ) -> Union[str, None]: + structured_output = self._get_structured_output(input) + + completion = None + try: + completion = await self._aclient.text_generation( # type: ignore + prompt=self.prepare_input(input), # type: ignore + max_new_tokens=max_new_tokens, + do_sample=do_sample, + typical_p=typical_p, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_sequences=stop_sequences, + return_full_text=return_full_text, + watermark=watermark, + grammar=structured_output, # type: ignore + # NOTE: here to ensure that the cache is not used and a different response is + # generated every time + seed=seed or random.randint(0, sys.maxsize), + ) + except Exception as e: + self._logger.warning( # type: ignore + f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." + f" Finish reason was: {e}" + ) + return completion + + async def _generate_with_chat_completion( + self, + input: "StandardInput", + max_new_tokens: int = 128, + repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + stop_sequences: Union[List[str], None] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + watermark: bool = False, + ) -> Union[str, None]: + message = None + try: + completion = await self._aclient.chat_completion( # type: ignore + messages=input, # type: ignore + max_tokens=max_new_tokens, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + temperature=temperature, + top_p=top_p, + # NOTE: here to ensure that the cache is not used and a different response is + # generated every time + seed=seed or random.randint(0, sys.maxsize), + ) + choice = completion.choices[0] + if (message := choice.message.content) is None: + self._logger.warning( # type: ignore + f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." + f" Finish reason was: {choice.finish_reason}" + ) + except Exception as e: + self._logger.warning( # type: ignore + f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." + f" Finish reason was: {e}" + ) + return message + + def _check_stop_sequences( + self, + stop_sequences: Optional[Union[str, List[str]]] = None, + ) -> Union[List[str], None]: + """Checks that no more than 4 stop sequences are provided. + + Args: + stop_sequences: the stop sequences to be checked. + + Returns: + The stop sequences. + """ + if stop_sequences is not None: + if isinstance(stop_sequences, str): + stop_sequences = [stop_sequences] + if len(stop_sequences) > 4: + warnings.warn( + "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.", + UserWarning, + stacklevel=2, + ) + stop_sequences = stop_sequences[:4] + return stop_sequences + @validate_call async def agenerate( # type: ignore self, input: FormattedInput, max_new_tokens: int = 128, repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None, + logit_bias: Optional[List[float]] = None, temperature: float = 1.0, do_sample: bool = False, top_k: Optional[int] = None, @@ -361,6 +474,11 @@ async def agenerate( # type: ignore Defaults to `128`. repetition_penalty: the repetition penalty to use for the generation. Defaults to `None`. + frequence_penalty: a value between `-2.0` and `2.0`. Positive values penalize + new tokens based on their existing frequency in the text so far, decreasing + model's likelihood to repeat the same line verbatim. Defauls to `None`. + logit_bias: modify the likelihood of specified tokens appearing in the completion. + Defaults to `None`. temperature: the temperature to use for the generation. Defaults to `1.0`. do_sample: whether to use sampling for the generation. Defaults to `False`. top_k: the top-k value to use for the generation. Defaults to `0.8`, since neither @@ -379,42 +497,41 @@ async def agenerate( # type: ignore Returns: A list of lists of strings containing the generated responses for each input. """ - if stop_sequences is not None: - if isinstance(stop_sequences, str): - stop_sequences = [stop_sequences] - if len(stop_sequences) > 4: - warnings.warn( - "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.", - UserWarning, - stacklevel=2, + stop_sequences = self._check_stop_sequences(stop_sequences) + + if self.tokenizer_id is None: + return [ + await self._generate_with_chat_completion( + input=input, # type: ignore + max_new_tokens=max_new_tokens, + typical_p=typical_p, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_sequences=stop_sequences, + return_full_text=return_full_text, + seed=seed, + watermark=watermark, ) - stop_sequences = stop_sequences[:4] - - structured_output = self.get_structured_output(input) + ] - completion = None - try: - completion = await self._aclient.text_generation( # type: ignore - prompt=self.prepare_input(input), # type: ignore + return [ + await self._generate_with_text_generation( + input=input, max_new_tokens=max_new_tokens, do_sample=do_sample, typical_p=typical_p, repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, temperature=temperature, top_p=top_p, top_k=top_k, stop_sequences=stop_sequences, return_full_text=return_full_text, + seed=seed, watermark=watermark, - grammar=structured_output, # type: ignore - # NOTE: here to ensure that the cache is not used and a different response is - # generated every time - seed=seed or random.randint(0, sys.maxsize), ) - except Exception as e: - self._logger.warning( # type: ignore - f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." - f" Finish reason was: {e}" - ) - - return [completion] + ] From 87371ec9d0d7ce0f62d910b22a5d4dca0f85e673 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 15 Jul 2024 11:29:21 +0200 Subject: [PATCH 23/30] Update `agenerate` arguments --- .../llms/huggingface/inference_endpoints.py | 109 ++++++++++++------ 1 file changed, 74 insertions(+), 35 deletions(-) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index a71d353599..84ab32b93f 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -386,17 +386,16 @@ async def _generate_with_chat_completion( self, input: "StandardInput", max_new_tokens: int = 128, - repetition_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, temperature: float = 1.0, - top_k: Optional[int] = None, + tool_choice: Optional[Dict[str, str]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, top_p: Optional[float] = None, - typical_p: Optional[float] = None, - stop_sequences: Union[List[str], None] = None, - return_full_text: bool = False, - seed: Optional[int] = None, - watermark: bool = False, ) -> Union[str, None]: message = None try: @@ -405,11 +404,16 @@ async def _generate_with_chat_completion( max_tokens=max_new_tokens, frequency_penalty=frequency_penalty, logit_bias=logit_bias, - temperature=temperature, - top_p=top_p, + presence_penalty=presence_penalty, # NOTE: here to ensure that the cache is not used and a different response is # generated every time seed=seed or random.randint(0, sys.maxsize), + stop=stop_sequences, + temperature=temperature, + tool_choice=tool_choice, # type: ignore + tool_prompt=tool_prompt, + tools=tools, # type: ignore + top_p=top_p, ) choice = completion.choices[0] if (message := choice.message.content) is None: @@ -453,46 +457,82 @@ async def agenerate( # type: ignore self, input: FormattedInput, max_new_tokens: int = 128, - repetition_penalty: Optional[float] = None, frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None, logit_bias: Optional[List[float]] = None, + presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, temperature: float = 1.0, + tool_choice: Optional[Dict[str, str]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_p: Optional[float] = None, do_sample: bool = False, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, top_k: Optional[int] = None, - top_p: Optional[float] = None, typical_p: Optional[float] = None, - stop_sequences: Optional[Union[str, List[str]]] = None, - return_full_text: bool = False, - seed: Optional[int] = None, watermark: bool = False, ) -> GenerateOutput: - """Generates completions for the given input using the async client. + """Generates completions for the given input using the async client. This method + uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`. + `chat_completion` method will be used only if no `tokenizer_id` has been specified. + Some arguments of this function are specific to the `text_generation` method, while + some others are specific to the `chat_completion` method. Args: input: a single input in chat format to generate responses for. max_new_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`. - repetition_penalty: the repetition penalty to use for the generation. Defaults - to `None`. frequence_penalty: a value between `-2.0` and `2.0`. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing model's likelihood to repeat the same line verbatim. Defauls to `None`. logit_bias: modify the likelihood of specified tokens appearing in the completion. + This argument is exclusive to the `chat_completion` method and will be used + only if `tokenizer_id` is `None`. Defaults to `None`. - temperature: the temperature to use for the generation. Defaults to `1.0`. - do_sample: whether to use sampling for the generation. Defaults to `False`. - top_k: the top-k value to use for the generation. Defaults to `0.8`, since neither - `0.0` nor `1.0` are valid values in TGI. - top_p: the top-p value to use for the generation. Defaults to `1.0`. - typical_p: the typical-p value to use for the generation. Defaults to `0.5`. + presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize + new tokens based on whether they appear in the text so far, increasing the + model likelihood to talk about new topics. This argument is exclusive to + the `chat_completion` method and will be used only if `tokenizer_id` is + `None`. Defauls to `None`. + seed: the seed to use for the generation. Defaults to `None`. stop_sequences: either a single string or a list of strings containing the sequences to stop the generation at. Defaults to `None`, but will be set to the `tokenizer.eos_token` if available. - return_full_text: whether to return the full text of the completion or just the - generated text. Defaults to `False`, meaning that only the generated text will be - returned. - seed: the seed to use for the generation. Defaults to `None`. - watermark: whether to add the watermark to the generated text. Defaults to `None`. + temperature: the temperature to use for the generation. Defaults to `1.0`. + tool_choice: the name of the tool the model should call. It can be a dictionary + like `{"function_name": "my_tool"}`. If not provided, then the model will + automatically choose which tool to use. This argument is exclusive to the + `chat_completion` method and will be used only if `tokenizer_id` is `None`. + Defaults to `None`. + tool_prompt: A prompt to be appended before the tools. This argument is exclusive + to the `chat_completion` method and will be used only if `tokenizer_id` + is `None`. Defauls to `None`. + tools: a list of tools definitions that the LLM can use. + This argument is exclusive to the `chat_completion` method and will be used + only if `tokenizer_id` is `None`. Defaults to `None`. + top_p: the top-p value to use for the generation. Defaults to `1.0`. + do_sample: whether to use sampling for the generation. This argument is exclusive + of the `text_generation` method and will be only used if `tokenizer_id` is not + `None`. Defaults to `False`. + repetition_penalty: the repetition penalty to use for the generation. This argument + is exclusive of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `None`. + return_full_text: whether to return the full text of the completion or just + the generated text. Defaults to `False`, meaning that only the generated + text will be returned. This argument is exclusive of the `text_generation` + method and will be only used if `tokenizer_id` is not `None`. + top_k: the top-k value to use for the generation. This argument is exclusive + of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid + values in TGI. + typical_p: the typical-p value to use for the generation. This argument is exclusive + of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `None`. + watermark: whether to add the watermark to the generated text. This argument + is exclusive of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `None`. Returns: A list of lists of strings containing the generated responses for each input. @@ -504,17 +544,16 @@ async def agenerate( # type: ignore await self._generate_with_chat_completion( input=input, # type: ignore max_new_tokens=max_new_tokens, - typical_p=typical_p, - repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, + presence_penalty=presence_penalty, + seed=seed, + stop_sequences=stop_sequences, temperature=temperature, + tool_choice=tool_choice, + tool_prompt=tool_prompt, + tools=tools, top_p=top_p, - top_k=top_k, - stop_sequences=stop_sequences, - return_full_text=return_full_text, - seed=seed, - watermark=watermark, ) ] From bf350e314f9b16ce1245a63baf85486ec413213d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 15 Jul 2024 11:32:45 +0200 Subject: [PATCH 24/30] Update unit tests --- src/distilabel/llms/huggingface/inference_endpoints.py | 4 ++-- tests/unit/llms/huggingface/test_inference_endpoints.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 84ab32b93f..e11df895be 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -369,11 +369,11 @@ async def _generate_with_text_generation( top_k=top_k, stop_sequences=stop_sequences, return_full_text=return_full_text, - watermark=watermark, - grammar=structured_output, # type: ignore # NOTE: here to ensure that the cache is not used and a different response is # generated every time seed=seed or random.randint(0, sys.maxsize), + watermark=watermark, + grammar=structured_output, # type: ignore ) except Exception as e: self._logger.warning( # type: ignore diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index d1bb67cff5..4ffbb72263 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -199,16 +199,17 @@ async def test_agenerate_with_structured_output( "do_sample": False, "typical_p": None, "repetition_penalty": None, + "frequency_penalty": None, "temperature": 1.0, "top_p": None, "top_k": None, "stop_sequences": None, "return_full_text": False, + "seed": 2053695854357871005, # pre-computed random value with `random.seed(42)` "watermark": False, "grammar": {"type": "regex", "value": "\\b[A-Z][a-z]*\\b"}, - "seed": 2053695854357871005, # pre-computed random value with `random.seed(42)` } - llm._aclient.text_generation.assert_called_with(**kwargs) + llm._aclient.text_generation.assert_called_with(**kwargs) # type: ignore def test_serialization(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( From 360433f907deb04bcad8a78a2a9db517d27e2ae3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 15 Jul 2024 13:13:09 +0200 Subject: [PATCH 25/30] Fix `tools` were not being used --- .../llms/huggingface/inference_endpoints.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index e11df895be..4cd4c7ca65 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -16,7 +16,7 @@ import random import sys import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from pydantic import ( Field, @@ -171,9 +171,6 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( " or dedicated inference endpoints, respectively." ) - if self.model_id and self.tokenizer_id is None: - self.tokenizer_id = self.model_id - if self.use_magpie_template and self.tokenizer_id is None: raise ValueError( "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please," @@ -392,7 +389,7 @@ async def _generate_with_chat_completion( seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, temperature: float = 1.0, - tool_choice: Optional[Dict[str, str]] = None, + tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, top_p: Optional[float] = None, @@ -463,7 +460,7 @@ async def agenerate( # type: ignore seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, temperature: float = 1.0, - tool_choice: Optional[Dict[str, str]] = None, + tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, top_p: Optional[float] = None, @@ -502,10 +499,9 @@ async def agenerate( # type: ignore `tokenizer.eos_token` if available. temperature: the temperature to use for the generation. Defaults to `1.0`. tool_choice: the name of the tool the model should call. It can be a dictionary - like `{"function_name": "my_tool"}`. If not provided, then the model will - automatically choose which tool to use. This argument is exclusive to the - `chat_completion` method and will be used only if `tokenizer_id` is `None`. - Defaults to `None`. + like `{"function_name": "my_tool"}` or "auto". If not provided, then the + model won't use any tool. This argument is exclusive to the `chat_completion` + method and will be used only if `tokenizer_id` is `None`. Defaults to `None`. tool_prompt: A prompt to be appended before the tools. This argument is exclusive to the `chat_completion` method and will be used only if `tokenizer_id` is `None`. Defauls to `None`. From ef68210e699b9f9484e0512194b58e8f679632ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 15 Jul 2024 13:37:57 +0200 Subject: [PATCH 26/30] Update unit tests --- src/distilabel/llms/azure.py | 2 +- .../llms/huggingface/inference_endpoints.py | 31 +++++++++++++++++ .../huggingface/test_inference_endpoints.py | 33 ++++++++++++++++--- 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/src/distilabel/llms/azure.py b/src/distilabel/llms/azure.py index ebcb5ef9ea..80c0807572 100644 --- a/src/distilabel/llms/azure.py +++ b/src/distilabel/llms/azure.py @@ -45,7 +45,7 @@ class AzureOpenAILLM(OpenAILLM): `None` if not set. Icon: - `:simple-microsoftazure:` + `:material-microsoft-azure:` Examples: diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 4cd4c7ca65..90c0eac16c 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -69,6 +69,7 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin): sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to `None`. + structured_output: Icon: `:hugging:` @@ -119,6 +120,29 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin): output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) ``` + + Generate structured data: + + ```python + from pydantic import BaseModel + from distilabel.llms import InferenceEndpointsLLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + api_key="api.key", + structured_output={"format": "json", "schema": User.model_json_schema()} + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` """ model_id: Optional[str] = None @@ -177,6 +201,13 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( " set a `tokenizer_id` and try again." ) + if ( + self.model_id + and self.tokenizer_id is None + and self.structured_output is not None + ): + self.tokenizer_id = self.model_id + if self.base_url and not (self.model_id or self.endpoint_name): return self diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index 4ffbb72263..5f7f70f834 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -44,11 +44,30 @@ def test_no_tokenizer_magpie_raise_value_error( magpie_pre_query_template="llama3", ) - def test_tokenizer_id_set_if_model_id( + def test_tokenizer_id_set_if_model_id_and_structured_output( self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral" + model_id="distilabel-internal-testing/tiny-random-mistral", + structured_output={ # type: ignore + "title": "MMORPG Character", + "type": "object", + "properties": { + "name": {"type": "string", "description": "Character's name"}, + "level": { + "type": "integer", + "minimum": 1, + "maximum": 100, + "description": "Character's level", + }, + "health": { + "type": "integer", + "minimum": 1, + "description": "Character's current health", + }, + }, + "required": ["name", "level", "health"], + }, ) assert llm.tokenizer_id == llm.model_id @@ -118,11 +137,12 @@ def test_dedicated_inference_endpoints_llm_via_url( ) @pytest.mark.asyncio - async def test_agenerate_via_inference_client( + async def test_agenerate_with_text_generation( self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral" + model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) llm.load() @@ -140,11 +160,12 @@ async def test_agenerate_via_inference_client( ) == [" Aenean hendrerit aliquam velit. ..."] @pytest.mark.asyncio - async def test_generate_via_inference_client( + async def test_generate_with_text_generation( self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) llm.load() @@ -172,6 +193,7 @@ async def test_agenerate_with_structured_output( ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"}, ) llm.load() @@ -214,6 +236,7 @@ async def test_agenerate_with_structured_output( def test_serialization(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) _dump = { From 9ee70966a84d85b99e8bd2a216220a74a560e1b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 15 Jul 2024 13:59:18 +0200 Subject: [PATCH 27/30] Fix list of tuples instead of list of list --- src/distilabel/llms/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index 2a64e77847..07fba6788d 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -329,7 +329,10 @@ async def _agenerate( for _ in range(num_generations) ] outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)] - return list(grouper(outputs, n=num_generations, incomplete="ignore")) + return [ + list(group) + for group in grouper(outputs, n=num_generations, incomplete="ignore") + ] def generate( self, From 3863eb3837502f61576643b2f3bc62ccb20b51a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 15 Jul 2024 14:00:58 +0200 Subject: [PATCH 28/30] Add missing docstring --- src/distilabel/llms/huggingface/inference_endpoints.py | 6 ++++-- src/distilabel/utils/itertools.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 90c0eac16c..3ae794e453 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -69,7 +69,9 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin): sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to `None`. - structured_output: + structured_output: a dictionary containing the structured output configuration or + if more fine-grained control is needed, an instance of `OutlinesStructuredOutput`. + Defaults to None. Icon: `:hugging:` @@ -141,7 +143,7 @@ class User(BaseModel): llm.load() - output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]]) ``` """ diff --git a/src/distilabel/utils/itertools.py b/src/distilabel/utils/itertools.py index 88ce86cc4e..2555f3b262 100644 --- a/src/distilabel/utils/itertools.py +++ b/src/distilabel/utils/itertools.py @@ -13,7 +13,7 @@ # limitations under the License. from itertools import zip_longest -from typing import Any, Iterable, List, Literal, TypeVar +from typing import Any, Iterable, Literal, Tuple, TypeVar T = TypeVar("T") @@ -26,7 +26,7 @@ def grouper( *, incomplete: Literal["fill", "strict", "ignore"] = "fill", fillvalue: Any = None, -) -> Iterable[List[T]]: +) -> Iterable[Tuple[T]]: "Collect data into non-overlapping fixed-length chunks or blocks." # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError From 87c11cc04e448fac869f8909562e5412aff38700 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 15 Jul 2024 14:55:24 +0200 Subject: [PATCH 29/30] Add `chat_completion` unit tests --- .../huggingface/test_inference_endpoints.py | 110 ++++++++++++++---- 1 file changed, 89 insertions(+), 21 deletions(-) diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index 5f7f70f834..436815b0e5 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -21,6 +21,12 @@ import nest_asyncio import pytest from distilabel.llms.huggingface.inference_endpoints import InferenceEndpointsLLM +from huggingface_hub import ( + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputMessage, + ChatCompletionOutputUsage, +) @pytest.fixture(autouse=True) @@ -49,25 +55,7 @@ def test_tokenizer_id_set_if_model_id_and_structured_output( ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", - structured_output={ # type: ignore - "title": "MMORPG Character", - "type": "object", - "properties": { - "name": {"type": "string", "description": "Character's name"}, - "level": { - "type": "integer", - "minimum": 1, - "maximum": 100, - "description": "Character's level", - }, - "health": { - "type": "integer", - "minimum": 1, - "description": "Character's current health", - }, - }, - "required": ["name", "level", "health"], - }, + structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"}, ) assert llm.tokenizer_id == llm.model_id @@ -160,9 +148,89 @@ async def test_agenerate_with_text_generation( ) == [" Aenean hendrerit aliquam velit. ..."] @pytest.mark.asyncio - async def test_generate_with_text_generation( + async def test_agenerate_with_chat_completion( + self, mock_inference_client: MagicMock + ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral", + ) + llm.load() + + llm._aclient.chat_completion = AsyncMock( # type: ignore + return_value=ChatCompletionOutput( # type: ignore + choices=[ + ChatCompletionOutputComplete( + finish_reason="length", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=" Aenean hendrerit aliquam velit. ...", + ), + ) + ], + created=1721045246, + id="", + model="meta-llama/Meta-Llama-3-70B-Instruct", + object="chat.completion", + system_fingerprint="2.1.1-dev0-sha-4327210", + usage=ChatCompletionOutputUsage( + completion_tokens=66, prompt_tokens=18, total_tokens=84 + ), + ) + ) + + assert await llm.agenerate( + input=[ + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) == [" Aenean hendrerit aliquam velit. ..."] + + @pytest.mark.asyncio + async def test_agenerate_with_chat_completion_fails( self, mock_inference_client: MagicMock ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral", + ) + llm.load() + + llm._aclient.chat_completion = AsyncMock( # type: ignore + return_value=ChatCompletionOutput( # type: ignore + choices=[ + ChatCompletionOutputComplete( + finish_reason="eos_token", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + ), + ) + ], + created=1721045246, + id="", + model="meta-llama/Meta-Llama-3-70B-Instruct", + object="chat.completion", + system_fingerprint="2.1.1-dev0-sha-4327210", + usage=ChatCompletionOutputUsage( + completion_tokens=66, prompt_tokens=18, total_tokens=84 + ), + ) + ) + + assert await llm.agenerate( + input=[ + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) == [None] + + @pytest.mark.asyncio + async def test_generate(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", tokenizer_id="distilabel-internal-testing/tiny-random-mistral", @@ -185,7 +253,7 @@ async def test_generate_with_text_generation( }, ] ] - ) == [(" Aenean hendrerit aliquam velit. ...",)] + ) == [[" Aenean hendrerit aliquam velit. ..."]] @pytest.mark.asyncio async def test_agenerate_with_structured_output( From cd3cc5d679d26beff54e9dd6a75c2a1612838611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 15 Jul 2024 15:04:25 +0200 Subject: [PATCH 30/30] Fix `GroqLLM.generate` unit test after updating `_agenerate` --- tests/unit/llms/test_groq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/llms/test_groq.py b/tests/unit/llms/test_groq.py index 7607ab2cb2..8789b56a6f 100644 --- a/tests/unit/llms/test_groq.py +++ b/tests/unit/llms/test_groq.py @@ -104,7 +104,7 @@ async def test_generate(self, mock_groq: MagicMock) -> None: }, ] ] - ) == [(" Aenean hendrerit aliquam velit. ...",)] + ) == [[" Aenean hendrerit aliquam velit. ..."]] @pytest.mark.parametrize( "structured_output, dump",