Skip to content

Commit

Permalink
remove sprint cache for ogg-zip
Browse files Browse the repository at this point in the history
  • Loading branch information
JackTemaki committed Jul 24, 2024
1 parent bd0eee4 commit 52da103
Showing 1 changed file with 3 additions and 148 deletions.
151 changes: 3 additions & 148 deletions tools/bliss-to-ogg-zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@
from decimal import Decimal
import tempfile
import gzip
import numpy
import xml.etree.ElementTree as ElementTree
import zipfile
import shutil
from subprocess import check_call
from glob import glob

import _setup_returnn_env # noqa
import returnn.sprint.cache


class BlissItem:
Expand Down Expand Up @@ -117,110 +114,6 @@ def iter_bliss(filename):
name_tree = name_tree[:-1]
count_tree = count_tree[:-1]


class SprintCacheHandler:
"""
This is just to apply the same silence trimming on the raw audio samples
which was applied on the features in the Sprint cache.
We can reconstruct this information because the Sprint cache also has the exact timing information.
"""

def __init__(self, opt, bliss_opt, raw_sample_rate, feat_sample_rate):
"""
:param str opt: either filename or filename pattern
:param str bliss_opt: either filename or filename pattern
:param int raw_sample_rate:
:param int feat_sample_rate:
"""
self.sprint_cache = self._load_sprint_cache(opt)
self.seg_times = self._collect_seg_times_from_bliss(bliss_opt)
self.raw_sample_rate = raw_sample_rate
self.feat_sample_rate = feat_sample_rate
self.pp_counter = 0

@staticmethod
def _load_sprint_cache(opt):
"""
:param str opt: either filename or filename pattern
:rtype: SprintCache.FileArchiveBundle|SprintCache.FileArchive
"""
if "*" in opt:
sprint_cache_fns = glob(opt)
assert sprint_cache_fns, "nothing found under sprint cache pattern %r" % (opt,)
sprint_cache = returnn.sprint.cache.FileArchiveBundle()
for fn in sprint_cache_fns:
print("Load Sprint cache:", fn)
sprint_cache.add_bundle_or_archive(fn)
else:
print("Load Sprint cache:", opt)
sprint_cache = returnn.sprint.cache.open_file_archive(opt, must_exists=True)
return sprint_cache

@staticmethod
def _collect_seg_times_from_bliss(opt):
"""
:param str opt: either filename or filename pattern
:rtype: dict[str,(Decimal,Decimal)]
"""
if "*" in opt:
items = []
fns = glob(opt)
assert fns, "nothing found under Bliss XML cache pattern %r" % (opt,)
for fn in fns:
print("Load Bliss XML:", fn)
items.extend(iter_bliss(fn))
else:
print("Load Bliss XML:", opt)
items = list(iter_bliss(opt))
return {seq.segment_name: (seq.start_time, seq.end_time) for seq in items}

# noinspection PyUnusedLocal
def feature_post_process(self, feature_data, seq_name, **kwargs):
"""
:param numpy.ndarray feature_data:
:param str seq_name:
:return: features
:rtype: numpy.ndarray
"""
assert feature_data.shape[1] == 1 # raw audio
self.pp_counter += 1
assert self.raw_sample_rate % self.feat_sample_rate == 0
num_frames_per_feat = self.raw_sample_rate // self.feat_sample_rate
assert num_frames_per_feat % 2 == 0
allowed_variance_num_frames = num_frames_per_feat // 2 # allow some variance
times, data = self.sprint_cache.read(seq_name, "feat")
assert len(times) == len(data)
prev_end_frame = None
res_feature_data = []
seq_time_offset = float(self.seg_times[seq_name][0])
for (start_time, end_time), feat in zip(times, data):
start_time -= seq_time_offset
end_time -= seq_time_offset
center_time = (start_time + end_time) / 2.0
start_frame = int(center_time * self.raw_sample_rate) - num_frames_per_feat // 2
assert 0 <= start_frame < feature_data.shape[0]
if prev_end_frame is not None:
if (
prev_end_frame - allowed_variance_num_frames
<= start_frame
<= prev_end_frame + allowed_variance_num_frames
):
start_frame = prev_end_frame
assert start_frame >= prev_end_frame
end_frame = start_frame + num_frames_per_feat
if feature_data.shape[0] < end_frame <= feature_data.shape[0] + allowed_variance_num_frames:
res_feature_data.append(feature_data[start_frame:])
res_feature_data.append(numpy.zeros((end_frame - feature_data.shape[0], 1), dtype=feature_data.dtype))
else:
assert end_frame <= feature_data.shape[0]
res_feature_data.append(feature_data[start_frame:end_frame])
prev_end_frame = end_frame
res_feature_data = numpy.concatenate(res_feature_data, axis=0)
assert res_feature_data.shape[0] % num_frames_per_feat == 0
assert res_feature_data.shape[0] // num_frames_per_feat == len(data)
return res_feature_data


