From 1e954e9fd3918007175f47dd649395f528d08857 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 9 Oct 2024 20:26:15 -0400 Subject: [PATCH] [Vertex AI] Refactored `FunctionCallingConfig.Mode` as a `struct` (#13864) --- .../Sources/FunctionCalling.swift | 58 +++++++++++++------ .../Tests/Integration/IntegrationTests.swift | 5 +- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/FirebaseVertexAI/Sources/FunctionCalling.swift b/FirebaseVertexAI/Sources/FunctionCalling.swift index 69924f3cc4b..9f62eff253e 100644 --- a/FirebaseVertexAI/Sources/FunctionCalling.swift +++ b/FirebaseVertexAI/Sources/FunctionCalling.swift @@ -73,33 +73,53 @@ public struct Tool { /// Configuration for specifying function calling behavior. public struct FunctionCallingConfig { - /// Defines the execution behavior for function calling by defining the - /// execution mode. - public enum Mode: String { - /// The default behavior for function calling. The model calls functions to answer queries at - /// its discretion. - case auto = "AUTO" + /// Defines the execution behavior for function calling by defining the execution mode. + public struct Mode: EncodableProtoEnum { + enum Kind: String { + case auto = "AUTO" + case any = "ANY" + case none = "NONE" + } + + /// The default behavior for function calling. + /// + /// The model calls functions to answer queries at its discretion. + public static var auto: Mode { + return self.init(kind: .auto) + } /// The model always predicts a provided function call to answer every query. - case any = "ANY" - - /// The model will never predict a function call to answer a query. This can also be achieved by - /// not passing any tools to the model. - case none = "NONE" + public static var any: Mode { + return self.init(kind: .any) + } + + /// The model will never predict a function call to answer a query. + /// + /// > Note: This can also be achieved by not passing any ``FunctionDeclaration`` tools + /// > when instantiating the model. + public static var none: Mode { + return self.init(kind: .none) + } + + let rawValue: String } - /// Specifies the mode in which function calling should execute. If - /// unspecified, the default value will be set to AUTO. + /// Specifies the mode in which function calling should execute. let mode: Mode? - /// A set of function names that, when provided, limits the functions the model - /// will call. - /// - /// This should only be set when the Mode is ANY. Function names - /// should match [FunctionDeclaration.name]. With mode set to ANY, model will - /// predict a function call from the set of function names provided. + /// A set of function names that, when provided, limits the functions the model will call. let allowedFunctionNames: [String]? + /// Creates a new `FunctionCallingConfig`. + /// + /// - Parameters: + /// - mode: Specifies the mode in which function calling should execute; if unspecified, the + /// default behavior will be ``Mode/auto``. + /// - allowedFunctionNames: A set of function names that, when provided, limits the functions + /// the model will call. + /// Note: This should only be set when the ``Mode`` is ``Mode/any``. Function names should match + /// `[FunctionDeclaration.name]`. With mode set to ``Mode/any``, the model will predict a + /// function call from the set of function names provided. public init(mode: FunctionCallingConfig.Mode? = nil, allowedFunctionNames: [String]? = nil) { self.mode = mode self.allowedFunctionNames = allowedFunctionNames diff --git a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift index 51241c915c2..80b6f1ab528 100644 --- a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift @@ -59,6 +59,7 @@ final class IntegrationTests: XCTestCase { generationConfig: generationConfig, safetySettings: safetySettings, tools: [], + toolConfig: .init(functionCallingConfig: .init(mode: FunctionCallingConfig.Mode.none)), systemInstruction: systemInstruction ) } @@ -94,6 +95,7 @@ final class IntegrationTests: XCTestCase { SafetySetting(harmCategory: .dangerousContent, threshold: .blockNone), SafetySetting(harmCategory: .civicIntegrity, threshold: .off), ], + toolConfig: .init(functionCallingConfig: .init(mode: .auto)), systemInstruction: systemInstruction ) @@ -135,7 +137,8 @@ final class IntegrationTests: XCTestCase { ) model = vertex.generativeModel( modelName: "gemini-1.5-flash", - tools: [Tool(functionDeclarations: [sumDeclaration])] + tools: [Tool(functionDeclarations: [sumDeclaration])], + toolConfig: .init(functionCallingConfig: .init(mode: .any, allowedFunctionNames: ["sum"])) ) let prompt = "What is 10 + 32?" let sumCall = FunctionCallPart(name: "sum", args: ["x": .number(10), "y": .number(32)])