Skip to content

Commit

Permalink
langgraph: use create_model_v2 and fallback on create_model (#1708)
Browse files Browse the repository at this point in the history
* langgraph: use create_model_v2 and fallback on create_model

* fix

* code review

* fix
  • Loading branch information
vbarda authored Sep 16, 2024
1 parent 289e72b commit ab8aeea
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 6 deletions.
6 changes: 3 additions & 3 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.base import RunnableLike
from langchain_core.runnables.utils import create_model
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1

Expand Down Expand Up @@ -47,6 +46,7 @@
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.utils.fields import get_field_default
from langgraph.utils.pydantic import create_model
from langgraph.utils.runnable import coerce_to_runnable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -784,12 +784,12 @@ def _get_schema(
if len(keys) == 1 and keys[0] == "__root__":
return create_model( # type: ignore[call-overload]
name,
__root__=(channels[keys[0]].UpdateType, None),
root=(channels[keys[0]].UpdateType, None),
)
else:
return create_model( # type: ignore[call-overload]
name,
**{
field_definitions={
k: (
channels[k].UpdateType,
(
Expand Down
8 changes: 5 additions & 3 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
create_model,
get_function_nonlocals,
get_unique_config_specs,
)
Expand Down Expand Up @@ -91,6 +90,7 @@
patch_config,
patch_configurable,
)
from langgraph.utils.pydantic import create_model
from langgraph.utils.runnable import RunnableCallable

WriteValue = Union[Callable[[Input], Output], Any]
Expand Down Expand Up @@ -309,7 +309,7 @@ def get_input_schema(
else:
return create_model( # type: ignore[call-overload]
self.get_name("Input"),
**{
field_definitions={
k: (self.channels[k].UpdateType, None)
for k in self.input_channels or self.channels.keys()
},
Expand All @@ -329,7 +329,9 @@ def get_output_schema(
else:
return create_model( # type: ignore[call-overload]
self.get_name("Output"),
**{k: (self.channels[k].ValueType, None) for k in self.output_channels},
field_definitions={
k: (self.channels[k].ValueType, None) for k in self.output_channels
},
)

@property
Expand Down
37 changes: 37 additions & 0 deletions libs/langgraph/langgraph/utils/pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, Dict, Optional, Union

from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1


def create_model(
model_name: str,
*,
field_definitions: Optional[Dict[str, Any]] = None,
root: Optional[Any] = None,
) -> Union[BaseModel, BaseModelV1]:
"""Create a pydantic model with the given field definitions.
Args:
model_name: The name of the model.
field_definitions: The field definitions for the model.
root: Type for a root model (RootModel)
"""
try:
# for langchain-core >= 0.3.0
from langchain_core.runnables.pydantic import create_model_v2

return create_model_v2(
model_name,
field_definitions=field_definitions,
root=root,
)
except ImportError:
# for langchain-core < 0.3.0
from langchain_core.runnables.utils import create_model

v1_kwargs = {}
if root is not None:
v1_kwargs["__root__"] = root

return create_model(model_name, **v1_kwargs, **(field_definitions or {}))

0 comments on commit ab8aeea

Please sign in to comment.