Skip to content

Commit

Permalink
Fix elastic search threading unsafe bug (#215)
Browse files Browse the repository at this point in the history
* Fix bug

* Remote duplicate properties
  • Loading branch information
moria97 authored Sep 13, 2024
1 parent 2f288c1 commit 6566427
Showing 1 changed file with 49 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Union

import nest_asyncio
import threading
import numpy as np
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
Expand Down Expand Up @@ -198,10 +199,11 @@ class Config:
text_field: str = "content"
vector_field: str = "embedding"
batch_size: int = 200
embedding_dimension: int = 1536
distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE"
retrieval_strategy: AsyncRetrievalStrategy

_store = PrivateAttr()
_local_storage = PrivateAttr()

def __init__(
self,
Expand All @@ -221,40 +223,13 @@ def __init__(
) -> None:
nest_asyncio.apply()

if not es_client:
es_client = get_elasticsearch_client(
url=es_url,
cloud_id=es_cloud_id,
api_key=es_api_key,
username=es_user,
password=es_password,
)
self._local_storage = threading.local()

if retrieval_strategy is None:
retrieval_strategy = AsyncDenseVectorStrategy(
distance=DistanceMetric[distance_strategy]
)

metadata_mappings = {
"document_id": {"type": "keyword"},
"doc_id": {"type": "keyword"},
"ref_doc_id": {"type": "keyword"},
}

self._store = AsyncVectorStore(
user_agent=get_user_agent(),
client=es_client,
index=index_name,
retrieval_strategy=retrieval_strategy,
text_field=text_field,
vector_field=vector_field,
metadata_mappings=metadata_mappings,
num_dimensions=embedding_dimension,
)
asyncio.get_event_loop().run_until_complete(
self._store._create_index_if_not_exists()
)

super().__init__(
index_name=index_name,
es_client=es_client,
Expand All @@ -265,18 +240,53 @@ def __init__(
es_password=es_password,
text_field=text_field,
vector_field=vector_field,
embedding_dimension=embedding_dimension,
batch_size=batch_size,
distance_strategy=distance_strategy,
retrieval_strategy=retrieval_strategy,
)
asyncio.get_event_loop().run_until_complete(
self._get_store()._create_index_if_not_exists()
)

@property
def client(self) -> Any:
"""Get async elasticsearch client."""
return self._store.client
if self.es_client is not None:
return self.es_client

if not hasattr(self._local_storage, "es_client"):
self._local_storage.es_client = get_elasticsearch_client(
url=self.es_url,
cloud_id=self.es_cloud_id,
api_key=self.es_api_key,
username=self.es_user,
password=self.es_password,
)
return self._local_storage.es_client

def _get_store(self, retrieval_strategy=None) -> Any:
metadata_mappings = {
"document_id": {"type": "keyword"},
"doc_id": {"type": "keyword"},
"ref_doc_id": {"type": "keyword"},
}
if not retrieval_strategy:
retrieval_strategy = self.retrieval_strategy

return AsyncVectorStore(
user_agent=get_user_agent(),
client=self.client,
index=self.index_name,
retrieval_strategy=retrieval_strategy,
text_field=self.text_field,
vector_field=self.vector_field,
metadata_mappings=metadata_mappings,
num_dimensions=self.embedding_dimension,
)

def close(self) -> None:
return asyncio.get_event_loop().run_until_complete(self._store.close())
return asyncio.get_event_loop().run_until_complete(self._get_store().close())

def add(
self,
Expand Down Expand Up @@ -345,10 +355,11 @@ async def async_add(
texts.append(node.get_content(metadata_mode=MetadataMode.NONE))
metadatas.append(node_to_metadata_dict(node, remove_text=True))

if not self._store.num_dimensions:
self._store.num_dimensions = len(embeddings[0])
es_store = self._get_store()
if not es_store.num_dimensions:
es_store.num_dimensions = len(embeddings[0])

return await self._store.add_texts(
return await es_store.add_texts(
texts=texts,
metadatas=metadatas,
vectors=embeddings,
Expand Down Expand Up @@ -385,7 +396,8 @@ async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
Raises:
Exception: If AsyncElasticsearch delete_by_query fails.
"""
await self._store.delete(
es_store = self._get_store()
await es_store.delete(
query={"term": {"metadata.ref_doc_id": ref_doc_id}}, **delete_kwargs
)

Expand Down Expand Up @@ -463,26 +475,13 @@ async def aquery(
else:
retrieval_strategy = AsyncDenseVectorStrategy()

metadata_mappings = {
"document_id": {"type": "keyword"},
"doc_id": {"type": "keyword"},
"ref_doc_id": {"type": "keyword"},
}
self._store = AsyncVectorStore(
user_agent=get_user_agent(),
client=self.es_client,
index=self.index_name,
retrieval_strategy=retrieval_strategy,
text_field=self.text_field,
vector_field=self.vector_field,
metadata_mappings=metadata_mappings,
)
es_store = self._get_store(retrieval_strategy=retrieval_strategy)

if query.filters is not None and len(query.filters.legacy_filters()) > 0:
filter = [_to_elasticsearch_filter(query.filters)]
else:
filter = es_filter or []
hits = await self._store.search(
hits = await es_store.search(
query=query.query_str,
query_vector=query.query_embedding,
k=query.similarity_top_k,
Expand Down

0 comments on commit 6566427

Please sign in to comment.