diff --git a/examples/src/indexes/vector_stores/chroma/fromTexts.ts b/examples/src/indexes/vector_stores/chroma/fromTexts.ts index 029f0e7186d0..5db2707f54bb 100644 --- a/examples/src/indexes/vector_stores/chroma/fromTexts.ts +++ b/examples/src/indexes/vector_stores/chroma/fromTexts.ts @@ -34,3 +34,18 @@ console.log(response); } ] */ + +// You can also filter by metadata +const filteredResponse = await vectorStore.similaritySearch("scared", 2, { + id: 1, +}); + +console.log(filteredResponse); +/* +[ + Document { + pageContent: 'Achilles: Yiikes! What is that?', + metadata: { id: 1 } + } +] +*/ diff --git a/langchain/package.json b/langchain/package.json index 71af92601843..5bef5d37b373 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -377,7 +377,7 @@ "apify-client": "^2.7.1", "axios": "^0.26.0", "cheerio": "^1.0.0-rc.12", - "chromadb": "^1.4.0", + "chromadb": "^1.4.2", "cohere-ai": "^5.0.2", "d3-dsv": "^2.0.0", "dotenv": "^16.0.3", @@ -431,7 +431,7 @@ "apify-client": "^2.7.1", "axios": "*", "cheerio": "^1.0.0-rc.12", - "chromadb": "^1.4.0", + "chromadb": "^1.4.2", "cohere-ai": "^5.0.2", "d3-dsv": "^2.0.0", "epub2": "^3.0.1", diff --git a/langchain/src/vectorstores/chroma.ts b/langchain/src/vectorstores/chroma.ts index 7ecf08a54179..8d89b6dbd108 100644 --- a/langchain/src/vectorstores/chroma.ts +++ b/langchain/src/vectorstores/chroma.ts @@ -10,14 +10,18 @@ export type ChromaLibArgs = url?: string; numDimensions?: number; collectionName?: string; + filter?: object; } | { index?: ChromaClientT; numDimensions?: number; collectionName?: string; + filter?: object; }; export class Chroma extends VectorStore { + declare FilterType: object; + index?: ChromaClientT; collection?: Collection; @@ -28,6 +32,8 @@ export class Chroma extends VectorStore { url: string; + filter?: object; + constructor(embeddings: Embeddings, args: ChromaLibArgs) { super(embeddings, args); this.numDimensions = args.numDimensions; @@ -38,6 +44,8 @@ export class Chroma extends VectorStore { } else if ("url" in args) { this.url = args.url || "http://localhost:8000"; } + + this.filter = args.filter; } async addDocuments(documents: Document[]): Promise { @@ -52,11 +60,15 @@ export class Chroma extends VectorStore { if (!this.collection) { if (!this.index) { const { ChromaClient } = await Chroma.imports(); - this.index = new ChromaClient(this.url); + this.index = new ChromaClient({ path: this.url }); + } + try { + this.collection = await this.index.getOrCreateCollection({ + name: this.collectionName, + }); + } catch (err) { + throw new Error(`Chroma getOrCreateCollection error: ${err}`); } - this.collection = await this.index.getOrCreateCollection( - this.collectionName - ); } return this.collection; @@ -80,22 +92,35 @@ export class Chroma extends VectorStore { const collection = await this.ensureCollection(); const docstoreSize = await collection.count(); - await collection.add( - Array.from({ length: vectors.length }, (_, i) => + await collection.add({ + ids: Array.from({ length: vectors.length }, (_, i) => (docstoreSize + i).toString() ), - vectors, - documents.map(({ metadata }) => metadata), - documents.map(({ pageContent }) => pageContent) - ); + embeddings: vectors, + metadatas: documents.map(({ metadata }) => metadata), + documents: documents.map(({ pageContent }) => pageContent), + }); } - async similaritySearchVectorWithScore(query: number[], k: number) { + async similaritySearchVectorWithScore( + query: number[], + k: number, + filter?: this["FilterType"] + ) { + if (filter && this.filter) { + throw new Error("cannot provide both `filter` and `this.filter`"); + } + const _filter = filter ?? this.filter; + const collection = await this.ensureCollection(); // similaritySearchVectorWithScore supports one query vector at a time // chroma supports multiple query vectors at a time - const result = await collection.query(query, k); + const result = await collection.query({ + query_embeddings: query, + n_results: k, + where: { ..._filter }, + }); const { ids, distances, documents, metadatas } = result; if (!ids || !distances || !documents || !metadatas) { @@ -111,8 +136,8 @@ export class Chroma extends VectorStore { for (let i = 0; i < firstIds.length; i += 1) { results.push([ new Document({ - pageContent: firstDocuments[i], - metadata: firstMetadatas[i], + pageContent: firstDocuments?.[i] ?? "", + metadata: firstMetadatas?.[i] ?? {}, }), firstDistances[i], ]); diff --git a/langchain/src/vectorstores/tests/chroma.int.test.ts b/langchain/src/vectorstores/tests/chroma.int.test.ts new file mode 100644 index 000000000000..5629c8e7b5b7 --- /dev/null +++ b/langchain/src/vectorstores/tests/chroma.int.test.ts @@ -0,0 +1,52 @@ +/* eslint-disable no-process-env */ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +import { beforeEach, describe, expect, test } from "@jest/globals"; +import { faker } from "@faker-js/faker"; +import * as uuid from "uuid"; +import { Document } from "../../document.js"; +import { Chroma } from "../chroma.js"; +import { OpenAIEmbeddings } from "../../embeddings/openai.js"; + +describe("Chroma", () => { + let chromaStore: Chroma; + + beforeEach(async () => { + const embeddings = new OpenAIEmbeddings(); + chromaStore = new Chroma(embeddings, { + url: "http://localhost:8000", + collectionName: "test-collection", + }); + }); + + test("auto-generated ids", async () => { + const pageContent = faker.lorem.sentence(5); + + await chromaStore.addDocuments([{ pageContent, metadata: { foo: "bar" } }]); + + const results = await chromaStore.similaritySearch(pageContent, 1); + + expect(results).toEqual([ + new Document({ metadata: { foo: "bar" }, pageContent }), + ]); + }); + + test("metadata filtering", async () => { + const pageContent = faker.lorem.sentence(5); + const id = uuid.v4(); + + await chromaStore.addDocuments([ + { pageContent, metadata: { foo: "bar" } }, + { pageContent, metadata: { foo: id } }, + { pageContent, metadata: { foo: "qux" } }, + ]); + + // If the filter wasn't working, we'd get all 3 documents back + const results = await chromaStore.similaritySearch(pageContent, 3, { + foo: id, + }); + + expect(results).toEqual([ + new Document({ metadata: { foo: id }, pageContent }), + ]); + }); +}); diff --git a/yarn.lock b/yarn.lock index cdfebd81ad30..02df15b3b832 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10615,6 +10615,13 @@ __metadata: languageName: node linkType: hard +"chromadb@npm:^1.4.2": + version: 1.4.2 + resolution: "chromadb@npm:1.4.2" + checksum: 6fe81e889ae3f6ac39c1a168bcabe12771b0a5271be503db47aa3f8e5c159e29e5eaa7387780fb08d8ef370a42ce58455de1f0622367cb62c4486dfb3b84c20f + languageName: node + linkType: hard + "chrome-trace-event@npm:^1.0.2": version: 1.0.3 resolution: "chrome-trace-event@npm:1.0.3" @@ -18015,7 +18022,7 @@ __metadata: axios: ^0.26.0 binary-extensions: ^2.2.0 cheerio: ^1.0.0-rc.12 - chromadb: ^1.4.0 + chromadb: ^1.4.2 cohere-ai: ^5.0.2 d3-dsv: ^2.0.0 dotenv: ^16.0.3 @@ -18081,7 +18088,7 @@ __metadata: apify-client: ^2.7.1 axios: "*" cheerio: ^1.0.0-rc.12 - chromadb: ^1.4.0 + chromadb: ^1.4.2 cohere-ai: ^5.0.2 d3-dsv: ^2.0.0 epub2: ^3.0.1