Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

yolo8 improvements #55

Merged
merged 14 commits into from
Aug 16, 2024
Merged
86 changes: 44 additions & 42 deletions datumaro/plugins/yolo_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,19 @@ def _make_annotation_line(self, width: int, height: int, anno: Annotation) -> Op

values = _make_yolo_bbox((width, height), anno.points)
string_values = " ".join("%.6f" % p for p in values)
return "%s %s\n" % (anno.label, string_values)
return "%s %s\n" % (self._map_labels_for_save[anno.label], string_values)

@cached_property
def _labels_to_save(self) -> List[int]:
return [
label_id
for label_id, label in enumerate(self._extractor.categories()[AnnotationType.label])
if label.parent == ""
]

@cached_property
def _map_labels_for_save(self) -> Dict[int, int]:
return {label_id: index for index, label_id in enumerate(self._labels_to_save)}

@staticmethod
def _make_image_subset_folder(save_dir: str, subset: str) -> str:
Expand Down Expand Up @@ -261,7 +273,7 @@ def patch(cls, dataset: IExtractor, patch: DatasetPatch, save_dir: str, **kwargs
os.remove(ann_path)


class YOLOv8Converter(YoloConverter):
class YOLOv8DetectionConverter(YoloConverter):
RESERVED_CONFIG_KEYS = YOLOv8Path.RESERVED_CONFIG_KEYS

def __init__(
Expand All @@ -276,6 +288,10 @@ def __init__(
super().__init__(extractor, save_dir, add_path_prefix=add_path_prefix, **kwargs)
self._config_filename = config_file or YOLOv8Path.DEFAULT_CONFIG_FILE

def _export_item_annotation(self, item: DatasetItem, subset_dir: str) -> None:
if len(item.annotations) > 0:
SpecLad marked this conversation as resolved.
Show resolved Hide resolved
super()._export_item_annotation(item, subset_dir)

@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
Expand All @@ -287,15 +303,19 @@ def build_cmdline_parser(cls, **kwargs):
)
return parser

def _save_config_files(self, subset_lists: Dict[str, str]):
def _save_config_files(self, subset_lists: Dict[str, str], **extra_config_fields):
extractor = self._extractor
save_dir = self._save_dir
with open(osp.join(save_dir, self._config_filename), "w", encoding="utf-8") as f:
label_categories = extractor.categories()[AnnotationType.label]
data = dict(
path=".",
names={idx: label.name for idx, label in enumerate(label_categories.items)},
names={
index: label_categories[label_id].name
for label_id, index in self._map_labels_for_save.items()
},
**subset_lists,
**extra_config_fields,
)
yaml.dump(data, f)

Expand All @@ -308,65 +328,47 @@ def _make_annotation_subset_folder(save_dir: str, subset: str) -> str:
return osp.join(save_dir, YOLOv8Path.LABELS_FOLDER_NAME, subset)


class YOLOv8SegmentationConverter(YOLOv8Converter):
class YOLOv8SegmentationConverter(YOLOv8DetectionConverter):
def _make_annotation_line(self, width: int, height: int, anno: Annotation) -> Optional[str]:
if anno.label is None or not isinstance(anno, Polygon):
return
values = [value / size for value, size in zip(anno.points, cycle((width, height)))]
string_values = " ".join("%.6f" % p for p in values)
return "%s %s\n" % (anno.label, string_values)
return "%s %s\n" % (self._map_labels_for_save[anno.label], string_values)


class YOLOv8OrientedBoxesConverter(YOLOv8Converter):
class YOLOv8OrientedBoxesConverter(YOLOv8DetectionConverter):
def _make_annotation_line(self, width: int, height: int, anno: Annotation) -> Optional[str]:
if anno.label is None or not isinstance(anno, Bbox):
return
points = _bbox_annotation_as_polygon(anno)
values = [value / size for value, size in zip(points, cycle((width, height)))]
string_values = " ".join("%.6f" % p for p in values)
return "%s %s\n" % (anno.label, string_values)
return "%s %s\n" % (self._map_labels_for_save[anno.label], string_values)


class YOLOv8PoseConverter(YOLOv8Converter):
class YOLOv8PoseConverter(YOLOv8DetectionConverter):
@cached_property
def _map_labels_for_save(self):
def _labels_to_save(self) -> List[int]:
point_categories = self._extractor.categories().get(
AnnotationType.points, PointsCategories.from_iterable([])
)
return {label_id: index for index, label_id in enumerate(sorted(point_categories.items))}

def _save_config_files(self, subset_lists: Dict[str, str]):
extractor = self._extractor
save_dir = self._save_dir
return sorted(point_categories.items)

point_categories = extractor.categories().get(
AnnotationType.points, PointsCategories.from_iterable([])
)
if len(set(len(cat.labels) for cat in point_categories.items.values())) > 1:
raise DatasetExportError(
"Can't export: skeletons should have the same number of points"
)
n_of_points = (
len(next(iter(point_categories.items.values())).labels)
if len(point_categories) > 0
else 0
@cached_property
def _max_number_of_points(self):
point_categories = self._extractor.categories().get(AnnotationType.points)
if point_categories is None or len(point_categories) == 0:
return 0
return max(len(category.labels) for category in point_categories.items.values())

def _save_config_files(self, subset_lists: Dict[str, str], **extra_config_fields):
super()._save_config_files(
subset_lists=subset_lists,
kpt_shape=[self._max_number_of_points, 3],
**extra_config_fields,
)

with open(osp.join(save_dir, self._config_filename), "w", encoding="utf-8") as f:
label_categories = extractor.categories()[AnnotationType.label]
parent_categories = {
self._map_labels_for_save[label_id]: label_categories.items[label_id].name
for label_id in point_categories.items
}
assert set(parent_categories.keys()) == set(range(len(parent_categories)))
data = dict(
path=".",
names=parent_categories,
kpt_shape=[n_of_points, 3],
**subset_lists,
)
yaml.dump(data, f)

def _make_annotation_line(self, width: int, height: int, skeleton: Annotation) -> Optional[str]:
if skeleton.label is None or not isinstance(skeleton, Skeleton):
return
Expand All @@ -385,7 +387,7 @@ def _make_annotation_line(self, width: int, height: int, skeleton: Annotation) -
.labels
]

points_values = [f"0.0, 0.0, {Points.Visibility.absent.value}"] * len(point_label_ids)
points_values = [f"0.0 0.0 {Points.Visibility.absent.value}"] * self._max_number_of_points
for element in skeleton.elements:
assert len(element.points) == 2 and len(element.visibility) == 1
position = point_label_ids.index(element.label)
Expand Down
80 changes: 28 additions & 52 deletions datumaro/plugins/yolo_format/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

import math
import os
import os.path as osp
import re
Expand All @@ -14,6 +13,8 @@
from itertools import cycle
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import cv2
import numpy as np
import yaml

from datumaro.components.annotation import (
Expand Down Expand Up @@ -319,7 +320,7 @@ def get_subset(self, name):
return self._subsets[name]


class YOLOv8Extractor(YoloExtractor):
class YOLOv8DetectionExtractor(YoloExtractor):
RESERVED_CONFIG_KEYS = YOLOv8Path.RESERVED_CONFIG_KEYS

def __init__(
Expand All @@ -330,6 +331,13 @@ def __init__(
) -> None:
super().__init__(*args, **kwargs)

def _parse_annotations(
self, anno_path: str, image: Image, *, item_id: Tuple[str, str]
) -> List[Annotation]:
if not osp.exists(anno_path):
return []
return super()._parse_annotations(anno_path, image, item_id=item_id)

@cached_property
def _config(self) -> Dict[str, Any]:
with open(self._config_path) as stream:
Expand Down Expand Up @@ -424,7 +432,7 @@ def _iterate_over_image_paths(
yield from subset_images_source


class YOLOv8SegmentationExtractor(YOLOv8Extractor):
class YOLOv8SegmentationExtractor(YOLOv8DetectionExtractor):
def _load_segmentation_annotation(
self, parts: List[str], image_height: int, image_width: int
) -> Polygon:
Expand All @@ -451,30 +459,7 @@ def _load_one_annotation(
)


class YOLOv8OrientedBoxesExtractor(YOLOv8Extractor):
RECTANGLE_ANGLE_PRECISION = math.pi * 1 / 180

@classmethod
def _check_is_rectangle(
cls, p1: Tuple[int, int], p2: Tuple[int, int], p3: Tuple[int, int], p4: Tuple[int, int]
) -> None:
p12_angle = math.atan2(p2[0] - p1[0], p2[1] - p1[1])
p23_angle = math.atan2(p3[0] - p2[0], p3[1] - p2[1])
p43_angle = math.atan2(p3[0] - p4[0], p3[1] - p4[1])
p14_angle = math.atan2(p4[0] - p1[0], p4[1] - p1[1])

if (
abs(p12_angle - p43_angle) > 0.001
or abs(p23_angle - p14_angle) > cls.RECTANGLE_ANGLE_PRECISION
):
raise InvalidAnnotationError(
"Given points do not form a rectangle: opposite sides have different slope angles."
)
if abs((p12_angle - p23_angle) % math.pi - math.pi / 2) > cls.RECTANGLE_ANGLE_PRECISION:
raise InvalidAnnotationError(
"Given points do not form a rectangle: adjacent sides are not orthogonal."
)

class YOLOv8OrientedBoxesExtractor(YOLOv8DetectionExtractor):
def _load_one_annotation(
self, parts: List[str], image_height: int, image_width: int
) -> Annotation:
Expand All @@ -491,30 +476,23 @@ def _load_one_annotation(
)
for idx, (x, y) in enumerate(take_by(parts[1:], 2))
]
self._check_is_rectangle(*points)

(x1, y1), (x2, y2), (x3, y3), (x4, y4) = points

width = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
height = math.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
rotation = math.atan2(y2 - y1, x2 - x1)
if rotation < 0:
rotation += math.pi * 2

center_x = (x1 + x2 + x3 + x4) / 4
center_y = (y1 + y2 + y3 + y4) / 4
(center_x, center_y), (width, height), rotation = cv2.minAreaRect(
np.array(points, dtype=np.float32)
)
rotation = rotation % 180

return Bbox(
x=center_x - width / 2,
y=center_y - height / 2,
w=width,
h=height,
label=label_id,
attributes=(dict(rotation=math.degrees(rotation)) if abs(rotation) > 0.00001 else {}),
attributes=(dict(rotation=rotation) if abs(rotation) > 0.00001 else {}),
)


class YOLOv8PoseExtractor(YOLOv8Extractor):
class YOLOv8PoseExtractor(YOLOv8DetectionExtractor):
def __init__(
self,
*args,
Expand Down Expand Up @@ -572,7 +550,7 @@ def _load_categories(self) -> CategoriesInfo:
if has_meta_file(self._path):
return self._load_categories_from_meta_file()

number_of_points, _ = self._kpt_shape
max_number_of_points, _ = self._kpt_shape
skeleton_labels = self._load_names_from_config_file()

if self._skeleton_sub_labels:
Expand All @@ -584,16 +562,17 @@ def _load_categories(self) -> CategoriesInfo:
if skeletons_with_wrong_sub_labels := [
skeleton
for skeleton in skeleton_labels
if len(self._skeleton_sub_labels[skeleton]) != number_of_points
if len(self._skeleton_sub_labels[skeleton]) > max_number_of_points
]:
raise InvalidAnnotationError(
f"Number of points in skeletons according to config file is {number_of_points}. "
f"Following skeletons have number of sub labels which differs: {skeletons_with_wrong_sub_labels}"
f"Number of points in skeletons according to config file is {max_number_of_points}. "
f"Following skeletons have more sub labels: {skeletons_with_wrong_sub_labels}"
)

children_labels = self._skeleton_sub_labels or {
skeleton_label: [
f"{skeleton_label}_point_{point_index}" for point_index in range(number_of_points)
f"{skeleton_label}_point_{point_index}"
for point_index in range(max_number_of_points)
]
for skeleton_label in skeleton_labels
}
Expand Down Expand Up @@ -625,12 +604,12 @@ def _map_label_id(self, ann_label_id: str) -> int:
def _load_one_annotation(
self, parts: List[str], image_height: int, image_width: int
) -> Annotation:
number_of_points, values_per_point = self._kpt_shape
if len(parts) != 5 + number_of_points * values_per_point:
max_number_of_points, values_per_point = self._kpt_shape
if len(parts) != 5 + max_number_of_points * values_per_point:
raise InvalidAnnotationError(
f"Unexpected field count {len(parts)} in the skeleton description. "
"Expected 5 fields (label, xc, yc, w, h)"
f"and then {values_per_point} for each of {number_of_points} points"
f"and then {values_per_point} for each of {max_number_of_points} points"
)

label_id = self._map_label_id(parts[0])
Expand Down Expand Up @@ -674,7 +653,4 @@ def _load_one_annotation(
),
]
]
return Skeleton(
points,
label=label_id,
)
return Skeleton(points, label=label_id)
12 changes: 6 additions & 6 deletions datumaro/plugins/yolo_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datumaro import Importer
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.plugins.yolo_format.extractor import (
YOLOv8Extractor,
YOLOv8DetectionExtractor,
YOLOv8OrientedBoxesExtractor,
YOLOv8PoseExtractor,
YOLOv8SegmentationExtractor,
Expand All @@ -31,8 +31,8 @@ def find_sources(cls, path) -> List[Dict[str, Any]]:
return cls._find_sources_recursive(path, ".data", "yolo")


class YOLOv8Importer(Importer):
EXTRACTOR = YOLOv8Extractor
class YOLOv8DetectionImporter(Importer):
EXTRACTOR = YOLOv8DetectionExtractor

@classmethod
def build_cmdline_parser(cls, **kwargs):
Expand Down Expand Up @@ -84,15 +84,15 @@ def find_sources_with_params(
]


class YOLOv8SegmentationImporter(YOLOv8Importer):
class YOLOv8SegmentationImporter(YOLOv8DetectionImporter):
EXTRACTOR = YOLOv8SegmentationExtractor


class YOLOv8OrientedBoxesImporter(YOLOv8Importer):
class YOLOv8OrientedBoxesImporter(YOLOv8DetectionImporter):
EXTRACTOR = YOLOv8OrientedBoxesExtractor


class YOLOv8PoseImporter(YOLOv8Importer):
class YOLOv8PoseImporter(YOLOv8DetectionImporter):
EXTRACTOR = YOLOv8PoseExtractor

@classmethod
Expand Down
Loading
Loading