Skip to content

Commit

Permalink
upstage[major]: upgrade pydantic and release 0.3 (#26)
Browse files Browse the repository at this point in the history
* upstage[major]: support pydantic v2 and langchain-core v0.3

* update deps and increment version

* set protected namespaces on embeddings

* fix lint errors

* increment version

* lock

---------

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
ccurme and baskaryan authored Sep 16, 2024
1 parent 8b9af58 commit eddddbf
Show file tree
Hide file tree
Showing 14 changed files with 403 additions and 435 deletions.
104 changes: 64 additions & 40 deletions libs/upstage/langchain_upstage/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
from operator import itemgetter
from typing import (
Expand All @@ -11,6 +13,7 @@
Tuple,
Type,
Union,
cast,
overload,
)

Expand All @@ -32,10 +35,9 @@
PydanticToolsParser,
)
from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import from_env, secret_from_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_openai.chat_models.base import (
BaseChatOpenAI,
Expand All @@ -45,7 +47,9 @@
_DictOrPydanticClass,
_is_pydantic_class,
)
from pydantic import BaseModel, Field, SecretStr, model_validator
from tokenizers import Tokenizer
from typing_extensions import Self

from langchain_upstage.document_parse import UpstageDocumentParseLoader

Expand Down Expand Up @@ -103,10 +107,23 @@ def _get_ls_params(

model_name: str = Field(default="solar-1-mini-chat", alias="model")
"""Model name to use."""
upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
upstage_api_key: SecretStr = Field(
default_factory=secret_from_env(
"UPSTAGE_API_KEY",
error_message=(
"You must specify an api key. "
"You can pass it an argument as `api_key=...` or "
"set the environment variable `UPSTAGE_API_KEY`."
),
),
alias="api_key",
)
"""Automatically inferred from env are `UPSTAGE_API_KEY` if not provided."""
upstage_api_base: Optional[str] = Field(
default="https://api.upstage.ai/v1/solar", alias="base_url"
default_factory=from_env(
"UPSTAGE_API_BASE", default="https://api.upstage.ai/v1/solar"
),
alias="base_url",
)
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
Expand All @@ -121,45 +138,38 @@ def _get_ls_params(
tokenizer_name: Optional[str] = "upstage/solar-pro-preview-tokenizer"
"""huggingface tokenizer name. Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-pro-preview-tokenizer"""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

values["upstage_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "upstage_api_key", "UPSTAGE_API_KEY")
)
values["upstage_api_base"] = values["upstage_api_base"] or os.getenv(
"UPSTAGE_API_BASE"
)

client_params = {
client_params: dict = {
"api_key": (
values["upstage_api_key"].get_secret_value()
if values["upstage_api_key"]
self.upstage_api_key.get_secret_value()
if self.upstage_api_key
else None
),
"base_url": values["upstage_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"base_url": self.upstage_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}

if not values.get("client"):
sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
self.client = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI(
if not (self.async_client or None):
async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
return values
return self

def _get_tokenizer(self) -> Tokenizer:
self.tokenizer_name = SOLAR_TOKENIZERS.get(self.model_name, self.tokenizer_name)
Expand Down Expand Up @@ -222,9 +232,8 @@ def _generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.client.create(messages=message_dicts, **params)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
response = self.client.create(**payload)
return self._create_chat_result(response)

async def _agenerate(
Expand All @@ -246,9 +255,8 @@ async def _agenerate(
)
return await agenerate_from_stream(stream_iter)

message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await self.async_client.create(messages=message_dicts, **params)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
response = await self.async_client.create(**payload)
return self._create_chat_result(response)

def _using_doc_parsing_model(self, kwargs: Dict[str, Any]) -> bool:
Expand Down Expand Up @@ -281,6 +289,22 @@ def _parse_documents(self, file_path: str) -> str:
document_contents += f"{file_title}:\n{doc.page_content}\n\n"
return document_contents

def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
messages = self._convert_input(input_).to_messages()
if stop is not None:
kwargs["stop"] = stop
return {
"messages": [_convert_message_to_dict(m) for m in messages],
**self._default_params,
**kwargs,
}

# TODO: Fix typing.
@overload # type: ignore[override]
def with_structured_output(
Expand Down Expand Up @@ -346,7 +370,7 @@ def with_structured_output(
.. code-block:: python
from langchain_upstage import ChatUpstage
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
Expand All @@ -372,7 +396,7 @@ class AnswerWithJustification(BaseModel):
.. code-block:: python
from langchain_upstage import ChatUpstage
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
Expand Down Expand Up @@ -400,8 +424,8 @@ class AnswerWithJustification(BaseModel):
.. code-block:: python
from langchain_upstage import ChatUpstage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
Expand Down Expand Up @@ -432,7 +456,7 @@ class AnswerWithJustification(BaseModel):
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
tools=[cast(type, schema)], first_tool_only=True
)
else:
output_parser = JsonOutputKeyToolsParser(
Expand Down
95 changes: 49 additions & 46 deletions libs/upstage/langchain_upstage/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import os
import warnings
from typing import (
Any,
Expand All @@ -16,18 +17,15 @@

import openai
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
from pydantic import (
BaseModel,
Extra,
ConfigDict,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
model_validator,
)
from typing_extensions import Self

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,10 +58,23 @@ class UpstageEmbeddings(BaseModel, Embeddings):
Not yet supported.
"""
upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""API Key for Solar API."""
upstage_api_base: str = Field(
default="https://api.upstage.ai/v1/solar", alias="base_url"
upstage_api_key: SecretStr = Field(
default_factory=secret_from_env(
"UPSTAGE_API_KEY",
error_message=(
"You must specify an api key. "
"You can pass it an argument as `api_key=...` or "
"set the environment variable `UPSTAGE_API_KEY`."
),
),
alias="api_key",
)
"""Automatically inferred from env are `UPSTAGE_API_KEY` if not provided."""
upstage_api_base: Optional[str] = Field(
default_factory=from_env(
"UPSTAGE_API_BASE", default="https://api.upstage.ai/v1/solar"
),
alias="base_url",
)
"""Endpoint URL to use."""
embedding_ctx_length: int = 4096
Expand Down Expand Up @@ -112,12 +123,15 @@ class UpstageEmbeddings(BaseModel, Embeddings):
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""

class Config:
extra = Extra.forbid
allow_population_by_field_name = True
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
protected_namespaces=(),
)

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
Expand All @@ -142,42 +156,31 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["model_kwargs"] = extra
return values

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""

upstage_api_key = get_from_dict_or_env(
values, "upstage_api_key", "UPSTAGE_API_KEY"
)
values["upstage_api_key"] = (
convert_to_secret_str(upstage_api_key) if upstage_api_key else None
)
values["upstage_api_base"] = values["upstage_api_base"] or os.getenv(
"UPSTAGE_API_BASE"
)
client_params = {
client_params: dict = {
"api_key": (
values["upstage_api_key"].get_secret_value()
if values["upstage_api_key"]
self.upstage_api_key.get_secret_value()
if self.upstage_api_key
else None
),
"base_url": values["upstage_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"base_url": self.upstage_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if not values.get("client"):
sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).embeddings
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI(
if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
self.client = openai.OpenAI(**client_params, **sync_specific).embeddings
if not (self.async_client or None):
async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).embeddings
return values
return self

@property
def _invocation_params(self) -> Dict[str, Any]:
Expand Down
Loading

0 comments on commit eddddbf

Please sign in to comment.