Skip to content

Commit

Permalink
Add embedding worker
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Wang <jay@zijie.wang>
  • Loading branch information
xiaohk committed Feb 2, 2024
1 parent 9054850 commit ac944b0
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { LitElement, css, unsafeCSS, html, PropertyValues } from 'lit';
import { customElement, property, state, query } from 'lit/decorators.js';
import { unsafeHTML } from 'lit/directives/unsafe-html.js';
import { pipeline } from '@xenova/transformers';
import { EmbeddingModel } from '../../workers/embedding';

import type { EmbeddingWorkerMessage } from '../../workers/embedding';

import componentCSS from './prompt-panel.css?inline';
import EmbeddingWorkerInline from '../../workers/embedding?worker&inline';
Expand All @@ -17,12 +19,25 @@ export class MememoPromptPanel extends LitElement {
//==========================================================================||
embeddingWorker: Worker;

embeddingWorkerRequestCount = 0;

get embeddingWorkerRequestID() {
this.embeddingWorkerRequestCount++;
return `prompt-panel-${this.embeddingWorkerRequestCount}`;
}

//==========================================================================||
// Lifecycle Methods ||
//==========================================================================||
constructor() {
super();
this.embeddingWorker = new EmbeddingWorkerInline();
this.embeddingWorker.addEventListener(
'message',
(e: MessageEvent<EmbeddingWorkerMessage>) => {
this.embeddingWorkerMessageHandler(e);
}
);
}

firstUpdated() {
Expand All @@ -40,12 +55,43 @@ export class MememoPromptPanel extends LitElement {
//==========================================================================||
async initData() {}

async getEmbedding() {}
getEmbedding() {
const message: EmbeddingWorkerMessage = {
command: 'startExtractEmbedding',
payload: {
detail: '',
requestID: this.embeddingWorkerRequestID,
model: EmbeddingModel.gteSmall,
sentences: ['Hello, how are you', 'yo']
}
};
this.embeddingWorker.postMessage(message);
}

//==========================================================================||
// Event Handlers ||
//==========================================================================||

embeddingWorkerMessageHandler(e: MessageEvent<EmbeddingWorkerMessage>) {
switch (e.data.command) {
case 'finishExtractEmbedding': {
const embeddings = e.data.payload.embeddings;
console.log(embeddings);
break;
}

case 'error': {
console.error('Worker error: ', e.data.payload.message);
break;
}

default: {
console.error('Worker: unknown message', e.data.command);
break;
}
}
}

//==========================================================================||
// Private Helpers ||
//==========================================================================||
Expand Down
48 changes: 39 additions & 9 deletions examples/rag-playground/src/workers/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export type EmbeddingWorkerMessage =
command: 'startExtractEmbedding';
payload: {
requestID: string;
text: string;
sentences: string[];
model: EmbeddingModel;
detail: string;
};
Expand All @@ -19,10 +19,10 @@ export type EmbeddingWorkerMessage =
command: 'finishExtractEmbedding';
payload: {
requestID: string;
text: string;
sentences: string[];
model: EmbeddingModel;
detail: string;
embedding: number[];
embeddings: number[][];
};
}
| {
Expand All @@ -39,26 +39,56 @@ const extractors: Record<EmbeddingModel, Promise<FeatureExtractionPipeline>> = {
'gte-small': pipeline('feature-extraction', 'gte-small')
};

/**
* Helper function to handle calls from the main thread
* @param e Message event
*/
self.onmessage = (e: MessageEvent<EmbeddingWorkerMessage>) => {
switch (e.data.command) {
case 'startExtractEmbedding': {
const { model, sentences, requestID, detail } = e.data.payload;
startExtractEmbedding(model, sentences, requestID, detail);
break;
}

default: {
console.error('Worker: unknown message', e.data.command);
break;
}
}
};

/**
* Extract embedding from the input text
* @param model Embedding model
* @param text Input text
*/
export const getEmbedding = async (
export const startExtractEmbedding = async (
model: EmbeddingModel,
text: string,
sentences: string[],
requestID: string,
detail: string
) => {
try {
const extractor = await extractors[model];
const sentences = [text];
const output = await extractor(sentences, {
pooling: 'mean',
normalize: true
});

const embedding = Array.from<number>(output.data as Float32Array);
const embeddings: number[][] = [];
const flattenEmbedding: number[] = Array.from<number>(
output.data as Float32Array
);

// Un-flatten the embedding output
for (let i = 0; i < output.dims[0]; i++) {
const curRow = flattenEmbedding.slice(
i * output.dims[1],
(i + 1) * output.dims[1]
);
embeddings.push(curRow);
}

// Send result to the main thread
const message: EmbeddingWorkerMessage = {
Expand All @@ -67,8 +97,8 @@ export const getEmbedding = async (
model,
requestID,
detail,
text,
embedding
sentences,
embeddings
}
};
postMessage(message);
Expand Down

0 comments on commit ac944b0

Please sign in to comment.