Skip to content

Commit

Permalink
First version of async llms with statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Oct 14, 2024
1 parent fe5d4c5 commit 394984f
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 46 deletions.
39 changes: 32 additions & 7 deletions src/distilabel/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
get_type_hints,
)

import orjson
from httpx import AsyncClient
from pydantic import Field, PrivateAttr, SecretStr, validate_call

from distilabel.llms.base import AsyncLLM
from distilabel.llms.statistics import compute_tokens
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import (
Expand All @@ -36,7 +38,11 @@
)

if TYPE_CHECKING:
from typing import BaseModel

from anthropic import AsyncAnthropic
from anthropic.types import Message
from tokenizers import Tokenizer


_ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY"
Expand Down Expand Up @@ -142,6 +148,7 @@ class User(BaseModel):

_api_key_env_var: str = PrivateAttr(default=_ANTHROPIC_API_KEY_ENV_VAR_NAME)
_aclient: Optional["AsyncAnthropic"] = PrivateAttr(...)
_tokenizer: "Tokenizer" = PrivateAttr(...)

def _check_model_exists(self) -> None:
"""Checks if the specified model exists in the available models."""
Expand Down Expand Up @@ -198,6 +205,10 @@ def load(self) -> None:
if structured_output := result.get("structured_output"):
self.structured_output = structured_output

from anthropic._tokenizers import sync_get_tokenizer

self._tokenizer = sync_get_tokenizer()

@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
Expand Down Expand Up @@ -260,17 +271,31 @@ async def agenerate( # type: ignore
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output)

generations = []

completion = await self._aclient.messages.create(**kwargs) # type: ignore
completion: Union["Message", "BaseModel"] = await self._aclient.messages.create(
**kwargs
) # type: ignore
if structured_output:
generations.append(completion.model_dump_json())
return generations
str_response = completion.model_dump_json()
return {
"generations": str_response,
"statistics": {
"input_tokens": compute_tokens(input, self._tokenizer.encode),
"output_tokens": compute_tokens(
orjson.dumps(str_response).decode("utf-8"),
self._tokenizer.encode,
),
},
}

if (content := completion.content[0].text) is None:
self._logger.warning(
f"Received no response using Anthropic client (model: '{self.model}')."
f" Finish reason was: {completion.stop_reason}"
)
generations.append(content)
return generations
return {
"generations": content,
"statistics": {
"input_tokens": completion.usage.input_tokens,
"output_tokens": completion.usage.output_tokens,
},
}
4 changes: 2 additions & 2 deletions src/distilabel/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,9 @@ async def _agenerate(
for input in inputs
for _ in range(num_generations)
]
outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
outputs = await asyncio.gather(*tasks)
return [
list(group)
list(group)[0]
for group in grouper(outputs, n=num_generations, incomplete="ignore")
]

Expand Down
46 changes: 40 additions & 6 deletions src/distilabel/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
Union,
)

import orjson
from pydantic import Field, PrivateAttr, SecretStr, validate_call
from tokenizers import Tokenizer

from distilabel.llms.base import AsyncLLM
from distilabel.llms.statistics import compute_tokens
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import (
Expand All @@ -34,7 +37,8 @@
)

if TYPE_CHECKING:
from cohere import AsyncClient, ChatMessage
from cohere import AsyncClient, ChatMessage, NonStreamedChatResponse
from pydantic import BaseModel


_COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY"
Expand Down Expand Up @@ -135,6 +139,7 @@ class User(BaseModel):

_ChatMessage: Type["ChatMessage"] = PrivateAttr(...)
_aclient: "AsyncClient" = PrivateAttr(...)
_tokenizer: "Tokenizer" = PrivateAttr(...)

@property
def model_name(self) -> str:
Expand Down Expand Up @@ -172,6 +177,10 @@ def load(self) -> None:
if structured_output := result.get("structured_output"):
self.structured_output = structured_output

from cohere.manually_maintained.tokenizers import get_hf_tokenizer

self._tokenizer: "Tokenizer" = get_hf_tokenizer(self._aclient, self.model)

