Skip to content

Commit

Permalink
[Text To Speech] Add Text To Speech implementation (#28)
Browse files Browse the repository at this point in the history
* Add Text To Speech implementation
  • Loading branch information
SwiftBeta authored Nov 11, 2023
1 parent 99073f7 commit b37189e
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import Foundation

public enum OpenAIAudioResponseType: String {
case mp3
case opus
case aac
case flac
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import Foundation

public enum OpenAITTSModelType {
case tts(TTS)

var name: String {
switch self {
case .tts(let model):
return model.rawValue
}
}
}

public enum TTS: String {
case tts1 = "tts-1"
case tts1HD = "tts-1-hd"
}
10 changes: 10 additions & 0 deletions Sources/SwiftOpenAI/OpenAI/DataModels/Audio/OpenAIVoiceType.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import Foundation

public enum OpenAIVoiceType: String {
case alloy
case echo
case fable
case onyx
case nova
case shimmer
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import Foundation

struct CreateSpeechEndpoint: Endpoint {
private let model: OpenAITTSModelType
private let input: String
private let voice: OpenAIVoiceType
private let responseFormat: OpenAIAudioResponseType
private let speed: Double

var method: HTTPMethod {
.POST
}

var path: String = "audio/speech"

init(model: OpenAITTSModelType,
input: String,
voice: OpenAIVoiceType,
responseFormat: OpenAIAudioResponseType,
speed: Double) {
self.model = model
self.input = input
self.voice = voice
self.responseFormat = responseFormat
self.speed = speed
}

var parameters: [String: Any]? {
["model": self.model.name as Any,
"input": self.input as Any,
"voice": self.voice.rawValue as Any,
"response_format": self.responseFormat.rawValue as Any,
"speed": self.speed as Any]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ enum OpenAIEndpoints {

case moderations(input: String)

case createSpeech(model: OpenAITTSModelType, input: String, voice: OpenAIVoiceType, responseFormat: OpenAIAudioResponseType, speed: Double)

public var endpoint: Endpoint {
switch self {
case .listModels:
Expand All @@ -40,6 +42,8 @@ enum OpenAIEndpoints {
input: input)
case .moderations(input: let input):
return ModerationEndpoint(input: input)
case .createSpeech(model: let model, input: let input, voice: let voice, responseFormat: let responseFormat, speed: let speed):
return CreateSpeechEndpoint(model: model, input: input, voice: voice, responseFormat: responseFormat, speed: speed)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import Foundation

protocol CreateSpeechRequestProtocol {
func execute(api: API,
apiKey: String,
model: OpenAITTSModelType,
input: String,
voice: OpenAIVoiceType,
responseFormat: OpenAIAudioResponseType,
speed: Double) async throws -> Data?
}

final public class CreateSpeechRequest: CreateSpeechRequestProtocol {
public typealias Init = (_ api: API,
_ apiKey: String,
_ model: OpenAITTSModelType,
_ input: String,
_ voice: OpenAIVoiceType,
_ responseFormat: OpenAIAudioResponseType,
_ speed: Double) async throws -> Data?

public init() { }

public func execute(api: API,
apiKey: String,
model: OpenAITTSModelType,
input: String,
voice: OpenAIVoiceType,
responseFormat: OpenAIAudioResponseType,
speed: Double) async throws -> Data? {
var endpoint = OpenAIEndpoints.createSpeech(model: model, input: input, voice: voice, responseFormat: responseFormat, speed: speed).endpoint
api.routeEndpoint(&endpoint, environment: OpenAIEnvironmentV1())

var urlRequest = api.buildURLRequest(endpoint: endpoint)
api.addHeaders(urlRequest: &urlRequest,
headers: ["Content-Type": "application/json",
"Authorization": "Bearer \(apiKey)"])

let result = await api.execute(with: urlRequest)

let jsonDecoder = JSONDecoder()
jsonDecoder.keyDecodingStrategy = .convertFromSnakeCase

switch result {
case .success(let data):
return data
case .failure(let error):
throw error
}
}
}
43 changes: 40 additions & 3 deletions Sources/SwiftOpenAI/OpenAI/SwiftOpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ protocol OpenAIProtocol {

func createImages(model: OpenAIImageModelType, prompt: String, numberOfImages: Int, size: ImageSize) async throws -> CreateImageDataModel?

func embeddings(model: OpenAIModelType,
input: String) async throws -> EmbeddingResponseDataModel?
func embeddings(model: OpenAIModelType, input: String) async throws -> EmbeddingResponseDataModel?

func moderations(input: String) async throws -> ModerationDataModel?

func createSpeech(model: OpenAITTSModelType, input: String, voice: OpenAIVoiceType, responseFormat: OpenAIAudioResponseType, speed: Double) async throws -> Data?
}

// swiftlint:disable line_length
Expand All @@ -35,6 +36,7 @@ public class SwiftOpenAI: OpenAIProtocol {
private let createImagesRequest: CreateImagesRequest.Init
private let embeddingsRequest: EmbeddingsRequest.Init
private let moderationsRequest: ModerationsRequest.Init
private let createSpeechRequest: CreateSpeechRequest.Init

public init(api: API = API(),
apiKey: String,
Expand All @@ -44,7 +46,8 @@ public class SwiftOpenAI: OpenAIProtocol {
createChatCompletionsStreamRequest: @escaping CreateChatCompletionsStreamRequest.Init = CreateChatCompletionsStreamRequest().execute,
createImagesRequest: @escaping CreateImagesRequest.Init = CreateImagesRequest().execute,
embeddingsRequest: @escaping EmbeddingsRequest.Init = EmbeddingsRequest().execute,
moderationsRequest: @escaping ModerationsRequest.Init = ModerationsRequest().execute) {
moderationsRequest: @escaping ModerationsRequest.Init = ModerationsRequest().execute,
createSpeechRequest: @escaping CreateSpeechRequest.Init = CreateSpeechRequest().execute) {
self.api = api
self.apiKey = apiKey
self.listModelsRequest = listModelsRequest
Expand All @@ -54,6 +57,7 @@ public class SwiftOpenAI: OpenAIProtocol {
self.createImagesRequest = createImagesRequest
self.embeddingsRequest = embeddingsRequest
self.moderationsRequest = moderationsRequest
self.createSpeechRequest = createSpeechRequest
}

/**
Expand Down Expand Up @@ -283,5 +287,38 @@ public class SwiftOpenAI: OpenAIProtocol {
public func moderations(input: String) async throws -> ModerationDataModel? {
try await moderationsRequest(api, apiKey, input)
}

/**
Generates speech audio from a given input text using the OpenAI Text-to-Speech API.
This method utilizes the OpenAI Text-to-Speech API to convert a provided input text into speech audio. You can specify the desired TTS model, voice type, response format, and speech speed. The generated audio can be saved, played, or used for various applications.
The method leverages Swift's concurrency model and supports async/await calls.
- Parameters:
- model: An `OpenAITTSModelType` representing the desired TTS model to use.
- input: A `String` containing the text to be converted into speech.
- voice: An `OpenAIVoiceType` specifying the voice style for the generated speech.
- responseFormat: An `OpenAIAudioResponseType` indicating the desired format of the audio response.
- speed: A `Double` representing the speech speed, with 1.0 being normal speed.
- Throws: An error if the API call fails or if there is an issue parsing the received audio data.
- Returns: An optional `Data` object containing the generated speech audio in the specified format. Returns `nil` if there was a problem fetching the data or parsing the audio response.
Example usage:
let inputText = "The quick brown fox jumped over the lazy dog."
do {
let audioData = try await createSpeech(model: .tts1, input: inputText, voice: .alloy, responseFormat: .mp3, speed: 1.0)
// Save, play, or process the audio data as needed
} catch {
print("Error: \(error)")
}
*/
public func createSpeech(model: OpenAITTSModelType, input: String, voice: OpenAIVoiceType, responseFormat: OpenAIAudioResponseType, speed: Double) async throws -> Data? {
try await createSpeechRequest(api, apiKey, model, input, voice, responseFormat, speed)
}
}
// swiftlint:enable line_length
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import XCTest
@testable import SwiftOpenAI

final class CreateSpeechEndpointSpec: XCTestCase {
func testEndpointCreateSpeech_WithModelTTS() throws {
let model: OpenAITTSModelType = .tts(.tts1)
let input = "Please create an audio with this input"
let voice: OpenAIVoiceType = .alloy
let responseFormat: OpenAIAudioResponseType = .mp3
let speed = 1.0

let sut = OpenAIEndpoints.createSpeech(
model: model,
input: input,
voice: voice,
responseFormat: responseFormat,
speed: speed
).endpoint

let modelParameter = sut.parameters!["model"] as! String
let inputParameter = sut.parameters!["input"] as! String
let voiceParameter = sut.parameters!["voice"] as! String
let responseFormatParameter = sut.parameters!["response_format"] as! String
let speedFormatParameter = sut.parameters!["speed"] as! Double

XCTAssertEqual(sut.path, "audio/speech")
XCTAssertEqual(sut.method, .POST)
XCTAssertEqual(sut.parameters?.count, 5)
XCTAssertEqual(modelParameter, model.name)
XCTAssertEqual(inputParameter, input)
XCTAssertEqual(voiceParameter, voice.rawValue)
XCTAssertEqual(responseFormatParameter, responseFormat.rawValue)
XCTAssertEqual(speedFormatParameter, speed)
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import XCTest
@testable import SwiftOpenAI

final class CreateSpeechRequestSpec: XCTestCase {
private let api = API()

func testRequest_CreatedWithCorrectHeaders() throws {
let apiKey = "1234567890"
let model: OpenAITTSModelType = .tts(.tts1)
let input = "Please create an audio with this input"
let voice: OpenAIVoiceType = .alloy
let responseFormat: OpenAIAudioResponseType = .mp3
let speed = 1.0

var endpoint = OpenAIEndpoints.createSpeech(
model: model,
input: input,
voice: voice,
responseFormat: responseFormat,
speed: speed
).endpoint

api.routeEndpoint(&endpoint, environment: OpenAIEnvironmentV1())

var sut = api.buildURLRequest(endpoint: endpoint)
api.addHeaders(urlRequest: &sut,
headers: ["Content-Type" : "application/json",
"Authorization" : "Bearer \(apiKey)"])

XCTAssertEqual(sut.allHTTPHeaderFields?.count, 2)
XCTAssertEqual(sut.allHTTPHeaderFields?["Content-Type"], "application/json")
XCTAssertEqual(sut.allHTTPHeaderFields?["Authorization"], "Bearer 1234567890")
}
}

0 comments on commit b37189e

Please sign in to comment.