Skip to content

Commit

Permalink
Model loading, test-time augmentation.
Browse files Browse the repository at this point in the history
- Set up UNet2DS and the unet2ds_nf example to load the architecture and weights from the same HDF5 file.
- Implemented 8x test-time augmentation for UNet2DS. This improved the test score from 0.535 to 0.542.
  • Loading branch information
alexklibisz committed Jul 11, 2017
1 parent 87869d9 commit f1b33bf
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 202 deletions.
256 changes: 119 additions & 137 deletions deepcalcium/models/neurons/unet_2d_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from itertools import cycle
from keras.callbacks import Callback, ModelCheckpoint, EarlyStopping, CSVLogger, ReduceLROnPlateau
from keras.optimizers import Adam
from keras.losses import binary_crossentropy
from keras_contrib.callbacks import DeadReluDetector
from math import ceil
from os import path, mkdir, remove
Expand All @@ -23,8 +24,9 @@

from deepcalcium.utils.runtime import funcname
from deepcalcium.datasets.nf import nf_mask_metrics
from deepcalcium.utils.keras_helpers import MetricsPlotCallback
from deepcalcium.utils.keras_helpers import MetricsPlotCallback, F1, prec, reca, dice, dicesq, dice_loss, dicesq_loss, posyt, posyp, load_model_with_new_input_shape
from deepcalcium.utils.visuals import mask_outlines
from deepcalcium.utils.data_utils import INVERTIBLE_2D_AUGMENTATIONS


class ValidationMetricsCB(Callback):
Expand Down Expand Up @@ -64,11 +66,14 @@ def on_epoch_end(self, epoch, logs={}):
logger.info('\n')
tic = time()

# Save weights from the training model and load them into validation model.
path_weights = '%s/weights.tmp' % self.cpdir
self.model.save_weights(path_weights)
self.model_val.load_weights(path_weights)
remove(path_weights)
# Transfer weights from the training model to the validation model.
self.model_val.set_weights(self.model.get_weights())

# # Save weights from the training model and load them into validation model.
# path_weights = '%s/weights.tmp' % self.cpdir
# self.model.save_weights(path_weights)
# self.model_val.load_weights(path_weights)
# remove(path_weights)

# Tracking precision, recall, f1 values.
pp, rr, ff = [], [], []
Expand Down Expand Up @@ -207,7 +212,6 @@ def conv_layer(nb_filters, x):
x = Conv2D(2, 1)(x)
x = Activation('softmax')(x)
x = Lambda(lambda x: x[:, :, :, -1])(x)
# x = Lambda(lambda x: x[:, :, :, -1], output_shape=window_shape)(x)

return Model(inputs=inputs, outputs=x)

