Skip to content

Commit

Permalink
Fixes token logging in callbacks when streaming=True is used. (#241)
Browse files Browse the repository at this point in the history
Fixes #240
Fixes #217 

### Code to verify
```python
from langchain_aws import ChatBedrock
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.prompts import ChatPromptTemplate

streaming = True

class MyCustomHandler(BaseCallbackHandler):
    def on_llm_new_token(self, token: str, **kwargs) -> None:
        print(f"My custom handler, token: {token}")

prompt = ChatPromptTemplate.from_messages(["Tell me a joke about {animal} in a few words."])

model = ChatBedrock(
    model_id="anthropic.claude-3-haiku-20240307-v1:0",
    streaming = streaming, 
    callbacks=[MyCustomHandler()]
)

chain = prompt | model

response = chain.invoke({"animal": "bears"})
```

### Output
```
My custom handler, token: 
My custom handler, token: Bear
My custom handler, token: -
My custom handler, token: ly funny
My custom handler, token: .
My custom handler, token: 
My custom handler, token: 
```
  • Loading branch information
3coins authored Oct 17, 2024
1 parent 48535f0 commit 902252b
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,19 @@ def _stream(
**kwargs,
):
if isinstance(chunk, AIMessageChunk):
yield ChatGenerationChunk(message=chunk)
generation_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk
else:
delta = chunk.text
if generation_info := chunk.generation_info:
usage_metadata = generation_info.pop("usage_metadata", None)
else:
usage_metadata = None
yield ChatGenerationChunk(
generation_chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=delta,
response_metadata=chunk.generation_info,
Expand All @@ -493,6 +498,11 @@ def _stream(
if chunk.generation_info is not None
else AIMessageChunk(content=delta)
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk

def _generate(
self,
Expand Down

0 comments on commit 902252b

Please sign in to comment.