Skip to content

Commit 0a880cc

Browse files
authored
[Ai] Add workaround for invalid SafetyRating from the backend. (#6925)
Due to a bug in the backend, it's possible that we receive an invalid `SafetyRating` value, without either category or probability. We return null in those cases to enable filtering by the higher level types.
1 parent 4b12b33 commit 0a880cc

File tree

4 files changed

+48
-12
lines changed

4 files changed

+48
-12
lines changed

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt

+19-11
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ internal constructor(
5151
val groundingMetadata: GroundingMetadata? = null,
5252
) {
5353
internal fun toPublic(): Candidate {
54-
val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty()
54+
val safetyRatings = safetyRatings?.mapNotNull { it.toPublic() }.orEmpty()
5555
val citations = citationMetadata?.toPublic()
5656
val finishReason = finishReason?.toPublic()
5757

@@ -120,23 +120,31 @@ internal constructor(
120120
internal data class Internal
121121
@JvmOverloads
122122
constructor(
123-
val category: HarmCategory.Internal,
124-
val probability: HarmProbability.Internal,
123+
val category: HarmCategory.Internal? = null,
124+
val probability: HarmProbability.Internal? = null,
125125
val blocked: Boolean? = null, // TODO(): any reason not to default to false?
126126
val probabilityScore: Float? = null,
127127
val severity: HarmSeverity.Internal? = null,
128128
val severityScore: Float? = null,
129129
) {
130130

131131
internal fun toPublic() =
132-
SafetyRating(
133-
category = category.toPublic(),
134-
probability = probability.toPublic(),
135-
probabilityScore = probabilityScore ?: 0f,
136-
blocked = blocked,
137-
severity = severity?.toPublic(),
138-
severityScore = severityScore
139-
)
132+
// Due to a bug in the backend, it's possible that we receive
133+
// an invalid `SafetyRating` value, without either category or
134+
// probability. We return null in those cases to enable
135+
// filtering by the higher level types.
136+
if (category == null || probability == null) {
137+
null
138+
} else {
139+
SafetyRating(
140+
category = category.toPublic(),
141+
probability = probability.toPublic(),
142+
probabilityScore = probabilityScore ?: 0f,
143+
blocked = blocked,
144+
severity = severity?.toPublic(),
145+
severityScore = severityScore
146+
)
147+
}
140148
}
141149
}
142150

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/PromptFeedback.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public class PromptFeedback(
4242
) {
4343

4444
internal fun toPublic(): PromptFeedback {
45-
val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty()
45+
val safetyRatings = safetyRatings?.mapNotNull { it.toPublic() }.orEmpty()
4646
return PromptFeedback(blockReason?.toPublic(), safetyRatings, blockReasonMessage)
4747
}
4848
}

firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt

+15
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ import kotlinx.coroutines.flow.collect
3636
import kotlinx.coroutines.flow.toList
3737
import kotlinx.coroutines.withTimeout
3838
import org.junit.Test
39+
import org.junit.runner.RunWith
40+
import org.robolectric.RobolectricTestRunner
3941

42+
@RunWith(RobolectricTestRunner::class)
4043
internal class VertexAIStreamingSnapshotTests {
4144
private val testTimeout = 5.seconds
4245

@@ -85,6 +88,18 @@ internal class VertexAIStreamingSnapshotTests {
8588
}
8689
}
8790

91+
@Test
92+
fun `invalid safety ratings during image generation`() =
93+
goldenVertexStreamingFile("streaming-success-image-invalid-safety-ratings.txt") {
94+
val responses = model.generateContentStream("prompt")
95+
96+
withTimeout(testTimeout) {
97+
val responseList = responses.toList()
98+
99+
responseList.isEmpty() shouldBe false
100+
}
101+
}
102+
88103
@Test
89104
fun `unknown enum in finish reason`() =
90105
goldenVertexStreamingFile("streaming-failure-unknown-finish-enum.txt") {

firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt

+13
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,11 @@ import kotlinx.serialization.json.jsonObject
5555
import kotlinx.serialization.json.jsonPrimitive
5656
import org.json.JSONArray
5757
import org.junit.Test
58+
import org.junit.runner.RunWith
59+
import org.robolectric.RobolectricTestRunner
5860

5961
@OptIn(PublicPreviewAPI::class)
62+
@RunWith(RobolectricTestRunner::class)
6063
internal class VertexAIUnarySnapshotTests {
6164
private val testTimeout = 5.seconds
6265

@@ -125,6 +128,16 @@ internal class VertexAIUnarySnapshotTests {
125128
}
126129
}
127130

131+
@Test
132+
fun `invalid safety ratings during image generation`() =
133+
goldenVertexUnaryFile("unary-success-image-invalid-safety-ratings.json") {
134+
withTimeout(testTimeout) {
135+
val response = model.generateContent("prompt")
136+
137+
response.candidates.isEmpty() shouldBe false
138+
}
139+
}
140+
128141
@Test
129142
fun `unknown enum in finish reason`() =
130143
goldenVertexUnaryFile("unary-failure-unknown-enum-finish-reason.json") {

0 commit comments

Comments
 (0)