Skip to content

Commit

Permalink
langchain-google-vertexai: Update vision_models.py to allow kwargs us…
Browse files Browse the repository at this point in the history
…age (#473)
  • Loading branch information
TommasoPetrolito authored Oct 3, 2024
1 parent 3219420 commit 9ebd8c1
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions libs/vertexai/langchain_google_vertexai/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.generation import Generation
from pydantic import BaseModel, ConfigDict, Field
from vertexai.preview.vision_models import ( # type: ignore[import-untyped]
from vertexai.vision_models import ( # type: ignore[import-untyped]
GeneratedImage,
Image,
ImageGenerationModel,
ImageTextModel,
)
from vertexai.vision_models import Image, ImageTextModel # type: ignore[import-untyped]

from langchain_google_vertexai._image_utils import (
ImageBytesLoader,
Expand Down Expand Up @@ -97,7 +98,7 @@ def _default_params(self) -> Dict[str, Any]:
def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]:
params = self._default_params
for key, value in kwargs.items():
if key in params and value is not None:
if value is not None:
params[key] = value
return params

Expand All @@ -110,6 +111,7 @@ def _get_captions(
image: Image,
number_of_results: Optional[int] = None,
language: Optional[str] = None,
**kwargs,
) -> List[str]:
"""Uses the sdk methods to generate a list of captions.
Expand All @@ -123,7 +125,7 @@ def _get_captions(
"""
with telemetry.tool_context_manager(self._user_agent):
params = self._prepare_params(
number_of_results=number_of_results, language=language
number_of_results=number_of_results, language=language, **kwargs
)
captions = self.client.get_captions(image=image, **params)
return captions
Expand Down Expand Up @@ -224,7 +226,7 @@ def _generate(
"{'type': 'image_url', 'image_url': {'image': <image_str>}}"
)

captions = self._get_captions(image, **messages[0].additional_kwargs)
captions = self._get_captions(image, **messages[0].additional_kwargs, **kwargs)

generations = [
ChatGeneration(message=AIMessage(content=caption)) for caption in captions
Expand Down Expand Up @@ -287,7 +289,7 @@ def _generate(
)

answers = self._ask_questions(
image=image, query=user_question, **messages[0].additional_kwargs
image=image, query=user_question, **messages[0].additional_kwargs, **kwargs
)

generations = [
Expand Down Expand Up @@ -361,7 +363,7 @@ def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]:
mapping = {"number_of_results": "number_of_images"}
for key, value in kwargs.items():
key = mapping.get(key, key)
if key in params and value is not None:
if value is not None:
params[key] = value
return {k: v for k, v in params.items() if v is not None}

Expand Down Expand Up @@ -477,7 +479,7 @@ def _generate(
)

image_str_list = self._generate_images(
prompt=user_query, **messages[0].additional_kwargs
prompt=user_query, **messages[0].additional_kwargs, **kwargs
)
image_content_part_list = [
create_image_content_part(image_str=image_str)
Expand Down Expand Up @@ -526,7 +528,10 @@ def _generate(
)

image_str_list = self._edit_images(
image_str=image_str, prompt=user_query, **messages[0].additional_kwargs
image_str=image_str,
prompt=user_query,
**messages[0].additional_kwargs,
**kwargs,
)
image_content_part_list = [
create_image_content_part(image_str=image_str)
Expand Down

0 comments on commit 9ebd8c1

Please sign in to comment.