Skip to content

Commit

Permalink
feat: add NMS to Florence2Sam2 (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
CamiloInx authored Sep 19, 2024
1 parent 9701f5e commit 2eea1c2
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions vision_agent_tools/models/florence2_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,33 @@ def __init__(self, model_config: Florence2SAM2Config | None = None):
)
self.image_predictor = SAM2ImagePredictor(self.video_predictor)

def _dummy_agnostic_non_max_suppression(self, predictions, nms_threshold):
"""
Apply agnostic Non-Maximum Suppression (NMS) to filter overlapping predictions.
Parameters:
predictions (dict[int, ImageBboxAndMaskLabel]): Dictionary of predictions.
nms_threshold (float): The IoU threshold value used for NMS.
Returns:
dict[int, ImageBboxAndMaskLabel]: Filtered predictions after applying NMS.
"""
filtered_predictions = {}
prediction_items = list(predictions.items())

while prediction_items:
best_prediction = prediction_items.pop(0)
filtered_predictions[best_prediction[0]] = best_prediction[1]

prediction_items = [
pred
for pred in prediction_items
if self._calculate_iou(best_prediction[1].mask, pred[1].mask)
< nms_threshold
]

return filtered_predictions

def _calculate_iou(
self, mask1: SegmentationBitMask, mask2: SegmentationBitMask
) -> float:
Expand Down Expand Up @@ -105,6 +132,7 @@ def _update_reference_predictions(
new_predictions: dict[int, ImageBboxAndMaskLabel],
objects_count: int,
iou_threshold: float = 0.8,
nms_threshold: float = 0.3,
) -> tuple[dict[int, ImageBboxAndMaskLabel], int]:
"""
Updates the object prediction ids of the 'new_predictions' input to match
Expand Down Expand Up @@ -139,7 +167,9 @@ def _update_reference_predictions(
new_obj_id = objects_count
new_prediction_objects[new_obj_id] = new_predictions[new_annotation_id]

updated_predictions = {**last_predictions, **new_prediction_objects}
updated_predictions = self._dummy_agnostic_non_max_suppression(
{**last_predictions, **new_prediction_objects}, nms_threshold
)
return (updated_predictions, objects_count)

@torch.inference_mode()
Expand Down Expand Up @@ -197,6 +227,7 @@ def handle_video(
video: VideoNumpy,
chunk_length: int | None = 20,
iou_threshold: float = 0.8,
nms_threshold: float = 0.3,
) -> dict[int, dict[int, ImageBboxAndMaskLabel]]:
video_shape = video.shape
num_frames = video_shape[0]
Expand All @@ -221,7 +252,7 @@ def handle_video(
# and update the object prediction id, to match the previous id.
# Also add the new objects in case they didn't exist before.
updated_objs, objects_count = self._update_reference_predictions(
last_chunk_frame_pred, objs, objects_count, iou_threshold
last_chunk_frame_pred, objs, objects_count, iou_threshold, nms_threshold
)
self.video_predictor.reset_state(inference_state)

Expand Down Expand Up @@ -279,6 +310,7 @@ def __call__(
video: VideoNumpy | None = None,
chunk_length: int | None = 20,
iou_threshold: float = 0.8,
nms_threshold: float = 0.3,
) -> dict[int, dict[int, ImageBboxAndMaskLabel]]:
"""
Florence2Sam2 model find objects in an image and track objects in a video.
Expand Down Expand Up @@ -313,5 +345,5 @@ def __call__(
return self.handle_image(prompt, image)
elif video is not None:
assert video.ndim == 4, "Video should have 4 dimensions"
return self.handle_video(prompt, video, chunk_length, iou_threshold)
return self.handle_video(prompt, video, chunk_length, iou_threshold, nms_threshold)
# No need to raise an error here, the validatie_call decorator will take care of it

0 comments on commit 2eea1c2

Please sign in to comment.