Skip to content

Commit

Permalink
aws[patch]: Add Support to pass in AWS credentials to bedrock models …
Browse files Browse the repository at this point in the history
…directly (#197)

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
langchain-infra and baskaryan authored Sep 19, 2024
1 parent cf435cf commit c419f76
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 55 deletions.
3 changes: 3 additions & 0 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,9 @@ def _as_converse(self) -> ChatBedrockConverse:
model=self.model_id,
region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
config=self.config,
provider=self.provider or "",
base_url=self.endpoint_url,
Expand Down
117 changes: 87 additions & 30 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import secret_from_env
from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_aws.function_calling import ToolsOutputParser
Expand Down Expand Up @@ -307,8 +308,46 @@ class Joke(BaseModel):
Profile should either have access keys or role information specified.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""

aws_access_key_id: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_ACCESS_KEY_ID", default=None)
)
"""AWS access key id.
If provided, aws_secret_access_key must also be provided.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If not provided, will be read from 'AWS_ACCESS_KEY_ID' environment variable.
"""

aws_secret_access_key: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_SECRET_ACCESS_KEY", default=None)
)
"""AWS secret_access_key.
If provided, aws_access_key_id must also be provided.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If not provided, will be read from 'AWS_SECRET_ACCESS_KEY' environment variable.
"""

aws_session_token: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_SESSION_TOKEN", default=None)
)
"""AWS session token.
If provided, aws_access_key_id and aws_secret_access_key must
also be provided. Not required unless using temporary credentials.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If not provided, will be read from 'AWS_SESSION_TOKEN' environment variable.
"""

provider: str = ""
Expand Down Expand Up @@ -358,6 +397,14 @@ class Joke(BaseModel):
populate_by_name=True,
)

@property
def lc_secrets(self) -> Dict[str, str]:
return {
"aws_access_key_id": "AWS_ACCESS_KEY_ID",
"aws_secret_access_key": "AWS_SECRET_ACCESS_KEY",
"aws_session_token": "AWS_SESSION_TOKEN",
}

@model_validator(mode="before")
@classmethod
def set_disable_streaming(cls, values: Dict) -> Any:
Expand All @@ -376,6 +423,7 @@ def set_disable_streaming(cls, values: Dict) -> Any:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""

# As of 08/05/24 only claude-3 and mistral-large models support tool choice:
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
if self.supports_tool_choice_values is None:
Expand All @@ -386,44 +434,53 @@ def validate_environment(self) -> Self:
else:
self.supports_tool_choice_values = ()

# Skip creating new client if passed in constructor
if self.client is None:
try:
if self.credentials_profile_name is not None:
session = boto3.Session(profile_name=self.credentials_profile_name)
else:
session = boto3.Session()
except ValueError as e:
raise ValueError(f"Error raised by bedrock service: {e}")
except Exception as e:
creds = {
"aws_access_key_id": self.aws_access_key_id,
"aws_secret_access_key": self.aws_secret_access_key,
"aws_session_token": self.aws_session_token,
}
if creds["aws_access_key_id"] and creds["aws_secret_access_key"]:
session_params = {
k: v.get_secret_value() for k, v in creds.items() if v
}
elif any(creds.values()):
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error: {e}"
) from e
f"If any of aws_access_key_id, aws_secret_access_key, or "
f"aws_session_token are specified then both aws_access_key_id and "
f"aws_secret_access_key must be specified. Only received "
f"{(k for k, v in creds.items() if v)}."
)
elif self.credentials_profile_name is not None:
session_params = {"profile_name": self.credentials_profile_name}
else:
# use default credentials
session_params = {}

self.region_name = (
self.region_name
or os.getenv("AWS_DEFAULT_REGION")
or session.region_name
)
try:
session = boto3.Session(**session_params)

client_params = {}
if self.region_name:
client_params["region_name"] = self.region_name
if self.endpoint_url:
client_params["endpoint_url"] = self.endpoint_url
if self.config:
client_params["config"] = self.config
self.region_name = (
self.region_name
or os.getenv("AWS_DEFAULT_REGION")
or session.region_name
)

try:
client_params = {
"endpoint_url": self.endpoint_url,
"config": self.config,
"region_name": self.region_name,
}
client_params = {k: v for k, v in client_params.items() if v}
self.client = session.client("bedrock-runtime", **client_params)
except ValueError as e:
raise ValueError(f"Error raised by bedrock service: {e}")
raise ValueError(f"Error raised by bedrock service:\n\n{e}") from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error: {e}"
f"profile name are valid. Bedrock error:\n\n{e}"
) from e

return self
Expand Down
102 changes: 78 additions & 24 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Union,
)

