Skip to content

Commit

Permalink
Merge pull request #57 from 3coins/fix-stop-reason-error
Browse files Browse the repository at this point in the history
Fixes error when stop reason missing in response.
  • Loading branch information
3coins authored May 23, 2024
2 parents 9adc14d + 83bb6b0 commit 57ea56b
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 1 deletion.
2 changes: 2 additions & 0 deletions libs/aws/langchain_aws/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Bedrock,
BedrockBase,
BedrockLLM,
LLMInputOutputAdapter,
)
from langchain_aws.llms.sagemaker_endpoint import SagemakerEndpoint

Expand All @@ -11,5 +12,6 @@
"Bedrock",
"BedrockBase",
"BedrockLLM",
"LLMInputOutputAdapter",
"SagemakerEndpoint",
]
2 changes: 1 addition & 1 deletion libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"stop_reason": response_body["stop_reason"],
"stop_reason": response_body.get("stop_reason"),
}

@classmethod
Expand Down
141 changes: 141 additions & 0 deletions libs/aws/tests/unit_tests/llms/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# type:ignore

import json
from typing import AsyncGenerator, Dict
from unittest.mock import MagicMock, patch
Expand All @@ -7,6 +9,7 @@
from langchain_aws import BedrockLLM
from langchain_aws.llms.bedrock import (
ALTERNATION_ERROR,
LLMInputOutputAdapter,
_human_assistant_format,
)

Expand Down Expand Up @@ -306,3 +309,141 @@ async def test_bedrock_async_streaming_call() -> None:
assert chunks[0] == "nice"
assert chunks[1] == " to meet"
assert chunks[2] == " you"


@pytest.fixture
def mistral_response():
body = MagicMock()
body.read.return_value = json.dumps(
{"outputs": [{"text": "This is the Mistral output text."}]}
).encode()
response = dict(
body=body,
ResponseMetadata={
"HTTPHeaders": {
"x-amzn-bedrock-input-token-count": "18",
"x-amzn-bedrock-output-token-count": "28",
}
},
)

return response


@pytest.fixture
def cohere_response():
body = MagicMock()
body.read.return_value = json.dumps(
{"generations": [{"text": "This is the Cohere output text."}]}
).encode()
response = dict(
body=body,
ResponseMetadata={
"HTTPHeaders": {
"x-amzn-bedrock-input-token-count": "12",
"x-amzn-bedrock-output-token-count": "22",
}
},
)
return response


@pytest.fixture
def anthropic_response():
body = MagicMock()
body.read.return_value = json.dumps(
{"completion": "This is the output text."}
).encode()
response = dict(
body=body,
ResponseMetadata={
"HTTPHeaders": {
"x-amzn-bedrock-input-token-count": "10",
"x-amzn-bedrock-output-token-count": "20",
}
},
)
return response


@pytest.fixture
def ai21_response():
body = MagicMock()
body.read.return_value = json.dumps(
{"completions": [{"data": {"text": "This is the AI21 output text."}}]}
).encode()
response = dict(
body=body,
ResponseMetadata={
"HTTPHeaders": {
"x-amzn-bedrock-input-token-count": "15",
"x-amzn-bedrock-output-token-count": "25",
}
},
)
return response


@pytest.fixture
def response_with_stop_reason():
body = MagicMock()
body.read.return_value = json.dumps(
{"completion": "This is the output text.", "stop_reason": "length"}
).encode()
response = dict(
body=body,
ResponseMetadata={
"HTTPHeaders": {
"x-amzn-bedrock-input-token-count": "10",
"x-amzn-bedrock-output-token-count": "20",
}
},
)
return response


def test_prepare_output_for_mistral(mistral_response):
result = LLMInputOutputAdapter.prepare_output("mistral", mistral_response)
assert result["text"] == "This is the Mistral output text."
assert result["usage"]["prompt_tokens"] == 18
assert result["usage"]["completion_tokens"] == 28
assert result["usage"]["total_tokens"] == 46
assert result["stop_reason"] is None


def test_prepare_output_for_cohere(cohere_response):
result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response)
assert result["text"] == "This is the Cohere output text."
assert result["usage"]["prompt_tokens"] == 12
assert result["usage"]["completion_tokens"] == 22
assert result["usage"]["total_tokens"] == 34
assert result["stop_reason"] is None


def test_prepare_output_with_stop_reason(response_with_stop_reason):
result = LLMInputOutputAdapter.prepare_output(
"anthropic", response_with_stop_reason
)
assert result["text"] == "This is the output text."
assert result["usage"]["prompt_tokens"] == 10
assert result["usage"]["completion_tokens"] == 20
assert result["usage"]["total_tokens"] == 30
assert result["stop_reason"] == "length"


def test_prepare_output_for_anthropic(anthropic_response):
result = LLMInputOutputAdapter.prepare_output("anthropic", anthropic_response)
assert result["text"] == "This is the output text."
assert result["usage"]["prompt_tokens"] == 10
assert result["usage"]["completion_tokens"] == 20
assert result["usage"]["total_tokens"] == 30
assert result["stop_reason"] is None


def test_prepare_output_for_ai21(ai21_response):
result = LLMInputOutputAdapter.prepare_output("ai21", ai21_response)
assert result["text"] == "This is the AI21 output text."
assert result["usage"]["prompt_tokens"] == 15
assert result["usage"]["completion_tokens"] == 25
assert result["usage"]["total_tokens"] == 40
assert result["stop_reason"] is None

0 comments on commit 57ea56b

Please sign in to comment.