diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 0b701708..7c5003a4 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -84,27 +84,27 @@ def _stream_response_to_generation_chunk( ) -> GenerationChunk: """Convert a stream response to a generation chunk.""" if messages_api: - match stream_response.get("type"): - case "message_start": - usage_info = stream_response.get("message", {}).get("usage", None) - generation_info = {"usage": usage_info} - return GenerationChunk(text="", generation_info=generation_info) - case "content_block_delta": - if not stream_response["delta"]: - return GenerationChunk(text="") - return GenerationChunk( - text=stream_response["delta"]["text"], - generation_info=dict( - stop_reason=stream_response.get("stop_reason", None), - ), - ) - case "message_delta": - usage_info = stream_response.get("usage", None) - stop_reason = stream_response.get("delta", {}).get("stop_reason") - generation_info = {"stop_reason": stop_reason, "usage": usage_info} - return GenerationChunk(text="", generation_info=generation_info) - case _: - return None + msg_type = stream_response.get("type") + if msg_type == "message_start": + usage_info = stream_response.get("message", {}).get("usage", None) + generation_info = {"usage": usage_info} + return GenerationChunk(text="", generation_info=generation_info) + elif msg_type == "content_block_delta": + if not stream_response["delta"]: + return GenerationChunk(text="") + return GenerationChunk( + text=stream_response["delta"]["text"], + generation_info=dict( + stop_reason=stream_response.get("stop_reason", None), + ), + ) + elif msg_type == "message_delta": + usage_info = stream_response.get("usage", None) + stop_reason = stream_response.get("delta", {}).get("stop_reason") + generation_info = {"stop_reason": stop_reason, "usage": usage_info} + return GenerationChunk(text="", generation_info=generation_info) + else: + return None else: # chunk obj format varies with provider generation_info = {k: v for k, v in stream_response.items() if k != output_key}