Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cerebras: pydantic compat, release 0.2.0 #3

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 35 additions & 52 deletions libs/cerebras/langchain_cerebras/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Wrapper around Cerebras' Chat Completions API."""

import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast

import openai
Expand All @@ -11,16 +10,17 @@
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
get_from_dict_or_env,
from_env,
secret_from_env,
)

# We ignore the "unused imports" here since we want to reexport these from this package.
from langchain_openai.chat_models.base import (
BaseChatOpenAI,
)
from pydantic import Field, SecretStr, model_validator
from typing_extensions import Self

CEREBRAS_BASE_URL = "https://api.cerebras.ai/v1/"

Expand Down Expand Up @@ -314,88 +314,71 @@ def _get_ls_params(
default_factory=secret_from_env("CEREBRAS_API_KEY", default=None),
)
"""Automatically inferred from env are `CEREBRAS_API_KEY` if not provided."""
cerebras_api_base: Optional[str] = Field(
default=CEREBRAS_BASE_URL, alias="base_url"
cerebras_api_base: str = Field(
default_factory=from_env("CEREBRAS_API_BASE", default=CEREBRAS_BASE_URL),
alias="base_url",
)

cerebras_proxy: Optional[str] = None
cerebras_proxy: str = Field(default_factory=from_env("CEREBRAS_PROXY", default=""))

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

values["cerebras_api_base"] = os.getenv(
"CEREBRAS_API_BASE", values["cerebras_api_base"]
)

values["cerebras_proxy"] = get_from_dict_or_env(
values, "cerebras_proxy", "CEREBRAS_PROXY", default=""
)

client_params = {
"api_key": (
values["cerebras_api_key"].get_secret_value()
if values["cerebras_api_key"]
self.cerebras_api_key.get_secret_value()
if self.cerebras_api_key
else None
),
# Ensure we always fallback to the Cerebras API url.
"base_url": (
values["cerebras_api_base"]
if values["cerebras_api_base"]
else CEREBRAS_BASE_URL
),
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"base_url": self.cerebras_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}

if values["cerebras_proxy"] and (
values["http_client"] or values["http_async_client"]
):
cerebras_proxy = values["cerebras_proxy"]
http_client = values["http_client"]
http_async_client = values["http_async_client"]
if self.cerebras_proxy and (self.http_client or self.http_async_client):
raise ValueError(
"Cannot specify 'cerebras_proxy' if one of "
"'http_client'/'http_async_client' is already specified. Received:\n"
f"{cerebras_proxy=}\n{http_client=}\n{http_async_client=}"
f"{self.cerebras_proxy=}\n{self.http_client=}\n{self.http_async_client=}"
)
if not values.get("client"):
if values["cerebras_proxy"] and not values["http_client"]:
if not self.client:
if self.cerebras_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
values["http_client"] = httpx.Client(proxy=values["cerebras_proxy"])
sync_specific = {"http_client": values["http_client"]}
values["root_client"] = openai.OpenAI(**client_params, **sync_specific)
values["client"] = values["root_client"].chat.completions
if not values.get("async_client"):
if values["cerebras_proxy"] and not values["http_async_client"]:
self.http_client = httpx.Client(proxy=self.cerebras_proxy)
sync_specific = {"http_client": self.http_client}
self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore
self.client = self.root_client.chat.completions
if not self.async_client:
if self.cerebras_proxy and not self.http_async_client:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
values["http_async_client"] = httpx.AsyncClient(
proxy=values["cerebras_proxy"]
)
async_specific = {"http_client": values["http_async_client"]}
values["root_async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific
self.http_async_client = httpx.AsyncClient(proxy=self.cerebras_proxy)
async_specific = {"http_client": self.http_async_client}
self.root_async_client = openai.AsyncOpenAI(
**client_params, # type: ignore
**async_specific, # type: ignore
)
values["async_client"] = values["root_async_client"].chat.completions
return values
self.async_client = self.root_async_client.chat.completions
return self

# Patch tool calling w/ streaming.
def _stream(
Expand Down
Loading
Loading