Skip to content

Commit

Permalink
Add function calling sample in Vertex AI (#12745)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Apr 10, 2024
1 parent 110f391 commit 4c07757
Show file tree
Hide file tree
Showing 4 changed files with 432 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseVertexAI
import GenerativeAIUIComponents
import SwiftUI

struct FunctionCallingScreen: View {
@EnvironmentObject
var viewModel: FunctionCallingViewModel

@State
private var userPrompt = "What is 100 Euros in U.S. Dollars?"

enum FocusedField: Hashable {
case message
}

@FocusState
var focusedField: FocusedField?

var body: some View {
VStack {
ScrollViewReader { scrollViewProxy in
List {
Text("Interact with a currency conversion API using function calling in Gemini.")
ForEach(viewModel.messages) { message in
MessageView(message: message)
}
if let error = viewModel.error {
ErrorView(error: error)
.tag("errorView")
}
}
.listStyle(.plain)
.onChange(of: viewModel.messages, perform: { newValue in
if viewModel.hasError {
// Wait for a short moment to make sure we can actually scroll to the bottom.
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) {
withAnimation {
scrollViewProxy.scrollTo("errorView", anchor: .bottom)
}
focusedField = .message
}
} else {
guard let lastMessage = viewModel.messages.last else { return }

// Wait for a short moment to make sure we can actually scroll to the bottom.
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) {
withAnimation {
scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom)
}
focusedField = .message
}
}
})
}
InputField("Message...", text: $userPrompt) {
Image(systemName: viewModel.busy ? "stop.circle.fill" : "arrow.up.circle.fill")
.font(.title)
}
.focused($focusedField, equals: .message)
.onSubmit { sendOrStop() }
}
.toolbar {
ToolbarItem(placement: .primaryAction) {
Button(action: newChat) {
Image(systemName: "square.and.pencil")
}
}
}
.navigationTitle("Function Calling")
.onAppear {
focusedField = .message
}
}

private func sendMessage() {
Task {
let prompt = userPrompt
userPrompt = ""
await viewModel.sendMessage(prompt, streaming: true)
}
}

private func sendOrStop() {
if viewModel.busy {
viewModel.stop()
} else {
sendMessage()
}
}

private func newChat() {
viewModel.startNewChat()
}
}

