Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed atlas vector search #2435

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add startup event to initialize db.apply jobs
- Update job_id after job submission
- Fixed default event.uuid
- Fixed atlas vector search

#### New Features & Functionality

Expand Down
4 changes: 3 additions & 1 deletion superduper/backends/local/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def drop(self, force: bool = False):
):
logging.warn('Aborting...')
shutil.rmtree(self.conn, ignore_errors=force)
os.makedirs(self.conn)
if os.path.exists(self.conn):
logging.warn('Failed to drop artifact store')
os.makedirs(self.conn, exist_ok=True)

def put_bytes(
self,
Expand Down
4 changes: 3 additions & 1 deletion superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def initialize_vector_searcher(
"""
searcher_type = searcher_type or s.CFG.cluster.vector_search.type

logging.info(f"Initializing vector searcher with type {searcher_type}")
if isinstance(vi, str):
vi = self.vector_indices.force_load(vi)
from superduper import VectorIndex
Expand All @@ -160,6 +161,7 @@ def initialize_vector_searcher(
clt = vi.indexing_listener.select.table_or_collection

vector_search_cls = vector_searcher_implementations[searcher_type]
logging.info(f"Using vector searcher: {vector_search_cls}")
vector_comparison = vector_search_cls.from_component(vi)

assert isinstance(clt.identifier, str), 'clt.identifier must be a string'
Expand Down Expand Up @@ -209,7 +211,7 @@ def drop(self, force: bool = False, data: bool = False):
):
logging.warn("Aborting...")

if self._cfg.cluster.vector_search.uri is not None:
if self._cfg.cluster.vector_search.uri:
for vi in self.show('vector_index'):
FastVectorSearcher.drop_remote(vi)

Expand Down
267 changes: 121 additions & 146 deletions superduper/vector_search/atlas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import copy
import json
import re
import typing as t
from functools import cached_property

from superduper import CFG, logging
from superduper.components.model import APIBaseModel
from superduper.vector_search.base import BaseVectorSearcher
from superduper.vector_search.base import BaseVectorSearcher, VectorItem

if t.TYPE_CHECKING:
from superduper.components.vector_index import VectorIndex
Expand All @@ -22,29 +18,32 @@ class MongoAtlasVectorSearcher(BaseVectorSearcher):
:param output_path: Path to the output
"""

native_service: t.ClassVar[bool] = False

def __init__(
self,
identifier: str,
collection: str,
dimensions: t.Optional[int] = None,
dimensions: int,
measure: t.Optional[str] = None,
output_path: t.Optional[str] = None,
):
import pymongo

self.identifier = identifier
vector_search_uri = CFG.cluster.vector_search.uri
assert vector_search_uri, 'Vector search URI is required'
db_name = vector_search_uri.split('/')[-1]
vector_search_uri = CFG.cluster.vector_search.uri or CFG.data_backend
assert vector_search_uri, "Vector search URI is required"
db_name = vector_search_uri.split("/")[-1]
self.database = getattr(pymongo.MongoClient(vector_search_uri), db_name)
assert output_path
self.output_path = output_path
self.collection = collection
self.measure = measure
self.dimensions = dimensions
self._is_exists = False

if not self._check_if_exists(identifier):
self._create_index(collection, output_path)
self._check_if_exists(create=True)
super().__init__(identifier=identifier, dimensions=dimensions, measure=measure)

def __len__(self):
pass
Expand All @@ -54,89 +53,43 @@ def index(self):
"""Return the index collection."""
return self.database[self.collection]

def is_initialized(self, identifier):
"""Check if vector index initialized."""
return self._check_if_exists(create=False)

@classmethod
def from_component(cls, vi: 'VectorIndex'):
def from_component(cls, vi: "VectorIndex"):
"""Create a vector searcher from a vector index.

:param vi: VectorIndex instance
"""
from superduper.components.listener import Listener
from superduper.components.model import ObjectModel

assert isinstance(vi.indexing_listener, Listener)
assert vi.indexing_listener.select is not None
collection = vi.indexing_listener.select.table_or_collection.identifier

indexing_key = vi.indexing_listener.key
assert isinstance(
indexing_key, str
), 'Only single key is support for atlas search'
if indexing_key.startswith(CFG.output_prefix):
indexing_key = indexing_key[len(CFG.output_prefix) :]
assert isinstance(vi.indexing_listener.model, ObjectModel) or isinstance(
vi.indexing_listener.model, APIBaseModel
)
assert isinstance(collection, str), 'Collection is required to be a string'
indexing_model = vi.indexing_listener.model.identifier
indexing_version = vi.indexing_listener.model.version
output_path = (
f'{CFG.output_prefix}{indexing_key}.{indexing_model}.{indexing_version}'
)
path = collection = vi.indexing_listener.outputs

return MongoAtlasVectorSearcher(
identifier=vi.identifier,
dimensions=vi.dimensions,
measure=vi.measure,
output_path=output_path,
output_path=path,
collection=collection,
)

def _replace_document_with_vector(self, step):
step = copy.deepcopy(step)
assert "like" in step['$vectorSearch']
vector = step['$vectorSearch']['like']
step['$vectorSearch']['queryVector'] = vector
def add(self, items: t.Sequence[VectorItem], cache: bool = False) -> None:
"""
Add items to the index.

step['$vectorSearch']['path'] = self.output_path
step['$vectorSearch']['index'] = self.identifier
del step['$vectorSearch']['like']
return step
:param items: t.Sequence of VectorItems
"""
self._check_if_exists(create=True)

def _prepare_pipeline(self, pipeline):
pipeline = copy.deepcopy(pipeline)
try:
search_step = next(
(i, step) for i, step in enumerate(pipeline) if '$vectorSearch' in step
)
except StopIteration:
return pipeline
pipeline[search_step[0]] = self._replace_document_with_vector(
search_step[1],
)
return pipeline

def _find(self, h, n=100):
h = self.to_list(h)
pl = [
{
"$vectorSearch": {
'like': h,
"limit": n,
'numCandidates': n,
}
},
{'$addFields': {'score': {'$meta': 'vectorSearchScore'}}},
]
pl = self._prepare_pipeline(
pl,
)
cursor = self.index.aggregate(pl)
scores = []
ids = []
for vector in cursor:
scores.append(vector['score'])
ids.append(str(vector['_id']))
return ids, scores
def delete(self, ids: t.Sequence[str]) -> None:
"""Remove items from the index.

:param ids: t.Sequence of ids of vectors.
"""

def find_nearest_from_id(self, id: str, n=100, within_ids=None):
"""Find the nearest vectors to the given ID.
Expand All @@ -145,7 +98,7 @@ def find_nearest_from_id(self, id: str, n=100, within_ids=None):
:param n: number of nearest vectors to return
:param within_ids: list of IDs to search within
"""
h = self.index.find_one({'id': id})
h = self.index.find_one({"_id": id})[self.output_path]
return self.find_nearest_from_array(h, n=n, within_ids=within_ids)

def find_nearest_from_array(self, h, n=100, within_ids=None):
Expand All @@ -155,84 +108,106 @@ def find_nearest_from_array(self, h, n=100, within_ids=None):
:param n: number of nearest vectors to return
:param within_ids: list of IDs to search within
"""
return self._find(h, n=n)

def add(self, items):
"""Add vectors to the index.

:param items: List of vectors to add
"""
items = list(map(lambda x: x.to_dict(), items))
if not CFG.cluster.vector_search == CFG.data_backend:
self.index.insert_many(items)
self._check_if_exists(create=True)
self._check_queryable()
vector_search = {
"index": self.identifier,
"path": self.output_path,
"queryVector": h,
"numCandidates": n,
"limit": n,
}
if within_ids:
vector_search["filter"] = {"_id": {"$in": within_ids}}

def delete(self, items):
"""Delete vectors from the index.
project = {
"_id": 1,
"_source": 1,
"score": {"$meta": "vectorSearchScore"},
}

:param items: List of vectors to delete
"""
ids = list(map(lambda x: x.id, items))
if not CFG.cluster.vector_search == CFG.data_backend:
self.index.delete_many({'id': {'$in': ids}})
pipeline = [
{"$vectorSearch": vector_search},
{"$project": project},
]

def _create_index(self, collection: str, output_path: str):
"""
Create a vector index in the data backend if an Atlas deployment.
cursor = self.index.aggregate(pipeline)
scores = []
ids = []
for vector in cursor:
scores.append(vector["score"])
ids.append(str(vector["_source"]))
return ids, scores

:param collection: Collection name
:param output_path: Path to the output
"""
_, key, model, version = output_path.split('.')
# TODO: Need to fix this and test it with CFG.output_prefix
if re.match(r'^_outputs\.[A-Za-z0-9_]+\.[A-Za-z0-9_]+', key):
key = key.split('.')[1]
def _create_index(self):
"""Create a vector index in the data backend if an Atlas deployment."""
if self.collection not in self.database.list_collection_names():
logging.warn(
f"Collection {self.collection} does not exist. " "Cannot create index."
)
return
from pymongo.operations import SearchIndexModel

fields4 = {
str(version): [
definition = {
"fields": [
{
"dimensions": self.dimensions,
"type": "vector",
"numDimensions": self.dimensions,
"path": self.output_path,
"similarity": self.measure,
"type": "knnVector",
}
},
]
}
fields3 = {
model: {
"fields": fields4,
"type": "document",
}
}
fields2 = {
key: {
"fields": fields3,
"type": "document",
}
}
fields1 = {
"_outputs": {
"fields": fields2,
"type": "document",
}
}
index_definition = {
"createSearchIndexes": collection,
"indexes": [
{
"name": self.identifier,
"definition": {
"mappings": {
"dynamic": True,
"fields": fields1,
}
},
}
],
}
logging.info(json.dumps(index_definition, indent=2))
self.database.command(index_definition)

def _check_if_exists(self, index: str):
indexes = self.index.list_search_indexes()
return len(
[i for i in indexes if i['name'] == index and i['status'] == 'READY']
search_index_model = SearchIndexModel(
definition=definition,
name=self.identifier,
type="vectorSearch",
)
logging.info(
f"Creating search index [{self.identifier}] on {self.collection} "
f"-- Definition: {definition}"
)
result = self.index.create_search_index(model=search_index_model)
return result

def _check_if_exists(self, create=True):
if self._is_exists:
return True
index = self._get_index()
if bool(index):
self._is_exists = True
elif create:
self._create_index()
return self._is_exists

def _check_queryable(self):
index = self._get_index()
if not index:
raise FileNotFoundError(
f"Index {self.identifier} does not exist in the collection "
f"{self.collection}. Cannot perform query."
)

if not index.get("queryable"):
raise FileNotFoundError(
f"Index {self.identifier} is pending and not yet queryable. "
"Please wait until the index is fully ready for queries."
f"Cannot perform query. "
f"You need to wait for the index to be queryable. "
f"Index: {index}"
)
return True

def _get_index(self):
try:
indexes = self.index.list_search_indexes()
except Exception:
return False

indexes = [i for i in indexes if i["name"] == self.identifier]

if not indexes:
return None
else:
return indexes[0]
Loading
Loading