Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search grounding #558

Merged
merged 17 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@
"FunctionLibraryType",
]

Mode = protos.DynamicRetrievalConfig.Mode

ModeOptions = Union[int, str, Mode]

_MODE: dict[ModeOptions, Mode] = {
Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED,
0: Mode.MODE_UNSPECIFIED,
"mode_unspecified": Mode.MODE_UNSPECIFIED,
"unspecified": Mode.MODE_UNSPECIFIED,
Mode.MODE_DYNAMIC: Mode.MODE_DYNAMIC,
1: Mode.MODE_DYNAMIC,
"mode_dynamic": Mode.MODE_DYNAMIC,
"dynamic": Mode.MODE_DYNAMIC,
}


def to_mode(x: ModeOptions) -> Mode:
if isinstance(x, str):
x = x.lower()
return _MODE[x]


def pil_to_blob(img):
# When you load an image with PIL you get a subclass of PIL.Image
Expand Down Expand Up @@ -650,16 +671,54 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
return fd.to_proto()


class DynamicRetrievalConfigDict(TypedDict):
mode: protos.DynamicRetrievalConfig.mode
dynamic_threshold: float


DynamicRetrievalConfig = Union[protos.DynamicRetrievalConfig, DynamicRetrievalConfigDict]


class GoogleSearchRetrievalDict(TypedDict):
dynamic_retrieval_config: DynamicRetrievalConfig


GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, GoogleSearchRetrievalDict]


def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType):
if isinstance(gsr, protos.GoogleSearchRetrieval):
return gsr
elif isinstance(gsr, Mapping):
drc = gsr.get("dynamic_retrieval_config", None)
if drc is not None and isinstance(drc, Mapping):
mode = drc.get("mode", None)
if mode is not None:
mode = to_mode(mode)
gsr = gsr.copy()
gsr["dynamic_retrieval_config"]["mode"] = mode
return protos.GoogleSearchRetrieval(gsr)
else:
raise TypeError(
"Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n"
f"However, received an object of type: {type(gsr)}.\n"
f"Object Value: {gsr}"
)


class Tool:
"""A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects."""
"""A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects,
protos.CodeExecution object, and protos.GoogleSearchRetrieval object."""

def __init__(
shilpakancharla marked this conversation as resolved.
Show resolved Hide resolved
self,
shilpakancharla marked this conversation as resolved.
Show resolved Hide resolved
*,
function_declarations: Iterable[FunctionDeclarationType] | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that only one of these is set?

google_search_retrieval: GoogleSearchRetrievalType | None = None,
code_execution: protos.CodeExecution | None = None,
):
# The main path doesn't use this but is seems useful.
if function_declarations:
if function_declarations is not None:
self._function_declarations = [
_make_function_declaration(f) for f in function_declarations
]
Expand All @@ -674,15 +733,25 @@ def __init__(
self._function_declarations = []
self._index = {}

if google_search_retrieval is not None:
self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval)
else:
self._google_search_retrieval = None

self._proto = protos.Tool(
function_declarations=[_encode_fd(fd) for fd in self._function_declarations],
google_search_retrieval=google_search_retrieval,
code_execution=code_execution,
)

@property
def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]:
return self._function_declarations

@property
def google_search_retrieval(self) -> protos.GoogleSearchRetrieval:
return self._google_search_retrieval

@property
def code_execution(self) -> protos.CodeExecution:
return self._proto.code_execution
Expand Down Expand Up @@ -711,7 +780,7 @@ class ToolDict(TypedDict):


ToolType = Union[
Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType
str, Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType
]


Expand All @@ -723,20 +792,41 @@ def _make_tool(tool: ToolType) -> Tool:
code_execution = tool.code_execution
else:
code_execution = None
shilpakancharla marked this conversation as resolved.
Show resolved Hide resolved
return Tool(function_declarations=tool.function_declarations, code_execution=code_execution)

