Skip to content

Commit

Permalink
Make ClientvLLM.model_name a cached_property (#862)
Browse files Browse the repository at this point in the history
* Update `ClientvLLM.model_name` to `cached_property`

* Fix unit test
  • Loading branch information
gabrielmbmb authored Aug 8, 2024
1 parent 1a39e01 commit 2ded30f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/distilabel/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -497,8 +498,8 @@ def load(self) -> None:
self.tokenizer, revision=self.tokenizer_revision
)

@property
def model_name(self) -> str:
@cached_property
def model_name(self) -> str: # type: ignore
"""Returns the name of the model served with vLLM server."""
models = self._client.models.list()
return models.data[0].id
Expand Down
5 changes: 2 additions & 3 deletions tests/unit/llms/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,14 @@ def test_prepare_batches_and_sort_back(
@mock.patch("openai.AsyncOpenAI")
class TestClientvLLM:
def test_clientvllm_model_name(
self, _openai_mock: mock.MagicMock, _async_openai_mock: mock.MagicMock
self, _: mock.MagicMock, openai_mock: mock.MagicMock
) -> None:
llm = ClientvLLM(
base_url="http://localhost:8000/v1",
tokenizer="google-bert/bert-base-uncased",
)

llm.load()

llm._client = mock.MagicMock()
llm._client.models.list.return_value = SyncPage[Model]( # type: ignore
data=[Model(id="llama", created=1234, object="model", owned_by="")],
object="model",
Expand Down

0 comments on commit 2ded30f

Please sign in to comment.