Skip to content

Commit

Permalink
support IndexFlatIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Oct 8, 2024
1 parent fa56510 commit 305a0d8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/tutorials/vector_db/optimizing_faiss.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ D, I = index.search(query_vectors, k)

## Choosing the Right Index
* `IndexFlatL2`: Use when accuracy is the primary concern, and the dataset is relatively small.
* `IndexFlatIP`: Use it if you use inner product or cosine for similarity measurement.
* `IndexIVFFlat`: Use when dealing with large datasets and you need to speed up the search process while maintaining reasonable accuracy.
* `IndexPQ`: Use when you need to optimize for memory usage and speed at the cost of some precision.

Expand All @@ -68,7 +69,7 @@ You must specify the values in your configuration file or after instantiating yo
```
vector_db:
db_type: faiss
faiss_index_type: IndexFlatL2, IndexIVFFlat, or IndexIVFPQ
faiss_index_type: IndexFlatL2, IndexFlatIP, IndexIVFFlat, or IndexIVFPQ
faiss_quantized_index_params: Parameters to pass into IndexIVFPQ (d, nlist, m, bits)
faiss_clustered_index_params: Parameters to pass into IndexIVFFlat (d, nlist)
faiss_index_nprobe: Set nprobe value. This defines how many nearby cells to search. It is applicable for both IndexIVFFlat and IndexIVFPQ
Expand Down
8 changes: 6 additions & 2 deletions src/agrag/modules/vector_db/faiss/faiss_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def construct_faiss_index(
Parameters:
----------
index_type: str
Type of FAISS Index to use (IndexFlatL2, IndexIVFFlat, IndexIVFPQ)
Type of FAISS Index to use (IndexFlatL2, IndexFlatIP, IndexIVFFlat, IndexIVFPQ)
embeddings : List[torch.Tensor]
A list of embeddings to be stored in the FAISS index.
embedding_dim: int
Expand All @@ -29,7 +29,7 @@ def construct_faiss_index(
Returns:
-------
Union[IndexFlatL2, IndexIVFFlat, IndexIVFPQ]
Union[IndexFlatL2, IndexFlatIP, IndexIVFFlat, IndexIVFPQ]
The constructed FAISS index.
"""
d = embeddings[0].shape[-1]
Expand All @@ -52,6 +52,8 @@ def construct_faiss_index(
index = faiss.IndexIVFFlat(quantizer, d, **faiss_clustered_index_params)
elif index_type == "IndexFlatL2":
index = quantizer
elif index_type == "IndexFlatIP": # Exact Search for Inner Product
index = faiss.IndexFlatIP(d)
else:
raise ValueError(f"Unsupported FAISS index type {index_type}")

Expand All @@ -61,6 +63,8 @@ def construct_faiss_index(

if index_type == "IndexFlatL2":
index.add(np.array(embeddings))
elif index_type == "IndexFlatIP":
index.add(np.array(embeddings))
elif index_type in ("IndexIVFPQ", "IndexIVFFlat"):
index.train(np.array(embeddings))
assert (
Expand Down
2 changes: 1 addition & 1 deletion src/agrag/modules/vector_db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def save_index(
with open(index_path, "w") as fp:
pass
if db_type == "faiss":
if not isinstance(index, (faiss.IndexFlatL2, faiss.IndexIVFFlat, faiss.IndexIVFPQ)):
if not isinstance(index, (faiss.IndexFlatL2, faiss.IndexFlatIP, faiss.IndexIVFFlat, faiss.IndexIVFPQ)):
raise TypeError("Index for FAISS incorrectly created. Not of a valid FAISS index type.")
success = save_faiss_index(index, index_path)
if s3_bucket and success:
Expand Down

0 comments on commit 305a0d8

Please sign in to comment.