Skip to content

Commit

Permalink
Add Cassandra vector store implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Sep 23, 2024
1 parent ea46820 commit a1dd4d5
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 3 deletions.
1 change: 1 addition & 0 deletions dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ numpy
pypi
nbformat
semversioner
cassio

# Library Methods
iterrows
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/verbs/text/embed/text_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def text_embed(
max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai
organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai
vector_store: # The optional configuration for the vector store
type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb
type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb, cassandra
<...>
```
"""
Expand Down
2 changes: 2 additions & 0 deletions graphrag/vector_stores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

from .azure_ai_search import AzureAISearch
from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult
from .cassandra import CassandraVectorStore
from .lancedb import LanceDBVectorStore
from .typing import VectorStoreFactory, VectorStoreType

__all__ = [
"AzureAISearch",
"BaseVectorStore",
"CassandraVectorStore",
"LanceDBVectorStore",
"VectorStoreDocument",
"VectorStoreFactory",
Expand Down
122 changes: 122 additions & 0 deletions graphrag/vector_stores/cassandra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""The Apache Cassandra vector store implementation package."""

from typing import Any

import cassio
from cassio.table import MetadataVectorCassandraTable
from typing_extensions import override

from graphrag.model.types import TextEmbedder

from .base import (
DEFAULT_VECTOR_SIZE,
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
)


class CassandraVectorStore(BaseVectorStore):
"""The Apache Cassandra vector storage implementation."""

def __init__(
self,
collection_name: str,
token: str | None = None,
database_id: str | None = None,
keyspace: str | None = None,
**kwargs: Any,
):
super().__init__(collection_name)
cassio.init(
token=token,
database_id=database_id,
keyspace=keyspace,
**kwargs,
)

@override
def connect(self, keyspace: str | None = None, **kwargs: Any) -> None:
self.db_connection = cassio.config.resolve_session()
self.keyspace = cassio.config.resolve_keyspace(keyspace)

@override
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
if overwrite:
self.db_connection.execute(
f"DROP TABLE IF EXISTS {self.keyspace}.{self.collection_name};"
)

if not documents:
return

if not self.document_collection or overwrite:
dimension = DEFAULT_VECTOR_SIZE
for doc in documents:
if doc.vector:
dimension = len(doc.vector)
break
self.document_collection = MetadataVectorCassandraTable(
table=self.collection_name,
vector_dimension=dimension,
primary_key_type="TEXT",
)

futures = [
self.document_collection.put_async(
row_id=doc.id,
body_blob=doc.text,
vector=doc.vector,
metadata=doc.attributes,
)
for doc in documents
if doc.vector
]

for future in futures:
future.result()

@override
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
msg = "Cassandra vector store doesn't support filtering by IDs."
raise NotImplementedError(msg)

@override
def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
response = self.document_collection.metric_ann_search(
vector=query_embedding,
n=k,
metric="cos",
**kwargs,
)

return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=doc["row_id"],
text=doc["body_blob"],
vector=doc["vector"],
attributes=doc["metadata"],
),
score=doc["distance"],
)
for doc in response
]

@override
def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
query_embedding = text_embedder(text)
if query_embedding:
return self.similarity_search_by_vector(
query_embedding=query_embedding, k=k, **kwargs
)
return []
6 changes: 5 additions & 1 deletion graphrag/vector_stores/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from enum import Enum
from typing import ClassVar

from . import BaseVectorStore, CassandraVectorStore
from .azure_ai_search import AzureAISearch
from .lancedb import LanceDBVectorStore

Expand All @@ -15,6 +16,7 @@ class VectorStoreType(str, Enum):

LanceDB = "lancedb"
AzureAISearch = "azure_ai_search"
Cassandra = "cassandra"


class VectorStoreFactory:
Expand All @@ -30,13 +32,15 @@ def register(cls, vector_store_type: str, vector_store: type):
@classmethod
def get_vector_store(
cls, vector_store_type: VectorStoreType | str, kwargs: dict
) -> LanceDBVectorStore | AzureAISearch:
) -> BaseVectorStore:
"""Get the vector store type from a string."""
match vector_store_type:
case VectorStoreType.LanceDB:
return LanceDBVectorStore(**kwargs)
case VectorStoreType.AzureAISearch:
return AzureAISearch(**kwargs)
case VectorStoreType.Cassandra:
return CassandraVectorStore(**kwargs)
case _:
if vector_store_type in cls.vector_store_types:
return cls.vector_store_types[vector_store_type](**kwargs)
Expand Down
80 changes: 79 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ datashaper = "^0.0.49"
azure-search-documents = "^11.4.0"
lancedb = "^0.13.0"


# Async IO
aiolimiter = "^1.1.0"
aiofiles = "^24.1.0"
Expand Down Expand Up @@ -87,6 +88,7 @@ azure-identity = "^1.17.1"
json-repair = "^0.28.4"

future = "^1.0.0" # Needed until graspologic fixes their dependency
cassio = "^0.1.9"

[tool.poetry.group.dev.dependencies]
coverage = "^7.6.0"
Expand Down

0 comments on commit a1dd4d5

Please sign in to comment.