Skip to content

Commit

Permalink
Merge pull request #1256 from weaviate/hybrid_vector_distance
Browse files Browse the repository at this point in the history
Add support for hybrid search with vector distance
  • Loading branch information
dirkkul authored Aug 27, 2024
2 parents 8cf04ec + 0faacca commit 0051a20
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 59 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ env:
WEAVIATE_123: 1.23.16
WEAVIATE_124: 1.24.21
WEAVIATE_125: 1.25.8
WEAVIATE_126: 1.26.1
WEAVIATE_126: preview-increase-version-number-8b44fe6


jobs:
Expand Down
47 changes: 47 additions & 0 deletions integration/test_collection_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,50 @@ def test_vector_per_target(
).objects
assert len(objs) == 1
assert objs[0].uuid == uuid1


def test_vector_distance(collection_factory: CollectionFactory):
collection = collection_factory(
properties=[Property(name="name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)

if collection._connection._weaviate_version.is_lower_than(1, 26, 3):
pytest.skip("Hybrid max vector distance is only supported in versions higher than 1.26.3")

uuid1 = collection.data.insert({}, vector=[1, 0, 0])
collection.data.insert({}, vector=[0, 1, 0])
collection.data.insert({}, vector=[0, 0, 1])

objs = collection.query.hybrid("name", vector=[1, 0, 0])
assert len(objs.objects) == 3
assert objs.objects[0].uuid == uuid1

objs = collection.query.hybrid("name", vector=[1, 0, 0], max_vector_distance=0.1)
assert len(objs.objects) == 1
assert objs.objects[0].uuid == uuid1


def test_aggregate_max_vector_distance(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)

if collection._connection._weaviate_version.is_lower_than(1, 26, 3):
pytest.skip("Hybrid max vector distance is only supported in versions higher than 1.26.3")

collection.data.insert({"name": "banana one"}, vector=[1, 0, 0, 0])
collection.data.insert({"name": "banana two"}, vector=[0, 1, 0, 0])
collection.data.insert({"name": "banana three"}, vector=[0, 1, 0, 0])
collection.data.insert({"name": "banana four"}, vector=[1, 0, 0, 0])

res = collection.aggregate.hybrid(
"banana",
vector=[1, 0, 0, 0],
max_vector_distance=0.5,
return_metrics=[wvc.aggregate.Metrics("name").text(count=True)],
)
assert res.total_count == 2
10 changes: 6 additions & 4 deletions weaviate/collections/aggregations/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import json
import os
import pathlib

from typing import List, Optional, TypeVar, Union, cast
from typing_extensions import ParamSpec

from httpx import ConnectError
from typing_extensions import ParamSpec

from weaviate.collections.classes.aggregate import (
AProperties,
Expand Down Expand Up @@ -34,13 +33,13 @@
from weaviate.collections.classes.config import ConsistencyLevel
from weaviate.collections.classes.filters import _Filters
from weaviate.collections.classes.grpc import Move
from weaviate.connect import ConnectionV4
from weaviate.collections.filters import _FilterToREST
from weaviate.connect import ConnectionV4
from weaviate.exceptions import WeaviateInvalidInputError, WeaviateQueryError
from weaviate.gql.aggregate import AggregateBuilder
from weaviate.types import NUMBER, UUID
from weaviate.util import file_encoder_b64, _decode_json_response_dict
from weaviate.validator import _ValidateArgument, _validate_input
from weaviate.types import NUMBER, UUID

P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -255,6 +254,7 @@ def _add_hybrid_to_builder(
query_properties: Optional[List[str]],
object_limit: Optional[int],
target_vector: Optional[str],
max_vector_distance: Optional[NUMBER],
) -> AggregateBuilder:
payload: dict = {}
if query is not None:
Expand All @@ -267,6 +267,8 @@ def _add_hybrid_to_builder(
payload["properties"] = query_properties
if target_vector is not None:
payload["targetVectors"] = [target_vector]
if max_vector_distance is not None:
payload["maxVectorDistance"] = max_vector_distance
builder = builder.with_hybrid(payload)
if object_limit is not None:
builder = builder.with_object_limit(object_limit)
Expand Down
10 changes: 9 additions & 1 deletion weaviate/collections/aggregations/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def hybrid(
filters: Optional[_Filters] = None,
group_by: Optional[Union[str, GroupByAggregate]] = None,
target_vector: Optional[str] = None,
max_vector_distance: Optional[NUMBER] = None,
total_count: bool = True,
return_metrics: Optional[PropertiesMetrics] = None,
) -> Union[AggregateReturn, AggregateGroupByReturn]:
Expand Down Expand Up @@ -70,7 +71,14 @@ async def hybrid(
)
builder = self._base(return_metrics, filters, total_count)
builder = self._add_hybrid_to_builder(
builder, query, alpha, vector, query_properties, object_limit, target_vector
builder,
query,
alpha,
vector,
query_properties,
object_limit,
target_vector,
max_vector_distance,
)
builder = self._add_groupby_to_builder(builder, group_by)
res = await self._do(builder)
Expand Down
6 changes: 6 additions & 0 deletions weaviate/collections/aggregations/hybrid.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class _HybridAsync(_AggregateAsync):
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
target_vector: Optional[str] = None,
max_vector_distance: Optional[float] = None,
total_count: bool = True,
return_metrics: Optional[PropertiesMetrics] = None,
) -> AggregateReturn: ...
Expand All @@ -38,6 +39,7 @@ class _HybridAsync(_AggregateAsync):
filters: Optional[_Filters] = None,
group_by: Union[str, GroupByAggregate],
target_vector: Optional[str] = None,
max_vector_distance: Optional[float] = None,
total_count: bool = True,
return_metrics: Optional[PropertiesMetrics] = None,
) -> AggregateGroupByReturn: ...
Expand All @@ -53,6 +55,7 @@ class _HybridAsync(_AggregateAsync):
filters: Optional[_Filters] = None,
group_by: Optional[Union[str, GroupByAggregate]] = None,
target_vector: Optional[str] = None,
max_vector_distance: Optional[float] = None,
total_count: bool = True,
return_metrics: Optional[PropertiesMetrics] = None,
) -> Union[AggregateReturn, AggregateGroupByReturn]: ...
Expand All @@ -70,6 +73,7 @@ class _Hybrid(_AggregateAsync):
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
target_vector: Optional[str] = None,
max_vector_distance: Optional[float] = None,
total_count: bool = True,
return_metrics: Optional[PropertiesMetrics] = None,
) -> AggregateReturn: ...
Expand All @@ -85,6 +89,7 @@ class _Hybrid(_AggregateAsync):
filters: Optional[_Filters] = None,
group_by: Union[str, GroupByAggregate],
target_vector: Optional[str] = None,
max_vector_distance: Optional[float] = None,
total_count: bool = True,
return_metrics: Optional[PropertiesMetrics] = None,
) -> AggregateGroupByReturn: ...
Expand All @@ -100,6 +105,7 @@ class _Hybrid(_AggregateAsync):
filters: Optional[_Filters] = None,
group_by: Optional[Union[str, GroupByAggregate]] = None,
target_vector: Optional[str] = None,
max_vector_distance: Optional[float] = None,
total_count: bool = True,
return_metrics: Optional[PropertiesMetrics] = None,
) -> Union[AggregateReturn, AggregateGroupByReturn]: ...
5 changes: 3 additions & 2 deletions weaviate/collections/grpc/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
get_args,
)

