diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 52153e39..bcf407ab 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -445,7 +445,7 @@ def _generate( **params, ) - llm_output["model_id"] = self.model_id + llm_output["model_name"] = self.model_id return ChatResult( generations=[ ChatGeneration( @@ -460,11 +460,11 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: final_output = {} for output in llm_outputs: output = output or {} - usage = output.get("usage", {}) - for token_type, token_count in usage.items(): + token_usage = output.get("usage", {}) + for token_type, token_count in token_usage.items(): final_usage[token_type] += token_count final_output.update(output) - final_output["usage"] = final_usage + final_output["token_usage"] = final_usage return final_output def get_num_tokens(self, text: str) -> int: diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index c20e3365..ffc32d47 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -89,7 +89,7 @@ def _stream_response_to_generation_chunk( if msg_type == "message_start": usage_info = stream_response.get("message", {}).get("usage", None) usage_info = _nest_usage_info_token_counts(usage_info) - generation_info = {"usage": usage_info} + generation_info = {"token_usage": usage_info} return GenerationChunk(text="", generation_info=generation_info) elif msg_type == "content_block_delta": if not stream_response["delta"]: @@ -104,7 +104,7 @@ def _stream_response_to_generation_chunk( usage_info = stream_response.get("usage", None) usage_info = _nest_usage_info_token_counts(usage_info) stop_reason = stream_response.get("delta", {}).get("stop_reason") - generation_info = {"stop_reason": stop_reason, "usage": usage_info} + generation_info = {"stop_reason": stop_reason, "token_usage": usage_info} return GenerationChunk(text="", generation_info=generation_info) else: return None @@ -171,7 +171,7 @@ def _combine_generation_info_for_llm_result( total_usage_info["prompt_tokens"] + total_usage_info["completion_tokens"] ) - return {"usage": total_usage_info, "stop_reason": stop_reason} + return {"token_usage": total_usage_info, "stop_reason": stop_reason} class LLMInputOutputAdapter: @@ -252,7 +252,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict: return { "text": text, "body": response_body, - "usage": { + "token_usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, @@ -631,7 +631,7 @@ def _prepare_input_and_invoke( if stop is not None: text = enforce_stop_tokens(text, stop) - llm_output = {"usage": usage_info, "stop_reason": stop_reason} + llm_output = {"token_usage": usage_info, "stop_reason": stop_reason} # Verify and raise a callback error if any intervention occurs or a signal is # sent from a Bedrock service, diff --git a/libs/aws/tests/callbacks.py b/libs/aws/tests/callbacks.py index 3a3902a0..b64815b9 100644 --- a/libs/aws/tests/callbacks.py +++ b/libs/aws/tests/callbacks.py @@ -286,10 +286,10 @@ def on_llm_end( **kwargs: Any, ) -> Any: if response.llm_output is not None: - self.input_token_count += response.llm_output.get("usage", {}).get( + self.input_token_count += response.llm_output.get("token_usage", {}).get( "prompt_tokens", None ) - self.output_token_count += response.llm_output.get("usage", {}).get( + self.output_token_count += response.llm_output.get("token_usage", {}).get( "completion_tokens", None ) self.stop_reason = response.llm_output.get("stop_reason", None) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index 6d1fb57a..7219af94 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -54,10 +54,10 @@ def test_chat_bedrock_generate_with_token_usage(chat: ChatBedrock) -> None: assert isinstance(response, LLMResult) assert isinstance(response.llm_output, dict) - usage = response.llm_output["usage"] - assert usage["prompt_tokens"] == 20 - assert usage["completion_tokens"] > 0 - assert usage["total_tokens"] > 0 + token_usage = response.llm_output["token_usage"] + assert token_usage["prompt_tokens"] == 20 + assert token_usage["completion_tokens"] > 0 + assert token_usage["total_tokens"] > 0 @pytest.mark.scheduled @@ -176,8 +176,8 @@ def test_bedrock_invoke(chat: ChatBedrock) -> None: """Test invoke tokens from BedrockChat.""" result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) - assert "usage" in result.additional_kwargs - assert result.additional_kwargs["usage"]["prompt_tokens"] == 13 + assert "token_usage" in result.additional_kwargs + assert result.additional_kwargs["token_usage"]["prompt_tokens"] == 13 @pytest.mark.scheduled diff --git a/libs/aws/tests/unit_tests/llms/test_bedrock.py b/libs/aws/tests/unit_tests/llms/test_bedrock.py index 7693cb19..d5a0cab5 100644 --- a/libs/aws/tests/unit_tests/llms/test_bedrock.py +++ b/libs/aws/tests/unit_tests/llms/test_bedrock.py @@ -405,18 +405,18 @@ def response_with_stop_reason(): def test_prepare_output_for_mistral(mistral_response): result = LLMInputOutputAdapter.prepare_output("mistral", mistral_response) assert result["text"] == "This is the Mistral output text." - assert result["usage"]["prompt_tokens"] == 18 - assert result["usage"]["completion_tokens"] == 28 - assert result["usage"]["total_tokens"] == 46 + assert result["token_usage"]["prompt_tokens"] == 18 + assert result["token_usage"]["completion_tokens"] == 28 + assert result["token_usage"]["total_tokens"] == 46 assert result["stop_reason"] is None def test_prepare_output_for_cohere(cohere_response): result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response) assert result["text"] == "This is the Cohere output text." - assert result["usage"]["prompt_tokens"] == 12 - assert result["usage"]["completion_tokens"] == 22 - assert result["usage"]["total_tokens"] == 34 + assert result["token_usage"]["prompt_tokens"] == 12 + assert result["token_usage"]["completion_tokens"] == 22 + assert result["token_usage"]["total_tokens"] == 34 assert result["stop_reason"] is None @@ -425,25 +425,25 @@ def test_prepare_output_with_stop_reason(response_with_stop_reason): "anthropic", response_with_stop_reason ) assert result["text"] == "This is the output text." - assert result["usage"]["prompt_tokens"] == 10 - assert result["usage"]["completion_tokens"] == 20 - assert result["usage"]["total_tokens"] == 30 + assert result["token_usage"]["prompt_tokens"] == 10 + assert result["token_usage"]["completion_tokens"] == 20 + assert result["token_usage"]["total_tokens"] == 30 assert result["stop_reason"] == "length" def test_prepare_output_for_anthropic(anthropic_response): result = LLMInputOutputAdapter.prepare_output("anthropic", anthropic_response) assert result["text"] == "This is the output text." - assert result["usage"]["prompt_tokens"] == 10 - assert result["usage"]["completion_tokens"] == 20 - assert result["usage"]["total_tokens"] == 30 + assert result["token_usage"]["prompt_tokens"] == 10 + assert result["token_usage"]["completion_tokens"] == 20 + assert result["token_usage"]["total_tokens"] == 30 assert result["stop_reason"] is None def test_prepare_output_for_ai21(ai21_response): result = LLMInputOutputAdapter.prepare_output("ai21", ai21_response) assert result["text"] == "This is the AI21 output text." - assert result["usage"]["prompt_tokens"] == 15 - assert result["usage"]["completion_tokens"] == 25 - assert result["usage"]["total_tokens"] == 40 + assert result["token_usage"]["prompt_tokens"] == 15 + assert result["token_usage"]["completion_tokens"] == 25 + assert result["token_usage"]["total_tokens"] == 40 assert result["stop_reason"] is None