Skip to content

Commit

Permalink
[Vertex AI] Refactored FunctionCallingConfig.Mode as a struct (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Oct 10, 2024
1 parent 4b263b6 commit 1e954e9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 20 deletions.
58 changes: 39 additions & 19 deletions FirebaseVertexAI/Sources/FunctionCalling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,33 +73,53 @@ public struct Tool {

/// Configuration for specifying function calling behavior.
public struct FunctionCallingConfig {
/// Defines the execution behavior for function calling by defining the
/// execution mode.
public enum Mode: String {
/// The default behavior for function calling. The model calls functions to answer queries at
/// its discretion.
case auto = "AUTO"
/// Defines the execution behavior for function calling by defining the execution mode.
public struct Mode: EncodableProtoEnum {
enum Kind: String {
case auto = "AUTO"
case any = "ANY"
case none = "NONE"
}

/// The default behavior for function calling.
///
/// The model calls functions to answer queries at its discretion.
public static var auto: Mode {
return self.init(kind: .auto)
}

/// The model always predicts a provided function call to answer every query.
case any = "ANY"

/// The model will never predict a function call to answer a query. This can also be achieved by
/// not passing any tools to the model.
case none = "NONE"
public static var any: Mode {
return self.init(kind: .any)
}

/// The model will never predict a function call to answer a query.
///
/// > Note: This can also be achieved by not passing any ``FunctionDeclaration`` tools
/// > when instantiating the model.
public static var none: Mode {
return self.init(kind: .none)
}

let rawValue: String
}

/// Specifies the mode in which function calling should execute. If
/// unspecified, the default value will be set to AUTO.
/// Specifies the mode in which function calling should execute.
let mode: Mode?

/// A set of function names that, when provided, limits the functions the model
/// will call.
///
/// This should only be set when the Mode is ANY. Function names
/// should match [FunctionDeclaration.name]. With mode set to ANY, model will
/// predict a function call from the set of function names provided.
/// A set of function names that, when provided, limits the functions the model will call.
let allowedFunctionNames: [String]?

/// Creates a new `FunctionCallingConfig`.
///
/// - Parameters:
/// - mode: Specifies the mode in which function calling should execute; if unspecified, the
/// default behavior will be ``Mode/auto``.
/// - allowedFunctionNames: A set of function names that, when provided, limits the functions
/// the model will call.
/// Note: This should only be set when the ``Mode`` is ``Mode/any``. Function names should match
/// `[FunctionDeclaration.name]`. With mode set to ``Mode/any``, the model will predict a
/// function call from the set of function names provided.
public init(mode: FunctionCallingConfig.Mode? = nil, allowedFunctionNames: [String]? = nil) {
self.mode = mode
self.allowedFunctionNames = allowedFunctionNames
Expand Down
5 changes: 4 additions & 1 deletion FirebaseVertexAI/Tests/Integration/IntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ final class IntegrationTests: XCTestCase {
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: [],
toolConfig: .init(functionCallingConfig: .init(mode: FunctionCallingConfig.Mode.none)),
systemInstruction: systemInstruction
)
}
Expand Down Expand Up @@ -94,6 +95,7 @@ final class IntegrationTests: XCTestCase {
SafetySetting(harmCategory: .dangerousContent, threshold: .blockNone),
SafetySetting(harmCategory: .civicIntegrity, threshold: .off),
],
toolConfig: .init(functionCallingConfig: .init(mode: .auto)),
systemInstruction: systemInstruction
)

Expand Down Expand Up @@ -135,7 +137,8 @@ final class IntegrationTests: XCTestCase {
)
model = vertex.generativeModel(
modelName: "gemini-1.5-flash",
tools: [Tool(functionDeclarations: [sumDeclaration])]
tools: [Tool(functionDeclarations: [sumDeclaration])],
toolConfig: .init(functionCallingConfig: .init(mode: .any, allowedFunctionNames: ["sum"]))
)
let prompt = "What is 10 + 32?"
let sumCall = FunctionCallPart(name: "sum", args: ["x": .number(10), "y": .number(32)])
Expand Down

0 comments on commit 1e954e9

Please sign in to comment.