def longest_common_prefix(strings):
"""
:param list[str]|set[str] strings:
Expand Down Expand Up @@ -271,7 +164,6 @@ def main():
"--no_conversion", help="skip ffmpeg call, assume audio is correct already", action="store_true"
)
arg_parser.add_argument("--no_cleanup", help="don't delete our temp files", action="store_true")
arg_parser.add_argument("--sprint_cache", help="filename of feature cache for synchronization")
arg_parser.add_argument("--raw_sample_rate", help="sample rate of audio input", type=int, default=8000)
arg_parser.add_argument("--feat_sample_rate", help="sample rate of features for sync", type=int, default=100)
arg_parser.add_argument("--ffmpeg_loglevel", help="loglevel for ffmpeg calls", type=str, default="info")
Expand Down Expand Up @@ -324,15 +216,6 @@ def main():
zip_filename = None
print("Dataset name:", name)

sprint_cache_handler = None
if args.sprint_cache:
sprint_cache_handler = SprintCacheHandler(
opt=args.sprint_cache,
bliss_opt=args.bliss_filename,
raw_sample_rate=args.raw_sample_rate,
feat_sample_rate=args.feat_sample_rate,
)

total_duration = Decimal(0)
total_num_chars = 0
temp_dir = tempfile.mkdtemp()
Expand All @@ -355,32 +238,9 @@ def main():
rec_filename_common_postfix
)
rec_name = rec_filename[len(rec_filename_common_prefix) : -len(rec_filename_common_postfix)]
if args.sprint_cache:
wav_tmp_filename = "%s/%s/%s_%s.wav" % (dest_dirname, rec_name, seq.start_time, seq.end_time)
os.makedirs(os.path.dirname(wav_tmp_filename), exist_ok=True)
cmd = ["ffmpeg"]
if args.ffmpeg_acodec:
cmd += ["-acodec", args.ffmpeg_acodec] # https://trac.ffmpeg.org/ticket/2810
cmd += ["-i", rec_filename, "-ss", str(seq.start_time), "-t", str(duration)]
if args.number_of_channels > 0:
cmd += ["-ac", str(args.number_of_channels)]
cmd += [wav_tmp_filename, "-loglevel", args.ffmpeg_loglevel]
print("$ %s" % " ".join(cmd))
check_call(cmd)
import soundfile # pip install pysoundfile

audio, sample_rate = soundfile.read(wav_tmp_filename)
assert sample_rate == args.raw_sample_rate
audio_synced = sprint_cache_handler.feature_post_process(numpy.expand_dims(audio, axis=1), seq.segment_name)
soundfile.write(wav_tmp_filename, audio_synced, args.raw_sample_rate)
source_filename = wav_tmp_filename
start_time = 0
limit_duration = False
else:
soundfile = audio_synced = sample_rate = wav_tmp_filename = None
source_filename = rec_filename
start_time = seq.start_time
limit_duration = True
source_filename = rec_filename
start_time = seq.start_time
limit_duration = True
dest_filename = "%s/%s/%s_%s.ogg" % (dest_dirname, rec_name, seq.start_time, seq.end_time)
os.makedirs(os.path.dirname(dest_filename), exist_ok=True)
if args.no_ogg:
Expand Down Expand Up @@ -410,11 +270,6 @@ def main():
cmd += [dest_filename, "-loglevel", args.ffmpeg_loglevel]
print("$ %s" % " ".join(cmd))
check_call(cmd)
if args.sprint_cache:
audio_ogg, sample_rate_ogg = soundfile.read(dest_filename)
assert len(audio_synced) == len(audio_ogg), "Number of frames in synced wav and converted ogg do not match"
assert sample_rate == sample_rate_ogg, "Sample rates in synced wav and converted ogg do not match"
os.remove(wav_tmp_filename)
dest_meta_file.write(
"{'text': %r, 'speaker_name': %r, 'file': %r, 'seq_name': %r, 'duration': %s},\n"
% (seq.orth, seq.speaker_name, dest_filename[len(dest_dirname) + 1 :], seq.segment_name, duration)
Expand Down

0 comments on commit 52da103

Please sign in to comment.