Skip to content

Commit

Permalink
Add responseSchema to GenerationConfig (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored May 28, 2024
1 parent edc9de3 commit d8b1fbb
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
11 changes: 10 additions & 1 deletion Sources/GoogleAI/GenerationConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ public struct GenerationConfig {
/// - `application/json`: JSON response in the candidates.
public let responseMIMEType: String?

/// Output response schema of the generated candidate text.
///
/// - Note: This only applies when the specified ``responseMIMEType`` supports a schema; currently
/// this is limited to `application/json`.
public let responseSchema: Schema?

/// Creates a new `GenerationConfig` value.
///
/// - Parameters:
Expand All @@ -80,9 +86,11 @@ public struct GenerationConfig {
/// - maxOutputTokens: See ``maxOutputTokens``.
/// - stopSequences: See ``stopSequences``.
/// - responseMIMEType: See ``responseMIMEType``.
/// - responseSchema: See ``responseSchema``.
public init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil,
candidateCount: Int? = nil, maxOutputTokens: Int? = nil,
stopSequences: [String]? = nil, responseMIMEType: String? = nil) {
stopSequences: [String]? = nil, responseMIMEType: String? = nil,
responseSchema: Schema? = nil) {
// Explicit init because otherwise if we re-arrange the above variables it changes the API
// surface.
self.temperature = temperature
Expand All @@ -92,6 +100,7 @@ public struct GenerationConfig {
self.maxOutputTokens = maxOutputTokens
self.stopSequences = stopSequences
self.responseMIMEType = responseMIMEType
self.responseSchema = responseSchema
}
}

Expand Down
42 changes: 39 additions & 3 deletions Tests/GoogleAITests/GenerationConfigTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,20 @@ final class GenerationConfigTests: XCTestCase {
let candidateCount = 2
let maxOutputTokens = 256
let stopSequences = ["END", "DONE"]
let responseMIMEType = "text/plain"
let responseMIMEType = "application/json"
let schemaType = DataType.object
let fieldName = "test-field"
let fieldType = DataType.string
let responseSchema = Schema(type: schemaType, properties: [fieldName: Schema(type: fieldType)])
let generationConfig = GenerationConfig(
temperature: temperature,
topP: topP,
topK: topK,
candidateCount: candidateCount,
maxOutputTokens: maxOutputTokens,
stopSequences: stopSequences,
responseMIMEType: responseMIMEType
responseMIMEType: responseMIMEType,
responseSchema: responseSchema
)

let jsonData = try encoder.encode(generationConfig)
Expand All @@ -67,6 +72,14 @@ final class GenerationConfigTests: XCTestCase {
"candidateCount" : \(candidateCount),
"maxOutputTokens" : \(maxOutputTokens),
"responseMIMEType" : "\(responseMIMEType)",
"responseSchema" : {
"properties" : {
"\(fieldName)" : {
"type" : "\(fieldType.rawValue)"
}
},
"type" : "\(schemaType.rawValue)"
},
"stopSequences" : [
"END",
"DONE"
Expand All @@ -79,7 +92,7 @@ final class GenerationConfigTests: XCTestCase {
}

func testEncodeGenerationConfig_responseMIMEType() throws {
let mimeType = "image/jpeg"
let mimeType = "text/plain"
let generationConfig = GenerationConfig(responseMIMEType: mimeType)

let jsonData = try encoder.encode(generationConfig)
Expand All @@ -91,4 +104,27 @@ final class GenerationConfigTests: XCTestCase {
}
""")
}

func testEncodeGenerationConfig_responseMIMETypeWithSchema() throws {
let mimeType = "application/json"
let schemaType = DataType.array
let arrayItemType = DataType.integer
let schema = Schema(type: schemaType, items: Schema(type: arrayItemType))
let generationConfig = GenerationConfig(responseMIMEType: mimeType, responseSchema: schema)

let jsonData = try encoder.encode(generationConfig)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"responseMIMEType" : "\(mimeType)",
"responseSchema" : {
"items" : {
"type" : "\(arrayItemType.rawValue)"
},
"type" : "\(schemaType.rawValue)"
}
}
""")
}
}

0 comments on commit d8b1fbb

Please sign in to comment.