if "google_search_retrieval" in tool:
google_search_retrieval = tool.google_search_retrieval
else:
google_search_retrieval = None

return Tool(
function_declarations=tool.function_declarations,
google_search_retrieval=google_search_retrieval,
code_execution=code_execution,
)
elif isinstance(tool, dict):
shilpakancharla marked this conversation as resolved.
Show resolved Hide resolved
if "function_declarations" in tool or "code_execution" in tool:
if (
"function_declarations" in tool
or "google_search_retrieval" in tool
or "code_execution" in tool
):
return Tool(**tool)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment to explain this else?

fd = tool
return Tool(function_declarations=[protos.FunctionDeclaration(**fd)])
elif isinstance(tool, str):
if tool.lower() == "code_execution":
shilpakancharla marked this conversation as resolved.
Show resolved Hide resolved
return Tool(code_execution=protos.CodeExecution())
# Check to see if one of the mode enums matches
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want this block, people shouldn't be passing a Mode-strings as a tool.

model.generate_content(tools="Dynamic") ??

shilpakancharla marked this conversation as resolved.
Show resolved Hide resolved
elif tool.lower() == "google_search_retrieval":
return Tool(google_search_retrieval=protos.GoogleSearchRetrieval())
else:
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")
raise ValueError(
"The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval."
)
elif isinstance(tool, protos.CodeExecution):
return Tool(code_execution=tool)
elif isinstance(tool, protos.GoogleSearchRetrieval):
return Tool(google_search_retrieval=tool)
elif isinstance(tool, Iterable):
return Tool(function_declarations=tool)
else:
Expand Down Expand Up @@ -792,7 +882,7 @@ def to_proto(self):

def _make_tools(tools: ToolsType) -> list[Tool]:
if isinstance(tools, str):
if tools.lower() == "code_execution":
if tools.lower() == "code_execution" or tools.lower() == "google_search_retrieval":
return [_make_tool(tools)]
else:
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")
Expand Down
78 changes: 72 additions & 6 deletions tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,78 @@ def no_args():
["empty_dictionary_list", [{"code_execution": {}}]],
)
def test_code_execution(self, tools):
if isinstance(tools, Iterable):
t = content_types._make_tools(tools)
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)
else:
t = content_types._make_tool(tools) # Pass code execution into tools
self.assertIsInstance(t.code_execution, protos.CodeExecution)
t = content_types._make_tools(tools)
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)

@parameterized.named_parameters(
["string", "google_search_retrieval"],
["empty_dictionary", {"google_search_retrieval": {}}],
[
"empty_dictionary_with_dynamic_retrieval_config",
{"google_search_retrieval": {"dynamic_retrieval_config": {}}},
],
[
"dictionary_with_mode_integer",
{"google_search_retrieval": {"dynamic_retrieval_config": {"mode": 0}}},
],
[
"dictionary_with_mode_string",
{"google_search_retrieval": {"dynamic_retrieval_config": {"mode": "DYNAMIC"}}},
],
[
"dictionary_with_dynamic_retrieval_config",
{
"google_search_retrieval": {
"dynamic_retrieval_config": {"mode": "unspecified", "dynamic_threshold": 0.5}
}
},
],
[
"proto_object",
protos.GoogleSearchRetrieval(
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
)
),
],
[
"proto_passed_in",
protos.Tool(
google_search_retrieval=protos.GoogleSearchRetrieval(
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
)
)
),
],
[
"proto_object_list",
[
protos.GoogleSearchRetrieval(
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
)
)
],
],
[
"proto_passed_in_list",
[
protos.Tool(
google_search_retrieval=protos.GoogleSearchRetrieval(
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
)
)
)
],
],
)
def test_search_grounding(self, tools):
if self._testMethodName == "test_search_grounding_empty_dictionary":
pass
t = content_types._make_tools(tools)
self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval)

def test_two_fun_is_one_tool(self):
def a():
Expand Down
Loading