diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 531999f55..f3db610e1 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -72,6 +72,27 @@ "FunctionLibraryType", ] +Mode = protos.DynamicRetrievalConfig.Mode + +ModeOptions = Union[int, str, Mode] + +_MODE: dict[ModeOptions, Mode] = { + Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED, + 0: Mode.MODE_UNSPECIFIED, + "mode_unspecified": Mode.MODE_UNSPECIFIED, + "unspecified": Mode.MODE_UNSPECIFIED, + Mode.MODE_DYNAMIC: Mode.MODE_DYNAMIC, + 1: Mode.MODE_DYNAMIC, + "mode_dynamic": Mode.MODE_DYNAMIC, + "dynamic": Mode.MODE_DYNAMIC, +} + + +def to_mode(x: ModeOptions) -> Mode: + if isinstance(x, str): + x = x.lower() + return _MODE[x] + def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob: # If the image is a local file, return a file-based blob without any modification. @@ -644,16 +665,54 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F return fd.to_proto() +class DynamicRetrievalConfigDict(TypedDict): + mode: protos.DynamicRetrievalConfig.mode + dynamic_threshold: float + + +DynamicRetrievalConfig = Union[protos.DynamicRetrievalConfig, DynamicRetrievalConfigDict] + + +class GoogleSearchRetrievalDict(TypedDict): + dynamic_retrieval_config: DynamicRetrievalConfig + + +GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, GoogleSearchRetrievalDict] + + +def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType): + if isinstance(gsr, protos.GoogleSearchRetrieval): + return gsr + elif isinstance(gsr, Mapping): + drc = gsr.get("dynamic_retrieval_config", None) + if drc is not None and isinstance(drc, Mapping): + mode = drc.get("mode", None) + if mode is not None: + mode = to_mode(mode) + gsr = gsr.copy() + gsr["dynamic_retrieval_config"]["mode"] = mode + return protos.GoogleSearchRetrieval(gsr) + else: + raise TypeError( + "Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n" + f"However, received an object of type: {type(gsr)}.\n" + f"Object Value: {gsr}" + ) + + class Tool: - """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects, + protos.CodeExecution object, and protos.GoogleSearchRetrieval object.""" def __init__( self, + *, function_declarations: Iterable[FunctionDeclarationType] | None = None, + google_search_retrieval: GoogleSearchRetrievalType | None = None, code_execution: protos.CodeExecution | None = None, ): # The main path doesn't use this but is seems useful. - if function_declarations: + if function_declarations is not None: self._function_declarations = [ _make_function_declaration(f) for f in function_declarations ] @@ -668,8 +727,14 @@ def __init__( self._function_declarations = [] self._index = {} + if google_search_retrieval is not None: + self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval) + else: + self._google_search_retrieval = None + self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations], + google_search_retrieval=google_search_retrieval, code_execution=code_execution, ) @@ -677,6 +742,10 @@ def __init__( def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations + @property + def google_search_retrieval(self) -> protos.GoogleSearchRetrieval: + return self._google_search_retrieval + @property def code_execution(self) -> protos.CodeExecution: return self._proto.code_execution @@ -705,7 +774,7 @@ class ToolDict(TypedDict): ToolType = Union[ - Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType + str, Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType ] @@ -717,9 +786,23 @@ def _make_tool(tool: ToolType) -> Tool: code_execution = tool.code_execution else: code_execution = None - return Tool(function_declarations=tool.function_declarations, code_execution=code_execution) + + if "google_search_retrieval" in tool: + google_search_retrieval = tool.google_search_retrieval + else: + google_search_retrieval = None + + return Tool( + function_declarations=tool.function_declarations, + google_search_retrieval=google_search_retrieval, + code_execution=code_execution, + ) elif isinstance(tool, dict): - if "function_declarations" in tool or "code_execution" in tool: + if ( + "function_declarations" in tool + or "google_search_retrieval" in tool + or "code_execution" in tool + ): return Tool(**tool) else: fd = tool @@ -727,10 +810,17 @@ def _make_tool(tool: ToolType) -> Tool: elif isinstance(tool, str): if tool.lower() == "code_execution": return Tool(code_execution=protos.CodeExecution()) + # Check to see if one of the mode enums matches + elif tool.lower() == "google_search_retrieval": + return Tool(google_search_retrieval=protos.GoogleSearchRetrieval()) else: - raise ValueError("The only string that can be passed as a tool is 'code_execution'.") + raise ValueError( + "The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval." + ) elif isinstance(tool, protos.CodeExecution): return Tool(code_execution=tool) + elif isinstance(tool, protos.GoogleSearchRetrieval): + return Tool(google_search_retrieval=tool) elif isinstance(tool, Iterable): return Tool(function_declarations=tool) else: @@ -786,7 +876,7 @@ def to_proto(self): def _make_tools(tools: ToolsType) -> list[Tool]: if isinstance(tools, str): - if tools.lower() == "code_execution": + if tools.lower() == "code_execution" or tools.lower() == "google_search_retrieval": return [_make_tool(tools)] else: raise ValueError("The only string that can be passed as a tool is 'code_execution'.") diff --git a/tests/test_content.py b/tests/test_content.py index 52e78f349..2031e40ae 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -435,12 +435,78 @@ def no_args(): ["empty_dictionary_list", [{"code_execution": {}}]], ) def test_code_execution(self, tools): - if isinstance(tools, Iterable): - t = content_types._make_tools(tools) - self.assertIsInstance(t[0].code_execution, protos.CodeExecution) - else: - t = content_types._make_tool(tools) # Pass code execution into tools - self.assertIsInstance(t.code_execution, protos.CodeExecution) + t = content_types._make_tools(tools) + self.assertIsInstance(t[0].code_execution, protos.CodeExecution) + + @parameterized.named_parameters( + ["string", "google_search_retrieval"], + ["empty_dictionary", {"google_search_retrieval": {}}], + [ + "empty_dictionary_with_dynamic_retrieval_config", + {"google_search_retrieval": {"dynamic_retrieval_config": {}}}, + ], + [ + "dictionary_with_mode_integer", + {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": 0}}}, + ], + [ + "dictionary_with_mode_string", + {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": "DYNAMIC"}}}, + ], + [ + "dictionary_with_dynamic_retrieval_config", + { + "google_search_retrieval": { + "dynamic_retrieval_config": {"mode": "unspecified", "dynamic_threshold": 0.5} + } + }, + ], + [ + "proto_object", + protos.GoogleSearchRetrieval( + dynamic_retrieval_config=protos.DynamicRetrievalConfig( + mode="MODE_UNSPECIFIED", dynamic_threshold=0.5 + ) + ), + ], + [ + "proto_passed_in", + protos.Tool( + google_search_retrieval=protos.GoogleSearchRetrieval( + dynamic_retrieval_config=protos.DynamicRetrievalConfig( + mode="MODE_UNSPECIFIED", dynamic_threshold=0.5 + ) + ) + ), + ], + [ + "proto_object_list", + [ + protos.GoogleSearchRetrieval( + dynamic_retrieval_config=protos.DynamicRetrievalConfig( + mode="MODE_UNSPECIFIED", dynamic_threshold=0.5 + ) + ) + ], + ], + [ + "proto_passed_in_list", + [ + protos.Tool( + google_search_retrieval=protos.GoogleSearchRetrieval( + dynamic_retrieval_config=protos.DynamicRetrievalConfig( + mode="MODE_UNSPECIFIED", dynamic_threshold=0.5 + ) + ) + ) + ], + ], + ) + def test_search_grounding(self, tools): + if self._testMethodName == "test_search_grounding_empty_dictionary": + pass + t = content_types._make_tools(tools) + self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval) def test_two_fun_is_one_tool(self): def a():