diff --git a/pyproject.toml b/pyproject.toml index 263f1aaac2..7593c1c6f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ anthropic = ["anthropic >= 0.20.0"] argilla = ["argilla >= 1.29.0"] cohere = ["cohere >= 5.2.0"] groq = ["groq >= 0.4.1"] -hf-inference-endpoints = ["huggingface_hub >= 0.19.0"] +hf-inference-endpoints = ["huggingface_hub >= 0.22.0"] hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"] instructor = ["instructor >= 1.2.3"] litellm = ["litellm >= 1.30.0"] diff --git a/src/distilabel/llms/__init__.py b/src/distilabel/llms/__init__.py index 3e50ddefaa..5f6ab9abbf 100644 --- a/src/distilabel/llms/__init__.py +++ b/src/distilabel/llms/__init__.py @@ -22,7 +22,7 @@ from distilabel.llms.litellm import LiteLLM from distilabel.llms.llamacpp import LlamaCppLLM from distilabel.llms.mistral import MistralLLM -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.moa import MixtureOfAgentsLLM from distilabel.llms.ollama import OllamaLLM from distilabel.llms.openai import OpenAILLM diff --git a/src/distilabel/llms/azure.py b/src/distilabel/llms/azure.py index ebcb5ef9ea..80c0807572 100644 --- a/src/distilabel/llms/azure.py +++ b/src/distilabel/llms/azure.py @@ -45,7 +45,7 @@ class AzureOpenAILLM(OpenAILLM): `None` if not set. Icon: - `:simple-microsoftazure:` + `:material-microsoft-azure:` Examples: diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index 2a64e77847..07fba6788d 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -329,7 +329,10 @@ async def _agenerate( for _ in range(num_generations) ] outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)] - return list(grouper(outputs, n=num_generations, incomplete="ignore")) + return [ + list(group) + for group in grouper(outputs, n=num_generations, incomplete="ignore") + ] def generate( self, diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 6d4d3d1a5e..3ae794e453 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -14,8 +14,9 @@ import os import random +import sys import warnings -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from pydantic import ( Field, @@ -25,9 +26,10 @@ model_validator, validate_call, ) -from typing_extensions import override +from typing_extensions import Annotated, override from distilabel.llms.base import AsyncLLM +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import ( @@ -42,15 +44,13 @@ if TYPE_CHECKING: from huggingface_hub import AsyncInferenceClient - from openai import AsyncOpenAI from transformers import PreTrainedTokenizer -class InferenceEndpointsLLM(AsyncLLM): +class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin): """InferenceEndpoints LLM implementation running the async API client. - This LLM will internally use `huggingface_hub.AsyncInferenceClient` or `openai.AsyncOpenAI` - depending on the `use_openai_client` attribute. + This LLM will internally use `huggingface_hub.AsyncInferenceClient`. Attributes: model_id: the model ID to use for the LLM as available in the Hugging Face Hub, which @@ -63,7 +63,15 @@ class InferenceEndpointsLLM(AsyncLLM): tokenizer_id: the tokenizer ID to use for the LLM as available in the Hugging Face Hub. Defaults to `None`, but defining one is recommended to properly format the prompt. model_display_name: the model display name to use for the LLM. Defaults to `None`. - use_openai_client: whether to use the OpenAI client instead of the Hugging Face client. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. + structured_output: a dictionary containing the structured output configuration or + if more fine-grained control is needed, an instance of `OutlinesStructuredOutput`. + Defaults to None. Icon: `:hugging:` @@ -114,6 +122,29 @@ class InferenceEndpointsLLM(AsyncLLM): output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) ``` + + Generate structured data: + + ```python + from pydantic import BaseModel + from distilabel.llms import InferenceEndpointsLLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + api_key="api.key", + structured_output={"format": "json", "schema": User.model_json_schema()} + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]]) + ``` """ model_id: Optional[str] = None @@ -137,7 +168,6 @@ class InferenceEndpointsLLM(AsyncLLM): tokenizer_id: Optional[str] = None model_display_name: Optional[str] = None - use_openai_client: bool = False structured_output: Optional[RuntimeParameter[StructuredOutputType]] = Field( default=None, @@ -149,7 +179,7 @@ class InferenceEndpointsLLM(AsyncLLM): _model_name: Optional[str] = PrivateAttr(default=None) _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None) _api_key_env_var: str = PrivateAttr(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME) - _aclient: Optional[Union["AsyncInferenceClient", "AsyncOpenAI"]] = PrivateAttr(...) + _aclient: Optional["AsyncInferenceClient"] = PrivateAttr(...) @model_validator(mode="after") # type: ignore def only_one_of_model_id_endpoint_name_or_base_url_provided( @@ -161,11 +191,25 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( if self.base_url and (self.model_id or self.endpoint_name): self._logger.warning( # type: ignore - f"Since the `base_url={self.base_url}` is available and either one of `model_id` or `endpoint_name`" - " is also provided, the `base_url` will either be ignored or overwritten with the one generated" - " from either of those args, for serverless or dedicated inference endpoints, respectively." + f"Since the `base_url={self.base_url}` is available and either one of `model_id`" + " or `endpoint_name` is also provided, the `base_url` will either be ignored" + " or overwritten with the one generated from either of those args, for serverless" + " or dedicated inference endpoints, respectively." + ) + + if self.use_magpie_template and self.tokenizer_id is None: + raise ValueError( + "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please," + " set a `tokenizer_id` and try again." ) + if ( + self.model_id + and self.tokenizer_id is None + and self.structured_output is not None + ): + self.tokenizer_id = self.model_id + if self.base_url and not (self.model_id or self.endpoint_name): return self @@ -176,19 +220,16 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( return self raise ValidationError( - "Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is provided too," - " it will be overwritten instead. Found `model_id`={self.model_id}, `endpoint_name`={self.endpoint_name}," - f" and `base_url`={self.base_url}." + f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is" + f" provided too, it will be overwritten instead. Found `model_id`={self.model_id}," + f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}." ) def load(self) -> None: # noqa: C901 - """Loads the either the `AsyncInferenceClient` or the `AsyncOpenAI` client to benefit - from async requests, running the Hugging Face Inference Endpoint underneath via the - `/v1/chat/completions` endpoint, exposed for the models running on TGI using the - `text-generation` task. + """Loads the `AsyncInferenceClient` client to connect to the Hugging Face Inference + Endpoint. Raises: - ImportError: if the `openai` Python client is not installed. ImportError: if the `huggingface-hub` Python client is not installed. ValueError: if the model is not currently deployed or is not running the TGI framework. ImportError: if the `transformers` Python client is not installed. @@ -234,31 +275,16 @@ def load(self) -> None: # noqa: C901 ) if client.status in ["paused", "scaledToZero"]: client.resume().wait(timeout=300) - elif client.status in ["initializing"]: + elif client.status == "initializing": client.wait(timeout=300) self.base_url = client.url self._model_name = client.repository - if self.use_openai_client: - try: - from openai import AsyncOpenAI - except ImportError as ie: - raise ImportError( - "OpenAI Python client is not installed. Please install it using" - " `pip install openai`." - ) from ie - - self._aclient = AsyncOpenAI( - base_url=self.base_url, - api_key=self.api_key.get_secret_value(), - max_retries=6, - ) - else: - self._aclient = AsyncInferenceClient( - model=self.base_url, - token=self.api_key.get_secret_value(), - ) + self._aclient = AsyncInferenceClient( + model=self.base_url, + token=self.api_key.get_secret_value(), + ) if self.tokenizer_id: try: @@ -283,113 +309,55 @@ def model_name(self) -> Union[str, None]: # type: ignore or self.base_url ) - async def _openai_agenerate( - self, - input: "StandardInput", - max_new_tokens: int = 128, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - temperature: float = 1.0, - top_p: Optional[float] = None, - stop: Optional[Union[str, List[str]]] = None, - ) -> GenerateOutput: - """Generates completions for the given input using the OpenAI async client.""" - completion = await self._aclient.chat.completions.create( # type: ignore - messages=input, # type: ignore - model="tgi", - max_tokens=max_new_tokens, - n=1, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - temperature=temperature, - top_p=top_p, - stop=stop, - timeout=50, - ) - if completion.choices[0].message.content is None: - self._logger.warning( # type: ignore - f"⚠️ Received no response using OpenAI client (model: '{self.model_name}')." - f" Finish reason was: {completion.choices[0].finish_reason}" + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + prompt: str = ( + self._tokenizer.apply_chat_template( # type: ignore + conversation=input, # type: ignore + tokenize=False, + add_generation_prompt=True, ) - return [completion.choices[0].message.content] + if input + else "" + ) + return super().apply_magpie_pre_query_template(prompt, input) - @validate_call - async def agenerate( # type: ignore - self, - input: FormattedInput, - max_new_tokens: int = 128, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repetition_penalty: Optional[float] = None, - temperature: float = 1.0, - do_sample: bool = False, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - typical_p: Optional[float] = None, - stop_sequences: Optional[Union[str, List[str]]] = None, - return_full_text: bool = False, - seed: Optional[int] = None, - watermark: bool = False, - ) -> GenerateOutput: - """Generates completions for the given input using the OpenAI async client. + def _get_structured_output( + self, input: FormattedInput + ) -> Union[Dict[str, Any], None]: + """Gets the structured output (if any) for the given input. Args: input: a single input in chat format to generate responses for. - max_new_tokens: the maximum number of new tokens that the model will generate. - Defaults to `128`. - frequency_penalty: the repetition penalty to use for the generation. Defaults - to `0.0`. Only applies if `use_openai_client=True`. - presence_penalty: the presence penalty to use for the generation. Defaults to - `0.0`. Only applies if `use_openai_client=True`. - repetition_penalty: the repetition penalty to use for the generation. Defaults - to `None`. Only applies if `use_openai_client=False`. - temperature: the temperature to use for the generation. Defaults to `1.0`. - do_sample: whether to use sampling for the generation. Defaults to `False`. - Only applies if `use_openai_client=False`. - top_k: the top-k value to use for the generation. Defaults to `0.8`, since neither - `0.0` nor `1.0` are valid values in TGI. - top_p: the top-p value to use for the generation. Defaults to `1.0`. - typical_p: the typical-p value to use for the generation. Defaults to `0.5`. - stop_sequences: either a single string or a list of strings containing the sequences - to stop the generation at. Defaults to `None`, but will be set to the - `tokenizer.eos_token` if available. - return_full_text: whether to return the full text of the completion or just the - generated text. Defaults to `False`, meaning that only the generated text will be - returned. - seed: the seed to use for the generation. Defaults to `None`. - watermark: whether to add the watermark to the generated text. Defaults to `None`. Returns: - A list of lists of strings containing the generated responses for each input. + The structured output that will be passed as `grammer` to the inference endpoint + or `None` if not required. """ - if stop_sequences is not None: - if isinstance(stop_sequences, str): - stop_sequences = [stop_sequences] - if len(stop_sequences) > 4: - warnings.warn( - "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.", - UserWarning, - stacklevel=2, - ) - stop_sequences = stop_sequences[:4] - structured_output = None + + # Specific structured output per input if isinstance(input, tuple): input, structured_output = input structured_output = { - "type": structured_output["format"], - "value": structured_output["schema"], + "type": structured_output["format"], # type: ignore + "value": structured_output["schema"], # type: ignore } - # NOTE: `self.structured_output` applies to all the generations, while `structured_output` i.e. the - # value included within the tuple provided as `input` to this method, is intended to be different per - # each input, so those should not be used together. Meaning that it should be either provided at attribute - # level i.e. self, or via a column within each input i.e. row. + # Same structured output for all the inputs if structured_output is None and self.structured_output is not None: try: structured_output = { - "type": self.structured_output["format"], - "value": self.structured_output["schema"], + "type": self.structured_output["format"], # type: ignore + "value": self.structured_output["schema"], # type: ignore } except KeyError as e: raise ValueError( @@ -397,50 +365,241 @@ async def agenerate( # type: ignore "the `structured_output` attribute." ) from e - if self.use_openai_client: - return await self._openai_agenerate( - input=input, - max_new_tokens=max_new_tokens, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - temperature=temperature, - top_p=top_p, - stop=stop_sequences, - ) + return structured_output - if self._tokenizer is not None: - prompt = self._tokenizer.apply_chat_template( # type: ignore - conversation=input, # type: ignore - tokenize=False, - add_generation_prompt=True, - ) - else: - # TODO: should we apply a default chat template here instead? e.g. ChatML - prompt = "\n".join([message["content"] for message in input]) + async def _generate_with_text_generation( + self, + input: FormattedInput, + max_new_tokens: int = 128, + repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + temperature: float = 1.0, + do_sample: bool = False, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + stop_sequences: Union[List[str], None] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + watermark: bool = False, + ) -> Union[str, None]: + structured_output = self._get_structured_output(input) completion = None try: completion = await self._aclient.text_generation( # type: ignore - prompt=prompt, # type: ignore + prompt=self.prepare_input(input), # type: ignore max_new_tokens=max_new_tokens, do_sample=do_sample, typical_p=typical_p, repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, temperature=temperature, top_p=top_p, top_k=top_k, stop_sequences=stop_sequences, return_full_text=return_full_text, + # NOTE: here to ensure that the cache is not used and a different response is + # generated every time + seed=seed or random.randint(0, sys.maxsize), watermark=watermark, grammar=structured_output, # type: ignore + ) + except Exception as e: + self._logger.warning( # type: ignore + f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." + f" Finish reason was: {e}" + ) + return completion + + async def _generate_with_chat_completion( + self, + input: "StandardInput", + max_new_tokens: int = 128, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: float = 1.0, + tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_p: Optional[float] = None, + ) -> Union[str, None]: + message = None + try: + completion = await self._aclient.chat_completion( # type: ignore + messages=input, # type: ignore + max_tokens=max_new_tokens, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + presence_penalty=presence_penalty, # NOTE: here to ensure that the cache is not used and a different response is # generated every time - seed=seed or random.randint(0, 2147483647), + seed=seed or random.randint(0, sys.maxsize), + stop=stop_sequences, + temperature=temperature, + tool_choice=tool_choice, # type: ignore + tool_prompt=tool_prompt, + tools=tools, # type: ignore + top_p=top_p, ) + choice = completion.choices[0] + if (message := choice.message.content) is None: + self._logger.warning( # type: ignore + f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." + f" Finish reason was: {choice.finish_reason}" + ) except Exception as e: self._logger.warning( # type: ignore f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) + return message - return [completion] + def _check_stop_sequences( + self, + stop_sequences: Optional[Union[str, List[str]]] = None, + ) -> Union[List[str], None]: + """Checks that no more than 4 stop sequences are provided. + + Args: + stop_sequences: the stop sequences to be checked. + + Returns: + The stop sequences. + """ + if stop_sequences is not None: + if isinstance(stop_sequences, str): + stop_sequences = [stop_sequences] + if len(stop_sequences) > 4: + warnings.warn( + "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.", + UserWarning, + stacklevel=2, + ) + stop_sequences = stop_sequences[:4] + return stop_sequences + + @validate_call + async def agenerate( # type: ignore + self, + input: FormattedInput, + max_new_tokens: int = 128, + frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None, + logit_bias: Optional[List[float]] = None, + presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: float = 1.0, + tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_p: Optional[float] = None, + do_sample: bool = False, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + top_k: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + ) -> GenerateOutput: + """Generates completions for the given input using the async client. This method + uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`. + `chat_completion` method will be used only if no `tokenizer_id` has been specified. + Some arguments of this function are specific to the `text_generation` method, while + some others are specific to the `chat_completion` method. + + Args: + input: a single input in chat format to generate responses for. + max_new_tokens: the maximum number of new tokens that the model will generate. + Defaults to `128`. + frequence_penalty: a value between `-2.0` and `2.0`. Positive values penalize + new tokens based on their existing frequency in the text so far, decreasing + model's likelihood to repeat the same line verbatim. Defauls to `None`. + logit_bias: modify the likelihood of specified tokens appearing in the completion. + This argument is exclusive to the `chat_completion` method and will be used + only if `tokenizer_id` is `None`. + Defaults to `None`. + presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize + new tokens based on whether they appear in the text so far, increasing the + model likelihood to talk about new topics. This argument is exclusive to + the `chat_completion` method and will be used only if `tokenizer_id` is + `None`. Defauls to `None`. + seed: the seed to use for the generation. Defaults to `None`. + stop_sequences: either a single string or a list of strings containing the sequences + to stop the generation at. Defaults to `None`, but will be set to the + `tokenizer.eos_token` if available. + temperature: the temperature to use for the generation. Defaults to `1.0`. + tool_choice: the name of the tool the model should call. It can be a dictionary + like `{"function_name": "my_tool"}` or "auto". If not provided, then the + model won't use any tool. This argument is exclusive to the `chat_completion` + method and will be used only if `tokenizer_id` is `None`. Defaults to `None`. + tool_prompt: A prompt to be appended before the tools. This argument is exclusive + to the `chat_completion` method and will be used only if `tokenizer_id` + is `None`. Defauls to `None`. + tools: a list of tools definitions that the LLM can use. + This argument is exclusive to the `chat_completion` method and will be used + only if `tokenizer_id` is `None`. Defaults to `None`. + top_p: the top-p value to use for the generation. Defaults to `1.0`. + do_sample: whether to use sampling for the generation. This argument is exclusive + of the `text_generation` method and will be only used if `tokenizer_id` is not + `None`. Defaults to `False`. + repetition_penalty: the repetition penalty to use for the generation. This argument + is exclusive of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `None`. + return_full_text: whether to return the full text of the completion or just + the generated text. Defaults to `False`, meaning that only the generated + text will be returned. This argument is exclusive of the `text_generation` + method and will be only used if `tokenizer_id` is not `None`. + top_k: the top-k value to use for the generation. This argument is exclusive + of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid + values in TGI. + typical_p: the typical-p value to use for the generation. This argument is exclusive + of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `None`. + watermark: whether to add the watermark to the generated text. This argument + is exclusive of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `None`. + + Returns: + A list of lists of strings containing the generated responses for each input. + """ + stop_sequences = self._check_stop_sequences(stop_sequences) + + if self.tokenizer_id is None: + return [ + await self._generate_with_chat_completion( + input=input, # type: ignore + max_new_tokens=max_new_tokens, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + presence_penalty=presence_penalty, + seed=seed, + stop_sequences=stop_sequences, + temperature=temperature, + tool_choice=tool_choice, + tool_prompt=tool_prompt, + tools=tools, + top_p=top_p, + ) + ] + + return [ + await self._generate_with_text_generation( + input=input, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + typical_p=typical_p, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_sequences=stop_sequences, + return_full_text=return_full_text, + seed=seed, + watermark=watermark, + ) + ] diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 6e7736d006..86754e8ef1 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -19,7 +19,8 @@ from distilabel.llms.base import LLM from distilabel.llms.chat_templates import CHATML_TEMPLATE -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput @@ -32,7 +33,7 @@ from distilabel.llms.typing import HiddenState -class TransformersLLM(LLM, CudaDevicePlacementMixin): +class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): """Hugging Face `transformers` library LLM implementation using the text generation pipeline. @@ -64,6 +65,12 @@ class TransformersLLM(LLM, CudaDevicePlacementMixin): local configuration will be used. Defaults to `None`. structured_output: a dictionary containing the structured output configuration or if more fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. Icon: `:hugging:` @@ -157,14 +164,25 @@ def model_name(self) -> str: return self.model def prepare_input(self, input: "StandardInput") -> str: - """Prepares the input by applying the chat template to the input, which is formatted - as an OpenAI conversation, and adding the generation prompt. + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. """ - return self._pipeline.tokenizer.apply_chat_template( # type: ignore - input, # type: ignore - tokenize=False, - add_generation_prompt=True, + prompt: str = ( + self._pipeline.tokenizer.apply_chat_template( # type: ignore + input, # type: ignore + tokenize=False, + add_generation_prompt=True, + ) + if input + else "" ) + return super().apply_magpie_pre_query_template(prompt, input) @validate_call def generate( # type: ignore @@ -209,6 +227,7 @@ def generate( # type: ignore do_sample=do_sample, num_return_sequences=num_generations, prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn, + pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore ) return [ [generation["generated_text"] for generation in output] diff --git a/src/distilabel/llms/mixins/__init__.py b/src/distilabel/llms/mixins/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/src/distilabel/llms/mixins/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/src/distilabel/llms/mixins.py b/src/distilabel/llms/mixins/cuda_device_placement.py similarity index 100% rename from src/distilabel/llms/mixins.py rename to src/distilabel/llms/mixins/cuda_device_placement.py diff --git a/src/distilabel/llms/mixins/magpie.py b/src/distilabel/llms/mixins/magpie.py new file mode 100644 index 0000000000..8efa3add58 --- /dev/null +++ b/src/distilabel/llms/mixins/magpie.py @@ -0,0 +1,89 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Literal, Union + +from pydantic import BaseModel, field_validator, model_validator +from typing_extensions import Self + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import StandardInput + +MagpieAvailablePreQueryTemplates = Literal["llama3", "qwen2"] +"""The available predefined pre-query templates.""" + +MAGPIE_PRE_QUERY_TEMPLATES: Dict[MagpieAvailablePreQueryTemplates, str] = { + "llama3": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", + "qwen2": "<|im_start|>user\n", +} + + +class MagpieChatTemplateMixin(BaseModel, validate_assignment=True): + """A simple mixin that adds the required logic to apply the pre-query template that + allows to an instruct fine-tuned LLM to generate user instructions as described in + the paper 'Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing'. + + This mixin is meant to be used in combination with the [Magpie][distilabel.steps.tasks.magpie.base.Magpie] + task. + + Attributes: + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. + + References: + - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) + """ + + use_magpie_template: bool = False + magpie_pre_query_template: Union[MagpieAvailablePreQueryTemplates, str, None] = None + + @field_validator("magpie_pre_query_template") + @classmethod + def magpie_pre_query_template_validator(cls, value: str) -> str: + """Resolves the pre-query template alias if it exists, otherwise, returns the + value with no modification.""" + if value in MAGPIE_PRE_QUERY_TEMPLATES: + return MAGPIE_PRE_QUERY_TEMPLATES[value] + return value + + @model_validator(mode="after") + def use_magpie_template_validation(self) -> Self: + """Checks that there is a pre-query template set if Magpie is going to be used.""" + if self.use_magpie_template and self.magpie_pre_query_template is None: + raise ValueError( + f"Cannot set `use_magpie_template=True` if `magpie_pre_query_template` is" + f" `None`. To use Magpie with `{self.__class__.__name__}` you need to set" + f" the `magpie_pre_query_template` attribute." + ) + return self + + def apply_magpie_pre_query_template( + self, prompt: str, input: "StandardInput" + ) -> str: + """Applies the pre-query template to the prompt if Magpie is going to be used. + + Args: + prompt: the prompt to which the pre-query template has to be applied. + input: the list with the chat items that were used to generate the prompt. + + Returns: + The prompt with the pre-query template applied if needed. + """ + if not self.use_magpie_template or (input and input[-1]["role"] == "user"): + return prompt + return prompt + self.magpie_pre_query_template # type: ignore diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index ee124c7dfa..d8e6100a13 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -30,7 +30,8 @@ from distilabel.llms.base import LLM from distilabel.llms.chat_templates import CHATML_TEMPLATE -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType @@ -39,11 +40,13 @@ from transformers import PreTrainedTokenizer from vllm import LLM as _vLLM + from distilabel.steps.tasks.typing import StandardInput + SamplingParams = None -class vLLM(LLM, CudaDevicePlacementMixin): +class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): """`vLLM` library LLM implementation. Attributes: @@ -75,6 +78,12 @@ class vLLM(LLM, CudaDevicePlacementMixin): _tokenizer: the tokenizer instance used to format the prompt before passing it to the `LLM`. This attribute is meant to be used internally and should not be accessed directly. It will be set in the `load` method. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. References: - https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py @@ -213,15 +222,26 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model - def prepare_input(self, input: "FormattedInput") -> str: - """Prepares the input by applying the chat template to the input, which is formatted - as an OpenAI conversation, and adding the generation prompt. + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. """ - return self._tokenizer.apply_chat_template( # type: ignore - input, # type: ignore - tokenize=False, - add_generation_prompt=True, # type: ignore + prompt: str = ( + self._tokenizer.apply_chat_template( # type: ignore + input, # type: ignore + tokenize=False, + add_generation_prompt=True, # type: ignore + ) + if input + else "" ) + return super().apply_magpie_pre_query_template(prompt, input) def _prepare_batches( self, inputs: List[FormattedInput] @@ -304,14 +324,13 @@ def generate( # type: ignore if extra_sampling_params is None: extra_sampling_params = {} structured_output = None - needs_sorting = False if isinstance(inputs[0], tuple): prepared_batches, sorted_indices = self._prepare_batches(inputs) - needs_sorting = True else: # Simulate a batch without the structured output content prepared_batches = [([self.prepare_input(input) for input in inputs], None)] + sorted_indices = None # In case we have a single structured output for the dataset, we can logits_processors = None @@ -348,7 +367,7 @@ def generate( # type: ignore # If logits_processor is set, we need to sort the outputs back to the original order # (would be needed only if we have multiple structured outputs in the dataset) - if needs_sorting: + if sorted_indices is not None: batched_outputs = _sort_batches( batched_outputs, sorted_indices, num_generations=num_generations ) diff --git a/src/distilabel/pipeline/step_wrapper.py b/src/distilabel/pipeline/step_wrapper.py index 29c2e3e11c..3befd5187d 100644 --- a/src/distilabel/pipeline/step_wrapper.py +++ b/src/distilabel/pipeline/step_wrapper.py @@ -16,7 +16,7 @@ from queue import Queue from typing import Any, Dict, List, Optional, Union, cast -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.pipeline.batch import _Batch from distilabel.pipeline.constants import LAST_BATCH_SENT_FLAG from distilabel.pipeline.typing import StepLoadStatus diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index b2456d7824..0b3a69596b 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -35,6 +35,8 @@ from distilabel.steps.tasks.instruction_backtranslation import ( InstructionBacktranslation, ) +from distilabel.steps.tasks.magpie.base import Magpie +from distilabel.steps.tasks.magpie.generator import MagpieGenerator from distilabel.steps.tasks.pair_rm import PairRM from distilabel.steps.tasks.prometheus_eval import PrometheusEval from distilabel.steps.tasks.quality_scorer import QualityScorer @@ -64,6 +66,8 @@ "GenerateTextRetrievalData", "MonolingualTripletGenerator", "InstructionBacktranslation", + "Magpie", + "MagpieGenerator", "PairRM", "PrometheusEval", "QualityScorer", diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py index bc8c5d2eff..0ab8b9ab9b 100644 --- a/src/distilabel/steps/tasks/evol_instruct/generator.py +++ b/src/distilabel/steps/tasks/evol_instruct/generator.py @@ -280,6 +280,7 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore instructions = [] mutation_no = 0 + # TODO: update to take into account `offset` iter_no = 0 while len(instructions) < self.num_instructions: prompts = self._apply_random_mutation(iter_no=iter_no) diff --git a/src/distilabel/steps/tasks/magpie/__init__.py b/src/distilabel/steps/tasks/magpie/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/src/distilabel/steps/tasks/magpie/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py new file mode 100644 index 0000000000..9ecbfe2f59 --- /dev/null +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -0,0 +1,375 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from pydantic import Field, PositiveInt + +from distilabel.llms.base import LLM +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin +from distilabel.mixins.runtime_parameters import ( + RuntimeParameter, + RuntimeParametersMixin, +) +from distilabel.steps.base import StepInput +from distilabel.steps.tasks.base import Task + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType, FormattedInput + from distilabel.steps.typing import StepOutput + +MAGPIE_MULTI_TURN_SYSTEM_PROMPT = ( + "You are a helpful Al assistant. The user will engage in a multi−round conversation" + " with you, asking initial questions and following up with additional related questions." + " Your goal is to provide thorough, relevant and insightful responses to help the user" + " with their queries." +) + + +class MagpieBase(RuntimeParametersMixin): + """Base class defining the generation logic of Magpie method. + + References: + - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) + """ + + llm: LLM + + n_turns: RuntimeParameter[PositiveInt] = Field( + default=1, + description="The number of turns to generate for the conversation.", + ) + only_instruction: RuntimeParameter[bool] = Field( + default=False, + description="Whether to generate only the instruction. If this argument" + " is `True`, then `n_turns` will be ignored.", + ) + system_prompt: Optional[RuntimeParameter[str]] = Field( + default=None, + description="An optional system prompt that can be used to steer the LLM to generate" + " content of certain topic, guide the style, etc.", + ) + + def _prepare_inputs_for_instruction_generation( + self, inputs: List[Dict[str, Any]] + ) -> List["FormattedInput"]: + """Prepares the inputs adding the system (if required) prompt provided in each row, + or if the conversations to generate have more than one turn, then adding the system + prompt for multi-turn conversation from the paper. + + Args: + inputs: the inputs to prepare. + + Returns: + The prepared inputs. + """ + prepared_inputs = [] + for input in inputs: + conversation = [] + if "system_prompt" in input: + conversation.append( + {"role": "system", "content": input["system_prompt"]} + ) + elif self.system_prompt is not None: + conversation.append({"role": "system", "content": self.system_prompt}) + elif self.n_turns > 1: # type: ignore + conversation.append( + {"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT} + ) + + prepared_inputs.append(conversation) + + return prepared_inputs + + def _append_messages_to_conversations( + self, role: str, messages: List[str], conversations: List["ChatType"] + ) -> List["ChatType"]: + """Appends the outputs generated by the LLM with the specified role to the conversations. + + Args: + role: the role to assign to the message to be appended. + messages: the list of messages generated by the LLM for each conversation. + conversations: the list of conversations to which the messages will be appended. + + Returns: + The updated conversations. + """ + for instruction, conversation in zip(messages, conversations): + conversation.append({"role": role, "content": instruction}) + return conversations + + def _generate_multi_turn_conversation( + self, inputs: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + conversations = self._prepare_inputs_for_instruction_generation(inputs) + + for _ in range(self.n_turns): # type: ignore + # Generate instruction or user message + outputs = self.llm.generate( + inputs=conversations, + num_generations=1, + **self.llm.generation_kwargs, # type: ignore + ) + + conversations = self._append_messages_to_conversations( + role="user", + messages=[output[0] for output in outputs], + conversations=conversations, # type: ignore + ) + + # TODO: handle potential previous `None`s + + # Generate response + outputs = self.llm.generate( + inputs=conversations, + num_generations=1, + **self.llm.generation_kwargs, # type: ignore + ) + + conversations = self._append_messages_to_conversations( + role="assistant", + messages=[output[0] for output in outputs], + conversations=conversations, # type: ignore + ) + + return [{"conversation": conversation} for conversation in conversations] + + def _generate_with_pre_query_template( + self, inputs: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Generate a list of instructions or conversations of the specified number of turns. + + Args: + inputs: a list of dictionaries that can contain a `system_prompt` key. + + Returns: + The list of generated conversations. + """ + + if self.only_instruction: + prepared_inputs = self._prepare_inputs_for_instruction_generation(inputs) + outputs = self.llm.generate( + inputs=prepared_inputs, + num_generations=1, + **self.llm.generation_kwargs, # type: ignore + ) + return [{"instruction": output[0]} for output in outputs] + + return self._generate_multi_turn_conversation(inputs) + + +class Magpie(Task, MagpieBase): + """Generates conversations using an instruct fine-tuned LLM. + + Magpie is a neat method that allows generating user instructions with no seed data + or specific system prompt thanks to the autoregressive capabilities of the instruct + fine-tuned LLMs. As they were fine-tuned using a chat template composed by a user message + and a desired assistant output, the instruct fine-tuned LLM learns that after the pre-query + or pre-instruct tokens comes an instruction. If these pre-query tokens are sent to the + LLM without any user message, then the LLM will continue generating tokens as if it was + the user. This trick allows "extracting" instructions from the instruct fine-tuned LLM. + After this instruct is generated, it can be sent again to the LLM to generate this time + an assistant response. This process can be repeated N times allowing to build a multi-turn + conversation. This method was described in the paper 'Magpie: Alignment Data Synthesis from + Scratch by Prompting Aligned LLMs with Nothing'. + + Attributes: + n_turns: the number of turns that the generated conversation will have. + only_instruction: whether to generate only the instruction. If this argument is + `True`, then `n_turns` will be ignored. Defaults to `False`. + system_prompt: an optional system prompt that can be used to steer the LLM to generate + content of certain topic, guide the style, etc. If the provided inputs contains + a `system_prompt` column, then this runtime parameter will be ignored and the + one from the column will be used. Defaults to `None`. + + Runtime parameters: + - `n_turns`: the number of turns that the generated conversation will have. + - `only_instruction`: whether to generate only the instruction. If this argument is + `True`, then `n_turns` will be ignored. Defaults to `False`. + - `system_prompt`: an optional system prompt that can be used to steer the LLM to + generate content of certain topic, guide the style, etc. If the provided inputs + contains a `system_prompt` column, then this runtime parameter will be ignored + and the one from the column will be used. Defaults to `None`. + + Input columns: + - system_prompt (`str`, optional): an optional system prompt that can be provided + to guide the generation of the instruct LLM and steer it to generate instructions + of certain topic. + + Output columns: + - conversation (`ChatType`): the generated conversation which is a list of chat + items with a role and a message. Only if `only_instructions=False`. + - instruction (`str`): the generated instructions if `only_instruction=True`. + + Categories: + - text-generation + - instruction + + References: + - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) + + Examples: + + Generating instructions with Llama 3 8B Instruct and TransformersLLM: + + ```python + from distilabel.llms import TransformersLLM + from distilabel.steps.tasks import Magpie + + magpie = Magpie( + llm=TransformersLLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + magpie_pre_query_template="llama3", + generation_kwargs={ + "temperature": 1.0, + "max_new_tokens": 64, + }, + device="mps", + ), + only_instruction=True, + ) + + magpie.load() + + result = next( + magpie.process( + inputs=[ + { + "system_prompt": "You're a math expert AI assistant that helps students of secondary school to solve calculus problems." + }, + { + "system_prompt": "You're an expert florist AI assistant that helps user to erradicate pests in their crops." + }, + ] + ) + ) + # [ + # {'instruction': "That's me! I'd love some help with solving calculus problems! What kind of calculation are you most effective at? Linear Algebra, derivatives, integrals, optimization?"}, + # {'instruction': 'I was wondering if there are certain flowers and plants that can be used for pest control?'} + # ] + ``` + + Generating conversations with Llama 3 8B Instruct and TransformersLLM: + + ```python + from distilabel.llms import TransformersLLM + from distilabel.steps.tasks import Magpie + + magpie = Magpie( + llm=TransformersLLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + magpie_pre_query_template="llama3", + generation_kwargs={ + "temperature": 1.0, + "max_new_tokens": 256, + }, + device="mps", + ), + n_turns=2, + ) + + magpie.load() + + result = next( + magpie.process( + inputs=[ + { + "system_prompt": "You're a math expert AI assistant that helps students of secondary school to solve calculus problems." + }, + { + "system_prompt": "You're an expert florist AI assistant that helps user to erradicate pests in their crops." + }, + ] + ) + ) + # [ + # { + # 'conversation': [ + # {'role': 'system', 'content': "You're a math expert AI assistant that helps students of secondary school to solve calculus problems."}, + # { + # 'role': 'user', + # 'content': 'I\'m having trouble solving the limits of functions in calculus. Could you explain how to work with them? Limits of functions are denoted by lim x→a f(x) or lim x→a [f(x)]. It is read as "the limit as x approaches a of f + # of x".' + # }, + # { + # 'role': 'assistant', + # 'content': 'Limits are indeed a fundamental concept in calculus, and understanding them can be a bit tricky at first, but don\'t worry, I\'m here to help! The notation lim x→a f(x) indeed means "the limit as x approaches a of f of + # x". What it\'s asking us to do is find the' + # } + # ] + # }, + # { + # 'conversation': [ + # {'role': 'system', 'content': "You're an expert florist AI assistant that helps user to erradicate pests in their crops."}, + # { + # 'role': 'user', + # 'content': "As a flower shop owner, I'm noticing some unusual worm-like creatures causing damage to my roses and other flowers. Can you help me identify what the problem is? Based on your expertise as a florist AI assistant, I think it + # might be pests or diseases, but I'm not sure which." + # }, + # { + # 'role': 'assistant', + # 'content': "I'd be delighted to help you investigate the issue! Since you've noticed worm-like creatures damaging your roses and other flowers, I'll take a closer look at the possibilities. Here are a few potential culprits: 1. + # **Aphids**: These small, soft-bodied insects can secrete a sticky substance called" + # } + # ] + # } + # ] + ``` + """ + + def model_post_init(self, __context: Any) -> None: + """Checks that the provided `LLM` uses the `MagpieChatTemplateMixin`.""" + super().model_post_init(__context) + + if not isinstance(self.llm, MagpieChatTemplateMixin): + raise ValueError( + f"`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`." + f"`{self.llm.__class__.__name__}` doesn't use the aforementioned mixin." + ) + + self.llm.use_magpie_template = True + + @property + def inputs(self) -> List[str]: + return [] + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + """Does nothing.""" + return [] + + @property + def outputs(self) -> List[str]: + """Either a multi-turn conversation or the instruction generated.""" + if self.only_instruction: + return ["instruction"] + return ["conversation"] + + def format_output( + self, + output: Union[str, None], + input: Union[Dict[str, Any], None] = None, + ) -> Dict[str, Any]: + """Does nothing.""" + return {} + + def process(self, inputs: StepInput) -> "StepOutput": + """Generate a list of instructions or conversations of the specified number of turns. + + Args: + inputs: a list of dictionaries that can contain a `system_prompt` key. + + Yields: + The list of generated conversations. + """ + yield self._generate_with_pre_query_template(inputs) diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py new file mode 100644 index 0000000000..8d9dca96e5 --- /dev/null +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -0,0 +1,240 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from pydantic import Field + +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.base import GeneratorTask +from distilabel.steps.tasks.magpie.base import MagpieBase + +if TYPE_CHECKING: + from distilabel.steps.typing import GeneratorStepOutput + + +class MagpieGenerator(GeneratorTask, MagpieBase): + """Generator task the generates instructions or conversations using Magpie. + + Magpie is a neat method that allows generating user instructions with no seed data + or specific system prompt thanks to the autoregressive capabilities of the instruct + fine-tuned LLMs. As they were fine-tuned using a chat template composed by a user message + and a desired assistant output, the instruct fine-tuned LLM learns that after the pre-query + or pre-instruct tokens comes an instruction. If these pre-query tokens are sent to the + LLM without any user message, then the LLM will continue generating tokens as it was + the user. This trick allows "extracting" instructions from the instruct fine-tuned LLM. + After this instruct is generated, it can be sent again to the LLM to generate this time + an assistant response. This process can be repeated N times allowing to build a multi-turn + conversation. This method was described in the paper 'Magpie: Alignment Data Synthesis from + Scratch by Prompting Aligned LLMs with Nothing'. + + Attributes: + n_turns: the number of turns that the generated conversation will have. + only_instruction: whether to generate only the instruction. If this argument is + `True`, then `n_turns` will be ignored. Defaults to `False`. + system_prompt: an optional system prompt that can be used to steer the LLM to generate + content of certain topic, guide the style, etc. If the provided inputs contains + a `system_prompt` column, then this runtime parameter will be ignored and the + one from the column will be used. Defaults to `None`. + num_rows: the number of rows to be generated. + + Runtime parameters: + - `n_turns`: the number of turns that the generated conversation will have. + - `only_instruction`: whether to generate only the instruction. If this argument + is `True`, then `n_turns` will be ignored. Defaults to `False`. + - `system_prompt`: an optional system prompt that can be used to steer the LLM to + generate content of certain topic, guide the style, etc. Defaults to `None`. + - `num_rows`: the number of rows to be generated. + + Output columns: + - conversation (`ChatType`): the generated conversation which is a list of chat + items with a role and a message. + - instruction (`str`): the generated instructions if `only_instruction=True`. + + Categories: + - text-generation + - instruction + - generator + + References: + - [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464) + + Examples: + + Generating instructions with Llama 3 8B Instruct and TransformersLLM: + + ```python + from distilabel.llms import TransformersLLM + from distilabel.steps.tasks import MagpieGenerator + + generator = MagpieGenerator( + llm=TransformersLLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + magpie_pre_query_template="llama3", + generation_kwargs={ + "temperature": 1.0, + "max_new_tokens": 256, + }, + device="mps", + ), + only_instruction=True, + num_rows=5, + ) + + generator.load() + + result = next(generator.process()) + # ( + # [ + # {"instruction": "I've just bought a new phone and I're excited to start using it."}, + # {"instruction": "What are the most common types of companies that use digital signage?"} + # ], + # True + # ) + ``` + + Generating a conversation with Llama 3 8B Instruct and TransformersLLM: + + ```python + from distilabel.llms import TransformersLLM + from distilabel.steps.tasks import MagpieGenerator + + generator = MagpieGenerator( + llm=TransformersLLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + magpie_pre_query_template="llama3", + generation_kwargs={ + "temperature": 1.0, + "max_new_tokens": 64, + }, + device="mps", + ), + n_turns=3, + num_rows=5, + ) + + generator.load() + + result = next(generator.process()) + # ( + # [ + # { + # 'conversation': [ + # { + # 'role': 'system', + # 'content': 'You are a helpful Al assistant. The user will engage in a multi−round conversation with you,asking initial questions and following up with additional related questions. Your goal is to provide thorough, relevant and + # insightful responses to help the user with their queries.' + # }, + # {'role': 'user', 'content': "I'm considering starting a social media campaign for my small business and I're not sure where to start. Can you help?"}, + # { + # 'role': 'assistant', + # 'content': "Exciting endeavor! Creating a social media campaign can be a great way to increase brand awareness, drive website traffic, and ultimately boost sales. I'd be happy to guide you through the process. To get started, + # let's break down the basics. First, we need to identify your goals and target audience. What do" + # }, + # { + # 'role': 'user', + # 'content': "Before I start a social media campaign, what kind of costs ammol should I expect to pay? There are several factors that contribute to the total cost of running a social media campaign. Let me outline some of the main + # expenses you might encounter: 1. Time: As the business owner, you'll likely spend time creating" + # }, + # { + # 'role': 'assistant', + # 'content': 'Time is indeed one of the biggest investments when it comes to running a social media campaign! Besides time, you may also incur costs associated with: 2. Content creation: You might need to hire freelancers or + # agencies to create high-quality content (images, videos, captions) for your social media platforms. 3. Advertising' + # } + # ] + # }, + # { + # 'conversation': [ + # { + # 'role': 'system', + # 'content': 'You are a helpful Al assistant. The user will engage in a multi−round conversation with you,asking initial questions and following up with additional related questions. Your goal is to provide thorough, relevant and + # insightful responses to help the user with their queries.' + # }, + # {'role': 'user', 'content': "I am thinking of buying a new laptop or computer. What are some important factors I should consider when making your decision? I'll make sure to let you know if any other favorites or needs come up!"}, + # { + # 'role': 'assistant', + # 'content': 'Exciting times ahead! When considering a new laptop or computer, there are several key factors to think about to ensure you find the right one for your needs. Here are some crucial ones to get you started: 1. + # **Purpose**: How will you use your laptop or computer? For work, gaming, video editing,' + # }, + # { + # 'role': 'user', + # 'content': 'Let me stop you there. Let\'s explore this "purpose" factor that you mentioned earlier. Can you elaborate more on what type of devices would be suitable for different purposes? For example, if I\'re primarily using my + # laptop for general usage like browsing, email, and word processing, would a budget-friendly laptop be sufficient' + # }, + # { + # 'role': 'assistant', + # 'content': "Understanding your purpose can greatly impact the type of device you'll need. **General Usage (Browsing, Email, Word Processing)**: For casual users who mainly use their laptop for daily tasks, a budget-friendly + # option can be sufficient. Look for laptops with: * Intel Core i3 or i5 processor* " + # } + # ] + # } + # ], + # True + # ) + ``` + """ + + # TODO: move this to `GeneratorTask` + num_rows: RuntimeParameter[int] = Field( + default=None, description="The number of rows to generate." + ) + + def model_post_init(self, __context: Any) -> None: + """Checks that the provided `LLM` uses the `MagpieChatTemplateMixin`.""" + super().model_post_init(__context) + + if not isinstance(self.llm, MagpieChatTemplateMixin): + raise ValueError( + f"`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`." + f"`{self.llm.__class__.__name__}` doesn't use the aforementioned mixin." + ) + + self.llm.use_magpie_template = True + + def format_output( + self, + output: Union[str, None], + input: Union[Dict[str, Any], None] = None, + ) -> Dict[str, Any]: + """Does nothing.""" + return {} + + @property + def outputs(self) -> List[str]: + """Either a multi-turn conversation or the instruction generated.""" + if self.only_instruction: + return ["instruction"] + return ["conversation"] + + def process(self, offset: int = 0) -> "GeneratorStepOutput": + """Generates the desired number of instructions or conversations using Magpie. + + Args: + offset: The offset to start the generation from. Defaults to `0`. + + Yields: + The generated instructions or conversations. + """ + generated = offset + + while generated <= self.num_rows: # type: ignore + rows_to_generate = ( + self.num_rows if self.num_rows < self.batch_size else self.batch_size # type: ignore + ) + conversations = self._generate_with_pre_query_template( + inputs=[{} for _ in range(rows_to_generate)] # type: ignore + ) + generated += rows_to_generate # type: ignore + yield (conversations, generated == self.num_rows) diff --git a/src/distilabel/steps/tasks/typing.py b/src/distilabel/steps/tasks/typing.py index 4f92cdc057..ae9fd9519e 100644 --- a/src/distilabel/steps/tasks/typing.py +++ b/src/distilabel/steps/tasks/typing.py @@ -38,7 +38,7 @@ class OutlinesStructuredOutputType(TypedDict, total=False): as obtained from `model_to_schema(BaseModel)`, if "regex", it should be a regex pattern as a string. """ - whitespace_pattern: Optional[Union[str, List[str]]] = None + whitespace_pattern: Optional[Union[str, List[str]]] """If "json" corresponds to a string or a list of strings with a pattern (doesn't impact string literals). For example, to allow only a single space or newline with diff --git a/src/distilabel/utils/itertools.py b/src/distilabel/utils/itertools.py index 88ce86cc4e..2555f3b262 100644 --- a/src/distilabel/utils/itertools.py +++ b/src/distilabel/utils/itertools.py @@ -13,7 +13,7 @@ # limitations under the License. from itertools import zip_longest -from typing import Any, Iterable, List, Literal, TypeVar +from typing import Any, Iterable, Literal, Tuple, TypeVar T = TypeVar("T") @@ -26,7 +26,7 @@ def grouper( *, incomplete: Literal["fill", "strict", "ignore"] = "fill", fillvalue: Any = None, -) -> Iterable[List[T]]: +) -> Iterable[Tuple[T]]: "Collect data into non-overlapping fixed-length chunks or blocks." # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index bbe6ca1ed4..adcd690276 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, List import pytest -from distilabel.llms.base import AsyncLLM +from distilabel.llms.base import LLM, AsyncLLM +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin if TYPE_CHECKING: from distilabel.llms.typing import GenerateOutput @@ -37,6 +38,22 @@ async def agenerate( return ["output" for _ in range(num_generations)] +class DummyMagpieLLM(LLM, MagpieChatTemplateMixin): + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + def generate( + self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any + ) -> List["GenerateOutput"]: + return [ + ["Hello Magpie" for _ in range(num_generations)] for _ in range(len(inputs)) + ] + + @pytest.fixture def dummy_llm() -> AsyncLLM: return DummyLLM() diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index ecc5d97596..436815b0e5 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -12,21 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import random +from typing import Generator from unittest import mock -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, patch import nest_asyncio import pytest from distilabel.llms.huggingface.inference_endpoints import InferenceEndpointsLLM +from huggingface_hub import ( + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputMessage, + ChatCompletionOutputUsage, +) + + +@pytest.fixture(autouse=True) +def mock_hf_token_env_variable() -> Generator[None, None, None]: + with patch.dict(os.environ, {"HF_TOKEN": "hf_token"}): + yield @patch("huggingface_hub.AsyncInferenceClient") -@patch("openai.AsyncOpenAI") class TestInferenceEndpointsLLM: - def test_load_no_api_key( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + def test_no_tokenizer_magpie_raise_value_error( + self, mock_inference_client: MagicMock + ) -> None: + with pytest.raises( + ValueError, + match="`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`", + ): + InferenceEndpointsLLM( + base_url="http://localhost:8000", + use_magpie_template=True, + magpie_pre_query_template="llama3", + ) + + def test_tokenizer_id_set_if_model_id_and_structured_output( + self, mock_inference_client: MagicMock ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral", + structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"}, + ) + + assert llm.tokenizer_id == llm.model_id + + def test_load_no_api_key(self, mock_inference_client: MagicMock) -> None: + del os.environ["HF_TOKEN"] + llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral" ) @@ -40,12 +76,8 @@ def test_load_no_api_key( ): llm.load() - def test_load_with_cached_token( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock - ) -> None: - llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral" - ) + def test_load_with_cached_token(self, mock_inference_client: MagicMock) -> None: + llm = InferenceEndpointsLLM(base_url="http://localhost:8000") # Mock `huggingface_hub.constants.HF_TOKEN_PATH` to exist with ( @@ -58,7 +90,7 @@ def test_load_with_cached_token( llm.load() def test_serverless_inference_endpoints_llm( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral" @@ -68,7 +100,7 @@ def test_serverless_inference_endpoints_llm( assert llm.model_name == "distilabel-internal-testing/tiny-random-mistral" def test_dedicated_inference_endpoints_llm( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( endpoint_name="tiny-random-mistral", @@ -79,11 +111,12 @@ def test_dedicated_inference_endpoints_llm( assert llm.model_name == "tiny-random-mistral" def test_dedicated_inference_endpoints_llm_via_url( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( base_url="https://api-inference.huggingface.co/models/distilabel-internal-testing/tiny-random-mistral" ) + llm.load() assert isinstance(llm, InferenceEndpointsLLM) assert ( @@ -92,13 +125,14 @@ def test_dedicated_inference_endpoints_llm_via_url( ) @pytest.mark.asyncio - async def test_agenerate_via_inference_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + async def test_agenerate_with_text_generation( + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral" + model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) - llm._aclient = mock_inference_client + llm.load() llm._aclient.text_generation = AsyncMock( return_value=" Aenean hendrerit aliquam velit. ..." @@ -114,23 +148,39 @@ async def test_agenerate_via_inference_client( ) == [" Aenean hendrerit aliquam velit. ..."] @pytest.mark.asyncio - async def test_agenerate_via_openai_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + async def test_agenerate_with_chat_completion( + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", - use_openai_client=True, ) - llm._aclient = mock_openai_client - - mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + llm.load() + + llm._aclient.chat_completion = AsyncMock( # type: ignore + return_value=ChatCompletionOutput( # type: ignore + choices=[ + ChatCompletionOutputComplete( + finish_reason="length", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=" Aenean hendrerit aliquam velit. ...", + ), + ) + ], + created=1721045246, + id="", + model="meta-llama/Meta-Llama-3-70B-Instruct", + object="chat.completion", + system_fingerprint="2.1.1-dev0-sha-4327210", + usage=ChatCompletionOutputUsage( + completion_tokens=66, prompt_tokens=18, total_tokens=84 + ), + ) ) - llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) assert await llm.agenerate( input=[ - {"role": "system", "content": ""}, { "role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", @@ -139,48 +189,58 @@ async def test_agenerate_via_openai_client( ) == [" Aenean hendrerit aliquam velit. ..."] @pytest.mark.asyncio - async def test_generate_via_inference_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + async def test_agenerate_with_chat_completion_fails( + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral" + model_id="distilabel-internal-testing/tiny-random-mistral", ) - llm._aclient = mock_inference_client - - llm._aclient.text_generation = AsyncMock( - return_value=" Aenean hendrerit aliquam velit. ..." + llm.load() + + llm._aclient.chat_completion = AsyncMock( # type: ignore + return_value=ChatCompletionOutput( # type: ignore + choices=[ + ChatCompletionOutputComplete( + finish_reason="eos_token", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + ), + ) + ], + created=1721045246, + id="", + model="meta-llama/Meta-Llama-3-70B-Instruct", + object="chat.completion", + system_fingerprint="2.1.1-dev0-sha-4327210", + usage=ChatCompletionOutputUsage( + completion_tokens=66, prompt_tokens=18, total_tokens=84 + ), + ) ) - nest_asyncio.apply() - - assert llm.generate( - inputs=[ - [ - {"role": "system", "content": ""}, - { - "role": "user", - "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - }, - ] + assert await llm.agenerate( + input=[ + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, ] - ) == [(" Aenean hendrerit aliquam velit. ...",)] + ) == [None] @pytest.mark.asyncio - async def test_generate_via_openai_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock - ) -> None: + async def test_generate(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", - use_openai_client=True, + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) - llm._aclient = mock_openai_client + llm.load() - mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + llm._aclient.text_generation = AsyncMock( + return_value=" Aenean hendrerit aliquam velit. ..." ) - llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) - ... nest_asyncio.apply() assert llm.generate( @@ -193,17 +253,18 @@ async def test_generate_via_openai_client( }, ] ] - ) == [(" Aenean hendrerit aliquam velit. ...",)] + ) == [[" Aenean hendrerit aliquam velit. ..."]] @pytest.mark.asyncio async def test_agenerate_with_structured_output( - self, mock_inference_client: MagicMock, _: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"}, ) - llm._aclient = mock_inference_client + llm.load() llm._aclient.text_generation = AsyncMock( return_value=" Aenean hendrerit aliquam velit. ..." @@ -223,29 +284,27 @@ async def test_agenerate_with_structured_output( ) == [" Aenean hendrerit aliquam velit. ..."] kwargs = { - "prompt": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + "prompt": "[INST] Lorem ipsum dolor sit amet, consectetur adipiscing elit. [/INST]", "max_new_tokens": 128, "do_sample": False, "typical_p": None, "repetition_penalty": None, + "frequency_penalty": None, "temperature": 1.0, "top_p": None, "top_k": None, "stop_sequences": None, "return_full_text": False, + "seed": 2053695854357871005, # pre-computed random value with `random.seed(42)` "watermark": False, "grammar": {"type": "regex", "value": "\\b[A-Z][a-z]*\\b"}, - "seed": 478163327, # pre-computed random value with `random.seed(42)` } - mock_inference_client.text_generation.assert_called_with(**kwargs) + llm._aclient.text_generation.assert_called_with(**kwargs) # type: ignore - def test_serialization( - self, - mock_inference_client: MagicMock, - mock_openai_client: MagicMock, - ) -> None: + def test_serialization(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) _dump = { @@ -253,11 +312,12 @@ def test_serialization( "endpoint_name": None, "endpoint_namespace": None, "base_url": None, - "tokenizer_id": None, + "tokenizer_id": "distilabel-internal-testing/tiny-random-mistral", "generation_kwargs": {}, + "magpie_pre_query_template": None, "structured_output": None, "model_display_name": None, - "use_openai_client": False, + "use_magpie_template": False, "type_info": { "module": "distilabel.llms.huggingface.inference_endpoints", "name": "InferenceEndpointsLLM", diff --git a/tests/unit/llms/mixins/__init__.py b/tests/unit/llms/mixins/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/llms/mixins/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/llms/test_mixins.py b/tests/unit/llms/mixins/test_cuda_device_placement.py similarity index 97% rename from tests/unit/llms/test_mixins.py rename to tests/unit/llms/mixins/test_cuda_device_placement.py index c0c7b10671..80690bbf41 100644 --- a/tests/unit/llms/test_mixins.py +++ b/tests/unit/llms/mixins/test_cuda_device_placement.py @@ -19,7 +19,7 @@ import pytest from distilabel.llms.base import LLM -from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin if TYPE_CHECKING: from distilabel.steps.tasks.typing import ChatType diff --git a/tests/unit/llms/mixins/test_magpie.py b/tests/unit/llms/mixins/test_magpie.py new file mode 100644 index 0000000000..bc7503fb2c --- /dev/null +++ b/tests/unit/llms/mixins/test_magpie.py @@ -0,0 +1,60 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from distilabel.llms.mixins.magpie import MAGPIE_PRE_QUERY_TEMPLATES + +from tests.unit.conftest import DummyMagpieLLM + + +class TestMagpieChatTemplateMixin: + def test_magpie_pre_query_template_set(self) -> None: + with pytest.raises( + ValueError, + match="Cannot set `use_magpie_template=True` if `magpie_pre_query_template` is `None`", + ): + DummyMagpieLLM(use_magpie_template=True) + + def test_magpie_pre_query_template_alias_resolved(self) -> None: + llm = DummyMagpieLLM(magpie_pre_query_template="llama3") + assert llm.magpie_pre_query_template == MAGPIE_PRE_QUERY_TEMPLATES["llama3"] + + def test_apply_magpie_pre_query_template(self) -> None: + llm = DummyMagpieLLM(magpie_pre_query_template="") + + assert ( + llm.apply_magpie_pre_query_template( + prompt="Hello hello", input=[] + ) + == "Hello hello" + ) + + llm = DummyMagpieLLM( + use_magpie_template=True, magpie_pre_query_template="" + ) + + assert ( + llm.apply_magpie_pre_query_template( + prompt="Hello hello", input=[] + ) + == "Hello hello" + ) + + assert ( + llm.apply_magpie_pre_query_template( + prompt="Hello helloHey", + input=[{"role": "user", "content": "Hey"}], + ) + == "Hello helloHey" + ) diff --git a/tests/unit/llms/test_groq.py b/tests/unit/llms/test_groq.py index 7607ab2cb2..8789b56a6f 100644 --- a/tests/unit/llms/test_groq.py +++ b/tests/unit/llms/test_groq.py @@ -104,7 +104,7 @@ async def test_generate(self, mock_groq: MagicMock) -> None: }, ] ] - ) == [(" Aenean hendrerit aliquam velit. ...",)] + ) == [[" Aenean hendrerit aliquam velit. ..."]] @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/steps/tasks/magpie/__init__.py b/tests/unit/steps/tasks/magpie/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/steps/tasks/magpie/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/steps/tasks/magpie/test_base.py b/tests/unit/steps/tasks/magpie/test_base.py new file mode 100644 index 0000000000..77ed178f4c --- /dev/null +++ b/tests/unit/steps/tasks/magpie/test_base.py @@ -0,0 +1,269 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from distilabel.llms.openai import OpenAILLM +from distilabel.steps.tasks.magpie.base import MAGPIE_MULTI_TURN_SYSTEM_PROMPT, Magpie + +from tests.unit.conftest import DummyMagpieLLM + + +class TestMagpie: + def test_raise_value_error_llm_no_magpie_mixin(self) -> None: + with pytest.raises( + ValueError, + match="`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`", + ): + Magpie(llm=OpenAILLM(model="gpt-4", api_key="fake")) # type: ignore + + def test_outputs(self) -> None: + task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3")) + + assert task.outputs == ["conversation"] + + task = Magpie( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + only_instruction=True, + ) + + assert task.outputs == ["instruction"] + + def test_process(self) -> None: + task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=1) + + task.load() + + assert next(task.process(inputs=[{}, {}, {}])) == [ + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + ] + + def test_process_with_n_turns(self) -> None: + task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=2) + + task.load() + + assert next(task.process(inputs=[{}, {}, {}])) == [ + { + "conversation": [ + {"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + ] + + def test_process_with_system_prompt_per_row(self) -> None: + task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=2) + + task.load() + + assert next( + task.process( + inputs=[ + {"system_prompt": "You're a math expert assistant."}, + {"system_prompt": "You're a florist expert assistant."}, + {"system_prompt": "You're a plumber expert assistant."}, + ] + ) + ) == [ + { + "conversation": [ + {"role": "system", "content": "You're a math expert assistant."}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "system", "content": "You're a florist expert assistant."}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + { + "conversation": [ + {"role": "system", "content": "You're a plumber expert assistant."}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie"}, + {"role": "assistant", "content": "Hello Magpie"}, + ], + }, + ] + + def test_process_only_instruction(self) -> None: + task = Magpie( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + only_instruction=True, + ) + + task.load() + + assert next(task.process(inputs=[{}, {}, {}])) == [ + {"instruction": "Hello Magpie"}, + {"instruction": "Hello Magpie"}, + {"instruction": "Hello Magpie"}, + ] + + def test_serialization(self) -> None: + task = Magpie( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + only_instruction=True, + ) + + assert task.dump() == { + "llm": { + "use_magpie_template": True, + "magpie_pre_query_template": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", + "generation_kwargs": {}, + "type_info": { + "module": "tests.unit.conftest", + "name": "DummyMagpieLLM", + }, + }, + "n_turns": 1, + "only_instruction": True, + "system_prompt": None, + "name": "magpie_0", + "resources": { + "replicas": 1, + "cpus": None, + "gpus": None, + "memory": None, + "resources": None, + }, + "input_mappings": {}, + "output_mappings": {}, + "input_batch_size": 50, + "group_generations": False, + "add_raw_output": True, + "num_generations": 1, + "runtime_parameters_info": [ + { + "name": "llm", + "runtime_parameters_info": [ + { + "name": "generation_kwargs", + "description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.", + "keys": [{"name": "kwargs", "optional": False}], + } + ], + }, + { + "name": "n_turns", + "optional": True, + "description": "The number of turns to generate for the conversation.", + }, + { + "name": "only_instruction", + "optional": True, + "description": "Whether to generate only the instruction. If this argument is `True`, then `n_turns` will be ignored.", + }, + { + "name": "system_prompt", + "optional": True, + "description": "An optional system prompt that can be used to steer the LLM to generate content of certain topic, guide the style, etc.", + }, + { + "name": "resources", + "runtime_parameters_info": [ + { + "name": "replicas", + "optional": True, + "description": "The number of replicas for the step.", + }, + { + "name": "cpus", + "optional": True, + "description": "The number of CPUs assigned to each step replica.", + }, + { + "name": "gpus", + "optional": True, + "description": "The number of GPUs assigned to each step replica.", + }, + { + "name": "memory", + "optional": True, + "description": "The memory in bytes required for each step replica.", + }, + { + "name": "resources", + "optional": True, + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + }, + ], + }, + { + "name": "input_batch_size", + "optional": True, + "description": "The number of rows that will contain the batches processed by the step.", + }, + { + "name": "add_raw_output", + "optional": True, + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + }, + { + "name": "num_generations", + "optional": True, + "description": "The number of generations to be produced per input.", + }, + ], + "type_info": { + "module": "distilabel.steps.tasks.magpie.base", + "name": "Magpie", + }, + } diff --git a/tests/unit/steps/tasks/magpie/test_generator.py b/tests/unit/steps/tasks/magpie/test_generator.py new file mode 100644 index 0000000000..7ebb815e0d --- /dev/null +++ b/tests/unit/steps/tasks/magpie/test_generator.py @@ -0,0 +1,154 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from distilabel.llms.openai import OpenAILLM +from distilabel.steps.tasks.magpie.generator import MagpieGenerator + +from tests.unit.conftest import DummyMagpieLLM + + +class TestMagpieGenerator: + def test_raise_value_error_llm_no_magpie_mixin(self) -> None: + with pytest.raises( + ValueError, + match="`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`", + ): + MagpieGenerator(llm=OpenAILLM(model="gpt-4", api_key="fake")) # type: ignore + + def test_outputs(self) -> None: + task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3")) + + assert task.outputs == ["conversation"] + + task = MagpieGenerator( + llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), + only_instruction=True, + ) + + assert task.outputs == ["instruction"] + + def test_serialization(self) -> None: + task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3")) + + assert task.dump() == { + "llm": { + "use_magpie_template": True, + "magpie_pre_query_template": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n", + "generation_kwargs": {}, + "type_info": { + "module": "tests.unit.conftest", + "name": "DummyMagpieLLM", + }, + }, + "n_turns": 1, + "only_instruction": False, + "system_prompt": None, + "name": "magpie_generator_0", + "resources": { + "replicas": 1, + "cpus": None, + "gpus": None, + "memory": None, + "resources": None, + }, + "input_mappings": {}, + "output_mappings": {}, + "batch_size": 50, + "group_generations": False, + "add_raw_output": True, + "num_generations": 1, + "num_rows": None, + "runtime_parameters_info": [ + { + "name": "llm", + "runtime_parameters_info": [ + { + "name": "generation_kwargs", + "description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.", + "keys": [{"name": "kwargs", "optional": False}], + } + ], + }, + { + "name": "n_turns", + "optional": True, + "description": "The number of turns to generate for the conversation.", + }, + { + "name": "only_instruction", + "optional": True, + "description": "Whether to generate only the instruction. If this argument is `True`, then `n_turns` will be ignored.", + }, + { + "name": "system_prompt", + "optional": True, + "description": "An optional system prompt that can be used to steer the LLM to generate content of certain topic, guide the style, etc.", + }, + { + "name": "resources", + "runtime_parameters_info": [ + { + "name": "replicas", + "optional": True, + "description": "The number of replicas for the step.", + }, + { + "name": "cpus", + "optional": True, + "description": "The number of CPUs assigned to each step replica.", + }, + { + "name": "gpus", + "optional": True, + "description": "The number of GPUs assigned to each step replica.", + }, + { + "name": "memory", + "optional": True, + "description": "The memory in bytes required for each step replica.", + }, + { + "name": "resources", + "optional": True, + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + }, + ], + }, + { + "name": "batch_size", + "optional": True, + "description": "The number of rows that will contain the batches generated by the step.", + }, + { + "name": "add_raw_output", + "optional": True, + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + }, + { + "name": "num_generations", + "optional": True, + "description": "The number of generations to be produced per input.", + }, + { + "name": "num_rows", + "optional": False, + "description": "The number of rows to generate.", + }, + ], + "type_info": { + "module": "distilabel.steps.tasks.magpie.generator", + "name": "MagpieGenerator", + }, + } diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index e174f53716..0e488eea13 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Type, Union +from typing import Any, Dict, Literal, Type, Union import pytest from distilabel.llms.huggingface.transformers import TransformersLLM @@ -33,6 +33,7 @@ class DummyUserTest(BaseModel): DUMP_JSON = { "cuda_devices": "auto", "generation_kwargs": {}, + "magpie_pre_query_template": None, "structured_output": { "format": "json", "schema": { @@ -57,6 +58,7 @@ class DummyUserTest(BaseModel): "device": None, "device_map": None, "token": None, + "use_magpie_template": False, "type_info": { "module": "distilabel.llms.huggingface.transformers", "name": "TransformersLLM", @@ -66,6 +68,7 @@ class DummyUserTest(BaseModel): DUMP_REGEX = { "cuda_devices": "auto", "generation_kwargs": {}, + "magpie_pre_query_template": None, "structured_output": { "format": "regex", "schema": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", @@ -81,6 +84,7 @@ class DummyUserTest(BaseModel): "device": None, "device_map": None, "token": None, + "use_magpie_template": False, "type_info": { "module": "distilabel.llms.huggingface.transformers", "name": "TransformersLLM", @@ -149,7 +153,10 @@ def test_generation( ], ) def test_serialization( - self, format: str, schema: Union[str, Type[BaseModel]], dump: Dict[str, Any] + self, + format: Literal["json", "regex"], + schema: Union[str, Type[BaseModel]], + dump: Dict[str, Any], ) -> None: llm = TransformersLLM( model="openaccess-ai-collective/tiny-mistral",