Skip to content

Commit

Permalink
Use TF-native NMS when in TF backend (keras-team#1931)
Browse files Browse the repository at this point in the history
* Use TF-native NMS when in TF backend

* I promise I have used a computer before
  • Loading branch information
ianstenbit authored Jul 11, 2023
1 parent bd747c4 commit 3c35123
Showing 1 changed file with 38 additions and 45 deletions.
83 changes: 38 additions & 45 deletions keras_cv/layers/object_detection/non_max_suppression.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,60 +88,53 @@ def call(

confidence_prediction = ops.max(class_prediction, axis=-1)

# TODO(tirthasheshpatel): Use backend-specific op where available
if multi_backend():
if not multi_backend() or keras.backend.backend() == "tensorflow":
idx, valid_det = tf.image.non_max_suppression_padded(
box_prediction,
confidence_prediction,
max_output_size=self.max_detections,
iou_threshold=self.iou_threshold,
score_threshold=self.confidence_threshold,
pad_to_max_output_size=True,
sorted_input=False,
)
elif keras.backend.backend() == "torch":
# Since TorchVision has a nice efficient NMS op, we might as well
# use it!
if keras.backend.backend() == "torch":
import torchvision

batch_size = box_prediction.shape[0]
idx = ops.zeros((batch_size, self.max_detections))
valid_det = ops.zeros((batch_size), "int32")

for batch_idx in range(batch_size):
conf_mask = (
confidence_prediction[batch_idx]
> self.confidence_threshold
)
conf_mask_idx = ops.squeeze(ops.nonzero(conf_mask), axis=0)
conf_i = confidence_prediction[batch_idx][conf_mask]
box_i = box_prediction[batch_idx][conf_mask]

idx_i = torchvision.ops.nms(
box_i, conf_i, iou_threshold=self.iou_threshold
)

idx_i = conf_mask_idx[idx_i]

num_boxes = idx_i.shape[0]
if num_boxes >= self.max_detections:
idx_i = idx_i[: self.max_detections]
num_boxes = self.max_detections

valid_det[batch_idx] = ops.cast(ops.size(idx_i), "int32")
idx[batch_idx, :num_boxes] = idx_i

else:
idx, valid_det = non_max_suppression(
box_prediction,
confidence_prediction,
max_output_size=self.max_detections,
iou_threshold=self.iou_threshold,
score_threshold=self.confidence_threshold,
import torchvision

batch_size = box_prediction.shape[0]
idx = ops.zeros((batch_size, self.max_detections))
valid_det = ops.zeros((batch_size), "int32")

for batch_idx in range(batch_size):
conf_mask = (
confidence_prediction[batch_idx] > self.confidence_threshold
)
conf_mask_idx = ops.squeeze(ops.nonzero(conf_mask), axis=0)
conf_i = confidence_prediction[batch_idx][conf_mask]
box_i = box_prediction[batch_idx][conf_mask]

idx_i = torchvision.ops.nms(
box_i, conf_i, iou_threshold=self.iou_threshold
)

idx_i = conf_mask_idx[idx_i]

num_boxes = idx_i.shape[0]
if num_boxes >= self.max_detections:
idx_i = idx_i[: self.max_detections]
num_boxes = self.max_detections

valid_det[batch_idx] = ops.cast(ops.size(idx_i), "int32")
idx[batch_idx, :num_boxes] = idx_i
else:
# For non-multibackend, our NMS fails during graph tracing due to
# the lack of a defined batch size, so we just fall back to the
# original implementation that this is ported from.
idx, valid_det = tf.image.non_max_suppression_padded(
idx, valid_det = non_max_suppression(
box_prediction,
confidence_prediction,
max_output_size=self.max_detections,
iou_threshold=self.iou_threshold,
score_threshold=self.confidence_threshold,
pad_to_max_output_size=True,
sorted_input=False,
)

box_prediction = ops.take_along_axis(
Expand Down

0 comments on commit 3c35123

Please sign in to comment.