Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Langchain-AWS-67] Refactor ChatBedrock._combine_llm_outputs() to mat… #68

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _generate(
**params,
)

llm_output["model_id"] = self.model_id
llm_output["model_name"] = self.model_id
return ChatResult(
generations=[
ChatGeneration(
Expand All @@ -460,11 +460,11 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
final_output = {}
for output in llm_outputs:
output = output or {}
usage = output.get("usage", {})
for token_type, token_count in usage.items():
token_usage = output.get("usage", {})
for token_type, token_count in token_usage.items():
final_usage[token_type] += token_count
final_output.update(output)
final_output["usage"] = final_usage
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a breaking change.

what we probably want to do is use the new AIMessage.token_usage attribute which is meant to be the standardized way to record token usage: langchain-ai/langchain#21944

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, the standard Langchain is moving toward is the usage_metadata attribute on AIMessage, which is just a dict with keys "input_tokens", "output_tokens", "total_tokens". If you're able to access the usage data within the _generate method, you can construct this dict and then populate usage_metadata on the AIMessage it generates. Note that we'd need to bump langchain-core >= 0.2.2 in the dependencies.

final_output["token_usage"] = final_usage
return final_output

def get_num_tokens(self, text: str) -> int:
Expand Down
10 changes: 5 additions & 5 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _stream_response_to_generation_chunk(
if msg_type == "message_start":
usage_info = stream_response.get("message", {}).get("usage", None)
usage_info = _nest_usage_info_token_counts(usage_info)
generation_info = {"usage": usage_info}
generation_info = {"token_usage": usage_info}
return GenerationChunk(text="", generation_info=generation_info)
elif msg_type == "content_block_delta":
if not stream_response["delta"]:
Expand All @@ -104,7 +104,7 @@ def _stream_response_to_generation_chunk(
usage_info = stream_response.get("usage", None)
usage_info = _nest_usage_info_token_counts(usage_info)
stop_reason = stream_response.get("delta", {}).get("stop_reason")
generation_info = {"stop_reason": stop_reason, "usage": usage_info}
generation_info = {"stop_reason": stop_reason, "token_usage": usage_info}
return GenerationChunk(text="", generation_info=generation_info)
else:
return None
Expand Down Expand Up @@ -171,7 +171,7 @@ def _combine_generation_info_for_llm_result(
total_usage_info["prompt_tokens"] + total_usage_info["completion_tokens"]
)

return {"usage": total_usage_info, "stop_reason": stop_reason}
return {"token_usage": total_usage_info, "stop_reason": stop_reason}


class LLMInputOutputAdapter:
Expand Down Expand Up @@ -252,7 +252,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
return {
"text": text,
"body": response_body,
"usage": {
"token_usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
Expand Down Expand Up @@ -631,7 +631,7 @@ def _prepare_input_and_invoke(
if stop is not None:
text = enforce_stop_tokens(text, stop)

llm_output = {"usage": usage_info, "stop_reason": stop_reason}
llm_output = {"token_usage": usage_info, "stop_reason": stop_reason}

# Verify and raise a callback error if any intervention occurs or a signal is
# sent from a Bedrock service,
Expand Down
4 changes: 2 additions & 2 deletions libs/aws/tests/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,10 @@ def on_llm_end(
**kwargs: Any,
) -> Any:
if response.llm_output is not None:
self.input_token_count += response.llm_output.get("usage", {}).get(
self.input_token_count += response.llm_output.get("token_usage", {}).get(
"prompt_tokens", None
)
self.output_token_count += response.llm_output.get("usage", {}).get(
self.output_token_count += response.llm_output.get("token_usage", {}).get(
"completion_tokens", None
)
self.stop_reason = response.llm_output.get("stop_reason", None)
Expand Down
12 changes: 6 additions & 6 deletions libs/aws/tests/integration_tests/chat_models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def test_chat_bedrock_generate_with_token_usage(chat: ChatBedrock) -> None:
assert isinstance(response, LLMResult)
assert isinstance(response.llm_output, dict)

usage = response.llm_output["usage"]
assert usage["prompt_tokens"] == 20
assert usage["completion_tokens"] > 0
assert usage["total_tokens"] > 0
token_usage = response.llm_output["token_usage"]
assert token_usage["prompt_tokens"] == 20
assert token_usage["completion_tokens"] > 0
assert token_usage["total_tokens"] > 0


@pytest.mark.scheduled
Expand Down Expand Up @@ -176,8 +176,8 @@ def test_bedrock_invoke(chat: ChatBedrock) -> None:
"""Test invoke tokens from BedrockChat."""
result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
assert "usage" in result.additional_kwargs
assert result.additional_kwargs["usage"]["prompt_tokens"] == 13
assert "token_usage" in result.additional_kwargs
assert result.additional_kwargs["token_usage"]["prompt_tokens"] == 13


@pytest.mark.scheduled
Expand Down
30 changes: 15 additions & 15 deletions libs/aws/tests/unit_tests/llms/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,18 +405,18 @@ def response_with_stop_reason():
def test_prepare_output_for_mistral(mistral_response):
result = LLMInputOutputAdapter.prepare_output("mistral", mistral_response)
assert result["text"] == "This is the Mistral output text."
assert result["usage"]["prompt_tokens"] == 18
assert result["usage"]["completion_tokens"] == 28
assert result["usage"]["total_tokens"] == 46
assert result["token_usage"]["prompt_tokens"] == 18
assert result["token_usage"]["completion_tokens"] == 28
assert result["token_usage"]["total_tokens"] == 46
assert result["stop_reason"] is None


def test_prepare_output_for_cohere(cohere_response):
result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response)
assert result["text"] == "This is the Cohere output text."
assert result["usage"]["prompt_tokens"] == 12
assert result["usage"]["completion_tokens"] == 22
assert result["usage"]["total_tokens"] == 34
assert result["token_usage"]["prompt_tokens"] == 12
assert result["token_usage"]["completion_tokens"] == 22
assert result["token_usage"]["total_tokens"] == 34
assert result["stop_reason"] is None


Expand All @@ -425,25 +425,25 @@ def test_prepare_output_with_stop_reason(response_with_stop_reason):
"anthropic", response_with_stop_reason
)
assert result["text"] == "This is the output text."
assert result["usage"]["prompt_tokens"] == 10
assert result["usage"]["completion_tokens"] == 20
assert result["usage"]["total_tokens"] == 30
assert result["token_usage"]["prompt_tokens"] == 10
assert result["token_usage"]["completion_tokens"] == 20
assert result["token_usage"]["total_tokens"] == 30
assert result["stop_reason"] == "length"


def test_prepare_output_for_anthropic(anthropic_response):
result = LLMInputOutputAdapter.prepare_output("anthropic", anthropic_response)
assert result["text"] == "This is the output text."
assert result["usage"]["prompt_tokens"] == 10
assert result["usage"]["completion_tokens"] == 20
assert result["usage"]["total_tokens"] == 30
assert result["token_usage"]["prompt_tokens"] == 10
assert result["token_usage"]["completion_tokens"] == 20
assert result["token_usage"]["total_tokens"] == 30
assert result["stop_reason"] is None


def test_prepare_output_for_ai21(ai21_response):
result = LLMInputOutputAdapter.prepare_output("ai21", ai21_response)
assert result["text"] == "This is the AI21 output text."
assert result["usage"]["prompt_tokens"] == 15
assert result["usage"]["completion_tokens"] == 25
assert result["usage"]["total_tokens"] == 40
assert result["token_usage"]["prompt_tokens"] == 15
assert result["token_usage"]["completion_tokens"] == 25
assert result["token_usage"]["total_tokens"] == 40
assert result["stop_reason"] is None
Loading