Skip to content

Commit

Permalink
[Vertex AI] Refactor BlockReason as a struct and add new values (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Oct 9, 2024
1 parent 27cffd9 commit 4b263b6
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
2 changes: 2 additions & 0 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
filter. (#13863)
- [added] Added new `FinishReason` values `.blocklist`, `.prohibitedContent`,
`.spii` and `.malformedFunctionCall` that may be reported. (#13860)
- [added] Added new `BlockReason` values `.blocklist` and `.prohibitedContent`
that may be reported when a prompt is blocked. (#13861)

# 11.3.0
- [added] Added `Decodable` conformance for `FunctionResponse`. (#13606)
Expand Down
55 changes: 33 additions & 22 deletions FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,43 @@ public struct FinishReason: DecodableProtoEnum, Hashable, Sendable {
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct PromptFeedback: Sendable {
/// A type describing possible reasons to block a prompt.
public enum BlockReason: String, Sendable {
/// The block reason is unknown.
case unknown = "UNKNOWN"
public struct BlockReason: DecodableProtoEnum, Hashable, Sendable {
enum Kind: String {
case safety = "SAFETY"
case other = "OTHER"
case blocklist = "BLOCKLIST"
case prohibitedContent = "PROHIBITED_CONTENT"
}

/// The prompt was blocked because it was deemed unsafe.
case safety = "SAFETY"
public static var safety: BlockReason {
return self.init(kind: .safety)
}

/// All other block reasons.
case other = "OTHER"
public static var other: BlockReason {
return self.init(kind: .other)
}

/// The prompt was blocked because it contained terms from the terminology blocklist.
public static var blocklist: BlockReason {
return self.init(kind: .blocklist)
}

/// The prompt was blocked due to prohibited content.
public static var prohibitedContent: BlockReason {
return self.init(kind: .prohibitedContent)
}

/// Returns the raw string representation of the `BlockReason` value.
///
/// > Note: This value directly corresponds to the values in the [REST
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#BlockedReason).
public let rawValue: String

var unrecognizedValueMessageCode: VertexLog.MessageCode {
.generateContentResponseUnrecognizedBlockReason
}
}

/// The reason a prompt was blocked, if it was blocked.
Expand Down Expand Up @@ -383,23 +411,6 @@ extension Citation: Decodable {
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension PromptFeedback.BlockReason: Decodable {
public init(from decoder: Decoder) throws {
let value = try decoder.singleValueContainer().decode(String.self)
guard let decodedBlockReason = PromptFeedback.BlockReason(rawValue: value) else {
VertexLog.error(
code: .generateContentResponseUnrecognizedBlockReason,
"Unrecognized BlockReason with value \"\(value)\"."
)
self = .unknown
return
}

self = decodedBlockReason
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension PromptFeedback: Decodable {
enum CodingKeys: CodingKey {
Expand Down
3 changes: 2 additions & 1 deletion FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -627,13 +627,14 @@ final class GenerativeModelTests: XCTestCase {
forResource: "unary-failure-unknown-enum-prompt-blocked",
withExtension: "json"
)
let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")

do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GenerateContentError.promptBlocked(response) {
let promptFeedback = try XCTUnwrap(response.promptFeedback)
XCTAssertEqual(promptFeedback.blockReason, .unknown)
XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
} catch {
XCTFail("Should throw a promptBlocked")
}
Expand Down

0 comments on commit 4b263b6

Please sign in to comment.