import boto3
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand All @@ -25,7 +26,8 @@
from langchain_core.messages import AIMessageChunk, ToolCall
from langchain_core.messages.tool import tool_call, tool_call_chunk
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from pydantic import ConfigDict, Field, model_validator
from langchain_core.utils import secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_aws.function_calling import _tools_in_params
Expand Down Expand Up @@ -463,6 +465,44 @@ class BedrockBase(BaseLanguageModel, ABC):
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""

aws_access_key_id: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_ACCESS_KEY_ID", default=None)
)
"""AWS access key id.
If provided, aws_secret_access_key must also be provided.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If not provided, will be read from 'AWS_ACCESS_KEY_ID' environment variable.
"""

aws_secret_access_key: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_SECRET_ACCESS_KEY", default=None)
)
"""AWS secret_access_key.
If provided, aws_access_key_id must also be provided.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If not provided, will be read from 'AWS_SECRET_ACCESS_KEY' environment variable.
"""

aws_session_token: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_SESSION_TOKEN", default=None)
)
"""AWS session token.
If provided, aws_access_key_id and aws_secret_access_key must also be provided.
Not required unless using temporary credentials.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If not provided, will be read from 'AWS_SESSION_TOKEN' environment variable.
"""

config: Any = None
"""An optional botocore.config.Config instance to pass to the client."""

Expand Down Expand Up @@ -550,6 +590,14 @@ async def on_llm_error(
...Logic to handle guardrail intervention...
""" # noqa: E501

@property
def lc_secrets(self) -> Dict[str, str]:
return {
"aws_access_key_id": "AWS_ACCESS_KEY_ID",
"aws_secret_access_key": "AWS_SECRET_ACCESS_KEY",
"aws_session_token": "AWS_SESSION_TOKEN",
}

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""
Expand All @@ -558,43 +606,49 @@ def validate_environment(self) -> Self:
if self.client is not None:
return self

try:
import boto3
creds = {
"aws_access_key_id": self.aws_access_key_id,
"aws_secret_access_key": self.aws_secret_access_key,
"aws_session_token": self.aws_session_token,
}
if creds["aws_access_key_id"] and creds["aws_secret_access_key"]:
session_params = {k: v.get_secret_value() for k, v in creds.items() if v}
elif any(creds.values()):
raise ValueError(
f"If any of aws_access_key_id, aws_secret_access_key, or "
f"aws_session_token are specified then both aws_access_key_id and "
f"aws_secret_access_key must be specified. Only received "
f"{(k for k, v in creds.items() if v)}."
)
elif self.credentials_profile_name is not None:
session_params = {"profile_name": self.credentials_profile_name}
else:
# use default credentials
session_params = {}

if self.credentials_profile_name is not None:
session = boto3.Session(profile_name=self.credentials_profile_name)
else:
# use default credentials
session = boto3.Session()
try:
session = boto3.Session(**session_params)

self.region_name = (
self.region_name
or os.getenv("AWS_DEFAULT_REGION")
or session.region_name
)

client_params = {}
if self.region_name:
client_params["region_name"] = self.region_name
if self.endpoint_url:
client_params["endpoint_url"] = self.endpoint_url
if self.config:
client_params["config"] = self.config

client_params = {
"endpoint_url": self.endpoint_url,
"config": self.config,
"region_name": self.region_name,
}
client_params = {k: v for k, v in client_params.items() if v}
self.client = session.client("bedrock-runtime", **client_params)

except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except ValueError as e:
raise ValueError(f"Error raised by bedrock service: {e}")
raise ValueError(f"Error raised by bedrock service:\n\n{e}") from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error: {e}"
f"profile name are valid. Bedrock error:\n\n{e}"
) from e

return self
Expand Down
24 changes: 23 additions & 1 deletion libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test chat model integration."""

import base64
from typing import Dict, List, Type, cast
from typing import Dict, List, Tuple, Type, cast

import pytest
from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -47,6 +47,28 @@ def standard_chat_model_params(self) -> dict:
"stop": [],
}

@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
"""Return env vars, init args, and expected instance attrs for initializing
from env vars."""
return (
{
"AWS_ACCESS_KEY_ID": "key_id",
"AWS_SECRET_ACCESS_KEY": "secret_key",
"AWS_SESSION_TOKEN": "token",
"AWS_DEFAULT_REGION": "region",
},
{
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
},
{
"aws_access_key_id": "key_id",
"aws_secret_access_key": "secret_key",
"aws_session_token": "token",
"region_name": "region",
},
)

@pytest.mark.xfail(reason="Doesn't support streaming init param.")
def test_init_streaming(self) -> None:
super().test_init_streaming()
Expand Down

0 comments on commit c419f76

Please sign in to comment.