Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add batch_size to FAISS __init__ #6401

Merged
merged 6 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions haystack/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
ef_search: int = 20,
ef_construction: int = 80,
validate_index_sync: bool = True,
batch_size: int = 10_000,
):
"""
:param sql_url: SQL connection URL for the database. The default value is "sqlite:///faiss_document_store.db"`. It defaults to a local, file-based SQLite DB. For large scale deployment, we recommend Postgres.
Expand Down Expand Up @@ -103,6 +104,8 @@ def __init__(
:param ef_search: Used only if `index_factory == "HNSW"`.
:param ef_construction: Used only if `index_factory == "HNSW"`.
:param validate_index_sync: Checks if the document count equals the embedding count at initialization time.
:param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face
memory issues, decrease the batch_size.
"""
faiss_import.check()
# special case if we want to load an existing index from disk
Expand Down Expand Up @@ -152,6 +155,7 @@ def __init__(

self.return_embedding = return_embedding
self.embedding_field = embedding_field
self.batch_size = batch_size

self.progress_bar = progress_bar

Expand Down Expand Up @@ -216,7 +220,7 @@ def write_documents(
self,
documents: Union[List[dict], List[Document]],
index: Optional[str] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
duplicate_documents: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
) -> None:
Expand All @@ -240,6 +244,8 @@ def write_documents(
raise NotImplementedError("FAISSDocumentStore does not support headers.")

index = index or self.index
batch_size = batch_size or self.batch_size

duplicate_documents = duplicate_documents or self.duplicate_documents
assert (
duplicate_documents in self.duplicate_documents_options
Expand Down Expand Up @@ -324,7 +330,7 @@ def update_embeddings(
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[FilterType] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
Expand All @@ -342,6 +348,7 @@ def update_embeddings(
:return: None
"""
index = index or self.index
batch_size = batch_size or self.batch_size

if update_existing_embeddings is True:
if filters is None:
Expand Down Expand Up @@ -404,9 +411,10 @@ def get_all_documents(
index: Optional[str] = None,
filters: Optional[FilterType] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
batch_size = batch_size or self.batch_size
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")

Expand All @@ -421,7 +429,7 @@ def get_all_documents_generator(
index: Optional[str] = None,
filters: Optional[FilterType] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> Generator[Document, None, None]:
"""
Expand All @@ -440,6 +448,7 @@ def get_all_documents_generator(
raise NotImplementedError("FAISSDocumentStore does not support headers.")

index = index or self.index
batch_size = batch_size or self.batch_size
documents = super(FAISSDocumentStore, self).get_all_documents_generator(
index=index, filters=filters, batch_size=batch_size, return_embedding=False
)
Expand All @@ -455,13 +464,15 @@ def get_documents_by_id(
self,
ids: List[str],
index: Optional[str] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")

index = index or self.index
batch_size = batch_size or self.batch_size

documents = super(FAISSDocumentStore, self).get_documents_by_id(ids=ids, index=index, batch_size=batch_size)
if self.return_embedding:
for doc in documents:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add batch_size to the __init__ method of FAISS Document Store. This works as the default value for all methods of
FAISS Document Store that support batch_size.