diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 86b683be..c3dc1fc8 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -19,9 +19,9 @@ AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.language_models.llms import LLM +from langchain_core.language_models import LLM, BaseLanguageModel from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain_core.utils import get_from_dict_or_env from langchain_aws.utils import ( @@ -292,7 +292,7 @@ async def aprepare_output_stream( ) -class BedrockBase(BaseModel, ABC): +class BedrockBase(BaseLanguageModel, ABC): """Base class for Bedrock models.""" client: Any = Field(exclude=True) #: :meta private: @@ -325,7 +325,7 @@ class BedrockBase(BaseModel, ABC): equivalent to the modelId property in the list-foundation-models api. For custom and provisioned models, an ARN value is expected.""" - model_kwargs: Optional[Dict] = None + model_kwargs: Optional[Dict[str, Any]] = None """Keyword arguments to pass to the model.""" endpoint_url: Optional[str] = None @@ -440,11 +440,14 @@ def validate_environment(cls, values: Dict) -> Dict: return values @property - def _identifying_params(self) -> Mapping[str, Any]: - """Get the identifying parameters.""" + def _identifying_params(self) -> Dict[str, Any]: _model_kwargs = self.model_kwargs or {} return { - **{"model_kwargs": _model_kwargs}, + "model_id": self.model_id, + "provider": self._get_provider(), + "stream": self.streaming, + "guardrails": self.guardrails, + **_model_kwargs, } def _get_provider(self) -> str: @@ -617,7 +620,8 @@ def _prepare_input_and_invoke_stream( # stop sequence from _generate() overrides # stop sequences in the class attribute - _model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop + if k := self.provider_stop_sequence_key_name_map.get(provider): + _model_kwargs[k] = stop if provider == "cohere": _model_kwargs["stream"] = True @@ -679,7 +683,8 @@ async def _aprepare_input_and_invoke_stream( raise ValueError( f"Stop sequence key name for {provider} is not supported." ) - _model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop + if k := self.provider_stop_sequence_key_name_map.get(provider): + _model_kwargs[k] = stop if provider == "cohere": _model_kwargs["stream"] = True