Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending ds_tool for SODA conversational dataset #32

Merged
merged 29 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6dc479f
extending ds_tool for soda conversational dataset
farzadab Jun 15, 2024
a08671d
add jmespath to requirements
farzadab Jun 15, 2024
bf1c9d8
formatting
farzadab Jun 15, 2024
94c019b
Merge remote-tracking branch 'origin/main' into farzad-soda
farzadab Jun 17, 2024
adf6f09
caching for TTS generations
farzadab Jun 19, 2024
d40ccc9
ignore parquet files
farzadab Jun 19, 2024
8b0cdf9
Jinja template support
farzadab Jun 19, 2024
23f6a2d
cleaning up templates + add README for documenting steps
farzadab Jun 19, 2024
bbd1457
remove unneeded jmespath
farzadab Jun 19, 2024
7ea545b
revert exact poetry.lock
farzadab Jun 19, 2024
5372282
newline at file end
farzadab Jun 19, 2024
37ade43
fix comment
farzadab Jun 19, 2024
fa24f7c
applying comments: minor fixes
farzadab Jun 21, 2024
b028473
dataset_creation scripts
farzadab Jun 21, 2024
10d74ca
improvement comment
farzadab Jun 21, 2024
2e3bf75
improve comments
farzadab Jun 21, 2024
24bfaac
revert to original upload split code with fix
farzadab Jun 21, 2024
6cde01f
json_mode for ds_tool
farzadab Jun 21, 2024
e9d4f08
fix json_mode in soda creation docs
farzadab Jun 21, 2024
f083a3c
escape quotes for SODA template
farzadab Jun 21, 2024
85ba7bd
cache for textgen
farzadab Jun 21, 2024
02a04de
moving both wrappers into the same place
farzadab Jun 21, 2024
1a32b14
move escape_quotes inside the jinja template
farzadab Jun 21, 2024
7714045
test for double quotes in SODA template
farzadab Jun 21, 2024
140e0bf
formatting
farzadab Jun 21, 2024
0b79843
using json.dumps for better JSON escaping
farzadab Jun 21, 2024
0ecf472
move ds_tool and related files into their own ds_tool folder
farzadab Jun 21, 2024
8cc4f45
add language_model name to cache unique_id
farzadab Jun 21, 2024
117a598
rename: audio_one_but_last -> audio_second_last_turn
farzadab Jun 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
runs/
outputs/
wandb/
*.parquet

artifacts/

Expand Down
13 changes: 1 addition & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ openai = "~1.33.0"
jiwer = "~3.0.4"
tensorboardx = "~2.6.2.2"
wandb = "~0.17.1"
jmespath = "^1.0.1"

[tool.poetry.group.dev.dependencies]
black = "~24.4.2"
Expand Down
33 changes: 33 additions & 0 deletions ultravox/tools/ds_templates/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Synthetically enhanced datasets
juberti marked this conversation as resolved.
Show resolved Hide resolved

## BoolQ

farzadab marked this conversation as resolved.
Show resolved Hide resolved
```bash
just ds_tool textgen -d google/boolq -u fixie-ai/boolq-audio -c explanation -T @ultravox/tools/ds_templates/boolq_template.jinja --token $HF_WRITE_TOKEN
```

## SODA

The SODA dataset was slightly modified to allow for training a voice-text model as follows to create the `fixie-ai/soda-audio` dataset:

### Alternative last turn (Llama3-8b)

```bash
just ds_tool textgen -d allenai/soda --shuffle True -s test -n 1000 -u fixie-ai/soda -c alt_last_turn \
-T @ultravox/tools/ds_templates/soda_template.jinja -b https://api.fireworks.ai/inference/v1 \
-k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN

just ds_tool textgen -d allenai/soda --shuffle True -s validation -n 1000 -u fixie-ai/soda -c alt_last_turn \
-T @ultravox/tools/ds_templates/soda_template.jinja -b https://api.fireworks.ai/inference/v1 \
-k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN

just ds_tool textgen -d allenai/soda --shuffle True -s train -n 100000 -u fixie-ai/soda -c alt_last_turn \
farzadab marked this conversation as resolved.
Show resolved Hide resolved
-T @ultravox/tools/ds_templates/soda_template.jinja -b https://api.fireworks.ai/inference/v1 \
-k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN
```

