diff --git a/.gitignore b/.gitignore index 00e9b3f9ddd..ba3393384c1 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,8 @@ __pycache__ # Ignore development npm files node_modules +# Ignore goturn model +goturn.caffemodel +goturn.prototxt + .DS_Store diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 278f79ddc48..407b945f525 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -134,6 +134,17 @@ wget https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_co export AUTO_SEGMENTATION_PATH="/path/to/dir" # dir must contain mask_rcnn_coco.h5 file ``` +### Tracking +- Download the GOTURN model if the the OpenCV GOTURN tracker is to be used. +```sh +#This downloads it into the cvat folder +components/tracking/install.sh +``` +- Add next lines to ``.env/bin/activate``: +```sh + export TRACKING="yes" +``` + ## JavaScript/Typescript coding style We use the [Airbnb JavaScript Style Guide](https://github.com/airbnb/javascript) for JavaScript code with a diff --git a/Dockerfile b/Dockerfile index 4b2a04484ad..4f8a4c62375 100644 --- a/Dockerfile +++ b/Dockerfile @@ -90,6 +90,13 @@ RUN if [ "$AUTO_SEGMENTATION" = "yes" ]; then \ bash -i /tmp/components/auto_segmentation/install.sh; \ fi +ARG TRACKING +ENV TRACKING=${TRACKING} +ENV TRACKING_PATH=${HOME}/tracking +RUN if [ "$TRACKING" = "yes" ]; then \ + bash -i /tmp/components/tracking/install.sh; \ + fi + ARG WITH_TESTS RUN if [ "$WITH_TESTS" = "yes" ]; then \ wget -q -O - https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add - && \ diff --git a/components/tracking/README.md b/components/tracking/README.md new file mode 100644 index 00000000000..c872c72916c --- /dev/null +++ b/components/tracking/README.md @@ -0,0 +1,19 @@ +## Tracking +This components allows to track bounding boxes in consecutive images. + +### Build docker image +```bash +# From project root directory +docker-compose -f docker-compose.yml -f components/tracking/docker-compose.tracking.yml build +``` + +### Run docker container +```bash +# From project root directory +docker-compose -f docker-compose.yml -f components/tracking/docker-compose.tracking.yml up -d +``` + +### TODO +* Make API consistent (one request per tracking job) +* Put jobs into queque +* Enable by default? \ No newline at end of file diff --git a/components/tracking/docker-compose.tracking.yml b/components/tracking/docker-compose.tracking.yml new file mode 100644 index 00000000000..3b3b28ee853 --- /dev/null +++ b/components/tracking/docker-compose.tracking.yml @@ -0,0 +1,13 @@ +# +# Copyright (C) 2018 Intel Corporation +# +# SPDX-License-Identifier: MIT +# +version: "2.3" + +services: + cvat: + build: + context: . + args: + TRACKING: "yes" diff --git a/components/tracking/install.sh b/components/tracking/install.sh new file mode 100755 index 00000000000..d55f48d65f1 --- /dev/null +++ b/components/tracking/install.sh @@ -0,0 +1,15 @@ +#!/bin/sh +set -e + +# Install GOTURN model for OpenCV; It is not in a package and has to be downloaded +cd "$(dirname "$0")" +echo "Downloading GOTURN model. This can take a while." +wget -q 'https://github.com/opencv/opencv_extra/raw/c4219d5eb3105ed8e634278fad312a1a8d2c182d/testdata/tracking/goturn.prototxt' +wget -q 'https://github.com/opencv/opencv_extra/raw/c4219d5eb3105ed8e634278fad312a1a8d2c182d/testdata/tracking/goturn.caffemodel.zip.001' +wget -q 'https://github.com/opencv/opencv_extra/raw/c4219d5eb3105ed8e634278fad312a1a8d2c182d/testdata/tracking/goturn.caffemodel.zip.002' +wget -q 'https://github.com/opencv/opencv_extra/raw/c4219d5eb3105ed8e634278fad312a1a8d2c182d/testdata/tracking/goturn.caffemodel.zip.003' +wget -q 'https://github.com/opencv/opencv_extra/raw/c4219d5eb3105ed8e634278fad312a1a8d2c182d/testdata/tracking/goturn.caffemodel.zip.004' +cat goturn.caffemodel.zip* > goturn.caffemodel.zip +unzip goturn.caffemodel.zip +mv goturn.caffemodel goturn.prototxt ../.. +rm goturn.caffemodel.zip* \ No newline at end of file diff --git a/cvat/apps/engine/static/engine/js/shapes.js b/cvat/apps/engine/static/engine/js/shapes.js index 2ff5c4e96ac..cf46dd256b9 100644 --- a/cvat/apps/engine/static/engine/js/shapes.js +++ b/cvat/apps/engine/static/engine/js/shapes.js @@ -142,21 +142,22 @@ class ShapeModel extends Listener { return interpolated; } - _neighboringFrames(frame) { + _neighboringFrames(frame, predicate = () => true) { if (!Number.isInteger(frame) || frame < 0) { throw Error(`Got invalid frame: ${frame}`); } let leftFrame = null; let rightFrame = null; - for (let frameKey in this._positions) { frameKey = +frameKey; - if (frameKey < frame && (frameKey > leftFrame || leftFrame === null)) { + if (frameKey < frame && (frameKey > leftFrame || leftFrame === null) + && predicate(this._positions[frameKey])) { leftFrame = frameKey; } - if (frameKey > frame && (frameKey < rightFrame || rightFrame === null)) { + if (frameKey > frame && (frameKey < rightFrame || rightFrame === null) + && predicate(this._positions[frameKey])) { rightFrame = frameKey; } } @@ -435,12 +436,14 @@ class ShapeModel extends Listener { this.notify('click'); } - prevKeyFrame() { - return this._neighboringFrames(window.cvat.player.frames.current)[0]; + prevKeyFrame(predicate = () => true) { + return this._neighboringFrames(window.cvat.player.frames.current, + predicate)[0]; } - nextKeyFrame() { - return this._neighboringFrames(window.cvat.player.frames.current)[1]; + nextKeyFrame(predicate = () => true) { + return this._neighboringFrames(window.cvat.player.frames.current, + predicate)[1]; } initKeyFrame() { @@ -678,7 +681,8 @@ class BoxModel extends ShapeModel { ytl: Math.clamp(position.ytl, 0, window.cvat.player.geometry.frameHeight), xbr: Math.clamp(position.xbr, 0, window.cvat.player.geometry.frameWidth), ybr: Math.clamp(position.ybr, 0, window.cvat.player.geometry.frameHeight), - occluded: position.occluded, + byMachine: position.byMachine ? true : false, + occluded: position.occluded ? true : false, z_order: position.z_order, }; diff --git a/cvat/apps/tracking/__init__.py b/cvat/apps/tracking/__init__.py new file mode 100644 index 00000000000..b66d1baa777 --- /dev/null +++ b/cvat/apps/tracking/__init__.py @@ -0,0 +1,10 @@ + +# Copyright (C) 2018-2019 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from cvat.settings.base import JS_3RDPARTY + +# default_app_config = 'cvat.apps.tracking.apps.TrackingConfig' + +JS_3RDPARTY['engine'] = JS_3RDPARTY.get('engine', []) + ['tracking/js/enginePlugin.js'] diff --git a/cvat/apps/tracking/admin.py b/cvat/apps/tracking/admin.py new file mode 100644 index 00000000000..3c40ebdfe11 --- /dev/null +++ b/cvat/apps/tracking/admin.py @@ -0,0 +1,8 @@ + +# Copyright (C) 2018 Intel Corporation +# +# SPDX-License-Identifier: MIT + + +# Register your models here. + diff --git a/cvat/apps/tracking/apps.py b/cvat/apps/tracking/apps.py new file mode 100644 index 00000000000..9bc2f945e18 --- /dev/null +++ b/cvat/apps/tracking/apps.py @@ -0,0 +1,11 @@ + +# Copyright (C) 2018 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from django.apps import AppConfig + + +class TrackingConfig(AppConfig): + name = 'tracking' + diff --git a/cvat/apps/tracking/migrations/__init__.py b/cvat/apps/tracking/migrations/__init__.py new file mode 100644 index 00000000000..d8e62e54b35 --- /dev/null +++ b/cvat/apps/tracking/migrations/__init__.py @@ -0,0 +1,5 @@ + +# Copyright (C) 2018 Intel Corporation +# +# SPDX-License-Identifier: MIT + diff --git a/cvat/apps/tracking/models.py b/cvat/apps/tracking/models.py new file mode 100644 index 00000000000..37401bdd220 --- /dev/null +++ b/cvat/apps/tracking/models.py @@ -0,0 +1,8 @@ + +# Copyright (C) 2018 Intel Corporation +# +# SPDX-License-Identifier: MIT + + +# Create your models here. + diff --git a/cvat/apps/tracking/static/tracking/js/enginePlugin.js b/cvat/apps/tracking/static/tracking/js/enginePlugin.js new file mode 100644 index 00000000000..6c41863cb98 --- /dev/null +++ b/cvat/apps/tracking/static/tracking/js/enginePlugin.js @@ -0,0 +1,456 @@ +/* + * Copyright (C) 2018 Intel Corporation + * + * SPDX-License-Identifier: MIT + */ + +/* global showMessage userConfirm */ + +/* + * Copyright (C) 2019 + * + * SPDX-License-Identifier: MIT + */ + +/* exported ShapeTracker ShapeTrackerController ShapeTrackerView buildShapeTracker */ + +/* global + Mousetrap: false + Listener: false + showOverlay: false + Logger: false +*/ + + +document.addEventListener('DOMContentLoaded', () => { + const TrackerTypes = { + BOOSTING: 'BOOSTING', + MIL: 'MIL', + KCF: 'KCF', + CSRT: 'CSRT', + MEDIANFLOW: 'MEDIANFLOW', + TLD: 'TLD', + MOSSE: 'MOSSE', + GOTRUN: 'GOTRUN', + }; + + const EndFrameModes = { + nextKeyframe: 'nextKeyframe', + nextUserKeyframe: 'nextUserKeyframe', + nextUserKeyframeMax: 'nextUserKeyframeMax', + end: 'end', + }; + + + //TODO: Why does the class Listener is not called Observable or Listenable? + /** Class to track shapes + * + * This tracks a object that is marked with a bounding box in consecutive + * images by adding new keyframes. + * . */ + class ShapeTrackerModel extends Listener { + constructor(shapeCollection) { + super('onShapeTrackerModelUpdate', () => this._state); + this._shapeCollection = shapeCollection; + this.trackerOptions = { + trackerType: TrackerTypes.BOOSTING, + }; + this._endFrameMode = EndFrameModes.nextUserKeyframe; + this.maxNrOfFrames = 20; + this._state = { + jobRunning: false, + }; + } + + // Setter and getters + set trackerType(value) { + if (value in TrackerTypes) { + this.trackerOptions.trackerType = value; + } else { + console.error('Wrong tracker type:'+value); + } + } + + get trackerType() { + return this.trackerOptions.trackerType; + } + + set endFrameMode(value) { + if (value in EndFrameModes) { + this._endFrameMode = value; + } else { + console.error('Wrong end frame mode:'+value); + } + } + + get endFrameMode() { + return this._endFrameMode; + } + + // Actions + /* Run the tracker on the currently selected shape. + * This insert new shapes from tracker. + */ + async trackCurrentSelectedShape() { + // TODO: Allow to track multiple shapes in parallel + this._state.jobRunning = true; + this.notify('startTracking'); + + const { activeShape } = this._shapeCollection; + if (activeShape) { + await this._trackShape(activeShape); + } + + this._state.jobRunning = false; + this.notify('stopTracking'); + } + + /* Run the tracker on the specified shape. + * This insert new shapes from tracker into the shapeCollection. + */ + async _trackShape(shape) { + // Can only track box shapes in interpolation mode + if (shape.type.startsWith('annotation')) { + showMessage('Can only track interpolation shapes.\n'+ + 'Make sure the you choose Interpolation instead of'+ + 'Annotation in the bottom right.'); + return; + } + if (!shape.type.endsWith('box')) { + showMessage('Can only track box shapes.\n'+ + 'Make sure the you choose Box in the bottom right.'); + return; + } + + // Save shapes from all frames for undu/redo + const oldPositions = JSON.parse(JSON.stringify(shape._positions)); + + // Determine first and last frame the tracker should be applied to + const { startFrame, stopFrame } = this._getStartStopFrame(shape); + + // Send request to server and insert new shapes + try { + let track = await this._makeRequest(shape, startFrame, stopFrame); + this._insertTrackedShapes(shape, track); + } catch (error) { + // _makeRequest has been rejected + console.error(error); + showMessage(error); + return; + } + + // Undo/redo code for complete tracking action + window.cvat.addAction('Track object', () => { + shape._positions = oldPositions; + }, () => { + result.forEach(track => this._insertTrackedShapes(shape, track)); + }, startFrame); + // End of undo/redo code + } + + /* Return the frist and last frame (number) this tracker should be + * applied to as an array of the form. + * {startFrame:, stopFrame:} + * + * startFrame: start tracking at this frame (included) + * stopFrame: track until this frame (exluded) + * This is determined by the settings (frame mode etc.). + */ + _getStartStopFrame(shape) { + const startFrame = window.cvat.player.frames.current; + if (this._endFrameMode === EndFrameModes.nextKeyframe) { + // Track until the next keyframe (or the end if there is none) + const nextKeyFrame = shape.nextKeyFrame(); + const stopFrame = nextKeyFrame != null ? nextKeyFrame + : window.cvat.player.frames.stop+1; + return { startFrame, stopFrame }; + } + if (this._endFrameMode === EndFrameModes.nextUserKeyframe) { + // Option: Track until keyframe by human + // nextHumanFrame == null if there is no such frame + const nextHumanFrame = shape.nextKeyFrame(frame => !frame.byMachine); + const stopFrame = Math.min(nextHumanFrame != null ? nextHumanFrame + : window.cvat.player.frames.stop+1, window.cvat.player.frames.stop + 1); + return { startFrame, stopFrame }; + } + if (this._endFrameMode === EndFrameModes.nextUserKeyframeMax) { + // Track at most 10 frames and until keyframe by human + // nextHumanFrame == null if there is no such frame + const nextHumanFrame = shape.nextKeyFrame(frame => !frame.byMachine); + const stopFrame = Math.min(window.cvat.player.frames.current + 10, + nextHumanFrame != null ? nextHumanFrame + : window.cvat.player.frames.stop + 1, + window.cvat.player.frames.stop + 1); + return { startFrame, stopFrame }; + } + + return { startFrame, stopFrame: window.cvat.player.frames.stop+1 }; + } + + async _makeRequest(shape, startFrame, stopFrame) { + const exportedCollection = this._shapeCollection.export(); + const serverReprAndclientReprPairs = exportedCollection[1]; + const track = serverReprAndclientReprPairs + .filter(serverAndClient => serverAndClient[1] === shape)[0][0]; + + if (!track) { + // Should never happen. + return new Error('Tracker: Could not transform to server model!'); + } + + const trackingJob = { + trackId: track.id, + track, + startFrame, + stopFrame, + trackerOptions: this.trackerOptions, + }; + + const data = { + jobId:window.cvat.job.id, + trackingJob: trackingJob, + }; + + return new Promise((resolve, reject) => { + $.ajax({ + url: '/tracking/track', + type: 'POST', + data: JSON.stringify(data), + contentType: 'application/json', + }).done((newShapes) => { + resolve(newShapes); + }).fail((errorData) => { + const message = `Could not track shape. Code: ${errorData.status}. ` + + `Message: ${errorData.responseText || errorData.statusText}`; + reject(new Error(message)); + }); + }); + } + + _insertTrackedShapes(shape, trackingJobResult) { + trackingJobResult.shapes.forEach((serverShape) => { + // eslint-disable-next-line no-underscore-dangle + const position = window.cvat.translate._boxTranslator + .serverToClient(serverShape); + position.byMachine = true; + position.outside = serverShape.outside; + position.occluded = serverShape.occluded; + position.z_order = serverShape.z_order; + // Insert/update postion silenty, i.e. wihtout adding to undo + // stack. Complete tracking action is added to undo + shape.updatePosition(serverShape.frame, position, true); + }); + } + + _logStatus() { + console.log( + 'ShapeTrackerModel:', + this.trackerOptions, + this._endFrameMode, + this.maxNrOfFrames, + this._state, + ); + } + } + + class ShapeTrackerController { + constructor(shapeTrackerModel) { + this._shapeTrackerModel = shapeTrackerModel; + } + + trackCurrentSelectedShape() { + this._shapeTrackerModel.trackCurrentSelectedShape(); + } + + setTrackerType(type) { + this._shapeTrackerModel.trackerType = type; + } + + setEndFrame(mode) { + this._shapeTrackerModel.endFrameMode = mode; + } + + setMaxNrOfFrames(frames) { + this._shapeTrackerModel.maxNrOfFrames = frames; + } + } + + class ShapeTrackerView { + constructor(shapeTrackerModel, shapeTrackerController) { + this._shapeTrackerModel = shapeTrackerModel; + this._shapeTrackerController = shapeTrackerController; + this._shapeTrackerModel.subscribe(this); + this._overlay = null; + this._setupSettingGUI(); + this._setupContextMenuGUI(); + this._setupKeybindings(); + } + + /* Setup key bindings.*/ + _setupKeybindings(){ + // TODO: Should the userConfig.js file be modified instead? + // Insert entry for new short cut + window.cvat.config._shortkeys.track = { + value: ['t'], + view_value: 'T', + description: 'Track', + }; + + // Connect shortcuts for actions + const trackHandler = Logger.shortkeyLogDecorator(() => { + this._shapeTrackerController.trackCurrentSelectedShape(); + }).bind(this); + + Mousetrap.bind( + window.cvat.config._shortkeys.track.value, + trackHandler.bind(this), + 'keydown', + ); + } + + /* Create and connect DOM elemtents in settings page. */ + _setupSettingGUI(){ + // Create DOM elements in settings menu + // TODO: Should the annotation.html be modified instead? + // TODO: How does this work with React? + let otherSettings = $('#otherSettigns'); + otherSettings.css('height','60%'); + $(`
+
+ + + + + + + + + +
+ +
+ +
+
+ `).insertAfter(otherSettings); + + // Connect to setting UI + // Bind Tracker type selection + this._shapeTrackerTypeSelect = $('#shapeTrackerTypeSelect'); + this._shapeTrackerTypeSelect.on('change', + e => this._shapeTrackerController.setTrackerType(e.target.value)); + + // Bind Selection of tracker mode for next frame + this._shapeTrackerEndSelect = $('#shapeTrackerEndSelect'); + this._shapeTrackerEndSelect.on('change', (e) => { + const { target: { value } } = e; + if (value.startsWith(EndFrameModes.nextUserKeyframeMax)) { + const maxNrOfFrames = parseInt( + value.replace(EndFrameModes.nextUserKeyframeMax, ''), 10, + ); + + this._shapeTrackerController.setMaxNrOfFrames(maxNrOfFrames); + this._shapeTrackerController.setEndFrame( + EndFrameModes.nextUserKeyframeMax, + ); + } else { + this._shapeTrackerController.setEndFrame(value); + } + }); + + this._updateSetting(); + } + + /* Create and connect DOM elemtents in context menu. */ + _setupContextMenuGUI(){ + // Create DOM elements + let shapeContextMenu = $('#shapeContextMenu'); + $(`
  • Track
  • `).appendTo(shapeContextMenu); + + // Connect menue entry in context menu + $('#shapeContextMenu li').click((e) => { + switch ($(e.target).attr('action')) { + case 'trackShape': + this._shapeTrackerController.trackCurrentSelectedShape(); + break; + default: + break; + } + }); + } + + /* Update setting gui according to model. */ + _updateSetting() { + this._shapeTrackerTypeSelect.val(this._shapeTrackerModel.trackerType); + let trackerEndSelectVal = this._shapeTrackerModel.endFrameMode + === EndFrameModes.nextUserKeyframeMax + ? EndFrameModes.nextUserKeyframeMax + this._shapeTrackerModel.maxNrOfFrames + : this._shapeTrackerModel.endFrameMode; + this._shapeTrackerEndSelect.val(trackerEndSelectVal); + } + + /* Called whenever the shapeTrackerModel changes. */ + onShapeTrackerModelUpdate(state) { + if (state.jobRunning) { + this._overlay = showOverlay('Wait for tracking result ...'); + } else if (this._overlay) { + this._overlay.remove(); + this._overlay = null; + } + } + } + + function buildShapeTracker(shapeCollectionModel) { + let trackerModel = new ShapeTrackerModel(shapeCollectionModel); + let trackerContoller = new ShapeTrackerController(trackerModel); + let trackerView = new ShapeTrackerView(trackerModel, trackerContoller); + + // TODO: remove (only attached to window for easy debugging) + window.trackerModel = trackerModel; + window.trackerController = trackerContoller; + window.trackerView = trackerView; + } + + // Patch the shape collection constructor to get a reference + // to the shape collection + // TODO: Should tracker be integrated into engine instead of a enginePlugin? + function ShapeCollectionModelWrapper(OriginalClass) { + // Constructor will patch some properties for each instance + function constructorDecorator(...args) { + const shapeCollectionModel = new OriginalClass(...args); + buildShapeTracker(shapeCollectionModel); + return shapeCollectionModel; + } + + constructorDecorator.prototype = OriginalClass.prototype; + constructorDecorator.prototype.constructor = constructorDecorator; + return constructorDecorator; + } + + // Apply patch for classes + ShapeCollectionModel = ShapeCollectionModelWrapper(ShapeCollectionModel); +}); + + + + diff --git a/cvat/apps/tracking/tests.py b/cvat/apps/tracking/tests.py new file mode 100644 index 00000000000..d20a46ab6a6 --- /dev/null +++ b/cvat/apps/tracking/tests.py @@ -0,0 +1,8 @@ + +# Copyright (C) 2018 Intel Corporation +# +# SPDX-License-Identifier: MIT + + +# Create your tests here. + diff --git a/cvat/apps/tracking/tracker.py b/cvat/apps/tracking/tracker.py new file mode 100644 index 00000000000..4f470218937 --- /dev/null +++ b/cvat/apps/tracking/tracker.py @@ -0,0 +1,101 @@ +from cvat.apps.engine.models import TrackedShape +import cv2 +import copy + + +def rectanlge_to_cv_bbox(rectangle_points): + """ + Convert the CVAT rectangle points (serverside) to a OpenCV rectangle. + :param tuple rectangle_points: Tuple of form (x1,y1,x2,y2) + :return: Form (x1, y1, width, height) + """ + # Dimensions must be ints, otherwise tracking throws a exception + return (int(rectangle_points[0]), int(rectangle_points[1]), + int(rectangle_points[2] - rectangle_points[0]), + int(rectangle_points[3] - rectangle_points[1])) + +def cv_bbox_to_rectangle(cv_bbox): + """ + Convert the OpenCV bounding box points to a CVAT rectangle points. + :param tuple cv_bbox: Form (x1, y1, width, height) + :return: Form (x1,y1,x2,y2) + """ + return (cv_bbox[0], cv_bbox[1], cv_bbox[0] + cv_bbox[2], cv_bbox[1] + cv_bbox[3]) + +def image_iterable_from_task(task, start_frame, stop_frame): + """ + Create a iterable to iterate over the images from a CVAT task. + :param Task task The Django model of type Task + :param int start_frame: Frame number where iteration should start (included) + :param int stop_frame: First frame that is excluded from iteration (excluded) + :return: Iterable over OpenCV images + """ + for frame in range(start_frame, stop_frame): + image_path = task.get_frame_path(frame) + img = cv2.imread(image_path) + yield frame, img + +class RectangleTracker: + trackerTypes = ['BOOSTING', 'MIL', 'KCF', 'CSRT', 'MEDIANFLOW', 'TLD', + 'MOSSE', 'GOTRUN'] + + def __init__(self, trackerType = "BOOSTING"): + """Create tracker. + :param str trackerType: String specifying tracker, see trackerTypes. + """ + trackerTypes_constructor = { + 'BOOSTING': cv2.TrackerBoosting_create, + 'MIL': cv2.TrackerMIL_create, + 'KCF': cv2.TrackerKCF_create, + 'CSRT': cv2.TrackerCSRT_create, + 'MEDIANFLOW': cv2.TrackerMedianFlow_create, + 'TLD': cv2.TrackerTLD_create, + 'MOSSE': cv2.TrackerMOSSE_create, + 'GOTRUN': cv2.TrackerGOTURN_create, + } + if trackerType not in trackerTypes_constructor: + raise Exception("Tracker type not known:" + trackerType) + self._tracker = trackerTypes_constructor[trackerType]() + + def track_rectangles(self, task, start_shape, stop_frame): + """ + Follow an the rectangle in consecutive frames of a task. + :param Task task: The Django Task with the images + :param TrackedShape start_shape: Start tracking with this shape; This + specifies the frame to start at (start_shape.frame). + :param int stop_frame: Stop tracking at this frame (excluded). + :return: List of Shapes (Rectangles) with new shapes. + """ + if not isinstance(start_shape, TrackedShape): + raise Exception("start_shape must be of type TrackedShape") + + # Only track in to future. + start_frame = start_shape.frame + if stop_frame < start_frame: + return [] + + # Load the image iterable for range of frames + # and init the tracker with the bounding box from the user given shape + images = image_iterable_from_task(task, start_frame, stop_frame) + img0 = next(images)[1] + bbox = rectanlge_to_cv_bbox(start_shape.points) + no_error = self._tracker.init(img0, bbox) + + #Generated shapes + shapes_by_tracking = [] + for frame, img in images: + # Let the tracker find the bounding box in the next image(s) + no_error, bbox = self._tracker.update(img) + + if no_error: + new_shape = copy.copy(start_shape) + new_shape.pk = None + new_shape.points = cv_bbox_to_rectangle(bbox) + new_shape.frame = frame + shapes_by_tracking.append(new_shape) + else: + break + + return shapes_by_tracking + + diff --git a/cvat/apps/tracking/urls.py b/cvat/apps/tracking/urls.py new file mode 100644 index 00000000000..c704a0ddd88 --- /dev/null +++ b/cvat/apps/tracking/urls.py @@ -0,0 +1,11 @@ + +# Copyright (C) 2018 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from django.urls import path +from . import views + +urlpatterns = [ + path('track', views.track), +] diff --git a/cvat/apps/tracking/views.py b/cvat/apps/tracking/views.py new file mode 100644 index 00000000000..99b9652265c --- /dev/null +++ b/cvat/apps/tracking/views.py @@ -0,0 +1,63 @@ + +# Copyright (C) 2018-2020 Intel Corporation +# +# SPDX-License-Identifier: MIT +import copy + +from django.http import JsonResponse +from rest_framework.decorators import api_view +from rules.contrib.views import permission_required + +from cvat.apps.authentication.decorators import login_required +from cvat.apps.engine.data_manager import TrackManager +from cvat.apps.engine.models import (Job, TrackedShape) +from cvat.apps.engine.serializers import (TrackedShapeSerializer) +from .tracker import RectangleTracker + +# TODO: Put tracker into background task. + +@api_view(['POST']) +@login_required +@permission_required(perm=['engine.task.access'], raise_exception=True) +def track(request): + # Track (with bounding boxes) that should be enriched with new shapes. + # Done by tracking a existing bounding box + + tracking_job = request.data['trackingJob'] + job_id = request.data['jobId'] + track = tracking_job['track'] #already in server model + # Start the tracking with the bounding box in this frame + start_frame = tracking_job['startFrame'] + # Until track this bounding box until this frame (excluded) + stop_frame = tracking_job['stopFrame'] + # Track the bounding boxes in images from this track + task = Job.objects.get(pk=job_id).segment.task + + # If we in a large task this creates unnessary many shapes + # We only need them between start_frame and stop_frame + shapes_of_track = TrackManager([tracking_job['track']]).to_shapes( + stop_frame) + first_frame_in_track = shapes_of_track[0]['frame'] + + def shape_to_db(tracked_shape_on_wire): + s = copy.copy(tracked_shape_on_wire) + s.pop('group', 0) + s.pop('attributes', 0) + s.pop('label_id', 0) + s.pop('byMachine', 0) + s.pop('keyframe') + return TrackedShape(**s) + + # This bounding box is used as a reference for tracking + start_shape = shape_to_db(shapes_of_track[start_frame-first_frame_in_track]) + + # Do the actual tracking and serializee back + tracker = RectangleTracker() + new_shapes = tracker.track_rectangles(task, start_shape, stop_frame) + new_shapes = [TrackedShapeSerializer(s).data for s in new_shapes] + + # Pack recognized shape in a track onto the wire + track_with_new_shapes = copy.copy(track) + track_with_new_shapes['shapes'] = new_shapes + + return JsonResponse(track_with_new_shapes) \ No newline at end of file diff --git a/cvat/requirements/base.txt b/cvat/requirements/base.txt index 7b688b34ec1..d47ac1f78ae 100644 --- a/cvat/requirements/base.txt +++ b/cvat/requirements/base.txt @@ -43,6 +43,7 @@ scikit-image==0.15.0 tensorflow==1.15.0 keras==2.2.5 opencv-python==4.1.0.25 +opencv-contrib-python==4.1.2.30 h5py==2.9.0 imgaug==0.2.9 django-cors-headers==3.2.0 diff --git a/cvat/settings/base.py b/cvat/settings/base.py index 08ae7e2b4e0..a90ae367b05 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -171,6 +171,12 @@ def generate_ssh_keys(): if 'yes' == os.environ.get('AUTO_SEGMENTATION', 'no'): INSTALLED_APPS += ['cvat.apps.auto_segmentation'] +# Enable (shape) tracking if environment variable TRACKING is set to yes +if 'yes' == os.environ.get('TRACKING', 'no'): + INSTALLED_APPS += ['cvat.apps.tracking'] + + + MIDDLEWARE = [ 'django.middleware.security.SecurityMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', diff --git a/cvat/urls.py b/cvat/urls.py index 6ae59f6b03a..5feb4eb338f 100644 --- a/cvat/urls.py +++ b/cvat/urls.py @@ -55,3 +55,7 @@ # new feature by Mohammad if apps.is_installed('cvat.apps.auto_segmentation'): urlpatterns.append(path('tensorflow/segmentation/', include('cvat.apps.auto_segmentation.urls'))) + +# Urls for (shape) tracking +if apps.is_installed('cvat.apps.tracking'): + urlpatterns.append(path('tracking/', include('cvat.apps.tracking.urls')))