From 4b263b6ce5f30fc642c3f6511027d797ec036469 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 9 Oct 2024 18:53:17 -0400 Subject: [PATCH] [Vertex AI] Refactor `BlockReason` as a `struct` and add new values (#13861) --- FirebaseVertexAI/CHANGELOG.md | 2 + .../Sources/GenerateContentResponse.swift | 55 +++++++++++-------- .../Tests/Unit/GenerativeModelTests.swift | 3 +- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index b6444fb81f0..04ae2573ecd 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -52,6 +52,8 @@ filter. (#13863) - [added] Added new `FinishReason` values `.blocklist`, `.prohibitedContent`, `.spii` and `.malformedFunctionCall` that may be reported. (#13860) +- [added] Added new `BlockReason` values `.blocklist` and `.prohibitedContent` + that may be reported when a prompt is blocked. (#13861) # 11.3.0 - [added] Added `Decodable` conformance for `FunctionResponse`. (#13606) diff --git a/FirebaseVertexAI/Sources/GenerateContentResponse.swift b/FirebaseVertexAI/Sources/GenerateContentResponse.swift index 079b759d26f..94b45cb45b1 100644 --- a/FirebaseVertexAI/Sources/GenerateContentResponse.swift +++ b/FirebaseVertexAI/Sources/GenerateContentResponse.swift @@ -221,15 +221,43 @@ public struct FinishReason: DecodableProtoEnum, Hashable, Sendable { @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct PromptFeedback: Sendable { /// A type describing possible reasons to block a prompt. - public enum BlockReason: String, Sendable { - /// The block reason is unknown. - case unknown = "UNKNOWN" + public struct BlockReason: DecodableProtoEnum, Hashable, Sendable { + enum Kind: String { + case safety = "SAFETY" + case other = "OTHER" + case blocklist = "BLOCKLIST" + case prohibitedContent = "PROHIBITED_CONTENT" + } /// The prompt was blocked because it was deemed unsafe. - case safety = "SAFETY" + public static var safety: BlockReason { + return self.init(kind: .safety) + } /// All other block reasons. - case other = "OTHER" + public static var other: BlockReason { + return self.init(kind: .other) + } + + /// The prompt was blocked because it contained terms from the terminology blocklist. + public static var blocklist: BlockReason { + return self.init(kind: .blocklist) + } + + /// The prompt was blocked due to prohibited content. + public static var prohibitedContent: BlockReason { + return self.init(kind: .prohibitedContent) + } + + /// Returns the raw string representation of the `BlockReason` value. + /// + /// > Note: This value directly corresponds to the values in the [REST + /// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#BlockedReason). + public let rawValue: String + + var unrecognizedValueMessageCode: VertexLog.MessageCode { + .generateContentResponseUnrecognizedBlockReason + } } /// The reason a prompt was blocked, if it was blocked. @@ -383,23 +411,6 @@ extension Citation: Decodable { } } -@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension PromptFeedback.BlockReason: Decodable { - public init(from decoder: Decoder) throws { - let value = try decoder.singleValueContainer().decode(String.self) - guard let decodedBlockReason = PromptFeedback.BlockReason(rawValue: value) else { - VertexLog.error( - code: .generateContentResponseUnrecognizedBlockReason, - "Unrecognized BlockReason with value \"\(value)\"." - ) - self = .unknown - return - } - - self = decodedBlockReason - } -} - @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) extension PromptFeedback: Decodable { enum CodingKeys: CodingKey { diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index ae0f55361fa..5474a4675ef 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -627,13 +627,14 @@ final class GenerativeModelTests: XCTestCase { forResource: "unary-failure-unknown-enum-prompt-blocked", withExtension: "json" ) + let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON") do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") } catch let GenerateContentError.promptBlocked(response) { let promptFeedback = try XCTUnwrap(response.promptFeedback) - XCTAssertEqual(promptFeedback.blockReason, .unknown) + XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason) } catch { XCTFail("Should throw a promptBlocked") }