diff --git a/docs/sections/pipeline_samples/examples/index.md b/docs/sections/pipeline_samples/examples/index.md index efa512cfeb..aa74004357 100644 --- a/docs/sections/pipeline_samples/examples/index.md +++ b/docs/sections/pipeline_samples/examples/index.md @@ -10,7 +10,7 @@ Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `dis This script makes use of [`LlamaCppLLM`][distilabel.llms.llamacpp.LlamaCppLLM] and the structured output capabilities thanks to [`outlines`](https://outlines-dev.github.io/outlines/welcome/) to generate RPG characters that adhere to a JSON schema. - It makes use of a local model which can be downlaoded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM]. + It makes use of a local model which can be downloaded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM]. ??? Run diff --git a/src/distilabel/llms/anthropic.py b/src/distilabel/llms/anthropic.py index fb4f6dc03c..af0fdbc76e 100644 --- a/src/distilabel/llms/anthropic.py +++ b/src/distilabel/llms/anthropic.py @@ -33,7 +33,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput from distilabel.utils.itertools import grouper if TYPE_CHECKING: @@ -163,7 +163,7 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: StandardInput, max_tokens: int = 128, stop_sequences: Union[List[str], None] = None, temperature: float = 1.0, @@ -223,7 +223,7 @@ async def agenerate( # type: ignore @override def generate( self, - inputs: List["ChatType"], + inputs: List["StandardInput"], num_generations: int = 1, **kwargs: Any, ) -> List["GenerateOutput"]: @@ -232,7 +232,7 @@ def generate( """ async def agenerate( - inputs: List["ChatType"], **kwargs: Any + inputs: List["StandardInput"], **kwargs: Any ) -> "GenerateOutput": """Internal function to parallelize the asynchronous generation of responses.""" tasks = [ diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index b61c72cf02..a320642615 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -37,7 +37,7 @@ InstructorStructuredOutputType, ) from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType - from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.tasks.typing import FormattedInput, StandardInput from distilabel.utils.docstring import Docstring if in_notebook(): @@ -94,7 +94,7 @@ def model_name(self) -> str: @abstractmethod def generate( self, - inputs: List["ChatType"], + inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any, ) -> List["GenerateOutput"]: @@ -187,7 +187,9 @@ def generate_parsed_docstring(self) -> "Docstring": """ return parse_google_docstring(self.generate) - def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]: + def get_last_hidden_states( + self, inputs: List["StandardInput"] + ) -> List["HiddenState"]: """Method to get the last hidden states of the model for a list of inputs. Args: @@ -264,7 +266,7 @@ def event_loop(self) -> "asyncio.AbstractEventLoop": @abstractmethod async def agenerate( - self, input: "ChatType", num_generations: int = 1, **kwargs: Any + self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any ) -> List[Union[str, None]]: """Method to generate a `num_generations` responses for a given input asynchronously, and executed concurrently in `generate` method. @@ -273,7 +275,7 @@ async def agenerate( def generate( self, - inputs: List["ChatType"], + inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any, ) -> List["GenerateOutput"]: @@ -282,7 +284,7 @@ def generate( """ async def agenerate( - inputs: List["ChatType"], **kwargs: Any + inputs: List["FormattedInput"], **kwargs: Any ) -> List[List[Union[str, None]]]: """Internal function to parallelize the asynchronous generation of responses.""" tasks = [ diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/llms/cohere.py index a49b203f3d..c4a9c361c5 100644 --- a/src/distilabel/llms/cohere.py +++ b/src/distilabel/llms/cohere.py @@ -30,7 +30,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput from distilabel.utils.itertools import grouper if TYPE_CHECKING: @@ -132,7 +132,7 @@ def load(self) -> None: self.structured_output = structured_output def _format_chat_to_cohere( - self, input: "ChatType" + self, input: "StandardInput" ) -> Tuple[Union[str, None], List["ChatMessage"], str]: """Formats the chat input to the Cohere Chat API conversational format. @@ -169,7 +169,7 @@ def _format_chat_to_cohere( @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: StandardInput, temperature: Optional[float] = None, max_tokens: Optional[int] = None, k: Optional[int] = None, @@ -241,7 +241,7 @@ async def agenerate( # type: ignore @override def generate( self, - inputs: List["ChatType"], + inputs: List["StandardInput"], num_generations: int = 1, **kwargs: Any, ) -> List["GenerateOutput"]: @@ -249,7 +249,7 @@ def generate( synchronously awaiting for the response of each input sent to `agenerate`.""" async def agenerate( - inputs: List["ChatType"], **kwargs: Any + inputs: List["StandardInput"], **kwargs: Any ) -> "GenerateOutput": """Internal function to parallelize the asynchronous generation of responses.""" tasks = [ diff --git a/src/distilabel/llms/groq.py b/src/distilabel/llms/groq.py index 4905f82839..75fb8d5b32 100644 --- a/src/distilabel/llms/groq.py +++ b/src/distilabel/llms/groq.py @@ -22,7 +22,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.steps.base import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput from distilabel.utils.itertools import grouper if TYPE_CHECKING: @@ -131,7 +131,7 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: StandardInput, seed: Optional[int] = None, max_new_tokens: int = 128, temperature: float = 1.0, @@ -188,7 +188,7 @@ async def agenerate( # type: ignore @override def generate( self, - inputs: List["ChatType"], + inputs: List["StandardInput"], num_generations: int = 1, **kwargs: Any, ) -> List["GenerateOutput"]: @@ -197,7 +197,7 @@ def generate( """ async def agenerate( - inputs: List["ChatType"], **kwargs: Any + inputs: List["StandardInput"], **kwargs: Any ) -> "GenerateOutput": """Internal function to parallelize the asynchronous generation of responses.""" tasks = [ diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 4570b93d1d..201f8237aa 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -31,7 +31,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import FormattedInput, Grammar, StandardInput from distilabel.utils.itertools import grouper if TYPE_CHECKING: @@ -148,6 +148,11 @@ class InferenceEndpointsLLM(AsyncLLM): model_display_name: Optional[str] = None use_openai_client: bool = False + grammar: Optional[RuntimeParameter[Grammar]] = Field( + default=None, + description="The grammar to use across all the generations.", + ) + _model_name: Optional[str] = PrivateAttr(default=None) _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None) _api_key_env_var: str = PrivateAttr(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME) @@ -290,7 +295,7 @@ def model_name(self) -> Union[str, None]: # type: ignore async def _openai_agenerate( self, - input: "ChatType", + input: "StandardInput", max_new_tokens: int = 128, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, @@ -322,7 +327,7 @@ async def _openai_agenerate( @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: "FormattedInput", max_new_tokens: int = 128, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, @@ -379,6 +384,10 @@ async def agenerate( # type: ignore ) stop_sequences = stop_sequences[:4] + grammar = None + if isinstance(input, tuple): + input, grammar = input + if self.use_openai_client: return await self._openai_agenerate( input=input, @@ -413,6 +422,9 @@ async def agenerate( # type: ignore stop_sequences=stop_sequences, return_full_text=return_full_text, watermark=watermark, + # NOTE: `self.grammar` applies to all the generations, while `grammar` is intended + # to be different per each input, and those are not intended to be used together + grammar=grammar or self.grammar, # type: ignore # NOTE: here to ensure that the cache is not used and a different response is # generated every time seed=seed or random.randint(0, 2147483647), @@ -429,7 +441,7 @@ async def agenerate( # type: ignore @override def generate( self, - inputs: List["ChatType"], + inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any, ) -> List["GenerateOutput"]: @@ -438,7 +450,7 @@ def generate( """ async def agenerate( - inputs: List["ChatType"], **kwargs: Any + inputs: List["FormattedInput"], **kwargs: Any ) -> "GenerateOutput": """Internal function to parallelize the asynchronous generation of responses.""" tasks = [ diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 19ac43b41b..1f654b7a7a 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -21,7 +21,7 @@ from distilabel.llms.chat_templates import CHATML_TEMPLATE from distilabel.llms.mixins import CudaDevicePlacementMixin from distilabel.llms.typing import GenerateOutput -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput if TYPE_CHECKING: from transformers import Pipeline @@ -130,7 +130,7 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model - def prepare_input(self, input: "ChatType") -> str: + def prepare_input(self, input: "StandardInput") -> str: """Prepares the input by applying the chat template to the input, which is formatted as an OpenAI conversation, and adding the generation prompt. """ @@ -143,7 +143,7 @@ def prepare_input(self, input: "ChatType") -> str: @validate_call def generate( # type: ignore self, - inputs: List[ChatType], + inputs: List[StandardInput], num_generations: int = 1, max_new_tokens: int = 128, temperature: float = 0.1, @@ -189,7 +189,9 @@ def generate( # type: ignore for output in outputs ] - def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]: + def get_last_hidden_states( + self, inputs: List["StandardInput"] + ) -> List["HiddenState"]: """Gets the last `hidden_states` of the model for the given inputs. It doesn't execute the task head. diff --git a/src/distilabel/llms/litellm.py b/src/distilabel/llms/litellm.py index d3660c5ea0..c664133012 100644 --- a/src/distilabel/llms/litellm.py +++ b/src/distilabel/llms/litellm.py @@ -20,7 +20,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput if TYPE_CHECKING: from litellm import Choices @@ -90,7 +90,7 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: StandardInput, num_generations: int = 1, functions: Optional[List] = None, function_call: Optional[str] = None, diff --git a/src/distilabel/llms/llamacpp.py b/src/distilabel/llms/llamacpp.py index f8f50ff154..94548baa69 100644 --- a/src/distilabel/llms/llamacpp.py +++ b/src/distilabel/llms/llamacpp.py @@ -19,7 +19,7 @@ from distilabel.llms.base import LLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput if TYPE_CHECKING: from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList @@ -128,7 +128,7 @@ def model_name(self) -> str: @validate_call def generate( # type: ignore self, - inputs: List[ChatType], + inputs: List[StandardInput], num_generations: int = 1, max_new_tokens: int = 128, frequency_penalty: float = 0.0, diff --git a/src/distilabel/llms/mistral.py b/src/distilabel/llms/mistral.py index dd96cae91f..8eafae87f7 100644 --- a/src/distilabel/llms/mistral.py +++ b/src/distilabel/llms/mistral.py @@ -22,7 +22,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput from distilabel.utils.itertools import grouper if TYPE_CHECKING: @@ -129,7 +129,7 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: StandardInput, max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, @@ -180,7 +180,7 @@ async def agenerate( # type: ignore @override def generate( self, - inputs: List["ChatType"], + inputs: List["StandardInput"], num_generations: int = 1, **kwargs: Any, ) -> List["GenerateOutput"]: @@ -189,7 +189,7 @@ def generate( """ async def agenerate( - inputs: List["ChatType"], **kwargs: Any + inputs: List["StandardInput"], **kwargs: Any ) -> "GenerateOutput": """Internal function to parallelize the asynchronous generation of responses.""" tasks = [ diff --git a/src/distilabel/llms/ollama.py b/src/distilabel/llms/ollama.py index fb06f1eed3..491e273279 100644 --- a/src/distilabel/llms/ollama.py +++ b/src/distilabel/llms/ollama.py @@ -19,7 +19,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput if TYPE_CHECKING: from ollama import AsyncClient @@ -117,7 +117,7 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: StandardInput, num_generations: int = 1, format: Literal["", "json"] = "", # TODO: include relevant options from `Options` in `agenerate` method. diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py index 6dedc2387c..a659ae5499 100644 --- a/src/distilabel/llms/openai.py +++ b/src/distilabel/llms/openai.py @@ -20,7 +20,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput if TYPE_CHECKING: from openai import AsyncOpenAI @@ -129,7 +129,7 @@ def model_name(self) -> str: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: StandardInput, num_generations: int = 1, max_new_tokens: int = 128, frequency_penalty: float = 0.0, diff --git a/src/distilabel/llms/vertexai.py b/src/distilabel/llms/vertexai.py index 28ceee3a0c..34cc9484f8 100644 --- a/src/distilabel/llms/vertexai.py +++ b/src/distilabel/llms/vertexai.py @@ -18,7 +18,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput if TYPE_CHECKING: from vertexai.generative_models import Content, GenerativeModel @@ -87,7 +87,7 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model - def _chattype_to_content(self, input: "ChatType") -> List["Content"]: + def _chattype_to_content(self, input: "StandardInput") -> List["Content"]: """Converts a chat type to a list of content items expected by the API. Args: @@ -114,7 +114,7 @@ def _chattype_to_content(self, input: "ChatType") -> List["Content"]: @validate_call async def agenerate( # type: ignore self, - input: ChatType, + input: StandardInput, num_generations: int = 1, temperature: Optional[float] = None, top_p: Optional[float] = None, diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index 00b3807465..373fb241df 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -21,7 +21,7 @@ from distilabel.llms.mixins import CudaDevicePlacementMixin from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.tasks.typing import StandardInput if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -153,7 +153,7 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model - def prepare_input(self, input: "ChatType") -> str: + def prepare_input(self, input: "StandardInput") -> str: """Prepares the input by applying the chat template to the input, which is formatted as an OpenAI conversation, and adding the generation prompt. """ @@ -166,7 +166,7 @@ def prepare_input(self, input: "ChatType") -> str: @validate_call def generate( # type: ignore self, - inputs: List[ChatType], + inputs: List[StandardInput], num_generations: int = 1, max_new_tokens: int = 128, frequency_penalty: float = 0.0, diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 39138918e8..bc77a53536 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -46,6 +46,7 @@ ROUTING_BATCH_FUNCTION_ATTR_NAME, STEP_ATTR_NAME, ) +from distilabel.utils.dicts import flatten_dict from distilabel.utils.files import list_files_in_dir from distilabel.utils.logging import setup_logging, stop_logging from distilabel.utils.serialization import ( @@ -1556,7 +1557,7 @@ def cache(self, path: "StrOrPath") -> None: batch_manager_step_dir = path.parent / "batch_manager_steps" / step_name batch_manager_step_dir.mkdir(parents=True, exist_ok=True) - # Store each `_BatchManagerStep` `_Batch`es in a separete file + # Store each `_BatchManagerStep` `_Batch`es in a separate file for buffered_step_name in step_dump["data"]: step_batches_dir = batch_manager_step_dir / buffered_step_name step_batches_dir.mkdir(parents=True, exist_ok=True) @@ -1718,7 +1719,16 @@ def _write(self, step_name: str) -> None: ) step_parquet_dir.mkdir() - table = pa.Table.from_pylist(self._buffers[step_name]) + try: + table = pa.Table.from_pylist(self._buffers[step_name]) + except pa.lib.ArrowInvalid as pae: + if ( + repr(pae) + != "ArrowInvalid('cannot mix struct and non-struct, non-null values')" + ): + raise pae + flattened_buffers = [flatten_dict(buf) for buf in self._buffers[step_name]] + table = pa.Table.from_pylist(flattened_buffers) last_schema = self._buffer_last_schema.get(step_name) if last_schema is None: diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 9fcb882c15..f785f9eb6a 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -30,6 +30,7 @@ from distilabel.steps.tasks.prometheus_eval import PrometheusEval from distilabel.steps.tasks.quality_scorer import QualityScorer from distilabel.steps.tasks.self_instruct import SelfInstruct +from distilabel.steps.tasks.structured_generation import StructuredGeneration from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback @@ -53,6 +54,7 @@ "PrometheusEval", "QualityScorer", "SelfInstruct", + "StructuredGeneration", "TextGeneration", "UltraFeedback", ] diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index 120f2758fe..7e2cfc2520 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.tasks.typing import FormattedInput from distilabel.steps.typing import StepOutput @@ -110,7 +110,7 @@ def _output_on_failure( """ # Create a dictionary with the outputs of the task (every output set to None) outputs = {output: None for output in self.outputs} - outputs["model_name"] = self.llm.model_name + outputs["model_name"] = self.llm.model_name # type: ignore outputs = self._maybe_add_raw_output( outputs, output, add_raw_output=self.add_raw_output ) @@ -142,12 +142,12 @@ class Task(_Task, Step): """ @abstractmethod - def format_input(self, input: Dict[str, Any]) -> "ChatType": + def format_input(self, input: Dict[str, Any]) -> "FormattedInput": """Abstract method to format the inputs of the task. It needs to receive an input as a Python dictionary, and generates an OpenAI chat-like list of dicts.""" pass - def _format_inputs(self, inputs: List[Dict[str, Any]]) -> List["ChatType"]: + def _format_inputs(self, inputs: List[Dict[str, Any]]) -> List["FormattedInput"]: """Formats the inputs of the task using the `format_input` method. Args: diff --git a/src/distilabel/steps/tasks/structured_generation.py b/src/distilabel/steps/tasks/structured_generation.py new file mode 100644 index 0000000000..ca43f9beba --- /dev/null +++ b/src/distilabel/steps/tasks/structured_generation.py @@ -0,0 +1,104 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Any, Dict, List, Union + +from distilabel.steps.tasks.base import Task +from distilabel.steps.tasks.typing import StructuredInput + + +class StructuredGeneration(Task): + """Generate structured content for a given `instruction` using an `LLM`. + + `StructuredGeneration` is a pre-defined task that defines the `instruction` and the `grammar` + as the inputs, and `generation` as the output. This task is used to generate structured content based on + the input instruction and following the schema provided within the `grammar` column per each + `instruction`. The `model_name` also returned as part of the output in order to enhance it. + + Attributes: + use_system_prompt: Whether to use the system prompt in the generation. Defaults to `True`, + which means that if the column `system_prompt` is defined within the input batch, then + the `system_prompt` will be used, otherwise, it will be ignored. + + Input columns: + - instruction (`str`): The instruction to generate structured content from. + - grammar (`Dict[str, Any]`): The grammar to generate structured content from. It should be a + Python dictionary with the keys `type` and `value`, where `type` should be one of `json` or + `regex`, and the `value` should be either the JSON schema or the regex pattern, respectively. + + Output columns: + - generation (`str`): The generated text matching the provided schema, if possible. + - model_name (`str`): The name of the model used to generate the text. + + Categories: + - outlines + - structured-generation + + Examples: + ```python + from distilabel.steps.tasks import StructuredGeneration + + task = StructuredGeneration(llm=LLM(...)) + ``` + """ + + use_system_prompt: bool = False + + @property + def inputs(self) -> List[str]: + """The input for the task are the `instruction` and the `grammar`. + Optionally, if the `use_system_prompt` flag is set to True, then the + `system_prompt` will be used too.""" + columns = ["instruction", "grammar"] + if self.use_system_prompt: + columns = ["system_prompt"] + columns + return columns + + def format_input(self, input: Dict[str, Any]) -> StructuredInput: + """The input is formatted as a `ChatType` assuming that the instruction + is the first interaction from the user within a conversation.""" + if not isinstance(input["instruction"], str): + raise ValueError( + f"Input `instruction` must be a string. Got: {input['instruction']}." + ) + + messages = [{"role": "user", "content": input["instruction"]}] + if self.use_system_prompt: + if "system_prompt" in input: + messages.insert( + 0, {"role": "system", "content": input["system_prompt"]} + ) + else: + warnings.warn( + "`use_system_prompt` is set to `True`, but no `system_prompt` in input batch, so it will be ignored.", + UserWarning, + stacklevel=2, + ) + + return (messages, input.get("grammar", None)) # type: ignore + + @property + def outputs(self) -> List[str]: + """The output for the task is the `generation` and the `model_name`.""" + return ["generation", "model_name"] + + def format_output( + self, output: Union[str, None], input: Dict[str, Any] + ) -> Dict[str, Any]: + """The output is formatted as a dictionary with the `generation`. The `model_name` + will be automatically included within the `process` method of `Task`. Note that even + if the `grammar` is defined to produce a JSON schema, this method will return the raw + output i.e. a string without any parsing.""" + return {"generation": output} diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index ece5344caf..28c207c287 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -37,7 +37,7 @@ class TextGeneration(Task): Output columns: - generation (`str`): The generated text. - - model_name (`str`): The model name used to generate the text. + - model_name (`str`): The name of the model used to generate the text. Categories: - text-generation diff --git a/src/distilabel/steps/tasks/typing.py b/src/distilabel/steps/tasks/typing.py index cbd6ffc09c..71e068cab1 100644 --- a/src/distilabel/steps/tasks/typing.py +++ b/src/distilabel/steps/tasks/typing.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Any, Dict, List, Literal, Tuple, Union from typing_extensions import TypedDict @@ -24,3 +24,16 @@ class ChatItem(TypedDict): ChatType = List[ChatItem] """ChatType is a type alias for a `list` of `dict`s following the OpenAI conversational format.""" + + +class Grammar(TypedDict): + type: Literal["json", "regex"] + value: Union[str, Dict[str, Any]] + + +StandardInput = ChatType +"""StandardInput is an alias for ChatType that defines the default / standard input produced by `format_input`.""" +StructuredInput = Tuple[StandardInput, Union[Grammar, None]] +"""StructuredInput defines a type produced by `format_input` when using either `StructuredGeneration` or a subclass of it.""" +FormattedInput = Union[StandardInput, StructuredInput] +"""FormattedInput is an alias for the union of `StandardInput` and `StructuredInput` as generated by `format_input` and expected by the `LLM`s.""" diff --git a/src/distilabel/utils/dicts.py b/src/distilabel/utils/dicts.py index 0ce96334f9..53d33d47f5 100644 --- a/src/distilabel/utils/dicts.py +++ b/src/distilabel/utils/dicts.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from collections import defaultdict from typing import Any, Dict, List, TypeVar @@ -33,3 +34,7 @@ def combine_dicts(*dicts: Dict[_K, Any]) -> Dict[_K, List[Any]]: for key, value in d.items(): combined_dict[key].append(value) return dict(combined_dict) + + +def flatten_dict(x: Dict[Any, Any]) -> Dict[Any, Any]: + return {k: json.dumps(v) if isinstance(v, dict) else v for k, v in x.items()} diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index 9caccf43c4..554cc44fec 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random from unittest.mock import AsyncMock, MagicMock, Mock, patch import nest_asyncio @@ -145,6 +146,7 @@ async def test_generate_via_openai_client( ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) + ... nest_asyncio.apply() assert llm.generate( @@ -159,6 +161,50 @@ async def test_generate_via_openai_client( ] ) == [(" Aenean hendrerit aliquam velit. ...",)] + @pytest.mark.asyncio + async def test_agenerate_with_grammar( + self, mock_inference_client: MagicMock, _: MagicMock + ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral", + grammar={"type": "regex", "value": r"\b[A-Z][a-z]*\b"}, + ) + llm._aclient = mock_inference_client + + llm._aclient.text_generation = AsyncMock( + return_value=" Aenean hendrerit aliquam velit. ..." + ) + + # Since there's a pseudo-random number within the generation kwargs, we set the seed + # here first to ensure reproducibility within the tests + random.seed(42) + + assert await llm.agenerate( + input=[ + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ] + ) == [" Aenean hendrerit aliquam velit. ..."] + + kwargs = { + "prompt": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + "max_new_tokens": 128, + "do_sample": False, + "typical_p": None, + "repetition_penalty": None, + "temperature": 1.0, + "top_p": None, + "top_k": None, + "stop_sequences": None, + "return_full_text": False, + "watermark": False, + "grammar": {"type": "regex", "value": "\\b[A-Z][a-z]*\\b"}, + "seed": 478163327, # pre-computed random value with `random.seed(42)` + } + mock_inference_client.text_generation.assert_called_with(**kwargs) + def test_serialization( self, mock_inference_client: MagicMock, mock_openai_client: MagicMock ) -> None: @@ -173,6 +219,7 @@ def test_serialization( "base_url": None, "tokenizer_id": None, "generation_kwargs": {}, + "grammar": None, "model_display_name": None, "use_openai_client": False, "structured_output": None, diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py new file mode 100644 index 0000000000..c4766aaa57 --- /dev/null +++ b/tests/unit/steps/tasks/test_structured_generation.py @@ -0,0 +1,124 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, List + +from distilabel.llms.base import LLM +from distilabel.llms.typing import GenerateOutput +from distilabel.pipeline.local import Pipeline +from distilabel.steps.tasks.structured_generation import StructuredGeneration +from distilabel.steps.tasks.typing import StructuredInput +from typing_extensions import override + + +class DummyStructuredLLM(LLM): + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + @override + def generate( # type: ignore + self, inputs: List["StructuredInput"], num_generations: int = 1, **kwargs: Any + ) -> List["GenerateOutput"]: + return [ + [json.dumps({"test": "output"}) for _ in range(num_generations)] + for _ in inputs + ] + + +class TestStructuredGeneration: + def test_format_input(self) -> None: + pipeline = Pipeline(name="unit-test-pipeline") + llm = DummyStructuredLLM() + task = StructuredGeneration(name="task", llm=llm, pipeline=pipeline) + + # 1. Including the `grammar` field within the input + assert task.format_input( + { + "instruction": "test", + "system_prompt": "test", + "grammar": {"type": "regex", "value": r"[a-zA-Z]+"}, + } + ) == ( + [{"role": "user", "content": "test"}], + {"type": "regex", "value": r"[a-zA-Z]+"}, + ) + + # 2. Not including the `grammar` field within the input + assert task.format_input({"instruction": "test", "system_prompt": "test"}) == ( + [{"role": "user", "content": "test"}], + None, + ) + + def test_format_input_with_system_prompt(self) -> None: + pipeline = Pipeline(name="unit-test-pipeline") + llm = DummyStructuredLLM() + task = StructuredGeneration( + name="task", + llm=llm, + pipeline=pipeline, + use_system_prompt=True, + ) + + assert task.format_input({"instruction": "test", "system_prompt": "test"}) == ( + [ + {"role": "system", "content": "test"}, + {"role": "user", "content": "test"}, + ], + None, + ) + + def test_process(self) -> None: + pipeline = Pipeline(name="unit-test-pipeline") + llm = DummyStructuredLLM() + task = StructuredGeneration(name="task", llm=llm, pipeline=pipeline) + assert next( + task.process( + [ + { + "instruction": "test", + "grammar": { + "type": "json", + "value": { + "properties": { + "test": {"title": "Test", "type": "string"} + }, + "required": ["test"], + "title": "Test", + "type": "object", + }, + }, + } + ] + ) + ) == [ + { + "instruction": "test", + "grammar": { + "type": "json", + "value": { + "properties": {"test": {"title": "Test", "type": "string"}}, + "required": ["test"], + "title": "Test", + "type": "object", + }, + }, + "generation": '{"test": "output"}', + "model_name": "test", + } + ] diff --git a/tests/unit/test_imports.py b/tests/unit/test_imports.py index cfccb1585f..9c3ff44d57 100644 --- a/tests/unit/test_imports.py +++ b/tests/unit/test_imports.py @@ -77,6 +77,7 @@ def test_imports() -> None: PrometheusEval, QualityScorer, SelfInstruct, + StructuredGeneration, TextGeneration, UltraFeedback, )