From 50cc4acff62656018c67c4531fb11c7ef5485028 Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Wed, 12 Jun 2024 20:48:44 +0000 Subject: [PATCH] datasets: Handle converting `int{16,32}` audio data in `VoiceSample`. We saw `VoiceSample` failing the assert on `float32` audio data when playing around with the Gradio infer app and submitting an `mp3` file. We didn't dig deeper into Gradio (I'm sure it's possible to alter/convert there as well), but it seems potentially useful for `VoiceSample` to handle `int16` and `int32` audio data on top of what it already handles. --- ultravox/data/datasets.py | 6 +++- ultravox/data/datasets_test.py | 63 +++++++++++++++++++++++++++++----- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index d945423a..d0277560 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -159,6 +159,10 @@ def __post_init__(self): if self.audio is not None: if self.audio.dtype == np.float64: self.audio = self.audio.astype(np.float32) + elif self.audio.dtype == np.int16: + self.audio = self.audio.astype(np.float32) / np.float32(32768.0) + elif self.audio.dtype == np.int32: + self.audio = self.audio.astype(np.float32) / np.float32(2147483648.0) assert ( self.audio.dtype == np.float32 ), f"Unexpected audio dtype: {self.audio.dtype}" @@ -166,7 +170,7 @@ def __post_init__(self): messages: List[Dict[str, str]] """List of messages, each with a "role" and "content" field.""" - audio: Optional[np.ndarray] = None + audio: Optional[np.typing.NDArray[np.float32]] = None """Audio data as float32 PCM @ `sample_rate`.""" sample_rate: int = SAMPLE_RATE """Audio sample rate in Hz.""" diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 74239145..1bf7ee23 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -1,8 +1,9 @@ import itertools -from typing import Optional +from typing import Optional, Union import datasets as hf_datasets import numpy as np +import pytest import torch from torch.utils import data from transformers.feature_extraction_utils import BatchFeature @@ -135,18 +136,39 @@ def _create_sine_wave( duration: float = 1.0, sample_rate: int = 16000, amplitude: float = 0.1, -): + target_dtype: str = "float32", +) -> Union[ + np.typing.NDArray[np.float32], + np.typing.NDArray[np.float64], + np.typing.NDArray[np.int16], + np.typing.NDArray[np.int32], +]: t = np.arange(sample_rate * duration, dtype=np.float32) / sample_rate - return amplitude * np.sin(2 * np.pi * freq * t) - - -def test_create_sample(): - # Create a PCM sine wave at 440 Hz, as int16. - array = _create_sine_wave() + wave = amplitude * np.sin(2 * np.pi * freq * t) + match target_dtype: + case "int16": + wave = np.int16(wave * 32767) + case "int32": + wave = np.int32(wave * 2147483647) + case "float32": + # Already float32, nothing needed. + pass + case "float64": + wave = wave.astype(np.float64) + case _: + raise ValueError(f"Unsupported dtype: {target_dtype}") + return wave + + +def _create_and_validate_sample(target_dtype: str = "float32"): + # Create a sine wave at 440 Hz with a duration of 1.0 second, sampled at 16 + # kHz, with an amplitude of 0.1, and the specified dtype. + array = _create_sine_wave(target_dtype=target_dtype) sample = datasets.VoiceSample.from_prompt_and_raw( "Transcribe <|audio|>", array, 16000 ) assert sample.sample_rate == 16000 + assert sample.audio is not None, "sample.audio should not be None" assert len(sample.audio) == 16000 assert sample.audio.dtype == np.float32 assert sample.messages == [ @@ -156,7 +178,32 @@ def test_create_sample(): json = sample.to_json() sample2 = datasets.VoiceSample.from_json(json) assert sample2.sample_rate == sample.sample_rate + assert sample2.audio is not None, "sample2.audio should not be None" assert len(sample2.audio) == len(sample.audio) assert sample2.audio.dtype == sample.audio.dtype assert sample2.messages == sample.messages assert np.allclose(sample2.audio, sample.audio, rtol=0.0001, atol=0.0001) + + +def test_create_sample__int16(): + _create_and_validate_sample("int16") + + +def test_create_sample__int32(): + _create_and_validate_sample("int32") + + +def test_create_sample__float32(): + _create_and_validate_sample("float32") + + +def test_create_sample__float64(): + _create_and_validate_sample("float64") + + +def test_create_sample__raises_on_unsupported_dtype(): + with pytest.raises(AssertionError): + array = np.ndarray(shape=(16000,), dtype=np.uint8) + sample = datasets.VoiceSample.from_prompt_and_raw( + "Transcribe <|audio|>", array, 16000 + )