Skip to content

Commit

Permalink
(Hopefully) fixed the inf loss and loss explosion issue by ignoring a…
Browse files Browse the repository at this point in the history
…ugmented boxes < 4 px in width and height. Fixes #222
  • Loading branch information
dbolya committed Dec 6, 2019
1 parent 02bde37 commit 821e830
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ This document will detail all changes I make.
I don't know how I'm going to be versioning things yet, so you get dates for now.

```
2019.12.06:
- Made training much more stable (no more infs and hopefully fewer loss explosions) by ignoring
augmented boxes with < 4px of height and width (this includes 0 area boxes which caused the inf).
See #222 for details.
2019.11.20:
- Fixed bug where saving videos wouldn't work when using cv2 not compiled with display support (#197).
Expand Down
14 changes: 12 additions & 2 deletions data/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from .config import cfg
from pycocotools import mask as maskUtils
import random

def get_label_map():
if cfg.dataset.label_map is None:
Expand Down Expand Up @@ -36,14 +37,16 @@ def __call__(self, target, width, height):
for obj in target:
if 'bbox' in obj:
bbox = obj['bbox']
label_idx = self.label_map[obj['category_id']] - 1
label_idx = obj['category_id']
if label_idx >= 0:
label_idx = self.label_map[label_idx] - 1
final_box = list(np.array([bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]])/scale)
final_box.append(label_idx)
res += [final_box] # [xmin, ymin, xmax, ymax, label_idx]
else:
print("No bbox found for object ", obj)

return res # [[xmin, ymin, xmax, ymax, label_idx], ... ]
return res


class COCODetection(data.Dataset):
Expand Down Expand Up @@ -121,6 +124,9 @@ def pull_item(self, index):
target = [x for x in target if not ('iscrowd' in x and x['iscrowd'])]
num_crowds = len(crowd)

for x in crowd:
x['category_id'] = -1

# This is so we ensure that all crowd annotations are at the end of the array
target += crowd

Expand Down Expand Up @@ -164,6 +170,10 @@ def pull_item(self, index):
masks = None
target = None

if target.shape[0] == 0:
print('Warning: Augmentation output an example with no ground truth. Resampling...')
return self.pull_image(random.randint(0, len(self.ids)-1))

return torch.from_numpy(img).permute(2, 0, 1), target, masks, height, width, num_crowds

def pull_image(self, index):
Expand Down
4 changes: 4 additions & 0 deletions data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,10 @@ def print(self):
# With uniform probability, rotate the image [0,90,180,270] degrees
'augment_random_rot90': False,

# Discard detections with width and height smaller than this (in absolute width and height)
'discard_box_width': 4 / 550,
'discard_box_height': 4 / 550,

# If using batchnorm anywhere in the backbone, freeze the batchnorm layer during training.
# Note: any additional batch norm layers after the backbone will not be frozen.
'freeze_bn': False,
Expand Down
2 changes: 1 addition & 1 deletion layers/output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
if visualize_lincomb:
display_lincomb(proto_data, masks)

masks = torch.matmul(proto_data, masks.t())
masks = proto_data @ masks.t()
masks = cfg.mask_proto_mask_activation(masks)

# Crop masks before upsampling because you know why
Expand Down
10 changes: 10 additions & 0 deletions utils/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,16 @@ def __call__(self, image, masks, boxes, labels=None):
boxes[:, [0, 2]] *= (width / img_w)
boxes[:, [1, 3]] *= (height / img_h)

# Discard boxes that are smaller than we'd like
w = boxes[:, 2] - boxes[:, 0]
h = boxes[:, 3] - boxes[:, 1]

keep = (w > cfg.discard_box_width) * (h > cfg.discard_box_height)
masks = masks[keep]
boxes = boxes[keep]
labels['labels'] = labels['labels'][keep]
labels['num_crowds'] = (labels['labels'] < 0).sum()

return image, masks, boxes, labels


Expand Down

0 comments on commit 821e830

Please sign in to comment.