Skip to content

Commit

Permalink
Fix #360; Compare SortaGrad vs Existing Curriculum Learning
Browse files Browse the repository at this point in the history
  • Loading branch information
andi4191 committed Jul 5, 2017
1 parent 08ff89c commit 6d8b85c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
9 changes: 9 additions & 0 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')

# Sorta Grad Curriculum Learning

tf.app.flags.DEFINE_boolean ('sorta_grad', False, 'enable sorta grad curriculum learning - defaults to False')

for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']:
tf.app.flags.DEFINE_float('%s_stddev' % var, None, 'standard deviation to use when initialising %s' % var)

Expand Down Expand Up @@ -1222,6 +1226,7 @@ def _next_epoch(self):
if result:
# Increment the epoch index - shared among train and test 'state'
self._epoch += 1
data_sets.train.update_files_list(self._epoch, FLAGS.sorta_grad)
return result

def _end_training(self):
Expand Down Expand Up @@ -1407,6 +1412,7 @@ def train(server=None):
global_step = tf.Variable(0, trainable=False, name='global_step')

# Read all data sets
global data_sets
data_sets = read_data_sets(FLAGS.train_files.split(','),
FLAGS.dev_files.split(','),
FLAGS.test_files.split(','),
Expand All @@ -1423,6 +1429,9 @@ def train(server=None):
# Get the data sets
switchable_data_set = SwitchableDataSet(data_sets)

# To make sure that sorta_grad is defined as per FLAG parameter - pre-procesing step
data_sets.train.update_files_list(0, FLAGS.sorta_grad)

# Create the optimizer
optimizer = create_optimizer()

Expand Down
15 changes: 11 additions & 4 deletions util/data_set_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, files_list, thread_count, batch_size, numcep, numcontext, nex
self._thread_count = thread_count
self._files_list = self._create_files_list(files_list)
self._next_index = next_index
self._files = files_list

def _get_device_count(self):
available_gpus = get_available_gpus()
Expand All @@ -74,13 +75,19 @@ def start_queue_threads(self, session, coord):
def close_queue(self, session):
session.run(self._close_op)

def _create_files_list(self, files_list):
def update_files_list(self, epoch, sorta_grad = True):
self._files_list = self._create_files_list(self._files, epoch, sorta_grad)

def _create_files_list(self, files_list, epoch = 0, sorta_grad=True):
# 1. Sort by wav filesize
# 2. Select just wav filename and transcript columns
# 3. Return a NumPy representation
return files_list.sort_values(by="wav_filesize") \
.ix[:, ["wav_filename", "transcript"]] \
.values
if epoch == 0 and sorta_grad is True:
return files_list.sort_values(by="wav_filesize").ix[:,["wav_filename", "transcript"]].values
else:
# Randomizing the data frame
files_list.sample(frac=1)
return files_list.ix[:,["wav_filename", "transcript"]].values

def _indices(self):
index = -1
Expand Down

0 comments on commit 6d8b85c

Please sign in to comment.