Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ds_tool.py #52

Merged
merged 17 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ultravox/data/text_proc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import sys

# Temporary fix for an issue where importing NLTK breaks PyTorch multiprocessing on MacOS.
# For more details, see: https://github.com/nltk/nltk/issues/2949
sys.modules["tkinter"] = None # type: ignore
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
import nltk # needed for truecase
import truecase

Expand Down
3 changes: 3 additions & 0 deletions ultravox/tools/ds_tool/continuation.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Continue the following text using less than 50 words:

{{ text_proc.format_asr_text(text) }}
87 changes: 66 additions & 21 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import openai
import simple_parsing

from ultravox.data import text_proc
from ultravox.tools.ds_tool import caching
from ultravox.tools.ds_tool import tts

Expand All @@ -34,17 +35,31 @@ def __post_init__(self):
provider=self.implementation,
)

def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Dataset:
def map_split(
self, ds_split: datasets.Dataset, num_proc: int, writer_batch_size: int
) -> datasets.Dataset:
print(f'TTS mapping "{self.column_name}" to "{self.audio_column_name}"...')
return ds_split.map(self._map_sample, num_proc=num_proc).cast_column(
return ds_split.map(
self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size
).cast_column(
self.audio_column_name, datasets.Audio(sampling_rate=self.sample_rate)
)

def _map_sample(self, sample):
# using a Jinja template for some added flexibility
# The {{ var }} syntax is how Jinja denotes variables
text = jinja2.Template("{{" + self.column_name + "}}").render(**sample)
text = text["text"] if isinstance(text, dict) else text
try:
text = jinja2.Template(
"{{" + self.column_name + "}}", undefined=jinja2.StrictUndefined
).render(**sample)
except jinja2.TemplateError as e:
print(f"Error rendering template: {e}")
print(f"column_name: {self.column_name}")
print(f"sample keys: {list(sample.keys())}")
raise ValueError(
f"Template rendering failed. Make sure column_name exists in the sample."
) from e

sample[self.audio_column_name] = tts_client.tts(text, self.voice)
return sample

Expand Down Expand Up @@ -73,12 +88,28 @@ def __post_init__(self):
with open(self.template[1:], "r") as template_file:
self.template = template_file.read()

def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Dataset:
def map_split(
self, ds_split: datasets.Dataset, num_proc: int, writer_batch_size: int
) -> datasets.Dataset:
print(f'Generating "{self.new_column_name}" with template:\n{self.template}')
return ds_split.map(self._map_sample, num_proc=num_proc)
return ds_split.map(
self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size
)

def _map_sample(self, sample):
rendered = jinja2.Template(self.template).render(**sample, json_dump=json.dumps)
# using a Jinja template for some added flexibility, template can include variables and functions
# e.g., {{ text }} or {{ text_proc.format_asr_text(text) }}
try:
rendered = jinja2.Template(
self.template, undefined=jinja2.StrictUndefined
).render(**sample, json_dump=json.dumps, text_proc=text_proc)
except jinja2.TemplateError as e:
print(f"Error rendering template: {e}")
print(f"template: {self.template}")
print(f"sample keys: {list(sample.keys())}")
raise ValueError(
f"Template rendering failed. Make sure all keys in the template exist in the sample."
) from e

if self.json_mode:
turns = json.loads(rendered)
Expand All @@ -102,6 +133,10 @@ def _map_sample(self, sample):
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -c explanation -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt
# just ds_tool textgen --new_column_name continuation --dataset_name openslr/librispeech_asr --dataset_subset clean --dataset_split train.360 \
# --shuffle --upload_name fixie-ai/librispeech_asr --private --base_url https://api.fireworks.ai/inference/v1 \
# --api_key $FIREWORKS_API_KEY --token $HF_TOKEN --language_model accounts/fireworks/models/llama-v3-8b-instruct \
# --template @ultravox/tools/ds_tool/continuation.jinja --max_tokens 64 --num_workers 30 --writer_batch_size 30
@dataclasses.dataclass
class DatasetToolArgs:
# HF source dataset parameters
Expand All @@ -114,12 +149,14 @@ class DatasetToolArgs:
shuffle_seed: int = simple_parsing.field(default=42)
num_samples: Optional[int] = simple_parsing.field(default=None, alias="-n")
num_workers: int = simple_parsing.field(default=16, alias="-w")
writer_batch_size: int = simple_parsing.field(default=1000)
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved

# HF destination dataset parameters
upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u")
upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B")
# eg if the original split="train", but we want to upload it as "validation"
upload_subset: Optional[str] = simple_parsing.field(default=None)
upload_split: Optional[str] = simple_parsing.field(default=None)
upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B")
num_shards: Optional[int] = simple_parsing.field(default=None, alias="-N")
private: bool = simple_parsing.field(default=False)
token: Optional[str] = None
Expand All @@ -131,9 +168,10 @@ class DatasetToolArgs:
)

def __post_init__(self):
assert (
not self.upload_split or self.dataset_split
), "Must specify dataset_split when using upload_split"
if not self.upload_subset and self.dataset_subset:
self.upload_subset = self.dataset_subset
if self.dataset_split and not self.upload_split:
self.upload_split = self.dataset_split


def main(args: DatasetToolArgs):
Expand All @@ -142,32 +180,39 @@ def main(args: DatasetToolArgs):
data_dict: datasets.DatasetDict = datasets.load_dataset(
ds_name, args.dataset_subset, split=args.dataset_split
)
if args.dataset_split:
data_dict = datasets.DatasetDict(**{args.dataset_split: data_dict})

if isinstance(data_dict, datasets.Dataset):
data_dict = datasets.DatasetDict({args.upload_split: data_dict})

if len(data_dict) > 1 and args.upload_split:
raise ValueError("Cannot upload multiple splits to a single split")

for split, ds_split in data_dict.items():
print(f'Processing split "{split}"...')
print(
f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {args.dataset_split}, containing {len(ds_split)} samples"
)
if args.shuffle:
ds_split = ds_split.shuffle(seed=args.shuffle_seed)
if args.num_samples:
ds_split = ds_split.select(range(args.num_samples))
data_dict[split] = args.task.map_split(ds_split, args.num_workers)
data_dict[split] = args.task.map_split(
ds_split, args.num_workers, args.writer_batch_size
)

token = args.token or os.environ.get("HF_TOKEN")
hub_args: Dict[str, Any] = {
"config_name": args.dataset_subset or "default",
"token": token,
"config_name": args.upload_subset or "default",
"token": args.token or os.environ.get("HF_TOKEN"),
"revision": args.upload_branch,
"private": args.private,
"split": args.upload_split,
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
}

if args.num_shards is not None:
hub_args["num_shards"] = {split: args.num_shards for split in data_dict.keys()}
hub_args["num_shards"] = args.num_shards

try:
if args.dataset_split:
data_dict[args.dataset_split].push_to_hub(
args.upload_name, split=args.dataset_split, **hub_args
args.upload_name, split=args.upload_split, **hub_args
)
else:
data_dict.push_to_hub(args.upload_name, **hub_args)
Expand Down
Loading