Skip to content

Commit

Permalink
Openai computed tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Oct 14, 2024
1 parent a51ce59 commit fe5d4c5
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 13 deletions.
6 changes: 4 additions & 2 deletions src/distilabel/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,13 @@ def generate( # type: ignore
"generations": output,
"statistics": {
"input_tokens": [
compute_tokens(row["content"], self._pipeline.tokenizer)
compute_tokens(
row["content"], self._pipeline.tokenizer.encode
)
for row in input
],
"output_tokens": [
compute_tokens(row, self._pipeline.tokenizer)
compute_tokens(row, self._pipeline.tokenizer.encode)
for row in output
],
},
Expand Down
35 changes: 32 additions & 3 deletions src/distilabel/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from distilabel import envs
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
from distilabel.llms.base import AsyncLLM
from distilabel.llms.statistics import compute_tokens
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType
Expand All @@ -32,6 +33,7 @@
from openai.types import FileObject as OpenAIFileObject
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from pydantic import BaseModel
from tiktoken.core import Encoding


_OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"
Expand Down Expand Up @@ -168,6 +170,7 @@ class User(BaseModel):
_api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME)
_client: "OpenAI" = PrivateAttr(None)
_aclient: "AsyncOpenAI" = PrivateAttr(None)
_tokenizer: "Encoding" = PrivateAttr(None)

def load(self) -> None:
"""Loads the `AsyncOpenAI` client to benefit from async requests."""
Expand Down Expand Up @@ -210,6 +213,10 @@ def load(self) -> None:
self._aclient = result.get("client") # type: ignore
if structured_output := result.get("structured_output"):
self.structured_output = structured_output
# It must be version 0.8.0 at least.
import tiktoken

self._tokenizer = tiktoken.encoding_for_model(self.model)

def unload(self) -> None:
"""Set clients to `None` as they both contain `thread._RLock` which cannot be pickled
Expand Down Expand Up @@ -307,9 +314,20 @@ async def agenerate( # type: ignore
kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore

completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore

if structured_output:
return self._generations_from_structured_output(completion)
# Note: Instructor extracts the content from the structured output, so we need to
# add the token count
generation = self._generations_from_structured_output(completion)

return {
"generations": generation,
"statistics": {
"input_tokens": compute_tokens(input, self._tokenizer.encode),
"output_tokens": compute_tokens(
orjson.dumps(generation).decode("utf-8"), self._tokenizer.encode
),
},
}

return self._generations_from_openai_completion(completion)

Expand Down Expand Up @@ -346,7 +364,18 @@ def _generations_from_openai_completion(
f" Finish reason was: {choice.finish_reason}"
)
generations.append(content)
return generations

return {
"generations": generations,
"statistics": {
"input_tokens": completion.usage.prompt_tokens
if completion.usage
else 0,
"output_tokens": completion.usage.completion_tokens
if completion.usage
else 0,
},
}

def offline_batch_generate(
self,
Expand Down
25 changes: 22 additions & 3 deletions src/distilabel/llms/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, List
from typing import Callable, List, Union

from distilabel.steps.tasks.typing import ChatType

def compute_tokens(text: str, tokenizer: Callable[[str], List[int]]) -> int:
return len(tokenizer.encode(text)) if text else 0

def compute_tokens(
text_or_messages: Union[str, ChatType], tokenizer: Callable[[str], List[int]]
) -> int:
"""Helper function to count the number of tokens in a text or list of messages.
Args:
text_or_messages: Either a string response or a list of messages.
tokenizer: A callable function that take str and returns the tokenized version of the text.
Returns:
int: _description_
"""
if isinstance(text_or_messages, str):
text = text_or_messages
else:
# If it's a list of messages, concatenate the content of each message
text = " ".join([message["content"] for message in text_or_messages])

return len(tokenizer(text)) if text else 0
60 changes: 55 additions & 5 deletions tests/unit/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ async def test_agenerate(
llm._aclient = async_openai_mock

mocked_completion = Mock(
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
choices=[
Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))
],
usage=Mock(prompt_tokens=100, completion_tokens=100),
)
llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion)

await llm.agenerate(
result = await llm.agenerate(
input=[
{"role": "system", "content": ""},
{
Expand All @@ -78,6 +81,10 @@ async def test_agenerate(
},
]
)
assert result == {
"generations": [" Aenean hendrerit aliquam velit. ..."],
"statistics": {"input_tokens": 100, "output_tokens": 100},
}

@pytest.mark.asyncio
async def test_agenerate_structured(
Expand All @@ -93,6 +100,9 @@ async def test_agenerate_structured(
},
) # type: ignore
llm._aclient = async_openai_mock
import tiktoken

llm._tokenizer = tiktoken.encoding_for_model(self.model_id)

sample_user = DummyUserDetail(name="John Doe", age=30)

Expand All @@ -107,7 +117,10 @@ async def test_agenerate_structured(
},
]
)
assert generation[0] == sample_user.model_dump_json()
assert isinstance(generation, dict)
generations = generation["generations"]
assert generations[0] == sample_user.model_dump_json()
assert generation["statistics"] == {"input_tokens": 10, "output_tokens": 12}

@pytest.mark.skipif(
sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher"
Expand Down Expand Up @@ -206,6 +219,11 @@ def test_check_and_get_batch_results(
},
}
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 100,
"total_tokens": 200,
},
},
},
},
Expand All @@ -228,6 +246,11 @@ def test_check_and_get_batch_results(
},
}
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 100,
"total_tokens": 200,
},
},
},
},
Expand All @@ -236,7 +259,23 @@ def test_check_and_get_batch_results(
llm.load()

outputs = llm._check_and_get_batch_results()
assert outputs == [["output 1"], ["output 2"]]

assert outputs == [
{
"generations": ["output 1"],
"statistics": {
"input_tokens": 100,
"output_tokens": 100,
},
},
{
"generations": ["output 2"],
"statistics": {
"input_tokens": 100,
"output_tokens": 100,
},
},
]

def test_check_and_get_batch_results_raises_valueerror(
self, _async_openai_mock: MagicMock, _openai_mock: MagicMock
Expand Down Expand Up @@ -322,12 +361,23 @@ def test_parse_output(
},
}
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 100,
"total_tokens": 200,
},
},
}
}
)

assert result == [" Aenean hendrerit aliquam velit. ..."]
assert result == {
"generations": [" Aenean hendrerit aliquam velit. ..."],
"statistics": {
"input_tokens": 100,
"output_tokens": 100,
},
}

def test_retrieve_batch_results(
self, _async_openai_mock: MagicMock, openai_mock: MagicMock
Expand Down

0 comments on commit fe5d4c5

Please sign in to comment.