From eece50afd1c5f6ace69d3ccc3ebeec63a7b57bec Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 3 Apr 2024 20:01:18 -0700 Subject: [PATCH 1/4] Added Bedrock LLM. --- libs/aws/langchain_aws/__init__.py | 4 +- libs/aws/langchain_aws/llms/__init__.py | 3 +- libs/aws/langchain_aws/llms/bedrock.py | 919 ++++++++++++++++++ libs/aws/langchain_aws/utils.py | 33 + libs/aws/tests/unit_tests/llms/__init__.py | 0 .../aws/tests/unit_tests/llms/test_bedrock.py | 308 ++++++ 6 files changed, 1265 insertions(+), 2 deletions(-) create mode 100644 libs/aws/langchain_aws/llms/bedrock.py create mode 100644 libs/aws/langchain_aws/utils.py create mode 100644 libs/aws/tests/unit_tests/llms/__init__.py create mode 100644 libs/aws/tests/unit_tests/llms/test_bedrock.py diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index 00e29d09..610bbc0e 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -1,10 +1,12 @@ -from langchain_aws.llms import SagemakerEndpoint +from langchain_aws.llms import Bedrock, BedrockLLM, SagemakerEndpoint from langchain_aws.retrievers import ( AmazonKendraRetriever, AmazonKnowledgeBasesRetriever, ) __all__ = [ + "Bedrock", + "BedrockLLM", "SagemakerEndpoint", "AmazonKendraRetriever", "AmazonKnowledgeBasesRetriever", diff --git a/libs/aws/langchain_aws/llms/__init__.py b/libs/aws/langchain_aws/llms/__init__.py index 1c1157b6..4f7facff 100644 --- a/libs/aws/langchain_aws/llms/__init__.py +++ b/libs/aws/langchain_aws/llms/__init__.py @@ -1,3 +1,4 @@ +from langchain_aws.llms.bedrock import ALTERNATION_ERROR, Bedrock, BedrockLLM from langchain_aws.llms.sagemaker_endpoint import SagemakerEndpoint -__all__ = ["SagemakerEndpoint"] +__all__ = ["ALTERNATION_ERROR", "Bedrock", "BedrockLLM", "SagemakerEndpoint"] diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py new file mode 100644 index 00000000..86b683be --- /dev/null +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -0,0 +1,919 @@ +import asyncio +import json +import warnings +from abc import ABC +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, +) + +from langchain_core._api.deprecation import deprecated +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain_core.utils import get_from_dict_or_env + +from langchain_aws.utils import ( + enforce_stop_tokens, + get_num_tokens_anthropic, + get_token_ids_anthropic, +) + +AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace" +GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment" +HUMAN_PROMPT = "\n\nHuman:" +ASSISTANT_PROMPT = "\n\nAssistant:" +ALTERNATION_ERROR = ( + "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'." +) + + +def _add_newlines_before_ha(input_text: str) -> str: + new_text = input_text + for word in ["Human:", "Assistant:"]: + new_text = new_text.replace(word, "\n\n" + word) + for i in range(2): + new_text = new_text.replace("\n\n\n" + word, "\n\n" + word) + return new_text + + +def _human_assistant_format(input_text: str) -> str: + if input_text.count("Human:") == 0 or ( + input_text.find("Human:") > input_text.find("Assistant:") + and "Assistant:" in input_text + ): + input_text = HUMAN_PROMPT + " " + input_text # SILENT CORRECTION + if input_text.count("Assistant:") == 0: + input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION + if input_text[: len("Human:")] == "Human:": + input_text = "\n\n" + input_text + input_text = _add_newlines_before_ha(input_text) + count = 0 + # track alternation + for i in range(len(input_text)): + if input_text[i : i + len(HUMAN_PROMPT)] == HUMAN_PROMPT: + if count % 2 == 0: + count += 1 + else: + warnings.warn(ALTERNATION_ERROR + f" Received {input_text}") + if input_text[i : i + len(ASSISTANT_PROMPT)] == ASSISTANT_PROMPT: + if count % 2 == 1: + count += 1 + else: + warnings.warn(ALTERNATION_ERROR + f" Received {input_text}") + + if count % 2 == 1: # Only saw Human, no Assistant + input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION + + return input_text + + +def _stream_response_to_generation_chunk( + stream_response: Dict[str, Any], +) -> GenerationChunk: + """Convert a stream response to a generation chunk.""" + if not stream_response["delta"]: + return GenerationChunk(text="") + return GenerationChunk( + text=stream_response["delta"]["text"], + generation_info=dict( + finish_reason=stream_response.get("stop_reason", None), + ), + ) + + +class LLMInputOutputAdapter: + """Adapter class to prepare the inputs from Langchain to a format + that LLM model expects. + + It also provides helper function to extract + the generated text from the model response.""" + + provider_to_output_key_map = { + "anthropic": "completion", + "amazon": "outputText", + "cohere": "text", + "meta": "generation", + "mistral": "outputs", + } + + @classmethod + def prepare_input( + cls, + provider: str, + model_kwargs: Dict[str, Any], + prompt: Optional[str] = None, + system: Optional[str] = None, + messages: Optional[List[Dict]] = None, + ) -> Dict[str, Any]: + input_body = {**model_kwargs} + if provider == "anthropic": + if messages: + input_body["anthropic_version"] = "bedrock-2023-05-31" + input_body["messages"] = messages + if system: + input_body["system"] = system + if "max_tokens" not in input_body: + input_body["max_tokens"] = 1024 + if prompt: + input_body["prompt"] = _human_assistant_format(prompt) + if "max_tokens_to_sample" not in input_body: + input_body["max_tokens_to_sample"] = 1024 + elif provider in ("ai21", "cohere", "meta", "mistral"): + input_body["prompt"] = prompt + elif provider == "amazon": + input_body = dict() + input_body["inputText"] = prompt + input_body["textGenerationConfig"] = {**model_kwargs} + else: + input_body["inputText"] = prompt + + return input_body + + @classmethod + def prepare_output(cls, provider: str, response: Any) -> dict: + text = "" + if provider == "anthropic": + response_body = json.loads(response.get("body").read().decode()) + if "completion" in response_body: + text = response_body.get("completion") + elif "content" in response_body: + content = response_body.get("content") + text = content[0].get("text") + else: + response_body = json.loads(response.get("body").read()) + + if provider == "ai21": + text = response_body.get("completions")[0].get("data").get("text") + elif provider == "cohere": + text = response_body.get("generations")[0].get("text") + elif provider == "meta": + text = response_body.get("generation") + elif provider == "mistral": + text = response_body.get("outputs")[0].get("text") + else: + text = response_body.get("results")[0].get("outputText") + + headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) + prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) + completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) + return { + "text": text, + "body": response_body, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + + @classmethod + def prepare_output_stream( + cls, + provider: str, + response: Any, + stop: Optional[List[str]] = None, + messages_api: bool = False, + ) -> Iterator[GenerationChunk]: + stream = response.get("body") + + if not stream: + return + + if messages_api: + output_key = "message" + else: + output_key = cls.provider_to_output_key_map.get(provider, "") + + if not output_key: + raise ValueError( + f"Unknown streaming response output key for provider: {provider}" + ) + + for event in stream: + chunk = event.get("chunk") + if not chunk: + continue + + chunk_obj = json.loads(chunk.get("bytes").decode()) + + if provider == "cohere" and ( + chunk_obj["is_finished"] or chunk_obj[output_key] == "" + ): + return + + elif ( + provider == "mistral" + and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop" + ): + return + + elif messages_api and (chunk_obj.get("type") == "content_block_stop"): + return + + if messages_api and chunk_obj.get("type") in ( + "message_start", + "content_block_start", + "content_block_delta", + ): + if chunk_obj.get("type") == "content_block_delta": + chk = _stream_response_to_generation_chunk(chunk_obj) + yield chk + else: + continue + else: + # chunk obj format varies with provider + yield GenerationChunk( + text=( + chunk_obj[output_key] + if provider != "mistral" + else chunk_obj[output_key][0]["text"] + ), + generation_info={ + GUARDRAILS_BODY_KEY: ( + chunk_obj.get(GUARDRAILS_BODY_KEY) + if GUARDRAILS_BODY_KEY in chunk_obj + else None + ), + }, + ) + + @classmethod + async def aprepare_output_stream( + cls, provider: str, response: Any, stop: Optional[List[str]] = None + ) -> AsyncIterator[GenerationChunk]: + stream = response.get("body") + + if not stream: + return + + output_key = cls.provider_to_output_key_map.get(provider, None) + + if not output_key: + raise ValueError( + f"Unknown streaming response output key for provider: {provider}" + ) + + for event in stream: + chunk = event.get("chunk") + if not chunk: + continue + + chunk_obj = json.loads(chunk.get("bytes").decode()) + + if provider == "cohere" and ( + chunk_obj["is_finished"] or chunk_obj[output_key] == "" + ): + return + + if ( + provider == "mistral" + and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop" + ): + return + + yield GenerationChunk( + text=( + chunk_obj[output_key] + if provider != "mistral" + else chunk_obj[output_key][0]["text"] + ) + ) + + +class BedrockBase(BaseModel, ABC): + """Base class for Bedrock models.""" + + client: Any = Field(exclude=True) #: :meta private: + + region_name: Optional[str] = None + """The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable + or region specified in ~/.aws/config in case it is not provided here. + """ + + credentials_profile_name: Optional[str] = Field(default=None, exclude=True) + """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which + has either access keys or role information specified. + If not specified, the default credential profile or, if on an EC2 instance, + credentials from IMDS will be used. + See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + """ + + config: Any = None + """An optional botocore.config.Config instance to pass to the client.""" + + provider: Optional[str] = None + """The model provider, e.g., amazon, cohere, ai21, etc. When not supplied, provider + is extracted from the first part of the model_id e.g. 'amazon' in + 'amazon.titan-text-express-v1'. This value should be provided for model ids that do + not have the provider in them, e.g., custom and provisioned models that have an ARN + associated with them.""" + + model_id: str + """Id of the model to call, e.g., amazon.titan-text-express-v1, this is + equivalent to the modelId property in the list-foundation-models api. For custom and + provisioned models, an ARN value is expected.""" + + model_kwargs: Optional[Dict] = None + """Keyword arguments to pass to the model.""" + + endpoint_url: Optional[str] = None + """Needed if you don't want to default to us-east-1 endpoint""" + + streaming: bool = False + """Whether to stream the results.""" + + provider_stop_sequence_key_name_map: Mapping[str, str] = { + "anthropic": "stop_sequences", + "amazon": "stopSequences", + "ai21": "stop_sequences", + "cohere": "stop_sequences", + "mistral": "stop_sequences", + } + + guardrails: Optional[Mapping[str, Any]] = { + "id": None, + "version": None, + "trace": False, + } + """ + An optional dictionary to configure guardrails for Bedrock. + + This field 'guardrails' consists of two keys: 'id' and 'version', + which should be strings, but are initialized to None. It's used to + determine if specific guardrails are enabled and properly set. + + Type: + Optional[Mapping[str, str]]: A mapping with 'id' and 'version' keys. + + Example: + llm = Bedrock(model_id="", client=, + model_kwargs={}, + guardrails={ + "id": "", + "version": ""}) + + To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the + 'run_manager' parameter of the 'generate', '_call' methods. + + Example: + llm = Bedrock(model_id="", client=, + model_kwargs={}, + guardrails={ + "id": "", + "version": "", + "trace": True}, + callbacks=[BedrockAsyncCallbackHandler()]) + + [https://python.langchain.com/docs/modules/callbacks/] for more information on callback handlers. + + class BedrockAsyncCallbackHandler(AsyncCallbackHandler): + async def on_llm_error( + self, + error: BaseException, + **kwargs: Any, + ) -> Any: + reason = kwargs.get("reason") + if reason == "GUARDRAIL_INTERVENED": + ...Logic to handle guardrail intervention... + """ # noqa: E501 + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that AWS credentials to and python package exists in environment.""" + + # Skip creating new client if passed in constructor + if values["client"] is not None: + return values + + try: + import boto3 + + if values["credentials_profile_name"] is not None: + session = boto3.Session(profile_name=values["credentials_profile_name"]) + else: + # use default credentials + session = boto3.Session() + + values["region_name"] = get_from_dict_or_env( + values, + "region_name", + "AWS_DEFAULT_REGION", + default=session.region_name, + ) + + client_params = {} + if values["region_name"]: + client_params["region_name"] = values["region_name"] + if values["endpoint_url"]: + client_params["endpoint_url"] = values["endpoint_url"] + if values["config"]: + client_params["config"] = values["config"] + + values["client"] = session.client("bedrock-runtime", **client_params) + + except ImportError: + raise ModuleNotFoundError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." + ) + except ValueError as e: + raise ValueError(f"Error raised by bedrock service: {e}") + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + f"profile name are valid. Bedrock error: {e}" + ) from e + + return values + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} + return { + **{"model_kwargs": _model_kwargs}, + } + + def _get_provider(self) -> str: + if self.provider: + return self.provider + if self.model_id.startswith("arn"): + raise ValueError( + "Model provider should be supplied when passing a model ARN as " + "model_id" + ) + + return self.model_id.split(".")[0] + + @property + def _model_is_anthropic(self) -> bool: + return self._get_provider() == "anthropic" + + @property + def _guardrails_enabled(self) -> bool: + """ + Determines if guardrails are enabled and correctly configured. + Checks if 'guardrails' is a dictionary with non-empty 'id' and 'version' keys. + Checks if 'guardrails.trace' is true. + + Returns: + bool: True if guardrails are correctly configured, False otherwise. + Raises: + TypeError: If 'guardrails' lacks 'id' or 'version' keys. + """ + try: + return ( + isinstance(self.guardrails, dict) + and bool(self.guardrails["id"]) + and bool(self.guardrails["version"]) + ) + + except KeyError as e: + raise TypeError( + "Guardrails must be a dictionary with 'id' and 'version' keys." + ) from e + + def _get_guardrails_canonical(self) -> Dict[str, Any]: + """ + The canonical way to pass in guardrails to the bedrock service + adheres to the following format: + + "amazon-bedrock-guardrailDetails": { + "guardrailId": "string", + "guardrailVersion": "string" + } + """ + return { + "amazon-bedrock-guardrailDetails": { + "guardrailId": self.guardrails.get("id"), # type: ignore[union-attr] + "guardrailVersion": self.guardrails.get("version"), # type: ignore[union-attr] + } + } + + def _prepare_input_and_invoke( + self, + prompt: Optional[str] = None, + system: Optional[str] = None, + messages: Optional[List[Dict]] = None, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Tuple[str, Dict[str, Any]]: + _model_kwargs = self.model_kwargs or {} + + provider = self._get_provider() + params = {**_model_kwargs, **kwargs} + if self._guardrails_enabled: + params.update(self._get_guardrails_canonical()) + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, + model_kwargs=params, + prompt=prompt, + system=system, + messages=messages, + ) + body = json.dumps(input_body) + accept = "application/json" + contentType = "application/json" + + request_options = { + "body": body, + "modelId": self.model_id, + "accept": accept, + "contentType": contentType, + } + + if self._guardrails_enabled: + request_options["guardrail"] = "ENABLED" + if self.guardrails.get("trace"): # type: ignore[union-attr] + request_options["trace"] = "ENABLED" + + try: + response = self.client.invoke_model(**request_options) + + text, body, usage_info = LLMInputOutputAdapter.prepare_output( + provider, response + ).values() + + except Exception as e: + raise ValueError(f"Error raised by bedrock service: {e}") + + if stop is not None: + text = enforce_stop_tokens(text, stop) + + # Verify and raise a callback error if any intervention occurs or a signal is + # sent from a Bedrock service, + # such as when guardrails are triggered. + services_trace = self._get_bedrock_services_signal(body) # type: ignore[arg-type] + + if services_trace.get("signal") and run_manager is not None: + run_manager.on_llm_error( + Exception( + f"Error raised by bedrock service: {services_trace.get('reason')}" + ), + **services_trace, + ) + + return text, usage_info + + def _get_bedrock_services_signal(self, body: dict) -> dict: + """ + This function checks the response body for an interrupt flag or message that indicates + whether any of the Bedrock services have intervened in the processing flow. It is + primarily used to identify modifications or interruptions imposed by these services + during the request-response cycle with a Large Language Model (LLM). + """ # noqa: E501 + + if ( + self._guardrails_enabled + and self.guardrails.get("trace") # type: ignore[union-attr] + and self._is_guardrails_intervention(body) + ): + return { + "signal": True, + "reason": "GUARDRAIL_INTERVENED", + "trace": body.get(AMAZON_BEDROCK_TRACE_KEY), + } + + return { + "signal": False, + "reason": None, + "trace": None, + } + + def _is_guardrails_intervention(self, body: dict) -> bool: + return body.get(GUARDRAILS_BODY_KEY) == "GUARDRAIL_INTERVENED" + + def _prepare_input_and_invoke_stream( + self, + prompt: Optional[str] = None, + system: Optional[str] = None, + messages: Optional[List[Dict]] = None, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + _model_kwargs = self.model_kwargs or {} + provider = self._get_provider() + + if stop: + if provider not in self.provider_stop_sequence_key_name_map: + raise ValueError( + f"Stop sequence key name for {provider} is not supported." + ) + + # stop sequence from _generate() overrides + # stop sequences in the class attribute + _model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop + + if provider == "cohere": + _model_kwargs["stream"] = True + + params = {**_model_kwargs, **kwargs} + + if self._guardrails_enabled: + params.update(self._get_guardrails_canonical()) + + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, + prompt=prompt, + system=system, + messages=messages, + model_kwargs=params, + ) + body = json.dumps(input_body) + + request_options = { + "body": body, + "modelId": self.model_id, + "accept": "application/json", + "contentType": "application/json", + } + + if self._guardrails_enabled: + request_options["guardrail"] = "ENABLED" + if self.guardrails.get("trace"): # type: ignore[union-attr] + request_options["trace"] = "ENABLED" + + try: + response = self.client.invoke_model_with_response_stream(**request_options) + + except Exception as e: + raise ValueError(f"Error raised by bedrock service: {e}") + + for chunk in LLMInputOutputAdapter.prepare_output_stream( + provider, response, stop, True if messages else False + ): + yield chunk + # verify and raise callback error if any middleware intervened + self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type] + + if run_manager is not None: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + + async def _aprepare_input_and_invoke_stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + _model_kwargs = self.model_kwargs or {} + provider = self._get_provider() + + if stop: + if provider not in self.provider_stop_sequence_key_name_map: + raise ValueError( + f"Stop sequence key name for {provider} is not supported." + ) + _model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop + + if provider == "cohere": + _model_kwargs["stream"] = True + + params = {**_model_kwargs, **kwargs} + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, prompt=prompt, model_kwargs=params + ) + body = json.dumps(input_body) + + response = await asyncio.get_running_loop().run_in_executor( + None, + lambda: self.client.invoke_model_with_response_stream( + body=body, + modelId=self.model_id, + accept="application/json", + contentType="application/json", + ), + ) + + async for chunk in LLMInputOutputAdapter.aprepare_output_stream( + provider, response, stop + ): + yield chunk + if run_manager is not None and asyncio.iscoroutinefunction( + run_manager.on_llm_new_token + ): + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + elif run_manager is not None: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) # type: ignore[unused-coroutine] + + +class BedrockLLM(LLM, BedrockBase): + """Bedrock models. + + To authenticate, the AWS client uses the following methods to + automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + If a specific credential profile should be used, you must pass + the name of the profile from the ~/.aws/credentials file that is to be used. + + Make sure the credentials / roles used have the required policies to + access the Bedrock service. + """ + + """ + Example: + .. code-block:: python + + from bedrock_langchain.bedrock_llm import BedrockLLM + + llm = BedrockLLM( + credentials_profile_name="default", + model_id="amazon.titan-text-express-v1", + streaming=True + ) + + """ + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + model_id = values["model_id"] + if model_id.startswith("anthropic.claude-3"): + raise ValueError( + "Claude v3 models are not supported by this LLM." + "Please use `from langchain_community.chat_models import BedrockChat` " + "instead." + ) + return super().validate_environment(values) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "amazon_bedrock" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this model can be serialized by Langchain.""" + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "llms", "bedrock"] + + @property + def lc_attributes(self) -> Dict[str, Any]: + attributes: Dict[str, Any] = {} + + if self.region_name: + attributes["region_name"] = self.region_name + + return attributes + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Call out to Bedrock service with streaming. + + Args: + prompt (str): The prompt to pass into the model + stop (Optional[List[str]], optional): Stop sequences. These will + override any stop sequences in the `model_kwargs` attribute. + Defaults to None. + run_manager (Optional[CallbackManagerForLLMRun], optional): Callback + run managers used to process the output. Defaults to None. + + Returns: + Iterator[GenerationChunk]: Generator that yields the streamed responses. + + Yields: + Iterator[GenerationChunk]: Responses from the model. + """ + return self._prepare_input_and_invoke_stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ) + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to Bedrock service model. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = llm("Tell me a joke.") + """ + + if self.streaming: + completion = "" + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + return completion + + text, _ = self._prepare_input_and_invoke( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ) + return text + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncGenerator[GenerationChunk, None]: + """Call out to Bedrock service with streaming. + + Args: + prompt (str): The prompt to pass into the model + stop (Optional[List[str]], optional): Stop sequences. These will + override any stop sequences in the `model_kwargs` attribute. + Defaults to None. + run_manager (Optional[CallbackManagerForLLMRun], optional): Callback + run managers used to process the output. Defaults to None. + + Yields: + AsyncGenerator[GenerationChunk, None]: Generator that asynchronously yields + the streamed responses. + """ + async for chunk in self._aprepare_input_and_invoke_stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to Bedrock service model. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = await llm._acall("Tell me a joke.") + """ + + if not self.streaming: + raise ValueError("Streaming must be set to True for async operations. ") + + chunks = [ + chunk.text + async for chunk in self._astream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ) + ] + return "".join(chunks) + + def get_num_tokens(self, text: str) -> int: + if self._model_is_anthropic: + return get_num_tokens_anthropic(text) + else: + return super().get_num_tokens(text) + + def get_token_ids(self, text: str) -> List[int]: + if self._model_is_anthropic: + return get_token_ids_anthropic(text) + else: + return super().get_token_ids(text) + + +@deprecated(since="0.1.0", removal="0.2.0", alternative="BedrockLLM") +class Bedrock(BedrockLLM): + pass diff --git a/libs/aws/langchain_aws/utils.py b/libs/aws/langchain_aws/utils.py new file mode 100644 index 00000000..ff9188a2 --- /dev/null +++ b/libs/aws/langchain_aws/utils.py @@ -0,0 +1,33 @@ +import re +from typing import Any, List + + +def enforce_stop_tokens(text: str, stop: List[str]) -> str: + """Cut off the text as soon as any stop words occur.""" + return re.split("|".join(stop), text, maxsplit=1)[0] + + +def _get_anthropic_client() -> Any: + try: + import anthropic + except ImportError: + raise ImportError( + "Could not import anthropic python package. " + "This is needed in order to accurately tokenize the text " + "for anthropic models. Please install it with `pip install anthropic`." + ) + return anthropic.Anthropic() + + +def get_num_tokens_anthropic(text: str) -> int: + """Get the number of tokens in a string of text.""" + client = _get_anthropic_client() + return client.count_tokens(text=text) + + +def get_token_ids_anthropic(text: str) -> List[int]: + """Get the token ids for a string of text.""" + client = _get_anthropic_client() + tokenizer = client.get_tokenizer() + encoded_text = tokenizer.encode(text) + return encoded_text.ids diff --git a/libs/aws/tests/unit_tests/llms/__init__.py b/libs/aws/tests/unit_tests/llms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/tests/unit_tests/llms/test_bedrock.py b/libs/aws/tests/unit_tests/llms/test_bedrock.py new file mode 100644 index 00000000..3a7a0d41 --- /dev/null +++ b/libs/aws/tests/unit_tests/llms/test_bedrock.py @@ -0,0 +1,308 @@ +import json +from typing import AsyncGenerator, Dict +from unittest.mock import MagicMock, patch + +import pytest + +from langchain_aws import BedrockLLM +from langchain_aws.llms.bedrock import ( + ALTERNATION_ERROR, + _human_assistant_format, +) + +TEST_CASES = { + """Hey""": """ + +Human: Hey + +Assistant:""", + """ + +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """ +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """ + +Human: Human: Hello + +Assistant:""": ( + "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'." + ), + """Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""", + """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""": """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant:""", + """ + +Human: Hello + +Assistant: Hello + +Human: Hello + +Assistant: Hello + +Assistant: Hello""": ALTERNATION_ERROR, + """ + +Human: Hi. + +Assistant: Hi. + +Human: Hi. + +Human: Hi. + +Assistant:""": ALTERNATION_ERROR, + """ +Human: Hello""": """ + +Human: Hello + +Assistant:""", + """ + +Human: Hello +Hello + +Assistant""": """ + +Human: Hello +Hello + +Assistant + +Assistant:""", + """Hello + +Assistant:""": """ + +Human: Hello + +Assistant:""", + """Hello + +Human: Hello + +""": """Hello + +Human: Hello + + + +Assistant:""", + """ + +Human: Assistant: Hello""": """ + +Human: + +Assistant: Hello""", + """ + +Human: Human + +Assistant: Assistant + +Human: Assistant + +Assistant: Human""": """ + +Human: Human + +Assistant: Assistant + +Human: Assistant + +Assistant: Human""", + """ +Assistant: Hello there, your name is: + +Human. + +Human: Hello there, your name is: + +Assistant.""": """ + +Human: + +Assistant: Hello there, your name is: + +Human. + +Human: Hello there, your name is: + +Assistant. + +Assistant:""", + """ + +Human: Human: Hi + +Assistant: Hi""": ALTERNATION_ERROR, + """Human: Hi + +Human: Hi""": ALTERNATION_ERROR, + """ + +Assistant: Hi + +Human: Hi""": """ + +Human: + +Assistant: Hi + +Human: Hi + +Assistant:""", + """ + +Human: Hi + +Assistant: Yo + +Human: Hey + +Assistant: Sup + +Human: Hi + +Assistant: Hi +Human: Hi +Assistant:""": """ + +Human: Hi + +Assistant: Yo + +Human: Hey + +Assistant: Sup + +Human: Hi + +Assistant: Hi + +Human: Hi + +Assistant:""", + """ + +Hello. + +Human: Hello. + +Assistant:""": """ + +Hello. + +Human: Hello. + +Assistant:""", +} + + +def test__human_assistant_format() -> None: + for input_text, expected_output in TEST_CASES.items(): + if expected_output == ALTERNATION_ERROR: + with pytest.warns(UserWarning, match=ALTERNATION_ERROR): + _human_assistant_format(input_text) + else: + output = _human_assistant_format(input_text) + assert output == expected_output + + +# Sample mock streaming response data +MOCK_STREAMING_RESPONSE = [ + {"chunk": {"bytes": b'{"text": "nice"}'}}, + {"chunk": {"bytes": b'{"text": " to meet"}'}}, + {"chunk": {"bytes": b'{"text": " you"}'}}, +] + + +async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]: + for item in MOCK_STREAMING_RESPONSE: + yield item + + +@pytest.mark.asyncio +async def test_bedrock_async_streaming_call() -> None: + # Mock boto3 import + mock_boto3 = MagicMock() + mock_boto3.Session.return_value.client.return_value = ( + MagicMock() + ) # Mocking the client method of the Session object + + with patch.dict( + "sys.modules", {"boto3": mock_boto3} + ): # Mocking boto3 at the top level using patch.dict + # Mock the `BedrockLLM` class's method that invokes the model + mock_invoke_method = MagicMock(return_value=async_gen_mock_streaming_response()) + with patch.object( + BedrockLLM, "_aprepare_input_and_invoke_stream", mock_invoke_method + ): + # Instantiate the Bedrock LLM + llm = BedrockLLM( + client=None, + model_id="anthropic.claude-v2", + streaming=True, + ) + # Call the _astream method + chunks = [ + json.loads(chunk["chunk"]["bytes"])["text"] # type: ignore + async for chunk in llm._astream("Hey, how are you?") + ] + + # Assertions + assert len(chunks) == 3 + assert chunks[0] == "nice" + assert chunks[1] == " to meet" + assert chunks[2] == " you" From 278276a750f1d94f546728bbca9a96469581222e Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 3 Apr 2024 21:32:48 -0700 Subject: [PATCH 2/4] Added bedrock chat, fixed usage and tests. --- libs/aws/langchain_aws/chat_models/bedrock.py | 399 ++++++++++++++++++ libs/aws/langchain_aws/llms/__init__.py | 15 +- libs/aws/tests/callbacks.py | 391 +++++++++++++++++ .../chat_models/test_bedrock.py | 162 +++++++ 4 files changed, 965 insertions(+), 2 deletions(-) create mode 100644 libs/aws/langchain_aws/chat_models/bedrock.py create mode 100644 libs/aws/tests/callbacks.py create mode 100644 libs/aws/tests/integration_tests/chat_models/test_bedrock.py diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py new file mode 100644 index 00000000..4213fcfc --- /dev/null +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -0,0 +1,399 @@ +import re +from collections import defaultdict +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast + +from langchain_core.callbacks import ( + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import Extra + +from langchain_aws.llms.bedrock import BedrockBase +from langchain_aws.utils import ( + get_num_tokens_anthropic, + get_token_ids_anthropic, +) + + +def _convert_one_message_to_text_llama(message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {message.content}" + elif isinstance(message, HumanMessage): + message_text = f"[INST] {message.content} [/INST]" + elif isinstance(message, AIMessage): + message_text = f"{message.content}" + elif isinstance(message, SystemMessage): + message_text = f"<> {message.content} <>" + else: + raise ValueError(f"Got unknown type {message}") + return message_text + + +def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str: + """Convert a list of messages to a prompt for llama.""" + + return "\n".join( + [_convert_one_message_to_text_llama(message) for message in messages] + ) + + +def _convert_one_message_to_text_anthropic( + message: BaseMessage, + human_prompt: str, + ai_prompt: str, +) -> str: + content = cast(str, message.content) + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {content}" + elif isinstance(message, HumanMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, AIMessage): + message_text = f"{ai_prompt} {content}" + elif isinstance(message, SystemMessage): + message_text = content + else: + raise ValueError(f"Got unknown type {message}") + return message_text + + +def convert_messages_to_prompt_anthropic( + messages: List[BaseMessage], + *, + human_prompt: str = "\n\nHuman:", + ai_prompt: str = "\n\nAssistant:", +) -> str: + """Format a list of messages into a full prompt for the Anthropic model + Args: + messages (List[BaseMessage]): List of BaseMessage to combine. + human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:". + ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:". + Returns: + str: Combined string with necessary human_prompt and ai_prompt tags. + """ + + messages = messages.copy() # don't mutate the original list + if not isinstance(messages[-1], AIMessage): + messages.append(AIMessage(content="")) + + text = "".join( + _convert_one_message_to_text_anthropic(message, human_prompt, ai_prompt) + for message in messages + ) + + # trim off the trailing ' ' that might come from the "Assistant: " + return text.rstrip() + + +def _convert_one_message_to_text_mistral(message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {message.content}" + elif isinstance(message, HumanMessage): + message_text = f"[INST] {message.content} [/INST]" + elif isinstance(message, AIMessage): + message_text = f"{message.content}" + elif isinstance(message, SystemMessage): + message_text = f"<> {message.content} <>" + else: + raise ValueError(f"Got unknown type {message}") + return message_text + + +def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str: + """Convert a list of messages to a prompt for mistral.""" + return "\n".join( + [_convert_one_message_to_text_mistral(message) for message in messages] + ) + + +def _format_image(image_url: str) -> Dict: + """ + Formats an image of format data:image/jpeg;base64,{b64_string} + to a dict for anthropic api + + { + "type": "base64", + "media_type": "image/jpeg", + "data": "/9j/4AAQSkZJRg...", + } + + And throws an error if it's not a b64 image + """ + regex = r"^data:(?Pimage/.+);base64,(?P.+)$" + match = re.match(regex, image_url) + if match is None: + raise ValueError( + "Anthropic only supports base64-encoded images currently." + " Example: data:image/png;base64,'/9j/4AAQSk'..." + ) + return { + "type": "base64", + "media_type": match.group("media_type"), + "data": match.group("data"), + } + + +def _format_anthropic_messages( + messages: List[BaseMessage], +) -> Tuple[Optional[str], List[Dict]]: + """Format messages for anthropic.""" + + """ + [ + { + "role": _message_type_lookups[m.type], + "content": [_AnthropicMessageContent(text=m.content).dict()], + } + for m in messages + ] + """ + system: Optional[str] = None + formatted_messages: List[Dict] = [] + for i, message in enumerate(messages): + if message.type == "system": + if i != 0: + raise ValueError("System message must be at beginning of message list.") + if not isinstance(message.content, str): + raise ValueError( + "System message must be a string, " + f"instead was: {type(message.content)}" + ) + system = message.content + continue + + role = _message_type_lookups[message.type] + content: Union[str, List[Dict]] + + if not isinstance(message.content, str): + # parse as dict + assert isinstance( + message.content, list + ), "Anthropic message content must be str or list of dicts" + + # populate content + content = [] + for item in message.content: + if isinstance(item, str): + content.append( + { + "type": "text", + "text": item, + } + ) + elif isinstance(item, dict): + if "type" not in item: + raise ValueError("Dict content item must have a type key") + if item["type"] == "image_url": + # convert format + source = _format_image(item["image_url"]["url"]) + content.append( + { + "type": "image", + "source": source, + } + ) + else: + content.append(item) + else: + raise ValueError( + f"Content items must be str or dict, instead was: {type(item)}" + ) + else: + content = message.content + + formatted_messages.append( + { + "role": role, + "content": content, + } + ) + return system, formatted_messages + + +class ChatPromptAdapter: + """Adapter class to prepare the inputs from Langchain to prompt format + that Chat model expects. + """ + + @classmethod + def convert_messages_to_prompt( + cls, provider: str, messages: List[BaseMessage] + ) -> str: + if provider == "anthropic": + prompt = convert_messages_to_prompt_anthropic(messages=messages) + elif provider == "meta": + prompt = convert_messages_to_prompt_llama(messages=messages) + elif provider == "mistral": + prompt = convert_messages_to_prompt_mistral(messages=messages) + elif provider == "amazon": + prompt = convert_messages_to_prompt_anthropic( + messages=messages, + human_prompt="\n\nUser:", + ai_prompt="\n\nBot:", + ) + else: + raise NotImplementedError( + f"Provider {provider} model does not support chat." + ) + return prompt + + @classmethod + def format_messages( + cls, provider: str, messages: List[BaseMessage] + ) -> Tuple[Optional[str], List[Dict]]: + if provider == "anthropic": + return _format_anthropic_messages(messages) + + raise NotImplementedError( + f"Provider {provider} not supported for format_messages" + ) + + +_message_type_lookups = {"human": "user", "ai": "assistant"} + + +class BedrockChat(BaseChatModel, BedrockBase): + """A chat model that uses the Bedrock API.""" + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "amazon_bedrock_chat" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this model can be serialized by Langchain.""" + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "chat_models", "bedrock"] + + @property + def lc_attributes(self) -> Dict[str, Any]: + attributes: Dict[str, Any] = {} + + if self.region_name: + attributes["region_name"] = self.region_name + + return attributes + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + provider = self._get_provider() + prompt, system, formatted_messages = None, None, None + + if provider == "anthropic": + system, formatted_messages = ChatPromptAdapter.format_messages( + provider, messages + ) + else: + prompt = ChatPromptAdapter.convert_messages_to_prompt( + provider=provider, messages=messages + ) + + for chunk in self._prepare_input_and_invoke_stream( + prompt=prompt, + system=system, + messages=formatted_messages, + stop=stop, + run_manager=run_manager, + **kwargs, + ): + delta = chunk.text + yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + completion = "" + llm_output: Dict[str, Any] = {"model_id": self.model_id} + usage_info = {} + if self.streaming: + for chunk in self._stream(messages, stop, run_manager, **kwargs): + completion += chunk.text + else: + provider = self._get_provider() + prompt, system, formatted_messages = None, None, None + params: Dict[str, Any] = {**kwargs} + + if provider == "anthropic": + system, formatted_messages = ChatPromptAdapter.format_messages( + provider, messages + ) + else: + prompt = ChatPromptAdapter.convert_messages_to_prompt( + provider=provider, messages=messages + ) + + if stop: + params["stop_sequences"] = stop + + completion, usage_info = self._prepare_input_and_invoke( + prompt=prompt, + stop=stop, + run_manager=run_manager, + system=system, + messages=formatted_messages, + **params, + ) + + llm_output["usage"] = usage_info + + return ChatResult( + generations=[ + ChatGeneration( + message=AIMessage( + content=completion, additional_kwargs={"usage": usage_info} + ) + ) + ], + llm_output=llm_output, + ) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + final_usage: Dict[str, int] = defaultdict(int) + final_output = {} + for output in llm_outputs: + output = output or {} + usage = output.pop("usage", {}) + for token_type, token_count in usage.items(): + final_usage[token_type] += token_count + final_output.update(output) + final_output["usage"] = final_usage + return final_output + + def get_num_tokens(self, text: str) -> int: + if self._model_is_anthropic: + return get_num_tokens_anthropic(text) + else: + return super().get_num_tokens(text) + + def get_token_ids(self, text: str) -> List[int]: + if self._model_is_anthropic: + return get_token_ids_anthropic(text) + else: + return super().get_token_ids(text) diff --git a/libs/aws/langchain_aws/llms/__init__.py b/libs/aws/langchain_aws/llms/__init__.py index 4f7facff..3255494a 100644 --- a/libs/aws/langchain_aws/llms/__init__.py +++ b/libs/aws/langchain_aws/llms/__init__.py @@ -1,4 +1,15 @@ -from langchain_aws.llms.bedrock import ALTERNATION_ERROR, Bedrock, BedrockLLM +from langchain_aws.llms.bedrock import ( + ALTERNATION_ERROR, + Bedrock, + BedrockBase, + BedrockLLM, +) from langchain_aws.llms.sagemaker_endpoint import SagemakerEndpoint -__all__ = ["ALTERNATION_ERROR", "Bedrock", "BedrockLLM", "SagemakerEndpoint"] +__all__ = [ + "ALTERNATION_ERROR", + "Bedrock", + "BedrockBase", + "BedrockLLM", + "SagemakerEndpoint", +] diff --git a/libs/aws/tests/callbacks.py b/libs/aws/tests/callbacks.py new file mode 100644 index 00000000..66b54256 --- /dev/null +++ b/libs/aws/tests/callbacks.py @@ -0,0 +1,391 @@ +"""A fake callback handler for testing purposes.""" +from itertools import chain +from typing import Any, Dict, List, Optional, Union +from uuid import UUID + +from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler +from langchain_core.messages import BaseMessage +from langchain_core.pydantic_v1 import BaseModel + + +class BaseFakeCallbackHandler(BaseModel): + """Base fake callback handler for testing.""" + + starts: int = 0 + ends: int = 0 + errors: int = 0 + text: int = 0 + ignore_llm_: bool = False + ignore_chain_: bool = False + ignore_agent_: bool = False + ignore_retriever_: bool = False + ignore_chat_model_: bool = False + + # to allow for similar callback handlers that are not technicall equal + fake_id: Union[str, None] = None + + # add finer-grained counters for easier debugging of failing tests + chain_starts: int = 0 + chain_ends: int = 0 + llm_starts: int = 0 + llm_ends: int = 0 + llm_streams: int = 0 + tool_starts: int = 0 + tool_ends: int = 0 + agent_actions: int = 0 + agent_ends: int = 0 + chat_model_starts: int = 0 + retriever_starts: int = 0 + retriever_ends: int = 0 + retriever_errors: int = 0 + retries: int = 0 + + +class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): + """Base fake callback handler mixin for testing.""" + + def on_llm_start_common(self) -> None: + self.llm_starts += 1 + self.starts += 1 + + def on_llm_end_common(self) -> None: + self.llm_ends += 1 + self.ends += 1 + + def on_llm_error_common(self) -> None: + self.errors += 1 + + def on_llm_new_token_common(self) -> None: + self.llm_streams += 1 + + def on_retry_common(self) -> None: + self.retries += 1 + + def on_chain_start_common(self) -> None: + self.chain_starts += 1 + self.starts += 1 + + def on_chain_end_common(self) -> None: + self.chain_ends += 1 + self.ends += 1 + + def on_chain_error_common(self) -> None: + self.errors += 1 + + def on_tool_start_common(self) -> None: + self.tool_starts += 1 + self.starts += 1 + + def on_tool_end_common(self) -> None: + self.tool_ends += 1 + self.ends += 1 + + def on_tool_error_common(self) -> None: + self.errors += 1 + + def on_agent_action_common(self) -> None: + self.agent_actions += 1 + self.starts += 1 + + def on_agent_finish_common(self) -> None: + self.agent_ends += 1 + self.ends += 1 + + def on_chat_model_start_common(self) -> None: + self.chat_model_starts += 1 + self.starts += 1 + + def on_text_common(self) -> None: + self.text += 1 + + def on_retriever_start_common(self) -> None: + self.starts += 1 + self.retriever_starts += 1 + + def on_retriever_end_common(self) -> None: + self.ends += 1 + self.retriever_ends += 1 + + def on_retriever_error_common(self) -> None: + self.errors += 1 + self.retriever_errors += 1 + + +class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): + """Fake callback handler for testing.""" + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return self.ignore_llm_ + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return self.ignore_chain_ + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return self.ignore_agent_ + + @property + def ignore_retriever(self) -> bool: + """Whether to ignore retriever callbacks.""" + return self.ignore_retriever_ + + def on_llm_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_start_common() + + def on_llm_new_token( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_new_token_common() + + def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_end_common() + + def on_llm_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_error_common() + + def on_retry( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retry_common() + + def on_chain_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_start_common() + + def on_chain_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_end_common() + + def on_chain_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_error_common() + + def on_tool_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_start_common() + + def on_tool_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_end_common() + + def on_tool_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_error_common() + + def on_agent_action( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_agent_action_common() + + def on_agent_finish( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_agent_finish_common() + + def on_text( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_text_common() + + def on_retriever_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retriever_start_common() + + def on_retriever_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retriever_end_common() + + def on_retriever_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retriever_error_common() + + def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": + return self + + +class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + assert all(isinstance(m, BaseMessage) for m in chain(*messages)) + self.on_chat_model_start_common() + + +class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin): + """Fake async callback handler for testing.""" + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return self.ignore_llm_ + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return self.ignore_chain_ + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return self.ignore_agent_ + + async def on_retry( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retry_common() + + async def on_llm_start( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_start_common() + + async def on_llm_new_token( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_new_token_common() + + async def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_end_common() + + async def on_llm_error( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_error_common() + + async def on_chain_start( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_chain_start_common() + + async def on_chain_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_chain_end_common() + + async def on_chain_error( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_chain_error_common() + + async def on_tool_start( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_tool_start_common() + + async def on_tool_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_tool_end_common() + + async def on_tool_error( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_tool_error_common() + + async def on_agent_action( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_agent_action_common() + + async def on_agent_finish( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_agent_finish_common() + + async def on_text( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_text_common() + + def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": + return self diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py new file mode 100644 index 00000000..437389ff --- /dev/null +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -0,0 +1,162 @@ +"""Test Bedrock chat model.""" +from typing import Any, cast + +import pytest +from langchain_core.callbacks import CallbackManager +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, LLMResult + +from langchain_aws.chat_models.bedrock import BedrockChat +from tests.callbacks import FakeCallbackHandler + + +@pytest.fixture +def chat() -> BedrockChat: + return BedrockChat(model_id="anthropic.claude-v2", model_kwargs={"temperature": 0}) # type: ignore[call-arg] + + +@pytest.mark.scheduled +def test_chat_bedrock(chat: BedrockChat) -> None: + """Test BedrockChat wrapper.""" + system = SystemMessage(content="You are a helpful assistant.") + human = HumanMessage(content="Hello") + response = chat([system, human]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +@pytest.mark.scheduled +def test_chat_bedrock_generate(chat: BedrockChat) -> None: + """Test BedrockChat wrapper with generate.""" + message = HumanMessage(content="Hello") + response = chat.generate([[message], [message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + +@pytest.mark.scheduled +def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None: + """Test BedrockChat wrapper with generate.""" + message = HumanMessage(content="Hello") + response = chat.generate([[message], [message]]) + assert isinstance(response, LLMResult) + assert isinstance(response.llm_output, dict) + + usage = response.llm_output["usage"] + assert usage["prompt_tokens"] == 20 + assert usage["completion_tokens"] > 0 + assert usage["total_tokens"] > 0 + + +@pytest.mark.scheduled +def test_chat_bedrock_streaming() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + chat = BedrockChat( # type: ignore[call-arg] + model_id="anthropic.claude-v2", + streaming=True, + callbacks=[callback_handler], + verbose=True, + ) + message = HumanMessage(content="Hello") + response = chat([message]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, BaseMessage) + + +@pytest.mark.scheduled +def test_chat_bedrock_streaming_generation_info() -> None: + """Test that generation info is preserved when streaming.""" + + class _FakeCallback(FakeCallbackHandler): + saved_things: dict = {} + + def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + # Save the generation + self.saved_things["generation"] = args[0] + + callback = _FakeCallback() + callback_manager = CallbackManager([callback]) + chat = BedrockChat( # type: ignore[call-arg] + model_id="anthropic.claude-v2", + callback_manager=callback_manager, + ) + list(chat.stream("hi")) + generation = callback.saved_things["generation"] + # `Hello!` is two tokens, assert that that is what is returned + assert generation.generations[0][0].text == "Hello!" + + +@pytest.mark.scheduled +def test_bedrock_streaming(chat: BedrockChat) -> None: + """Test streaming tokens from OpenAI.""" + + full = None + for token in chat.stream("I'm Pickle Rick"): + full = token if full is None else full + token # type: ignore[operator] + assert isinstance(token.content, str) + assert isinstance(cast(AIMessageChunk, full).content, str) + + +@pytest.mark.scheduled +async def test_bedrock_astream(chat: BedrockChat) -> None: + """Test streaming tokens from OpenAI.""" + + async for token in chat.astream("I'm Pickle Rick"): + assert isinstance(token.content, str) + + +@pytest.mark.scheduled +async def test_bedrock_abatch(chat: BedrockChat) -> None: + """Test streaming tokens from BedrockChat.""" + result = await chat.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token.content, str) + + +@pytest.mark.scheduled +async def test_bedrock_abatch_tags(chat: BedrockChat) -> None: + """Test batch tokens from BedrockChat.""" + result = await chat.abatch( + ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} + ) + for token in result: + assert isinstance(token.content, str) + + +@pytest.mark.scheduled +def test_bedrock_batch(chat: BedrockChat) -> None: + """Test batch tokens from BedrockChat.""" + result = chat.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token.content, str) + + +@pytest.mark.scheduled +async def test_bedrock_ainvoke(chat: BedrockChat) -> None: + """Test invoke tokens from BedrockChat.""" + result = await chat.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) + assert isinstance(result.content, str) + + +@pytest.mark.scheduled +def test_bedrock_invoke(chat: BedrockChat) -> None: + """Test invoke tokens from BedrockChat.""" + result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) + assert isinstance(result.content, str) + assert "usage" in result.additional_kwargs # type: ignore[attr-defined] + assert result.additional_kwargs["usage"]["prompt_tokens"] == 13 # type: ignore[attr-defined] From c2ca4b1e1c86fd48526a93ca4806dca02a3ee0b0 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 3 Apr 2024 21:37:22 -0700 Subject: [PATCH 3/4] Fixed liniting. --- libs/aws/langchain_aws/chat_models/bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 4213fcfc..72f725c7 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -331,7 +331,7 @@ def _generate( ) -> ChatResult: completion = "" llm_output: Dict[str, Any] = {"model_id": self.model_id} - usage_info = {} + usage_info: Dict[str, Any] = {} if self.streaming: for chunk in self._stream(messages, stop, run_manager, **kwargs): completion += chunk.text From 4eae07ce19a1f76235751eca1bb1f98a87fe1517 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 3 Apr 2024 21:49:55 -0700 Subject: [PATCH 4/4] Renamed to ChatBedrock, fixed exports. --- libs/aws/langchain_aws/__init__.py | 3 ++ .../aws/langchain_aws/chat_models/__init__.py | 3 ++ libs/aws/langchain_aws/chat_models/bedrock.py | 8 ++++- .../chat_models/test_bedrock.py | 34 +++++++++---------- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index 610bbc0e..e4aef32d 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -1,3 +1,4 @@ +from langchain_aws.chat_models import BedrockChat, ChatBedrock from langchain_aws.llms import Bedrock, BedrockLLM, SagemakerEndpoint from langchain_aws.retrievers import ( AmazonKendraRetriever, @@ -7,6 +8,8 @@ __all__ = [ "Bedrock", "BedrockLLM", + "BedrockChat", + "ChatBedrock", "SagemakerEndpoint", "AmazonKendraRetriever", "AmazonKnowledgeBasesRetriever", diff --git a/libs/aws/langchain_aws/chat_models/__init__.py b/libs/aws/langchain_aws/chat_models/__init__.py index e69de29b..e334788a 100644 --- a/libs/aws/langchain_aws/chat_models/__init__.py +++ b/libs/aws/langchain_aws/chat_models/__init__.py @@ -0,0 +1,3 @@ +from langchain_aws.chat_models.bedrock import BedrockChat, ChatBedrock + +__all__ = ["BedrockChat", "ChatBedrock"] diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 72f725c7..5fa7182e 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -2,6 +2,7 @@ from collections import defaultdict from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast +from langchain_core._api.deprecation import deprecated from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) @@ -260,7 +261,7 @@ def format_messages( _message_type_lookups = {"human": "user", "ai": "assistant"} -class BedrockChat(BaseChatModel, BedrockBase): +class ChatBedrock(BaseChatModel, BedrockBase): """A chat model that uses the Bedrock API.""" @property @@ -397,3 +398,8 @@ def get_token_ids(self, text: str) -> List[int]: return get_token_ids_anthropic(text) else: return super().get_token_ids(text) + + +@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock") +class BedrockChat(ChatBedrock): + pass diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 437389ff..31f00caa 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -2,7 +2,6 @@ from typing import Any, cast import pytest -from langchain_core.callbacks import CallbackManager from langchain_core.messages import ( AIMessageChunk, BaseMessage, @@ -11,17 +10,17 @@ ) from langchain_core.outputs import ChatGeneration, LLMResult -from langchain_aws.chat_models.bedrock import BedrockChat +from langchain_aws.chat_models.bedrock import ChatBedrock from tests.callbacks import FakeCallbackHandler @pytest.fixture -def chat() -> BedrockChat: - return BedrockChat(model_id="anthropic.claude-v2", model_kwargs={"temperature": 0}) # type: ignore[call-arg] +def chat() -> ChatBedrock: + return ChatBedrock(model_id="anthropic.claude-v2", model_kwargs={"temperature": 0}) # type: ignore[call-arg] @pytest.mark.scheduled -def test_chat_bedrock(chat: BedrockChat) -> None: +def test_chat_bedrock(chat: ChatBedrock) -> None: """Test BedrockChat wrapper.""" system = SystemMessage(content="You are a helpful assistant.") human = HumanMessage(content="Hello") @@ -31,7 +30,7 @@ def test_chat_bedrock(chat: BedrockChat) -> None: @pytest.mark.scheduled -def test_chat_bedrock_generate(chat: BedrockChat) -> None: +def test_chat_bedrock_generate(chat: ChatBedrock) -> None: """Test BedrockChat wrapper with generate.""" message = HumanMessage(content="Hello") response = chat.generate([[message], [message]]) @@ -45,7 +44,7 @@ def test_chat_bedrock_generate(chat: BedrockChat) -> None: @pytest.mark.scheduled -def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None: +def test_chat_bedrock_generate_with_token_usage(chat: ChatBedrock) -> None: """Test BedrockChat wrapper with generate.""" message = HumanMessage(content="Hello") response = chat.generate([[message], [message]]) @@ -62,7 +61,7 @@ def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None: def test_chat_bedrock_streaming() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" callback_handler = FakeCallbackHandler() - chat = BedrockChat( # type: ignore[call-arg] + chat = ChatBedrock( # type: ignore[call-arg] model_id="anthropic.claude-v2", streaming=True, callbacks=[callback_handler], @@ -90,10 +89,9 @@ def on_llm_end( self.saved_things["generation"] = args[0] callback = _FakeCallback() - callback_manager = CallbackManager([callback]) - chat = BedrockChat( # type: ignore[call-arg] + chat = ChatBedrock( # type: ignore[call-arg] model_id="anthropic.claude-v2", - callback_manager=callback_manager, + callbacks=[callback], ) list(chat.stream("hi")) generation = callback.saved_things["generation"] @@ -102,7 +100,7 @@ def on_llm_end( @pytest.mark.scheduled -def test_bedrock_streaming(chat: BedrockChat) -> None: +def test_bedrock_streaming(chat: ChatBedrock) -> None: """Test streaming tokens from OpenAI.""" full = None @@ -113,7 +111,7 @@ def test_bedrock_streaming(chat: BedrockChat) -> None: @pytest.mark.scheduled -async def test_bedrock_astream(chat: BedrockChat) -> None: +async def test_bedrock_astream(chat: ChatBedrock) -> None: """Test streaming tokens from OpenAI.""" async for token in chat.astream("I'm Pickle Rick"): @@ -121,7 +119,7 @@ async def test_bedrock_astream(chat: BedrockChat) -> None: @pytest.mark.scheduled -async def test_bedrock_abatch(chat: BedrockChat) -> None: +async def test_bedrock_abatch(chat: ChatBedrock) -> None: """Test streaming tokens from BedrockChat.""" result = await chat.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) for token in result: @@ -129,7 +127,7 @@ async def test_bedrock_abatch(chat: BedrockChat) -> None: @pytest.mark.scheduled -async def test_bedrock_abatch_tags(chat: BedrockChat) -> None: +async def test_bedrock_abatch_tags(chat: ChatBedrock) -> None: """Test batch tokens from BedrockChat.""" result = await chat.abatch( ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} @@ -139,7 +137,7 @@ async def test_bedrock_abatch_tags(chat: BedrockChat) -> None: @pytest.mark.scheduled -def test_bedrock_batch(chat: BedrockChat) -> None: +def test_bedrock_batch(chat: ChatBedrock) -> None: """Test batch tokens from BedrockChat.""" result = chat.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) for token in result: @@ -147,14 +145,14 @@ def test_bedrock_batch(chat: BedrockChat) -> None: @pytest.mark.scheduled -async def test_bedrock_ainvoke(chat: BedrockChat) -> None: +async def test_bedrock_ainvoke(chat: ChatBedrock) -> None: """Test invoke tokens from BedrockChat.""" result = await chat.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) assert isinstance(result.content, str) @pytest.mark.scheduled -def test_bedrock_invoke(chat: BedrockChat) -> None: +def test_bedrock_invoke(chat: ChatBedrock) -> None: """Test invoke tokens from BedrockChat.""" result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str)