From 0e3201e7d4b9328b72a6f7cd009edfe0f73b32b1 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Thu, 8 Aug 2024 11:09:10 -0700 Subject: [PATCH] Filter out audio in map sample (#72) * First * One liner * update name --- ultravox/tools/ds_tool/ds_tool.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 5ca3b878..44c3b866 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -46,7 +46,7 @@ def map_split( ds_split: datasets.Dataset, num_proc: int, writer_batch_size: int, - excluded_fields: List[str], + exclude_fields: List[str], ) -> datasets.Dataset: print(f'TTS mapping "{self.template}" to "{self.audio_column_name}"...') ds_split = ds_split.map( @@ -128,8 +128,9 @@ def _map_sample(self, sample, exclude_fields): # We need to filter out the audio before the sample is passed into the jinja template # or it will get loaded into memory and spike usage. filtered_sample = { - k: v for k, v in sample.items() if k not in exclude_fields + k: sample[k] for k in sample.keys() if k not in exclude_fields } + rendered = jinja2.Template( self.template, undefined=jinja2.StrictUndefined ).render(**filtered_sample, json_dump=json.dumps, text_proc=text_proc)