Expand Down Expand Up @@ -239,16 +243,21 @@ def __init__(self, cpdir='%s/.deep-calcium-datasets/tmp' % path.expanduser('~'),
if not path.exists(self.cpdir):
mkdir(self.cpdir)

def fit(self, datasets, weights_path=None, shape_trn=(96, 96), shape_val=(512, 512), batch_size_trn=32,
cobj = [F1, prec, reca, dice, dicesq, posyt, posyp, dice_loss, dicesq_loss]
self.custom_objects = {x.__name__: x for x in cobj}

def fit(self, datasets, model_path=None, proceed=False, shape_trn=(96, 96), shape_val=(512, 512), batch_size_trn=32,
batch_size_val=1, nb_steps_trn=200, nb_epochs=20, prop_trn=0.75, prop_val=0.25, keras_callbacks=[],
optimizer=Adam(0.002), loss='binary_crossentropy'):
optimizer=Adam(0.002), loss=binary_crossentropy):
"""Constructs network based on parameters and trains with the given data.
# Arguments
datasets: List of HDF5 datasets. Each of these will be passed to self.series_summary_func and
self.mask_summary_func to compute its series and mask summaries, so the HDF5 structure
should be compatible with those functions.
weights_path: filesystem path to weights that should be loaded into the network.
model_path: filesystem path to serialized model that should be loaded into the network.
proceed: whether to continue training where the model left off or start over. Only relevant when a
model_path is given because it uses the saved optimizer state.
shape_trn: (height, width) shape of the windows cropped for training.
shape_val: (height, width) shape of the windows used for validation.
batch_size_trn: Batch size used for training.
Expand All @@ -258,92 +267,67 @@ def fit(self, datasets, weights_path=None, shape_trn=(96, 96), shape_val=(512, 5
prop_val: Proportion of each summary image used to validate, cropped from the bottom of the image.
keras_callbacks: List of callbacks appended to internal callbacks for training.
optimizer: Instanitated keras optimizer.
loss: Loss function, currently either binary_crossentropy or dice_squared from https://arxiv.org/abs/1606.04797.
loss: Loss function, one of binary_crossentropy, dice, or dice-squared from https://arxiv.org/abs/1606.04797.
# Returns
history: the Keras training history as a dictionary of metrics and their values after each epoch.
"""

# Error check.
assert len(shape_trn) == 2
assert len(shape_val) == 2
assert shape_trn[0] == shape_trn[1]
assert shape_val[0] == shape_val[1]
assert 0 < prop_trn < 1
assert 0 < prop_val < 1
assert loss in {'binary_crossentropy', 'dice_squared'}

logger = logging.getLogger(funcname())

# Define, compile neural net.
model = self.net_builder(shape_trn)
model_val = self.net_builder(shape_val)
json.dump(model.to_json(), open('%s/model.json' % self.cpdir, 'w'), indent=2)

# Metric: True positive proportion.
def ytpos(yt, yp):
size = K.sum(K.ones_like(yt))
return K.sum(yt) / (size + K.epsilon())

# Metric: Predicted positive proportion.
def yppos(yt, yp):
size = K.sum(K.ones_like(yp))
return K.sum(K.round(yp)) / (size + K.epsilon())

# Metric: Binary pixel-wise precision.
def prec(yt, yp):
yp = K.round(yp)
tp = K.sum(yt * yp)
fp = K.sum(K.clip(yp - yt, 0, 1))
return tp / (tp + fp + K.epsilon())

# Metric: Binary pixel-wise recall.
def reca(yt, yp):
yp = K.round(yp)
tp = K.sum(yt * yp)
fn = K.sum(K.clip(yt - yp, 0, 1))
return tp / (tp + fn + K.epsilon())

# Metric: Squared dice coefficient from VNet paper.
def dice_squared(yt, yp):
nmr = 2 * K.sum(yt * yp)
dnm = K.sum(yt**2) + K.sum(yp**2) + K.epsilon()
return (nmr / dnm)

def dice_squared_loss(yt, yp):
return 1 - dice_squared(yt, yp)

if loss == 'dice_squared':
loss = dice_squared_loss
assert not (proceed and not model_path)

losses = {
'binary_crossentropy': binary_crossentropy,
'dice_loss': dice_loss,
'dicesq_loss': dicesq_loss
}
assert loss in losses.keys() or loss in losses.values()
loss = losses[loss] if type(loss) == str else loss

# Load network from disk.
if model_path:
lmwnis = load_model_with_new_input_shape
model = lmwnis(model_path, shape_trn, compile=proceed,
custom_objects=self.custom_objects)
model_val = lmwnis(model_path, shape_val, compile=False,
custom_objects=self.custom_objects)

# Define, compile network.
else:
loss = 'binary_crossentropy'
model = self.net_builder(shape_trn)
model_val = self.net_builder(shape_val)
model.summary()

model.compile(optimizer=optimizer, loss=loss,
metrics=[dice_squared, ytpos, yppos, prec, reca])
model.summary()

if weights_path is not None:
model.load_weights(weights_path)
logger.info('Loaded weights from %s.' % weights_path)
if not proceed:
model.compile(optimizer=optimizer, loss=loss,
metrics=[F1, prec, reca, dice, dicesq, posyt, posyp])

# Pre-compute summaries once to avoid problems with accessing HDF5.
S_summ = [self.series_summary_func(ds) for ds in datasets]
M_summ = [self.mask_summary_func(ds) for ds in datasets]

# Define generators for training and validation data.
y_coords_trn = [(0, int(s.shape[0] * prop_trn)) for s in S_summ]
gen_trn = self.batch_gen_fit(
S_summ, M_summ, y_coords_trn, batch_size_trn, shape_trn, nb_max_augment=15)
yctrn = [(0, int(s.shape[0] * prop_trn)) for s in S_summ]
gen_trn = self.batch_gen(S_summ, M_summ, yctrn, batch_size_trn,
shape_trn, nb_max_augment=15)

# Validation setup.
y_coords_val = [(s.shape[0] - int(s.shape[0] * prop_val), s.shape[0])
for s in S_summ]

ycval = [(s.shape[0] - int(s.shape[0] * prop_val), s.shape[0]) for s in S_summ]
names = [ds.attrs['name'] for ds in datasets]

callbacks = [
ValidationMetricsCB(model_val, S_summ, M_summ,
names, y_coords_val, self.cpdir),
ValidationMetricsCB(model_val, S_summ, M_summ, names, ycval, self.cpdir),
CSVLogger('%s/metrics.csv' % self.cpdir),
MetricsPlotCallback('%s/metrics.png' % self.cpdir,
'%s/metrics.csv' % self.cpdir),
ModelCheckpoint('%s/weights_val_nf_f1_mean.hdf5' % self.cpdir, mode='max',
ModelCheckpoint('%s/model_val_nf_f1_mean.hdf5' % self.cpdir, mode='max',
monitor='val_nf_f1_mean', save_best_only=True, verbose=1),
EarlyStopping(monitor='val_nf_f1_mean', min_delta=1e-3,
patience=10, verbose=1, mode='max'),
Expand All @@ -360,7 +344,7 @@ def dice_squared_loss(yt, yp):

return trained.history

def batch_gen_fit(self, S_summ, M_summ, y_coords, batch_size, window_shape, nb_max_augment=0):
def batch_gen(self, S_summ, M_summ, y_coords, batch_size, window_shape, nb_max_augment=0):
"""Builds and yields batches of image windows and corresponding mask windows for training.
Includes random data augmentation.
Expand Down Expand Up @@ -460,94 +444,92 @@ def stretch(a, b):

yield s_batch, m_batch

def evaluate(self, datasets, weights_path=None, window_shape=(512, 512), save=False):
"""Evaluates predicted masks vs. true masks for the given sequences."""

logger = logging.getLogger(funcname())

model = self.net_builder(window_shape)
if weights_path is not None:
model.load_weights(weights_path)
logger.info('Loaded weights from %s.' % weights_path)
def predict(self, datasets, model_path, window_shape=(512, 512), print_scores=False, save=False, augmentation=False):
"""Make predictions on the given datasets. Currently uses batches of 1.
# Currently only supporting full-sized windows.
assert window_shape == (512, 512), 'TODO: implement variable window sizes.'

# Padding helper.
_, hw, ww = model.input_shape
pad = lambda x: np.pad(
x, ((0, hw - x.shape[0]), (0, ww - x.shape[1])), mode='reflect')

# Evaluate each sequence, mask pair.
mean_prec, mean_reca, mean_comb = 0., 0., 0.
for ds in datasets:
name = ds.attrs['name']
s = self.series_summary_func(ds)
m = self.mask_summary_func(ds)
hs, ws = s.shape

# Pad and make prediction.
s_batch = np.zeros((1, ) + window_shape)
s_batch[0] = pad(s)
mp = model.predict(s_batch)[0, :hs, :ws].round()

# Track scores.
prec, reca, incl, excl, comb = nf_mask_metrics(m, mp)
logger.info('%s: prec=%.3lf, reca=%.3lf, incl=%.3lf, excl=%.3lf, comb=%.3lf' % (
name, prec, reca, incl, excl, comb))
mean_prec += prec / len(datasets)
mean_reca += reca / len(datasets)
mean_comb += comb / len(datasets)

# Save mask and prediction.
if save:
imsave('%s/%s_mp.png' % (self.cpdir, name),
mask_outlines(s, [m, mp], ['blue', 'red']))

logger.info('Mean prec=%.3lf, reca=%.3lf, comb=%.3lf' %
(mean_prec, mean_reca, mean_comb))

return mean_comb
Arguments:
datasets: List of HDF5 datasets. Each of these will be passed to self.series_summary_func and
self.mask_summary_func to compute its series and mask summaries, so the HDF5 structure
should be compatible with those functions.
model_path: Path to the serialized Keras model HDF5 file. This file should include both the
architecture and the weights.
window_shape: Tuple window shape used for making predictions. Summary images with windows smaller
than this are padded up to match this shape.
print_scores: Flag to print the Neurofinder evaluation metrics. Only works when the datasets include
ground-truth masks.
save: Flag to save the predictions as PNGs with outlines around the predicted neurons in red. If
the ground-truth masks are given, it will also show outlines around the groun-truth neurons.
augmentation: Flag to perform 8x test-time augmentation. Predictions are made for each of the
augmentations, the augmentation is inverted to its original orientation, and the average
of all the augmentations is used as the prediction. In practice, this improved a
Neurofinder submission from 0.5356 to 0.542.
Returns:
Mp: list of the predicted masks stored as Numpy arrays.
def predict(self, datasets, weights_path=None, window_shape=(512, 512), batch_size=10, save=False):
"""Predicts masks for the given sequences. Optionally saves the masks. Returns the masks as numpy arrays in order corresponding the given sequences."""
"""

logger = logging.getLogger(funcname())

model = self.net_builder(window_shape)
if weights_path is not None:
model.load_weights(weights_path)
logger.info('Loaded weights from %s.' % weights_path)
model = load_model_with_new_input_shape(model_path, window_shape, compile=False,
custom_objects=self.custom_objects)
logger.info('Loaded model from %s.' % model_path)

# Currently only supporting full-sized windows.
assert window_shape == (512, 512), 'TODO: implement variable window sizes.'

# Padding helper.
_, hw, ww = model.input_shape
pad = lambda x: np.pad(x, ((0, hw - x.shape[0]), (0, ww - x.shape[1])), 'reflect')
def pad(x):
_, hw, ww = model.input_shape
return np.pad(x, ((0, hw - x.shape[0]), (0, ww - x.shape[1])), mode='reflect')

# Store predictions.
# Store predicted masks and scores.
Mp = []
mean_prec, mean_reca, mean_comb = 0., 0., 0.

# Evaluate each sequence, mask pair.
mean_prec, mean_reca, mean_comb = 0., 0., 0.
for ds in datasets:
name = ds.attrs['name']
s = self.series_summary_func(ds)
hs, ws = s.shape

# Pad and make prediction.
# Pad and make prediction(s).
s_batch = np.zeros((1, ) + window_shape)
s_batch[0] = pad(s)
mp = model.predict(s_batch)[0, :hs, :ws].round()
assert mp.shape == s.shape

if augmentation:
mp = np.zeros_like(s)
for name, aug, inv in INVERTIBLE_2D_AUGMENTATIONS:
mpaug = model.predict(aug(s_batch))
mp += inv(mpaug)[0, :hs, :ws] / len(INVERTIBLE_2D_AUGMENTATIONS)
mp = mp.round()

else:
mp = model.predict(s_batch)[0, :hs, :ws].round()

Mp.append(mp)

# Save prediction.
if save:
# Track scores.
if print_scores:
m = self.mask_summary_func(ds)
prec, reca, incl, excl, comb = nf_mask_metrics(m, mp)
logger.info('%s: prec=%.3lf, reca=%.3lf, incl=%.3lf, excl=%.3lf, comb=%.3lf' % (
name, prec, reca, incl, excl, comb))
mean_prec += prec / len(datasets)
mean_reca += reca / len(datasets)
mean_comb += comb / len(datasets)

# Save mask and prediction.
if save and 'masks' in ds:
m = self.mask_summary_func(ds)
outlined = mask_outlines(s, [m, mp], ['blue', 'red'])
imsave('%s/%s_mp.png' % (self.cpdir, name), outlined)

elif save:
outlined = mask_outlines(s, [mp], ['red'])
imsave('%s/%s_mp.png' % (self.cpdir, name), outlined)

logger.info('%s prediction complete.' % name)
if print_scores:
logger.info('Mean prec=%.3lf, reca=%.3lf, comb=%.3lf' %
(mean_prec, mean_reca, mean_comb))

return Mp
2 changes: 1 addition & 1 deletion deepcalcium/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Augmentations that can be applied to a batch of 2D images and inverted.
# Structure is the augmentation name, the augmentation, and the inverse
# of the augmentation. Intended for test-time augmentation for segmentation.
INVERTIBLE_2D_BATCH_AUGMENTATIONS = [
INVERTIBLE_2D_AUGMENTATIONS = [
('identity',
lambda x: x,
lambda x: x),
Expand Down
Loading

0 comments on commit f1b33bf

Please sign in to comment.