Skip to content

Commit

Permalink
Fix in filter_bboxes (#1949)
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Sep 22, 2024
1 parent b358a88 commit 37414e0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
9 changes: 8 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,16 @@ repos:
entry: python tools/check_docstrings.py
language: system
types: [python]
- repo: local
hooks:
- id: check-albucore-version
name: Check albucore version
entry: python ./tools/check_albucore_version.py
language: system
files: setup.py
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.6
rev: v0.6.7
hooks:
# Run the linter.
- id: ruff
Expand Down
5 changes: 2 additions & 3 deletions albumentations/core/bbox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def filter_bboxes(
epsilon = 1e-7

if len(bboxes) == 0:
return np.array([], dtype=np.float32)
return np.array([], dtype=np.float32).reshape(0, 4)

# Calculate areas of bounding boxes before clipping in pixels
denormalized_box_areas = calculate_bbox_areas_in_pixels(bboxes, image_shape)
Expand Down Expand Up @@ -474,8 +474,7 @@ def filter_bboxes(
# Apply the mask to get the filtered bboxes
filtered_bboxes = clipped_bboxes[mask]

# If no bboxes pass the filter, return an empty array with the same number of columns as input
return filtered_bboxes if len(filtered_bboxes) > 0 else np.array([], dtype=np.float32)
return np.array([], dtype=np.float32).reshape(0, 4) if len(filtered_bboxes) == 0 else filtered_bboxes


def union_of_bboxes(bboxes: np.ndarray, erosion_rate: float) -> np.ndarray | None:
Expand Down
12 changes: 9 additions & 3 deletions tests/test_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_check_bboxes_additional_columns():
np.array([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], [0.5, 0.5, 0.6, 0.6]]),
(100, 100),
200, 0, 0, 0,
np.array([])
np.array([]).reshape(0, 4)
),
(
np.array([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], [0.5, 0.5, 0.6, 0.6]]),
Expand Down Expand Up @@ -524,14 +524,20 @@ def test_check_bboxes_additional_columns():
np.array([[0.1, 0.1, 0.2, 0.2, 1], [0.3, 0.3, 0.4, 0.4, 2], [0.5, 0.5, 0.6, 0.7, 3]]),
(100, 100),
300, 0, 0, 0,
np.array([])
np.array([]).reshape(0, 4)
),
(
np.array([]),
(100, 100),
0, 0, 0, 0,
np.array([])
np.array([]).reshape(0, 4)
),
(
np.array([[0.1, 0.1, 0.2, 0.2]]),
(100, 100),
101, 0, 0, 0,
np.array([]).reshape(0, 4)
)
])
def test_filter_bboxes(bboxes, image_shape, min_area, min_visibility, min_width, min_height, expected):
result = filter_bboxes(bboxes, image_shape, min_area, min_visibility, min_width, min_height)
Expand Down

0 comments on commit 37414e0

Please sign in to comment.