Skip to content

Commit

Permalink
Renamed to ChatBedrock, fixed exports.
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Apr 4, 2024
1 parent c2ca4b1 commit 4eae07c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 19 deletions.
3 changes: 3 additions & 0 deletions libs/aws/langchain_aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from langchain_aws.chat_models import BedrockChat, ChatBedrock
from langchain_aws.llms import Bedrock, BedrockLLM, SagemakerEndpoint
from langchain_aws.retrievers import (
AmazonKendraRetriever,
Expand All @@ -7,6 +8,8 @@
__all__ = [
"Bedrock",
"BedrockLLM",
"BedrockChat",
"ChatBedrock",
"SagemakerEndpoint",
"AmazonKendraRetriever",
"AmazonKnowledgeBasesRetriever",
Expand Down
3 changes: 3 additions & 0 deletions libs/aws/langchain_aws/chat_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from langchain_aws.chat_models.bedrock import BedrockChat, ChatBedrock

__all__ = ["BedrockChat", "ChatBedrock"]
8 changes: 7 additions & 1 deletion libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast

from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
Expand Down Expand Up @@ -260,7 +261,7 @@ def format_messages(
_message_type_lookups = {"human": "user", "ai": "assistant"}


class BedrockChat(BaseChatModel, BedrockBase):
class ChatBedrock(BaseChatModel, BedrockBase):
"""A chat model that uses the Bedrock API."""

@property
Expand Down Expand Up @@ -397,3 +398,8 @@ def get_token_ids(self, text: str) -> List[int]:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)


@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatBedrock")
class BedrockChat(ChatBedrock):
pass
34 changes: 16 additions & 18 deletions libs/aws/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, cast

import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
Expand All @@ -11,17 +10,17 @@
)
from langchain_core.outputs import ChatGeneration, LLMResult

from langchain_aws.chat_models.bedrock import BedrockChat
from langchain_aws.chat_models.bedrock import ChatBedrock
from tests.callbacks import FakeCallbackHandler


@pytest.fixture
def chat() -> BedrockChat:
return BedrockChat(model_id="anthropic.claude-v2", model_kwargs={"temperature": 0}) # type: ignore[call-arg]
def chat() -> ChatBedrock:
return ChatBedrock(model_id="anthropic.claude-v2", model_kwargs={"temperature": 0}) # type: ignore[call-arg]


@pytest.mark.scheduled
def test_chat_bedrock(chat: BedrockChat) -> None:
def test_chat_bedrock(chat: ChatBedrock) -> None:
"""Test BedrockChat wrapper."""
system = SystemMessage(content="You are a helpful assistant.")
human = HumanMessage(content="Hello")
Expand All @@ -31,7 +30,7 @@ def test_chat_bedrock(chat: BedrockChat) -> None:


@pytest.mark.scheduled
def test_chat_bedrock_generate(chat: BedrockChat) -> None:
def test_chat_bedrock_generate(chat: ChatBedrock) -> None:
"""Test BedrockChat wrapper with generate."""
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
Expand All @@ -45,7 +44,7 @@ def test_chat_bedrock_generate(chat: BedrockChat) -> None:


@pytest.mark.scheduled
def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None:
def test_chat_bedrock_generate_with_token_usage(chat: ChatBedrock) -> None:
"""Test BedrockChat wrapper with generate."""
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
Expand All @@ -62,7 +61,7 @@ def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None:
def test_chat_bedrock_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
chat = BedrockChat( # type: ignore[call-arg]
chat = ChatBedrock( # type: ignore[call-arg]
model_id="anthropic.claude-v2",
streaming=True,
callbacks=[callback_handler],
Expand Down Expand Up @@ -90,10 +89,9 @@ def on_llm_end(
self.saved_things["generation"] = args[0]

callback = _FakeCallback()
callback_manager = CallbackManager([callback])
chat = BedrockChat( # type: ignore[call-arg]
chat = ChatBedrock( # type: ignore[call-arg]
model_id="anthropic.claude-v2",
callback_manager=callback_manager,
callbacks=[callback],
)
list(chat.stream("hi"))
generation = callback.saved_things["generation"]
Expand All @@ -102,7 +100,7 @@ def on_llm_end(


@pytest.mark.scheduled
def test_bedrock_streaming(chat: BedrockChat) -> None:
def test_bedrock_streaming(chat: ChatBedrock) -> None:
"""Test streaming tokens from OpenAI."""

full = None
Expand All @@ -113,23 +111,23 @@ def test_bedrock_streaming(chat: BedrockChat) -> None:


@pytest.mark.scheduled
async def test_bedrock_astream(chat: BedrockChat) -> None:
async def test_bedrock_astream(chat: ChatBedrock) -> None:
"""Test streaming tokens from OpenAI."""

async for token in chat.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)


@pytest.mark.scheduled
async def test_bedrock_abatch(chat: BedrockChat) -> None:
async def test_bedrock_abatch(chat: ChatBedrock) -> None:
"""Test streaming tokens from BedrockChat."""
result = await chat.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)


@pytest.mark.scheduled
async def test_bedrock_abatch_tags(chat: BedrockChat) -> None:
async def test_bedrock_abatch_tags(chat: ChatBedrock) -> None:
"""Test batch tokens from BedrockChat."""
result = await chat.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
Expand All @@ -139,22 +137,22 @@ async def test_bedrock_abatch_tags(chat: BedrockChat) -> None:


@pytest.mark.scheduled
def test_bedrock_batch(chat: BedrockChat) -> None:
def test_bedrock_batch(chat: ChatBedrock) -> None:
"""Test batch tokens from BedrockChat."""
result = chat.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)


@pytest.mark.scheduled
async def test_bedrock_ainvoke(chat: BedrockChat) -> None:
async def test_bedrock_ainvoke(chat: ChatBedrock) -> None:
"""Test invoke tokens from BedrockChat."""
result = await chat.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)


@pytest.mark.scheduled
def test_bedrock_invoke(chat: BedrockChat) -> None:
def test_bedrock_invoke(chat: ChatBedrock) -> None:
"""Test invoke tokens from BedrockChat."""
result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
Expand Down

0 comments on commit 4eae07c

Please sign in to comment.