struct FunctionCallingScreen_Previews: PreviewProvider {
struct ContainerView: View {
@EnvironmentObject
var viewModel: FunctionCallingViewModel

var body: some View {
FunctionCallingScreen()
.onAppear {
viewModel.messages = ChatMessage.samples
}
}
}

static var previews: some View {
NavigationStack {
FunctionCallingScreen().environmentObject(FunctionCallingViewModel())
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseVertexAI
import Foundation
import UIKit

@MainActor
class FunctionCallingViewModel: ObservableObject {
/// This array holds both the user's and the system's chat messages
@Published var messages = [ChatMessage]()

/// Indicates we're waiting for the model to finish
@Published var busy = false

@Published var error: Error?
var hasError: Bool {
return error != nil
}

/// Function calls pending processing
private var functionCalls = [FunctionCall]()

private var model: GenerativeModel
private var chat: Chat

private var chatTask: Task<Void, Never>?

init() {
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro",
tools: [Tool(functionDeclarations: [
FunctionDeclaration(
name: "get_exchange_rate",
description: "Get the exchange rate for currencies between countries",
parameters: [
"currency_from": Schema(
type: .string,
format: "enum",
description: "The currency to convert from in ISO 4217 format",
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
),
"currency_to": Schema(
type: .string,
format: "enum",
description: "The currency to convert to in ISO 4217 format",
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
),
],
requiredParameters: ["currency_from", "currency_to"]
),
])]
)
chat = model.startChat()
}

func sendMessage(_ text: String, streaming: Bool = true) async {
error = nil
chatTask?.cancel()

chatTask = Task {
busy = true
defer {
busy = false
}

// first, add the user's message to the chat
let userMessage = ChatMessage(message: text, participant: .user)
messages.append(userMessage)

// add a pending message while we're waiting for a response from the backend
let systemMessage = ChatMessage.pending(participant: .system)
messages.append(systemMessage)

print(messages)
do {
repeat {
if streaming {
try await internalSendMessageStreaming(text)
} else {
try await internalSendMessage(text)
}
} while !functionCalls.isEmpty
} catch {
self.error = error
print(error.localizedDescription)
messages.removeLast()
}
}
}

func startNewChat() {
stop()
error = nil
chat = model.startChat()
messages.removeAll()
}

func stop() {
chatTask?.cancel()
error = nil
}

private func internalSendMessageStreaming(_ text: String) async throws {
let functionResponses = try await processFunctionCalls()
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
if functionResponses.isEmpty {
responseStream = chat.sendMessageStream(text)
} else {
for functionResponse in functionResponses {
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
}
responseStream = chat.sendMessageStream(functionResponses.modelContent())
}
for try await chunk in responseStream {
processResponseContent(content: chunk)
}
}

private func internalSendMessage(_ text: String) async throws {
let functionResponses = try await processFunctionCalls()
let response: GenerateContentResponse
if functionResponses.isEmpty {
response = try await chat.sendMessage(text)
} else {
for functionResponse in functionResponses {
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
}
response = try await chat.sendMessage(functionResponses.modelContent())
}
processResponseContent(content: response)
}

func processResponseContent(content: GenerateContentResponse) {
guard let candidate = content.candidates.first else {
fatalError("No candidate.")
}

for part in candidate.content.parts {
switch part {
case let .text(text):
// replace pending message with backend response
messages[messages.count - 1].message += text
messages[messages.count - 1].pending = false
case let .functionCall(functionCall):
messages.insert(functionCall.chatMessage(), at: messages.count - 1)
functionCalls.append(functionCall)
case .data, .functionResponse:
fatalError("Unsupported response content.")
}
}
}

func processFunctionCalls() async throws -> [FunctionResponse] {
var functionResponses = [FunctionResponse]()
for functionCall in functionCalls {
switch functionCall.name {
case "get_exchange_rate":
let exchangeRates = getExchangeRate(args: functionCall.args)
functionResponses.append(FunctionResponse(
name: "get_exchange_rate",
response: exchangeRates
))
default:
fatalError("Unknown function named \"\(functionCall.name)\".")
}
}
functionCalls = []

return functionResponses
}

// MARK: - Callable Functions

func getExchangeRate(args: JSONObject) -> JSONObject {
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
guard case let .string(from) = args["currency_from"] else {
fatalError("Missing `currency_from` parameter.")
}
guard case let .string(to) = args["currency_to"] else {
fatalError("Missing `currency_to` parameter.")
}

// 2. Get the exchange rate
let allRates: [String: [String: Double]] = [
"AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379],
"CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362],
"EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932],
"GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836],
"JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679],
"USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26],
]
guard let fromRates = allRates[from] else {
return ["error": .string("No data for currency \(from).")]
}
guard let toRate = fromRates[to] else {
return ["error": .string("No data for currency \(to).")]
}

// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
return ["rates": .number(toRate)]
}
}

private extension FunctionCall {
func chatMessage() -> ChatMessage {
let encoder = JSONEncoder()
encoder.outputFormatting = .prettyPrinted

let jsonData: Data
do {
jsonData = try encoder.encode(self)
} catch {
fatalError("JSON Encoding Failed: \(error.localizedDescription)")
}
guard let json = String(data: jsonData, encoding: .utf8) else {
fatalError("Failed to convert JSON data to a String.")
}
let messageText = "Function call requested by model:\n```\n\(json)\n```"

return ChatMessage(message: messageText, participant: .system)
}
}

private extension FunctionResponse {
func chatMessage() -> ChatMessage {
let encoder = JSONEncoder()
encoder.outputFormatting = .prettyPrinted

let jsonData: Data
do {
jsonData = try encoder.encode(self)
} catch {
fatalError("JSON Encoding Failed: \(error.localizedDescription)")
}
guard let json = String(data: jsonData, encoding: .utf8) else {
fatalError("Failed to convert JSON data to a String.")
}
let messageText = "Function response returned by app:\n```\n\(json)\n```"

return ChatMessage(message: messageText, participant: .user)
}
}

private extension [FunctionResponse] {
func modelContent() -> [ModelContent] {
return self.map { ModelContent(
role: "function",
parts: [ModelContent.Part.functionResponse($0)]
)
}
}
}
Loading

0 comments on commit 4c07757

Please sign in to comment.