Skip to content

Commit

Permalink
[Datumaro] Plugins and transforms (#1126)
Browse files Browse the repository at this point in the history
* Fix model run command

* Rename annotation types, update class interfaces

* Fix random cvat format test fails

* Mask operations and dataset format fixes

* Update tests, extract format testing functions

* Add transform interface

* Implement plugin system

* Update tests with plugins

* Fix logging

* Add transfroms

* Update cvat integration
  • Loading branch information
zhiltsov-max authored Feb 6, 2020
1 parent 939de86 commit 2848f1d
Show file tree
Hide file tree
Showing 63 changed files with 2,287 additions and 1,641 deletions.
14 changes: 5 additions & 9 deletions cvat/apps/dataset_manager/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,35 +156,31 @@ def convert_attrs(label, cvat_attrs):

for tag_obj in cvat_anno.tags:
anno_group = tag_obj.group
if isinstance(anno_group, int):
anno_group = anno_group
anno_label = map_label(tag_obj.label)
anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes)

anno = datumaro.LabelObject(label=anno_label,
anno = datumaro.Label(label=anno_label,
attributes=anno_attr, group=anno_group)
item_anno.append(anno)

for shape_obj in cvat_anno.labeled_shapes:
anno_group = shape_obj.group
if isinstance(anno_group, int):
anno_group = anno_group
anno_label = map_label(shape_obj.label)
anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes)

