Skip to content

Commit

Permalink
Merge branch 'develop' into embedding-dataset-tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Jun 3, 2024
2 parents 2e5d9f8 + e61b598 commit b7cd785
Show file tree
Hide file tree
Showing 24 changed files with 426 additions and 61 deletions.
2 changes: 1 addition & 1 deletion docs/sections/pipeline_samples/examples/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `dis

This script makes use of [`LlamaCppLLM`][distilabel.llms.llamacpp.LlamaCppLLM] and the structured output capabilities thanks to [`outlines`](https://outlines-dev.github.io/outlines/welcome/) to generate RPG characters that adhere to a JSON schema.

It makes use of a local model which can be downlaoded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM].
It makes use of a local model which can be downloaded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM].

??? Run

Expand Down
8 changes: 4 additions & 4 deletions src/distilabel/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
Expand Down Expand Up @@ -163,7 +163,7 @@ def model_name(self) -> str:
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: StandardInput,
max_tokens: int = 128,
stop_sequences: Union[List[str], None] = None,
temperature: float = 1.0,
Expand Down Expand Up @@ -223,7 +223,7 @@ async def agenerate( # type: ignore
@override
def generate(
self,
inputs: List["ChatType"],
inputs: List["StandardInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand All @@ -232,7 +232,7 @@ def generate(
"""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["StandardInput"], **kwargs: Any
) -> "GenerateOutput":
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand Down
24 changes: 16 additions & 8 deletions src/distilabel/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
InstructorStructuredOutputType,
)
from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import FormattedInput, StandardInput
from distilabel.utils.docstring import Docstring

if in_notebook():
Expand Down Expand Up @@ -94,7 +94,7 @@ def model_name(self) -> str:
@abstractmethod
def generate(
self,
inputs: List["ChatType"],
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand Down Expand Up @@ -187,7 +187,9 @@ def generate_parsed_docstring(self) -> "Docstring":
"""
return parse_google_docstring(self.generate)

def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]:
def get_last_hidden_states(
self, inputs: List["StandardInput"]
) -> List["HiddenState"]:
"""Method to get the last hidden states of the model for a list of inputs.
Args:
Expand Down Expand Up @@ -231,6 +233,7 @@ class AsyncLLM(LLM):
"""

_event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None)
_new_event_loop: bool = PrivateAttr(default=False)

@property
def generate_parameters(self) -> List[inspect.Parameter]:
Expand All @@ -257,14 +260,16 @@ def event_loop(self) -> "asyncio.AbstractEventLoop":
self._event_loop = asyncio.get_running_loop()
if self._event_loop.is_closed():
self._event_loop = asyncio.new_event_loop() # type: ignore
self._new_event_loop = True
except RuntimeError:
self._event_loop = asyncio.new_event_loop()
self._new_event_loop = True
asyncio.set_event_loop(self._event_loop)
return self._event_loop

@abstractmethod
async def agenerate(
self, input: "ChatType", num_generations: int = 1, **kwargs: Any
self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any
) -> List[Union[str, None]]:
"""Method to generate a `num_generations` responses for a given input asynchronously,
and executed concurrently in `generate` method.
Expand All @@ -273,7 +278,7 @@ async def agenerate(

def generate(
self,
inputs: List["ChatType"],
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand All @@ -282,7 +287,7 @@ def generate(
"""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["FormattedInput"], **kwargs: Any
) -> List[List[Union[str, None]]]:
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand All @@ -301,8 +306,11 @@ def __del__(self) -> None:
"""Closes the event loop when the object is deleted."""
if sys.meta_path is None:
return
if self.event_loop is not None:
self.event_loop.close()

if self._new_event_loop:
if self._event_loop.is_running():
self._event_loop.stop()
self._event_loop.close()

@staticmethod
def _prepare_structured_output(
Expand Down
10 changes: 5 additions & 5 deletions src/distilabel/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from distilabel.llms.base import AsyncLLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
Expand Down Expand Up @@ -132,7 +132,7 @@ def load(self) -> None:
self.structured_output = structured_output

def _format_chat_to_cohere(
self, input: "ChatType"
self, input: "StandardInput"
) -> Tuple[Union[str, None], List["ChatMessage"], str]:
"""Formats the chat input to the Cohere Chat API conversational format.
Expand Down Expand Up @@ -169,7 +169,7 @@ def _format_chat_to_cohere(
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: StandardInput,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
k: Optional[int] = None,
Expand Down Expand Up @@ -241,15 +241,15 @@ async def agenerate( # type: ignore
@override
def generate(
self,
inputs: List["ChatType"],
inputs: List["StandardInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
"""Method to generate a list of responses asynchronously, returning the output
synchronously awaiting for the response of each input sent to `agenerate`."""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["StandardInput"], **kwargs: Any
) -> "GenerateOutput":
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand Down
8 changes: 4 additions & 4 deletions src/distilabel/llms/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.steps.base import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
Expand Down Expand Up @@ -131,7 +131,7 @@ def model_name(self) -> str:
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: StandardInput,
seed: Optional[int] = None,
max_new_tokens: int = 128,
temperature: float = 1.0,
Expand Down Expand Up @@ -188,7 +188,7 @@ async def agenerate( # type: ignore
@override
def generate(
self,
inputs: List["ChatType"],
inputs: List["StandardInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand All @@ -197,7 +197,7 @@ def generate(
"""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["StandardInput"], **kwargs: Any
) -> "GenerateOutput":
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand Down
36 changes: 27 additions & 9 deletions src/distilabel/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import random
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional, Union

from pydantic import (
Expand All @@ -31,7 +32,7 @@
from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import FormattedInput, Grammar, StandardInput
from distilabel.utils.itertools import grouper

if TYPE_CHECKING:
Expand Down Expand Up @@ -148,6 +149,11 @@ class InferenceEndpointsLLM(AsyncLLM):
model_display_name: Optional[str] = None
use_openai_client: bool = False

grammar: Optional[RuntimeParameter[Grammar]] = Field(
default=None,
description="The grammar to use across all the generations.",
)

_model_name: Optional[str] = PrivateAttr(default=None)
_tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None)
_api_key_env_var: str = PrivateAttr(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME)
Expand Down Expand Up @@ -201,6 +207,7 @@ def load(self) -> None: # noqa: C901
from huggingface_hub import (
AsyncInferenceClient,
InferenceClient,
constants,
get_inference_endpoint,
)
except ImportError as ie:
Expand All @@ -210,10 +217,14 @@ def load(self) -> None: # noqa: C901
) from ie

if self.api_key is None:
raise ValueError(
f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
)
if not Path(constants.HF_TOKEN_PATH).exists():
raise ValueError(
f"To use `{self.__class__.__name__}` an API key must be provided via"
" `api_key` attribute or runtime parameter, set the environment variable"
f" `{self._api_key_env_var}` or use the `huggingface-hub` CLI to login"
" with `huggingface-cli login`."
)
self.api_key = SecretStr(open(constants.HF_TOKEN_PATH).read().strip())

if self.model_id is not None:
client = InferenceClient()
Expand Down Expand Up @@ -290,7 +301,7 @@ def model_name(self) -> Union[str, None]: # type: ignore

async def _openai_agenerate(
self,
input: "ChatType",
input: "StandardInput",
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
Expand Down Expand Up @@ -322,7 +333,7 @@ async def _openai_agenerate(
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: "FormattedInput",
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
Expand Down Expand Up @@ -379,6 +390,10 @@ async def agenerate( # type: ignore
)
stop_sequences = stop_sequences[:4]

grammar = None
if isinstance(input, tuple):
input, grammar = input

if self.use_openai_client:
return await self._openai_agenerate(
input=input,
Expand Down Expand Up @@ -413,6 +428,9 @@ async def agenerate( # type: ignore
stop_sequences=stop_sequences,
return_full_text=return_full_text,
watermark=watermark,
# NOTE: `self.grammar` applies to all the generations, while `grammar` is intended
# to be different per each input, and those are not intended to be used together
grammar=grammar or self.grammar, # type: ignore
# NOTE: here to ensure that the cache is not used and a different response is
# generated every time
seed=seed or random.randint(0, 2147483647),
Expand All @@ -429,7 +447,7 @@ async def agenerate( # type: ignore
@override
def generate(
self,
inputs: List["ChatType"],
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
Expand All @@ -438,7 +456,7 @@ def generate(
"""

async def agenerate(
inputs: List["ChatType"], **kwargs: Any
inputs: List["FormattedInput"], **kwargs: Any
) -> "GenerateOutput":
"""Internal function to parallelize the asynchronous generation of responses."""
tasks = [
Expand Down
10 changes: 6 additions & 4 deletions src/distilabel/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from distilabel.llms.chat_templates import CHATML_TEMPLATE
from distilabel.llms.mixins import CudaDevicePlacementMixin
from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput

if TYPE_CHECKING:
from transformers import Pipeline
Expand Down Expand Up @@ -130,7 +130,7 @@ def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model

def prepare_input(self, input: "ChatType") -> str:
def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input by applying the chat template to the input, which is formatted
as an OpenAI conversation, and adding the generation prompt.
"""
Expand All @@ -143,7 +143,7 @@ def prepare_input(self, input: "ChatType") -> str:
@validate_call
def generate( # type: ignore
self,
inputs: List[ChatType],
inputs: List[StandardInput],
num_generations: int = 1,
max_new_tokens: int = 128,
temperature: float = 0.1,
Expand Down Expand Up @@ -189,7 +189,9 @@ def generate( # type: ignore
for output in outputs
]

def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]:
def get_last_hidden_states(
self, inputs: List["StandardInput"]
) -> List["HiddenState"]:
"""Gets the last `hidden_states` of the model for the given inputs. It doesn't
execute the task head.
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from distilabel.llms.base import AsyncLLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput

if TYPE_CHECKING:
from litellm import Choices
Expand Down Expand Up @@ -90,7 +90,7 @@ def model_name(self) -> str:
@validate_call
async def agenerate( # type: ignore
self,
input: ChatType,
input: StandardInput,
num_generations: int = 1,
functions: Optional[List] = None,
function_call: Optional[str] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from distilabel.llms.base import LLM
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.typing import StandardInput

if TYPE_CHECKING:
from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList
Expand Down Expand Up @@ -128,7 +128,7 @@ def model_name(self) -> str:
@validate_call
def generate( # type: ignore
self,
inputs: List[ChatType],
inputs: List[StandardInput],
num_generations: int = 1,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
Expand Down
Loading

0 comments on commit b7cd785

Please sign in to comment.