Skip to content

Commit

Permalink
Search grounding (#558)
Browse files Browse the repository at this point in the history
* Updated tests and current progress on adding search grounding.

* Update google/generativeai/types/content_types.py

Co-authored-by: Mark Daoust <[email protected]>

* Update tests/test_content.py

Co-authored-by: Mark Daoust <[email protected]>

* Update search grounding

* update content_types

* Update and add aditional test cases

* update test case on empty_dictionary_with_dynamic_retrieval_config

* Update test cases and _make_search_grounding

* fix tests

Change-Id: Ib9e19d78861da180f713e09ec93d366d5d7b5762

* Remove print statement

* Fix tuned model tests

Change-Id: I5ace9222954be7d903ebbdabab9efc663fa79174

* Fix tests

Change-Id: Ifa610965c5d6c38123080a7e16416ac325418285

* format

Change-Id: Iab48a9400d53f3cbdc5ca49c73df4f6a186a867b

* fix typing

Change-Id: If892b20ca29d1afb82c48ae1a49bef58e0421bab

* Format

Change-Id: I51a51150879adb3d4b6b00323e0d8eaf4c0b2515

---------

Co-authored-by: Mark Daoust <[email protected]>
  • Loading branch information
shilpakancharla and MarkDaoust authored Sep 24, 2024
1 parent 6c8dad1 commit d5103eb
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 13 deletions.
104 changes: 97 additions & 7 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,27 @@
"FunctionLibraryType",
]

Mode = protos.DynamicRetrievalConfig.Mode

ModeOptions = Union[int, str, Mode]

_MODE: dict[ModeOptions, Mode] = {
Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED,
0: Mode.MODE_UNSPECIFIED,
"mode_unspecified": Mode.MODE_UNSPECIFIED,
"unspecified": Mode.MODE_UNSPECIFIED,
Mode.MODE_DYNAMIC: Mode.MODE_DYNAMIC,
1: Mode.MODE_DYNAMIC,
"mode_dynamic": Mode.MODE_DYNAMIC,
"dynamic": Mode.MODE_DYNAMIC,
}


def to_mode(x: ModeOptions) -> Mode:
if isinstance(x, str):
x = x.lower()
return _MODE[x]


def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
# If the image is a local file, return a file-based blob without any modification.
Expand Down Expand Up @@ -644,16 +665,54 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
return fd.to_proto()


class DynamicRetrievalConfigDict(TypedDict):
mode: protos.DynamicRetrievalConfig.mode
dynamic_threshold: float


DynamicRetrievalConfig = Union[protos.DynamicRetrievalConfig, DynamicRetrievalConfigDict]


class GoogleSearchRetrievalDict(TypedDict):
dynamic_retrieval_config: DynamicRetrievalConfig


GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, GoogleSearchRetrievalDict]


def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType):
if isinstance(gsr, protos.GoogleSearchRetrieval):
return gsr
elif isinstance(gsr, Mapping):
drc = gsr.get("dynamic_retrieval_config", None)
if drc is not None and isinstance(drc, Mapping):
mode = drc.get("mode", None)
if mode is not None:
mode = to_mode(mode)
gsr = gsr.copy()
gsr["dynamic_retrieval_config"]["mode"] = mode
return protos.GoogleSearchRetrieval(gsr)
else:
raise TypeError(
"Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n"
f"However, received an object of type: {type(gsr)}.\n"
f"Object Value: {gsr}"
)


class Tool:
"""A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects."""
"""A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects,
protos.CodeExecution object, and protos.GoogleSearchRetrieval object."""

def __init__(
self,
*,
function_declarations: Iterable[FunctionDeclarationType] | None = None,
google_search_retrieval: GoogleSearchRetrievalType | None = None,
code_execution: protos.CodeExecution | None = None,
):
# The main path doesn't use this but is seems useful.
if function_declarations:
if function_declarations is not None:
self._function_declarations = [
_make_function_declaration(f) for f in function_declarations
]
Expand All @@ -668,15 +727,25 @@ def __init__(
self._function_declarations = []
self._index = {}

if google_search_retrieval is not None:
self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval)
else:
self._google_search_retrieval = None

self._proto = protos.Tool(
function_declarations=[_encode_fd(fd) for fd in self._function_declarations],
google_search_retrieval=google_search_retrieval,
code_execution=code_execution,
)

