From 902252bb612b96c6bc79fc95594f35c1b8217521 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 16 Oct 2024 19:11:05 -0700 Subject: [PATCH] Fixes token logging in callbacks when streaming=True is used. (#241) Fixes #240 Fixes #217 ### Code to verify ```python from langchain_aws import ChatBedrock from langchain.callbacks.base import BaseCallbackHandler from langchain_core.prompts import ChatPromptTemplate streaming = True class MyCustomHandler(BaseCallbackHandler): def on_llm_new_token(self, token: str, **kwargs) -> None: print(f"My custom handler, token: {token}") prompt = ChatPromptTemplate.from_messages(["Tell me a joke about {animal} in a few words."]) model = ChatBedrock( model_id="anthropic.claude-3-haiku-20240307-v1:0", streaming = streaming, callbacks=[MyCustomHandler()] ) chain = prompt | model response = chain.invoke({"animal": "bears"}) ``` ### Output ``` My custom handler, token: My custom handler, token: Bear My custom handler, token: - My custom handler, token: ly funny My custom handler, token: . My custom handler, token: My custom handler, token: ``` --- libs/aws/langchain_aws/chat_models/bedrock.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 23a842a2..54d34ffb 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -477,14 +477,19 @@ def _stream( **kwargs, ): if isinstance(chunk, AIMessageChunk): - yield ChatGenerationChunk(message=chunk) + generation_chunk = ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk else: delta = chunk.text if generation_info := chunk.generation_info: usage_metadata = generation_info.pop("usage_metadata", None) else: usage_metadata = None - yield ChatGenerationChunk( + generation_chunk = ChatGenerationChunk( message=AIMessageChunk( content=delta, response_metadata=chunk.generation_info, @@ -493,6 +498,11 @@ def _stream( if chunk.generation_info is not None else AIMessageChunk(content=delta) ) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk def _generate( self,