Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ds_tool.py #52

Merged
merged 17 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ultravox/data/text_proc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import sys

sys.modules["tkinter"] = None # type: ignore
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
import nltk # needed for truecase
import truecase

Expand Down
3 changes: 3 additions & 0 deletions ultravox/tools/ds_tool/continuation.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Continue the following text using less than 50 words:

{{ text }}
120 changes: 82 additions & 38 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import dataclasses
import json
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

import datasets
import jinja2
import openai
import simple_parsing
from jinja2 import StrictUndefined
from jinja2 import TemplateError

from ultravox.data.text_proc import format_asr_text
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
from ultravox.tools.ds_tool import caching
from ultravox.tools.ds_tool import tts

Expand All @@ -23,6 +26,8 @@ class TtsTask:
audio_column_name: Optional[str] = simple_parsing.field(default=None, alias="-a")
voice: Optional[str] = simple_parsing.field(default=None, alias="-V")
sample_rate: int = simple_parsing.field(default=16000, alias="-r")
write_batch_size: int = 1000
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
format_fields: List[str] = simple_parsing.field(default_factory=list)
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self):
# The TTS client is separate from the task to avoid pickling issues when multiprocessing.
Expand All @@ -36,14 +41,29 @@ def __post_init__(self):

def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Dataset:
print(f'TTS mapping "{self.column_name}" to "{self.audio_column_name}"...')
return ds_split.map(self._map_sample, num_proc=num_proc).cast_column(
return ds_split.map(
self._map_sample, num_proc=num_proc, writer_batch_size=self.write_batch_size
).cast_column(
self.audio_column_name, datasets.Audio(sampling_rate=self.sample_rate)
)

def _map_sample(self, sample):
for field in self.format_fields:
sample[field] = format_asr_text(sample[field])
# using a Jinja template for some added flexibility
# The {{ var }} syntax is how Jinja denotes variables
text = jinja2.Template("{{" + self.column_name + "}}").render(**sample)
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved

try:
text = jinja2.Template("{{" + self.column_name + "}}").render(**sample)
except TemplateError as e:
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
print(f"Error rendering template: {e}")
print(f"column_name: {self.column_name}")
print(f"sample keys: {list(sample.keys())}")
raise ValueError(
f"Template rendering failed. Make sure column_name exists in the sample."
) from e

text = text["text"] if isinstance(text, dict) else text
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
sample[self.audio_column_name] = tts_client.tts(text, self.voice)
return sample
Expand All @@ -60,6 +80,8 @@ class TextGenerationTask:
api_key: Optional[str] = simple_parsing.field(default=None, alias="-k")
max_tokens: int = 128
temperature: float = 0
write_batch_size: int = 1000
format_fields: List[str] = simple_parsing.field(default_factory=list)
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self):
# The OAI client is separate from the task to avoid pickling issues when multiprocessing.
Expand All @@ -75,10 +97,25 @@ def __post_init__(self):

def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Dataset:
print(f'Generating "{self.new_column_name}" with template:\n{self.template}')
return ds_split.map(self._map_sample, num_proc=num_proc)
return ds_split.map(
self._map_sample, num_proc=num_proc, writer_batch_size=self.write_batch_size
)

def _map_sample(self, sample):
rendered = jinja2.Template(self.template).render(**sample, json_dump=json.dumps)
for field in self.format_fields:
sample[field] = format_asr_text(sample[field])
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved

try:
rendered = jinja2.Template(self.template, undefined=StrictUndefined).render(
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
**sample, json_dump=json.dumps
)
except TemplateError as e:
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
print(f"Error rendering template: {e}")
print(f"template: {self.template}")
print(f"sample keys: {list(sample.keys())}")
raise ValueError(
f"Template rendering failed. Make sure all keys in the template exist in the sample."
) from e

if self.json_mode:
turns = json.loads(rendered)
Expand All @@ -102,6 +139,10 @@ def _map_sample(self, sample):
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -c explanation -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
# just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt
# just ds_tool textgen --new_column_name continuation --dataset_name openslr/librispeech_asr --dataset_subset clean --dataset_split train.360 \
# --shuffle --format_fields text --upload_name fixie-ai/librispeech_asr --private --base_url https://api.fireworks.ai/inference/v1 \
# --api_key $FIREWORKS_API_KEY --token $HF_TOKEN --language_model accounts/fireworks/models/llama-v3-8b-instruct \
# --template @ultravox/tools/ds_tool/continuation.jinja --max_tokens 64 --num_workers 30 --write_batch_size 30
@dataclasses.dataclass
class DatasetToolArgs:
# HF source dataset parameters
Expand All @@ -117,9 +158,10 @@ class DatasetToolArgs:

# HF destination dataset parameters
upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u")
upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B")
# eg if the original split="train", but we want to upload it as "validation"
upload_subset: Optional[str] = simple_parsing.field(default=None)
upload_split: Optional[str] = simple_parsing.field(default=None)
upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B")
num_shards: Optional[int] = simple_parsing.field(default=None, alias="-N")
private: bool = simple_parsing.field(default=False)
token: Optional[str] = None
Expand All @@ -131,55 +173,57 @@ class DatasetToolArgs:
)

def __post_init__(self):
assert (
not self.upload_split or self.dataset_split
), "Must specify dataset_split when using upload_split"
assert self.dataset_subset, "dataset_subset must be specified"
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
if not self.upload_subset:
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
self.upload_subset = self.dataset_subset
if self.dataset_split and not self.upload_split:
self.upload_split = self.dataset_split


def main(args: DatasetToolArgs):
ds_name = args.dataset_name
print(f'Loading dataset "{ds_name}" for task {args.task}')
data_dict: datasets.DatasetDict = datasets.load_dataset(
ds: datasets.DatasetDict = datasets.load_dataset(
ds_name, args.dataset_subset, split=args.dataset_split
)
if args.dataset_split:
data_dict = datasets.DatasetDict(**{args.dataset_split: data_dict})

for split, ds_split in data_dict.items():
print(f'Processing split "{split}"...')
if args.shuffle:
ds_split = ds_split.shuffle(seed=args.shuffle_seed)
if args.num_samples:
ds_split = ds_split.select(range(args.num_samples))
data_dict[split] = args.task.map_split(ds_split, args.num_workers)
if isinstance(ds, datasets.Dataset):
ds = datasets.DatasetDict({args.upload_split: ds})

if len(ds) > 1 and args.upload_split:
raise ValueError("Cannot upload multiple splits to a single split")

token = args.token or os.environ.get("HF_TOKEN")
hub_args: Dict[str, Any] = {
"config_name": args.dataset_subset or "default",
"token": token,
"config_name": args.upload_subset,
"token": args.token or os.environ.get("HF_TOKEN"),
"revision": args.upload_branch,
"private": args.private,
"split": args.upload_split,
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
}
if args.num_shards is not None:
hub_args["num_shards"] = {split: args.num_shards for split in data_dict.keys()}
hub_args["num_shards"] = args.num_shards

try:
if args.dataset_split:
data_dict[args.dataset_split].push_to_hub(
args.upload_name, split=args.dataset_split, **hub_args
)
else:
data_dict.push_to_hub(args.upload_name, **hub_args)
except Exception as e:
print(f"Failed to push to hub: {e}")

# If the push fails or upload_name is not specified, save the data locally.
for split in data_dict.keys():
output_name = f"{split}-00000-of-00001.parquet"
data_dict[split].to_parquet(output_name)
for split, ds_split in ds.items():
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
print(
f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {args.dataset_split}, containing {len(ds_split)} samples"
)
if args.shuffle:
ds_split = ds_split.shuffle(seed=args.shuffle_seed)
if args.num_samples:
ds_split = ds_split.select(range(args.num_samples))
ds_split = args.task.map_split(ds_split, args.num_workers)

upload_split = args.upload_split or split

try:
ds_split.push_to_hub(args.upload_name, split=upload_split, **hub_args)
except Exception as e:
print(f"Failed to push to hub: {e}")

# If the push fails or upload_name is not specified, save the data locally.
output_name = f"{args.upload_subset}-{upload_split}-00000-of-00001.parquet"
ds_split.to_parquet(output_name)
print(f"Saved to {output_name}")
print(f"Sample {0} of {split}: {data_dict[split][0]}")
print(f"Sample {0} of {args.upload_subset}: {ds[0]}")


if __name__ == "__main__":
Expand Down
Loading