Skip to content

Commit

Permalink
[Datumaro] Label remapping transform (#1233)
Browse files Browse the repository at this point in the history
* Add label remapping transform

* Apply transforms before project saving

* Refactor voc converter
  • Loading branch information
nmanovic authored Mar 5, 2020
1 parent 78dad73 commit be5577d
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 17 deletions.
2 changes: 2 additions & 0 deletions datumaro/datumaro/components/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ def sources(self):
return self._sources

def _save_branch_project(self, extractor, save_dir=None):
extractor = Dataset.from_extractors(extractor) # apply lazy transforms

# NOTE: probably this function should be in the ViewModel layer
save_dir = osp.abspath(save_dir)
if save_dir:
Expand Down
120 changes: 118 additions & 2 deletions datumaro/datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
#
# SPDX-License-Identifier: MIT

from enum import Enum
import logging as log
import os.path as osp
import random

import pycocotools.mask as mask_utils

from datumaro.components.extractor import (Transform, AnnotationType,
RleMask, Polygon, Bbox)
RleMask, Polygon, Bbox,
LabelCategories, MaskCategories, PointsCategories
)
from datumaro.components.cli_plugin import CliPlugin
import datumaro.util.mask_tools as mask_tools
from datumaro.util.annotation_tools import find_group_leader, find_instances
Expand Down Expand Up @@ -46,7 +49,7 @@ def crop_segments(cls, segment_anns, img_width, img_height):
segments.append(s.points)
elif s.type == AnnotationType.mask:
if isinstance(s, RleMask):
rle = s._rle
rle = s.rle
else:
rle = mask_tools.mask_to_rle(s.image)
segments.append(rle)
Expand Down Expand Up @@ -365,3 +368,116 @@ def transform_item(self, item):
if item.has_image and item.image.filename:
name = osp.splitext(item.image.filename)[0]
return self.wrap_item(item, id=name)

class RemapLabels(Transform, CliPlugin):
DefaultAction = Enum('DefaultAction', ['keep', 'delete'])

@staticmethod
def _split_arg(s):
parts = s.split(':')
if len(parts) != 2:
import argparse
raise argparse.ArgumentTypeError()
return (parts[0], parts[1])

@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-l', '--label', action='append',
type=cls._split_arg, dest='mapping',
help="Label in the form of: '<src>:<dst>' (repeatable)")
parser.add_argument('--default',
choices=[a.name for a in cls.DefaultAction],
default=cls.DefaultAction.keep.name,
help="Action for unspecified labels")
return parser

def __init__(self, extractor, mapping, default=None):
super().__init__(extractor)

assert isinstance(default, (str, self.DefaultAction))
if isinstance(default, str):
default = self.DefaultAction[default]

assert isinstance(mapping, (dict, list))
if isinstance(mapping, list):
mapping = dict(mapping)

self._categories = {}

src_label_cat = self._extractor.categories().get(AnnotationType.label)
if src_label_cat is not None:
self._make_label_id_map(src_label_cat, mapping, default)

src_mask_cat = self._extractor.categories().get(AnnotationType.mask)
if src_mask_cat is not None:
assert src_label_cat is not None
dst_mask_cat = MaskCategories(attributes=src_mask_cat.attributes)
dst_mask_cat.colormap = {
id: src_mask_cat.colormap[id]
for id, _ in enumerate(src_label_cat.items)
if self._map_id(id) or id == 0
}
self._categories[AnnotationType.mask] = dst_mask_cat

src_points_cat = self._extractor.categories().get(AnnotationType.points)
if src_points_cat is not None:
assert src_label_cat is not None
dst_points_cat = PointsCategories(attributes=src_points_cat.attributes)
dst_points_cat.items = {
id: src_points_cat.items[id]
for id, item in enumerate(src_label_cat.items)
if self._map_id(id) or id == 0
}
self._categories[AnnotationType.points] = dst_points_cat

def _make_label_id_map(self, src_label_cat, label_mapping, default_action):
dst_label_cat = LabelCategories(attributes=src_label_cat.attributes)
id_mapping = {}
for src_index, src_label in enumerate(src_label_cat.items):
dst_label = label_mapping.get(src_label.name)
if not dst_label and default_action == self.DefaultAction.keep:
dst_label = src_label.name # keep unspecified as is
if not dst_label:
continue

dst_index = dst_label_cat.find(dst_label)[0]
if dst_index is None:
dst_label_cat.add(dst_label,
src_label.parent, src_label.attributes)
dst_index = dst_label_cat.find(dst_label)[0]
id_mapping[src_index] = dst_index

if log.getLogger().isEnabledFor(log.DEBUG):
log.debug("Label mapping:")
for src_id, src_label in enumerate(src_label_cat.items):
if id_mapping.get(src_id):
log.debug("#%s '%s' -> #%s '%s'",
src_id, src_label.name, id_mapping[src_id],
dst_label_cat.items[id_mapping[src_id]].name
)
else:
log.debug("#%s '%s' -> <deleted>", src_id, src_label.name)

self._map_id = lambda src_id: id_mapping.get(src_id, None)
self._categories[AnnotationType.label] = dst_label_cat

def categories(self):
return self._categories

