diff --git a/src/pai_rag/integrations/vector_stores/elasticsearch/my_elasticsearch.py b/src/pai_rag/integrations/vector_stores/elasticsearch/my_elasticsearch.py index e26013a3..300da81c 100644 --- a/src/pai_rag/integrations/vector_stores/elasticsearch/my_elasticsearch.py +++ b/src/pai_rag/integrations/vector_stores/elasticsearch/my_elasticsearch.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Union import nest_asyncio +import threading import numpy as np from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.schema import BaseNode, MetadataMode, TextNode @@ -198,10 +199,11 @@ class Config: text_field: str = "content" vector_field: str = "embedding" batch_size: int = 200 + embedding_dimension: int = 1536 distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE" retrieval_strategy: AsyncRetrievalStrategy - _store = PrivateAttr() + _local_storage = PrivateAttr() def __init__( self, @@ -221,40 +223,13 @@ def __init__( ) -> None: nest_asyncio.apply() - if not es_client: - es_client = get_elasticsearch_client( - url=es_url, - cloud_id=es_cloud_id, - api_key=es_api_key, - username=es_user, - password=es_password, - ) + self._local_storage = threading.local() if retrieval_strategy is None: retrieval_strategy = AsyncDenseVectorStrategy( distance=DistanceMetric[distance_strategy] ) - metadata_mappings = { - "document_id": {"type": "keyword"}, - "doc_id": {"type": "keyword"}, - "ref_doc_id": {"type": "keyword"}, - } - - self._store = AsyncVectorStore( - user_agent=get_user_agent(), - client=es_client, - index=index_name, - retrieval_strategy=retrieval_strategy, - text_field=text_field, - vector_field=vector_field, - metadata_mappings=metadata_mappings, - num_dimensions=embedding_dimension, - ) - asyncio.get_event_loop().run_until_complete( - self._store._create_index_if_not_exists() - ) - super().__init__( index_name=index_name, es_client=es_client, @@ -265,18 +240,53 @@ def __init__( es_password=es_password, text_field=text_field, vector_field=vector_field, + embedding_dimension=embedding_dimension, batch_size=batch_size, distance_strategy=distance_strategy, retrieval_strategy=retrieval_strategy, ) + asyncio.get_event_loop().run_until_complete( + self._get_store()._create_index_if_not_exists() + ) @property def client(self) -> Any: """Get async elasticsearch client.""" - return self._store.client + if self.es_client is not None: + return self.es_client + + if not hasattr(self._local_storage, "es_client"): + self._local_storage.es_client = get_elasticsearch_client( + url=self.es_url, + cloud_id=self.es_cloud_id, + api_key=self.es_api_key, + username=self.es_user, + password=self.es_password, + ) + return self._local_storage.es_client + + def _get_store(self, retrieval_strategy=None) -> Any: + metadata_mappings = { + "document_id": {"type": "keyword"}, + "doc_id": {"type": "keyword"}, + "ref_doc_id": {"type": "keyword"}, + } + if not retrieval_strategy: + retrieval_strategy = self.retrieval_strategy + + return AsyncVectorStore( + user_agent=get_user_agent(), + client=self.client, + index=self.index_name, + retrieval_strategy=retrieval_strategy, + text_field=self.text_field, + vector_field=self.vector_field, + metadata_mappings=metadata_mappings, + num_dimensions=self.embedding_dimension, + ) def close(self) -> None: - return asyncio.get_event_loop().run_until_complete(self._store.close()) + return asyncio.get_event_loop().run_until_complete(self._get_store().close()) def add( self, @@ -345,10 +355,11 @@ async def async_add( texts.append(node.get_content(metadata_mode=MetadataMode.NONE)) metadatas.append(node_to_metadata_dict(node, remove_text=True)) - if not self._store.num_dimensions: - self._store.num_dimensions = len(embeddings[0]) + es_store = self._get_store() + if not es_store.num_dimensions: + es_store.num_dimensions = len(embeddings[0]) - return await self._store.add_texts( + return await es_store.add_texts( texts=texts, metadatas=metadatas, vectors=embeddings, @@ -385,7 +396,8 @@ async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: Raises: Exception: If AsyncElasticsearch delete_by_query fails. """ - await self._store.delete( + es_store = self._get_store() + await es_store.delete( query={"term": {"metadata.ref_doc_id": ref_doc_id}}, **delete_kwargs ) @@ -463,26 +475,13 @@ async def aquery( else: retrieval_strategy = AsyncDenseVectorStrategy() - metadata_mappings = { - "document_id": {"type": "keyword"}, - "doc_id": {"type": "keyword"}, - "ref_doc_id": {"type": "keyword"}, - } - self._store = AsyncVectorStore( - user_agent=get_user_agent(), - client=self.es_client, - index=self.index_name, - retrieval_strategy=retrieval_strategy, - text_field=self.text_field, - vector_field=self.vector_field, - metadata_mappings=metadata_mappings, - ) + es_store = self._get_store(retrieval_strategy=retrieval_strategy) if query.filters is not None and len(query.filters.legacy_filters()) > 0: filter = [_to_elasticsearch_filter(query.filters)] else: filter = es_filter or [] - hits = await self._store.search( + hits = await es_store.search( query=query.query_str, query_vector=query.query_embedding, k=query.similarity_top_k,