Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StructuredGeneration task and support for grammar in InferenceEndpointsLLM #680

Merged
merged 13 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sections/pipeline_samples/examples/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `dis

This script makes use of [`LlamaCppLLM`][distilabel.llms.llamacpp.LlamaCppLLM] and the structured output capabilities thanks to [`outlines`](https://outlines-dev.github.io/outlines/welcome/) to generate RPG characters that adhere to a JSON schema.

It makes use of a local model which can be downlaoded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM].
It makes use of a local model which can be downloaded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM].

??? Run

Expand Down
8 changes: 4 additions & 4 deletions src/distilabel/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
Expand Down Expand Up @@ -163,7 +163,7 @@ def model_name(self) -> str:
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: StandardInput,
max_tokens: int = 128,
stop_sequences: Union[List[str], None] = None,
temperature: float = 1.0,
Expand Down Expand Up @@ -223,7 +223,7 @@ async def agenerate( # type: ignore
@override
def generate(
self,
inputs: List["ChatType"],
inputs: List["StandardInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand All @@ -232,7 +232,7 @@ def generate(
"""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["StandardInput"], **kwargs: Any
) -> "GenerateOutput":
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand Down
14 changes: 8 additions & 6 deletions src/distilabel/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
InstructorStructuredOutputType,
)
from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import FormattedInput, StandardInput
from distilabel.utils.docstring import Docstring

if in_notebook():
Expand Down Expand Up @@ -94,7 +94,7 @@ def model_name(self) -> str:
@abstractmethod
def generate(
self,
inputs: List["ChatType"],
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand Down Expand Up @@ -187,7 +187,9 @@ def generate_parsed_docstring(self) -> "Docstring":
"""
return parse_google_docstring(self.generate)

def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]:
def get_last_hidden_states(
self, inputs: List["StandardInput"]
) -> List["HiddenState"]:
"""Method to get the last hidden states of the model for a list of inputs.

Args:
Expand Down Expand Up @@ -264,7 +266,7 @@ def event_loop(self) -> "asyncio.AbstractEventLoop":

@abstractmethod
async def agenerate(
self, input: "ChatType", num_generations: int = 1, **kwargs: Any
self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any
) -> List[Union[str, None]]:
"""Method to generate a `num_generations` responses for a given input asynchronously,
and executed concurrently in `generate` method.
Expand All @@ -273,7 +275,7 @@ async def agenerate(

def generate(
self,
inputs: List["ChatType"],
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand All @@ -282,7 +284,7 @@ def generate(
"""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["FormattedInput"], **kwargs: Any
) -> List[List[Union[str, None]]]:
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand Down
10 changes: 5 additions & 5 deletions src/distilabel/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from distilabel.llms.base import AsyncLLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
Expand Down Expand Up @@ -132,7 +132,7 @@ def load(self) -> None:
self.structured_output = structured_output

def _format_chat_to_cohere(
self, input: "ChatType"
self, input: "StandardInput"
) -> Tuple[Union[str, None], List["ChatMessage"], str]:
"""Formats the chat input to the Cohere Chat API conversational format.
Expand Down Expand Up @@ -169,7 +169,7 @@ def _format_chat_to_cohere(
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: StandardInput,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
k: Optional[int] = None,
Expand Down Expand Up @@ -241,15 +241,15 @@ async def agenerate( # type: ignore
@override
def generate(
self,
inputs: List["ChatType"],
inputs: List["StandardInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
"""Method to generate a list of responses asynchronously, returning the output
synchronously awaiting for the response of each input sent to `agenerate`."""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["StandardInput"], **kwargs: Any
) -> "GenerateOutput":
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand Down
8 changes: 4 additions & 4 deletions src/distilabel/llms/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.steps.base import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
Expand Down Expand Up @@ -131,7 +131,7 @@ def model_name(self) -> str:
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: StandardInput,
seed: Optional[int] = None,
max_new_tokens: int = 128,
temperature: float = 1.0,
Expand Down Expand Up @@ -188,7 +188,7 @@ async def agenerate( # type: ignore
@override
def generate(
self,
inputs: List["ChatType"],
inputs: List["StandardInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand All @@ -197,7 +197,7 @@ def generate(
"""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["StandardInput"], **kwargs: Any
) -> "GenerateOutput":
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand Down
22 changes: 17 additions & 5 deletions src/distilabel/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import FormattedInput, Grammar, StandardInput
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
Expand Down Expand Up @@ -148,6 +148,11 @@ class InferenceEndpointsLLM(AsyncLLM):
model_display_name: Optional[str] = None
use_openai_client: bool = False

grammar: Optional[RuntimeParameter[Grammar]] = Field(
default=None,
description="The grammar to use across all the generations.",
)

_model_name: Optional[str] = PrivateAttr(default=None)
_tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None)
_api_key_env_var: str = PrivateAttr(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME)
Expand Down Expand Up @@ -290,7 +295,7 @@ def model_name(self) -> Union[str, None]: # type: ignore

async def _openai_agenerate(
self,
input: "ChatType",
input: "StandardInput",
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
Expand Down Expand Up @@ -322,7 +327,7 @@ async def _openai_agenerate(
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: "FormattedInput",
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
Expand Down Expand Up @@ -379,6 +384,10 @@ async def agenerate( # type: ignore
)
stop_sequences = stop_sequences[:4]

grammar = None
if isinstance(input, tuple):
input, grammar = input

if self.use_openai_client:
return await self._openai_agenerate(
input=input,
Expand Down Expand Up @@ -413,6 +422,9 @@ async def agenerate( # type: ignore
stop_sequences=stop_sequences,
return_full_text=return_full_text,
watermark=watermark,
# NOTE: `self.grammar` applies to all the generations, while `grammar` is intended
# to be different per each input, and those are not intended to be used together
grammar=grammar or self.grammar, # type: ignore
# NOTE: here to ensure that the cache is not used and a different response is
# generated every time
seed=seed or random.randint(0, 2147483647),
Expand All @@ -429,7 +441,7 @@ async def agenerate( # type: ignore
@override
def generate(
self,
inputs: List["ChatType"],
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand All @@ -438,7 +450,7 @@ def generate(
"""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["FormattedInput"], **kwargs: Any
) -> "GenerateOutput":
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand Down
10 changes: 6 additions & 4 deletions src/distilabel/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from distilabel.llms.chat_templates import CHATML_TEMPLATE
from distilabel.llms.mixins import CudaDevicePlacementMixin
from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput

if TYPE_CHECKING:
from transformers import Pipeline
Expand Down Expand Up @@ -130,7 +130,7 @@ def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model

def prepare_input(self, input: "ChatType") -> str:
def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input by applying the chat template to the input, which is formatted
as an OpenAI conversation, and adding the generation prompt.
"""
Expand All @@ -143,7 +143,7 @@ def prepare_input(self, input: "ChatType") -> str:
@validate_call
def generate( # type: ignore
self,
inputs: List[ChatType],
inputs: List[StandardInput],
num_generations: int = 1,
max_new_tokens: int = 128,
temperature: float = 0.1,
Expand Down Expand Up @@ -189,7 +189,9 @@ def generate( # type: ignore
for output in outputs
]

def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]:
def get_last_hidden_states(
self, inputs: List["StandardInput"]
) -> List["HiddenState"]:
"""Gets the last `hidden_states` of the model for the given inputs. It doesn't
execute the task head.
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput

if TYPE_CHECKING:
from litellm import Choices
Expand Down Expand Up @@ -90,7 +90,7 @@ def model_name(self) -> str:
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: StandardInput,
num_generations: int = 1,
functions: Optional[List] = None,
function_call: Optional[str] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from distilabel.llms.base import LLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput

if TYPE_CHECKING:
from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList
Expand Down Expand Up @@ -128,7 +128,7 @@ def model_name(self) -> str:
@validate_call
def generate( # type: ignore
self,
inputs: List[ChatType],
inputs: List[StandardInput],
num_generations: int = 1,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
Expand Down
8 changes: 4 additions & 4 deletions src/distilabel/llms/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
Expand Down Expand Up @@ -129,7 +129,7 @@ def model_name(self) -> str:
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: StandardInput,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
Expand Down Expand Up @@ -180,7 +180,7 @@ async def agenerate( # type: ignore
@override
def generate(
self,
inputs: List["ChatType"],
inputs: List["StandardInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand All @@ -189,7 +189,7 @@ def generate(
"""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["StandardInput"], **kwargs: Any
) -> "GenerateOutput":
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from distilabel.llms.base import AsyncLLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput

if TYPE_CHECKING:
from ollama import AsyncClient
Expand Down Expand Up @@ -117,7 +117,7 @@ def model_name(self) -> str:
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: StandardInput,
num_generations: int = 1,
format: Literal["", "json"] = "",
# TODO: include relevant options from `Options` in `agenerate` method.
Expand Down
Loading
Loading