diff --git a/FirebaseVertexAI/Sources/Safety.swift b/FirebaseVertexAI/Sources/Safety.swift index d810613aecb..aa6264a3455 100644 --- a/FirebaseVertexAI/Sources/Safety.swift +++ b/FirebaseVertexAI/Sources/Safety.swift @@ -108,9 +108,24 @@ public struct SafetySetting { let rawValue: String } + /// The method of computing whether the ``SafetySetting/HarmBlockThreshold`` has been exceeded. + public struct HarmBlockMethod: EncodableProtoEnum, Sendable { + enum Kind: String { + case severity = "SEVERITY" + case probability = "PROBABILITY" + } + + public static let severity = HarmBlockMethod(kind: .severity) + + public static let probability = HarmBlockMethod(kind: .probability) + + let rawValue: String + } + enum CodingKeys: String, CodingKey { case harmCategory = "category" case threshold + case method } /// The category this safety setting should be applied to. @@ -119,10 +134,14 @@ public struct SafetySetting { /// The threshold describing what content should be blocked. public let threshold: HarmBlockThreshold + public let method: HarmBlockMethod? + /// Initializes a new safety setting with the given category and threshold. - public init(harmCategory: HarmCategory, threshold: HarmBlockThreshold) { + public init(harmCategory: HarmCategory, threshold: HarmBlockThreshold, + method: HarmBlockMethod? = nil) { self.harmCategory = harmCategory self.threshold = threshold + self.method = method } } diff --git a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift index fee87108da7..b884a41e9a5 100644 --- a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift @@ -30,8 +30,8 @@ final class IntegrationTests: XCTestCase { parts: "You are a friendly and helpful assistant." ) let safetySettings = [ - SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove), - SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove), + SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove, method: .probability), + SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove, method: .severity), SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockLowAndAbove), SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove), SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove), @@ -89,11 +89,11 @@ final class IntegrationTests: XCTestCase { modelName: "gemini-1.5-pro", generationConfig: generationConfig, safetySettings: [ - SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove), + SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove, method: .severity), SafetySetting(harmCategory: .hateSpeech, threshold: .blockMediumAndAbove), SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockOnlyHigh), SafetySetting(harmCategory: .dangerousContent, threshold: .blockNone), - SafetySetting(harmCategory: .civicIntegrity, threshold: .off), + SafetySetting(harmCategory: .civicIntegrity, threshold: .off, method: .probability), ], toolConfig: .init(functionCallingConfig: .auto()), systemInstruction: systemInstruction