From 2334a56728e6118a7c0456d69f0d7a64001a3dd1 Mon Sep 17 00:00:00 2001 From: Zhongqiang Huang Date: Wed, 24 Jul 2024 15:48:44 -0400 Subject: [PATCH] Revert change to enable push_to_hub at the subset level --- ultravox/tools/ds_tool/ds_tool.py | 69 +++++++++++++++++-------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 5a7068db..71c4f1f3 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -7,8 +7,6 @@ import jinja2 import openai import simple_parsing -from jinja2 import StrictUndefined -from jinja2 import TemplateError from ultravox.data import text_proc from ultravox.tools.ds_tool import caching @@ -51,8 +49,10 @@ def _map_sample(self, sample): # using a Jinja template for some added flexibility # The {{ var }} syntax is how Jinja denotes variables try: - text = jinja2.Template("{{" + self.column_name + "}}").render(**sample) - except TemplateError as e: + text = jinja2.Template( + "{{" + self.column_name + "}}", undefined=jinja2.StrictUndefined + ).render(**sample) + except jinja2.TemplateError as e: print(f"Error rendering template: {e}") print(f"column_name: {self.column_name}") print(f"sample keys: {list(sample.keys())}") @@ -100,10 +100,10 @@ def _map_sample(self, sample): # using a Jinja template for some added flexibility, template can include variables and functions # e.g., {{ text }} or {{ text_proc.format_asr_text(text) }} try: - rendered = jinja2.Template(self.template, undefined=StrictUndefined).render( - **sample, json_dump=json.dumps, text_proc=text_proc - ) - except TemplateError as e: + rendered = jinja2.Template( + self.template, undefined=jinja2.StrictUndefined + ).render(**sample, json_dump=json.dumps, text_proc=text_proc) + except jinja2.TemplateError as e: print(f"Error rendering template: {e}") print(f"template: {self.template}") print(f"sample keys: {list(sample.keys())}") @@ -177,26 +177,17 @@ def __post_init__(self): def main(args: DatasetToolArgs): ds_name = args.dataset_name print(f'Loading dataset "{ds_name}" for task {args.task}') - ds: datasets.DatasetDict = datasets.load_dataset( + data_dict: datasets.DatasetDict = datasets.load_dataset( ds_name, args.dataset_subset, split=args.dataset_split ) - if isinstance(ds, datasets.Dataset): - ds = datasets.DatasetDict({args.upload_split: ds}) + if isinstance(data_dict, datasets.Dataset): + data_dict = datasets.DatasetDict({args.upload_split: data_dict}) - if len(ds) > 1 and args.upload_split: + if len(data_dict) > 1 and args.upload_split: raise ValueError("Cannot upload multiple splits to a single split") - hub_args: Dict[str, Any] = { - "config_name": args.upload_subset or "default", - "token": args.token or os.environ.get("HF_TOKEN"), - "revision": args.upload_branch, - "private": args.private, - } - if args.num_shards is not None: - hub_args["num_shards"] = args.num_shards - - for split, ds_split in ds.items(): + for split, ds_split in data_dict.items(): print( f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {args.dataset_split}, containing {len(ds_split)} samples" ) @@ -204,22 +195,36 @@ def main(args: DatasetToolArgs): ds_split = ds_split.shuffle(seed=args.shuffle_seed) if args.num_samples: ds_split = ds_split.select(range(args.num_samples)) - ds_split = args.task.map_split( + data_dict[split] = args.task.map_split( ds_split, args.num_workers, args.writer_batch_size ) - upload_split = args.upload_split or split + hub_args: Dict[str, Any] = { + "config_name": args.upload_subset or "default", + "token": args.token or os.environ.get("HF_TOKEN"), + "revision": args.upload_branch, + "private": args.private, + } - try: - ds_split.push_to_hub(args.upload_name, split=upload_split, **hub_args) - except Exception as e: - print(f"Failed to push to hub: {e}") + if args.num_shards is not None: + hub_args["num_shards"] = args.num_shards - # If the push fails or upload_name is not specified, save the data locally. - output_name = f"{args.upload_subset}-{upload_split}-00000-of-00001.parquet" - ds_split.to_parquet(output_name) + try: + if args.dataset_split: + data_dict[args.dataset_split].push_to_hub( + args.upload_name, split=args.upload_split, **hub_args + ) + else: + data_dict.push_to_hub(args.upload_name, **hub_args) + except Exception as e: + print(f"Failed to push to hub: {e}") + + # If the push fails or upload_name is not specified, save the data locally. + for split in data_dict.keys(): + output_name = f"{split}-00000-of-00001.parquet" + data_dict[split].to_parquet(output_name) print(f"Saved to {output_name}") - print(f"Sample {0} of {args.upload_subset}: {ds[0]}") + print(f"Sample {0} of {split}: {data_dict[split][0]}") if __name__ == "__main__":