forked from pierluigiferrari/ssd_keras
-
Notifications
You must be signed in to change notification settings - Fork 16
/
keras_ssd_loss.py
207 lines (171 loc) · 12.2 KB
/
keras_ssd_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
'''
The Keras-compatible loss function for the SSD model. Currently supports TensorFlow only.
Copyright (C) 2017 Pierluigi Ferrari
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
'''
import tensorflow as tf
class SSDLoss:
'''
The SSD loss, see https://arxiv.org/abs/1512.02325.
'''
def __init__(self,
neg_pos_ratio=3,
n_neg_min=0,
alpha=1.0):
'''
Arguments:
neg_pos_ratio (int, optional): The maximum ratio of negative (i.e. background)
to positive ground truth boxes to include in the loss computation.
There are no actual background ground truth boxes of course, but `y_true`
contains anchor boxes labeled with the background class. Since
the number of background boxes in `y_true` will usually exceed
the number of positive boxes by far, it is necessary to balance
their influence on the loss. Defaults to 3 following the paper.
n_neg_min (int, optional): The minimum number of negative ground truth boxes to
enter the loss computation *per batch*. This argument can be used to make
sure that the model learns from a minimum number of negatives in batches
in which there are very few, or even none at all, positive ground truth
boxes. It defaults to 0 and if used, it should be set to a value that
stands in reasonable proportion to the batch size used for training.
alpha (float, optional): A factor to weight the localization loss in the
computation of the total loss. Defaults to 1.0 following the paper.
'''
self.neg_pos_ratio = neg_pos_ratio
self.n_neg_min = n_neg_min
self.alpha = alpha
def smooth_L1_loss(self, y_true, y_pred):
'''
Compute smooth L1 loss, see references.
Arguments:
y_true (nD tensor): A TensorFlow tensor of any shape containing the ground truth data.
In this context, the expected tensor has shape `(batch_size, #boxes, 4)` and
contains the ground truth bounding box coordinates, where the last dimension
contains `(xmin, xmax, ymin, ymax)`.
y_pred (nD tensor): A TensorFlow tensor of identical structure to `y_true` containing
the predicted data, in this context the predicted bounding box coordinates.
Returns:
The smooth L1 loss, a nD-1 Tensorflow tensor. In this context a 2D tensor
of shape (batch, n_boxes_total).
References:
https://arxiv.org/abs/1504.08083
'''
absolute_loss = tf.abs(y_true - y_pred)
square_loss = 0.5 * (y_true - y_pred)**2
l1_loss = tf.where(tf.less(absolute_loss, 1.0), square_loss, absolute_loss - 0.5)
return tf.reduce_sum(l1_loss, axis=-1)
def log_loss(self, y_true, y_pred):
'''
Compute the softmax log loss.
Arguments:
y_true (nD tensor): A TensorFlow tensor of any shape containing the ground truth data.
In this context, the expected tensor has shape (batch_size, #boxes, #classes)
and contains the ground truth bounding box categories.
y_pred (nD tensor): A TensorFlow tensor of identical structure to `y_true` containing
the predicted data, in this context the predicted bounding box categories.
Returns:
The softmax log loss, a nD-1 Tensorflow tensor. In this context a 2D tensor
of shape (batch, n_boxes_total).
'''
# Make sure that `y_pred` doesn't contain any zeros (which would break the log function)
y_pred = tf.maximum(y_pred, 1e-15)
# Compute the log loss
log_loss = -tf.reduce_sum(y_true * tf.log(y_pred), axis=-1)
return log_loss
def compute_loss(self, y_true, y_pred):
'''
Compute the loss of the SSD model prediction against the ground truth.
Arguments:
y_true (array): A Numpy array of shape `(batch_size, #boxes, #classes + 12)`,
where `#boxes` is the total number of boxes that the model predicts
per image. Be careful to make sure that the index of each given
box in `y_true` is the same as the index for the corresponding
box in `y_pred`. The last axis must have length `#classes + 12` and contain
`[classes one-hot encoded, 4 ground truth box coordinate offsets, 8 arbitrary entries]`
in this order, including the background class. The last eight entries of the
last axis are not used by this function and therefore their contents are
irrelevant, they only exist so that `y_true` has the same shape as `y_pred`,
where the last four entries of the last axis contain the anchor box
coordinates, which are needed during inference. Important: Boxes that
you want the cost function to ignore need to have a one-hot
class vector of all zeros.
y_pred (Keras tensor): The model prediction. The shape is identical
to that of `y_true`, i.e. `(batch_size, #boxes, #classes + 12)`.
The last axis must contain entries in the format
`[classes one-hot encoded, 4 predicted box coordinate offsets, 8 arbitrary entries]`.
Returns:
A scalar, the total multitask loss for classification and localization.
'''
self.neg_pos_ratio = tf.constant(self.neg_pos_ratio)
self.n_neg_min = tf.constant(self.n_neg_min)
self.alpha = tf.constant(self.alpha)
batch_size = tf.shape(y_pred)[0] # Output dtype: tf.int32
n_boxes = tf.shape(y_pred)[1] # Output dtype: tf.int32, note that `n_boxes` in this context denotes the total number of boxes per image, not the number of boxes per cell
# 1: Compute the losses for class and box predictions for every box
classification_loss = tf.to_float(self.log_loss(y_true[:,:,:-12], y_pred[:,:,:-12])) # Output shape: (batch_size, n_boxes)
localization_loss = tf.to_float(self.smooth_L1_loss(y_true[:,:,-12:-8], y_pred[:,:,-12:-8])) # Output shape: (batch_size, n_boxes)
# 2: Compute the classification losses for the positive and negative targets
# Create masks for the positive and negative ground truth classes
negatives = y_true[:,:,0] # Tensor of shape (batch_size, n_boxes)
positives = tf.to_float(tf.reduce_max(y_true[:,:,1:-12], axis=-1)) # Tensor of shape (batch_size, n_boxes)
# Count the number of positive boxes (classes 1 to n) in y_true across the whole batch
n_positive = tf.reduce_sum(positives)
# Now mask all negative boxes and sum up the losses for the positive boxes PER batch item
# (Keras loss functions must output one scalar loss value PER batch item, rather than just
# one scalar for the entire batch, that's why we're not summing across all axes)
pos_class_loss = tf.reduce_sum(classification_loss * positives, axis=-1) # Tensor of shape (batch_size,)
# Compute the classification loss for the negative default boxes (if there are any)
# First, compute the classification loss for all negative boxes
neg_class_loss_all = classification_loss * negatives # Tensor of shape (batch_size, n_boxes)
n_neg_losses = tf.count_nonzero(neg_class_loss_all, dtype=tf.int32) # The number of non-zero loss entries in `neg_class_loss_all`
# What's the point of `n_neg_losses`? For the next step, which will be to compute which negative boxes enter the classification
# loss, we don't just want to know how many negative ground truth boxes there are, but for how many of those there actually is
# a positive (i.e. non-zero) loss. This is necessary because `tf.nn.top-k()` in the function below will pick the top k boxes with
# the highest losses no matter what, even if it receives a vector where all losses are zero. In the unlikely event that all negative
# classification losses ARE actually zero though, this behavior might lead to `tf.nn.top-k()` returning the indices of positive
# boxes, leading to an incorrect negative classification loss computation, and hence an incorrect overall loss computation.
# We therefore need to make sure that `n_negative_keep`, which assumes the role of the `k` argument in `tf.nn.top-k()`,
# is at most the number of negative boxes for which there is a positive classification loss.
# Compute the number of negative examples we want to account for in the loss
# We'll keep at most `self.neg_pos_ratio` times the number of positives in `y_true`, but at least `self.n_neg_min` (unless `n_neg_loses` is smaller)
n_negative_keep = tf.minimum(tf.maximum(self.neg_pos_ratio * tf.to_int32(n_positive), self.n_neg_min), n_neg_losses)
# In the unlikely case when either (1) there are no negative ground truth boxes at all
# or (2) the classification loss for all negative boxes is zero, return zero as the `neg_class_loss`
def f1():
return tf.zeros([batch_size])
# Otherwise compute the negative loss
def f2():
# Now we'll identify the top-k (where k == `n_negative_keep`) boxes with the highest confidence loss that
# belong to the background class in the ground truth data. Note that this doesn't necessarily mean that the model
# predicted the wrong class for those boxes, it just means that the loss for those boxes is the highest.
# To do this, we reshape `neg_class_loss_all` to 1D...
neg_class_loss_all_1D = tf.reshape(neg_class_loss_all, [-1]) # Tensor of shape (batch_size * n_boxes,)
# ...and then we get the indices for the `n_negative_keep` boxes with the highest loss out of those...
values, indices = tf.nn.top_k(neg_class_loss_all_1D, n_negative_keep, False) # We don't need sorting
# ...and with these indices we'll create a mask...
negatives_keep = tf.scatter_nd(tf.expand_dims(indices, axis=1), updates=tf.ones_like(indices, dtype=tf.int32), shape=tf.shape(neg_class_loss_all_1D)) # Tensor of shape (batch_size * n_boxes,)
negatives_keep = tf.to_float(tf.reshape(negatives_keep, [batch_size, n_boxes])) # Tensor of shape (batch_size, n_boxes)
# ...and use it to keep only those boxes and mask all other classification losses
neg_class_loss = tf.reduce_sum(classification_loss * negatives_keep, axis=-1) # Tensor of shape (batch_size,)
return neg_class_loss
neg_class_loss = tf.cond(tf.equal(n_neg_losses, tf.constant(0)), f1, f2)
class_loss = pos_class_loss + neg_class_loss # Tensor of shape (batch_size,)
# 3: Compute the localization loss for the positive targets
# We don't penalize localization loss for negative predicted boxes (obviously: there are no ground truth boxes they would correspond to)
loc_loss = tf.reduce_sum(localization_loss * positives, axis=-1) # Tensor of shape (batch_size,)
# 4: Compute the total loss
total_loss = (class_loss + self.alpha * loc_loss) / tf.maximum(1.0, n_positive) # In case `n_positive == 0`
# Keras has the annoying habit of dividing the loss by the batch size, which sucks in our case
# because the relevant criterion to average our loss over is the number of positive boxes in the batch
# (by which we're dividing in the line above), not the batch size. So in order to revert Keras' averaging
# over the batch size, we'll have to multiply by it.
total_loss *= tf.to_float(batch_size)
return total_loss