Skip to content

Commit

Permalink
Merge pull request #16 from langchain-ai/erick/fix-identifying-params
Browse files Browse the repository at this point in the history
fix identifying params
  • Loading branch information
efriis authored Apr 16, 2024
2 parents 688a3a5 + 8d2d6da commit 6042d9d
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.language_models import LLM, BaseLanguageModel
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env

from langchain_aws.utils import (
Expand Down Expand Up @@ -292,7 +292,7 @@ async def aprepare_output_stream(
)


class BedrockBase(BaseModel, ABC):
class BedrockBase(BaseLanguageModel, ABC):
"""Base class for Bedrock models."""

client: Any = Field(exclude=True) #: :meta private:
Expand Down Expand Up @@ -325,7 +325,7 @@ class BedrockBase(BaseModel, ABC):
equivalent to the modelId property in the list-foundation-models api. For custom and
provisioned models, an ARN value is expected."""

model_kwargs: Optional[Dict] = None
model_kwargs: Optional[Dict[str, Any]] = None
"""Keyword arguments to pass to the model."""

endpoint_url: Optional[str] = None
Expand Down Expand Up @@ -440,11 +440,14 @@ def validate_environment(cls, values: Dict) -> Dict:
return values

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
def _identifying_params(self) -> Dict[str, Any]:
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
"model_id": self.model_id,
"provider": self._get_provider(),
"stream": self.streaming,
"guardrails": self.guardrails,
**_model_kwargs,
}

def _get_provider(self) -> str:
Expand Down Expand Up @@ -617,7 +620,8 @@ def _prepare_input_and_invoke_stream(

# stop sequence from _generate() overrides
# stop sequences in the class attribute
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
if k := self.provider_stop_sequence_key_name_map.get(provider):
_model_kwargs[k] = stop

if provider == "cohere":
_model_kwargs["stream"] = True
Expand Down Expand Up @@ -679,7 +683,8 @@ async def _aprepare_input_and_invoke_stream(
raise ValueError(
f"Stop sequence key name for {provider} is not supported."
)
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
if k := self.provider_stop_sequence_key_name_map.get(provider):
_model_kwargs[k] = stop

if provider == "cohere":
_model_kwargs["stream"] = True
Expand Down

0 comments on commit 6042d9d

Please sign in to comment.