From ac944b03fef8fc415a7515ee74203446d3aa8625 Mon Sep 17 00:00:00 2001 From: Jay Wang Date: Fri, 2 Feb 2024 00:48:19 -0500 Subject: [PATCH] Add embedding worker Signed-off-by: Jay Wang --- .../components/prompt-panel/prompt-panel.ts | 50 ++++++++++++++++++- .../rag-playground/src/workers/embedding.ts | 48 ++++++++++++++---- 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/examples/rag-playground/src/components/prompt-panel/prompt-panel.ts b/examples/rag-playground/src/components/prompt-panel/prompt-panel.ts index 9f5e78e..61af1fc 100644 --- a/examples/rag-playground/src/components/prompt-panel/prompt-panel.ts +++ b/examples/rag-playground/src/components/prompt-panel/prompt-panel.ts @@ -1,7 +1,9 @@ import { LitElement, css, unsafeCSS, html, PropertyValues } from 'lit'; import { customElement, property, state, query } from 'lit/decorators.js'; import { unsafeHTML } from 'lit/directives/unsafe-html.js'; -import { pipeline } from '@xenova/transformers'; +import { EmbeddingModel } from '../../workers/embedding'; + +import type { EmbeddingWorkerMessage } from '../../workers/embedding'; import componentCSS from './prompt-panel.css?inline'; import EmbeddingWorkerInline from '../../workers/embedding?worker&inline'; @@ -17,12 +19,25 @@ export class MememoPromptPanel extends LitElement { //==========================================================================|| embeddingWorker: Worker; + embeddingWorkerRequestCount = 0; + + get embeddingWorkerRequestID() { + this.embeddingWorkerRequestCount++; + return `prompt-panel-${this.embeddingWorkerRequestCount}`; + } + //==========================================================================|| // Lifecycle Methods || //==========================================================================|| constructor() { super(); this.embeddingWorker = new EmbeddingWorkerInline(); + this.embeddingWorker.addEventListener( + 'message', + (e: MessageEvent) => { + this.embeddingWorkerMessageHandler(e); + } + ); } firstUpdated() { @@ -40,12 +55,43 @@ export class MememoPromptPanel extends LitElement { //==========================================================================|| async initData() {} - async getEmbedding() {} + getEmbedding() { + const message: EmbeddingWorkerMessage = { + command: 'startExtractEmbedding', + payload: { + detail: '', + requestID: this.embeddingWorkerRequestID, + model: EmbeddingModel.gteSmall, + sentences: ['Hello, how are you', 'yo'] + } + }; + this.embeddingWorker.postMessage(message); + } //==========================================================================|| // Event Handlers || //==========================================================================|| + embeddingWorkerMessageHandler(e: MessageEvent) { + switch (e.data.command) { + case 'finishExtractEmbedding': { + const embeddings = e.data.payload.embeddings; + console.log(embeddings); + break; + } + + case 'error': { + console.error('Worker error: ', e.data.payload.message); + break; + } + + default: { + console.error('Worker: unknown message', e.data.command); + break; + } + } + } + //==========================================================================|| // Private Helpers || //==========================================================================|| diff --git a/examples/rag-playground/src/workers/embedding.ts b/examples/rag-playground/src/workers/embedding.ts index fff61ac..dbd257c 100644 --- a/examples/rag-playground/src/workers/embedding.ts +++ b/examples/rag-playground/src/workers/embedding.ts @@ -10,7 +10,7 @@ export type EmbeddingWorkerMessage = command: 'startExtractEmbedding'; payload: { requestID: string; - text: string; + sentences: string[]; model: EmbeddingModel; detail: string; }; @@ -19,10 +19,10 @@ export type EmbeddingWorkerMessage = command: 'finishExtractEmbedding'; payload: { requestID: string; - text: string; + sentences: string[]; model: EmbeddingModel; detail: string; - embedding: number[]; + embeddings: number[][]; }; } | { @@ -39,26 +39,56 @@ const extractors: Record> = { 'gte-small': pipeline('feature-extraction', 'gte-small') }; +/** + * Helper function to handle calls from the main thread + * @param e Message event + */ +self.onmessage = (e: MessageEvent) => { + switch (e.data.command) { + case 'startExtractEmbedding': { + const { model, sentences, requestID, detail } = e.data.payload; + startExtractEmbedding(model, sentences, requestID, detail); + break; + } + + default: { + console.error('Worker: unknown message', e.data.command); + break; + } + } +}; + /** * Extract embedding from the input text * @param model Embedding model * @param text Input text */ -export const getEmbedding = async ( +export const startExtractEmbedding = async ( model: EmbeddingModel, - text: string, + sentences: string[], requestID: string, detail: string ) => { try { const extractor = await extractors[model]; - const sentences = [text]; const output = await extractor(sentences, { pooling: 'mean', normalize: true }); - const embedding = Array.from(output.data as Float32Array); + const embeddings: number[][] = []; + const flattenEmbedding: number[] = Array.from( + output.data as Float32Array + ); + + // Un-flatten the embedding output + for (let i = 0; i < output.dims[0]; i++) { + const curRow = flattenEmbedding.slice( + i * output.dims[1], + (i + 1) * output.dims[1] + ); + embeddings.push(curRow); + } // Send result to the main thread const message: EmbeddingWorkerMessage = { @@ -67,8 +97,8 @@ export const getEmbedding = async ( model, requestID, detail, - text, - embedding + sentences, + embeddings } }; postMessage(message);