Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

datasets: Handle converting int16 audio data in VoiceSample. #26

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,18 @@ def __post_init__(self):
if self.audio is not None:
if self.audio.dtype == np.float64:
self.audio = self.audio.astype(np.float32)
elif self.audio.dtype == np.int16:
self.audio = self.audio.astype(np.float32) / np.float32(32768.0)
shaper marked this conversation as resolved.
Show resolved Hide resolved
shaper marked this conversation as resolved.
Show resolved Hide resolved
elif self.audio.dtype == np.int32:
self.audio = self.audio.astype(np.float32) / np.float32(2147483648.0)
assert (
self.audio.dtype == np.float32
), f"Unexpected audio dtype: {self.audio.dtype}"
assert self.audio.ndim == 1, f"Unexpected audio shape: {self.audio.shape}"

messages: List[Dict[str, str]]
"""List of messages, each with a "role" and "content" field."""
audio: Optional[np.ndarray] = None
audio: Optional[np.typing.NDArray[np.float32]] = None
"""Audio data as float32 PCM @ `sample_rate`."""
sample_rate: int = SAMPLE_RATE
"""Audio sample rate in Hz."""
Expand Down
63 changes: 55 additions & 8 deletions ultravox/data/datasets_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import itertools
from typing import Optional
from typing import Optional, Union

import datasets as hf_datasets
import numpy as np
import pytest
import torch
from torch.utils import data
from transformers.feature_extraction_utils import BatchFeature
Expand Down Expand Up @@ -135,18 +136,39 @@ def _create_sine_wave(
duration: float = 1.0,
sample_rate: int = 16000,
amplitude: float = 0.1,
):
target_dtype: str = "float32",
) -> Union[
np.typing.NDArray[np.float32],
np.typing.NDArray[np.float64],
np.typing.NDArray[np.int16],
np.typing.NDArray[np.int32],
]:
t = np.arange(sample_rate * duration, dtype=np.float32) / sample_rate
return amplitude * np.sin(2 * np.pi * freq * t)


def test_create_sample():
# Create a PCM sine wave at 440 Hz, as int16.
array = _create_sine_wave()
wave = amplitude * np.sin(2 * np.pi * freq * t)
match target_dtype:
case "int16":
wave = np.int16(wave * 32767)
case "int32":
wave = np.int32(wave * 2147483647)
case "float32":
# Already float32, nothing needed.
pass
case "float64":
wave = wave.astype(np.float64)
case _:
raise ValueError(f"Unsupported dtype: {target_dtype}")
return wave


def _create_and_validate_sample(target_dtype: str = "float32"):
# Create a sine wave at 440 Hz with a duration of 1.0 second, sampled at 16
# kHz, with an amplitude of 0.1, and the specified dtype.
array = _create_sine_wave(target_dtype=target_dtype)
sample = datasets.VoiceSample.from_prompt_and_raw(
"Transcribe <|audio|>", array, 16000
)
assert sample.sample_rate == 16000
assert sample.audio is not None, "sample.audio should not be None"
assert len(sample.audio) == 16000
assert sample.audio.dtype == np.float32
assert sample.messages == [
Expand All @@ -156,7 +178,32 @@ def test_create_sample():
json = sample.to_json()
sample2 = datasets.VoiceSample.from_json(json)
assert sample2.sample_rate == sample.sample_rate
assert sample2.audio is not None, "sample2.audio should not be None"
assert len(sample2.audio) == len(sample.audio)
assert sample2.audio.dtype == sample.audio.dtype
assert sample2.messages == sample.messages
assert np.allclose(sample2.audio, sample.audio, rtol=0.0001, atol=0.0001)


def test_create_sample__int16():
_create_and_validate_sample("int16")


def test_create_sample__int32():
_create_and_validate_sample("int32")


def test_create_sample__float32():
_create_and_validate_sample("float32")


def test_create_sample__float64():
_create_and_validate_sample("float64")


def test_create_sample__raises_on_unsupported_dtype():
with pytest.raises(AssertionError):
array = np.ndarray(shape=(16000,), dtype=np.uint8)
sample = datasets.VoiceSample.from_prompt_and_raw(
"Transcribe <|audio|>", array, 16000
)
Loading