Skip to content

Commit

Permalink
Revert change to enable push_to_hub at the subset level
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhongqiang Huang committed Jul 24, 2024
1 parent 2189851 commit 94b86ff
Showing 1 changed file with 32 additions and 27 deletions.
59 changes: 32 additions & 27 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,50 +177,55 @@ def __post_init__(self):
def main(args: DatasetToolArgs):
ds_name = args.dataset_name
print(f'Loading dataset "{ds_name}" for task {args.task}')
ds: datasets.DatasetDict = datasets.load_dataset(
data_dict: datasets.DatasetDict = datasets.load_dataset(
ds_name, args.dataset_subset, split=args.dataset_split
)

if isinstance(ds, datasets.Dataset):
ds = datasets.DatasetDict({args.upload_split: ds})
if isinstance(data_dict, datasets.Dataset):
data_dict = datasets.DatasetDict({args.upload_split: data_dict})

if len(ds) > 1 and args.upload_split:
if len(data_dict) > 1 and args.upload_split:
raise ValueError("Cannot upload multiple splits to a single split")

hub_args: Dict[str, Any] = {
"config_name": args.upload_subset or "default",
"token": args.token or os.environ.get("HF_TOKEN"),
"revision": args.upload_branch,
"private": args.private,
}
if args.num_shards is not None:
hub_args["num_shards"] = args.num_shards

for split, ds_split in ds.items():
for split, ds_split in data_dict.items():
print(
f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {args.dataset_split}, containing {len(ds_split)} samples"
)
if args.shuffle:
ds_split = ds_split.shuffle(seed=args.shuffle_seed)
if args.num_samples:
ds_split = ds_split.select(range(args.num_samples))
ds_split = args.task.map_split(
data_dict[split] = args.task.map_split(
ds_split, args.num_workers, args.writer_batch_size
)

hub_args: Dict[str, Any] = {
"config_name": args.upload_subset if args.upload_subset else "default",
"token": args.token or os.environ.get("HF_TOKEN"),
"revision": args.upload_branch,
"private": args.private,
}

upload_split = args.upload_split or split

try:
ds_split.push_to_hub(args.upload_name, split=upload_split, **hub_args)
except Exception as e:
print(f"Failed to push to hub: {e}")

# If the push fails or upload_name is not specified, save the data locally.
output_name = f"{args.upload_subset}-{upload_split}-00000-of-00001.parquet"
ds_split.to_parquet(output_name)
if args.num_shards is not None:
hub_args["num_shards"] = args.num_shards

try:
if args.dataset_split:
upload_split = args.upload_split or args.dataset_split
data_dict[args.dataset_split].push_to_hub(
args.upload_name, split=upload_split, **hub_args
)
else:
data_dict.push_to_hub(args.upload_name, **hub_args)
except Exception as e:
print(f"Failed to push to hub: {e}")

# If the push fails or upload_name is not specified, save the data locally.
for split in data_dict.keys():
output_name = f"{split}-00000-of-00001.parquet"
data_dict[split].to_parquet(output_name)
print(f"Saved to {output_name}")
print(f"Sample {0} of {args.upload_subset}: {ds[0]}")

print(f"Sample {0} of {split}: {data_dict[split][0]}")

if __name__ == "__main__":
main(simple_parsing.parse(DatasetToolArgs))

0 comments on commit 94b86ff

Please sign in to comment.