Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Magpie and MagpieGenerator tasks #778

Merged
merged 32 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e1b4c8b
Move `CudaDevicePlacementMixin` to new module
gabrielmbmb Jul 9, 2024
d557542
Initial work for implementing Magpie
gabrielmbmb Jul 9, 2024
755f7ec
Simplify magpie implementation
gabrielmbmb Jul 9, 2024
1719d15
Remove `use_open_ai` and add `MagpieChatTemplateMixin` to
gabrielmbmb Jul 10, 2024
775ca4e
Add `MagpieChatTemplateMixin` to `vLLM`
gabrielmbmb Jul 10, 2024
9ff6eeb
Add `MagpieGenerator` task
gabrielmbmb Jul 10, 2024
7b9bde5
Move `CudaDevicePlacementMixins` to new subpackage
gabrielmbmb Jul 10, 2024
dadac54
Fix unit tests
gabrielmbmb Jul 10, 2024
844ec57
Fix docstrings
gabrielmbmb Jul 10, 2024
04ecc3a
Mock `HF_TOKEN` environment variable
gabrielmbmb Jul 10, 2024
a15752a
Fix list index out of range
gabrielmbmb Jul 11, 2024
b46cd33
Fix `MagpieGenerator` last batch
gabrielmbmb Jul 11, 2024
75bd827
Add `only_instruction` attribute
gabrielmbmb Jul 11, 2024
a86f640
Update categories
gabrielmbmb Jul 11, 2024
463f622
testing
gabrielmbmb Jul 11, 2024
953b933
Worth trying
gabrielmbmb Jul 11, 2024
53ff036
Add examples
gabrielmbmb Jul 11, 2024
ba85907
Add magpie unit tests
gabrielmbmb Jul 11, 2024
b2e8805
Fix docstring
gabrielmbmb Jul 11, 2024
e52ae3f
Update docstrings
gabrielmbmb Jul 11, 2024
5736a25
Apply suggestions from code review
gabrielmbmb Jul 11, 2024
32b1725
Update to `huggingface_hub >= 0.22.0`
gabrielmbmb Jul 11, 2024
e899133
Add generation with `chat_completion`
gabrielmbmb Jul 12, 2024
91d05e2
Merge branch 'magpie' of https://github.com/argilla-io/rlxf into magpie
gabrielmbmb Jul 12, 2024
87371ec
Update `agenerate` arguments
gabrielmbmb Jul 15, 2024
bf350e3
Update unit tests
gabrielmbmb Jul 15, 2024
360433f
Fix `tools` were not being used
gabrielmbmb Jul 15, 2024
ef68210
Update unit tests
gabrielmbmb Jul 15, 2024
9ee7096
Fix list of tuples instead of list of list
gabrielmbmb Jul 15, 2024
3863eb3
Add missing docstring
gabrielmbmb Jul 15, 2024
87c11cc
Add `chat_completion` unit tests
gabrielmbmb Jul 15, 2024
cd3cc5d
Fix `GroqLLM.generate` unit test after updating `_agenerate`
gabrielmbmb Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/distilabel/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from distilabel.llms.litellm import LiteLLM
from distilabel.llms.llamacpp import LlamaCppLLM
from distilabel.llms.mistral import MistralLLM
from distilabel.llms.mixins import CudaDevicePlacementMixin
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.llms.moa import MixtureOfAgentsLLM
from distilabel.llms.ollama import OllamaLLM
from distilabel.llms.openai import OpenAILLM
Expand Down
212 changes: 90 additions & 122 deletions src/distilabel/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import os
import random
import sys
import warnings
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from pydantic import (
Field,
Expand All @@ -28,6 +29,7 @@
from typing_extensions import override

from distilabel.llms.base import AsyncLLM
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import (
Expand All @@ -42,15 +44,13 @@

if TYPE_CHECKING:
from huggingface_hub import AsyncInferenceClient
from openai import AsyncOpenAI
from transformers import PreTrainedTokenizer


class InferenceEndpointsLLM(AsyncLLM):
class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin):
"""InferenceEndpoints LLM implementation running the async API client.

This LLM will internally use `huggingface_hub.AsyncInferenceClient` or `openai.AsyncOpenAI`
depending on the `use_openai_client` attribute.
This LLM will internally use `huggingface_hub.AsyncInferenceClient`.

Attributes:
model_id: the model ID to use for the LLM as available in the Hugging Face Hub, which
Expand All @@ -63,7 +63,6 @@ class InferenceEndpointsLLM(AsyncLLM):
tokenizer_id: the tokenizer ID to use for the LLM as available in the Hugging Face Hub.
Defaults to `None`, but defining one is recommended to properly format the prompt.
model_display_name: the model display name to use for the LLM. Defaults to `None`.
use_openai_client: whether to use the OpenAI client instead of the Hugging Face client.

Icon:
`:hugging:`
Expand Down Expand Up @@ -137,7 +136,6 @@ class InferenceEndpointsLLM(AsyncLLM):

tokenizer_id: Optional[str] = None
model_display_name: Optional[str] = None
use_openai_client: bool = False

structured_output: Optional[RuntimeParameter[StructuredOutputType]] = Field(
default=None,
Expand All @@ -149,7 +147,7 @@ class InferenceEndpointsLLM(AsyncLLM):
_model_name: Optional[str] = PrivateAttr(default=None)
_tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None)
_api_key_env_var: str = PrivateAttr(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME)
_aclient: Optional[Union["AsyncInferenceClient", "AsyncOpenAI"]] = PrivateAttr(...)
_aclient: Optional["AsyncInferenceClient"] = PrivateAttr(...)

@model_validator(mode="after") # type: ignore
def only_one_of_model_id_endpoint_name_or_base_url_provided(
Expand All @@ -161,9 +159,19 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided(

if self.base_url and (self.model_id or self.endpoint_name):
self._logger.warning( # type: ignore
f"Since the `base_url={self.base_url}` is available and either one of `model_id` or `endpoint_name`"
" is also provided, the `base_url` will either be ignored or overwritten with the one generated"
" from either of those args, for serverless or dedicated inference endpoints, respectively."
f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
" or `endpoint_name` is also provided, the `base_url` will either be ignored"
" or overwritten with the one generated from either of those args, for serverless"
" or dedicated inference endpoints, respectively."
)

if self.model_id and self.tokenizer_id is None:
self.tokenizer_id = self.model_id

if self.use_magpie_template and self.tokenizer_id is None:
raise ValueError(
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
" set a `tokenizer_id` and try again."
)

if self.base_url and not (self.model_id or self.endpoint_name):
Expand All @@ -176,19 +184,16 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided(
return self

raise ValidationError(
"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is provided too,"
" it will be overwritten instead. Found `model_id`={self.model_id}, `endpoint_name`={self.endpoint_name},"
f" and `base_url`={self.base_url}."
f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
)

def load(self) -> None: # noqa: C901
"""Loads the either the `AsyncInferenceClient` or the `AsyncOpenAI` client to benefit
from async requests, running the Hugging Face Inference Endpoint underneath via the
`/v1/chat/completions` endpoint, exposed for the models running on TGI using the
`text-generation` task.
"""Loads the `AsyncInferenceClient` client to connect to the Hugging Face Inference
Endpoint.

Raises:
ImportError: if the `openai` Python client is not installed.
ImportError: if the `huggingface-hub` Python client is not installed.
ValueError: if the model is not currently deployed or is not running the TGI framework.
ImportError: if the `transformers` Python client is not installed.
Expand Down Expand Up @@ -234,31 +239,16 @@ def load(self) -> None: # noqa: C901
)
if client.status in ["paused", "scaledToZero"]:
client.resume().wait(timeout=300)
elif client.status in ["initializing"]:
elif client.status == "initializing":
client.wait(timeout=300)

self.base_url = client.url
self._model_name = client.repository

if self.use_openai_client:
try:
from openai import AsyncOpenAI
except ImportError as ie:
raise ImportError(
"OpenAI Python client is not installed. Please install it using"
" `pip install openai`."
) from ie

self._aclient = AsyncOpenAI(
base_url=self.base_url,
api_key=self.api_key.get_secret_value(),
max_retries=6,
)
else:
self._aclient = AsyncInferenceClient(
model=self.base_url,
token=self.api_key.get_secret_value(),
)
self._aclient = AsyncInferenceClient(
model=self.base_url,
token=self.api_key.get_secret_value(),
)

if self.tokenizer_id:
try:
Expand All @@ -283,43 +273,69 @@ def model_name(self) -> Union[str, None]: # type: ignore
or self.base_url
)

async def _openai_agenerate(
self,
input: "StandardInput",
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: Optional[float] = None,
stop: Optional[Union[str, List[str]]] = None,
) -> GenerateOutput:
"""Generates completions for the given input using the OpenAI async client."""
completion = await self._aclient.chat.completions.create( # type: ignore
messages=input, # type: ignore
model="tgi",
max_tokens=max_new_tokens,
n=1,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
stop=stop,
timeout=50,
)
if completion.choices[0].message.content is None:
self._logger.warning( # type: ignore
f"⚠️ Received no response using OpenAI client (model: '{self.model_name}')."
f" Finish reason was: {completion.choices[0].finish_reason}"
def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.

Args:
input: the input list containing chat items.

Returns:
The prompt to send to the LLM.
"""
prompt: str = (
self._tokenizer.apply_chat_template( # type: ignore
conversation=input, # type: ignore
tokenize=False,
add_generation_prompt=True,
)
return [completion.choices[0].message.content]
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)

def get_structured_output(
self, input: FormattedInput
) -> Union[Dict[str, Any], None]:
"""Gets the structured output (if any) for the given input.

Args:
input: a single input in chat format to generate responses for.

Returns:
The structured output that will be passed as `grammer` to the inference endpoint
or `None` if not required.
"""
structured_output = None

# Specific structured output per input
if isinstance(input, tuple):
input, structured_output = input
structured_output = {
"type": structured_output["format"],
"value": structured_output["schema"],
}

# Same structured output for all the inputs
if structured_output is None and self.structured_output is not None:
try:
structured_output = {
"type": self.structured_output["format"],
"value": self.structured_output["schema"],
}
except KeyError as e:
raise ValueError(
"To use the structured output you have to inform the `format` and `schema` in "
"the `structured_output` attribute."
) from e

return structured_output

@validate_call
async def agenerate( # type: ignore
self,
input: FormattedInput,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repetition_penalty: Optional[float] = None,
temperature: float = 1.0,
do_sample: bool = False,
Expand All @@ -331,21 +347,16 @@ async def agenerate( # type: ignore
seed: Optional[int] = None,
watermark: bool = False,
) -> GenerateOutput:
"""Generates completions for the given input using the OpenAI async client.
"""Generates completions for the given input using the async client.

Args:
input: a single input in chat format to generate responses for.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
frequency_penalty: the repetition penalty to use for the generation. Defaults
to `0.0`. Only applies if `use_openai_client=True`.
presence_penalty: the presence penalty to use for the generation. Defaults to
`0.0`. Only applies if `use_openai_client=True`.
repetition_penalty: the repetition penalty to use for the generation. Defaults
to `None`. Only applies if `use_openai_client=False`.
to `None`.
temperature: the temperature to use for the generation. Defaults to `1.0`.
do_sample: whether to use sampling for the generation. Defaults to `False`.
Only applies if `use_openai_client=False`.
top_k: the top-k value to use for the generation. Defaults to `0.8`, since neither
`0.0` nor `1.0` are valid values in TGI.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
Expand Down Expand Up @@ -373,55 +384,12 @@ async def agenerate( # type: ignore
)
stop_sequences = stop_sequences[:4]

structured_output = None
if isinstance(input, tuple):
input, structured_output = input
structured_output = {
"type": structured_output["format"],
"value": structured_output["schema"],
}

# NOTE: `self.structured_output` applies to all the generations, while `structured_output` i.e. the
# value included within the tuple provided as `input` to this method, is intended to be different per
# each input, so those should not be used together. Meaning that it should be either provided at attribute
# level i.e. self, or via a column within each input i.e. row.
if structured_output is None and self.structured_output is not None:
try:
structured_output = {
"type": self.structured_output["format"],
"value": self.structured_output["schema"],
}
except KeyError as e:
raise ValueError(
"To use the structured output you have to inform the `format` and `schema` in "
"the `structured_output` attribute."
) from e

if self.use_openai_client:
return await self._openai_agenerate(
input=input,
max_new_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
stop=stop_sequences,
)

if self._tokenizer is not None:
prompt = self._tokenizer.apply_chat_template( # type: ignore
conversation=input, # type: ignore
tokenize=False,
add_generation_prompt=True,
)
else:
# TODO: should we apply a default chat template here instead? e.g. ChatML
prompt = "\n".join([message["content"] for message in input])
structured_output = self.get_structured_output(input)

completion = None
try:
completion = await self._aclient.text_generation( # type: ignore
prompt=prompt, # type: ignore
prompt=self.prepare_input(input), # type: ignore
max_new_tokens=max_new_tokens,
do_sample=do_sample,
typical_p=typical_p,
Expand All @@ -435,7 +403,7 @@ async def agenerate( # type: ignore
grammar=structured_output, # type: ignore
# NOTE: here to ensure that the cache is not used and a different response is
# generated every time
seed=seed or random.randint(0, 2147483647),
seed=seed or random.randint(0, sys.maxsize),
)
except Exception as e:
self._logger.warning( # type: ignore
Expand Down
Loading
Loading