def transform_item(self, item):
# TODO: provide non-inplace version
annotations = []
for ann in item.annotations:
if ann.type in { AnnotationType.label, AnnotationType.mask,
AnnotationType.points, AnnotationType.polygon,
AnnotationType.polyline, AnnotationType.bbox
} and ann.label is not None:
conv_label = self._map_id(ann.label)
if conv_label is not None:
ann._label = conv_label
annotations.append(ann)
else:
annotations.append(ann)
item._annotations = annotations
return item
25 changes: 12 additions & 13 deletions datumaro/datumaro/plugins/voc_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,13 @@ def _write_xml_bbox(bbox, parent_elem):
class _Converter:
def __init__(self, extractor, save_dir,
tasks=None, apply_colormap=True, save_images=False, label_map=None):
assert tasks is None or isinstance(tasks, (VocTask, list))
assert tasks is None or isinstance(tasks, (VocTask, list, set))
if tasks is None:
tasks = list(VocTask)
tasks = set(VocTask)
elif isinstance(tasks, VocTask):
tasks = [tasks]
tasks = {tasks}
else:
tasks = [t if t in VocTask else VocTask[t] for t in tasks]

tasks = set(t if t in VocTask else VocTask[t] for t in tasks)
self._tasks = tasks

self._extractor = extractor
Expand Down Expand Up @@ -259,10 +258,10 @@ def save_subsets(self):
if len(actions_elem) != 0:
obj_elem.append(actions_elem)

if set(self._tasks) & set([None,
if self._tasks & {None,
VocTask.detection,
VocTask.person_layout,
VocTask.action_classification]):
VocTask.action_classification}:
with open(osp.join(self._ann_dir, item.id + '.xml'), 'w') as f:
f.write(ET.tostring(root_elem,
encoding='unicode', pretty_print=True))
Expand Down Expand Up @@ -302,19 +301,19 @@ def save_subsets(self):
action_list[item.id] = None
segm_list[item.id] = None

if set(self._tasks) & set([None,
if self._tasks & {None,
VocTask.classification,
VocTask.detection,
VocTask.action_classification,
VocTask.person_layout]):
VocTask.person_layout}:
self.save_clsdet_lists(subset_name, clsdet_list)
if set(self._tasks) & set([None, VocTask.classification]):
if self._tasks & {None, VocTask.classification}:
self.save_class_lists(subset_name, class_lists)
if set(self._tasks) & set([None, VocTask.action_classification]):
if self._tasks & {None, VocTask.action_classification}:
self.save_action_lists(subset_name, action_list)
if set(self._tasks) & set([None, VocTask.person_layout]):
if self._tasks & {None, VocTask.person_layout}:
self.save_layout_lists(subset_name, layout_list)
if set(self._tasks) & set([None, VocTask.segmentation]):
if self._tasks & {None, VocTask.segmentation}:
self.save_segm_lists(subset_name, segm_list)

def save_action_lists(self, subset_name, action_list):
Expand Down
98 changes: 96 additions & 2 deletions datumaro/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from unittest import TestCase

from datumaro.components.extractor import (Extractor, DatasetItem,
Mask, Polygon, PolyLine, Points, Bbox
Mask, Polygon, PolyLine, Points, Bbox, Label,
LabelCategories, MaskCategories, AnnotationType
)
from datumaro.util.test_utils import compare_datasets
import datumaro.util.mask_tools as mask_tools
import datumaro.plugins.transforms as transforms
from datumaro.util.test_utils import compare_datasets


class TransformsTest(TestCase):
Expand Down Expand Up @@ -361,3 +363,95 @@ def __iter__(self):
('train', -0.5),
('test', 1.5),
])

def test_remap_labels(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, annotations=[
# Should be remapped
Label(1),
Bbox(1, 2, 3, 4, label=2),
Mask(image=np.array([1]), label=3),

# Should be kept
Polygon([1, 1, 2, 2, 3, 4], label=4),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
]),
])

def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')
label_cat.add('label1')
label_cat.add('label2')
label_cat.add('label3')
label_cat.add('label4')

mask_cat = MaskCategories(
colormap=mask_tools.generate_colormap(5))

return {
AnnotationType.label: label_cat,
AnnotationType.mask: mask_cat,
}

class DstExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, annotations=[
Label(1),
Bbox(1, 2, 3, 4, label=0),
Mask(image=np.array([1]), label=1),

Polygon([1, 1, 2, 2, 3, 4], label=2),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
]),
])

def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')
label_cat.add('label9')
label_cat.add('label4')

mask_cat = MaskCategories(colormap={
k: v for k, v in mask_tools.generate_colormap(5).items()
if k in { 0, 1, 3, 4 }
})

return {
AnnotationType.label: label_cat,
AnnotationType.mask: mask_cat,
}

actual = transforms.RemapLabels(SrcExtractor(), mapping={
'label1': 'label9',
'label2': 'label0',
'label3': 'label9',
}, default='keep')

compare_datasets(self, DstExtractor(), actual)

def test_remap_labels_delete_unspecified(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([ DatasetItem(id=1, annotations=[ Label(0) ]) ])

def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')

return { AnnotationType.label: label_cat }

class DstExtractor(Extractor):
def __iter__(self):
return iter([ DatasetItem(id=1, annotations=[]) ])

def categories(self):
return { AnnotationType.label: LabelCategories() }

actual = transforms.RemapLabels(SrcExtractor(),
mapping={}, default='delete')

compare_datasets(self, DstExtractor(), actual)

0 comments on commit be5577d

Please sign in to comment.