From 7f64bb17089d736bf791c8cab1c86c78562dcefe Mon Sep 17 00:00:00 2001 From: Zhongqiang Huang Date: Wed, 24 Jul 2024 15:48:44 -0400 Subject: [PATCH] Revert change to enable push_to_hub at the subset level --- ultravox/tools/ds_tool/ds_tool.py | 59 +++++++++++++++++-------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 5a7068db..5b2ff6f4 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -177,26 +177,17 @@ def __post_init__(self): def main(args: DatasetToolArgs): ds_name = args.dataset_name print(f'Loading dataset "{ds_name}" for task {args.task}') - ds: datasets.DatasetDict = datasets.load_dataset( + data_dict: datasets.DatasetDict = datasets.load_dataset( ds_name, args.dataset_subset, split=args.dataset_split ) - if isinstance(ds, datasets.Dataset): - ds = datasets.DatasetDict({args.upload_split: ds}) + if isinstance(data_dict, datasets.Dataset): + data_dict = datasets.DatasetDict({args.upload_split: data_dict}) - if len(ds) > 1 and args.upload_split: + if len(data_dict) > 1 and args.upload_split: raise ValueError("Cannot upload multiple splits to a single split") - hub_args: Dict[str, Any] = { - "config_name": args.upload_subset or "default", - "token": args.token or os.environ.get("HF_TOKEN"), - "revision": args.upload_branch, - "private": args.private, - } - if args.num_shards is not None: - hub_args["num_shards"] = args.num_shards - - for split, ds_split in ds.items(): + for split, ds_split in data_dict.items(): print( f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {args.dataset_split}, containing {len(ds_split)} samples" ) @@ -204,23 +195,37 @@ def main(args: DatasetToolArgs): ds_split = ds_split.shuffle(seed=args.shuffle_seed) if args.num_samples: ds_split = ds_split.select(range(args.num_samples)) - ds_split = args.task.map_split( + data_dict[split] = args.task.map_split( ds_split, args.num_workers, args.writer_batch_size ) + + hub_args: Dict[str, Any] = { + "config_name": args.upload_subset if args.upload_subset else "default", + "token": args.token or os.environ.get("HF_TOKEN"), + "revision": args.upload_branch, + "private": args.private, + } - upload_split = args.upload_split or split - - try: - ds_split.push_to_hub(args.upload_name, split=upload_split, **hub_args) - except Exception as e: - print(f"Failed to push to hub: {e}") - - # If the push fails or upload_name is not specified, save the data locally. - output_name = f"{args.upload_subset}-{upload_split}-00000-of-00001.parquet" - ds_split.to_parquet(output_name) + if args.num_shards is not None: + hub_args["num_shards"] = args.num_shards + + try: + if args.dataset_split: + upload_split = args.upload_split or args.dataset_split + data_dict[args.dataset_split].push_to_hub( + args.upload_name, split=upload_split, **hub_args + ) + else: + data_dict.push_to_hub(args.upload_name, **hub_args) + except Exception as e: + print(f"Failed to push to hub: {e}") + + # If the push fails or upload_name is not specified, save the data locally. + for split in data_dict.keys(): + output_name = f"{split}-00000-of-00001.parquet" + data_dict[split].to_parquet(output_name) print(f"Saved to {output_name}") - print(f"Sample {0} of {args.upload_subset}: {ds[0]}") - + print(f"Sample {0} of {split}: {data_dict[split][0]}") if __name__ == "__main__": main(simple_parsing.parse(DatasetToolArgs))