Skip to content

Commit

Permalink
Extending ds_tool for SODA conversational dataset (#32)
Browse files Browse the repository at this point in the history
* extending ds_tool for soda conversational dataset

* caching for TTS generations

* dataset_creation scripts

* escape quotes for SODA template
  • Loading branch information
farzadab authored Jun 24, 2024
1 parent 4202b56 commit cbe5f6e
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
runs/
outputs/
wandb/
*.parquet

artifacts/

Expand Down
4 changes: 2 additions & 2 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ eval *FLAGS:
poetry run python -m ultravox.tools.eval_tool {{FLAGS}}

tts *FLAGS:
poetry run python -m ultravox.tools.ds_tool tts {{FLAGS}}
poetry run python -m ultravox.tools.ds_tool.ds_tool tts {{FLAGS}}

ds_tool *FLAGS:
poetry run python -m ultravox.tools.ds_tool {{FLAGS}}
poetry run python -m ultravox.tools.ds_tool.ds_tool {{FLAGS}}

mds *FLAGS:
poetry run python -m ultravox.tools.mds_tool {{FLAGS}}
Expand Down
12 changes: 12 additions & 0 deletions scripts/dataset_creation/boolq_audio.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
# Steps to reproduce the fixie-ai/boolq-audio dataset

# Step 0: create the `fixie-ai/boolq-audio` dataset in the UI or using huggingface_hub.create_repo

# Step 1: Create a plausible explanation for the answer
# This explanation is only used in the `-extended` version of the dataset and is used mainly for better training.
just ds_tool textgen -d google/boolq -u fixie-ai/boolq-audio -c explanation -T @ultravox/tools/ds_tool/boolq_template.jinja --token $HF_WRITE_TOKEN -N 8

# Step 2: TTS the question into the audio input for the model
# Note: the original dataset was not created using this script. This is just an example of how to create the audio version of the dataset
just ds_tool tts -d fixie-ai/boolq-audio -u fixie-ai/boolq-audio -c question -a audio -i azure --token $HF_WRITE_TOKEN -N 8
24 changes: 24 additions & 0 deletions scripts/dataset_creation/soda_audio.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
# Steps to reproduce the fixie-ai/soda-audio dataset

# Step 0: create the `fixie-ai/soda-audio` dataset in the UI or using huggingface_hub.create_repo

# Step 1: Create an alternative last turn using Llama3-8b Instruct model
# We want the model to generate the same response whether the input is audio or text

just ds_tool textgen -d allenai/soda --shuffle True -s test -n 1000 -u fixie-ai/soda-audio -c alt_last_turn \
-T @ultravox/tools/ds_tool/soda_alt_last_turn.jinja -j -b https://api.fireworks.ai/inference/v1 \
-k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN

just ds_tool textgen -d allenai/soda --shuffle True -s validation -n 1000 -u fixie-ai/soda-audio -c alt_last_turn \
-T @ultravox/tools/ds_tool/soda_alt_last_turn.jinja -j -b https://api.fireworks.ai/inference/v1 \
-k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN

just ds_tool textgen -d allenai/soda --shuffle True -s train -n 100000 -u fixie-ai/soda-audio -c alt_last_turn \
-T @ultravox/tools/ds_tool/soda_alt_last_turn.jinja -j -b https://api.fireworks.ai/inference/v1 \
-k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN


# Step 2: TTS the turn before last: audio input for the model

just ds_tool tts -d fixie-ai/soda-audio -u fixie-ai/soda-audio -c "dialogue[-2]" -a audio_second_last_turn -i eleven -V random --token $HF_WRITE_TOKEN -N 8
8 changes: 8 additions & 0 deletions ultravox/tools/ds_tool/boolq_explanation.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{# This template was used to create the synthetic "explanation" field in `fixie-ai/boolq-audio` -#}
Passage: {{ passage }}

Question: {{ question }}

Answer: {{ answer }}

Provide a short explanation to the question given the passage that provides a rationale for the answer.
61 changes: 61 additions & 0 deletions ultravox/tools/ds_tool/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import hashlib
import json
import os
from typing import Optional

import openai

from ultravox.tools.ds_tool import tts


class CachingChatWrapper:
def __init__(self, client: openai.Client, unique_id: str):
super().__init__()
self._client = client
self._base_path = os.path.join(
".cache/ds_tool/textgen",
unique_id.replace("https://", "").replace("/", "__"),
)
os.makedirs(self._base_path, exist_ok=True)

def chat_completion(self, **kwargs) -> str:
text_hash = hashlib.sha256(json.dumps(kwargs).encode()).hexdigest()

cache_path = os.path.join(self._base_path, f"{text_hash}.txt")

if os.path.exists(cache_path):
with open(cache_path, "r") as f:
return f.read()

response = self._client.chat.completions.create(**kwargs)
text = response.choices[0].message.content

with open(cache_path, "w") as f:
f.write(text)

return text


class CachingTtsWrapper:
def __init__(self, client: tts.Client, provider: str):
super().__init__()
self._client = client
self._base_path = os.path.join(".cache/ds_tool/tts", provider)

def tts(self, text: str, voice: Optional[str] = None) -> bytes:
path = os.path.join(self._base_path, voice or "default")
text_hash = hashlib.sha256(text.encode()).hexdigest()
os.makedirs(path, exist_ok=True)

cache_path = os.path.join(path, f"{text_hash}.wav")

if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
return f.read()

wav = self._client.tts(text, voice)

with open(cache_path, "wb") as f:
f.write(wav)

return wav
65 changes: 43 additions & 22 deletions ultravox/tools/ds_tool.py → ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import dataclasses
import json
import os
from typing import Any, Dict, Optional, Union

import datasets
import jinja2
import openai
import simple_parsing

from ultravox.tools import tts
from ultravox.tools.ds_tool import caching
from ultravox.tools.ds_tool import tts

tts_client: tts.Client
chat_client: openai.Client

DEFAULT_TEXTGEN_TEMPLATE = """Passage: {passage}
Question: {question}
Answer: {answer}
Provide a short explanation to the question given the passage that provides a rationale for the answer."""
tts_client: caching.CachingTtsWrapper
chat_client: caching.CachingChatWrapper


@dataclasses.dataclass
class TtsTask:
implementation: str = simple_parsing.field(default="azure", alias="-i")
# Column name containing the text to convert to audio. It can be a Jinja variable expression.
column_name: str = simple_parsing.field(default="question", alias="-c")
audio_column_name: Optional[str] = simple_parsing.field(default=None, alias="-a")
voice: Optional[str] = simple_parsing.field(default=None, alias="-V")
Expand All @@ -33,7 +29,10 @@ def __post_init__(self):
global tts_client
if self.audio_column_name is None:
self.audio_column_name = f"{self.column_name}_audio"
tts_client = tts.create_client(self.implementation, self.sample_rate)
tts_client = caching.CachingTtsWrapper(
tts.create_client(self.implementation, self.sample_rate),
provider=self.implementation,
)

def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Dataset:
print(f'TTS mapping "{self.column_name}" to "{self.audio_column_name}"...')
Expand All @@ -42,16 +41,19 @@ def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Datas
)

def _map_sample(self, sample):
text = sample[self.column_name]
# using a Jinja template for some added flexibility
# The {{ var }} syntax is how Jinja denotes variables
text = jinja2.Template("{{" + self.column_name + "}}").render(**sample)
text = text["text"] if isinstance(text, dict) else text
sample[self.audio_column_name] = tts_client.tts(text, self.voice)
return sample


@dataclasses.dataclass
class TextGenerationTask:
new_column_name: str = simple_parsing.field(default="explanation", alias="-c")
template: str = simple_parsing.field(default=DEFAULT_TEXTGEN_TEMPLATE, alias="-T")
new_column_name: str = simple_parsing.field(alias="-c")
template: str = simple_parsing.field(alias="-T")
json_mode: bool = simple_parsing.field(default=False, alias="-j")

language_model: str = simple_parsing.field(default="gpt-4o", alias="-m")
base_url: Optional[str] = simple_parsing.field(default=None, alias="-b")
Expand All @@ -62,7 +64,11 @@ class TextGenerationTask:
def __post_init__(self):
# The OAI client is separate from the task to avoid pickling issues when multiprocessing.
global chat_client
chat_client = openai.Client(base_url=self.base_url, api_key=self.api_key)
# Caching the client to avoid repeated calls to the API if the tool fails.
chat_client = caching.CachingChatWrapper(
openai.Client(base_url=self.base_url, api_key=self.api_key),
unique_id=f"{self.base_url}__{self.language_model}",
)
if self.template.startswith("@"):
with open(self.template[1:], "r") as template_file:
self.template = template_file.read()
Expand All @@ -72,19 +78,27 @@ def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Datas
return ds_split.map(self._map_sample, num_proc=num_proc)

def _map_sample(self, sample):
input_text = self.template.format(**sample)
response = chat_client.chat.completions.create(
rendered = jinja2.Template(self.template).render(**sample, json_dump=json.dumps)

if self.json_mode:
turns = json.loads(rendered)
assert isinstance(turns, list)
assert all(isinstance(turn, dict) for turn in turns)
assert len(turns) > 0
assert turns[-1].get("role", None) == "user"
else:
turns = [{"role": "user", "content": rendered}]

sample[self.new_column_name] = chat_client.chat_completion(
model=self.language_model,
messages=[{"role": "user", "content": input_text}],
messages=turns,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
sample[self.new_column_name] = response.choices[0].message.content
return sample


# This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model.
# Example usages:
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -c explanation -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt
Expand All @@ -94,6 +108,8 @@ class DatasetToolArgs:
dataset_subset: Optional[str] = simple_parsing.field(default=None, alias="-S")
dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s")

shuffle: bool = simple_parsing.field(default=False)
shuffle_seed: int = simple_parsing.field(default=42)
num_samples: Optional[int] = simple_parsing.field(default=None, alias="-n")
num_workers: int = simple_parsing.field(default=16, alias="-w")

Expand Down Expand Up @@ -122,6 +138,8 @@ def main(args: DatasetToolArgs):

for split, ds_split in data_dict.items():
print(f'Processing split "{split}"...')
if args.shuffle:
ds_split = ds_split.shuffle(seed=args.shuffle_seed)
if args.num_samples:
ds_split = ds_split.select(range(args.num_samples))
data_dict[split] = args.task.map_split(ds_split, args.num_workers)
Expand All @@ -138,7 +156,9 @@ def main(args: DatasetToolArgs):

try:
if args.dataset_split:
data_dict[args.dataset_split].push_to_hub(args.upload_name, **hub_args)
data_dict[args.dataset_split].push_to_hub(
args.upload_name, split=args.dataset_split, **hub_args
)
else:
data_dict.push_to_hub(args.upload_name, **hub_args)
except Exception as e:
Expand All @@ -149,6 +169,7 @@ def main(args: DatasetToolArgs):
output_name = f"{split}-00000-of-00001.parquet"
data_dict[split].to_parquet(output_name)
print(f"Saved to {output_name}")
print(f"Sample {0} of {split}: {data_dict[split][0]}")


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions ultravox/tools/ds_tool/soda_alt_last_turn.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{# This template was used to create the synthetic "alt_last_turn" field (Llama3-8B alternative to the last turn) in `fixie-ai/soda-audio` -#}
[
{ "role": "system", "content": "Follow the flow of the conversation and respond just like a human would in the same situation."},
{% for turn in dialogue[:-1] %}
{ "role": {% if loop.revindex0 % 2 == 0 %} "user" {% else %} "assistant" {% endif %}, "content": {{ json_dump(turn) }} }
{%- if not loop.last %},{% endif %}
{% endfor %}
]
26 changes: 26 additions & 0 deletions ultravox/tools/ds_tool/template_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import json

import jinja2


def test_quotes():
with open("tools/ds_tool/soda_alt_last_turn.jinja", "r") as template_file:
template = template_file.read()

dialogue = [
'Have you ever used a double quote (")',
"Of course, what about a single quote (')?",
'"Yes, I have."',
"last turn is ignored!",
]

messages = json.loads(
jinja2.Template(template).render(dialogue=dialogue, json_dump=json.dumps)
)
assert isinstance(messages, list)
assert all(isinstance(turn, dict) for turn in messages)
assert messages[-1]["role"] == "user"

assert len(messages) == 4
assert messages[0]["role"] == "system"
assert [x["content"] for x in messages[1:]] == dialogue[:-1]
5 changes: 2 additions & 3 deletions ultravox/tools/tts.py → ultravox/tools/ds_tool/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def __init__(self, sample_rate: int = 16000):
self._sample_rate = sample_rate

@abc.abstractmethod
def tts(self, text: str, voice: Optional[str] = None):
def tts(self, text: str, voice: Optional[str] = None) -> bytes:
raise NotImplementedError

def _post(self, url: str, headers: Dict[str, str], json: Dict[str, Any]):
response = self._session.post(url, headers=headers, json=json)
response.raise_for_status()
return response

def _handle_pcm_response(self, response: requests.Response):
def _handle_pcm_response(self, response: requests.Response) -> bytes:
pcm_array = np.frombuffer(response.content, dtype=np.int16)
wav_bytes = io.BytesIO()
sf.write(wav_bytes, pcm_array, self._sample_rate, format="wav")
Expand Down Expand Up @@ -132,7 +132,6 @@ def tts(self, text: str, voice: Optional[str] = None):
i = np.random.randint(len(self.ALL_VOICES)) + os.getpid()
voice = self.ALL_VOICES[i % len(self.ALL_VOICES)]
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice}/stream?output_format=pcm_16000"
print("url", url)
headers = {"xi-api-key": os.environ["ELEVEN_API_KEY"]}
body = {
"text": text,
Expand Down

0 comments on commit cbe5f6e

Please sign in to comment.