Skip to content

Commit

Permalink
Merge pull request #36 from shasheene/master
Browse files Browse the repository at this point in the history
Applies formatting changes, fixes VTT output
  • Loading branch information
abhirooptalasila authored Sep 9, 2021
2 parents 47ddcf1 + 40485c5 commit b053484
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 95 deletions.
1 change: 1 addition & 0 deletions autosub/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

11 changes: 6 additions & 5 deletions autosub/audioProcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ def extract_audio(input_file, audio_file_name):
input_file: input video file
audio_file_name: save audio WAV file with same filename as video file
"""

try:
command = ["ffmpeg", "-hide_banner", "-loglevel", "warning", "-i", input_file, "-ac", "1", "-ar", "16000", "-vn", "-f", "wav", audio_file_name]
command = ["ffmpeg", "-hide_banner", "-loglevel", "warning", "-i", input_file, "-ac", "1", "-ar", "16000",
"-vn", "-f", "wav", audio_file_name]
ret = subprocess.call(command)
print("Extracted audio to audio/{}".format(audio_file_name.split("/")[-1]))
except Exception as e:
Expand All @@ -26,15 +27,15 @@ def extract_audio(input_file, audio_file_name):
def convert_samplerate(audio_path, desired_sample_rate):
"""Convert extracted audio to the format expected by DeepSpeech
***WONT be called as extract_audio() converts the audio to 16kHz while saving***
Args:
audio_path: audio file path
desired_sample_rate: DeepSpeech expects 16kHz
desired_sample_rate: DeepSpeech expects 16kHz
Returns:
numpy buffer: audio signal stored in numpy array
"""

sox_cmd = "sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer \
--endian little --compression 0.0 --no-dither norm -3.0 - ".format(
quote(audio_path), desired_sample_rate)
Expand Down
48 changes: 25 additions & 23 deletions autosub/featureExtraction.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import math
import numpy as np
from scipy.fftpack import fft
from scipy.signal import lfilter
from scipy.fftpack.realtransforms import dct

eps = 0.00000001


def zero_crossing_rate(frame):
"""Computes zero crossing rate of frame
"""

count = len(frame)
count_zero = np.sum(np.abs(np.diff(np.sign(frame)))) / 2
return np.float64(count_zero) / np.float64(count - 1.0)
Expand All @@ -21,14 +20,14 @@ def zero_crossing_rate(frame):
def energy(frame):
"""Computes signal energy of frame
"""

return np.sum(frame ** 2) / np.float64(len(frame))


def energy_entropy(frame, n_short_blocks=10):
"""Computes entropy of energy
"""

# total frame energy
frame_energy = np.sum(frame ** 2)
frame_length = len(frame)
Expand All @@ -44,7 +43,7 @@ def energy_entropy(frame, n_short_blocks=10):

# Compute entropy of the normalized sub-frame energies:
entropy = -np.sum(s * np.log2(s + eps))

return entropy


Expand All @@ -54,7 +53,7 @@ def energy_entropy(frame, n_short_blocks=10):
def spectral_centroid_spread(fft_magnitude, sampling_rate):
"""Computes spectral centroid of frame (given abs(FFT))
"""

ind = (np.arange(1, len(fft_magnitude) + 1)) * \
(sampling_rate / (2.0 * len(fft_magnitude)))

Expand All @@ -79,7 +78,7 @@ def spectral_centroid_spread(fft_magnitude, sampling_rate):
def spectral_entropy(signal, n_short_blocks=10):
"""Computes the spectral entropy
"""

# number of frame samples
num_frames = len(signal)

Expand All @@ -105,12 +104,12 @@ def spectral_entropy(signal, n_short_blocks=10):

def spectral_flux(fft_magnitude, previous_fft_magnitude):
"""Computes the spectral flux feature of the current frame
Args:
fft_magnitude : the abs(fft) of the current frame
previous_fft_magnitude : the abs(fft) of the previous frame
"""

# compute the spectral flux as the sum of square distances:
fft_sum = np.sum(fft_magnitude + eps)
previous_fft_sum = np.sum(previous_fft_magnitude + eps)
Expand All @@ -124,24 +123,25 @@ def spectral_flux(fft_magnitude, previous_fft_magnitude):
def spectral_rolloff(signal, c):
"""Computes spectral roll-off
"""

energy = np.sum(signal ** 2)
fft_length = len(signal)
threshold = c * energy
# Ffind the spectral rolloff as the frequency position
# Ffind the spectral rolloff as the frequency position
# where the respective spectral energy is equal to c*totalEnergy
cumulative_sum = np.cumsum(signal ** 2) + eps
a = np.nonzero(cumulative_sum > threshold)[0]
if len(a) > 0:
sp_rolloff = np.float64(a[0]) / (float(fft_length))
else:
sp_rolloff = 0.0

return sp_rolloff


def mfcc_filter_banks(sampling_rate, num_fft, lowfreq=133.33, linc=200 / 3,
logsc=1.0711703, num_lin_filt=13, num_log_filt=27):
"""Computes the triangular filterbank for MFCC computation
"""Computes the triangular filterbank for MFCC computation
(used in the stFeatureExtraction function before the stMFCC function call)
This function is taken from the scikits.talkbox library (MIT Licence):
https://pypi.python.org/pypi/scikits.talkbox
Expand Down Expand Up @@ -189,13 +189,13 @@ def mfcc(fft_magnitude, fbank, num_mfcc_feats):
Args:
fft_magnitude : fft magnitude abs(FFT)
fbank : filter bank (see mfccInitFilterBanks)
Returns:
ceps : MFCCs (13 element vector)
Note: MFCC calculation is, in general, taken from the
Note: MFCC calculation is, in general, taken from the
scikits.talkbox library (MIT Licence),
# with a small number of modifications to make it more
# with a small number of modifications to make it more
compact and suitable for the pyAudioAnalysis Lib
"""

