diff --git a/pedalboard/juce_overrides/juce_PatchedMP3AudioFormat.cpp b/pedalboard/juce_overrides/juce_PatchedMP3AudioFormat.cpp index bbbe4708..c161711c 100644 --- a/pedalboard/juce_overrides/juce_PatchedMP3AudioFormat.cpp +++ b/pedalboard/juce_overrides/juce_PatchedMP3AudioFormat.cpp @@ -814,6 +814,18 @@ struct MP3Frame { return ParseSuccessful::yes; } + int numSamples() const { + switch (layer) { + case 1: + return 384; + case 3: + if (lsf) + return 576; + default: + return 1152; + } + } + int layer, frameSize, numChannels, single; int lsf; // 0 = mpeg-1, 1 = mpeg-2/LSF bool mpeg25; // true = mpeg-2.5, false = mpeg-1/2 @@ -3282,6 +3294,7 @@ class PatchedMP3Reader : public AudioFormatReader { usesFloatingPointData = true; sampleRate = stream.frame.getFrequency(); numChannels = (unsigned int)stream.frame.numChannels; + samplesPerFrame = stream.frame.numSamples(); lengthInSamples = findLength(streamPos); } } @@ -3295,12 +3308,12 @@ class PatchedMP3Reader : public AudioFormatReader { } if (currentPosition != startSampleInFile) { - if (!stream.seek((int)(startSampleInFile / 1152 - 1))) { + if (!stream.seek((int)(startSampleInFile / samplesPerFrame - 1))) { currentPosition = -1; createEmptyDecodedData(); } else { decodedStart = decodedEnd = 0; - const int64 streamPos = stream.currentFrameIndex * 1152; + const int64 streamPos = stream.currentFrameIndex * samplesPerFrame; int toSkip = (int)(startSampleInFile - streamPos); jassert(toSkip >= 0); @@ -3356,6 +3369,7 @@ class PatchedMP3Reader : public AudioFormatReader { private: PatchedMP3Stream stream; int64 currentPosition; + int samplesPerFrame; enum { decodedDataSize = 1152 }; float decoded0[decodedDataSize], decoded1[decodedDataSize]; int decodedStart, decodedEnd; @@ -3370,6 +3384,12 @@ class PatchedMP3Reader : public AudioFormatReader { bool readNextBlock() { for (int attempts = 10; --attempts >= 0;) { int samplesDone = 0; + + if (stream.stream.isExhausted()) { + createEmptyDecodedData(); + return true; + } + const int result = stream.decodeNextBlock(decoded0, decoded1, samplesDone); @@ -3426,7 +3446,7 @@ class PatchedMP3Reader : public AudioFormatReader { } } - return numFrames * 1152; + return numFrames * samplesPerFrame; } JUCE_DECLARE_NON_COPYABLE_WITH_LEAK_DETECTOR(PatchedMP3Reader) diff --git a/tests/audio/correct/sample_mono_22050Hz.mp3 b/tests/audio/correct/sample_mono_22050Hz.mp3 new file mode 100644 index 00000000..90425b20 Binary files /dev/null and b/tests/audio/correct/sample_mono_22050Hz.mp3 differ diff --git a/tests/audio/correct/sample_mono_22050Hz.reencoded.mp3 b/tests/audio/correct/sample_mono_22050Hz.reencoded.mp3 new file mode 100644 index 00000000..b04603c1 Binary files /dev/null and b/tests/audio/correct/sample_mono_22050Hz.reencoded.mp3 differ diff --git a/tests/test_io.py b/tests/test_io.py index 6eff7a25..741ee48f 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -32,8 +32,10 @@ EXPECTED_DURATION_SECONDS = 5 EXPECT_LENGTH_TO_BE_EXACT = {"wav", "aiff", "caf", "ogg", "m4a", "mp4"} +MP3_FRAME_LENGTH_SAMPLES = 1152 TEST_AUDIO_FILES = { + 22050: glob.glob(os.path.join(os.path.dirname(__file__), "audio", "correct", "*22050*")), 44100: glob.glob(os.path.join(os.path.dirname(__file__), "audio", "correct", "*44100*")), 48000: glob.glob(os.path.join(os.path.dirname(__file__), "audio", "correct", "*48000*")), } @@ -830,7 +832,7 @@ def test_write_empty_file(extension: str, samplerate: float, num_channels: int): # The built-in JUCE MP3 reader (only used on Linux and Windows) # reads zero-length MP3 files as having exactly one frame. if "mp3" in extension and platform.system() != "Darwin": - assert af.frames <= 1152 + assert af.frames <= MP3_FRAME_LENGTH_SAMPLES contents = af.read(af.frames) np.testing.assert_allclose(np.zeros_like(contents), contents) else: @@ -1068,3 +1070,67 @@ def test_seek_accuracy(quality: int, chunk_duration: int, granularity: int, exte f" {offset:,}" ), ) + + +@pytest.mark.parametrize( + "audio_filename,samplerate", + [(a, s) for a, s in FILENAMES_AND_SAMPLERATES if s == 22050 and ".mp3" in a], +) +def test_22050Hz_mono_mp3(audio_filename: str, samplerate: float): + """ + File size estimation was broken for 22kHz mono MP3 files. + This test should catch that kind of problem. + """ + af = pedalboard.io.ReadableAudioFile(audio_filename, cross_platform_formats_only=True) + assert af.duration < 30.5 + assert af.samplerate == samplerate + data_read_all_at_once = af.read(af.frames) + + chunk_size = MP3_FRAME_LENGTH_SAMPLES + chunks = [] + af.seek(0) + while af.tell() < af.frames: + chunks.append(af.read(chunk_size)) + data_read_in_chunks = np.concatenate(chunks, axis=1) + np.testing.assert_allclose(data_read_all_at_once, data_read_in_chunks) + + +@pytest.mark.parametrize("quality", [f"V{x}" for x in range(0, 10)] + [320, 64]) +@pytest.mark.parametrize( + "samplerate", [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000] +) +@pytest.mark.parametrize("num_channels", [1, 2]) +def test_mp3_at_all_samplerates(quality: str, samplerate: float, num_channels: int): + secs = 2 + # Make an audio signal that is equal parts noise and silence to make sure + # we end up with a mixture of bitrates in the file: + signal = np.concatenate( + [np.random.rand(samplerate * secs) - 0.5, np.zeros(samplerate * secs)] + ).astype(np.float32) + if num_channels == 2: + signal = np.stack([signal] * num_channels) + else: + signal = np.expand_dims(signal, 0) + + buf = io.BytesIO() + buf.name = "test.mp3" + with pedalboard.io.AudioFile( + buf, "w", samplerate, num_channels=num_channels, quality=quality + ) as f: + f.write(signal) + + read_buf = io.BytesIO(buf.getvalue()) + + with pedalboard.io.ReadableAudioFile(read_buf, cross_platform_formats_only=True) as af: + # Allow for up to two MP3 frames of padding: + assert af.frames <= (signal.shape[-1] + MP3_FRAME_LENGTH_SAMPLES * 2) + assert af.frames >= signal.shape[-1] + # MP3 is lossy, so we can't expect the waveforms to be comparable; + # but at least make sure that the first half of the signal is loud + # and the second half is silent: + assert np.amax(np.mean(af.read(samplerate * secs), axis=0)) >= np.amax( + signal[:, : samplerate * secs] + ) + # skip a couple MP3 frames: + af.read(MP3_FRAME_LENGTH_SAMPLES * 2) + assert np.amax(np.mean(af.read(samplerate * secs), axis=0)) < 0.01