def _format_chat_to_cohere(
self, input: "FormattedInput"
) -> Tuple[Union[str, None], List["ChatMessage"], str]:
Expand Down Expand Up @@ -278,16 +287,41 @@ async def agenerate( # type: ignore
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore

response = await self._aclient.chat(**kwargs) # type: ignore
response: Union[
"NonStreamedChatResponse", "BaseModel"
] = await self._aclient.chat(**kwargs) # type: ignore

if structured_output:
return [response.model_dump_json()]
# TODO: Refactor the dict response, it's quite similar in many LLMs
str_response = response.model_dump_json()
return {
"generations": str_response,
"statistics": {
"input_tokens": compute_tokens(input, self._tokenizer.encode),
"output_tokens": compute_tokens(
orjson.dumps(str_response).decode("utf-8"),
self._tokenizer.encode,
),
},
}

if (text := response.text) == "":
self._logger.warning( # type: ignore
f"Received no response using Cohere client (model: '{self.model}')."
f" Finish reason was: {response.finish_reason}"
)
return [None]

return [text]
return {
"generations": None,
"statistics": {
"input_tokens": compute_tokens(input, self._tokenizer.encode),
"output_tokens": 0,
},
}

return {
"generations": text,
"statistics": {
"input_tokens": compute_tokens(input, self._tokenizer.encode),
"output_tokens": compute_tokens(text, self._tokenizer.encode),
},
}
22 changes: 20 additions & 2 deletions src/distilabel/llms/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,14 @@ async def agenerate( # type: ignore
completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
if structured_output:
generations.append(completion.model_dump_json())
return generations
return {
"generations": generations,
"statistics": {
# TODO: Need a way of knowing the tokenizer.
"input_tokens": 0,
"output_tokens": 0,
},
}

for choice in completion.choices:
if (content := choice.message.content) is None:
Expand All @@ -238,4 +245,15 @@ async def agenerate( # type: ignore
f" Finish reason was: {choice.finish_reason}"
)
generations.append(content)
return generations

return {
"generations": generations,
"statistics": {
"input_tokens": completion.usage.prompt_tokens
if completion.usage
else 0,
"output_tokens": completion.usage.completion_tokens
if completion.usage
else 0,
},
}
19 changes: 15 additions & 4 deletions tests/unit/llms/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None:
llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key") # type: ignore
llm._aclient = mock_anthropic

mocked_completion = Mock()
mocked_completion.content = [Mock(text="Aenean hendrerit aliquam velit...")]
mocked_completion = Mock(
content=[Mock(text="Aenean hendrerit aliquam velit...")],
usage=Mock(input_tokens=100, output_tokens=100),
)

llm._aclient.messages.create = AsyncMock(return_value=mocked_completion)

await llm.agenerate(
result = await llm.agenerate(
input=[
{"role": "system", "content": ""},
{
Expand All @@ -51,6 +53,10 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None:
},
]
)
assert result == {
"generations": "Aenean hendrerit aliquam velit...",
"statistics": {"input_tokens": 100, "output_tokens": 100},
}

@pytest.mark.asyncio
async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
Expand All @@ -64,6 +70,9 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
},
) # type: ignore
llm._aclient = mock_openai
from anthropic._tokenizers import sync_get_tokenizer

llm._tokenizer = sync_get_tokenizer()

sample_user = DummyUserDetail(name="John Doe", age=30)

Expand All @@ -78,7 +87,9 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
},
]
)
assert generation[0] == sample_user.model_dump_json()
generations = generation["generations"]
assert generations == sample_user.model_dump_json()
assert generation["statistics"] == {"input_tokens": 20, "output_tokens": 11}

@pytest.mark.skipif(
sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher"
Expand Down
38 changes: 21 additions & 17 deletions tests/unit/llms/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import nest_asyncio
import pytest
from tokenizers import Tokenizer

from distilabel.llms.cohere import CohereLLM

Expand Down Expand Up @@ -50,16 +51,12 @@ async def test_agenerate(self, mock_async_client: mock.MagicMock) -> None:
llm = CohereLLM(model="command-r")
llm._aclient = mock_async_client # type: ignore

mocked_completion = mock.Mock(
choices=[
mock.Mock(
message=mock.Mock(content=" Aenean hendrerit aliquam velit. ...")
)
]
)
mocked_completion = mock.Mock(text="Aenean hendrerit aliquam velit...")
llm._aclient.chat = mock.AsyncMock(return_value=mocked_completion)

await llm.agenerate(
llm._tokenizer = Tokenizer.from_pretrained("bert-base-uncased")

result = await llm.agenerate(
input=[
{"role": "system", "content": ""},
{
Expand All @@ -68,6 +65,10 @@ async def test_agenerate(self, mock_async_client: mock.MagicMock) -> None:
},
]
)
assert result == {
"generations": ["Aenean hendrerit aliquam velit..."],
"statistics": {"input_tokens": 23, "output_tokens": 16},
}

@pytest.mark.skipif(
sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher"
Expand All @@ -89,6 +90,7 @@ async def test_agenerate_structured(
sample_user = DummyUserDetail(name="John Doe", age=30)

llm._aclient.chat = mock.AsyncMock(return_value=sample_user)
llm._tokenizer = Tokenizer.from_pretrained("bert-base-uncased")

generation = await llm.agenerate(
input=[
Expand All @@ -99,25 +101,23 @@ async def test_agenerate_structured(
},
]
)
assert generation == [sample_user.model_dump_json()]
assert generation == {
"generations": [sample_user.model_dump_json()],
"statistics": {"input_tokens": 23, "output_tokens": 26},
}

@pytest.mark.asyncio
async def test_generate(self, mock_async_client: mock.MagicMock) -> None:
llm = CohereLLM(model="command-r")
llm._aclient = mock_async_client # type: ignore

mocked_completion = mock.Mock(
choices=[
mock.Mock(
message=mock.Mock(content=" Aenean hendrerit aliquam velit. ...")
)
]
)
mocked_completion = mock.Mock(text="Aenean hendrerit aliquam velit...")
llm._aclient.chat = mock.AsyncMock(return_value=mocked_completion)

llm._tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
nest_asyncio.apply()

llm.generate(
result = llm.generate(
inputs=[
[
{"role": "system", "content": ""},
Expand All @@ -128,6 +128,10 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None:
]
]
)
assert result == {
"generations": ["Aenean hendrerit aliquam velit..."],
"statistics": {"input_tokens": 23, "output_tokens": 16},
}

@pytest.mark.parametrize(
"structured_output, dump",
Expand Down
Loading

0 comments on commit 394984f

Please sign in to comment.