Skip to content

Commit

Permalink
Filter out audio in map sample (#72)
Browse files Browse the repository at this point in the history
* First

* One liner

* update name
  • Loading branch information
liPatrick authored Aug 8, 2024
1 parent 2e3a49e commit 0e3201e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def map_split(
ds_split: datasets.Dataset,
num_proc: int,
writer_batch_size: int,
excluded_fields: List[str],
exclude_fields: List[str],
) -> datasets.Dataset:
print(f'TTS mapping "{self.template}" to "{self.audio_column_name}"...')
ds_split = ds_split.map(
Expand Down Expand Up @@ -128,8 +128,9 @@ def _map_sample(self, sample, exclude_fields):
# We need to filter out the audio before the sample is passed into the jinja template
# or it will get loaded into memory and spike usage.
filtered_sample = {
k: v for k, v in sample.items() if k not in exclude_fields
k: sample[k] for k in sample.keys() if k not in exclude_fields
}

rendered = jinja2.Template(
self.template, undefined=jinja2.StrictUndefined
).render(**filtered_sample, json_dump=json.dumps, text_proc=text_proc)
Expand Down

0 comments on commit 0e3201e

Please sign in to comment.