diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 4d81991e8..d2c7b83ed 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -246,11 +246,13 @@ def generate( # type: ignore "generations": output, "statistics": { "input_tokens": [ - compute_tokens(row["content"], self._pipeline.tokenizer) + compute_tokens( + row["content"], self._pipeline.tokenizer.encode + ) for row in input ], "output_tokens": [ - compute_tokens(row, self._pipeline.tokenizer) + compute_tokens(row, self._pipeline.tokenizer.encode) for row in output ], }, diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py index 48cac8a50..b696763e3 100644 --- a/src/distilabel/llms/openai.py +++ b/src/distilabel/llms/openai.py @@ -22,6 +22,7 @@ from distilabel import envs from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException from distilabel.llms.base import AsyncLLM +from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType @@ -32,6 +33,7 @@ from openai.types import FileObject as OpenAIFileObject from openai.types.chat import ChatCompletion as OpenAIChatCompletion from pydantic import BaseModel + from tiktoken.core import Encoding _OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" @@ -168,6 +170,7 @@ class User(BaseModel): _api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME) _client: "OpenAI" = PrivateAttr(None) _aclient: "AsyncOpenAI" = PrivateAttr(None) + _tokenizer: "Encoding" = PrivateAttr(None) def load(self) -> None: """Loads the `AsyncOpenAI` client to benefit from async requests.""" @@ -210,6 +213,10 @@ def load(self) -> None: self._aclient = result.get("client") # type: ignore if structured_output := result.get("structured_output"): self.structured_output = structured_output + # It must be version 0.8.0 at least. + import tiktoken + + self._tokenizer = tiktoken.encoding_for_model(self.model) def unload(self) -> None: """Set clients to `None` as they both contain `thread._RLock` which cannot be pickled @@ -307,9 +314,20 @@ async def agenerate( # type: ignore kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore - if structured_output: - return self._generations_from_structured_output(completion) + # Note: Instructor extracts the content from the structured output, so we need to + # add the token count + generation = self._generations_from_structured_output(completion) + + return { + "generations": generation, + "statistics": { + "input_tokens": compute_tokens(input, self._tokenizer.encode), + "output_tokens": compute_tokens( + orjson.dumps(generation).decode("utf-8"), self._tokenizer.encode + ), + }, + } return self._generations_from_openai_completion(completion) @@ -346,7 +364,18 @@ def _generations_from_openai_completion( f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - return generations + + return { + "generations": generations, + "statistics": { + "input_tokens": completion.usage.prompt_tokens + if completion.usage + else 0, + "output_tokens": completion.usage.completion_tokens + if completion.usage + else 0, + }, + } def offline_batch_generate( self, diff --git a/src/distilabel/llms/statistics.py b/src/distilabel/llms/statistics.py index 4dd8714ed..8af0094b1 100644 --- a/src/distilabel/llms/statistics.py +++ b/src/distilabel/llms/statistics.py @@ -12,8 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List +from typing import Callable, List, Union +from distilabel.steps.tasks.typing import ChatType -def compute_tokens(text: str, tokenizer: Callable[[str], List[int]]) -> int: - return len(tokenizer.encode(text)) if text else 0 + +def compute_tokens( + text_or_messages: Union[str, ChatType], tokenizer: Callable[[str], List[int]] +) -> int: + """Helper function to count the number of tokens in a text or list of messages. + + Args: + text_or_messages: Either a string response or a list of messages. + tokenizer: A callable function that take str and returns the tokenized version of the text. + + Returns: + int: _description_ + """ + if isinstance(text_or_messages, str): + text = text_or_messages + else: + # If it's a list of messages, concatenate the content of each message + text = " ".join([message["content"] for message in text_or_messages]) + + return len(tokenizer(text)) if text else 0 diff --git a/tests/unit/llms/test_openai.py b/tests/unit/llms/test_openai.py index 03fb94c1d..c0c7fb427 100644 --- a/tests/unit/llms/test_openai.py +++ b/tests/unit/llms/test_openai.py @@ -65,11 +65,14 @@ async def test_agenerate( llm._aclient = async_openai_mock mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + ], + usage=Mock(prompt_tokens=100, completion_tokens=100), ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -78,6 +81,10 @@ async def test_agenerate( }, ] ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": 100, "output_tokens": 100}, + } @pytest.mark.asyncio async def test_agenerate_structured( @@ -93,6 +100,9 @@ async def test_agenerate_structured( }, ) # type: ignore llm._aclient = async_openai_mock + import tiktoken + + llm._tokenizer = tiktoken.encoding_for_model(self.model_id) sample_user = DummyUserDetail(name="John Doe", age=30) @@ -107,7 +117,10 @@ async def test_agenerate_structured( }, ] ) - assert generation[0] == sample_user.model_dump_json() + assert isinstance(generation, dict) + generations = generation["generations"] + assert generations[0] == sample_user.model_dump_json() + assert generation["statistics"] == {"input_tokens": 10, "output_tokens": 12} @pytest.mark.skipif( sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" @@ -206,6 +219,11 @@ def test_check_and_get_batch_results( }, } ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 100, + "total_tokens": 200, + }, }, }, }, @@ -228,6 +246,11 @@ def test_check_and_get_batch_results( }, } ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 100, + "total_tokens": 200, + }, }, }, }, @@ -236,7 +259,23 @@ def test_check_and_get_batch_results( llm.load() outputs = llm._check_and_get_batch_results() - assert outputs == [["output 1"], ["output 2"]] + + assert outputs == [ + { + "generations": ["output 1"], + "statistics": { + "input_tokens": 100, + "output_tokens": 100, + }, + }, + { + "generations": ["output 2"], + "statistics": { + "input_tokens": 100, + "output_tokens": 100, + }, + }, + ] def test_check_and_get_batch_results_raises_valueerror( self, _async_openai_mock: MagicMock, _openai_mock: MagicMock @@ -322,12 +361,23 @@ def test_parse_output( }, } ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 100, + "total_tokens": 200, + }, }, } } ) - assert result == [" Aenean hendrerit aliquam velit. ..."] + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": { + "input_tokens": 100, + "output_tokens": 100, + }, + } def test_retrieve_batch_results( self, _async_openai_mock: MagicMock, openai_mock: MagicMock