diff --git a/ultravox/data/text_proc.py b/ultravox/data/text_proc.py index ed517aac..89349ec9 100644 --- a/ultravox/data/text_proc.py +++ b/ultravox/data/text_proc.py @@ -1,5 +1,9 @@ import os +import sys +# Temporary fix for an issue where importing NLTK breaks PyTorch multiprocessing on MacOS. +# For more details, see: https://github.com/nltk/nltk/issues/2949 +sys.modules["tkinter"] = None # type: ignore import nltk # needed for truecase import truecase diff --git a/ultravox/tools/ds_tool/continuation.jinja b/ultravox/tools/ds_tool/continuation.jinja new file mode 100644 index 00000000..de1e113e --- /dev/null +++ b/ultravox/tools/ds_tool/continuation.jinja @@ -0,0 +1,3 @@ +Continue the following text using less than 50 words: + +{{ text_proc.format_asr_text(text) }} diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index c3aa0dda..71c4f1f3 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -8,6 +8,7 @@ import openai import simple_parsing +from ultravox.data import text_proc from ultravox.tools.ds_tool import caching from ultravox.tools.ds_tool import tts @@ -34,17 +35,31 @@ def __post_init__(self): provider=self.implementation, ) - def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Dataset: + def map_split( + self, ds_split: datasets.Dataset, num_proc: int, writer_batch_size: int + ) -> datasets.Dataset: print(f'TTS mapping "{self.column_name}" to "{self.audio_column_name}"...') - return ds_split.map(self._map_sample, num_proc=num_proc).cast_column( + return ds_split.map( + self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size + ).cast_column( self.audio_column_name, datasets.Audio(sampling_rate=self.sample_rate) ) def _map_sample(self, sample): # using a Jinja template for some added flexibility # The {{ var }} syntax is how Jinja denotes variables - text = jinja2.Template("{{" + self.column_name + "}}").render(**sample) - text = text["text"] if isinstance(text, dict) else text + try: + text = jinja2.Template( + "{{" + self.column_name + "}}", undefined=jinja2.StrictUndefined + ).render(**sample) + except jinja2.TemplateError as e: + print(f"Error rendering template: {e}") + print(f"column_name: {self.column_name}") + print(f"sample keys: {list(sample.keys())}") + raise ValueError( + f"Template rendering failed. Make sure column_name exists in the sample." + ) from e + sample[self.audio_column_name] = tts_client.tts(text, self.voice) return sample @@ -73,12 +88,28 @@ def __post_init__(self): with open(self.template[1:], "r") as template_file: self.template = template_file.read() - def map_split(self, ds_split: datasets.Dataset, num_proc: int) -> datasets.Dataset: + def map_split( + self, ds_split: datasets.Dataset, num_proc: int, writer_batch_size: int + ) -> datasets.Dataset: print(f'Generating "{self.new_column_name}" with template:\n{self.template}') - return ds_split.map(self._map_sample, num_proc=num_proc) + return ds_split.map( + self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size + ) def _map_sample(self, sample): - rendered = jinja2.Template(self.template).render(**sample, json_dump=json.dumps) + # using a Jinja template for some added flexibility, template can include variables and functions + # e.g., {{ text }} or {{ text_proc.format_asr_text(text) }} + try: + rendered = jinja2.Template( + self.template, undefined=jinja2.StrictUndefined + ).render(**sample, json_dump=json.dumps, text_proc=text_proc) + except jinja2.TemplateError as e: + print(f"Error rendering template: {e}") + print(f"template: {self.template}") + print(f"sample keys: {list(sample.keys())}") + raise ValueError( + f"Template rendering failed. Make sure all keys in the template exist in the sample." + ) from e if self.json_mode: turns = json.loads(rendered) @@ -102,6 +133,10 @@ def _map_sample(self, sample): # just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -c question -a audio --token $HF_WRITE_TOKEN # just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -c explanation -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct # just ds_tool textgen -d ylacombe/expresso -u fixie-ai/expresso -c continuation -T @expresso_template.txt +# just ds_tool textgen --new_column_name continuation --dataset_name openslr/librispeech_asr --dataset_subset clean --dataset_split train.360 \ +# --shuffle --upload_name fixie-ai/librispeech_asr --private --base_url https://api.fireworks.ai/inference/v1 \ +# --api_key $FIREWORKS_API_KEY --token $HF_TOKEN --language_model accounts/fireworks/models/llama-v3-8b-instruct \ +# --template @ultravox/tools/ds_tool/continuation.jinja --max_tokens 64 --num_workers 30 --writer_batch_size 30 @dataclasses.dataclass class DatasetToolArgs: # HF source dataset parameters @@ -114,12 +149,14 @@ class DatasetToolArgs: shuffle_seed: int = simple_parsing.field(default=42) num_samples: Optional[int] = simple_parsing.field(default=None, alias="-n") num_workers: int = simple_parsing.field(default=16, alias="-w") + writer_batch_size: int = simple_parsing.field(default=1000) # HF destination dataset parameters upload_name: Optional[str] = simple_parsing.field(default=None, alias="-u") + upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B") # eg if the original split="train", but we want to upload it as "validation" + upload_subset: Optional[str] = simple_parsing.field(default=None) upload_split: Optional[str] = simple_parsing.field(default=None) - upload_branch: Optional[str] = simple_parsing.field(default="main", alias="-B") num_shards: Optional[int] = simple_parsing.field(default=None, alias="-N") private: bool = simple_parsing.field(default=False) token: Optional[str] = None @@ -131,9 +168,10 @@ class DatasetToolArgs: ) def __post_init__(self): - assert ( - not self.upload_split or self.dataset_split - ), "Must specify dataset_split when using upload_split" + if not self.upload_subset and self.dataset_subset: + self.upload_subset = self.dataset_subset + if self.dataset_split and not self.upload_split: + self.upload_split = self.dataset_split def main(args: DatasetToolArgs): @@ -142,32 +180,39 @@ def main(args: DatasetToolArgs): data_dict: datasets.DatasetDict = datasets.load_dataset( ds_name, args.dataset_subset, split=args.dataset_split ) - if args.dataset_split: - data_dict = datasets.DatasetDict(**{args.dataset_split: data_dict}) + + if isinstance(data_dict, datasets.Dataset): + data_dict = datasets.DatasetDict({args.upload_split: data_dict}) + + if len(data_dict) > 1 and args.upload_split: + raise ValueError("Cannot upload multiple splits to a single split") for split, ds_split in data_dict.items(): - print(f'Processing split "{split}"...') + print( + f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {args.dataset_split}, containing {len(ds_split)} samples" + ) if args.shuffle: ds_split = ds_split.shuffle(seed=args.shuffle_seed) if args.num_samples: ds_split = ds_split.select(range(args.num_samples)) - data_dict[split] = args.task.map_split(ds_split, args.num_workers) + data_dict[split] = args.task.map_split( + ds_split, args.num_workers, args.writer_batch_size + ) - token = args.token or os.environ.get("HF_TOKEN") hub_args: Dict[str, Any] = { - "config_name": args.dataset_subset or "default", - "token": token, + "config_name": args.upload_subset or "default", + "token": args.token or os.environ.get("HF_TOKEN"), "revision": args.upload_branch, "private": args.private, - "split": args.upload_split, } + if args.num_shards is not None: - hub_args["num_shards"] = {split: args.num_shards for split in data_dict.keys()} + hub_args["num_shards"] = args.num_shards try: if args.dataset_split: data_dict[args.dataset_split].push_to_hub( - args.upload_name, split=args.dataset_split, **hub_args + args.upload_name, split=args.upload_split, **hub_args ) else: data_dict.push_to_hub(args.upload_name, **hub_args)