Skip to content

Commit

Permalink
Merge pull request #20 from NAPTlME/bedrock-token-count-callbacks
Browse files Browse the repository at this point in the history
Bedrock token count callbacks
  • Loading branch information
3coins authored May 21, 2024
2 parents 622756c + bfa0871 commit 4a88b7f
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 66 deletions.
2 changes: 1 addition & 1 deletion libs/aws/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ test tests integration_test integration_tests:
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/aws --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/aws --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_aws
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
Expand Down
35 changes: 24 additions & 11 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
from langchain_core.tools import BaseTool

from langchain_aws.function_calling import convert_to_anthropic_tool, get_system_message
from langchain_aws.llms.bedrock import BedrockBase
from langchain_aws.llms.bedrock import (
BedrockBase,
_combine_generation_info_for_llm_result,
)
from langchain_aws.utils import (
get_num_tokens_anthropic,
get_token_ids_anthropic,
Expand Down Expand Up @@ -383,7 +386,13 @@ def _stream(
**kwargs,
):
delta = chunk.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield ChatGenerationChunk(
message=AIMessageChunk(
content=delta, response_metadata=chunk.generation_info
)
if chunk.generation_info is not None
else AIMessageChunk(content=delta)
)

def _generate(
self,
Expand All @@ -393,11 +402,18 @@ def _generate(
**kwargs: Any,
) -> ChatResult:
completion = ""
llm_output: Dict[str, Any] = {"model_id": self.model_id}
usage_info: Dict[str, Any] = {}
llm_output: Dict[str, Any] = {}
provider_stop_reason_code = self.provider_stop_reason_key_map.get(
self._get_provider(), "stop_reason"
)
if self.streaming:
response_metadata: List[Dict[str, Any]] = []
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
response_metadata.append(chunk.message.response_metadata)
llm_output = _combine_generation_info_for_llm_result(
response_metadata, provider_stop_reason_code
)
else:
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None
Expand All @@ -420,7 +436,7 @@ def _generate(
if stop:
params["stop_sequences"] = stop

completion, usage_info = self._prepare_input_and_invoke(
completion, llm_output = self._prepare_input_and_invoke(
prompt=prompt,
stop=stop,
run_manager=run_manager,
Expand All @@ -429,14 +445,11 @@ def _generate(
**params,
)

llm_output["usage"] = usage_info

llm_output["model_id"] = self.model_id
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=completion, additional_kwargs={"usage": usage_info}
)
message=AIMessage(content=completion, additional_kwargs=llm_output)
)
],
llm_output=llm_output,
Expand All @@ -447,7 +460,7 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
final_output = {}
for output in llm_outputs:
output = output or {}
usage = output.pop("usage", {})
usage = output.get("usage", {})
for token_type, token_count in usage.items():
final_usage[token_type] += token_count
final_output.update(output)
Expand Down
Loading

0 comments on commit 4a88b7f

Please sign in to comment.