Skip to content

Commit

Permalink
[Feature] refactored for efficiency (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucksp authored Dec 6, 2024
1 parent cdf40ae commit 21490b4
Showing 1 changed file with 86 additions and 59 deletions.
145 changes: 86 additions & 59 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,39 +35,54 @@ interface Props {
* @param recognizedObjectsCount: just to figure out what the next trackingObject key should be.
*/
export class FrameRecognition {
trackingObjects: Record<string, RecognizedObject>;
recognizedObjectsCount: number;
height: number;
width: number;
recognitionCount: number;
scoreThreshold: number;
private static readonly MAX_HISTORY_SIZE = 50;
private static readonly MAX_MISSED_COUNT = 30;
private static readonly TRACKING_THRESHOLD = 0.03;

private trackingObjects: Map<number, RecognizedObject>;
private recognizedObjectsCount: number;
private readonly height: number;
private readonly width: number;
private readonly recognitionCount: number;
private readonly scoreThreshold: number;

constructor({ height, width, recognitionCount = 20, score = 0.25 }: Props) {
this.height = height;
this.width = width;
this.trackingObjects = {};
this.trackingObjects = new Map();
this.recognizedObjectsCount = 0;
this.recognitionCount = recognitionCount;
this.scoreThreshold = score;
}

// Memoize and optimize confident object retrieval
private _cachedConfidentObject: RecognizedObject | null = null;
private _lastConfidentObjectCheck = 0;

/**
* getConfidentObject are the most confident about.
* It finds the object with the highest score & count.
* @returns RecognizedObject
* @todo allow more than 1 output item
*/
getConfidentObject() {
getConfidentObject(): RecognizedObject | null {
const currentTime = Date.now();
if (
this._cachedConfidentObject &&
currentTime - this._lastConfidentObjectCheck < 100
) {
return this._cachedConfidentObject;
}

let topObject: RecognizedObject | null = null;

for (const recognized of Object.values(this.trackingObjects)) {
for (const recognized of this.trackingObjects.values()) {
if (recognized.recognitionCount < this.recognitionCount) {
continue;
}

if (!topObject) {
topObject = recognized;
} else if (
if (
!topObject ||
recognized.score > topObject.score ||
(recognized.score === topObject.score &&
recognized.recognitionCount > topObject.recognitionCount)
Expand All @@ -76,6 +91,8 @@ export class FrameRecognition {
}
}

this._cachedConfidentObject = topObject;
this._lastConfidentObjectCheck = currentTime;
return topObject;
}

Expand All @@ -92,7 +109,7 @@ export class FrameRecognition {
labels: { [x: string]: number },
scores: number[]
) {
const missedKeys = new Set(Object.keys(this.trackingObjects));
const missedKeys = new Set(this.trackingObjects.keys());

for (let i = 0; i < Object.keys(labels).length; i++) {
const label = labels[i];
Expand All @@ -109,47 +126,58 @@ export class FrameRecognition {
bottom: boundingBoxes[i * 4 + 3],
};

// Optimize bounding box transformation
boundingBox.top =
0.5 - ((0.5 - boundingBox.top) * this.width) / this.height;
boundingBox.bottom =
0.5 - ((0.5 - boundingBox.bottom) * this.width) / this.height;

// Convert to number key to match Map type
const recognizedKey = this.findTrackedObjectKey(boundingBox, label);
if (!recognizedKey) {
if (recognizedKey === null) {
this.addNewObject(boundingBox, label, score);
} else {
this.updateTrackingObject(recognizedKey, boundingBox, score);
missedKeys.delete(recognizedKey);
}
}

// Optimize penalty application
for (const missedKey of missedKeys) {
this.penalizeObject(missedKey);
}

// Invalidate cached confident object
this._cachedConfidentObject = null;
}

/**
* addNewObject() adds new record to trackingObjects
*/
addNewObject(boundingBox: BoundingBox, label: number, score: number) {
const newId = this.recognizedObjectsCount + 1;
this.trackingObjects[newId] = {
private addNewObject(boundingBox: BoundingBox, label: number, score: number) {
// Increment first to start with 1 as the first key
this.recognizedObjectsCount++;
this.trackingObjects.set(this.recognizedObjectsCount, {
label,
boundingBox,
boundingBoxes: [boundingBox],
score,
scores: [score],
recognitionCount: 1,
missedCount: 0,
};
this.recognizedObjectsCount = newId;
});
}

/**
* updateTrackingObject() updates existing record in the trackingObjects. It updates the averages for score, bounding box position.
*/
updateTrackingObject(key: string, boundingBox: BoundingBox, score: number) {
const trackedObject = this.trackingObjects[key];
private updateTrackingObject(
key: number,
boundingBox: BoundingBox,
score: number
) {
const trackedObject = this.trackingObjects.get(key);
if (!trackedObject) return;

this.updateAveragedArray(trackedObject.scores, score);
trackedObject.score = this.getAverage(trackedObject.scores);
Expand All @@ -165,12 +193,14 @@ export class FrameRecognition {
/**
* penalizeObject(): called when the object with the given key is not recognized in the latest frame data. It adds to the missedCount so that if the threshold is reached, the entire key of trackingObject will be removed.
*/
penalizeObject(key: string) {
const trackedObject = this.trackingObjects[key];
private penalizeObject(key: number) {
const trackedObject = this.trackingObjects.get(key);
if (!trackedObject) return;

trackedObject.missedCount++;

if (trackedObject.missedCount >= 30) {
delete this.trackingObjects[key];
if (trackedObject.missedCount >= FrameRecognition.MAX_MISSED_COUNT) {
this.trackingObjects.delete(key);
}
}

Expand All @@ -180,25 +210,24 @@ export class FrameRecognition {
* @param label
* @returns key from trackingObjects
*/
findTrackedObjectKey(boundingBox: BoundingBox, label: number): string | null {
const threshold = 0.03;

for (const [key, tracked] of Object.entries(this.trackingObjects)) {
private findTrackedObjectKey(
boundingBox: BoundingBox,
label: number
): number | null {
for (const [key, tracked] of this.trackingObjects.entries()) {
if (tracked.label !== label) continue;

const topDelta = boundingBox.top - tracked.boundingBox.top;
const bottomDelta = boundingBox.bottom - tracked.boundingBox.bottom;
const leftDelta = boundingBox.left - tracked.boundingBox.left;
const rightDelta = boundingBox.right - tracked.boundingBox.right;

if (
Math.abs(topDelta) <= threshold &&
Math.abs(bottomDelta) <= threshold &&
Math.abs(leftDelta) <= threshold &&
Math.abs(rightDelta) <= threshold
) {
return key;
}
const isClose =
Math.abs(boundingBox.top - tracked.boundingBox.top) <=
FrameRecognition.TRACKING_THRESHOLD &&
Math.abs(boundingBox.bottom - tracked.boundingBox.bottom) <=
FrameRecognition.TRACKING_THRESHOLD &&
Math.abs(boundingBox.left - tracked.boundingBox.left) <=
FrameRecognition.TRACKING_THRESHOLD &&
Math.abs(boundingBox.right - tracked.boundingBox.right) <=
FrameRecognition.TRACKING_THRESHOLD;

if (isClose) return key;
}

return null;
Expand All @@ -207,32 +236,29 @@ export class FrameRecognition {
/**
* updatingAveragedArray() adds a new element to the array, based on FIFO logic if over the threshold.
*/
updateAveragedArray<T>(values: T[], newValue: T) {
if (values.length >= 50) {
private updateAveragedArray<T>(values: T[], newValue: T): void {
if (values.length >= FrameRecognition.MAX_HISTORY_SIZE) {
values.shift();
}
values.push(newValue);
}

getAverage(values: number[]) {
const sum = values.reduce((acc, value) => acc + value, 0);
return sum / values.length;
private getAverage(values: number[]): number {
return values.reduce((acc, value) => acc + value, 0) / values.length;
}

getAverageBoundingBox(boxes: BoundingBox[]) {
private getAverageBoundingBox(boxes: BoundingBox[]): BoundingBox {
const count = boxes.length;
const total = boxes.reduce(
(acc, box) => {
acc.top += box.top;
acc.left += box.left;
acc.bottom += box.bottom;
acc.right += box.right;
return acc;
},
(acc, box) => ({
top: acc.top + box.top,
left: acc.left + box.left,
bottom: acc.bottom + box.bottom,
right: acc.right + box.right,
}),
{ top: 0, left: 0, bottom: 0, right: 0 }
);

const count = boxes.length;

return {
top: total.top / count,
left: total.left / count,
Expand All @@ -241,8 +267,9 @@ export class FrameRecognition {
};
}

resetAll() {
this.trackingObjects = {};
resetAll(): void {
this.trackingObjects.clear();
this.recognizedObjectsCount = 0;
this._cachedConfidentObject = null;
}
}

0 comments on commit 21490b4

Please sign in to comment.