Skip to content

Commit

Permalink
Return generations with list of strings and token count from _raw_res…
Browse files Browse the repository at this point in the history
…ponse
  • Loading branch information
plaguss committed Oct 15, 2024
1 parent 394984f commit 9003fa0
Show file tree
Hide file tree
Showing 19 changed files with 238 additions and 124 deletions.
21 changes: 5 additions & 16 deletions src/distilabel/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@
get_type_hints,
)

import orjson
from httpx import AsyncClient
from pydantic import Field, PrivateAttr, SecretStr, validate_call

from distilabel.llms.base import AsyncLLM
from distilabel.llms.statistics import compute_tokens
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import (
Expand All @@ -42,7 +40,6 @@

from anthropic import AsyncAnthropic
from anthropic.types import Message
from tokenizers import Tokenizer


_ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY"
Expand Down Expand Up @@ -148,7 +145,6 @@ class User(BaseModel):

_api_key_env_var: str = PrivateAttr(default=_ANTHROPIC_API_KEY_ENV_VAR_NAME)
_aclient: Optional["AsyncAnthropic"] = PrivateAttr(...)
_tokenizer: "Tokenizer" = PrivateAttr(...)

def _check_model_exists(self) -> None:
"""Checks if the specified model exists in the available models."""
Expand Down Expand Up @@ -205,10 +201,6 @@ def load(self) -> None:
if structured_output := result.get("structured_output"):
self.structured_output = structured_output

from anthropic._tokenizers import sync_get_tokenizer

self._tokenizer = sync_get_tokenizer()

@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
Expand Down Expand Up @@ -275,15 +267,12 @@ async def agenerate( # type: ignore
**kwargs
) # type: ignore
if structured_output:
str_response = completion.model_dump_json()
raw_response = completion._raw_response
return {
"generations": str_response,
"generations": [completion.model_dump_json()],
"statistics": {
"input_tokens": compute_tokens(input, self._tokenizer.encode),
"output_tokens": compute_tokens(
orjson.dumps(str_response).decode("utf-8"),
self._tokenizer.encode,
),
"input_tokens": raw_response.usage.input_tokens,
"output_tokens": raw_response.usage.output_tokens,
},
}

Expand All @@ -293,7 +282,7 @@ async def agenerate( # type: ignore
f" Finish reason was: {completion.stop_reason}"
)
return {
"generations": content,
"generations": [content],
"statistics": {
"input_tokens": completion.usage.input_tokens,
"output_tokens": completion.usage.output_tokens,
Expand Down
12 changes: 5 additions & 7 deletions src/distilabel/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)

if TYPE_CHECKING:
from cohere import AsyncClient, ChatMessage, NonStreamedChatResponse
from cohere import AsyncClient, ChatMessage, Message
from pydantic import BaseModel


Expand Down Expand Up @@ -287,15 +287,13 @@ async def agenerate( # type: ignore
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore

response: Union[
"NonStreamedChatResponse", "BaseModel"
] = await self._aclient.chat(**kwargs) # type: ignore
response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs) # type: ignore

if structured_output:
# TODO: Refactor the dict response, it's quite similar in many LLMs
str_response = response.model_dump_json()
return {
"generations": str_response,
"generations": [str_response],
"statistics": {
"input_tokens": compute_tokens(input, self._tokenizer.encode),
"output_tokens": compute_tokens(
Expand All @@ -311,15 +309,15 @@ async def agenerate( # type: ignore
f" Finish reason was: {response.finish_reason}"
)
return {
"generations": None,
"generations": [None],
"statistics": {
"input_tokens": compute_tokens(input, self._tokenizer.encode),
"output_tokens": 0,
},
}

return {
"generations": text,
"generations": [text],
"statistics": {
"input_tokens": compute_tokens(input, self._tokenizer.encode),
"output_tokens": compute_tokens(text, self._tokenizer.encode),
Expand Down
16 changes: 9 additions & 7 deletions src/distilabel/llms/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,21 @@ async def agenerate( # type: ignore
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output)

generations = []
completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
if structured_output:
generations.append(completion.model_dump_json())
raw_response = completion._raw_response
return {
"generations": generations,
"generations": [completion.model_dump_json()],
"statistics": {
# TODO: Need a way of knowing the tokenizer.
"input_tokens": 0,
"output_tokens": 0,
"input_tokens": raw_response.usage.prompt_tokens
if raw_response.usage
else 0,
"output_tokens": raw_response.usage.completion_tokens
if raw_response.usage
else 0,
},
}

generations = []
for choice in completion.choices:
if (content := choice.message.content) is None:
self._logger.warning( # type: ignore
Expand Down
30 changes: 27 additions & 3 deletions src/distilabel/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Callable, List, Optional, Union

import orjson
from pydantic import Field, PrivateAttr, validate_call

from distilabel.llms.base import AsyncLLM
Expand Down Expand Up @@ -194,6 +195,7 @@ async def agenerate( # type: ignore # noqa: C901
A list of lists of strings containing the generated responses for each input.
"""
import litellm
from litellm import token_counter

structured_output = None
if isinstance(input, tuple):
Expand Down Expand Up @@ -256,10 +258,24 @@ async def _call_aclient_until_n_choices() -> List["Choices"]:
raise e

generations = []
input_tokens = token_counter(model=self.model, messages=input)
output_tokens = 0

if self.structured_output:
generations.append([choice.model_dump_json() for choice in choices])
return generations
for choice in choices:
generations.append(choice.model_dump_json())
output_tokens += token_counter(
model=self.model,
text=orjson.dumps(choice.model_dump_json()).decode("utf-8"),
)

return {
"generations": generations,
"statistics": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
},
}

for choice in choices:
if (content := choice.message.content) is None:
Expand All @@ -268,4 +284,12 @@ async def _call_aclient_until_n_choices() -> List["Choices"]:
f" Finish reason was: {choice.finish_reason}"
)
generations.append(content)
return generations
output_tokens += token_counter(model=self.model, text=content)

return {
"generations": generations,
"statistics": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
},
}
18 changes: 15 additions & 3 deletions src/distilabel/llms/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,14 @@ async def agenerate( # type: ignore
completion = await self._aclient.chat.complete_async(**kwargs) # type: ignore

if structured_output:
generations.append(completion.model_dump_json())
return generations
raw_response = completion._raw_response
return {
"generations": [completion.model_dump_json()],
"statistics": {
"input_tokens": raw_response.usage.prompt_tokens,
"output_tokens": raw_response.usage.completion_tokens,
},
}

for choice in completion.choices:
if (content := choice.message.content) is None:
Expand All @@ -231,4 +237,10 @@ async def agenerate( # type: ignore
f" Finish reason was: {choice.finish_reason}"
)
generations.append(content)
return generations
return {
"generations": generations,
"statistics": {
"input_tokens": completion.usage.prompt_tokens,
"output_tokens": completion.usage.completion_tokens,
},
}
12 changes: 11 additions & 1 deletion src/distilabel/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ async def agenerate( # type: ignore
A list of strings as completion for the given input.
"""
text = None
input_tokens = 0
output_tokens = 0
try:
completion: Dict[str, Any] = await self._aclient.chat( # type: ignore
model=self.model,
Expand All @@ -169,10 +171,18 @@ async def agenerate( # type: ignore
keep_alive=keep_alive,
)
text = completion["message"]["content"]
input_tokens = completion["prompt_eval_count"]
output_tokens = completion["eval_count"]
except Exception as e:
self._logger.warning( # type: ignore
f"⚠️ Received no response using Ollama client (model: '{self.model_name}')."
f" Finish reason was: {e}"
)

return [text]
return {
"generations": [text],
"statistics": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
},
}
38 changes: 7 additions & 31 deletions src/distilabel/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from distilabel import envs
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
from distilabel.llms.base import AsyncLLM
from distilabel.llms.statistics import compute_tokens
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType
Expand All @@ -32,8 +31,6 @@
from openai.types import Batch as OpenAIBatch
from openai.types import FileObject as OpenAIFileObject
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from pydantic import BaseModel
from tiktoken.core import Encoding


_OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"
Expand Down Expand Up @@ -170,7 +167,6 @@ class User(BaseModel):
_api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME)
_client: "OpenAI" = PrivateAttr(None)
_aclient: "AsyncOpenAI" = PrivateAttr(None)
_tokenizer: "Encoding" = PrivateAttr(None)

def load(self) -> None:
"""Loads the `AsyncOpenAI` client to benefit from async requests."""
Expand Down Expand Up @@ -213,10 +209,6 @@ def load(self) -> None:
self._aclient = result.get("client") # type: ignore
if structured_output := result.get("structured_output"):
self.structured_output = structured_output
# It must be version 0.8.0 at least.
import tiktoken

self._tokenizer = tiktoken.encoding_for_model(self.model)

def unload(self) -> None:
"""Set clients to `None` as they both contain `thread._RLock` which cannot be pickled
Expand Down Expand Up @@ -315,36 +307,20 @@ async def agenerate( # type: ignore

completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
if structured_output:
# Note: Instructor extracts the content from the structured output, so we need to
# add the token count
generation = self._generations_from_structured_output(completion)

return {
"generations": generation,
"generations": [completion.model_dump_json()],
"statistics": {
"input_tokens": compute_tokens(input, self._tokenizer.encode),
"output_tokens": compute_tokens(
orjson.dumps(generation).decode("utf-8"), self._tokenizer.encode
),
"input_tokens": completion._raw_response.usage.prompt_tokens
if completion._raw_response
else 0,
"output_tokens": completion._raw_response.usage.completion_tokens
if completion._raw_response
else 0,
},
}

return self._generations_from_openai_completion(completion)

def _generations_from_structured_output(
self, completion: "BaseModel"
) -> "GenerateOutput":
"""Get the generations from the structured output object.
Args:
completion: an instance of `pydantic.BaseModel` with the content of the structuted
output.
Returns:
A list with the content of the structured output.
"""
return [completion.model_dump_json()]

def _generations_from_openai_completion(
self, completion: "OpenAIChatCompletion"
) -> "GenerateOutput":
Expand Down
10 changes: 5 additions & 5 deletions src/distilabel/llms/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def compute_tokens(
tokenizer: A callable function that take str and returns the tokenized version of the text.
Returns:
int: _description_
The number of tokens.
"""
if isinstance(text_or_messages, str):
text = text_or_messages
else:
if isinstance(text_or_messages, list):
# If it's a list of messages, concatenate the content of each message
text = " ".join([message["content"] for message in text_or_messages])
else:
text = text_or_messages

return len(tokenizer(text)) if text else 0
return len(tokenizer(text))
12 changes: 11 additions & 1 deletion src/distilabel/llms/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,25 @@ async def agenerate( # type: ignore
)

text = None
input_tokens = 0
output_tokens = 0
try:
text = content.candidates[0].text
input_tokens = content.usage_metadata.prompt_token_count
output_tokens = content.usage_metadata.candidates_token_count
except ValueError:
self._logger.warning( # type: ignore
f"Received no response using VertexAI client (model: '{self.model}')."
f" Finish reason was: '{content.candidates[0].finish_reason}'."
)

return [text]
return {
"generations": [text],
"statistics": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
},
}


def _is_gemini_model(model: str) -> bool:
Expand Down
1 change: 0 additions & 1 deletion tests/unit/llms/huggingface/test_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ async def test_generate(self, mock_inference_client: MagicMock) -> None:
)

nest_asyncio.apply()

assert llm.generate(
inputs=[
[
Expand Down
Loading

0 comments on commit 9003fa0

Please sign in to comment.