from typing_extensions import TypeAlias

from grpc.aio import AioRpcError # type: ignore
from typing_extensions import TypeAlias

from weaviate.collections.classes.config import ConsistencyLevel
from weaviate.collections.classes.filters import _Filters
Expand Down Expand Up @@ -165,6 +164,7 @@ def hybrid(
vector: Optional[HybridVectorType] = None,
properties: Optional[List[str]] = None,
fusion_type: Optional[HybridFusion] = None,
distance: Optional[NUMBER] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
autocut: Optional[int] = None,
Expand Down Expand Up @@ -274,6 +274,7 @@ def hybrid(
near_text=near_text,
near_vector=near_vector,
vector_bytes=vector_bytes,
vector_distance=distance,
)
if query is not None or vector is not None
else None
Expand Down
2 changes: 2 additions & 0 deletions weaviate/collections/queries/hybrid/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def hybrid(
vector: Optional[HybridVectorType] = None,
query_properties: Optional[List[str]] = None,
fusion_type: Optional[HybridFusion] = None,
max_vector_distance: Optional[NUMBER] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
auto_limit: Optional[int] = None,
Expand Down Expand Up @@ -121,6 +122,7 @@ async def hybrid(
fusion_type=fusion_type,
limit=limit,
offset=offset,
distance=max_vector_distance,
autocut=auto_limit,
filters=filters,
group_by=_GroupBy.from_input(group_by),
Expand Down
Loading

0 comments on commit 0051a20

Please sign in to comment.