From 4c077572353f2f4e8a06710ba01ed27863690735 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 10 Apr 2024 16:30:24 -0400 Subject: [PATCH] Add function calling sample in Vertex AI (#12745) --- .../Screens/FunctionCallingScreen.swift | 128 +++++++++ .../ViewModels/FunctionCallingViewModel.swift | 264 ++++++++++++++++++ .../project.pbxproj | 32 +++ .../GenerativeAISample/ContentView.swift | 8 + 4 files changed, 432 insertions(+) create mode 100644 FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift create mode 100644 FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift diff --git a/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift b/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift new file mode 100644 index 00000000000..30d222e29bc --- /dev/null +++ b/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift @@ -0,0 +1,128 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import FirebaseVertexAI +import GenerativeAIUIComponents +import SwiftUI + +struct FunctionCallingScreen: View { + @EnvironmentObject + var viewModel: FunctionCallingViewModel + + @State + private var userPrompt = "What is 100 Euros in U.S. Dollars?" + + enum FocusedField: Hashable { + case message + } + + @FocusState + var focusedField: FocusedField? + + var body: some View { + VStack { + ScrollViewReader { scrollViewProxy in + List { + Text("Interact with a currency conversion API using function calling in Gemini.") + ForEach(viewModel.messages) { message in + MessageView(message: message) + } + if let error = viewModel.error { + ErrorView(error: error) + .tag("errorView") + } + } + .listStyle(.plain) + .onChange(of: viewModel.messages, perform: { newValue in + if viewModel.hasError { + // Wait for a short moment to make sure we can actually scroll to the bottom. + DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) { + withAnimation { + scrollViewProxy.scrollTo("errorView", anchor: .bottom) + } + focusedField = .message + } + } else { + guard let lastMessage = viewModel.messages.last else { return } + + // Wait for a short moment to make sure we can actually scroll to the bottom. + DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) { + withAnimation { + scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom) + } + focusedField = .message + } + } + }) + } + InputField("Message...", text: $userPrompt) { + Image(systemName: viewModel.busy ? "stop.circle.fill" : "arrow.up.circle.fill") + .font(.title) + } + .focused($focusedField, equals: .message) + .onSubmit { sendOrStop() } + } + .toolbar { + ToolbarItem(placement: .primaryAction) { + Button(action: newChat) { + Image(systemName: "square.and.pencil") + } + } + } + .navigationTitle("Function Calling") + .onAppear { + focusedField = .message + } + } + + private func sendMessage() { + Task { + let prompt = userPrompt + userPrompt = "" + await viewModel.sendMessage(prompt, streaming: true) + } + } + + private func sendOrStop() { + if viewModel.busy { + viewModel.stop() + } else { + sendMessage() + } + } + + private func newChat() { + viewModel.startNewChat() + } +} + +struct FunctionCallingScreen_Previews: PreviewProvider { + struct ContainerView: View { + @EnvironmentObject + var viewModel: FunctionCallingViewModel + + var body: some View { + FunctionCallingScreen() + .onAppear { + viewModel.messages = ChatMessage.samples + } + } + } + + static var previews: some View { + NavigationStack { + FunctionCallingScreen().environmentObject(FunctionCallingViewModel()) + } + } +} diff --git a/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift b/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift new file mode 100644 index 00000000000..f18c0e85fc1 --- /dev/null +++ b/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift @@ -0,0 +1,264 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import FirebaseVertexAI +import Foundation +import UIKit + +@MainActor +class FunctionCallingViewModel: ObservableObject { + /// This array holds both the user's and the system's chat messages + @Published var messages = [ChatMessage]() + + /// Indicates we're waiting for the model to finish + @Published var busy = false + + @Published var error: Error? + var hasError: Bool { + return error != nil + } + + /// Function calls pending processing + private var functionCalls = [FunctionCall]() + + private var model: GenerativeModel + private var chat: Chat + + private var chatTask: Task? + + init() { + model = VertexAI.vertexAI().generativeModel( + modelName: "gemini-1.0-pro", + tools: [Tool(functionDeclarations: [ + FunctionDeclaration( + name: "get_exchange_rate", + description: "Get the exchange rate for currencies between countries", + parameters: [ + "currency_from": Schema( + type: .string, + format: "enum", + description: "The currency to convert from in ISO 4217 format", + enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"] + ), + "currency_to": Schema( + type: .string, + format: "enum", + description: "The currency to convert to in ISO 4217 format", + enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"] + ), + ], + requiredParameters: ["currency_from", "currency_to"] + ), + ])] + ) + chat = model.startChat() + } + + func sendMessage(_ text: String, streaming: Bool = true) async { + error = nil + chatTask?.cancel() + + chatTask = Task { + busy = true + defer { + busy = false + } + + // first, add the user's message to the chat + let userMessage = ChatMessage(message: text, participant: .user) + messages.append(userMessage) + + // add a pending message while we're waiting for a response from the backend + let systemMessage = ChatMessage.pending(participant: .system) + messages.append(systemMessage) + + print(messages) + do { + repeat { + if streaming { + try await internalSendMessageStreaming(text) + } else { + try await internalSendMessage(text) + } + } while !functionCalls.isEmpty + } catch { + self.error = error + print(error.localizedDescription) + messages.removeLast() + } + } + } + + func startNewChat() { + stop() + error = nil + chat = model.startChat() + messages.removeAll() + } + + func stop() { + chatTask?.cancel() + error = nil + } + + private func internalSendMessageStreaming(_ text: String) async throws { + let functionResponses = try await processFunctionCalls() + let responseStream: AsyncThrowingStream + if functionResponses.isEmpty { + responseStream = chat.sendMessageStream(text) + } else { + for functionResponse in functionResponses { + messages.insert(functionResponse.chatMessage(), at: messages.count - 1) + } + responseStream = chat.sendMessageStream(functionResponses.modelContent()) + } + for try await chunk in responseStream { + processResponseContent(content: chunk) + } + } + + private func internalSendMessage(_ text: String) async throws { + let functionResponses = try await processFunctionCalls() + let response: GenerateContentResponse + if functionResponses.isEmpty { + response = try await chat.sendMessage(text) + } else { + for functionResponse in functionResponses { + messages.insert(functionResponse.chatMessage(), at: messages.count - 1) + } + response = try await chat.sendMessage(functionResponses.modelContent()) + } + processResponseContent(content: response) + } + + func processResponseContent(content: GenerateContentResponse) { + guard let candidate = content.candidates.first else { + fatalError("No candidate.") + } + + for part in candidate.content.parts { + switch part { + case let .text(text): + // replace pending message with backend response + messages[messages.count - 1].message += text + messages[messages.count - 1].pending = false + case let .functionCall(functionCall): + messages.insert(functionCall.chatMessage(), at: messages.count - 1) + functionCalls.append(functionCall) + case .data, .functionResponse: + fatalError("Unsupported response content.") + } + } + } + + func processFunctionCalls() async throws -> [FunctionResponse] { + var functionResponses = [FunctionResponse]() + for functionCall in functionCalls { + switch functionCall.name { + case "get_exchange_rate": + let exchangeRates = getExchangeRate(args: functionCall.args) + functionResponses.append(FunctionResponse( + name: "get_exchange_rate", + response: exchangeRates + )) + default: + fatalError("Unknown function named \"\(functionCall.name)\".") + } + } + functionCalls = [] + + return functionResponses + } + + // MARK: - Callable Functions + + func getExchangeRate(args: JSONObject) -> JSONObject { + // 1. Validate and extract the parameters provided by the model (from a `FunctionCall`) + guard case let .string(from) = args["currency_from"] else { + fatalError("Missing `currency_from` parameter.") + } + guard case let .string(to) = args["currency_to"] else { + fatalError("Missing `currency_to` parameter.") + } + + // 2. Get the exchange rate + let allRates: [String: [String: Double]] = [ + "AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379], + "CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362], + "EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932], + "GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836], + "JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679], + "USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26], + ] + guard let fromRates = allRates[from] else { + return ["error": .string("No data for currency \(from).")] + } + guard let toRate = fromRates[to] else { + return ["error": .string("No data for currency \(to).")] + } + + // 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`) + return ["rates": .number(toRate)] + } +} + +private extension FunctionCall { + func chatMessage() -> ChatMessage { + let encoder = JSONEncoder() + encoder.outputFormatting = .prettyPrinted + + let jsonData: Data + do { + jsonData = try encoder.encode(self) + } catch { + fatalError("JSON Encoding Failed: \(error.localizedDescription)") + } + guard let json = String(data: jsonData, encoding: .utf8) else { + fatalError("Failed to convert JSON data to a String.") + } + let messageText = "Function call requested by model:\n```\n\(json)\n```" + + return ChatMessage(message: messageText, participant: .system) + } +} + +private extension FunctionResponse { + func chatMessage() -> ChatMessage { + let encoder = JSONEncoder() + encoder.outputFormatting = .prettyPrinted + + let jsonData: Data + do { + jsonData = try encoder.encode(self) + } catch { + fatalError("JSON Encoding Failed: \(error.localizedDescription)") + } + guard let json = String(data: jsonData, encoding: .utf8) else { + fatalError("Failed to convert JSON data to a String.") + } + let messageText = "Function response returned by app:\n```\n\(json)\n```" + + return ChatMessage(message: messageText, participant: .user) + } +} + +private extension [FunctionResponse] { + func modelContent() -> [ModelContent] { + return self.map { ModelContent( + role: "function", + parts: [ModelContent.Part.functionResponse($0)] + ) + } + } +} diff --git a/FirebaseVertexAI/Sample/GenerativeAISample.xcodeproj/project.pbxproj b/FirebaseVertexAI/Sample/GenerativeAISample.xcodeproj/project.pbxproj index 01336b9cf19..0e1bf8211f0 100644 --- a/FirebaseVertexAI/Sample/GenerativeAISample.xcodeproj/project.pbxproj +++ b/FirebaseVertexAI/Sample/GenerativeAISample.xcodeproj/project.pbxproj @@ -10,6 +10,8 @@ 868A33662BB476FA00304BB1 /* FirebaseAppCheck in Frameworks */ = {isa = PBXBuildFile; productRef = 868A33652BB476FA00304BB1 /* FirebaseAppCheck */; }; 868A33682BB476FA00304BB1 /* FirebaseVertexAI-Preview in Frameworks */ = {isa = PBXBuildFile; productRef = 868A33672BB476FA00304BB1 /* FirebaseVertexAI-Preview */; }; 869200B32B879C4F00482873 /* GoogleService-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 869200B22B879C4F00482873 /* GoogleService-Info.plist */; }; + 86C1F4832BC726150026816F /* FunctionCallingScreen.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86C1F47E2BC726150026816F /* FunctionCallingScreen.swift */; }; + 86C1F4842BC726150026816F /* FunctionCallingViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86C1F4802BC726150026816F /* FunctionCallingViewModel.swift */; }; 88263BF02B239C09008AB09B /* ErrorView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 88263BEE2B239BFE008AB09B /* ErrorView.swift */; }; 88263BF12B239C11008AB09B /* ErrorDetailsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 889873842B208563005B4896 /* ErrorDetailsView.swift */; }; 8848C8332B0D04BC007B434F /* GenerativeAISampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8848C8322B0D04BC007B434F /* GenerativeAISampleApp.swift */; }; @@ -31,6 +33,8 @@ /* Begin PBXFileReference section */ 869200B22B879C4F00482873 /* GoogleService-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "GoogleService-Info.plist"; sourceTree = ""; }; + 86C1F47E2BC726150026816F /* FunctionCallingScreen.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FunctionCallingScreen.swift; sourceTree = ""; }; + 86C1F4802BC726150026816F /* FunctionCallingViewModel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FunctionCallingViewModel.swift; sourceTree = ""; }; 8802666F2B0FC39000CF7CB6 /* PhotoReasoningViewModel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PhotoReasoningViewModel.swift; sourceTree = ""; }; 880266752B0FC39000CF7CB6 /* PhotoReasoningScreen.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PhotoReasoningScreen.swift; sourceTree = ""; }; 88209C1B2B0FBDC300F64795 /* SummarizeScreen.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SummarizeScreen.swift; sourceTree = ""; }; @@ -71,6 +75,31 @@ /* End PBXFrameworksBuildPhase section */ /* Begin PBXGroup section */ + 86C1F47F2BC726150026816F /* Screens */ = { + isa = PBXGroup; + children = ( + 86C1F47E2BC726150026816F /* FunctionCallingScreen.swift */, + ); + path = Screens; + sourceTree = ""; + }; + 86C1F4812BC726150026816F /* ViewModels */ = { + isa = PBXGroup; + children = ( + 86C1F4802BC726150026816F /* FunctionCallingViewModel.swift */, + ); + path = ViewModels; + sourceTree = ""; + }; + 86C1F4822BC726150026816F /* FunctionCallingSample */ = { + isa = PBXGroup; + children = ( + 86C1F4812BC726150026816F /* ViewModels */, + 86C1F47F2BC726150026816F /* Screens */, + ); + path = FunctionCallingSample; + sourceTree = ""; + }; 8802666E2B0FC39000CF7CB6 /* ViewModels */ = { isa = PBXGroup; children = ( @@ -119,6 +148,7 @@ 8848C8452B0D051E007B434F /* GenerativeAITextSample */, 8848C8572B0D056C007B434F /* GenerativeAIMultimodalSample */, 88E10F432B110D5300C08E95 /* ChatSample */, + 86C1F4822BC726150026816F /* FunctionCallingSample */, 8848C8302B0D04BC007B434F /* Products */, 88209C222B0FBE1700F64795 /* Frameworks */, ); @@ -327,7 +357,9 @@ isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; files = ( + 86C1F4832BC726150026816F /* FunctionCallingScreen.swift in Sources */, 886F95DF2B17D5010036F07A /* BouncingDots.swift in Sources */, + 86C1F4842BC726150026816F /* FunctionCallingViewModel.swift in Sources */, 886F95DE2B17D5010036F07A /* ChatMessage.swift in Sources */, 88263BF12B239C11008AB09B /* ErrorDetailsView.swift in Sources */, 8848C8352B0D04BC007B434F /* ContentView.swift in Sources */, diff --git a/FirebaseVertexAI/Sample/GenerativeAISample/ContentView.swift b/FirebaseVertexAI/Sample/GenerativeAISample/ContentView.swift index 34331bf6eb7..2841d634a6c 100644 --- a/FirebaseVertexAI/Sample/GenerativeAISample/ContentView.swift +++ b/FirebaseVertexAI/Sample/GenerativeAISample/ContentView.swift @@ -18,6 +18,9 @@ struct ContentView: View { @StateObject var viewModel = ConversationViewModel() + @StateObject + var functionCallingViewModel = FunctionCallingViewModel() + var body: some View { NavigationStack { List { @@ -37,6 +40,11 @@ struct ContentView: View { } label: { Label("Chat", systemImage: "ellipsis.message.fill") } + NavigationLink { + FunctionCallingScreen().environmentObject(functionCallingViewModel) + } label: { + Label("Function Calling", systemImage: "function") + } } .navigationTitle("Generative AI Samples") }