Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add propagate shapes action #8044

Merged
merged 17 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Added

- `Propagate shapes` action to create copies of visible shapes on multiple frames forward or backward
(<https://github.com/cvat-ai/cvat/pull/8044>)
2 changes: 1 addition & 1 deletion cvat-core/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "cvat-core",
"version": "15.0.6",
"version": "15.0.7",
"type": "module",
"description": "Part of Computer Vision Tool which presents an interface for client-side integration",
"main": "src/api.ts",
Expand Down
73 changes: 69 additions & 4 deletions cvat-core/src/annotations-actions.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
// Copyright (C) 2023 CVAT.ai Corporation
// Copyright (C) 2023-2024 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT

import { omit, throttle } from 'lodash';
import { ArgumentError } from './exceptions';
import { SerializedCollection } from './server-response-types';
import { SerializedCollection, SerializedShape } from './server-response-types';
import { Job, Task } from './session';
import { EventScope, ObjectType } from './enums';
import ObjectState from './object-state';
import { getAnnotations, getCollection } from './annotations';
import { propagateShapes } from './object-utils';

export interface SingleFrameActionInput {
collection: Omit<SerializedCollection, 'tracks' | 'tags' | 'version'>;
Expand All @@ -28,12 +29,20 @@ export enum ActionParameterType {
NUMBER = 'number',
}

// For SELECT values should be a list of possible options
// For NUMBER values should be a list with [min, max, step],
// or a callback ({ instance }: { instance: Job | Task }) => [min, max, step]
type ActionParameters = Record<string, {
type: ActionParameterType;
values: string[];
defaultValue: string;
values: string[] | (({ instance }: { instance: Job | Task }) => string[]);
defaultValue: string | (({ instance }: { instance: Job | Task }) => string);
}>;

export enum FrameSelectionType {
SEGMENT = 'segment',
CURRENT_FRAME = 'current_frame',
}

export default class BaseSingleFrameAction {
/* eslint-disable @typescript-eslint/no-unused-vars */
public async init(
Expand All @@ -58,6 +67,10 @@ export default class BaseSingleFrameAction {
public get parameters(): ActionParameters | null {
throw new Error('Method not implemented');
}

public get frameSelection(): FrameSelectionType {
return FrameSelectionType.SEGMENT;
}
}

class RemoveFilteredShapes extends BaseSingleFrameAction {
Expand All @@ -82,6 +95,57 @@ class RemoveFilteredShapes extends BaseSingleFrameAction {
}
}

class PropagateShapes extends BaseSingleFrameAction {
#targetFrame: number;

public async init(instance, parameters): Promise<void> {
this.#targetFrame = parameters['Target frame'];
}

public async destroy(): Promise<void> {
// nothing to destroy
}

public async run(
instance,
{ collection: { shapes }, frameData: { number } },
): Promise<SingleFrameActionOutput> {
if (number === this.#targetFrame) {
return { collection: { shapes } };
}
const propagatedShapes = propagateShapes<SerializedShape>(shapes, number, this.#targetFrame);
return { collection: { shapes: [...shapes, ...propagatedShapes] } };
}

public get name(): string {
return 'Propagate shapes';
}

public get parameters(): ActionParameters | null {
return {
'Target frame': {
type: ActionParameterType.NUMBER,
values: ({ instance }) => {
if (instance instanceof Job) {
return [instance.startFrame, instance.stopFrame, 1].map((val) => val.toString());
}
return [0, instance.size - 1, 1].map((val) => val.toString());
},
defaultValue: ({ instance }) => {
if (instance instanceof Job) {
return instance.stopFrame.toString();
}
return (instance.size - 1).toString();
},
},
};
}

public get frameSelection(): FrameSelectionType {
return FrameSelectionType.CURRENT_FRAME;
}
}

const registeredActions: BaseSingleFrameAction[] = [];

export async function listActions(): Promise<BaseSingleFrameAction[]> {
Expand All @@ -102,6 +166,7 @@ export async function registerAction(action: BaseSingleFrameAction): Promise<voi
}

registerAction(new RemoveFilteredShapes());
registerAction(new PropagateShapes());

async function runSingleFrameChain(
instance: Job | Task,
Expand Down
5 changes: 5 additions & 0 deletions cvat-core/src/annotations-objects.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2187,6 +2187,11 @@ export class MaskShape extends Shape {

constructor(data: SerializedShape, clientID: number, color: string, injection: AnnotationInjection) {
super(data, clientID, color, injection);
const [left, top, right, bottom] = this.points.slice(-4);
const { width, height } = this.frameMeta[this.frame];
bsekachev marked this conversation as resolved.
Show resolved Hide resolved
if (left > width || top > height || right > width || bottom > height) {
klakhov marked this conversation as resolved.
Show resolved Hide resolved
this.points = cropMask(this.points, width, height);
}
[this.left, this.top, this.right, this.bottom] = this.points.splice(-4, 4);
this.getMasksOnFrame = injection.getMasksOnFrame;
this.pinned = true;
Expand Down
3 changes: 2 additions & 1 deletion cvat-core/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import {
Exception, ArgumentError, DataError, ScriptingError, ServerError,
} from './exceptions';

import { mask2Rle, rle2Mask } from './object-utils';
import { mask2Rle, rle2Mask, propagateShapes } from './object-utils';
import User from './user';
import pjson from '../package.json';
import config from './config';
Expand Down Expand Up @@ -397,6 +397,7 @@ function build(): CVATCore {
utils: {
mask2Rle,
rle2Mask,
propagateShapes,
},
};

Expand Down
3 changes: 2 additions & 1 deletion cvat-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { AnnotationFormats } from './annotation-formats';
import logger from './logger';
import * as enums from './enums';
import config from './config';
import { mask2Rle, rle2Mask } from './object-utils';
import { mask2Rle, rle2Mask, propagateShapes } from './object-utils';
import User from './user';
import Project from './project';
import { Job, Task } from './session';
Expand Down Expand Up @@ -201,5 +201,6 @@ export default interface CVATCore {
utils: {
mask2Rle: typeof mask2Rle;
rle2Mask: typeof rle2Mask;
propagateShapes: typeof propagateShapes;
};
}
63 changes: 62 additions & 1 deletion cvat-core/src/object-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import { DataError, ArgumentError } from './exceptions';
import { Attribute } from './labels';
import { ShapeType, AttributeType } from './enums';
import { ShapeType, AttributeType, ObjectType } from './enums';
import { SerializedShape } from './server-response-types';
import ObjectState, { SerializedData } from './object-state';

export function checkNumberOfPoints(shapeType: ShapeType, points: number[]): void {
if (shapeType === ShapeType.RECTANGLE) {
Expand Down Expand Up @@ -356,3 +358,62 @@ export function rle2Mask(rle: number[], width: number, height: number): number[]

return decoded;
}

export function propagateShapes<T extends SerializedShape | ObjectState>(
shapes: T[], from: number, to: number,
): T[] {
const getCopy = (shape: T): SerializedShape | SerializedData => {
if (shape instanceof ObjectState) {
return {
attributes: shape.attributes,
points: shape.shapeType === 'skeleton' ? null : shape.points,
occluded: shape.occluded,
objectType: shape.objectType !== ObjectType.TRACK ? shape.objectType : ObjectType.SHAPE,
shapeType: shape.shapeType,
label: shape.label,
zOrder: shape.zOrder,
rotation: shape.rotation,
frame: from,
elements: shape.shapeType === 'skeleton' ? shape.elements
.map((element: ObjectState): any => getCopy(element as T)) : [],
source: shape.source,
};
}
return {
attributes: [...shape.attributes.map((attribute) => ({ ...attribute }))],
points: shape.type === 'skeleton' ? null : [...shape.points],
occluded: shape.occluded,
type: shape.type,
label_id: shape.label_id,
z_order: shape.z_order,
rotation: shape.rotation,
frame: from,
elements: shape.type === 'skeleton' ? shape.elements
.map((element: SerializedShape): SerializedShape => getCopy(element as T) as SerializedShape) : [],
source: shape.source,
group: 0,
outside: false,
};
};

const states: T[] = [];
const sign = Math.sign(to - from);
for (let frame = from + sign; sign > 0 ? frame <= to : frame >= to; frame += sign) {
for (const shape of shapes) {
const copy = getCopy(shape);

copy.frame = frame;
copy.elements?.forEach((element: Omit<SerializedShape, 'elements'> | SerializedData): void => {
element.frame = frame;
});

if (shape instanceof ObjectState) {
states.push(new ObjectState(copy as SerializedData) as T);
} else {
states.push(copy as T);
}
}
}

return states;
}
2 changes: 1 addition & 1 deletion cvat-ui/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "cvat-ui",
"version": "1.63.11",
"version": "1.63.12",
"description": "CVAT single-page application",
"main": "src/index.tsx",
"scripts": {
Expand Down
25 changes: 1 addition & 24 deletions cvat-ui/src/actions/annotation-actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -473,31 +473,8 @@ export function propagateObjectAsync(from: number, to: number): ThunkAction {
throw new Error('There is not an activated object state to be propagated');
}

const getCopyFromState = (_objectState: any): any => ({
attributes: _objectState.attributes,
points: _objectState.shapeType === 'skeleton' ? null : _objectState.points,
occluded: _objectState.occluded,
objectType: _objectState.objectType !== ObjectType.TRACK ? _objectState.objectType : ObjectType.SHAPE,
shapeType: _objectState.shapeType,
label: _objectState.label,
zOrder: _objectState.zOrder,
rotation: _objectState.rotation,
frame: from,
elements: _objectState.shapeType === 'skeleton' ? _objectState.elements
.map((element: any): any => getCopyFromState(element)) : [],
source: _objectState.source,
});

const copy = getCopyFromState(objectState);
await sessionInstance.logger.log(EventScope.propagateObject, { count: Math.abs(to - from) });
const states = [];
const sign = Math.sign(to - from);
for (let frame = from + sign; sign > 0 ? frame <= to : frame >= to; frame += sign) {
copy.frame = frame;
copy.elements.forEach((element: any) => { element.frame = frame; });
const newState = new cvat.classes.ObjectState(copy);
states.push(newState);
}
const states = cvat.utils.propagateShapes<ObjectState>([objectState], from, to);

await sessionInstance.annotations.put(states);
const history = await sessionInstance.actions.get();
Expand Down
Loading
Loading