From aec906fe6c5590e4be34c169f3b9a89ead3143a5 Mon Sep 17 00:00:00 2001 From: Jeff Park Date: Thu, 14 Nov 2024 12:20:32 -0500 Subject: [PATCH 1/3] refactor: simplify content processing in anthropic formatter --- .../langchain_google_vertexai/model_garden.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/model_garden.py b/libs/vertexai/langchain_google_vertexai/model_garden.py index 7afa8071..12c93e17 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden.py @@ -205,14 +205,18 @@ def _format_params( def _format_output(self, data: Any, **kwargs: Any) -> ChatResult: data_dict = data.model_dump() - content = [c for c in data_dict["content"] if c["type"] != "tool_use"] - content = content[0]["text"] if len(content) == 1 else content + content = data_dict["content"] llm_output = { k: v for k, v in data_dict.items() if k not in ("content", "role", "type") } - tool_calls = _extract_tool_calls(data_dict["content"]) - if tool_calls: - msg = AIMessage(content=content, tool_calls=tool_calls) + if len(content) == 1 and content[0]["type"] == "text": + msg = AIMessage(content=content[0]["text"]) + elif any(block["type"] == "tool_use" for block in content): + tool_calls = _extract_tool_calls(content) + msg = AIMessage( + content=content, + tool_calls=tool_calls, + ) else: msg = AIMessage(content=content) # Collect token usage From b9ac975ceb297ec4321630ebf5c6417e5be54f85 Mon Sep 17 00:00:00 2001 From: Jeff Park Date: Fri, 15 Nov 2024 21:06:25 -0500 Subject: [PATCH 2/3] chore: adding test and extract_tool_call for updated anthropic format options --- .../_anthropic_parsers.py | 19 +++++--- .../langchain_google_vertexai/model_garden.py | 9 +++- .../tests/unit_tests/test_chat_models.py | 47 +++++++++++++++++++ 3 files changed, 67 insertions(+), 8 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py b/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py index a31759bf..d979f20e 100644 --- a/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py +++ b/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Type +from typing import Any, List, Optional, Type, Union from langchain_core.messages import AIMessage, ToolCall from langchain_core.messages.tool import tool_call @@ -55,11 +55,18 @@ def _pydantic_parse(self, tool_call: dict) -> BaseModel: return cls_(**tool_call["args"]) -def _extract_tool_calls(content: List[dict]) -> List[ToolCall]: - tool_calls = [] - for block in content: - if block["type"] == "tool_use": +def _extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]: + """Extract tool calls from a list of content blocks.""" + if isinstance(content, list): + tool_calls = [] + for block in content: + if isinstance(block, str): + continue + if block["type"] != "tool_use": + continue tool_calls.append( tool_call(name=block["name"], args=block["input"], id=block["id"]) ) - return tool_calls + return tool_calls + else: + return [] diff --git a/libs/vertexai/langchain_google_vertexai/model_garden.py b/libs/vertexai/langchain_google_vertexai/model_garden.py index 12c93e17..ab553f6e 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden.py @@ -153,15 +153,20 @@ def validate_environment(self) -> Self: AsyncAnthropicVertex, ) + if self.project is None: + raise ValueError("project is required for ChatAnthropicVertex") + + project_id: str = self.project + self.client = AnthropicVertex( - project_id=self.project, + project_id=project_id, region=self.location, max_retries=self.max_retries, access_token=self.access_token, credentials=self.credentials, ) self.async_client = AsyncAnthropicVertex( - project_id=self.project, + project_id=project_id, region=self.location, max_retries=self.max_retries, access_token=self.access_token, diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index e9da74ea..06567e71 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -41,6 +41,7 @@ _parse_examples, _parse_response_candidate, ) +from langchain_google_vertexai.model_garden import ChatAnthropicVertex def test_init() -> None: @@ -1067,3 +1068,49 @@ def test_init_client_with_custom_api() -> None: transport = mock_prediction_service.call_args.kwargs["transport"] assert client_options.api_endpoint == "https://example.com" assert transport == "rest" + + +def test_anthropic_format_output() -> None: + """Test format output handles different content structures correctly.""" + + @dataclass + class Usage: + input_tokens: int + output_tokens: int + + @dataclass + class Message: + def model_dump(self): + return { + "content": [ + { + "type": "tool_use", + "id": "123", + "name": "calculator", + "input": {"number": 42}, + } + ], + "model": "baz", + "role": "assistant", + "usage": Usage(input_tokens=2, output_tokens=1), + "type": "message", + } + + usage: Usage + + test_msg = Message(usage=Usage(input_tokens=2, output_tokens=1)) + + model = ChatAnthropicVertex(project="test-project", location="test-location") + result = model._format_output(test_msg) + + message = result.generations[0].message + assert isinstance(message, AIMessage) + assert message.content == test_msg.model_dump()["content"] + assert len(message.tool_calls) == 1 + assert message.tool_calls[0]["name"] == "calculator" + assert message.tool_calls[0]["args"] == {"number": 42} + assert message.usage_metadata == { + "input_tokens": 2, + "output_tokens": 1, + "total_tokens": 3, + } From c90f5d8f0a1847cccf7be78a4a31a149473a94cc Mon Sep 17 00:00:00 2001 From: Jeff Park Date: Fri, 15 Nov 2024 22:12:57 -0500 Subject: [PATCH 3/3] chore: fix test_anthropic_tool_calling integration test --- libs/vertexai/tests/integration_tests/test_model_garden.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/vertexai/tests/integration_tests/test_model_garden.py b/libs/vertexai/tests/integration_tests/test_model_garden.py index a0946a6c..6b90d591 100644 --- a/libs/vertexai/tests/integration_tests/test_model_garden.py +++ b/libs/vertexai/tests/integration_tests/test_model_garden.py @@ -168,7 +168,7 @@ async def test_anthropic_async() -> None: def _check_tool_calls(response: BaseMessage, expected_name: str) -> None: """Check tool calls are as expected.""" assert isinstance(response, AIMessage) - assert isinstance(response.content, str) + assert isinstance(response.content, list) tool_calls = response.tool_calls assert len(tool_calls) == 1 tool_call = tool_calls[0]