Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function calling sample in Vertex AI #12745

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseVertexAI
import GenerativeAIUIComponents
import SwiftUI

struct FunctionCallingScreen: View {
@EnvironmentObject
var viewModel: FunctionCallingViewModel

@State
private var userPrompt = "What is 100 Euros in U.S. Dollars?"

enum FocusedField: Hashable {
case message
}

@FocusState
var focusedField: FocusedField?

var body: some View {
VStack {
ScrollViewReader { scrollViewProxy in
List {
Text("Interact with a currency conversion API using function calling in Gemini.")
ForEach(viewModel.messages) { message in
MessageView(message: message)
}
if let error = viewModel.error {
ErrorView(error: error)
.tag("errorView")
}
}
.listStyle(.plain)
.onChange(of: viewModel.messages, perform: { newValue in
if viewModel.hasError {
// Wait for a short moment to make sure we can actually scroll to the bottom.
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) {
withAnimation {
scrollViewProxy.scrollTo("errorView", anchor: .bottom)
}
focusedField = .message
}
} else {
guard let lastMessage = viewModel.messages.last else { return }

// Wait for a short moment to make sure we can actually scroll to the bottom.
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) {
withAnimation {
scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom)
}
focusedField = .message
}
}
})
}
InputField("Message...", text: $userPrompt) {
Image(systemName: viewModel.busy ? "stop.circle.fill" : "arrow.up.circle.fill")
.font(.title)
}
.focused($focusedField, equals: .message)
.onSubmit { sendOrStop() }
}
.toolbar {
ToolbarItem(placement: .primaryAction) {
Button(action: newChat) {
Image(systemName: "square.and.pencil")
}
}
}
.navigationTitle("Function Calling")
.onAppear {
focusedField = .message
}
}

private func sendMessage() {
Task {
let prompt = userPrompt
userPrompt = ""
await viewModel.sendMessage(prompt, streaming: true)
}
}

private func sendOrStop() {
if viewModel.busy {
viewModel.stop()
} else {
sendMessage()
}
}

private func newChat() {
viewModel.startNewChat()
}
}

