Skip to content

Commit

Permalink
Add support for blockReasonMessage (#6268)
Browse files Browse the repository at this point in the history
b/367673161
  • Loading branch information
rlazo committed Sep 18, 2024
1 parent 48d460c commit d153670
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ internal object FinishReasonSerializer :
internal data class PromptFeedback(
val blockReason: BlockReason? = null,
val safetyRatings: List<SafetyRating>? = null,
val blockReasonMessage: String? = null,
)

@Serializable(BlockReasonSerializer::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ internal fun com.google.firebase.vertexai.common.server.PromptFeedback.toPublic(
return com.google.firebase.vertexai.type.PromptFeedback(
blockReason?.toPublic(),
safetyRatings,
blockReasonMessage
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ package com.google.firebase.vertexai.type
*
* @param blockReason The reason that content was blocked, if at all.
* @param safetyRatings A list of relevant [SafetyRating].
* @param blockReasonMessage A message describing the reason that content was blocked, if any.
*/
class PromptFeedback(
val blockReason: BlockReason?,
val safetyRatings: List<SafetyRating>,
val blockReasonMessage: String?
)

/** Describes why content was blocked. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ internal class StreamingSnapshotTests {
}
}

@Test
fun `prompt blocked for safety with message`() =
goldenStreamingFile("streaming-failure-prompt-blocked-safety-with-message.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val exception = shouldThrow<PromptBlockedException> { responses.collect() }
exception.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY
exception.response.promptFeedback?.blockReasonMessage shouldBe "Reasons"
}
}

@Test
fun `empty content`() =
goldenStreamingFile("streaming-failure-empty-content.txt") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,18 @@ internal class UnarySnapshotTests {
}
}

@Test
fun `prompt blocked for safety with message`() =
goldenUnaryFile("unary-failure-prompt-blocked-safety-with-message.json") {
withTimeout(testTimeout) {
shouldThrow<PromptBlockedException> { model.generateContent("prompt") } should
{
it.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY
it.response.promptFeedback?.blockReasonMessage shouldContain "Reasons"
}
}
}

@Test
fun `empty content`() =
goldenUnaryFile("unary-failure-empty-content.json") {
Expand Down

0 comments on commit d153670

Please sign in to comment.