Skip to content

Commit

Permalink
Revert change to enable push_to_hub at the subset level
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhongqiang Huang committed Jul 24, 2024
1 parent 2189851 commit 2334a56
Showing 1 changed file with 37 additions and 32 deletions.
69 changes: 37 additions & 32 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import jinja2
import openai
import simple_parsing
from jinja2 import StrictUndefined
from jinja2 import TemplateError

from ultravox.data import text_proc
from ultravox.tools.ds_tool import caching
Expand Down Expand Up @@ -51,8 +49,10 @@ def _map_sample(self, sample):
# using a Jinja template for some added flexibility
# The {{ var }} syntax is how Jinja denotes variables
try:
text = jinja2.Template("{{" + self.column_name + "}}").render(**sample)
except TemplateError as e:
text = jinja2.Template(
"{{" + self.column_name + "}}", undefined=jinja2.StrictUndefined
).render(**sample)
except jinja2.TemplateError as e:
print(f"Error rendering template: {e}")
print(f"column_name: {self.column_name}")
print(f"sample keys: {list(sample.keys())}")
Expand Down Expand Up @@ -100,10 +100,10 @@ def _map_sample(self, sample):
# using a Jinja template for some added flexibility, template can include variables and functions
# e.g., {{ text }} or {{ text_proc.format_asr_text(text) }}
try:
rendered = jinja2.Template(self.template, undefined=StrictUndefined).render(
**sample, json_dump=json.dumps, text_proc=text_proc
)
except TemplateError as e:
rendered = jinja2.Template(
self.template, undefined=jinja2.StrictUndefined
).render(**sample, json_dump=json.dumps, text_proc=text_proc)
except jinja2.TemplateError as e:
print(f"Error rendering template: {e}")
print(f"template: {self.template}")
print(f"sample keys: {list(sample.keys())}")
Expand Down Expand Up @@ -177,49 +177,54 @@ def __post_init__(self):
def main(args: DatasetToolArgs):
ds_name = args.dataset_name
print(f'Loading dataset "{ds_name}" for task {args.task}')
ds: datasets.DatasetDict = datasets.load_dataset(
data_dict: datasets.DatasetDict = datasets.load_dataset(
ds_name, args.dataset_subset, split=args.dataset_split
)

if isinstance(ds, datasets.Dataset):
ds = datasets.DatasetDict({args.upload_split: ds})
if isinstance(data_dict, datasets.Dataset):
data_dict = datasets.DatasetDict({args.upload_split: data_dict})

if len(ds) > 1 and args.upload_split:
if len(data_dict) > 1 and args.upload_split:
raise ValueError("Cannot upload multiple splits to a single split")

hub_args: Dict[str, Any] = {
"config_name": args.upload_subset or "default",
"token": args.token or os.environ.get("HF_TOKEN"),
"revision": args.upload_branch,
"private": args.private,
}
if args.num_shards is not None:
hub_args["num_shards"] = args.num_shards

for split, ds_split in ds.items():
for split, ds_split in data_dict.items():
print(
f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {args.dataset_split}, containing {len(ds_split)} samples"
)
if args.shuffle:
ds_split = ds_split.shuffle(seed=args.shuffle_seed)
if args.num_samples:
ds_split = ds_split.select(range(args.num_samples))
ds_split = args.task.map_split(
data_dict[split] = args.task.map_split(
ds_split, args.num_workers, args.writer_batch_size
)

upload_split = args.upload_split or split
hub_args: Dict[str, Any] = {
"config_name": args.upload_subset or "default",
"token": args.token or os.environ.get("HF_TOKEN"),
"revision": args.upload_branch,
"private": args.private,
}

try:
ds_split.push_to_hub(args.upload_name, split=upload_split, **hub_args)
except Exception as e:
print(f"Failed to push to hub: {e}")
if args.num_shards is not None:
hub_args["num_shards"] = args.num_shards

# If the push fails or upload_name is not specified, save the data locally.
output_name = f"{args.upload_subset}-{upload_split}-00000-of-00001.parquet"
ds_split.to_parquet(output_name)
try:
if args.dataset_split:
data_dict[args.dataset_split].push_to_hub(
args.upload_name, split=args.upload_split, **hub_args
)
else:
data_dict.push_to_hub(args.upload_name, **hub_args)
except Exception as e:
print(f"Failed to push to hub: {e}")

# If the push fails or upload_name is not specified, save the data locally.
for split in data_dict.keys():
output_name = f"{split}-00000-of-00001.parquet"
data_dict[split].to_parquet(output_name)
print(f"Saved to {output_name}")
print(f"Sample {0} of {args.upload_subset}: {ds[0]}")
print(f"Sample {0} of {split}: {data_dict[split][0]}")


if __name__ == "__main__":
Expand Down

0 comments on commit 2334a56

Please sign in to comment.