### TTS the turn before last

```bash
just ds_tool tts -d fixie-ai/soda -u fixie-ai/soda-copy -c "dialogue[-2]" -a audio_one_but_last -i eleven -V random --token $HF_WRITE_TOKEN
```
8 changes: 8 additions & 0 deletions ultravox/tools/ds_templates/boolq_template.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{# This template was used to create the synthetic "explanation" field in `fixie-ai/boolq-audio` -#}
farzadab marked this conversation as resolved.
Show resolved Hide resolved
Passage: {{ passage }}

Question: {{ question }}

Answer: {{ answer }}

Provide a short explanation to the question given the passage that provides a rationale for the answer.
8 changes: 8 additions & 0 deletions ultravox/tools/ds_templates/soda_template.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{# This template was used to create the synthetic "alt_last_turn" field (Llama3-8B alternative to the last turn) in `fixie-ai/soda-audio` -#}
[
{ "role": "system", "content": "Follow the flow of the conversation and respond just like a human would in the same situation."},
{% for turn in dialogue[:-1] %}
{ "role": {% if loop.revindex0 % 2 == 0 %} "user" {% else %} "assistant" {% endif %}, "content": "{{ turn }}" }
{%- if not loop.last %},{% endif %}
{% endfor %}
]
49 changes: 32 additions & 17 deletions ultravox/tools/ds_tool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dataclasses
import json
import os
from typing import Any, Dict, Optional, Union

import datasets
import jinja2
import openai
import simple_parsing

Expand All @@ -11,18 +13,11 @@
tts_client: tts.Client
chat_client: openai.Client

DEFAULT_TEXTGEN_TEMPLATE = """Passage: {passage}

Question: {question}

Answer: {answer}

Provide a short explanation to the question given the passage that provides a rationale for the answer."""


@dataclasses.dataclass
class TtsTask:
implementation: str = simple_parsing.field(default="azure", alias="-i")
# Column name containing the text to convert to audio. It can be a JMESPath expression.
column_name: str = simple_parsing.field(default="question", alias="-c")
audio_column_name: Optional[str] = simple_parsing.field(default=None, alias="-a")
voice: Optional[str] = simple_parsing.field(default=None, alias="-V")
Expand All @@ -42,16 +37,18 @@ def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Datas
)

def _map_sample(self, sample):
text = sample[self.column_name]
# using a Jinja template for some added flexibility
# The {{ var }} syntax is how Jinja denotes variables
text = jinja2.Template("{{" + self.column_name + "}}").render(**sample)
text = text["text"] if isinstance(text, dict) else text
sample[self.audio_column_name] = tts_client.tts(text, self.voice)
return sample


@dataclasses.dataclass
class TextGenerationTask:
new_column_name: str = simple_parsing.field(default="explanation", alias="-c")
template: str = simple_parsing.field(default=DEFAULT_TEXTGEN_TEMPLATE, alias="-T")
new_column_name: str = simple_parsing.field(alias="-c")
template: str = simple_parsing.field(alias="-T")
juberti marked this conversation as resolved.
Show resolved Hide resolved

language_model: str = simple_parsing.field(default="gpt-4o", alias="-m")
base_url: Optional[str] = simple_parsing.field(default=None, alias="-b")
Expand All @@ -72,10 +69,20 @@ def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Datas
return ds_split.map(self._map_sample, num_proc=num_proc)

def _map_sample(self, sample):
input_text = self.template.format(**sample)
rendered = jinja2.Template(self.template).render(**sample)

try:
turns = json.loads(rendered)
assert isinstance(turns, list)
assert all(isinstance(turn, dict) for turn in turns)
assert len(turns) > 0
assert turns[-1].get("role", None) == "user"
except json.JSONDecodeError:
turns = [{"role": "user", "content": rendered}]
farzadab marked this conversation as resolved.
Show resolved Hide resolved

response = chat_client.chat.completions.create(
model=self.language_model,
messages=[{"role": "user", "content": input_text}],
messages=turns,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
Expand All @@ -84,7 +91,6 @@ def _map_sample(self, sample):


# This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model.
# Example usages:
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -c explanation -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt
Expand All @@ -94,6 +100,8 @@ class DatasetToolArgs:
dataset_subset: Optional[str] = simple_parsing.field(default=None, alias="-S")
dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s")

shuffle: bool = simple_parsing.field(default=False)
shuffle_seed: int = simple_parsing.field(default=42)
num_samples: Optional[int] = simple_parsing.field(default=None, alias="-n")
num_workers: int = simple_parsing.field(default=16, alias="-w")

Expand Down Expand Up @@ -122,6 +130,8 @@ def main(args: DatasetToolArgs):

for split, ds_split in data_dict.items():
print(f'Processing split "{split}"...')
if args.shuffle:
farzadab marked this conversation as resolved.
Show resolved Hide resolved
ds_split = ds_split.shuffle(seed=args.shuffle_seed)
if args.num_samples:
ds_split = ds_split.select(range(args.num_samples))
data_dict[split] = args.task.map_split(ds_split, args.num_workers)
Expand All @@ -138,9 +148,13 @@ def main(args: DatasetToolArgs):

try:
if args.dataset_split:
data_dict[args.dataset_split].push_to_hub(args.upload_name, **hub_args)
else:
data_dict.push_to_hub(args.upload_name, **hub_args)
# load the full dataset, otherwise the existing splits will be overwritten
upload_ds = datasets.load_dataset(args.upload_name, args.dataset_subset)
farzadab marked this conversation as resolved.
Show resolved Hide resolved
for split in data_dict.keys():
upload_ds[split] = data_dict[split]
data_dict = upload_ds

data_dict.push_to_hub(args.upload_name, **hub_args)
except Exception as e:
print(f"Failed to push to hub: {e}")

Expand All @@ -149,6 +163,7 @@ def main(args: DatasetToolArgs):
output_name = f"{split}-00000-of-00001.parquet"
data_dict[split].to_parquet(output_name)
print(f"Saved to {output_name}")
print(f"Sample {0} of {split}: {data_dict[split][0]}")


if __name__ == "__main__":
Expand Down
39 changes: 35 additions & 4 deletions ultravox/tools/tts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import abc
import hashlib
import io
import logging
import os
from typing import Any, Dict, Optional
from xml.sax import saxutils
Expand Down Expand Up @@ -132,7 +134,7 @@ def tts(self, text: str, voice: Optional[str] = None):
i = np.random.randint(len(self.ALL_VOICES)) + os.getpid()
voice = self.ALL_VOICES[i % len(self.ALL_VOICES)]
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice}/stream?output_format=pcm_16000"
print("url", url)
logging.debug(f"url {url}")
farzadab marked this conversation as resolved.
Show resolved Hide resolved
headers = {"xi-api-key": os.environ["ELEVEN_API_KEY"]}
body = {
"text": text,
Expand All @@ -145,9 +147,38 @@ def tts(self, text: str, voice: Optional[str] = None):
return self._handle_pcm_response(self._post(url, headers, body))


class CachedClientWrapper:
farzadab marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, client: Client, provider: str):
super().__init__()
self._client = client
self._base_path = os.path.join(".cache/ds_tool", provider)
farzadab marked this conversation as resolved.
Show resolved Hide resolved

def tts(self, text: str, voice: Optional[str] = None):
path = os.path.join(self._base_path, voice or "default")
text_hash = hashlib.sha256(text.encode()).hexdigest()
os.makedirs(path, exist_ok=True)

cache_path = os.path.join(path, f"{text_hash}.wav")

if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
return f.read()

pcm = self._client.tts(text, voice)
farzadab marked this conversation as resolved.
Show resolved Hide resolved

with open(cache_path, "wb") as f:
f.write(pcm)

return pcm


def create_client(implementation: str, sample_rate: int):
client: Client
if implementation == "azure":
return AzureTts(sample_rate=sample_rate)
client = AzureTts(sample_rate=sample_rate)
elif implementation == "eleven":
return ElevenTts(sample_rate=sample_rate)
raise ValueError(f"Unknown TTS implementation: {implementation}")
client = ElevenTts(sample_rate=sample_rate)
else:
raise ValueError(f"Unknown TTS implementation: {implementation}")

return CachedClientWrapper(client, provider=implementation)
Loading