anno_points = shape_obj.points
if shape_obj.type == ShapeType.POINTS:
anno = datumaro.PointsObject(anno_points,
anno = datumaro.Points(anno_points,
label=anno_label, attributes=anno_attr, group=anno_group)
elif shape_obj.type == ShapeType.POLYLINE:
anno = datumaro.PolyLineObject(anno_points,
anno = datumaro.PolyLine(anno_points,
label=anno_label, attributes=anno_attr, group=anno_group)
elif shape_obj.type == ShapeType.POLYGON:
anno = datumaro.PolygonObject(anno_points,
anno = datumaro.Polygon(anno_points,
label=anno_label, attributes=anno_attr, group=anno_group)
elif shape_obj.type == ShapeType.RECTANGLE:
x0, y0, x1, y1 = anno_points
anno = datumaro.BboxObject(x0, y0, x1 - x0, y1 - y0,
anno = datumaro.Bbox(x0, y0, x1 - x0, y1 - y0,
label=anno_label, attributes=anno_attr, group=anno_group)
else:
raise Exception("Unknown shape type '%s'" % (shape_obj.type))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
'server_port': 80
}, schema=CONFIG_SCHEMA, mutable=False)

class cvat_rest_api_task_images(datumaro.Extractor):
class cvat_rest_api_task_images(datumaro.SourceExtractor):
def _image_local_path(self, item_id):
task_id = self._config.task_id
return osp.join(self._cache_dir,
Expand All @@ -53,7 +53,7 @@ def _connect(self):

session = None
try:
print("Enter credentials for '%s:%s':" % \
print("Enter credentials for '%s:%s' to read task data:" % \
(self._config.server_host, self._config.server_port))
username = input('User: ')
password = getpass.getpass()
Expand Down
90 changes: 7 additions & 83 deletions cvat/apps/dataset_manager/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import django_rq

from cvat.apps.engine.log import slogger
from cvat.apps.engine.models import Task, ShapeType
from cvat.apps.engine.models import Task
from .util import current_function_name, make_zip_archive

_CVAT_ROOT_DIR = __file__[:__file__.rfind('cvat/')]
_DATUMARO_REPO_PATH = osp.join(_CVAT_ROOT_DIR, 'datumaro')
sys.path.append(_DATUMARO_REPO_PATH)
from datumaro.components.project import Project
from datumaro.components.project import Project, Environment
import datumaro.components.extractor as datumaro
from .bindings import CvatImagesDirExtractor, CvatTaskExtractor

Expand Down Expand Up @@ -132,83 +132,7 @@ def _generate_categories(self):
return categories

def put_annotations(self, annotations):
patch = {}

categories = self._dataset.categories()
label_cat = categories[datumaro.AnnotationType.label]

label_map = {}
attr_map = {}
db_labels = self._db_task.label_set.all()
for db_label in db_labels:
label_map[db_label.id] = label_cat.find(db_label.name)

db_attributes = db_label.attributespec_set.all()
for db_attr in db_attributes:
attr_map[(db_label.id, db_attr.id)] = db_attr.name
map_label = lambda label_db_id: label_map[label_db_id]
map_attr = lambda label_db_id, attr_db_id: \
attr_map[(label_db_id, attr_db_id)]

for tag_obj in annotations['tags']:
item_id = str(tag_obj['frame'])
item_anno = patch.get(item_id, [])

anno_group = tag_obj['group']
if isinstance(anno_group, int):
anno_group = [anno_group]
anno_label = map_label(tag_obj['label_id'])
anno_attr = {}
for attr in tag_obj['attributes']:
attr_name = map_attr(tag_obj['label_id'], attr['id'])
anno_attr[attr_name] = attr['value']

anno = datumaro.LabelObject(label=anno_label,
attributes=anno_attr, group=anno_group)
item_anno.append(anno)

patch[item_id] = item_anno

for shape_obj in annotations['shapes']:
item_id = str(shape_obj['frame'])
item_anno = patch.get(item_id, [])

anno_group = shape_obj['group']
if isinstance(anno_group, int):
anno_group = [anno_group]
anno_label = map_label(shape_obj['label_id'])
anno_attr = {}
for attr in shape_obj['attributes']:
attr_name = map_attr(shape_obj['label_id'], attr['id'])
anno_attr[attr_name] = attr['value']

anno_points = shape_obj['points']
if shape_obj['type'] == ShapeType.POINTS:
anno = datumaro.PointsObject(anno_points,
label=anno_label, attributes=anno_attr, group=anno_group)
elif shape_obj['type'] == ShapeType.POLYLINE:
anno = datumaro.PolyLineObject(anno_points,
label=anno_label, attributes=anno_attr, group=anno_group)
elif shape_obj['type'] == ShapeType.POLYGON:
anno = datumaro.PolygonObject(anno_points,
label=anno_label, attributes=anno_attr, group=anno_group)
elif shape_obj['type'] == ShapeType.RECTANGLE:
x0, y0, x1, y1 = anno_points
anno = datumaro.BboxObject(x0, y0, x1 - x0, y1 - y0,
label=anno_label, attributes=anno_attr, group=anno_group)
else:
raise Exception("Unknown shape type '%s'" % (shape_obj['type']))

item_anno.append(anno)

patch[item_id] = item_anno

# TODO: support track annotations

patch = [datumaro.DatasetItem(id=id_, annotations=anno) \
for id_, ann in patch.items()]

self._dataset.update(patch)
raise NotImplementedError()

def save(self, save_dir=None, save_images=False):
if self._dataset is not None:
Expand Down Expand Up @@ -296,10 +220,10 @@ def _remote_export(self, save_dir, server_url=None):
osp.join(templates_dir, 'README.md'),
osp.join(target_dir, 'README.md'))

templates_dir = osp.join(templates_dir, 'extractors')
templates_dir = osp.join(templates_dir, 'plugins')
target_dir = osp.join(target_dir,
exported_project.config.env_dir,
exported_project.env.config.extractors_dir)
exported_project.config.plugins_dir)
os.makedirs(target_dir, exist_ok=True)
shutil.copyfile(
osp.join(templates_dir, _TASK_IMAGES_REMOTE_EXTRACTOR + '.py'),
Expand Down Expand Up @@ -409,9 +333,9 @@ def clear_export_cache(task_id, file_path, file_ctime):
]

def get_export_formats():
from datumaro.components import converters
converters = Environment().converters

available_formats = set(name for name, _ in converters.items)
available_formats = set(converters.items)
available_formats.add(EXPORT_FORMAT_DATUMARO_PROJECT)

public_formats = []
Expand Down
65 changes: 61 additions & 4 deletions datumaro/datumaro/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import argparse
import logging as log
import logging.handlers
import os
import sys

from . import contexts, commands
Expand Down Expand Up @@ -81,15 +83,70 @@ def make_parser():

return parser

def set_up_logger(args):
log.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
level=args.loglevel)
class _LogManager:
_LOGLEVEL_ENV_NAME = '_DATUMARO_INIT_LOGLEVEL'
_BUFFER_SIZE = 1000
_root = None
_init_handler = None
_default_handler = None

@classmethod
def init_basic_logger(cls):
base_loglevel = os.getenv(cls._LOGLEVEL_ENV_NAME, 'info')
base_loglevel = loglevel(base_loglevel)
root = log.getLogger()
root.setLevel(base_loglevel)

# NOTE: defer use of this handler until the logger
# is properly initialized, but keep logging enabled before this.
# Store messages obtained during initialization and print them after
# if necessary.
default_handler = log.StreamHandler()
default_handler.setFormatter(
log.Formatter('%(asctime)s %(levelname)s: %(message)s'))

init_handler = logging.handlers.MemoryHandler(cls._BUFFER_SIZE,
target=default_handler)
root.addHandler(init_handler)

cls._root = root
cls._init_handler = init_handler
cls._default_handler = default_handler

@classmethod
def set_up_logger(cls, level):
log.getLogger().setLevel(level)

if cls._init_handler:
# NOTE: Handlers are not capable of filtering with loglevel
# despite a level can be set for a handler. The level is checked
# by Logger. However, handler filters are checked at handler level.
class LevelFilter:
def __init__(self, level):
super().__init__()
self.level = level

def filter(self, record):
return record.levelno >= self.level
filt = LevelFilter(level)
cls._default_handler.addFilter(filt)

cls._root.removeHandler(cls._init_handler)
cls._init_handler.close()
del cls._init_handler
cls._init_handler = None

cls._default_handler.removeFilter(filt)

cls._root.addHandler(cls._default_handler)

def main(args=None):
_LogManager.init_basic_logger()

parser = make_parser()
args = parser.parse_args(args)

set_up_logger(args)
_LogManager.set_up_logger(args.loglevel)

if 'command' not in args:
parser.print_help()
Expand Down
Loading

0 comments on commit 2848f1d

Please sign in to comment.