Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FunctionCall decoding #114

Merged
merged 3 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Sources/GoogleAI/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ public class Chat {
}

parts.append(part)

case .functionCall:
// TODO(andrewheard): Add function call to the chat history when encoding is implemented.
fatalError("Function calling not yet implemented in chat.")
}
}
}
Expand Down
41 changes: 41 additions & 0 deletions Sources/GoogleAI/FunctionCalling.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

/// A predicted function call returned from the model.
public struct FunctionCall: Equatable {
/// The name of the function to call.
let name: String

/// The function parameters and values.
let args: JSONObject
}

extension FunctionCall: Decodable {
enum CodingKeys: CodingKey {
case name
case args
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
name = try container.decode(String.self, forKey: .name)
if let args = try container.decodeIfPresent(JSONObject.self, forKey: .args) {
self.args = args
} else {
args = JSONObject()
}
}
}
11 changes: 10 additions & 1 deletion Sources/GoogleAI/ModelContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public struct ModelContent: Codable, Equatable {
enum CodingKeys: String, CodingKey {
case text
case inlineData
case functionCall
}

enum InlineDataKeys: String, CodingKey {
Expand All @@ -38,6 +39,9 @@ public struct ModelContent: Codable, Equatable {
/// Data with a specified media type. Not all media types may be supported by the AI model.
case data(mimetype: String, Data)

/// A predicted function call returned from the model.
case functionCall(FunctionCall)

// MARK: Convenience Initializers

/// Convenience function for populating a Part with JPEG data.
Expand All @@ -64,6 +68,9 @@ public struct ModelContent: Codable, Equatable {
)
try inlineDataContainer.encode(mimetype, forKey: .mimeType)
try inlineDataContainer.encode(bytes, forKey: .bytes)
case .functionCall:
// TODO(andrewheard): Encode FunctionCalls when when encoding is implemented.
fatalError("FunctionCall encoding not implemented.")
}
}

Expand All @@ -79,10 +86,12 @@ public struct ModelContent: Codable, Equatable {
let mimetype = try dataContainer.decode(String.self, forKey: .mimeType)
let bytes = try dataContainer.decode(Data.self, forKey: .bytes)
self = .data(mimetype: mimetype, bytes)
} else if let functionCall = try? values.decode(FunctionCall.self, forKey: .functionCall) {
andrewheard marked this conversation as resolved.
Show resolved Hide resolved
self = .functionCall(functionCall)
} else {
throw DecodingError.dataCorrupted(.init(
codingPath: [CodingKeys.text, CodingKeys.inlineData],
debugDescription: "Neither text or inline data was found."
debugDescription: "No text, inline data or function call was found."
))
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "current_time"
}
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0
}
]
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "current_time",
"args": {}
}
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "sum",
"args": {
"y": 5,
"x": 4
}
}
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0
}
]
}
67 changes: 67 additions & 0 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,73 @@ final class GenerativeModelTests: XCTestCase {
_ = try await model.generateContent(testPrompt)
}

func testGenerateContent_success_functionCall_emptyArguments() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-function-call-empty-arguments",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
XCTAssertEqual(candidate.content.parts.count, 1)
let part = try XCTUnwrap(candidate.content.parts.first)
guard case let .functionCall(functionCall) = part else {
XCTFail("Part is not a FunctionCall.")
return
}
XCTAssertEqual(functionCall.name, "current_time")
XCTAssertTrue(functionCall.args.isEmpty)
}

func testGenerateContent_success_functionCall_noArguments() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-function-call-no-arguments",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
XCTAssertEqual(candidate.content.parts.count, 1)
let part = try XCTUnwrap(candidate.content.parts.first)
guard case let .functionCall(functionCall) = part else {
XCTFail("Part is not a FunctionCall.")
return
}
XCTAssertEqual(functionCall.name, "current_time")
XCTAssertTrue(functionCall.args.isEmpty)
}

func testGenerateContent_success_functionCall_withArguments() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-function-call-with-arguments",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
XCTAssertEqual(candidate.content.parts.count, 1)
let part = try XCTUnwrap(candidate.content.parts.first)
guard case let .functionCall(functionCall) = part else {
XCTFail("Part is not a FunctionCall.")
return
}
XCTAssertEqual(functionCall.name, "sum")
XCTAssertEqual(functionCall.args.count, 2)
let argX = try XCTUnwrap(functionCall.args["x"])
XCTAssertEqual(argX, .number(4))
let argY = try XCTUnwrap(functionCall.args["y"])
XCTAssertEqual(argY, .number(5))
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol
Expand Down
Loading