Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Make generateContentStream/sendMessageStream throws #13573

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
- [fixed] Resolved a decoding error for citations without a `uri` and added
support for decoding `title` fields, which were previously ignored. (#13518)
- [changed] **Breaking Change**: The methods for starting streaming requests
(`generateContentStream` and `sendMessageStream`) and creating a chat instance
(`startChat`) are now asynchronous and must be called with `await`. (#13545)
(`generateContentStream` and `sendMessageStream`) are now throwing and
asynchronous and must be called with `try await`. (#13545, #13573)
- [changed] **Breaking Change**: Creating a chat instance (`startChat`) is now
asynchronous and must be called with `await`. (#13545)

# 10.29.0
- [feature] Added community support for watchOS. (#13215)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ConversationViewModel: ObservableObject {
guard let chat else {
throw ChatError.notInitialized
}
let responseStream = await chat.sendMessageStream(text)
let responseStream = try await chat.sendMessageStream(text)
for try await chunk in responseStream {
messages[messages.count - 1].pending = false
if let text = chunk.text {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ class FunctionCallingViewModel: ObservableObject {
}
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
if functionResponses.isEmpty {
responseStream = await chat.sendMessageStream(text)
responseStream = try await chat.sendMessageStream(text)
} else {
for functionResponse in functionResponses {
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
}
responseStream = await chat.sendMessageStream(functionResponses.modelContent())
responseStream = try await chat.sendMessageStream(functionResponses.modelContent())
}
for try await chunk in responseStream {
processResponseContent(content: chunk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class PhotoReasoningViewModel: ObservableObject {
}
}

let outputContentStream = await model.generateContentStream(prompt, images)
let outputContentStream = try await model.generateContentStream(prompt, images)

// stream response
for try await outputContent in outputContentStream {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SummarizeViewModel: ObservableObject {

let prompt = "Summarize the following text for me: \(inputText)"

let outputContentStream = await model.generateContentStream(prompt)
let outputContentStream = try await model.generateContentStream(prompt)

// stream response
for try await outputContent in outputContentStream {
Expand Down
17 changes: 6 additions & 11 deletions FirebaseVertexAI/Sources/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public actor Chat {
/// - Parameter parts: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, *)
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...)
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return try sendMessageStream([ModelContent(parts: parts)])
}
Expand All @@ -95,21 +95,16 @@ public actor Chat {
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, *)
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent])
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent]) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let resolvedContent: [ModelContent]
do {
resolvedContent = try content()
} catch let underlying {
return AsyncThrowingStream { continuation in
let error: Error
if let contentError = underlying as? ImageConversionError {
error = GenerateContentError.promptImageContentError(underlying: contentError)
} else {
error = GenerateContentError.internalError(underlying: underlying)
}
continuation.finish(throwing: error)
if let contentError = underlying as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: contentError)
}
throw GenerateContentError.internalError(underlying: underlying)
}

return AsyncThrowingStream { continuation in
Expand All @@ -121,7 +116,7 @@ public actor Chat {

// Send the history alongside the new message as context.
let request = history + newContent
let stream = await model.generateContentStream(request)
let stream = try await model.generateContentStream(request)
do {
for try await chunk in stream {
// Capture any content that's streaming. This should be populated if there's no error.
Expand Down
15 changes: 5 additions & 10 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ public final actor GenerativeModel {
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
/// error if an error occurred.
@available(macOS 12.0, *)
public func generateContentStream(_ parts: any ThrowingPartsRepresentable...)
public func generateContentStream(_ parts: any ThrowingPartsRepresentable...) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return try generateContentStream([ModelContent(parts: parts)])
}
Expand All @@ -190,21 +190,16 @@ public final actor GenerativeModel {
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
/// error if an error occurred.
@available(macOS 12.0, *)
public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent])
public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent]) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let evaluatedContent: [ModelContent]
do {
evaluatedContent = try content()
} catch let underlying {
return AsyncThrowingStream { continuation in
let error: Error
if let contentError = underlying as? ImageConversionError {
error = GenerateContentError.promptImageContentError(underlying: contentError)
} else {
error = GenerateContentError.internalError(underlying: underlying)
}
continuation.finish(throwing: error)
if let contentError = underlying as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: contentError)
}
throw GenerateContentError.internalError(underlying: underlying)
}

let generateContentRequest = GenerateContentRequest(model: modelResourceName,
Expand Down
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Tests/Unit/ChatTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ final class ChatTests: XCTestCase {
)
let chat = Chat(model: model, history: [])
let input = "Test input"
let stream = await chat.sendMessageStream(input)
let stream = try await chat.sendMessageStream(input)

// Ensure the values are parsed correctly
for try await value in stream {
Expand Down
36 changes: 18 additions & 18 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ final class GenerativeModelTests: XCTestCase {
)

do {
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
Expand All @@ -784,7 +784,7 @@ final class GenerativeModelTests: XCTestCase {
)

do {
let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
Expand All @@ -807,7 +807,7 @@ final class GenerativeModelTests: XCTestCase {
)

do {
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
Expand All @@ -827,7 +827,7 @@ final class GenerativeModelTests: XCTestCase {
)

do {
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await _ in stream {
XCTFail("Content shouldn't be shown, this shouldn't happen.")
}
Expand All @@ -847,7 +847,7 @@ final class GenerativeModelTests: XCTestCase {
)

do {
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await _ in stream {
XCTFail("Content shouldn't be shown, this shouldn't happen.")
}
Expand All @@ -866,7 +866,7 @@ final class GenerativeModelTests: XCTestCase {
withExtension: "txt"
)

let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
do {
for try await content in stream {
XCTAssertNotNil(content.text)
Expand All @@ -887,7 +887,7 @@ final class GenerativeModelTests: XCTestCase {
)

var responses = 0
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await content in stream {
XCTAssertNotNil(content.text)
responses += 1
Expand All @@ -904,7 +904,7 @@ final class GenerativeModelTests: XCTestCase {
)

var responses = 0
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await content in stream {
XCTAssertNotNil(content.text)
responses += 1
Expand All @@ -921,7 +921,7 @@ final class GenerativeModelTests: XCTestCase {
)

var hadUnknown = false
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await content in stream {
XCTAssertNotNil(content.text)
if let ratings = content.candidates.first?.safetyRatings,
Expand All @@ -940,7 +940,7 @@ final class GenerativeModelTests: XCTestCase {
withExtension: "txt"
)

let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
var citations = [Citation]()
var responses = [GenerateContentResponse]()
for try await content in stream {
Expand Down Expand Up @@ -996,7 +996,7 @@ final class GenerativeModelTests: XCTestCase {
appCheckToken: appCheckToken
)

let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
for try await _ in stream {}
}

Expand All @@ -1018,7 +1018,7 @@ final class GenerativeModelTests: XCTestCase {
appCheckToken: AppCheckInteropFake.placeholderTokenValue
)

let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
for try await _ in stream {}
}

Expand All @@ -1030,7 +1030,7 @@ final class GenerativeModelTests: XCTestCase {
)
var responses = [GenerateContentResponse]()

let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}
Expand All @@ -1056,7 +1056,7 @@ final class GenerativeModelTests: XCTestCase {

var responseCount = 0
do {
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await content in stream {
XCTAssertNotNil(content.text)
responseCount += 1
Expand All @@ -1076,7 +1076,7 @@ final class GenerativeModelTests: XCTestCase {
func testGenerateContentStream_nonHTTPResponse() async throws {
MockURLProtocol.requestHandler = try nonHTTPRequestHandler()

let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
do {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
Expand All @@ -1096,7 +1096,7 @@ final class GenerativeModelTests: XCTestCase {
withExtension: "txt"
)

let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
do {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
Expand All @@ -1120,7 +1120,7 @@ final class GenerativeModelTests: XCTestCase {
withExtension: "txt"
)

let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
do {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
Expand Down Expand Up @@ -1159,7 +1159,7 @@ final class GenerativeModelTests: XCTestCase {
)

var responses = 0
let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
for try await content in stream {
XCTAssertNotNil(content.text)
responses += 1
Expand Down
Loading