Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): separate face search relation #10371

Merged
merged 7 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions e2e/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -398,14 +398,7 @@ export const utils = {
return;
}

const vector = Array.from({ length: 512 }, Math.random);
const embedding = `[${vector.join(',')}]`;

await client.query('INSERT INTO asset_faces ("assetId", "personId", "embedding") VALUES ($1, $2, $3)', [
assetId,
personId,
embedding,
]);
await client.query('INSERT INTO asset_faces ("assetId", "personId") VALUES ($1, $2)', [assetId, personId]);
},

setPersonThumbnail: async (personId: string) => {
Expand Down
8 changes: 4 additions & 4 deletions server/src/entities/asset-face.entity.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { AssetEntity } from 'src/entities/asset.entity';
import { FaceSearchEntity } from 'src/entities/face-search.entity';
import { PersonEntity } from 'src/entities/person.entity';
import { Column, Entity, Index, ManyToOne, PrimaryGeneratedColumn } from 'typeorm';
import { Column, Entity, Index, ManyToOne, OneToOne, PrimaryGeneratedColumn } from 'typeorm';

@Entity('asset_faces', { synchronize: false })
@Index('IDX_asset_faces_assetId_personId', ['assetId', 'personId'])
Expand All @@ -15,9 +16,8 @@ export class AssetFaceEntity {
@Column({ nullable: true, type: 'uuid' })
personId!: string | null;

@Index('face_index', { synchronize: false })
@Column({ type: 'float4', array: true, select: false, transformer: { from: (v) => JSON.parse(v), to: (v) => v } })
embedding!: number[];
@OneToOne(() => FaceSearchEntity, (faceSearchEntity) => faceSearchEntity.face, { cascade: ['insert'] })
faceSearch?: FaceSearchEntity;

@Column({ default: 0, type: 'int' })
imageWidth!: number;
Expand Down
21 changes: 21 additions & 0 deletions server/src/entities/face-search.entity.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { AssetFaceEntity } from 'src/entities/asset-face.entity';
import { asVector } from 'src/utils/database';
import { Column, Entity, Index, JoinColumn, OneToOne, PrimaryColumn } from 'typeorm';

@Entity('face_search', { synchronize: false })
mertalev marked this conversation as resolved.
Show resolved Hide resolved
export class FaceSearchEntity {
@OneToOne(() => AssetFaceEntity, { onDelete: 'CASCADE', nullable: true })
@JoinColumn({ name: 'faceId', referencedColumnName: 'id' })
face?: AssetFaceEntity;

@PrimaryColumn()
faceId!: string;

@Index('face_index', { synchronize: false })
@Column({
type: 'float4',
array: true,
transformer: { from: (v) => JSON.parse(v), to: (v) => asVector(v) },
})
embedding!: number[];
}
2 changes: 2 additions & 0 deletions server/src/entities/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { AssetStackEntity } from 'src/entities/asset-stack.entity';
import { AssetEntity } from 'src/entities/asset.entity';
import { AuditEntity } from 'src/entities/audit.entity';
import { ExifEntity } from 'src/entities/exif.entity';
import { FaceSearchEntity } from 'src/entities/face-search.entity';
import { GeodataPlacesEntity } from 'src/entities/geodata-places.entity';
import { LibraryEntity } from 'src/entities/library.entity';
import { MemoryEntity } from 'src/entities/memory.entity';
Expand All @@ -34,6 +35,7 @@ export const entities = [
AssetJobStatusEntity,
AuditEntity,
ExifEntity,
FaceSearchEntity,
GeodataPlacesEntity,
MemoryEntity,
MoveEntity,
Expand Down
54 changes: 54 additions & 0 deletions server/src/migrations/1718486162779-AddFaceSearchRelation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { getVectorExtension } from 'src/database.config';
import { DatabaseExtension } from 'src/interfaces/database.interface';
import { MigrationInterface, QueryRunner } from 'typeorm';

export class AddFaceSearchRelation1718486162779 implements MigrationInterface {
public async up(queryRunner: QueryRunner): Promise<void> {
if (getVectorExtension() === DatabaseExtension.VECTORS) {
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
await queryRunner.query(`SET vectors.pgvector_compatibility=on`);
}

await queryRunner.query(`
CREATE TABLE face_search (
"faceId" uuid PRIMARY KEY REFERENCES asset_faces(id) ON DELETE CASCADE,
embedding vector(512) NOT NULL )`);

await queryRunner.query(`ALTER TABLE face_search ALTER COLUMN embedding SET STORAGE EXTERNAL`);
await queryRunner.query(`ALTER TABLE smart_search ALTER COLUMN embedding SET STORAGE EXTERNAL`);

await queryRunner.query(`
INSERT INTO face_search("faceId", embedding)
SELECT id, embedding
FROM asset_faces faces`);

await queryRunner.query(`ALTER TABLE asset_faces DROP COLUMN "embedding"`);

await queryRunner.query(`
CREATE INDEX face_index ON face_search
USING hnsw (embedding vector_cosine_ops)
WITH (ef_construction = 300, m = 16)`);
}

public async down(queryRunner: QueryRunner): Promise<void> {
if (getVectorExtension() === DatabaseExtension.VECTORS) {
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
await queryRunner.query(`SET vectors.pgvector_compatibility=on`);
}

await queryRunner.query(`ALTER TABLE asset_faces ADD COLUMN "embedding" vector(512)`);
await queryRunner.query(`ALTER TABLE face_search ALTER COLUMN embedding SET STORAGE DEFAULT`);
await queryRunner.query(`ALTER TABLE smart_search ALTER COLUMN embedding SET STORAGE DEFAULT`);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I missed this line. Nice!

await queryRunner.query(`
UPDATE asset_faces
SET embedding = fs.embedding
FROM face_search fs
WHERE id = fs."faceId"`);
await queryRunner.query(`DROP TABLE face_search`);

await queryRunner.query(`
CREATE INDEX face_index ON asset_faces
USING hnsw (embedding vector_cosine_ops)
WITH (ef_construction = 300, m = 16)`);
}
}
5 changes: 3 additions & 2 deletions server/src/queries/search.repository.sql
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,16 @@ WITH
"faces"."boundingBoxY1" AS "boundingBoxY1",
"faces"."boundingBoxX2" AS "boundingBoxX2",
"faces"."boundingBoxY2" AS "boundingBoxY2",
"faces"."embedding" <= > $1 AS "distance"
"search"."embedding" <= > $1 AS "distance"
FROM
"asset_faces" "faces"
INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId"
AND ("asset"."deletedAt" IS NULL)
INNER JOIN "face_search" "search" ON "search"."faceId" = "faces"."id"
WHERE
"asset"."ownerId" IN ($2)
ORDER BY
"faces"."embedding" <= > $1 ASC
"search"."embedding" <= > $1 ASC
LIMIT
100
)
Expand Down
2 changes: 1 addition & 1 deletion server/src/repositories/database.repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ export class DatabaseRepository implements IDatabaseRepository {
} catch (error) {
if (getVectorExtension() === DatabaseExtension.VECTORS) {
this.logger.warn(`Could not reindex index ${index}. Attempting to auto-fix.`);
const table = index === VectorIndex.CLIP ? 'smart_search' : 'asset_faces';
const table = index === VectorIndex.CLIP ? 'smart_search' : 'face_search';
const dimSize = await this.getDimSize(table);
await this.dataSource.manager.transaction(async (manager) => {
await this.setSearchPath(manager);
Expand Down
7 changes: 2 additions & 5 deletions server/src/repositories/person.repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import {
PersonStatistics,
UpdateFacesData,
} from 'src/interfaces/person.interface';
import { asVector } from 'src/utils/database';
import { Instrumentation } from 'src/utils/instrumentation';
import { Paginated, PaginationOptions, paginate } from 'src/utils/pagination';
import { FindManyOptions, FindOptionsRelations, FindOptionsSelect, In, Repository } from 'typeorm';
Expand Down Expand Up @@ -249,10 +248,8 @@ export class PersonRepository implements IPersonRepository {
}

async createFaces(entities: AssetFaceEntity[]): Promise<string[]> {
const res = await this.assetFaceRepository.insert(
entities.map((entity) => ({ ...entity, embedding: () => asVector(entity.embedding, true) })),
);
return res.identifiers.map((row) => row.id);
const res = await this.assetFaceRepository.save(entities);
return res.map((row) => row.id);
}

async update(entity: Partial<PersonEntity>): Promise<PersonEntity> {
Expand Down
5 changes: 3 additions & 2 deletions server/src/repositories/search.repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,11 @@ export class SearchRepository implements ISearchRepository {
await this.assetRepository.manager.transaction(async (manager) => {
const cte = manager
.createQueryBuilder(AssetFaceEntity, 'faces')
.select('faces.embedding <=> :embedding', 'distance')
.select('search.embedding <=> :embedding', 'distance')
.innerJoin('faces.asset', 'asset')
.innerJoin('faces.faceSearch', 'search')
.where('asset.ownerId IN (:...userIds )')
.orderBy('faces.embedding <=> :embedding')
.orderBy('search.embedding <=> :embedding')
.setParameters({ userIds, embedding: asVector(embedding) });

cte.limit(numResults);
Expand Down
5 changes: 4 additions & 1 deletion server/src/services/person.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -668,15 +668,18 @@ describe(PersonService.name, () => {
machineLearningMock.detectFaces.mockResolvedValue(detectFaceMock);
searchMock.searchFaces.mockResolvedValue([{ face: faceStub.face1, distance: 0.7 }]);
assetMock.getByIds.mockResolvedValue([assetStub.image]);
const faceId = 'face-id';
cryptoMock.randomUUID.mockReturnValue(faceId);
const face = {
id: faceId,
assetId: 'asset-id',
embedding: [1, 2, 3, 4],
boundingBoxX1: 100,
boundingBoxY1: 100,
boundingBoxX2: 200,
boundingBoxY2: 200,
imageHeight: 500,
imageWidth: 400,
faceSearch: { faceId, embedding: [1, 2, 3, 4] },
};

await sut.handleDetectFaces({ id: assetStub.image.id });
Expand Down
41 changes: 26 additions & 15 deletions server/src/services/person.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
mapFaces,
mapPerson,
} from 'src/dtos/person.dto';
import { AssetFaceEntity } from 'src/entities/asset-face.entity';
import { AssetEntity, AssetType } from 'src/entities/asset.entity';
import { PersonPathType } from 'src/entities/move.entity';
import { PersonEntity } from 'src/entities/person.entity';
Expand Down Expand Up @@ -70,7 +71,7 @@ export class PersonService {
@Inject(IStorageRepository) private storageRepository: IStorageRepository,
@Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(ISearchRepository) private smartInfoRepository: ISearchRepository,
@Inject(ICryptoRepository) cryptoRepository: ICryptoRepository,
@Inject(ICryptoRepository) private cryptoRepository: ICryptoRepository,
@Inject(ILoggerRepository) private logger: ILoggerRepository,
) {
this.access = AccessCore.create(accessRepository);
Expand Down Expand Up @@ -347,16 +348,21 @@ export class PersonService {

if (faces.length > 0) {
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } });
const mappedFaces = faces.map((face) => ({
assetId: asset.id,
embedding: face.embedding,
imageHeight,
imageWidth,
boundingBoxX1: face.boundingBox.x1,
boundingBoxY1: face.boundingBox.y1,
boundingBoxX2: face.boundingBox.x2,
boundingBoxY2: face.boundingBox.y2,
}));
const mappedFaces: Partial<AssetFaceEntity>[] = [];
for (const face of faces) {
const faceId = this.cryptoRepository.randomUUID();
mappedFaces.push({
id: faceId,
assetId: asset.id,
imageHeight,
imageWidth,
boundingBoxX1: face.boundingBox.x1,
boundingBoxY1: face.boundingBox.y1,
boundingBoxX2: face.boundingBox.x2,
boundingBoxY2: face.boundingBox.y2,
faceSearch: { faceId, embedding: face.embedding },
});
}

const faceIds = await this.repository.createFaces(mappedFaces);
await this.jobRepository.queueAll(faceIds.map((id) => ({ name: JobName.FACIAL_RECOGNITION, data: { id } })));
Expand Down Expand Up @@ -409,22 +415,27 @@ export class PersonService {

const face = await this.repository.getFaceByIdWithAssets(
id,
{ person: true, asset: true },
{ id: true, personId: true, embedding: true },
{ person: true, asset: true, faceSearch: true },
{ id: true, personId: true, faceSearch: { embedding: true } },
);
if (!face || !face.asset) {
this.logger.warn(`Face ${id} not found`);
return JobStatus.FAILED;
}

if (!face.faceSearch?.embedding) {
this.logger.warn(`Face ${id} does not have an embedding`);
return JobStatus.FAILED;
}

if (face.personId) {
this.logger.debug(`Face ${id} already has a person assigned`);
return JobStatus.SKIPPED;
}

const matches = await this.smartInfoRepository.searchFaces({
userIds: [face.asset.ownerId],
embedding: face.embedding,
embedding: face.faceSearch.embedding,
maxDistance: machineLearning.facialRecognition.maxDistance,
numResults: machineLearning.facialRecognition.minFaces,
});
Expand All @@ -448,7 +459,7 @@ export class PersonService {
if (!personId) {
const matchWithPerson = await this.smartInfoRepository.searchFaces({
userIds: [face.asset.ownerId],
embedding: face.embedding,
embedding: face.faceSearch.embedding,
maxDistance: machineLearning.facialRecognition.maxDistance,
numResults: 1,
hasPerson: true,
Expand Down
Loading
Loading