Skip to content

Commit

Permalink
Vectorstore: use a retriever query for hybrid search (#2666)
Browse files Browse the repository at this point in the history
* Vectorstore: use a retriever query for hybrid search

Fixes #2651

* only run hybrid search tests when using a stack version >= 8.14

* add support for rrf=False back
  • Loading branch information
miguelgrinberg authored Oct 14, 2024
1 parent 14e6265 commit e22de7e
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 72 deletions.
39 changes: 30 additions & 9 deletions elasticsearch/helpers/vectorstore/_async/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,9 @@ def _hybrid(
) -> Dict[str, Any]:
# Add a query to the knn query.
# RRF is used to even the score from the knn query and text query
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
query_body = {
"knn": knn,
standard_query = {
"query": {
"bool": {
"must": [
Expand All @@ -300,14 +299,36 @@ def _hybrid(
],
"filter": filter,
}
},
}
}

if isinstance(self.rrf, Dict):
query_body["rank"] = {"rrf": self.rrf}
elif isinstance(self.rrf, bool) and self.rrf is True:
query_body["rank"] = {"rrf": {}}

if self.rrf is False:
query_body = {
"knn": knn,
**standard_query,
}
else:
rrf_options = {}
if isinstance(self.rrf, Dict):
if "rank_constant" in self.rrf:
rrf_options["rank_constant"] = self.rrf["rank_constant"]
if "window_size" in self.rrf:
# 'window_size' was renamed to 'rank_window_size', but we support
# the older name for backwards compatibility
rrf_options["rank_window_size"] = self.rrf["window_size"]
if "rank_window_size" in self.rrf:
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
query_body = {
"retriever": {
"rrf": {
"retrievers": [
{"standard": standard_query},
{"knn": knn},
],
**rrf_options,
},
},
}
return query_body

def needs_inference(self) -> bool:
Expand Down
39 changes: 30 additions & 9 deletions elasticsearch/helpers/vectorstore/_sync/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,9 @@ def _hybrid(
) -> Dict[str, Any]:
# Add a query to the knn query.
# RRF is used to even the score from the knn query and text query
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
query_body = {
"knn": knn,
standard_query = {
"query": {
"bool": {
"must": [
Expand All @@ -300,14 +299,36 @@ def _hybrid(
],
"filter": filter,
}
},
}
}

if isinstance(self.rrf, Dict):
query_body["rank"] = {"rrf": self.rrf}
elif isinstance(self.rrf, bool) and self.rrf is True:
query_body["rank"] = {"rrf": {}}

if self.rrf is False:
query_body = {
"knn": knn,
**standard_query,
}
else:
rrf_options = {}
if isinstance(self.rrf, Dict):
if "rank_constant" in self.rrf:
rrf_options["rank_constant"] = self.rrf["rank_constant"]
if "window_size" in self.rrf:
# 'window_size' was renamed to 'rank_window_size', but we support
# the older name for backwards compatibility
rrf_options["rank_window_size"] = self.rrf["window_size"]
if "rank_window_size" in self.rrf:
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
query_body = {
"retriever": {
"rrf": {
"retrievers": [
{"standard": standard_query},
{"knn": knn},
],
**rrf_options,
},
},
}
return query_body

def needs_inference(self) -> bool:
Expand Down
188 changes: 134 additions & 54 deletions test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
VectorStore,
)
from elasticsearch.helpers.vectorstore._sync._utils import model_is_deployed
from test_elasticsearch.utils import es_version

from . import ConsistentFakeEmbeddings, FakeEmbeddings

Expand Down Expand Up @@ -337,6 +338,9 @@ def test_search_knn_with_hybrid_search(
self, sync_client: Elasticsearch, index: str
) -> None:
"""Test end to end construction and search with metadata."""
if es_version(sync_client) < (8, 14):
pytest.skip("This test requires Elasticsearch 8.14 or newer")

store = VectorStore(
index=index,
retrieval_strategy=DenseVectorStrategy(hybrid=True),
Expand All @@ -349,20 +353,48 @@ def test_search_knn_with_hybrid_search(

def assert_query(query_body: dict, query: Optional[str]) -> dict:
assert query_body == {
"knn": {
"field": "vector_field",
"filter": [],
"k": 1,
"num_candidates": 50,
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
},
"query": {
"bool": {
"filter": [],
"must": [{"match": {"text_field": {"query": "foo"}}}],
"retriever": {
"rrf": {
"retrievers": [
{
"standard": {
"query": {
"bool": {
"filter": [],
"must": [
{
"match": {
"text_field": {"query": "foo"}
}
}
],
}
},
},
},
{
"knn": {
"field": "vector_field",
"filter": [],
"k": 1,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
},
},
],
}
},
"rank": {"rrf": {}},
}
}
return query_body

Expand All @@ -373,55 +405,77 @@ def test_search_knn_with_hybrid_search_rrf(
self, sync_client: Elasticsearch, index: str
) -> None:
"""Test end to end construction and rrf hybrid search with metadata."""
if es_version(sync_client) < (8, 14):
pytest.skip("This test requires Elasticsearch 8.14 or newer")

texts = ["foo", "bar", "baz"]

def assert_query(
query_body: dict,
query: Optional[str],
expected_rrf: Union[dict, bool],
) -> dict:
cmp_query_body = {
"knn": {
"field": "vector_field",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
},
standard_query = {
"query": {
"bool": {
"filter": [],
"must": [{"match": {"text_field": {"query": "foo"}}}],
}
},
}
}
knn_query = {
"field": "vector_field",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
}

if isinstance(expected_rrf, dict):
cmp_query_body["rank"] = {"rrf": expected_rrf}
elif isinstance(expected_rrf, bool) and expected_rrf is True:
cmp_query_body["rank"] = {"rrf": {}}
if expected_rrf is not False:
cmp_query_body = {
"retriever": {
"rrf": {
"retrievers": [
{"standard": standard_query},
{"knn": knn_query},
],
}
}
}
if isinstance(expected_rrf, dict):
cmp_query_body["retriever"]["rrf"].update(expected_rrf)
else:
cmp_query_body = {
"knn": knn_query,
**standard_query,
}

assert query_body == cmp_query_body

return query_body

# 1. check query_body is okay
rrf_test_cases: List[Union[dict, bool]] = [
True,
False,
{"rank_constant": 1, "window_size": 5},
]
if es_version(sync_client) >= (8, 14):
rrf_test_cases: List[Union[dict, bool]] = [
True,
False,
{"rank_constant": 1, "rank_window_size": 5},
]
else:
# for 8.13.x and older there is no retriever query, so we can only
# run hybrid searches with rrf=False
rrf_test_cases: List[Union[dict, bool]] = [False]
for rrf_test_case in rrf_test_cases:
store = VectorStore(
index=index,
Expand All @@ -441,21 +495,47 @@ def assert_query(
# 2. check query result is okay
es_output = store.client.search(
index=index,
query={
"bool": {
"filter": [],
"must": [{"match": {"text_field": {"query": "foo"}}}],
retriever={
"rrf": {
"retrievers": [
{
"knn": {
"field": "vector_field",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
},
},
{
"standard": {
"query": {
"bool": {
"filter": [],
"must": [
{"match": {"text_field": {"query": "foo"}}}
],
}
},
},
},
],
"rank_constant": 1,
"rank_window_size": 5,
}
},
knn={
"field": "vector_field",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
},
size=3,
rank={"rrf": {"rank_constant": 1, "window_size": 5}},
)

assert [o["_source"]["text_field"] for o in output] == [
Expand Down

0 comments on commit e22de7e

Please sign in to comment.