Expand All @@ -208,7 +208,7 @@ def chroma_features_init(num_fft, sampling_rate):
"""This function initializes the chroma matrices used in the calculation
of the chroma features
"""

freqs = np.array([((f + 1) * sampling_rate) /
(2 * num_fft) for f in range(num_fft)])
cp = 27.50
Expand Down Expand Up @@ -265,8 +265,10 @@ def chroma_features(signal, sampling_rate, num_fft):

return chroma_names, final_matrix


""" Windowing and feature extraction """


def feature_extraction(signal, sampling_rate, window, step, deltas=True):
"""This function implements the shor-term windowing process.
For each short-term window a set of features is extracted.
Expand All @@ -278,11 +280,11 @@ def feature_extraction(signal, sampling_rate, window, step, deltas=True):
window : the short-term window size (in samples)
step : the short-term window step (in samples)
deltas : (opt) True/False if delta features are to be computed
Returns:
features (numpy.ndarray) : contains features
features (numpy.ndarray) : contains features
(n_feats x numOfShortTermWindows)
feature_names (numpy.ndarray) : contains feature names
feature_names (numpy.ndarray) : contains feature names
(n_feats x numOfShortTermWindows)
"""

Expand Down Expand Up @@ -409,5 +411,5 @@ def feature_extraction(signal, sampling_rate, window, step, deltas=True):
fft_magnitude_previous = fft_magnitude.copy()

features = np.concatenate(features, 1)
return features, feature_names

return features, feature_names
52 changes: 29 additions & 23 deletions autosub/main.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import os
import re
import shutil
import sys
import wave
import shutil
import argparse
import subprocess

import numpy as np
from deepspeech import Model
from tqdm import tqdm
from deepspeech import Model, version
from segmentAudio import silenceRemoval

from audioProcessing import extract_audio, convert_samplerate
from segmentAudio import silenceRemoval
from writeToFile import write_to_file

# Line count for SRT file
line_count = 1


def sort_alphanumeric(data):
"""Sort function to sort os.listdir() alphanumerically
Helps to process audio files sequentially after splitting
Helps to process audio files sequentially after splitting
Args:
data : file name
"""

convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted(data, key = alphanum_key)
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]

return sorted(data, key=alphanum_key)


