-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extending
ds_tool
for SODA conversational dataset (#32)
* extending ds_tool for soda conversational dataset * caching for TTS generations * dataset_creation scripts * escape quotes for SODA template
- Loading branch information
Showing
10 changed files
with
187 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
runs/ | ||
outputs/ | ||
wandb/ | ||
*.parquet | ||
|
||
artifacts/ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/bin/bash | ||
# Steps to reproduce the fixie-ai/boolq-audio dataset | ||
|
||
# Step 0: create the `fixie-ai/boolq-audio` dataset in the UI or using huggingface_hub.create_repo | ||
|
||
# Step 1: Create a plausible explanation for the answer | ||
# This explanation is only used in the `-extended` version of the dataset and is used mainly for better training. | ||
just ds_tool textgen -d google/boolq -u fixie-ai/boolq-audio -c explanation -T @ultravox/tools/ds_tool/boolq_template.jinja --token $HF_WRITE_TOKEN -N 8 | ||
|
||
# Step 2: TTS the question into the audio input for the model | ||
# Note: the original dataset was not created using this script. This is just an example of how to create the audio version of the dataset | ||
just ds_tool tts -d fixie-ai/boolq-audio -u fixie-ai/boolq-audio -c question -a audio -i azure --token $HF_WRITE_TOKEN -N 8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#!/bin/bash | ||
# Steps to reproduce the fixie-ai/soda-audio dataset | ||
|
||
# Step 0: create the `fixie-ai/soda-audio` dataset in the UI or using huggingface_hub.create_repo | ||
|
||
# Step 1: Create an alternative last turn using Llama3-8b Instruct model | ||
# We want the model to generate the same response whether the input is audio or text | ||
|
||
just ds_tool textgen -d allenai/soda --shuffle True -s test -n 1000 -u fixie-ai/soda-audio -c alt_last_turn \ | ||
-T @ultravox/tools/ds_tool/soda_alt_last_turn.jinja -j -b https://api.fireworks.ai/inference/v1 \ | ||
-k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN | ||
|
||
just ds_tool textgen -d allenai/soda --shuffle True -s validation -n 1000 -u fixie-ai/soda-audio -c alt_last_turn \ | ||
-T @ultravox/tools/ds_tool/soda_alt_last_turn.jinja -j -b https://api.fireworks.ai/inference/v1 \ | ||
-k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN | ||
|
||
just ds_tool textgen -d allenai/soda --shuffle True -s train -n 100000 -u fixie-ai/soda-audio -c alt_last_turn \ | ||
-T @ultravox/tools/ds_tool/soda_alt_last_turn.jinja -j -b https://api.fireworks.ai/inference/v1 \ | ||
-k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN | ||
|
||
|
||
# Step 2: TTS the turn before last: audio input for the model | ||
|
||
just ds_tool tts -d fixie-ai/soda-audio -u fixie-ai/soda-audio -c "dialogue[-2]" -a audio_second_last_turn -i eleven -V random --token $HF_WRITE_TOKEN -N 8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{# This template was used to create the synthetic "explanation" field in `fixie-ai/boolq-audio` -#} | ||
Passage: {{ passage }} | ||
|
||
Question: {{ question }} | ||
|
||
Answer: {{ answer }} | ||
|
||
Provide a short explanation to the question given the passage that provides a rationale for the answer. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import hashlib | ||
import json | ||
import os | ||
from typing import Optional | ||
|
||
import openai | ||
|
||
from ultravox.tools.ds_tool import tts | ||
|
||
|
||
class CachingChatWrapper: | ||
def __init__(self, client: openai.Client, unique_id: str): | ||
super().__init__() | ||
self._client = client | ||
self._base_path = os.path.join( | ||
".cache/ds_tool/textgen", | ||
unique_id.replace("https://", "").replace("/", "__"), | ||
) | ||
os.makedirs(self._base_path, exist_ok=True) | ||
|
||
def chat_completion(self, **kwargs) -> str: | ||
text_hash = hashlib.sha256(json.dumps(kwargs).encode()).hexdigest() | ||
|
||
cache_path = os.path.join(self._base_path, f"{text_hash}.txt") | ||
|
||
if os.path.exists(cache_path): | ||
with open(cache_path, "r") as f: | ||
return f.read() | ||
|
||
response = self._client.chat.completions.create(**kwargs) | ||
text = response.choices[0].message.content | ||
|
||
with open(cache_path, "w") as f: | ||
f.write(text) | ||
|
||
return text | ||
|
||
|
||
class CachingTtsWrapper: | ||
def __init__(self, client: tts.Client, provider: str): | ||
super().__init__() | ||
self._client = client | ||
self._base_path = os.path.join(".cache/ds_tool/tts", provider) | ||
|
||
def tts(self, text: str, voice: Optional[str] = None) -> bytes: | ||
path = os.path.join(self._base_path, voice or "default") | ||
text_hash = hashlib.sha256(text.encode()).hexdigest() | ||
os.makedirs(path, exist_ok=True) | ||
|
||
cache_path = os.path.join(path, f"{text_hash}.wav") | ||
|
||
if os.path.exists(cache_path): | ||
with open(cache_path, "rb") as f: | ||
return f.read() | ||
|
||
wav = self._client.tts(text, voice) | ||
|
||
with open(cache_path, "wb") as f: | ||
f.write(wav) | ||
|
||
return wav |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{# This template was used to create the synthetic "alt_last_turn" field (Llama3-8B alternative to the last turn) in `fixie-ai/soda-audio` -#} | ||
[ | ||
{ "role": "system", "content": "Follow the flow of the conversation and respond just like a human would in the same situation."}, | ||
{% for turn in dialogue[:-1] %} | ||
{ "role": {% if loop.revindex0 % 2 == 0 %} "user" {% else %} "assistant" {% endif %}, "content": {{ json_dump(turn) }} } | ||
{%- if not loop.last %},{% endif %} | ||
{% endfor %} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import json | ||
|
||
import jinja2 | ||
|
||
|
||
def test_quotes(): | ||
with open("tools/ds_tool/soda_alt_last_turn.jinja", "r") as template_file: | ||
template = template_file.read() | ||
|
||
dialogue = [ | ||
'Have you ever used a double quote (")', | ||
"Of course, what about a single quote (')?", | ||
'"Yes, I have."', | ||
"last turn is ignored!", | ||
] | ||
|
||
messages = json.loads( | ||
jinja2.Template(template).render(dialogue=dialogue, json_dump=json.dumps) | ||
) | ||
assert isinstance(messages, list) | ||
assert all(isinstance(turn, dict) for turn in messages) | ||
assert messages[-1]["role"] == "user" | ||
|
||
assert len(messages) == 4 | ||
assert messages[0]["role"] == "system" | ||
assert [x["content"] for x in messages[1:]] == dialogue[:-1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters