Skip to content

Commit

Permalink
Merge pull request #32 from fedor-intercom/add_llama3_support
Browse files Browse the repository at this point in the history
Adding support for Llama3 models in BedrockChat
  • Loading branch information
3coins authored May 10, 2024
2 parents 22fd6d3 + 1f212f4 commit 2efb770
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 4 deletions.
46 changes: 42 additions & 4 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,41 @@ def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str:
)


def _convert_one_message_to_text_llama3(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = (
f"<|start_header_id|>{message.role}"
f"<|end_header_id|>{message.content}<|eot_id|>"
)
elif isinstance(message, HumanMessage):
message_text = (
f"<|start_header_id|>user" f"<|end_header_id|>{message.content}<|eot_id|>"
)
elif isinstance(message, AIMessage):
message_text = (
f"<|start_header_id|>assistant"
f"<|end_header_id|>{message.content}<|eot_id|>"
)
elif isinstance(message, SystemMessage):
message_text = (
f"<|start_header_id|>system" f"<|end_header_id|>{message.content}<|eot_id|>"
)
else:
raise ValueError(f"Got unknown type {message}")

return message_text


def convert_messages_to_prompt_llama3(messages: List[BaseMessage]) -> str:
"""Convert a list of messages to a prompt for llama."""

return "\n".join(
["<|begin_of_text|>"]
+ [_convert_one_message_to_text_llama3(message) for message in messages]
+ ["<|start_header_id|>assistant<|end_header_id|>\n\n"]
)


def _convert_one_message_to_text_anthropic(
message: BaseMessage,
human_prompt: str,
Expand Down Expand Up @@ -243,12 +278,15 @@ class ChatPromptAdapter:

@classmethod
def convert_messages_to_prompt(
cls, provider: str, messages: List[BaseMessage]
cls, provider: str, messages: List[BaseMessage], model: str
) -> str:
if provider == "anthropic":
prompt = convert_messages_to_prompt_anthropic(messages=messages)
elif provider == "meta":
prompt = convert_messages_to_prompt_llama(messages=messages)
if "llama3" in model:
prompt = convert_messages_to_prompt_llama3(messages=messages)
else:
prompt = convert_messages_to_prompt_llama(messages=messages)
elif provider == "mistral":
prompt = convert_messages_to_prompt_mistral(messages=messages)
elif provider == "amazon":
Expand Down Expand Up @@ -333,7 +371,7 @@ def _stream(
system = self.system_prompt_with_tools
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
provider=provider, messages=messages, model=self._get_model()
)

for chunk in self._prepare_input_and_invoke_stream(
Expand Down Expand Up @@ -376,7 +414,7 @@ def _generate(
system = self.system_prompt_with_tools
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
provider=provider, messages=messages, model=self._get_model()
)

if stop:
Expand Down
3 changes: 3 additions & 0 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ def _get_provider(self) -> str:

return self.model_id.split(".")[0]

def _get_model(self) -> str:
return self.model_id.split(".")[1]

@property
def _model_is_anthropic(self) -> bool:
return self._get_provider() == "anthropic"
Expand Down
17 changes: 17 additions & 0 deletions libs/aws/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test Bedrock chat model."""

from typing import Any, cast

import pytest
Expand Down Expand Up @@ -74,6 +75,22 @@ def test_chat_bedrock_streaming() -> None:
assert isinstance(response, BaseMessage)


@pytest.mark.scheduled
def test_chat_bedrock_streaming_llama3() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
chat = ChatBedrock( # type: ignore[call-arg]
model_id="meta.llama3-8b-instruct-v1:0",
streaming=True,
callbacks=[callback_handler],
verbose=True,
)
message = HumanMessage(content="Hello")
response = chat([message])
assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage)


@pytest.mark.scheduled
def test_chat_bedrock_streaming_generation_info() -> None:
"""Test that generation info is preserved when streaming."""
Expand Down

0 comments on commit 2efb770

Please sign in to comment.