def ds_process_audio(ds, audio_file, output_file_handle_dict, split_duration):
Expand All @@ -42,12 +44,12 @@ def ds_process_audio(ds, audio_file, output_file_handle_dict, split_duration):
output_file_handle_dict : Mapping of subtitle format (eg, 'srt') to open file_handle
split_duration: for long audio segments, split the subtitle based on this number of seconds
"""

global line_count
fin = wave.open(audio_file, 'rb')
fs_orig = fin.getframerate()
desired_sample_rate = ds.sampleRate()

# Check if sampling rate is required rate (16000)
# won't be carried out as FFmpeg already converts to 16kHz
if fs_orig != desired_sample_rate:
Expand All @@ -58,7 +60,7 @@ def ds_process_audio(ds, audio_file, output_file_handle_dict, split_duration):
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)

fin.close()

# Perform inference on audio segment
metadata = ds.sttWithMetadata(audio)

Expand All @@ -81,13 +83,14 @@ def ds_process_audio(ds, audio_file, output_file_handle_dict, split_duration):
cues += [float(limits[0]) + token.start_time]
# time duration is exceeded and at the next word boundary
needs_split = ((token.start_time - previous_end_time) > split_duration) and token.text == " "
is_final_character = current_token_index+1 == num_tokens
is_final_character = current_token_index + 1 == num_tokens
# Write out the line
if needs_split or is_final_character:
# Determine the timestamps
split_limits = [float(limits[0]) + previous_end_time, float(limits[0]) + token.start_time]
# Convert character list to string. Upper bound has plus 1 as python list slices are [inclusive, exclusive]
split_inferred_text = ''.join([x.text for x in metadata.transcripts[0].tokens[split_start_index:current_token_index+1]])
split_inferred_text = ''.join(
[x.text for x in metadata.transcripts[0].tokens[split_start_index:current_token_index + 1]])
write_to_file(output_file_handle_dict, split_inferred_text, line_count, split_limits, cues)
# Reset and update indexes for the next subtitle split
previous_end_time = token.start_time
Expand All @@ -113,13 +116,13 @@ def main():
print("Scorer: ", os.path.join(os.getcwd(), x))
ds_scorer = os.path.join(os.getcwd(), x)

# Load DeepSpeech model
# Load DeepSpeech model
try:
ds = Model(ds_model)
except:
print("Invalid model file. Exiting\n")
sys.exit(1)

try:
ds.enableExternalScorer(ds_scorer)
except:
Expand All @@ -130,8 +133,11 @@ def main():
parser.add_argument('--file', required=True,
help='Input video file')
parser.add_argument('--format', choices=supported_output_formats, nargs='+',
help='Create only certain output formats rather than all formats', default=supported_output_formats)
parser.add_argument('--split-duration', type=float, help='Split run-on sentences exceededing this duration (in seconds) into multiple subtitles', default=5)
help='Create only certain output formats rather than all formats',
default=supported_output_formats)
parser.add_argument('--split-duration', type=float,
help='Split run-on sentences exceededing this duration (in seconds) into multiple subtitles',
default=5)
args = parser.parse_args()

if os.path.isfile(args.file):
Expand All @@ -140,7 +146,7 @@ def main():
else:
print(args.file, ": No such file exists")
sys.exit(1)

base_directory = os.getcwd()
output_directory = os.path.join(base_directory, "output")
audio_directory = os.path.join(base_directory, "audio")
Expand Down Expand Up @@ -170,15 +176,15 @@ def main():

# Extract audio from input video file
extract_audio(input_file, audio_file_name)

print("Splitting on silent parts in audio file")
silenceRemoval(audio_file_name)

print("\nRunning inference:")

for file in tqdm(sort_alphanumeric(os.listdir(audio_directory))):
audio_segment_path = os.path.join(audio_directory, file)

# Dont run inference on the original audio file
if audio_segment_path.split(os.sep)[-1] != audio_file_name.split(os.sep)[-1]:
ds_process_audio(ds, audio_segment_path, output_file_handle_dict, split_duration=args.split_duration)
Expand Down
Loading

0 comments on commit b053484

Please sign in to comment.