From 3c35123e3d1187189efd3bb6acde4d7633bbffd7 Mon Sep 17 00:00:00 2001 From: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com> Date: Tue, 11 Jul 2023 10:29:22 -0600 Subject: [PATCH] Use TF-native NMS when in TF backend (#1931) * Use TF-native NMS when in TF backend * I promise I have used a computer before --- .../object_detection/non_max_suppression.py | 83 +++++++++---------- 1 file changed, 38 insertions(+), 45 deletions(-) diff --git a/keras_cv/layers/object_detection/non_max_suppression.py b/keras_cv/layers/object_detection/non_max_suppression.py index 3f50e47246..c15dfaeacb 100644 --- a/keras_cv/layers/object_detection/non_max_suppression.py +++ b/keras_cv/layers/object_detection/non_max_suppression.py @@ -88,60 +88,53 @@ def call( confidence_prediction = ops.max(class_prediction, axis=-1) - # TODO(tirthasheshpatel): Use backend-specific op where available - if multi_backend(): + if not multi_backend() or keras.backend.backend() == "tensorflow": + idx, valid_det = tf.image.non_max_suppression_padded( + box_prediction, + confidence_prediction, + max_output_size=self.max_detections, + iou_threshold=self.iou_threshold, + score_threshold=self.confidence_threshold, + pad_to_max_output_size=True, + sorted_input=False, + ) + elif keras.backend.backend() == "torch": # Since TorchVision has a nice efficient NMS op, we might as well # use it! - if keras.backend.backend() == "torch": - import torchvision - - batch_size = box_prediction.shape[0] - idx = ops.zeros((batch_size, self.max_detections)) - valid_det = ops.zeros((batch_size), "int32") - - for batch_idx in range(batch_size): - conf_mask = ( - confidence_prediction[batch_idx] - > self.confidence_threshold - ) - conf_mask_idx = ops.squeeze(ops.nonzero(conf_mask), axis=0) - conf_i = confidence_prediction[batch_idx][conf_mask] - box_i = box_prediction[batch_idx][conf_mask] - - idx_i = torchvision.ops.nms( - box_i, conf_i, iou_threshold=self.iou_threshold - ) - - idx_i = conf_mask_idx[idx_i] - - num_boxes = idx_i.shape[0] - if num_boxes >= self.max_detections: - idx_i = idx_i[: self.max_detections] - num_boxes = self.max_detections - - valid_det[batch_idx] = ops.cast(ops.size(idx_i), "int32") - idx[batch_idx, :num_boxes] = idx_i - - else: - idx, valid_det = non_max_suppression( - box_prediction, - confidence_prediction, - max_output_size=self.max_detections, - iou_threshold=self.iou_threshold, - score_threshold=self.confidence_threshold, + import torchvision + + batch_size = box_prediction.shape[0] + idx = ops.zeros((batch_size, self.max_detections)) + valid_det = ops.zeros((batch_size), "int32") + + for batch_idx in range(batch_size): + conf_mask = ( + confidence_prediction[batch_idx] > self.confidence_threshold + ) + conf_mask_idx = ops.squeeze(ops.nonzero(conf_mask), axis=0) + conf_i = confidence_prediction[batch_idx][conf_mask] + box_i = box_prediction[batch_idx][conf_mask] + + idx_i = torchvision.ops.nms( + box_i, conf_i, iou_threshold=self.iou_threshold ) + + idx_i = conf_mask_idx[idx_i] + + num_boxes = idx_i.shape[0] + if num_boxes >= self.max_detections: + idx_i = idx_i[: self.max_detections] + num_boxes = self.max_detections + + valid_det[batch_idx] = ops.cast(ops.size(idx_i), "int32") + idx[batch_idx, :num_boxes] = idx_i else: - # For non-multibackend, our NMS fails during graph tracing due to - # the lack of a defined batch size, so we just fall back to the - # original implementation that this is ported from. - idx, valid_det = tf.image.non_max_suppression_padded( + idx, valid_det = non_max_suppression( box_prediction, confidence_prediction, max_output_size=self.max_detections, iou_threshold=self.iou_threshold, score_threshold=self.confidence_threshold, - pad_to_max_output_size=True, - sorted_input=False, ) box_prediction = ops.take_along_axis(