struct FunctionCallingScreen_Previews: PreviewProvider {
struct ContainerView: View {
@EnvironmentObject
var viewModel: FunctionCallingViewModel

var body: some View {
FunctionCallingScreen()
.onAppear {
viewModel.messages = ChatMessage.samples
}
}
}

static var previews: some View {
NavigationStack {
FunctionCallingScreen().environmentObject(FunctionCallingViewModel())
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseVertexAI
import Foundation
import UIKit

@MainActor
class FunctionCallingViewModel: ObservableObject {
/// This array holds both the user's and the system's chat messages
@Published var messages = [ChatMessage]()

/// Indicates we're waiting for the model to finish
@Published var busy = false

@Published var error: Error?
var hasError: Bool {
return error != nil
}

/// Function calls pending processing
private var functionCalls = [FunctionCall]()

private var model: GenerativeModel
private var chat: Chat

private var chatTask: Task<Void, Never>?

init() {
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro",
tools: [Tool(functionDeclarations: [
FunctionDeclaration(
name: "get_exchange_rate",
description: "Get the exchange rate for currencies between countries",
parameters: [
"currency_from": Schema(
type: .string,
format: "enum",
description: "The currency to convert from in ISO 4217 format",
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
),
"currency_to": Schema(
type: .string,
format: "enum",
description: "The currency to convert to in ISO 4217 format",
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
),
],
requiredParameters: ["currency_from", "currency_to"]
),
])]
)
chat = model.startChat()
}

func sendMessage(_ text: String, streaming: Bool = true) async {
error = nil
chatTask?.cancel()

chatTask = Task {
busy = true
defer {
busy = false
}

// first, add the user's message to the chat
let userMessage = ChatMessage(message: text, participant: .user)
messages.append(userMessage)

// add a pending message while we're waiting for a response from the backend
let systemMessage = ChatMessage.pending(participant: .system)
messages.append(systemMessage)

print(messages)
do {
repeat {
if streaming {
try await internalSendMessageStreaming(text)
} else {
try await internalSendMessage(text)
}
} while !functionCalls.isEmpty
} catch {
self.error = error
print(error.localizedDescription)
messages.removeLast()
}
}
}

func startNewChat() {
stop()
error = nil
chat = model.startChat()
messages.removeAll()
}

func stop() {
chatTask?.cancel()
error = nil
}

private func internalSendMessageStreaming(_ text: String) async throws {
let functionResponses = try await processFunctionCalls()
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
if functionResponses.isEmpty {
responseStream = chat.sendMessageStream(text)
} else {
for functionResponse in functionResponses {
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
}
responseStream = chat.sendMessageStream(functionResponses.modelContent())
}
for try await chunk in responseStream {
processResponseContent(content: chunk)
}
}

private func internalSendMessage(_ text: String) async throws {
let functionResponses = try await processFunctionCalls()
let response: GenerateContentResponse
if functionResponses.isEmpty {
response = try await chat.sendMessage(text)
} else {
for functionResponse in functionResponses {
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
}
response = try await chat.sendMessage(functionResponses.modelContent())
}
processResponseContent(content: response)
}

func processResponseContent(content: GenerateContentResponse) {
guard let candidate = content.candidates.first else {
fatalError("No candidate.")
}

for part in candidate.content.parts {
switch part {
case let .text(text):
// replace pending message with backend response
messages[messages.count - 1].message += text
messages[messages.count - 1].pending = false
case let .functionCall(functionCall):
messages.insert(functionCall.chatMessage(), at: messages.count - 1)
functionCalls.append(functionCall)
case .data, .functionResponse:
fatalError("Unsupported response content.")
}
}
}

func processFunctionCalls() async throws -> [FunctionResponse] {
var functionResponses = [FunctionResponse]()
for functionCall in functionCalls {
switch functionCall.name {
case "get_exchange_rate":
let exchangeRates = getExchangeRate(args: functionCall.args)
functionResponses.append(FunctionResponse(
name: "get_exchange_rate",
response: exchangeRates
))
default:
fatalError("Unknown function named \"\(functionCall.name)\".")
}
}
functionCalls = []

return functionResponses
}

// MARK: - Callable Functions

func getExchangeRate(args: JSONObject) -> JSONObject {
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
guard case let .string(from) = args["currency_from"] else {
fatalError("Missing `currency_from` parameter.")
}
guard case let .string(to) = args["currency_to"] else {
fatalError("Missing `currency_to` parameter.")
}

// 2. Get the exchange rate
let allRates: [String: [String: Double]] = [
"AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379],
"CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362],
"EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932],
"GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836],
"JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679],
"USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26],
]
guard let fromRates = allRates[from] else {
return ["error": .string("No data for currency \(from).")]
}
guard let toRate = fromRates[to] else {
return ["error": .string("No data for currency \(to).")]
}

// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
return ["rates": .number(toRate)]
}
}

private extension FunctionCall {
func chatMessage() -> ChatMessage {
let encoder = JSONEncoder()
encoder.outputFormatting = .prettyPrinted

let jsonData: Data
do {
jsonData = try encoder.encode(self)
} catch {
fatalError("JSON Encoding Failed: \(error.localizedDescription)")
}
guard let json = String(data: jsonData, encoding: .utf8) else {
fatalError("Failed to convert JSON data to a String.")
}
let messageText = "Function call requested by model:\n```\n\(json)\n```"

return ChatMessage(message: messageText, participant: .system)
}
}

private extension FunctionResponse {
func chatMessage() -> ChatMessage {
let encoder = JSONEncoder()
encoder.outputFormatting = .prettyPrinted

let jsonData: Data
do {
jsonData = try encoder.encode(self)
} catch {
fatalError("JSON Encoding Failed: \(error.localizedDescription)")
}
guard let json = String(data: jsonData, encoding: .utf8) else {
fatalError("Failed to convert JSON data to a String.")
}
let messageText = "Function response returned by app:\n```\n\(json)\n```"

return ChatMessage(message: messageText, participant: .user)
}
}

private extension [FunctionResponse] {
func modelContent() -> [ModelContent] {
return self.map { ModelContent(
role: "function",
parts: [ModelContent.Part.functionResponse($0)]
)
}
}
}
Loading
Loading