Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed typescript issues in 'tools-control.tsx' #7785

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { AIToolsIcon } from 'icons';
import { Canvas, convertShapesForInteractor } from 'cvat-canvas-wrapper';
import {
getCore, Label, MLModel, ObjectState, Job,
LabelType,
} from 'cvat-core-wrapper';
import openCVWrapper, { MatType } from 'utils/opencv-wrapper/opencv-wrapper';
import {
Expand Down Expand Up @@ -57,7 +58,7 @@ interface StateToProps {
canvasInstance: Canvas;
labels: Label[];
states: ObjectState[];
activeLabelID: number;
activeLabelID: number | null;
jobInstance: Job;
isActivated: boolean;
frame: number;
Expand Down Expand Up @@ -114,7 +115,7 @@ function mapStateToProps(state: CombinedState): StateToProps {
labels,
states,
canvasInstance: canvasInstance as Canvas,
jobInstance,
jobInstance: jobInstance as Job,
frame,
curZOrder,
defaultApproxPolyAccuracy,
Expand Down Expand Up @@ -142,7 +143,7 @@ interface TrackedShape {

interface State {
activeInteractor: MLModel | null;
activeLabelID: number;
activeLabelID: number | null;
activeTracker: MLModel | null;
convertMasksToPolygons: boolean;
trackedShapes: TrackedShape[];
Expand All @@ -153,6 +154,10 @@ interface State {
portals: React.ReactPortal[];
}

type InteractorResults = Extract<Awaited<ReturnType<typeof core.lambda.call>>, { mask: number[][] }>;
type TrackerResults = Extract<Awaited<ReturnType<typeof core.lambda.call>>, { states: any[]; shapes: number[][] }>;
type DetectedShapes = Extract<Awaited<ReturnType<typeof core.lambda.call>>, { length: number }>;

function trackedRectangleMapper(shape: number[]): number[] {
return shape.reduce(
(acc: number[], value: number, index: number): number[] => {
Expand Down Expand Up @@ -211,7 +216,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
isAborted: boolean;
latestResponse: {
rle: number[];
points: number[][];
points: [number, number][];
bounds?: [number, number, number, number];
};
lastestApproximatedPoints: number[][];
Expand All @@ -232,7 +237,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
convertMasksToPolygons: false,
activeInteractor: props.interactors.length ? props.interactors[0] : null,
activeTracker: props.trackers.length ? props.trackers[0] : null,
activeLabelID: props.labels.length ? props.labels[0].id : null,
activeLabelID: props.labels.length ? props.labels[0].id as number : null,
approxPolyAccuracy: props.defaultApproxPolyAccuracy,
trackedShapes: [],
fetching: false,
Expand Down Expand Up @@ -374,11 +379,12 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
try {
// run server request
this.setState({ fetching: true });

const response = await core.lambda.call(
jobInstance.taskId,
interactor,
{ ...data, job: jobInstance.id },
);
) as InteractorResults;

// if only mask presented, let's receive points
if (response.mask && !response.points) {
Expand All @@ -388,7 +394,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
}

// approximation with cv.approxPolyDP
const approximated = await this.approximateResponsePoints(response.points);
const approximated = await this.approximateResponsePoints(response.points as [number, number][]);
const rle = core.utils.mask2Rle(response.mask.flat());
if (response.bounds) {
rle.push(...response.bounds);
Expand All @@ -398,21 +404,19 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
rle.push(0, 0, width - 1, height - 1);
}

response.mask = rle;

if (this.interaction.id !== interactionId || this.interaction.isAborted) {
// new interaction session or the session is aborted
return;
}

this.interaction.latestResponse = {
bounds: response.bounds,
points: response.points,
points: response.points as [number, number][],
rle,
};
this.interaction.lastestApproximatedPoints = approximated;

this.setState({ pointsReceived: !!response.points.length });
this.setState({ pointsReceived: !!response.points?.length });
} finally {
if (this.interaction.id === interactionId && this.interaction.hideMessage) {
this.interaction.hideMessage();
Expand Down Expand Up @@ -482,16 +486,15 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
};

private onTracking = async (e: Event): Promise<void> => {
const { trackedShapes, activeTracker } = this.state;
const { trackedShapes, activeTracker, activeLabelID } = this.state;
const {
isActivated, jobInstance, frame, curZOrder, fetchAnnotations,
} = this.props;

if (!isActivated) {
if (!isActivated || !activeLabelID) {
return;
}

const { activeLabelID } = this.state;
const [label] = jobInstance.labels.filter((_label: any): boolean => _label.id === activeLabelID);

const { isDone, shapesUpdated } = (e as CustomEvent).detail;
Expand Down Expand Up @@ -694,18 +697,19 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
const {
serverlessState, shapePoints, clientID, trackerModel,
} = trackedShape;
const [clientState] = objectStates.filter((_state: any): boolean => _state.clientID === clientID);
const clientState = objectStates.find((_state): boolean => _state.clientID === clientID);
const keyframes = clientState?.keyframes;

if (
!clientState ||
clientState.keyframes.prev !== frame - 1 ||
clientState.keyframes.last >= frame
!clientState || !keyframes ||
keyframes?.prev !== frame - 1 ||
(typeof keyframes?.last === 'number' && keyframes?.last >= frame)
) {
return acc;
}

if (clientState && !clientState.outside) {
const { points } = clientState;
const points = clientState.points as number[];
withServerRequest = true;
const stateIsRelevant =
serverlessState !== null &&
Expand Down Expand Up @@ -762,12 +766,12 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
duration: 0,
className: 'cvat-tracking-notice',
});
// eslint-disable-next-line no-await-in-loop

const response = await core.lambda.call(jobInstance.taskId, tracker, {
frame: frame - 1,
shapes: trackableObjects.shapes,
job: jobInstance.id,
});
}) as TrackerResults;

const { states: serverlessStates } = response;
const statefullContainer = trackingData.statefull[trackerID] || {
Expand Down Expand Up @@ -816,7 +820,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
shapes: trackableObjects.shapes,
states: trackableObjects.states,
job: jobInstance.id,
});
}) as TrackerResults;

response.shapes = response.shapes.map(trackedRectangleMapper);
for (let i = 0; i < trackableObjects.clientIDs.length; i++) {
Expand Down Expand Up @@ -865,7 +869,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
frame,
objectType: ObjectType.SHAPE,
source: core.enums.Source.SEMI_AUTO,
label: labels.length ? labels.filter((label: any) => label.id === activeLabelID)[0] : null,
label: labels.find((label) => label.id === activeLabelID as number) as Label,
shapeType: ShapeType.POLYGON,
points: this.interaction.lastestApproximatedPoints.flat(),
occluded: false,
Expand All @@ -878,7 +882,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
frame,
objectType: ObjectType.SHAPE,
source: core.enums.Source.SEMI_AUTO,
label: labels.length ? labels.filter((label: any) => label.id === activeLabelID)[0] : null,
label: labels.find((label) => label.id === activeLabelID as number) as Label,
shapeType: ShapeType.MASK,
points: this.interaction.latestResponse.rle,
occluded: false,
Expand Down Expand Up @@ -910,19 +914,19 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
mask: number[][],
left: number,
top: number,
): Promise<number[]> {
): Promise<[number, number][]> {
await this.initializeOpenCV();

const src = openCVWrapper.mat.fromData(mask[0].length, mask.length, MatType.CV_8UC1, mask.flat());
try {
const polygons = openCVWrapper.contours.findContours(src, true);
return polygons[0].map((val: number, idx: number) => {
return polygons[0].reduce<[number, number][]>((acc, _, idx, array) => {
if (idx % 2) {
return val + top;
acc.push([array[idx - 1] + left, array[idx] + top]);
}

return val + left;
});
return acc;
}, []);
} finally {
src.delete();
}
Expand Down Expand Up @@ -1028,9 +1032,9 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
className='cvat-tools-track-button'
disabled={!activeTracker || fetching || frame === jobInstance.stopFrame}
onClick={() => {
this.setState({ mode: 'tracking' });
if (activeTracker && activeLabelID) {
this.setState({ mode: 'tracking' });

if (activeTracker) {
canvasInstance.cancel();
canvasInstance.interact({
shapeType: 'rectangle',
Expand All @@ -1052,7 +1056,9 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
}

private renderInteractorBlock(): JSX.Element {
const { interactors, canvasInstance, onInteractionStart } = this.props;
const {
interactors, canvasInstance, labels, onInteractionStart,
} = this.props;
const {
activeInteractor, activeLabelID, fetching,
} = this.state;
Expand Down Expand Up @@ -1123,9 +1129,9 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
fetching ||
activeInteractor.version < MIN_SUPPORTED_INTERACTOR_VERSION}
onClick={() => {
this.setState({ mode: 'interaction' });
if (activeInteractor && activeLabelID && labels.length) {
this.setState({ mode: 'interaction' });

if (activeInteractor) {
canvasInstance.cancel();
activeInteractor.onChangeToolsBlockerState = this.onChangeToolsBlockerState;
canvasInstance.interact({
Expand Down Expand Up @@ -1228,21 +1234,10 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
this.setState({ mode: 'detection', fetching: true });
const result = await core.lambda.call(jobInstance.taskId, model, {
...body, frame, job: jobInstance.id,
});

type SerializedShape = {
type: ShapeType;
rotation?: number;
attributes: { name: string; value: string }[];
label: string;
outside?: boolean;
points?: number[];
mask?: number[];
elements: SerializedShape[];
};
}) as DetectedShapes;

const states = result.map(
(data: SerializedShape): ObjectState | null => {
(data): ObjectState | null => {
const jobLabel = jobInstance.labels
.find((jLabel: Label): boolean => jLabel.name === data.label);

Expand All @@ -1261,7 +1256,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
zOrder: curZOrder,
};

if (data.type === ShapeType.SKELETON && jobLabel.type === ShapeType.SKELETON) {
if (data.type === ShapeType.SKELETON && jobLabel.type === LabelType.SKELETON) {
// find a center of the skeleton
// to set this center as outside points position
const center = data.elements.reduce<[number, number]>((acc, { points }) => {
Expand All @@ -1276,7 +1271,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
return {
label: sublabel,
objectType: ObjectType.SHAPE,
shapeType: sublabel.type,
shapeType: sublabel.type as any as ShapeType,
attributes: {},
frame,
source: core.enums.Source.AUTO,
Expand All @@ -1289,7 +1284,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
outside: !!element.outside || false,
} : {}),
};
}).map((elementData) => new core.classes.ObjectState({ ...elementData }));
});

if (elements.every((element) => element.outside)) {
return null;
Expand All @@ -1306,20 +1301,23 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
if (data.type === 'mask' && data.points && body.convMaskToPoly) {
return new core.classes.ObjectState({
...objectData,
shapeType: 'polygon',
shapeType: ShapeType.POLYGON,
points: data.points,
});
}

if (data.type === 'mask') {
const [left, top, right, bottom] = (data.mask as number[]).splice(-4);
const rle = core.utils.mask2Rle(data.mask);
rle.push(left, top, right, bottom);
return new core.classes.ObjectState({
...objectData,
shapeType: data.type,
points: rle,
});
if (data.mask) {
const [left, top, right, bottom] = data.mask.splice(-4);
const rle = core.utils.mask2Rle(data.mask);
rle.push(left, top, right, bottom);
return new core.classes.ObjectState({
...objectData,
shapeType: data.type,
points: rle,
});
}
return null;
}

return new core.classes.ObjectState({
Expand Down
Loading