Skip to content

Commit

Permalink
updated llm clients to also abide by the global caching parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
djl11 committed Nov 7, 2024
1 parent d5ed783 commit 56f8aa5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion unify/universal_api/clients/multi_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
stateful: bool = False,
return_full_completion: bool = False,
traced: bool = False,
cache: bool = False,
cache: bool = None,
# passthrough arguments
extra_headers: Optional[Headers] = None,
extra_query: Optional[Query] = None,
Expand Down
22 changes: 9 additions & 13 deletions unify/universal_api/clients/uni_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from openai.types.chat.completion_create_params import ResponseFormat
from typing_extensions import Self
from unify import BASE_URL, LOCAL_MODELS
from ...utils._caching import _get_cache, _write_to_cache
from ...utils._caching import _get_cache, _write_to_cache, _get_caching
from ..clients.base import _Client
from ..types import Prompt
from ..utils.endpoint_metrics import Metrics
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
stateful: bool = False,
return_full_completion: bool = False,
traced: bool = False,
cache: bool = False,
cache: Optional[bool] = None,
# passthrough arguments
extra_headers: Optional[Headers] = None,
extra_query: Optional[Query] = None,
Expand Down Expand Up @@ -813,7 +813,7 @@ def _generate_non_stream(
log_response_body=log_response_body,
)
chat_completion = None
if cache:
if cache is True or _get_caching() and cache is None:
chat_completion = _get_cache(fn_name="chat.completions.create", kw=kw)
if chat_completion is None:
try:
Expand All @@ -832,7 +832,7 @@ def _generate_non_stream(
print(f"done (thread {threading.get_ident()})")
except openai.APIStatusError as e:
raise Exception(e.message)
if cache:
if cache is True or _get_caching() and cache is None:
_write_to_cache(
fn_name="chat.completions.create",
kw=kw,
Expand Down Expand Up @@ -1044,14 +1044,10 @@ async def _generate_non_stream(
log_query_body=log_query_body,
log_response_body=log_response_body,
)
chat_completion = (
_get_cache(
fn_name="chat.completions.create",
kw=kw,
)
if cache
else None
)
if cache is True or _get_caching() and cache is None:
chat_completion = _get_cache(fn_name="chat.completions.create", kw=kw)
else:
chat_completion = None
if chat_completion is None:
try:
if endpoint in LOCAL_MODELS:
Expand All @@ -1071,7 +1067,7 @@ async def _generate_non_stream(
)
except openai.APIStatusError as e:
raise Exception(e.message)
if cache:
if cache is True or _get_caching() and cache is None:
_write_to_cache(
fn_name="chat.completions.create",
kw=kw,
Expand Down

0 comments on commit 56f8aa5

Please sign in to comment.