From 610807f002a677c0dc31b8e1f4d4a5a718b455b3 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Mon, 14 Oct 2024 17:26:58 -0400 Subject: [PATCH 1/2] [Vertex AI] Add `blockReasonMessage` to `PromptFeedback` --- .../Sources/GenerateContentResponse.swift | 9 +++- .../Tests/Unit/GenerativeModelTests.swift | 49 ++++++++++++++++++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/FirebaseVertexAI/Sources/GenerateContentResponse.swift b/FirebaseVertexAI/Sources/GenerateContentResponse.swift index 07491765f90..d786676ec80 100644 --- a/FirebaseVertexAI/Sources/GenerateContentResponse.swift +++ b/FirebaseVertexAI/Sources/GenerateContentResponse.swift @@ -235,12 +235,17 @@ public struct PromptFeedback: Sendable { /// The reason a prompt was blocked, if it was blocked. public let blockReason: BlockReason? + /// A human-readable description of the ``blockReason``. + public let blockReasonMessage: String? + /// The safety ratings of the prompt. public let safetyRatings: [SafetyRating] /// Initializer for SwiftUI previews or tests. - public init(blockReason: BlockReason?, safetyRatings: [SafetyRating]) { + public init(blockReason: BlockReason?, blockReasonMessage: String? = nil, + safetyRatings: [SafetyRating]) { self.blockReason = blockReason + self.blockReasonMessage = blockReasonMessage self.safetyRatings = safetyRatings } } @@ -387,6 +392,7 @@ extension Citation: Decodable { extension PromptFeedback: Decodable { enum CodingKeys: CodingKey { case blockReason + case blockReasonMessage case safetyRatings } @@ -396,6 +402,7 @@ extension PromptFeedback: Decodable { PromptFeedback.BlockReason.self, forKey: .blockReason ) + blockReasonMessage = try container.decodeIfPresent(String.self, forKey: .blockReasonMessage) if let safetyRatings = try container.decodeIfPresent( [SafetyRating].self, forKey: .safetyRatings diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index 5ffa94daf64..7b6b1c15336 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -619,6 +619,29 @@ final class GenerativeModelTests: XCTestCase { XCTFail("Should throw") } catch let GenerateContentError.promptBlocked(response) { XCTAssertNil(response.text) + let promptFeedback = try XCTUnwrap(response.promptFeedback) + XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety) + XCTAssertNil(promptFeedback.blockReasonMessage) + } catch { + XCTFail("Should throw a promptBlocked") + } + } + + func testGenerateContent_failure_promptBlockedSafetyWithMessage() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-failure-prompt-blocked-safety-with-message", + withExtension: "json" + ) + + do { + _ = try await model.generateContent(testPrompt) + XCTFail("Should throw") + } catch let GenerateContentError.promptBlocked(response) { + XCTAssertNil(response.text) + let promptFeedback = try XCTUnwrap(response.promptFeedback) + XCTAssertEqual(promptFeedback.blockReason, PromptFeedback.BlockReason.safety) + XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons") } catch { XCTFail("Should throw a promptBlocked") } @@ -909,7 +932,31 @@ final class GenerativeModelTests: XCTestCase { XCTFail("Content shouldn't be shown, this shouldn't happen.") } } catch let GenerateContentError.promptBlocked(response) { - XCTAssertEqual(response.promptFeedback?.blockReason, .safety) + let promptFeedback = try XCTUnwrap(response.promptFeedback) + XCTAssertEqual(promptFeedback.blockReason, .safety) + XCTAssertNil(promptFeedback.blockReasonMessage) + return + } + + XCTFail("Should have caught an error.") + } + + func testGenerateContentStream_failurePromptBlockedSafetyWithMessage() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "streaming-failure-prompt-blocked-safety-with-message", + withExtension: "txt" + ) + + do { + let stream = try model.generateContentStream("Hi") + for try await _ in stream { + XCTFail("Content shouldn't be shown, this shouldn't happen.") + } + } catch let GenerateContentError.promptBlocked(response) { + let promptFeedback = try XCTUnwrap(response.promptFeedback) + XCTAssertEqual(promptFeedback.blockReason, .safety) + XCTAssertEqual(promptFeedback.blockReasonMessage, "Reasons") return } From 8f9cf2d0b80e0428a7d8c6732e2ef2b9926c42ee Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Mon, 14 Oct 2024 17:32:00 -0400 Subject: [PATCH 2/2] Add changelog entry --- FirebaseVertexAI/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index c3b6b5462ef..d0d95a82070 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -70,6 +70,8 @@ `.spii` and `.malformedFunctionCall` that may be reported. (#13860) - [added] Added new `BlockReason` values `.blocklist` and `.prohibitedContent` that may be reported when a prompt is blocked. (#13861) +- [added] Added the `PromptFeedback` property `blockReasonMessage` that *may* be + provided alongside the `blockReason`. (#13891) # 11.3.0 - [added] Added `Decodable` conformance for `FunctionResponse`. (#13606)