Skip to content

Commit

Permalink
Add default RequestOptions.timeout of 300 seconds (google-gemini#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored and G.Dev.Ssomsak committed Jun 21, 2024
1 parent 33628ae commit 9f5fecb
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 13 deletions.
12 changes: 5 additions & 7 deletions Sources/GoogleAI/GenerativeAIRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,18 @@ protocol GenerativeAIRequest: Encodable {
/// Configuration parameters for sending requests to the backend.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public struct RequestOptions {
/// The request’s timeout interval in seconds; if not specified uses the default value for a
/// `URLRequest`.
let timeout: TimeInterval?
/// The request’s timeout interval in seconds.
let timeout: TimeInterval

/// The API version to use in requests to the backend.
let apiVersion: String

/// Initializes a request options object.
///
/// - Parameters:
/// - timeout The request’s timeout interval in seconds; if not specified uses the default value
/// for a `URLRequest`.
/// - apiVersion The API version to use in requests to the backend; defaults to "v1beta".
public init(timeout: TimeInterval? = nil, apiVersion: String = "v1beta") {
/// - timeout: The request’s timeout interval in seconds; defaults to 300 seconds (5 minutes).
/// - apiVersion: The API version to use in requests to the backend; defaults to "v1beta".
public init(timeout: TimeInterval = 300.0, apiVersion: String = "v1beta") {
self.timeout = timeout
self.apiVersion = apiVersion
}
Expand Down
5 changes: 1 addition & 4 deletions Sources/GoogleAI/GenerativeAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,7 @@ struct GenerativeAIService {
let encoder = JSONEncoder()
encoder.keyEncodingStrategy = .convertToSnakeCase
urlRequest.httpBody = try encoder.encode(request)

if let timeoutInterval = request.options.timeout {
urlRequest.timeoutInterval = timeoutInterval
}
urlRequest.timeoutInterval = request.options.timeout

return urlRequest
}
Expand Down
51 changes: 49 additions & 2 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,20 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(response.candidates.count, 1)
}

func testGenerateContent_requestOptions_defaultTimeout() async throws {
let expectedTimeout = 300.0 // Default in timeout in RequestOptions()
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json",
timeout: expectedTimeout
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
}

// MARK: - Generate Content (Streaming)

func testGenerateContentStream_failureInvalidAPIKey() async throws {
Expand Down Expand Up @@ -967,6 +981,25 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(responses, 1)
}

func testGenerateContentStream_requestOptions_defaultTimeout() async throws {
let expectedTimeout = 300.0 // Default in timeout in RequestOptions()
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt",
timeout: expectedTimeout
)

var responses = 0
let stream = model.generateContentStream(testPrompt)
for try await content in stream {
XCTAssertNotNil(content.text)
responses += 1
}

XCTAssertEqual(responses, 1)
}

// MARK: - Count Tokens

func testCountTokens_succeeds() async throws {
Expand Down Expand Up @@ -1019,6 +1052,20 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(response.totalTokens, 6)
}

func testCountTokens_requestOptions_defaultTimeout() async throws {
let expectedTimeout = 300.0
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "success-total-tokens",
withExtension: "json",
timeout: expectedTimeout
)

let response = try await model.countTokens(testPrompt)

XCTAssertEqual(response.totalTokens, 6)
}

// MARK: - Model Resource Name

func testModelResourceName_noPrefix() async throws {
Expand Down Expand Up @@ -1067,8 +1114,8 @@ final class GenerativeModelTests: XCTestCase {
private func httpRequestHandler(forResource name: String,
withExtension ext: String,
statusCode: Int = 200,
timeout: TimeInterval = URLRequest
.defaultTimeoutInterval()) throws -> ((URLRequest) throws -> (
timeout: TimeInterval = RequestOptions()
.timeout) throws -> ((URLRequest) throws -> (
URLResponse,
AsyncLineSequence<URL.AsyncBytes>?
)) {
Expand Down

0 comments on commit 9f5fecb

Please sign in to comment.