Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix default_metadata for files API. #548

Merged
merged 7 commits into from
Sep 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import os
import contextlib
import inspect
import dataclasses
import pathlib
import types
from typing import Any, cast
from collections.abc import Sequence
import httplib2
Expand All @@ -30,6 +30,21 @@
__version__ = "0.0.0"

USER_AGENT = "genai-py"

#### Caution! ####
# - It would make sense for the discovery URL to respect the client_options.endpoint setting.
# - That would make testing Files on the staging server possible.
# - We tried fixing this once, but broke colab in the process because their endpoint didn't forward the discovery
# requests. https://github.com/google-gemini/generative-ai-python/pull/333
# - Kaggle would have a similar problem (b/362278209).
# - I think their proxy would forward the discovery traffic.
# - But they don't need to intercept the files-service at all, and uploads of large files could overload them.
# - Do the scotty uploads go to the same domain?
# - If you do route the discovery call to kaggle, be sure to attach the default_metadata (they need it).
# - One solution to all this would be if configure could take overrides per service.
# - set client_options.endpoint, but use a different endpoint for file service? It's not clear how best to do that
# through the file service.
##################
GENAI_API_DISCOVERY_URL = "https://generativelanguage.googleapis.com/$discovery/rest"


Expand All @@ -50,7 +65,7 @@ def __init__(self, *args, **kwargs):
self._discovery_api = None
super().__init__(*args, **kwargs)

def _setup_discovery_api(self):
def _setup_discovery_api(self, metadata: dict | Sequence[tuple[str, str]] = ()):
api_key = self._client_options.api_key
if api_key is None:
raise ValueError(
Expand All @@ -61,6 +76,7 @@ def _setup_discovery_api(self):
http=httplib2.Http(),
postproc=lambda resp, content: (resp, content),
uri=f"{GENAI_API_DISCOVERY_URL}?version=v1beta&key={api_key}",
headers=dict(metadata),
)
response, content = request.execute()
request.http.close()
Expand All @@ -78,9 +94,10 @@ def create_file(
name: str | None = None,
display_name: str | None = None,
resumable: bool = True,
metadata: Sequence[tuple[str, str]] = (),
) -> protos.File:
if self._discovery_api is None:
self._setup_discovery_api()
self._setup_discovery_api(metadata)

file = {}
if name is not None:
Expand All @@ -92,6 +109,8 @@ def create_file(
filename=path, mimetype=mime_type, resumable=resumable
)
request = self._discovery_api.media().upload(body={"file": file}, media_body=media)
for key, value in metadata:
request.headers[key] = value
result = request.execute()

return self.get_file({"name": result["file"]["name"]})
Expand Down Expand Up @@ -226,16 +245,14 @@ def make_client(self, name):
def keep(name, f):
if name.startswith("_"):
return False
elif name == "create_file":
return False
elif not isinstance(f, types.FunctionType):
return False
elif isinstance(f, classmethod):

if not callable(f):
return False
elif isinstance(f, staticmethod):

if "metadata" not in inspect.signature(f).parameters.keys():
return False
else:
return True

return True

def add_default_metadata_wrapper(f):
def call(*args, metadata=(), **kwargs):
Expand All @@ -244,7 +261,7 @@ def call(*args, metadata=(), **kwargs):

return call

for name, value in cls.__dict__.items():
for name, value in inspect.getmembers(cls):
if not keep(name, value):
continue
f = getattr(client, name)
Expand Down
Loading