Skip to content

Commit

Permalink
Bind semantic search!
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Wang <jay@zijie.wang>
  • Loading branch information
xiaohk committed Feb 6, 2024
1 parent eca65f2 commit 190fbd7
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 19 deletions.
94 changes: 93 additions & 1 deletion examples/rag-playground/src/components/playground/playground.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import { LitElement, css, unsafeCSS, html, PropertyValues } from 'lit';
import { customElement, property, state, query } from 'lit/decorators.js';
import { unsafeHTML } from 'lit/directives/unsafe-html.js';
import { EmbeddingModel } from '../../workers/embedding';
import type { EmbeddingWorkerMessage } from '../../workers/embedding';
import type { MememoTextViewer } from '../text-viewer/text-viewer';

import '../query-box/query-box';
import '../text-viewer/text-viewer';

import componentCSS from './playground.css?inline';
import EmbeddingWorkerInline from '../../workers/embedding?worker&inline';

interface DatasetInfo {
dataURL: string;
Expand Down Expand Up @@ -36,12 +40,33 @@ export class MememoPlayground extends LitElement {
//==========================================================================||
// Class Properties ||
//==========================================================================||
userQuery = '';
embeddingWorker: Worker;
embeddingWorkerRequestCount = 0;
get embeddingWorkerRequestID() {
this.embeddingWorkerRequestCount++;
return `prompt-panel-${this.embeddingWorkerRequestCount}`;
}

@state()
topK = 10;

@query('mememo-text-viewer')
textViewerComponent: MememoTextViewer | undefined | null;

//==========================================================================||
// Lifecycle Methods ||
//==========================================================================||
constructor() {
super();

this.embeddingWorker = new EmbeddingWorkerInline();
this.embeddingWorker.addEventListener(
'message',
(e: MessageEvent<EmbeddingWorkerMessage>) => {
this.embeddingWorkerMessageHandler(e);
}
);
}

/**
Expand All @@ -55,9 +80,73 @@ export class MememoPlayground extends LitElement {
//==========================================================================||
async initData() {}

/**
* Extract embeddings for the input sentences
* @param sentences Input sentences
*/
getEmbedding(sentences: string[]) {
const message: EmbeddingWorkerMessage = {
command: 'startExtractEmbedding',
payload: {
detail: '',
requestID: this.embeddingWorkerRequestID,
model: EmbeddingModel.gteSmall,
sentences: sentences
}
};
this.embeddingWorker.postMessage(message);
}

/**
* Use k-nearest neighbor to find semantically similar documents
* @param embedding Embeddings of the user query
*/
semanticSearch(embedding: number[]) {
if (!this.textViewerComponent) {
throw Error('textViewerComponent is not initialized.');
}

this.textViewerComponent.semanticSearch(embedding, this.topK, 0.5);
}

/**
* Augment the prompt using relevant documents
* @param relevantDocuments Documents that are relevant to the user query
*/
compilePrompt(relevantDocuments: string[]) {
console.log(relevantDocuments);
}

//==========================================================================||
// Event Handlers ||
//==========================================================================||
userQueryRunClickHandler(e: CustomEvent<string>) {
this.userQuery = e.detail;

// Extract embeddings for the user query
this.getEmbedding([this.userQuery]);
}

embeddingWorkerMessageHandler(e: MessageEvent<EmbeddingWorkerMessage>) {
switch (e.data.command) {
case 'finishExtractEmbedding': {
const { embeddings } = e.data.payload;
// Start semantic search using the embedding
this.semanticSearch(embeddings[0]);
break;
}

case 'error': {
console.error('Worker error: ', e.data.payload.message);
break;
}

default: {
console.error('Worker: unknown message', e.data.command);
break;
}
}
}

//==========================================================================||
// Private Helpers ||
Expand All @@ -70,7 +159,10 @@ export class MememoPlayground extends LitElement {
return html`
<div class="playground">
<div class="container container-input">
<mememo-query-box></mememo-query-box>
<mememo-query-box
@runButtonClicked=${(e: CustomEvent<string>) =>
this.userQueryRunClickHandler(e)}
></mememo-query-box>
</div>
<div class="container container-search">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ export class MememoPromptPanel extends LitElement {
//==========================================================================||
async initData() {}

getEmbedding() {
getEmbedding(sentences: string[]) {
const message: EmbeddingWorkerMessage = {
command: 'startExtractEmbedding',
payload: {
detail: '',
requestID: this.embeddingWorkerRequestID,
model: EmbeddingModel.gteSmall,
sentences: ['Hello, how are you']
sentences: sentences
}
};
this.embeddingWorker.postMessage(message);
Expand Down
78 changes: 77 additions & 1 deletion examples/rag-playground/src/components/query-box/query-box.css
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,84 @@ textarea {
}

.header {
font-weight: 800;
font-size: var(--font-u1);
color: var(--gray-600);
line-height: 1;

width: 100%;
display: flex;
flex-direction: row;
justify-content: space-between;
justify-content: center;
gap: 10px;
align-items: center;

.text {
font-weight: 800;
}
}

.svg-icon {
display: flex;
justify-content: center;
align-items: center;
width: 1em;
height: 1em;

color: currentColor;
transition: transform 80ms linear;
transform-origin: center;

& svg {
fill: currentColor;
width: 100%;
height: 100%;
}
}

.button-group {
display: flex;
flex-flow: row;
align-items: center;
gap: 10px;
}

button {
all: unset;

display: flex;
line-height: 1;
font-size: var(--font-d2);
padding: 4px 6px;
border-radius: 5px;
white-space: nowrap;
height: var(--font-d2);

cursor: pointer;
user-select: none;
-webkit-user-select: none;

background-color: color-mix(in lab, var(--gray-200), white 20%);
color: var(--gray-800);
display: flex;
flex-flow: row;
align-items: center;
font-size: var(--font-d1);

&:hover {
background-color: color-mix(in lab, var(--gray-300), white 30%);
}

&:active {
background-color: color-mix(in lab, var(--gray-300), white 20%);
}

.svg-icon {
position: relative;
top: 1px;
margin-right: 3px;
color: var(--gray-700);
width: 12px;
height: 12px;
}
}
42 changes: 40 additions & 2 deletions examples/rag-playground/src/components/query-box/query-box.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import { customElement, property, state, query } from 'lit/decorators.js';
import { unsafeHTML } from 'lit/directives/unsafe-html.js';

import componentCSS from './query-box.css?inline';
import searchIcon from '../../images/icon-search.svg?raw';
import refreshIcon from '../../images/icon-refresh2.svg?raw';
import playIcon from '../../images/icon-play.svg?raw';

/**
* Query box element.
Expand All @@ -13,12 +16,17 @@ export class MememoQueryBox extends LitElement {
//==========================================================================||
// Class Properties ||
//==========================================================================||
userQuery: string;
defaultQuery: string;

//==========================================================================||
// Lifecycle Methods ||
//==========================================================================||
constructor() {
super();
this.defaultQuery =
'What are some ways to integrate information retrieval into machine learning?';
this.userQuery = this.defaultQuery;
}

/**
Expand All @@ -35,6 +43,20 @@ export class MememoQueryBox extends LitElement {
//==========================================================================||
// Event Handlers ||
//==========================================================================||
textareaInput(e: InputEvent) {
const textareaElement = e.currentTarget as HTMLTextAreaElement;
this.userQuery = textareaElement.value;
}

runButtonClicked() {
// Notify the parent to run the user query
const event = new CustomEvent('runButtonClicked', {
bubbles: true,
composed: true,
detail: this.userQuery
});
this.dispatchEvent(event);
}

//==========================================================================||
// Private Helpers ||
Expand All @@ -46,8 +68,24 @@ export class MememoQueryBox extends LitElement {
render() {
return html`
<div class="query-box">
<div class="header">User Query</div>
<textarea rows="5"></textarea>
<div class="header">
<span class="text">User Query</span>
<div class="button-group">
<button>
<span class="svg-icon">${unsafeHTML(refreshIcon)}</span>
random
</button>
<button @click=${() => this.runButtonClicked()}>
<span class="svg-icon">${unsafeHTML(playIcon)}</span>
run
</button>
</div>
</div>
<textarea rows="5" @input=${(e: InputEvent) => this.textareaInput(e)}>
${this.userQuery}</textarea
>
</div>
`;
}
Expand Down
35 changes: 35 additions & 0 deletions examples/rag-playground/src/components/text-viewer/text-viewer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ export class MememoTextViewer extends LitElement {
return this.lexicalSearchRequestCount++;
}

semanticSearchRequestCount = 0;
get semanticSearchRequestID() {
return this.semanticSearchRequestCount++;
}

//==========================================================================||
// Lifecycle Methods ||
//==========================================================================||
Expand Down Expand Up @@ -125,6 +130,25 @@ export class MememoTextViewer extends LitElement {
this.mememoWorker.postMessage(message);
}

/**
* Retrieve relevant documents using MeMemo
* @param embedding Input embedding
* @param topK Top k relevant documents to retrieve
* @param maxDistance Distance threshold for relevance
*/
semanticSearch(embedding: number[], topK: number, maxDistance: number) {
const message: MememoWorkerMessage = {
command: 'startSemanticSearch',
payload: {
embedding,
requestID: this.semanticSearchRequestID,
topK,
maxDistance
}
};
this.mememoWorker.postMessage(message);
}

//==========================================================================||
// Event Handlers ||
//==========================================================================||
Expand Down Expand Up @@ -249,6 +273,17 @@ export class MememoTextViewer extends LitElement {
break;
}

case 'finishSemanticSearch': {
const { documents, documentDistances, embedding } = e.data.payload;

// Update the shown documents
this.curQuery = null;
this.isFiltered = true;
this.curDocuments = documents;
this.shownDocuments = this.curDocuments.slice(0, this.shownDocumentCap);
break;
}

case 'finishExportIndex': {
const indexJSON = e.data.payload.indexJSON;
downloadJSON(indexJSON, undefined, 'index.json');
Expand Down
7 changes: 7 additions & 0 deletions examples/rag-playground/src/images/icon-play.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 11 additions & 0 deletions examples/rag-playground/src/images/icon-refresh2.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 190fbd7

Please sign in to comment.