diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 9dc7e5b7..8f0a8a11 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -64,6 +64,41 @@ def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str: ) +def _convert_one_message_to_text_llama3(message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = ( + f"<|start_header_id|>{message.role}" + f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, HumanMessage): + message_text = ( + f"<|start_header_id|>user" f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, AIMessage): + message_text = ( + f"<|start_header_id|>assistant" + f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, SystemMessage): + message_text = ( + f"<|start_header_id|>system" f"<|end_header_id|>{message.content}<|eot_id|>" + ) + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + +def convert_messages_to_prompt_llama3(messages: List[BaseMessage]) -> str: + """Convert a list of messages to a prompt for llama.""" + + return "\n".join( + ["<|begin_of_text|>"] + + [_convert_one_message_to_text_llama3(message) for message in messages] + + ["<|start_header_id|>assistant<|end_header_id|>\n\n"] + ) + + def _convert_one_message_to_text_anthropic( message: BaseMessage, human_prompt: str, @@ -243,12 +278,15 @@ class ChatPromptAdapter: @classmethod def convert_messages_to_prompt( - cls, provider: str, messages: List[BaseMessage] + cls, provider: str, messages: List[BaseMessage], model: str ) -> str: if provider == "anthropic": prompt = convert_messages_to_prompt_anthropic(messages=messages) elif provider == "meta": - prompt = convert_messages_to_prompt_llama(messages=messages) + if "llama3" in model: + prompt = convert_messages_to_prompt_llama3(messages=messages) + else: + prompt = convert_messages_to_prompt_llama(messages=messages) elif provider == "mistral": prompt = convert_messages_to_prompt_mistral(messages=messages) elif provider == "amazon": @@ -333,7 +371,7 @@ def _stream( system = self.system_prompt_with_tools else: prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages + provider=provider, messages=messages, model=self._get_model() ) for chunk in self._prepare_input_and_invoke_stream( @@ -376,7 +414,7 @@ def _generate( system = self.system_prompt_with_tools else: prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages + provider=provider, messages=messages, model=self._get_model() ) if stop: diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 90b860ab..1f900024 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -463,6 +463,9 @@ def _get_provider(self) -> str: return self.model_id.split(".")[0] + def _get_model(self) -> str: + return self.model_id.split(".")[1] + @property def _model_is_anthropic(self) -> bool: return self._get_provider() == "anthropic" diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 42882149..cd6265b9 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -1,4 +1,5 @@ """Test Bedrock chat model.""" + from typing import Any, cast import pytest @@ -74,6 +75,22 @@ def test_chat_bedrock_streaming() -> None: assert isinstance(response, BaseMessage) +@pytest.mark.scheduled +def test_chat_bedrock_streaming_llama3() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + chat = ChatBedrock( # type: ignore[call-arg] + model_id="meta.llama3-8b-instruct-v1:0", + streaming=True, + callbacks=[callback_handler], + verbose=True, + ) + message = HumanMessage(content="Hello") + response = chat([message]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, BaseMessage) + + @pytest.mark.scheduled def test_chat_bedrock_streaming_generation_info() -> None: """Test that generation info is preserved when streaming."""