Skip to content

Commit

Permalink
Vectorstore: use a retriever query for hybrid search
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Sep 30, 2024
1 parent 14e6265 commit 32ab831
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 90 deletions.
53 changes: 34 additions & 19 deletions elasticsearch/helpers/vectorstore/_async/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,31 +283,46 @@ def _hybrid(
) -> Dict[str, Any]:
# Add a query to the knn query.
# RRF is used to even the score from the knn query and text query
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
rrf_options = {}
if isinstance(self.rrf, Dict):
if "rank_constant" in self.rrf:
rrf_options["rank_constant"] = self.rrf["rank_constant"]
if "window_size" in self.rrf:
# 'window_size' was renamed to 'rank_window_size', but we support
# the older name for backwards compatiblit
rrf_options["rank_window_size"] = self.rrf["window_size"]
if "rank_window_size" in self.rrf:
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
query_body = {
"knn": knn,
"query": {
"bool": {
"must": [
"retriever": {
"rrf": {
"retrievers": [
{
"match": {
self.text_field: {
"query": query,
}
}
}
"standard": {
"query": {
"bool": {
"must": [
{
"match": {
self.text_field: {
"query": query,
}
}
}
],
"filter": filter,
}
},
},
},
{"knn": knn},
],
"filter": filter,
}
**rrf_options,
},
},
}

if isinstance(self.rrf, Dict):
query_body["rank"] = {"rrf": self.rrf}
elif isinstance(self.rrf, bool) and self.rrf is True:
query_body["rank"] = {"rrf": {}}

return query_body

def needs_inference(self) -> bool:
Expand Down
53 changes: 34 additions & 19 deletions elasticsearch/helpers/vectorstore/_sync/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,31 +283,46 @@ def _hybrid(
) -> Dict[str, Any]:
# Add a query to the knn query.
# RRF is used to even the score from the knn query and text query
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
rrf_options = {}
if isinstance(self.rrf, Dict):
if "rank_constant" in self.rrf:
rrf_options["rank_constant"] = self.rrf["rank_constant"]
if "window_size" in self.rrf:
# 'window_size' was renamed to 'rank_window_size', but we support
# the older name for backwards compatiblit
rrf_options["rank_window_size"] = self.rrf["window_size"]
if "rank_window_size" in self.rrf:
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
query_body = {
"knn": knn,
"query": {
"bool": {
"must": [
"retriever": {
"rrf": {
"retrievers": [
{
"match": {
self.text_field: {
"query": query,
}
}
}
"standard": {
"query": {
"bool": {
"must": [
{
"match": {
self.text_field: {
"query": query,
}
}
}
],
"filter": filter,
}
},
},
},
{"knn": knn},
],
"filter": filter,
}
**rrf_options,
},
},
}

if isinstance(self.rrf, Dict):
query_body["rank"] = {"rrf": self.rrf}
elif isinstance(self.rrf, bool) and self.rrf is True:
query_body["rank"] = {"rrf": {}}

return query_body

def needs_inference(self) -> bool:
Expand Down
174 changes: 122 additions & 52 deletions test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,20 +349,48 @@ def test_search_knn_with_hybrid_search(

def assert_query(query_body: dict, query: Optional[str]) -> dict:
assert query_body == {
"knn": {
"field": "vector_field",
"filter": [],
"k": 1,
"num_candidates": 50,
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
},
"query": {
"bool": {
"filter": [],
"must": [{"match": {"text_field": {"query": "foo"}}}],
"retriever": {
"rrf": {
"retrievers": [
{
"standard": {
"query": {
"bool": {
"filter": [],
"must": [
{
"match": {
"text_field": {"query": "foo"}
}
}
],
}
},
},
},
{
"knn": {
"field": "vector_field",
"filter": [],
"k": 1,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
},
},
],
}
},
"rank": {"rrf": {}},
}
}
return query_body

Expand All @@ -381,36 +409,52 @@ def assert_query(
expected_rrf: Union[dict, bool],
) -> dict:
cmp_query_body = {
"knn": {
"field": "vector_field",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
},
"query": {
"bool": {
"filter": [],
"must": [{"match": {"text_field": {"query": "foo"}}}],
"retriever": {
"rrf": {
"retrievers": [
{
"standard": {
"query": {
"bool": {
"filter": [],
"must": [
{
"match": {
"text_field": {"query": "foo"}
}
}
],
}
},
},
},
{
"knn": {
"field": "vector_field",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
},
},
],
}
},
}
}

if isinstance(expected_rrf, dict):
cmp_query_body["rank"] = {"rrf": expected_rrf}
elif isinstance(expected_rrf, bool) and expected_rrf is True:
cmp_query_body["rank"] = {"rrf": {}}
cmp_query_body["retriever"]["rrf"].update(expected_rrf)

assert query_body == cmp_query_body

Expand All @@ -420,7 +464,7 @@ def assert_query(
rrf_test_cases: List[Union[dict, bool]] = [
True,
False,
{"rank_constant": 1, "window_size": 5},
{"rank_constant": 1, "rank_window_size": 5},
]
for rrf_test_case in rrf_test_cases:
store = VectorStore(
Expand All @@ -441,21 +485,47 @@ def assert_query(
# 2. check query result is okay
es_output = store.client.search(
index=index,
query={
"bool": {
"filter": [],
"must": [{"match": {"text_field": {"query": "foo"}}}],
retriever={
"rrf": {
"retrievers": [
{
"knn": {
"field": "vector_field",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
},
},
{
"standard": {
"query": {
"bool": {
"filter": [],
"must": [
{"match": {"text_field": {"query": "foo"}}}
],
}
},
},
},
],
"rank_constant": 1,
"rank_window_size": 5,
}
},
knn={
"field": "vector_field",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
},
size=3,
rank={"rrf": {"rank_constant": 1, "window_size": 5}},
)

assert [o["_source"]["text_field"] for o in output] == [
Expand Down

0 comments on commit 32ab831

Please sign in to comment.