From 37414e0f002684de1b15465edc1d9ca3ac53dbad Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Sun, 22 Sep 2024 12:52:41 -0700 Subject: [PATCH] Fix in filter_bboxes (#1949) --- .pre-commit-config.yaml | 9 ++++++++- albumentations/core/bbox_utils.py | 5 ++--- tests/test_bbox.py | 12 +++++++++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a250af51..c77dbb920 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,9 +42,16 @@ repos: entry: python tools/check_docstrings.py language: system types: [python] + - repo: local + hooks: + - id: check-albucore-version + name: Check albucore version + entry: python ./tools/check_albucore_version.py + language: system + files: setup.py - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.6 + rev: v0.6.7 hooks: # Run the linter. - id: ruff diff --git a/albumentations/core/bbox_utils.py b/albumentations/core/bbox_utils.py index 4088f83ec..4777e7c0a 100644 --- a/albumentations/core/bbox_utils.py +++ b/albumentations/core/bbox_utils.py @@ -445,7 +445,7 @@ def filter_bboxes( epsilon = 1e-7 if len(bboxes) == 0: - return np.array([], dtype=np.float32) + return np.array([], dtype=np.float32).reshape(0, 4) # Calculate areas of bounding boxes before clipping in pixels denormalized_box_areas = calculate_bbox_areas_in_pixels(bboxes, image_shape) @@ -474,8 +474,7 @@ def filter_bboxes( # Apply the mask to get the filtered bboxes filtered_bboxes = clipped_bboxes[mask] - # If no bboxes pass the filter, return an empty array with the same number of columns as input - return filtered_bboxes if len(filtered_bboxes) > 0 else np.array([], dtype=np.float32) + return np.array([], dtype=np.float32).reshape(0, 4) if len(filtered_bboxes) == 0 else filtered_bboxes def union_of_bboxes(bboxes: np.ndarray, erosion_rate: float) -> np.ndarray | None: diff --git a/tests/test_bbox.py b/tests/test_bbox.py index ce0ced405..53aaa835b 100644 --- a/tests/test_bbox.py +++ b/tests/test_bbox.py @@ -494,7 +494,7 @@ def test_check_bboxes_additional_columns(): np.array([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], [0.5, 0.5, 0.6, 0.6]]), (100, 100), 200, 0, 0, 0, - np.array([]) + np.array([]).reshape(0, 4) ), ( np.array([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], [0.5, 0.5, 0.6, 0.6]]), @@ -524,14 +524,20 @@ def test_check_bboxes_additional_columns(): np.array([[0.1, 0.1, 0.2, 0.2, 1], [0.3, 0.3, 0.4, 0.4, 2], [0.5, 0.5, 0.6, 0.7, 3]]), (100, 100), 300, 0, 0, 0, - np.array([]) + np.array([]).reshape(0, 4) ), ( np.array([]), (100, 100), 0, 0, 0, 0, - np.array([]) + np.array([]).reshape(0, 4) ), + ( + np.array([[0.1, 0.1, 0.2, 0.2]]), + (100, 100), + 101, 0, 0, 0, + np.array([]).reshape(0, 4) + ) ]) def test_filter_bboxes(bboxes, image_shape, min_area, min_visibility, min_width, min_height, expected): result = filter_bboxes(bboxes, image_shape, min_area, min_visibility, min_width, min_height)