From cbe5f6ebb80cffa56cd4e5883d25b61ba555c4bd Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 24 Jun 2024 13:31:56 -0700 Subject: [PATCH] Extending `ds_tool` for SODA conversational dataset (#32) * extending ds_tool for soda conversational dataset * caching for TTS generations * dataset_creation scripts * escape quotes for SODA template --- .gitignore | 1 + Justfile | 4 +- scripts/dataset_creation/boolq_audio.sh | 12 ++++ scripts/dataset_creation/soda_audio.sh | 24 +++++++ .../tools/ds_tool/boolq_explanation.jinja | 8 +++ ultravox/tools/ds_tool/caching.py | 61 +++++++++++++++++ ultravox/tools/{ => ds_tool}/ds_tool.py | 65 ++++++++++++------- .../tools/ds_tool/soda_alt_last_turn.jinja | 8 +++ ultravox/tools/ds_tool/template_test.py | 26 ++++++++ ultravox/tools/{ => ds_tool}/tts.py | 5 +- 10 files changed, 187 insertions(+), 27 deletions(-) create mode 100644 scripts/dataset_creation/boolq_audio.sh create mode 100644 scripts/dataset_creation/soda_audio.sh create mode 100644 ultravox/tools/ds_tool/boolq_explanation.jinja create mode 100644 ultravox/tools/ds_tool/caching.py rename ultravox/tools/{ => ds_tool}/ds_tool.py (72%) create mode 100644 ultravox/tools/ds_tool/soda_alt_last_turn.jinja create mode 100644 ultravox/tools/ds_tool/template_test.py rename ultravox/tools/{ => ds_tool}/tts.py (97%) diff --git a/.gitignore b/.gitignore index f6513a5d..11d49f38 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ runs/ outputs/ wandb/ +*.parquet artifacts/ diff --git a/Justfile b/Justfile index 2217404b..9ed906cf 100644 --- a/Justfile +++ b/Justfile @@ -46,10 +46,10 @@ eval *FLAGS: poetry run python -m ultravox.tools.eval_tool {{FLAGS}} tts *FLAGS: - poetry run python -m ultravox.tools.ds_tool tts {{FLAGS}} + poetry run python -m ultravox.tools.ds_tool.ds_tool tts {{FLAGS}} ds_tool *FLAGS: - poetry run python -m ultravox.tools.ds_tool {{FLAGS}} + poetry run python -m ultravox.tools.ds_tool.ds_tool {{FLAGS}} mds *FLAGS: poetry run python -m ultravox.tools.mds_tool {{FLAGS}} diff --git a/scripts/dataset_creation/boolq_audio.sh b/scripts/dataset_creation/boolq_audio.sh new file mode 100644 index 00000000..707a87b4 --- /dev/null +++ b/scripts/dataset_creation/boolq_audio.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Steps to reproduce the fixie-ai/boolq-audio dataset + +# Step 0: create the `fixie-ai/boolq-audio` dataset in the UI or using huggingface_hub.create_repo + +# Step 1: Create a plausible explanation for the answer +# This explanation is only used in the `-extended` version of the dataset and is used mainly for better training. +just ds_tool textgen -d google/boolq -u fixie-ai/boolq-audio -c explanation -T @ultravox/tools/ds_tool/boolq_template.jinja --token $HF_WRITE_TOKEN -N 8 + +# Step 2: TTS the question into the audio input for the model +# Note: the original dataset was not created using this script. This is just an example of how to create the audio version of the dataset +just ds_tool tts -d fixie-ai/boolq-audio -u fixie-ai/boolq-audio -c question -a audio -i azure --token $HF_WRITE_TOKEN -N 8 diff --git a/scripts/dataset_creation/soda_audio.sh b/scripts/dataset_creation/soda_audio.sh new file mode 100644 index 00000000..cb4af43e --- /dev/null +++ b/scripts/dataset_creation/soda_audio.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Steps to reproduce the fixie-ai/soda-audio dataset + +# Step 0: create the `fixie-ai/soda-audio` dataset in the UI or using huggingface_hub.create_repo + +# Step 1: Create an alternative last turn using Llama3-8b Instruct model +# We want the model to generate the same response whether the input is audio or text + +just ds_tool textgen -d allenai/soda --shuffle True -s test -n 1000 -u fixie-ai/soda-audio -c alt_last_turn \ + -T @ultravox/tools/ds_tool/soda_alt_last_turn.jinja -j -b https://api.fireworks.ai/inference/v1 \ + -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN + +just ds_tool textgen -d allenai/soda --shuffle True -s validation -n 1000 -u fixie-ai/soda-audio -c alt_last_turn \ + -T @ultravox/tools/ds_tool/soda_alt_last_turn.jinja -j -b https://api.fireworks.ai/inference/v1 \ + -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN + +just ds_tool textgen -d allenai/soda --shuffle True -s train -n 100000 -u fixie-ai/soda-audio -c alt_last_turn \ + -T @ultravox/tools/ds_tool/soda_alt_last_turn.jinja -j -b https://api.fireworks.ai/inference/v1 \ + -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct --token $HF_WRITE_TOKEN + + +# Step 2: TTS the turn before last: audio input for the model + +just ds_tool tts -d fixie-ai/soda-audio -u fixie-ai/soda-audio -c "dialogue[-2]" -a audio_second_last_turn -i eleven -V random --token $HF_WRITE_TOKEN -N 8 diff --git a/ultravox/tools/ds_tool/boolq_explanation.jinja b/ultravox/tools/ds_tool/boolq_explanation.jinja new file mode 100644 index 00000000..49b1d121 --- /dev/null +++ b/ultravox/tools/ds_tool/boolq_explanation.jinja @@ -0,0 +1,8 @@ +{# This template was used to create the synthetic "explanation" field in `fixie-ai/boolq-audio` -#} +Passage: {{ passage }} + +Question: {{ question }} + +Answer: {{ answer }} + +Provide a short explanation to the question given the passage that provides a rationale for the answer. diff --git a/ultravox/tools/ds_tool/caching.py b/ultravox/tools/ds_tool/caching.py new file mode 100644 index 00000000..9f76f4e2 --- /dev/null +++ b/ultravox/tools/ds_tool/caching.py @@ -0,0 +1,61 @@ +import hashlib +import json +import os +from typing import Optional + +import openai + +from ultravox.tools.ds_tool import tts + + +class CachingChatWrapper: + def __init__(self, client: openai.Client, unique_id: str): + super().__init__() + self._client = client + self._base_path = os.path.join( + ".cache/ds_tool/textgen", + unique_id.replace("https://", "").replace("/", "__"), + ) + os.makedirs(self._base_path, exist_ok=True) + + def chat_completion(self, **kwargs) -> str: + text_hash = hashlib.sha256(json.dumps(kwargs).encode()).hexdigest() + + cache_path = os.path.join(self._base_path, f"{text_hash}.txt") + + if os.path.exists(cache_path): + with open(cache_path, "r") as f: + return f.read() + + response = self._client.chat.completions.create(**kwargs) + text = response.choices[0].message.content + + with open(cache_path, "w") as f: + f.write(text) + + return text + + +class CachingTtsWrapper: + def __init__(self, client: tts.Client, provider: str): + super().__init__() + self._client = client + self._base_path = os.path.join(".cache/ds_tool/tts", provider) + + def tts(self, text: str, voice: Optional[str] = None) -> bytes: + path = os.path.join(self._base_path, voice or "default") + text_hash = hashlib.sha256(text.encode()).hexdigest() + os.makedirs(path, exist_ok=True) + + cache_path = os.path.join(path, f"{text_hash}.wav") + + if os.path.exists(cache_path): + with open(cache_path, "rb") as f: + return f.read() + + wav = self._client.tts(text, voice) + + with open(cache_path, "wb") as f: + f.write(wav) + + return wav diff --git a/ultravox/tools/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py similarity index 72% rename from ultravox/tools/ds_tool.py rename to ultravox/tools/ds_tool/ds_tool.py index b4d47c05..36691dc4 100644 --- a/ultravox/tools/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -1,28 +1,24 @@ import dataclasses +import json import os from typing import Any, Dict, Optional, Union import datasets +import jinja2 import openai import simple_parsing -from ultravox.tools import tts +from ultravox.tools.ds_tool import caching +from ultravox.tools.ds_tool import tts -tts_client: tts.Client -chat_client: openai.Client - -DEFAULT_TEXTGEN_TEMPLATE = """Passage: {passage} - -Question: {question} - -Answer: {answer} - -Provide a short explanation to the question given the passage that provides a rationale for the answer.""" +tts_client: caching.CachingTtsWrapper +chat_client: caching.CachingChatWrapper @dataclasses.dataclass class TtsTask: implementation: str = simple_parsing.field(default="azure", alias="-i") + # Column name containing the text to convert to audio. It can be a Jinja variable expression. column_name: str = simple_parsing.field(default="question", alias="-c") audio_column_name: Optional[str] = simple_parsing.field(default=None, alias="-a") voice: Optional[str] = simple_parsing.field(default=None, alias="-V") @@ -33,7 +29,10 @@ def __post_init__(self): global tts_client if self.audio_column_name is None: self.audio_column_name = f"{self.column_name}_audio" - tts_client = tts.create_client(self.implementation, self.sample_rate) + tts_client = caching.CachingTtsWrapper( + tts.create_client(self.implementation, self.sample_rate), + provider=self.implementation, + ) def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Dataset: print(f'TTS mapping "{self.column_name}" to "{self.audio_column_name}"...') @@ -42,7 +41,9 @@ def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Datas ) def _map_sample(self, sample): - text = sample[self.column_name] + # using a Jinja template for some added flexibility + # The {{ var }} syntax is how Jinja denotes variables + text = jinja2.Template("{{" + self.column_name + "}}").render(**sample) text = text["text"] if isinstance(text, dict) else text sample[self.audio_column_name] = tts_client.tts(text, self.voice) return sample @@ -50,8 +51,9 @@ def _map_sample(self, sample): @dataclasses.dataclass class TextGenerationTask: - new_column_name: str = simple_parsing.field(default="explanation", alias="-c") - template: str = simple_parsing.field(default=DEFAULT_TEXTGEN_TEMPLATE, alias="-T") + new_column_name: str = simple_parsing.field(alias="-c") + template: str = simple_parsing.field(alias="-T") + json_mode: bool = simple_parsing.field(default=False, alias="-j") language_model: str = simple_parsing.field(default="gpt-4o", alias="-m") base_url: Optional[str] = simple_parsing.field(default=None, alias="-b") @@ -62,7 +64,11 @@ class TextGenerationTask: def __post_init__(self): # The OAI client is separate from the task to avoid pickling issues when multiprocessing. global chat_client - chat_client = openai.Client(base_url=self.base_url, api_key=self.api_key) + # Caching the client to avoid repeated calls to the API if the tool fails. + chat_client = caching.CachingChatWrapper( + openai.Client(base_url=self.base_url, api_key=self.api_key), + unique_id=f"{self.base_url}__{self.language_model}", + ) if self.template.startswith("@"): with open(self.template[1:], "r") as template_file: self.template = template_file.read() @@ -72,19 +78,27 @@ def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Datas return ds_split.map(self._map_sample, num_proc=num_proc) def _map_sample(self, sample): - input_text = self.template.format(**sample) - response = chat_client.chat.completions.create( + rendered = jinja2.Template(self.template).render(**sample, json_dump=json.dumps) + + if self.json_mode: + turns = json.loads(rendered) + assert isinstance(turns, list) + assert all(isinstance(turn, dict) for turn in turns) + assert len(turns) > 0 + assert turns[-1].get("role", None) == "user" + else: + turns = [{"role": "user", "content": rendered}] + + sample[self.new_column_name] = chat_client.chat_completion( model=self.language_model, - messages=[{"role": "user", "content": input_text}], + messages=turns, max_tokens=self.max_tokens, temperature=self.temperature, ) - sample[self.new_column_name] = response.choices[0].message.content return sample # This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model. -# Example usages: # just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN # just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -c explanation -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct # just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt @@ -94,6 +108,8 @@ class DatasetToolArgs: dataset_subset: Optional[str] = simple_parsing.field(default=None, alias="-S") dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s") + shuffle: bool = simple_parsing.field(default=False) + shuffle_seed: int = simple_parsing.field(default=42) num_samples: Optional[int] = simple_parsing.field(default=None, alias="-n") num_workers: int = simple_parsing.field(default=16, alias="-w") @@ -122,6 +138,8 @@ def main(args: DatasetToolArgs): for split, ds_split in data_dict.items(): print(f'Processing split "{split}"...') + if args.shuffle: + ds_split = ds_split.shuffle(seed=args.shuffle_seed) if args.num_samples: ds_split = ds_split.select(range(args.num_samples)) data_dict[split] = args.task.map_split(ds_split, args.num_workers) @@ -138,7 +156,9 @@ def main(args: DatasetToolArgs): try: if args.dataset_split: - data_dict[args.dataset_split].push_to_hub(args.upload_name, **hub_args) + data_dict[args.dataset_split].push_to_hub( + args.upload_name, split=args.dataset_split, **hub_args + ) else: data_dict.push_to_hub(args.upload_name, **hub_args) except Exception as e: @@ -149,6 +169,7 @@ def main(args: DatasetToolArgs): output_name = f"{split}-00000-of-00001.parquet" data_dict[split].to_parquet(output_name) print(f"Saved to {output_name}") + print(f"Sample {0} of {split}: {data_dict[split][0]}") if __name__ == "__main__": diff --git a/ultravox/tools/ds_tool/soda_alt_last_turn.jinja b/ultravox/tools/ds_tool/soda_alt_last_turn.jinja new file mode 100644 index 00000000..0dc86128 --- /dev/null +++ b/ultravox/tools/ds_tool/soda_alt_last_turn.jinja @@ -0,0 +1,8 @@ +{# This template was used to create the synthetic "alt_last_turn" field (Llama3-8B alternative to the last turn) in `fixie-ai/soda-audio` -#} +[ + { "role": "system", "content": "Follow the flow of the conversation and respond just like a human would in the same situation."}, + {% for turn in dialogue[:-1] %} + { "role": {% if loop.revindex0 % 2 == 0 %} "user" {% else %} "assistant" {% endif %}, "content": {{ json_dump(turn) }} } + {%- if not loop.last %},{% endif %} + {% endfor %} +] diff --git a/ultravox/tools/ds_tool/template_test.py b/ultravox/tools/ds_tool/template_test.py new file mode 100644 index 00000000..14de25a0 --- /dev/null +++ b/ultravox/tools/ds_tool/template_test.py @@ -0,0 +1,26 @@ +import json + +import jinja2 + + +def test_quotes(): + with open("tools/ds_tool/soda_alt_last_turn.jinja", "r") as template_file: + template = template_file.read() + + dialogue = [ + 'Have you ever used a double quote (")', + "Of course, what about a single quote (')?", + '"Yes, I have."', + "last turn is ignored!", + ] + + messages = json.loads( + jinja2.Template(template).render(dialogue=dialogue, json_dump=json.dumps) + ) + assert isinstance(messages, list) + assert all(isinstance(turn, dict) for turn in messages) + assert messages[-1]["role"] == "user" + + assert len(messages) == 4 + assert messages[0]["role"] == "system" + assert [x["content"] for x in messages[1:]] == dialogue[:-1] diff --git a/ultravox/tools/tts.py b/ultravox/tools/ds_tool/tts.py similarity index 97% rename from ultravox/tools/tts.py rename to ultravox/tools/ds_tool/tts.py index 3dc690d5..e6b8b9e4 100644 --- a/ultravox/tools/tts.py +++ b/ultravox/tools/ds_tool/tts.py @@ -26,7 +26,7 @@ def __init__(self, sample_rate: int = 16000): self._sample_rate = sample_rate @abc.abstractmethod - def tts(self, text: str, voice: Optional[str] = None): + def tts(self, text: str, voice: Optional[str] = None) -> bytes: raise NotImplementedError def _post(self, url: str, headers: Dict[str, str], json: Dict[str, Any]): @@ -34,7 +34,7 @@ def _post(self, url: str, headers: Dict[str, str], json: Dict[str, Any]): response.raise_for_status() return response - def _handle_pcm_response(self, response: requests.Response): + def _handle_pcm_response(self, response: requests.Response) -> bytes: pcm_array = np.frombuffer(response.content, dtype=np.int16) wav_bytes = io.BytesIO() sf.write(wav_bytes, pcm_array, self._sample_rate, format="wav") @@ -132,7 +132,6 @@ def tts(self, text: str, voice: Optional[str] = None): i = np.random.randint(len(self.ALL_VOICES)) + os.getpid() voice = self.ALL_VOICES[i % len(self.ALL_VOICES)] url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice}/stream?output_format=pcm_16000" - print("url", url) headers = {"xi-api-key": os.environ["ELEVEN_API_KEY"]} body = { "text": text,