Skip to content

Commit

Permalink
Adding the search_distance attribute to the ChatVectorDBChain base cl…
Browse files Browse the repository at this point in the history
…ass.

Fixing some bugs
  • Loading branch information
mpuig committed Feb 22, 2023
1 parent 9b780d4 commit 9be07ac
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
7 changes: 4 additions & 3 deletions langchain/chains/chat_vector_db/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Chain for chatting with a vector database."""
from __future__ import annotations

from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Optional, Union

from pydantic import BaseModel

Expand Down Expand Up @@ -33,6 +33,7 @@ class ChatVectorDBChain(Chain, BaseModel):
output_key: str = "answer"
return_source_documents: bool = False
top_k_docs_for_context: int = 4
search_distance: Optional[Union[float, str]]
"""Return the source documents."""

@property
Expand Down Expand Up @@ -90,7 +91,7 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
else:
new_question = question
docs = self.vectorstore.similarity_search(
new_question, k=self.top_k_docs_for_context, **vectordbkwargs
new_question, k=self.top_k_docs_for_context, search_distance=self.search_distance, **vectordbkwargs
)
new_inputs = inputs.copy()
new_inputs["question"] = new_question
Expand All @@ -113,7 +114,7 @@ async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
new_question = question
# TODO: This blocks the event loop, but it's not clear how to avoid it.
docs = self.vectorstore.similarity_search(
new_question, k=self.top_k_docs_for_context, **vectordbkwargs
new_question, k=self.top_k_docs_for_context, search_distance=self.search_distance, **vectordbkwargs
)
new_inputs = inputs.copy()
new_inputs["question"] = new_question
Expand Down
8 changes: 4 additions & 4 deletions langchain/vectorstores/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Wrapper around weaviate vector database."""
from __future__ import annotations

from typing import Any, Iterable, List, Optional
from typing import Any, Iterable, List, Optional, Dict
from uuid import uuid4

from langchain.docstore.document import Document
Expand Down Expand Up @@ -78,9 +78,9 @@ def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Look up similar documents in weaviate."""
content = {"concepts": [query]}
if kwargs.get("certainty") and isinstance(kwargs.get("certainty"), float):
content["certainty"] = kwargs.get("certainty")
content: Dict[str, Any] = {"concepts": [query]}
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(self._index_name, self._query_attrs)
result = query_obj.with_near_text(content).with_limit(k).do()
docs = []
Expand Down

0 comments on commit 9be07ac

Please sign in to comment.