Skip to content

Commit

Permalink
object detection label format (#941)
Browse files Browse the repository at this point in the history
* added object detection metrics

* object detection labels are now list of dicts
  • Loading branch information
lcadalzo authored Nov 18, 2020
1 parent 2cc3625 commit e09b6c4
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 36 deletions.
12 changes: 12 additions & 0 deletions armory/data/adversarial_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,23 @@ def gtsrb_poison(
)


def apricot_label_preprocessing(x, y):
"""
Convert labels to list of dicts. If batch_size > 1, this will already be the case,
and y will simply be returned without modification.
"""
if isinstance(y, dict):
y = [y]
return y


def apricot_dev_adversarial(
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
preprocessing_fn: Callable = apricot_canonical_preprocessing,
label_preprocessing_fn: Callable = apricot_label_preprocessing,
cache_dataset: bool = True,
framework: str = "numpy",
shuffle_files: bool = False,
Expand Down Expand Up @@ -328,6 +339,7 @@ def replace_magic_val(data, raw_val, transformed_val, sub_key):
epochs=epochs,
dataset_dir=dataset_dir,
preprocessing_fn=preprocessing_fn,
label_preprocessing_fn=label_preprocessing_fn,
as_supervised=False,
supervised_xy_keys=("image", "objects"),
shuffle_files=shuffle_files,
Expand Down
40 changes: 18 additions & 22 deletions armory/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,8 @@ def get_batch(self) -> (np.ndarray, np.ndarray):

if self.variable_y:
if isinstance(y_list[0], dict):
# Translate a list of dicts into a dict of arrays
y = {}
for k in y_list[0].keys():
y[k] = self.np_1D_object_array([y_i[k] for y_i in y_list])
# Store y as a list of dicts
y = y_list
elif isinstance(y_list[0], tuple):
# Translate a list of tuples into a tuple of arrays
y = tuple(self.np_1D_object_array(i) for i in zip(*y_list))
Expand Down Expand Up @@ -967,29 +965,27 @@ def ucf101(
)


def tf_to_pytorch_box_conversion(x, y):
def xview_label_preprocessing(x, y):
"""
Converts boxes from TF format to PyTorch format
TF format: [y1/height, x1/width, y2/height, x2/width]
PyTorch format: [x1, y1, x2, y2] (unnormalized)
Additionally, if batch_size is 1, this function converts the single y dictionary
to a list of length 1.
"""
orig_boxes = y["boxes"]
if orig_boxes.dtype == np.object:
converted_boxes = np.empty(orig_boxes.shape, dtype=object)
for i, (x_i, orig_boxes_i) in enumerate(zip(x, orig_boxes)):
height, width = x_i.shape[:2]
converted_boxes[i] = orig_boxes_i[:, [1, 0, 3, 2]] * [
width,
height,
width,
height,
]
else:
converted_boxes = orig_boxes[:, :, [1, 0, 3, 2]]
height, width = x.shape[1:3]
y_preprocessed = []
# This will be true only when batch_size is 1
if isinstance(y, dict):
y = [y]
for i, label_dict in enumerate(y):
orig_boxes = label_dict["boxes"].reshape((-1, 4))
converted_boxes = orig_boxes[:, [1, 0, 3, 2]]
height, width = x[i].shape[:2]
converted_boxes *= [width, height, width, height]
y["boxes"] = converted_boxes
return y
label_dict["boxes"] = converted_boxes
y_preprocessed.append(label_dict)
return y_preprocessed


def xview(
Expand All @@ -999,7 +995,7 @@ def xview(
dataset_dir: str = None,
preprocessing_fn: Callable = xview_canonical_preprocessing,
fit_preprocessing_fn: Callable = None,
label_preprocessing_fn: Callable = tf_to_pytorch_box_conversion,
label_preprocessing_fn: Callable = xview_label_preprocessing,
cache_dataset: bool = True,
framework: str = "numpy",
shuffle_files: bool = True,
Expand Down
28 changes: 14 additions & 14 deletions armory/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,23 +497,23 @@ def object_detection_AP_per_class(list_of_ys, list_of_y_preds):
for batch_idx, (y, y_pred) in enumerate(zip(list_of_ys, list_of_y_preds)):
for img_idx in range(len(y_pred)):
global_img_idx = (batch_size * batch_idx) + img_idx
for gt_box_idx in range(y["labels"][img_idx].size):
label = y["labels"][img_idx][gt_box_idx]
box = y["boxes"][img_idx][gt_box_idx]

img_labels = y[img_idx]["labels"].flatten()
img_boxes = y[img_idx]["boxes"].reshape((-1, 4))
for gt_box_idx in range(img_labels.size):
label = img_labels[gt_box_idx]
box = img_boxes[gt_box_idx]
gt_box_dict = {"img_idx": global_img_idx, "label": label, "box": box}
gt_boxes_list.append(gt_box_dict)

for pred_box_idx in range(y_pred[img_idx]["labels"].size):
label = y_pred[img_idx]["labels"][pred_box_idx]
box = y_pred[img_idx]["boxes"][pred_box_idx]
score = y_pred[img_idx]["scores"][pred_box_idx]

pred_label = y_pred[img_idx]["labels"][pred_box_idx]
pred_box = y_pred[img_idx]["boxes"][pred_box_idx]
pred_score = y_pred[img_idx]["scores"][pred_box_idx]
pred_box_dict = {
"img_idx": global_img_idx,
"label": label,
"box": box,
"score": score,
"label": pred_label,
"box": pred_box,
"score": pred_score,
}
pred_boxes_list.append(pred_box_dict)

Expand Down Expand Up @@ -677,10 +677,10 @@ class (at a location overlapping the patch). A false positive is the case where
for img_idx in range(len(y_pred)):
global_img_idx = (batch_size * batch_idx) + img_idx
idx_of_patch = np.where(
y["labels"][img_idx] == ADV_PATCH_MAGIC_NUMBER_LABEL_ID
y[img_idx]["labels"].flatten() == ADV_PATCH_MAGIC_NUMBER_LABEL_ID
)[0]
patch_box = y["boxes"][img_idx][idx_of_patch].flatten()
patch_id = int(y["patch_id"][img_idx][idx_of_patch])
patch_box = y[img_idx]["boxes"].reshape((-1, 4))[idx_of_patch].flatten()
patch_id = int(y[img_idx]["patch_id"].flatten()[idx_of_patch])
patch_target_label = APRICOT_PATCHES[patch_id]["adv_target"]
patch_box_dict = {
"img_idx": global_img_idx,
Expand Down

0 comments on commit e09b6c4

Please sign in to comment.