Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Eleven and Fireworks support to ds_tool #31

Merged
merged 3 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions ultravox/tools/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from ultravox.tools import tts

chat_client = openai.Client()
tts_client = tts.AzureTts()
tts_client: tts.Client
chat_client: openai.Client

DEFAULT_TEXTGEN_TEMPLATE = """Passage: {passage}

Expand All @@ -22,14 +22,18 @@

@dataclasses.dataclass
class TtsTask:
implementation: str = simple_parsing.field(default="azure", alias="-i")
column_name: str = simple_parsing.field(default="question", alias="-c")
audio_column_name: Optional[str] = simple_parsing.field(default=None, alias="-a")
voice: Optional[str] = simple_parsing.field(default=None, alias="-V")
sample_rate: int = simple_parsing.field(default=16000, alias="-r")

def __post_init__(self):
# The TTS client is separate from the task to avoid pickling issues when multiprocessing.
global tts_client
if self.audio_column_name is None:
self.audio_column_name = f"{self.column_name}_audio"
tts_client = tts.create_client(self.implementation, self.sample_rate)
juberti marked this conversation as resolved.
Show resolved Hide resolved

def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Dataset:
print(f'TTS mapping "{self.column_name}" to "{self.audio_column_name}"...')
Expand All @@ -40,7 +44,7 @@ def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Datas
def _map_sample(self, sample):
text = sample[self.column_name]
text = text["text"] if isinstance(text, dict) else text
sample[self.audio_column_name] = tts_client.tts(text)
sample[self.audio_column_name] = tts_client.tts(text, self.voice)
return sample


Expand All @@ -50,10 +54,15 @@ class TextGenerationTask:
template: str = simple_parsing.field(default=DEFAULT_TEXTGEN_TEMPLATE, alias="-T")

language_model: str = simple_parsing.field(default="gpt-4o", alias="-m")
base_url: Optional[str] = simple_parsing.field(default=None, alias="-b")
api_key: Optional[str] = simple_parsing.field(default=None, alias="-k")
max_tokens: int = 128
temperature: float = 0

def __post_init__(self):
# The OAI client is separate from the task to avoid pickling issues when multiprocessing.
global chat_client
chat_client = openai.Client(base_url=self.base_url, api_key=self.api_key)
juberti marked this conversation as resolved.
Show resolved Hide resolved
if self.template.startswith("@"):
with open(self.template[1:], "r") as template_file:
self.template = template_file.read()
Expand All @@ -75,9 +84,10 @@ def _map_sample(self, sample):


# This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model.
# Ex: just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN
# Ex: just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/boolq-audio -c explanation
# Ex: just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt
# Example usages:
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -c explanation -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt
@dataclasses.dataclass
class DatasetToolArgs:
dataset_name: str = simple_parsing.field(alias="-d")
Expand All @@ -88,7 +98,7 @@ class DatasetToolArgs:
num_workers: int = simple_parsing.field(default=16, alias="-w")

upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u")
upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-b")
upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B")
num_shards: Optional[int] = simple_parsing.field(default=None, alias="-N")
private: bool = simple_parsing.field(default=False)

Expand Down
132 changes: 119 additions & 13 deletions ultravox/tools/tts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import abc
import io
import os
from typing import Optional
from typing import Any, Dict, Optional
from xml.sax import saxutils

import numpy as np
import requests
import soundfile as sf

RANDOM_VOICE_KEY = "random"


def _make_ssml(voice: str, text: str):
return f"""
Expand All @@ -17,19 +20,68 @@ def _make_ssml(voice: str, text: str):
</speak>"""


class AzureTts:
DEFAULT_VOICE = "en-US-JennyNeural"

def __init__(self):
class Client(abc.ABC):
def __init__(self, sample_rate: int = 16000):
self._session = requests.Session()
self._sample_rate = sample_rate

@abc.abstractmethod
def tts(self, text: str, voice: Optional[str] = None):
raise NotImplementedError

def _post(self, url: str, headers: Dict[str, str], json: Dict[str, Any]):
response = self._session.post(url, headers=headers, json=json)
response.raise_for_status()
return response

def tts(self, text: str, voice: Optional[str] = None, sample_rate: int = 16000):
def _handle_pcm_response(self, response: requests.Response):
pcm_array = np.frombuffer(response.content, dtype=np.int16)
wav_bytes = io.BytesIO()
sf.write(wav_bytes, pcm_array, self._sample_rate, format="wav")
return wav_bytes.getvalue()


class AzureTts(Client):
DEFAULT_VOICE = "en-US-JennyNeural"
ALL_VOICES = [
"en-US-AvaNeural",
"en-US-AndrewNeural",
"en-US-EmmaNeural",
"en-US-BrianNeural",
"en-US-JennyNeural",
"en-US-GuyNeural",
"en-US-AriaNeural",
"en-US-DavisNeural",
"en-US-JaneNeural",
"en-US-JasonNeural",
"en-US-SaraNeural",
"en-US-TonyNeural",
"en-US-NancyNeural",
"en-US-AmberNeural",
"en-US-AnaNeural",
"en-US-AshleyNeural",
"en-US-BrandonNeural",
"en-US-ChristopherNeural",
"en-US-CoraNeural",
"en-US-ElizabethNeural",
"en-US-EricNeural",
"en-US-JacobNeural",
"en-US-MichelleNeural",
"en-US-MonicaNeural",
"en-US-RogerNeural",
]

def tts(self, text: str, voice: Optional[str] = None):
voice = voice or self.DEFAULT_VOICE
if voice == RANDOM_VOICE_KEY:
voice = np.random.choice(self.ALL_VOICES)
assert voice
region = "westus"
api_key = os.environ.get("AZURE_TTS_API_KEY") or os.environ.get(
"AZURE_WESTUS_TTS_API_KEY"
)
output_format = f"raw-{sample_rate // 1000}khz-16bit-mono-pcm"
assert api_key, "Please set the AZURE_TTS_API_KEY environment variable."
output_format = f"raw-{self._sample_rate // 1000}khz-16bit-mono-pcm"
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
headers = {
"Ocp-Apim-Subscription-Key": api_key,
Expand All @@ -38,10 +90,64 @@ def tts(self, text: str, voice: Optional[str] = None, sample_rate: int = 16000):
"User-Agent": "MyTTS",
}
body = _make_ssml(voice, text)
response = self._session.post(url, headers=headers, data=body)
response.raise_for_status()
return self._handle_pcm_response(self._post(url, headers, body))

pcm_array = np.frombuffer(response.content, dtype=np.int16)
wav_bytes = io.BytesIO()
sf.write(wav_bytes, pcm_array, sample_rate, format="wav")
return wav_bytes.getvalue()

class ElevenTts(Client):
DEFAULT_VOICE = "21m00Tcm4TlvDq8ikWAM"
DEFAULT_MODEL = "eleven_multilingual_v2"
ALL_VOICES = [
"21m00Tcm4TlvDq8ikWAM",
"29vD33N1CtxCmqQRPOHJ",
"2EiwWnXFnvU5JabPnv8n",
"5Q0t7uMcjvnagumLfvZi",
"AZnzlk1XvdvUeBnXmlld",
"CYw3kZ02Hs0563khs1Fj",
"D38z5RcWu1voky8WS1ja",
"EXAVITQu4vr4xnSDxMaL",
"ErXwobaYiN019PkySvjV",
"GBv7mTt0atIp3Br8iCZE",
"IKne3meq5aSn9XLyUdCD",
"JBFqnCBsd6RMkjVDRZzb",
"LcfcDJNUP1GQjkzn1xUU",
"MF3mGyEYCl7XYWbV9V6O",
"N2lVS1w4EtoT3dr4eOWO",
"ODq5zmih8GrVes37Dizd",
"SOYHLrjzK2X1ezoPC6cr",
"TX3LPaxmHKxFdv7VOQHJ",
"ThT5KcBeYPX3keUQqHPh",
"TxGEqnHWrfWFTfGW9XjX",
"VR6AewLTigWG4xSOukaG",
"XB0fDUnXU5powFXDhCwa",
"Xb7hH8MSUJpSbSDYk0k2",
"XrExE9yKIg1WjnnlVkGX",
"ZQe5CZNOzWyzPSCn5a3c",
"Zlb1dXrM653N07WRdFW3",
]

def tts(self, text: str, voice: Optional[str] = None):
voice = voice or self.DEFAULT_VOICE
if voice == RANDOM_VOICE_KEY:
# Every process has same random seed, so we mix in the PID here for more variation.
i = np.random.randint(len(self.ALL_VOICES)) + os.getpid()
voice = self.ALL_VOICES[i % len(self.ALL_VOICES)]
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice}/stream?output_format=pcm_16000"
print("url", url)
headers = {"xi-api-key": os.environ["ELEVEN_API_KEY"]}
body = {
"text": text,
"model_id": self.DEFAULT_MODEL,
"voice_settings": {
"stability": 0.5,
"similarity_boost": False,
},
}
return self._handle_pcm_response(self._post(url, headers, body))


def create_client(implementation: str, sample_rate: int):
if implementation == "azure":
return AzureTts(sample_rate=sample_rate)
elif implementation == "eleven":
return ElevenTts(sample_rate=sample_rate)
juberti marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Unknown TTS implementation: {implementation}")
Loading