Skip to content

Commit

Permalink
fix: avg log probs nan values (#550)
Browse files Browse the repository at this point in the history
Co-authored-by: Leonid Kuligin <[email protected]>
  • Loading branch information
eliasecchig and lkuligin authored Oct 11, 2024
1 parent b33cc37 commit 1ad60d2
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
9 changes: 7 additions & 2 deletions libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utilities to init Vertex AI."""

import dataclasses
import math
import re
from enum import Enum, auto
from importlib import metadata
Expand Down Expand Up @@ -177,9 +178,13 @@ def get_generation_info(
candidate.finish_reason.name if candidate.finish_reason else None
),
}

if hasattr(candidate, "avg_logprobs") and candidate.avg_logprobs is not None:
info["avg_logprobs"] = candidate.avg_logprobs
if (
isinstance(candidate.avg_logprobs, float)
and not math.isnan(candidate.avg_logprobs)
and candidate.avg_logprobs > 0
):
info["avg_logprobs"] = candidate.avg_logprobs

try:
if candidate.grounding_metadata:
Expand Down
60 changes: 59 additions & 1 deletion libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,63 @@ def test_langgraph_example() -> None:
assert isinstance(step2, AIMessage)


@pytest.mark.asyncio
@pytest.mark.release
async def test_astream_events_langgraph_example() -> None:
llm = ChatVertexAI(
model_name="gemini-1.5-flash-002",
max_output_tokens=8192,
temperature=0.2,
)

add_declaration = {
"name": "add",
"description": "Adds a and b.",
"parameters": {
"properties": {
"a": {"description": "first int", "type": "integer"},
"b": {"description": "second int", "type": "integer"},
},
"required": ["a", "b"],
"type": "object",
},
}

multiply_declaration = {
"name": "multiply",
"description": "Multiply a and b.",
"parameters": {
"properties": {
"a": {"description": "first int", "type": "integer"},
"b": {"description": "second int", "type": "integer"},
},
"required": ["a", "b"],
"type": "object",
},
}

messages = [
SystemMessage(
content=(
"You are a helpful assistant tasked with performing "
"arithmetic on a set of inputs."
)
),
HumanMessage(content="Multiply 2 and 3"),
HumanMessage(content="No, actually multiply 3 and 3!"),
]
agenerator = llm.astream_events(
messages,
tools=[{"function_declarations": [add_declaration, multiply_declaration]}],
version="v2",
)
events = [events async for events in agenerator]
assert len(events) > 0
# Check the function call in the output
output = events[-1]["data"]["output"]
assert output.additional_kwargs["function_call"]["name"] == "multiply"


@pytest.mark.xfail(reason="can't create service account key on gcp")
@pytest.mark.release
def test_init_from_credentials_obj() -> None:
Expand All @@ -986,7 +1043,8 @@ def test_response_metadata_avg_logprobs() -> None:
llm = ChatVertexAI(model="gemini-1.5-flash")
response = llm.invoke("Hello!")
probs = response.response_metadata.get("avg_logprobs")
assert isinstance(probs, float)
if probs is not None:
assert isinstance(probs, float)


@pytest.fixture
Expand Down

0 comments on commit 1ad60d2

Please sign in to comment.