@property
def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]:
return self._function_declarations

@property
def google_search_retrieval(self) -> protos.GoogleSearchRetrieval:
return self._google_search_retrieval

@property
def code_execution(self) -> protos.CodeExecution:
return self._proto.code_execution
Expand Down Expand Up @@ -705,7 +774,7 @@ class ToolDict(TypedDict):


ToolType = Union[
Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType
str, Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType
]


Expand All @@ -717,20 +786,41 @@ def _make_tool(tool: ToolType) -> Tool:
code_execution = tool.code_execution
else:
code_execution = None
return Tool(function_declarations=tool.function_declarations, code_execution=code_execution)

if "google_search_retrieval" in tool:
google_search_retrieval = tool.google_search_retrieval
else:
google_search_retrieval = None

return Tool(
function_declarations=tool.function_declarations,
google_search_retrieval=google_search_retrieval,
code_execution=code_execution,
)
elif isinstance(tool, dict):
if "function_declarations" in tool or "code_execution" in tool:
if (
"function_declarations" in tool
or "google_search_retrieval" in tool
or "code_execution" in tool
):
return Tool(**tool)
else:
fd = tool
return Tool(function_declarations=[protos.FunctionDeclaration(**fd)])
elif isinstance(tool, str):
if tool.lower() == "code_execution":
return Tool(code_execution=protos.CodeExecution())
# Check to see if one of the mode enums matches
elif tool.lower() == "google_search_retrieval":
return Tool(google_search_retrieval=protos.GoogleSearchRetrieval())
else:
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")
raise ValueError(
"The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval."
)
elif isinstance(tool, protos.CodeExecution):
return Tool(code_execution=tool)
elif isinstance(tool, protos.GoogleSearchRetrieval):
return Tool(google_search_retrieval=tool)
elif isinstance(tool, Iterable):
return Tool(function_declarations=tool)
else:
Expand Down Expand Up @@ -786,7 +876,7 @@ def to_proto(self):

def _make_tools(tools: ToolsType) -> list[Tool]:
if isinstance(tools, str):
if tools.lower() == "code_execution":
if tools.lower() == "code_execution" or tools.lower() == "google_search_retrieval":
return [_make_tool(tools)]
else:
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")
Expand Down
78 changes: 72 additions & 6 deletions tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,78 @@ def no_args():
["empty_dictionary_list", [{"code_execution": {}}]],
)
def test_code_execution(self, tools):
if isinstance(tools, Iterable):
t = content_types._make_tools(tools)
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)
else:
t = content_types._make_tool(tools) # Pass code execution into tools
self.assertIsInstance(t.code_execution, protos.CodeExecution)
t = content_types._make_tools(tools)
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)

@parameterized.named_parameters(
["string", "google_search_retrieval"],
["empty_dictionary", {"google_search_retrieval": {}}],
[
"empty_dictionary_with_dynamic_retrieval_config",
{"google_search_retrieval": {"dynamic_retrieval_config": {}}},
],
[
"dictionary_with_mode_integer",
{"google_search_retrieval": {"dynamic_retrieval_config": {"mode": 0}}},
],
[
"dictionary_with_mode_string",
{"google_search_retrieval": {"dynamic_retrieval_config": {"mode": "DYNAMIC"}}},
],
[
"dictionary_with_dynamic_retrieval_config",
{
"google_search_retrieval": {
"dynamic_retrieval_config": {"mode": "unspecified", "dynamic_threshold": 0.5}
}
},
],
[
"proto_object",
protos.GoogleSearchRetrieval(
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
)
),
],
[
"proto_passed_in",
protos.Tool(
google_search_retrieval=protos.GoogleSearchRetrieval(
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
)
)
),
],
[
"proto_object_list",
[
protos.GoogleSearchRetrieval(
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
)
)
],
],
[
"proto_passed_in_list",
[
protos.Tool(
google_search_retrieval=protos.GoogleSearchRetrieval(
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
)
)
)
],
],
)
def test_search_grounding(self, tools):
if self._testMethodName == "test_search_grounding_empty_dictionary":
pass
t = content_types._make_tools(tools)
self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval)

def test_two_fun_is_one_tool(self):
def a():
Expand Down

0 comments on commit d5103eb

Please sign in to comment.