diff --git a/langchain/chains/chat_vector_db/base.py b/langchain/chains/chat_vector_db/base.py index c2bbd4df4986d..8948eebd0f962 100644 --- a/langchain/chains/chat_vector_db/base.py +++ b/langchain/chains/chat_vector_db/base.py @@ -1,7 +1,7 @@ """Chain for chatting with a vector database.""" from __future__ import annotations -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Optional, Union from pydantic import BaseModel @@ -33,6 +33,7 @@ class ChatVectorDBChain(Chain, BaseModel): output_key: str = "answer" return_source_documents: bool = False top_k_docs_for_context: int = 4 + search_distance: Optional[Union[float, str]] """Return the source documents.""" @property @@ -90,7 +91,7 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: else: new_question = question docs = self.vectorstore.similarity_search( - new_question, k=self.top_k_docs_for_context, **vectordbkwargs + new_question, k=self.top_k_docs_for_context, search_distance=self.search_distance, **vectordbkwargs ) new_inputs = inputs.copy() new_inputs["question"] = new_question @@ -113,7 +114,7 @@ async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: new_question = question # TODO: This blocks the event loop, but it's not clear how to avoid it. docs = self.vectorstore.similarity_search( - new_question, k=self.top_k_docs_for_context, **vectordbkwargs + new_question, k=self.top_k_docs_for_context, search_distance=self.search_distance, **vectordbkwargs ) new_inputs = inputs.copy() new_inputs["question"] = new_question diff --git a/langchain/vectorstores/weaviate.py b/langchain/vectorstores/weaviate.py index 2ee82e59b838f..ab904a297f93d 100644 --- a/langchain/vectorstores/weaviate.py +++ b/langchain/vectorstores/weaviate.py @@ -1,7 +1,7 @@ """Wrapper around weaviate vector database.""" from __future__ import annotations -from typing import Any, Iterable, List, Optional +from typing import Any, Iterable, List, Optional, Dict from uuid import uuid4 from langchain.docstore.document import Document @@ -78,9 +78,9 @@ def similarity_search( self, query: str, k: int = 4, **kwargs: Any ) -> List[Document]: """Look up similar documents in weaviate.""" - content = {"concepts": [query]} - if kwargs.get("certainty") and isinstance(kwargs.get("certainty"), float): - content["certainty"] = kwargs.get("certainty") + content: Dict[str, Any] = {"concepts": [query]} + if kwargs.get("search_distance"): + content["certainty"] = kwargs.get("search_distance") query_obj = self._client.query.get(self._index_name, self._query_attrs) result = query_obj.with_near_text(content).with_limit(k).do() docs = []