From 58126ed79ee73bcdeb734fd00e2d486ef3b416ef Mon Sep 17 00:00:00 2001 From: Dirk Kulawiak Date: Thu, 4 Jul 2024 12:33:47 +0200 Subject: [PATCH 1/6] Add support for vector per target for hybrid --- integration/test_collection.py | 356 --------------------- integration/test_collection_hybrid.py | 425 ++++++++++++++++++++++++++ weaviate/collections/classes/grpc.py | 12 +- weaviate/collections/grpc/query.py | 198 ++++++------ 4 files changed, 531 insertions(+), 460 deletions(-) create mode 100644 integration/test_collection_hybrid.py diff --git a/integration/test_collection.py b/integration/test_collection.py index 5f705a302..bb1a2bd68 100644 --- a/integration/test_collection.py +++ b/integration/test_collection.py @@ -7,7 +7,6 @@ import pytest -import weaviate.classes as wvc from integration.conftest import CollectionFactory, CollectionFactoryGet, _sanitize_collection_name from integration.constants import WEAVIATE_LOGO_OLD_ENCODED, WEAVIATE_LOGO_NEW_ENCODED from weaviate.collections.classes.batch import ErrorObject @@ -541,144 +540,6 @@ def test_types(collection_factory: CollectionFactory, data_type: DataType, value assert object_get_from_batch is not None and object_get_from_batch.properties[name] == value -@pytest.mark.parametrize("fusion_type", [HybridFusion.RANKED, HybridFusion.RELATIVE_SCORE]) -def test_search_hybrid(collection_factory: CollectionFactory, fusion_type: HybridFusion) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4()) - collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4()) - objs = collection.query.hybrid( - alpha=0, query="name", fusion_type=fusion_type, include_vector=True - ).objects - assert len(objs) == 1 - - objs = collection.query.hybrid( - alpha=1, query="name", fusion_type=fusion_type, vector=objs[0].vector["default"] - ).objects - assert len(objs) == 2 - - -def test_search_hybrid_group_by(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4()) - collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4()) - if collection._connection.supports_groupby_in_bm25_and_hybrid(): - objs = collection.query.hybrid( - alpha=0, - query="name", - include_vector=True, - group_by=GroupBy(prop="name", objects_per_group=1, number_of_groups=2), - ).objects - assert len(objs) == 1 - assert objs[0].belongs_to_group == "some name" - else: - with pytest.raises(WeaviateUnsupportedFeatureError): - collection.query.hybrid( - alpha=0, - query="name", - include_vector=True, - group_by=GroupBy(prop="name", objects_per_group=1, number_of_groups=2), - ) - - -@pytest.mark.parametrize("query", [None, ""]) -def test_search_hybrid_only_vector( - collection_factory: CollectionFactory, query: Optional[str] -) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - uuid_ = collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4()) - vec = collection.query.fetch_object_by_id(uuid_, include_vector=True).vector - assert vec is not None - - collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4()) - - objs = collection.query.hybrid(alpha=1, query=query, vector=vec["default"]).objects - assert len(objs) == 2 - - -@pytest.mark.parametrize("limit", [1, 2]) -def test_hybrid_limit(collection_factory: CollectionFactory, limit: int) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.none(), - ) - - res = collection.data.insert_many( - [ - {"Name": "test"}, - {"Name": "another"}, - {"Name": "test"}, - ] - ) - assert res.has_errors is False - assert len(collection.query.hybrid(query="test", alpha=0, limit=limit).objects) == limit - - -@pytest.mark.parametrize("offset,expected", [(0, 2), (1, 1), (2, 0)]) -def test_hybrid_offset(collection_factory: CollectionFactory, offset: int, expected: int) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.none(), - ) - - res = collection.data.insert_many( - [ - {"Name": "test"}, - {"Name": "another"}, - {"Name": "test"}, - ] - ) - assert res.has_errors is False - - assert len(collection.query.hybrid(query="test", alpha=0, offset=offset).objects) == expected - - -def test_hybrid_alpha(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[Property(name="name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - - res = collection.data.insert_many( - [ - {"name": "banana"}, - {"name": "fruit"}, - {"name": "car"}, - ] - ) - assert res.has_errors is False - - hybrid_res = collection.query.hybrid(query="fruit", alpha=0) - bm25_res = collection.query.bm25(query="fruit") - assert all( - bm25_res.objects[i].uuid == hybrid_res.objects[i].uuid - for i in range(len(hybrid_res.objects)) - ) - - hybrid_res = collection.query.hybrid(query="fruit", alpha=1) - text_res = collection.query.near_text(query="fruit") - assert all( - text_res.objects[i].uuid == hybrid_res.objects[i].uuid - for i in range(len(hybrid_res.objects)) - ) - - def test_bm25(collection_factory: CollectionFactory) -> None: collection = collection_factory( properties=[Property(name="Name", data_type=DataType.TEXT)], @@ -1844,220 +1705,3 @@ def test_none_query_hybrid_bm25(collection_factory: CollectionFactory) -> None: bm25_objs = collection.query.bm25(query=None, return_metadata=MetadataQuery.full()).objects assert len(bm25_objs) == 3 assert all(obj.metadata.score is not None and obj.metadata.score == 0.0 for obj in bm25_objs) - - -def test_hybrid_near_vector_search(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[ - Property(name="text", data_type=DataType.TEXT), - ], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - uuid_banana = collection.data.insert({"text": "banana"}) - obj = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) - - if collection._connection._weaviate_version.is_lower_than(1, 25, 0): - with pytest.raises(WeaviateUnsupportedFeatureError): - collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_vector(vector=obj.vector["default"]), - ).objects - return - - collection.data.insert({"text": "dog"}) - collection.data.insert({"text": "different concept"}) - - hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_vector(vector=obj.vector["default"]), - ).objects - - assert hybrid_objs[0].uuid == uuid_banana - assert len(hybrid_objs) == 3 - - # make a near vector search to get the distance - near_vec = collection.query.near_vector( - near_vector=obj.vector["default"], return_metadata=["distance"] - ).objects - assert near_vec[0].metadata.distance is not None - - hybrid_objs2 = collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_vector( - vector=obj.vector["default"], distance=near_vec[0].metadata.distance + 0.001 - ), - return_metadata=MetadataQuery.full(), - ).objects - - assert hybrid_objs2[0].uuid == uuid_banana - assert len(hybrid_objs2) == 1 - - -def test_hybrid_near_vector_search_named_vectors(collection_factory: CollectionFactory) -> None: - dummy = collection_factory("dummy") - collection_maker = lambda: collection_factory( - properties=[ - Property(name="text", data_type=DataType.TEXT), - Property(name="int", data_type=DataType.INT), - ], - vectorizer_config=[ - Configure.NamedVectors.text2vec_contextionary( - name="text", vectorize_collection_name=False - ), - Configure.NamedVectors.text2vec_contextionary( - name="int", vectorize_collection_name=False - ), - ], - ) - - if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): - with pytest.raises(WeaviateInvalidInputError): - collection_maker() - return - - collection = collection_maker() - uuid_banana = collection.data.insert({"text": "banana"}) - collection.data.insert({"text": "dog"}) - collection.data.insert({"text": "different concept"}) - - obj = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) - - if collection._connection._weaviate_version.is_lower_than(1, 25, 0): - with pytest.raises(WeaviateUnsupportedFeatureError): - hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_vector(vector=obj.vector["text"]), - target_vector="text", - ).objects - return - - hybrid_objs = collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_vector(vector=obj.vector["text"]), - target_vector="text", - ).objects - - assert hybrid_objs[0].uuid == uuid_banana - assert len(hybrid_objs) == 3 - - # make a near vector search to get the distance - near_vec = collection.query.near_vector( - near_vector=obj.vector["text"], return_metadata=["distance"], target_vector="text" - ).objects - assert near_vec[0].metadata.distance is not None - - hybrid_objs2 = collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_vector( - vector=obj.vector["text"], - distance=near_vec[0].metadata.distance + 0.001, - ), - target_vector="text", - return_metadata=MetadataQuery.full(), - ).objects - - assert hybrid_objs2[0].uuid == uuid_banana - assert len(hybrid_objs2) == 1 - - -def test_hybrid_near_text_search(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[ - Property(name="text", data_type=DataType.TEXT), - ], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - - if collection._connection._weaviate_version.is_lower_than(1, 25, 0): - with pytest.raises(WeaviateUnsupportedFeatureError): - collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_text(query="banana pudding"), - ).objects - return - - uuid_banana_pudding = collection.data.insert({"text": "banana pudding"}) - collection.data.insert({"text": "banana smoothie"}) - collection.data.insert({"text": "different concept"}) - - hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_text(query="banana pudding"), - ).objects - - assert hybrid_objs[0].uuid == uuid_banana_pudding - assert len(hybrid_objs) == 3 - - hybrid_objs2 = collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_text( - query="banana", - move_to=wvc.query.Move(concepts="pudding", force=0.1), - move_away=wvc.query.Move(concepts="smoothie", force=0.1), - ), - return_metadata=MetadataQuery.full(), - ).objects - - assert hybrid_objs2[0].uuid == uuid_banana_pudding - - -def test_hybrid_near_text_search_named_vectors(collection_factory: CollectionFactory) -> None: - dummy = collection_factory("dummy") - collection_maker = lambda: collection_factory( - properties=[ - Property(name="text", data_type=DataType.TEXT), - Property(name="int", data_type=DataType.INT), - ], - vectorizer_config=[ - Configure.NamedVectors.text2vec_contextionary( - name="text", vectorize_collection_name=False - ), - Configure.NamedVectors.text2vec_contextionary( - name="int", vectorize_collection_name=False - ), - ], - ) - if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): - with pytest.raises(WeaviateInvalidInputError): - collection_maker() - return - - collection = collection_maker() - uuid_banana_pudding = collection.data.insert({"text": "banana pudding"}) - collection.data.insert({"text": "banana smoothie"}) - collection.data.insert({"text": "different concept"}) - - if collection._connection._weaviate_version.is_lower_than(1, 25, 0): - with pytest.raises(WeaviateUnsupportedFeatureError): - hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_text(query="banana pudding"), - target_vector="text", - ).objects - return - - hybrid_objs = collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_text(query="banana pudding"), - target_vector="text", - ).objects - - assert hybrid_objs[0].uuid == uuid_banana_pudding - assert len(hybrid_objs) == 3 - - hybrid_objs2 = collection.query.hybrid( - query=None, - vector=wvc.query.HybridVector.near_text( - query="banana", - move_to=wvc.query.Move(concepts="pudding", force=0.1), - move_away=wvc.query.Move(concepts="smoothie", force=0.1), - ), - target_vector="text", - return_metadata=MetadataQuery.full(), - ).objects - - assert hybrid_objs2[0].uuid == uuid_banana_pudding diff --git a/integration/test_collection_hybrid.py b/integration/test_collection_hybrid.py new file mode 100644 index 000000000..8f88936a4 --- /dev/null +++ b/integration/test_collection_hybrid.py @@ -0,0 +1,425 @@ +import uuid +from typing import Any, List, Optional + +import numpy as np +import pandas as pd +import polars as pl +import pytest + +import weaviate.classes as wvc +from integration.conftest import CollectionFactory +from weaviate.collections.classes.config import ( + Configure, + DataType, + Property, +) +from weaviate.collections.classes.grpc import ( + HybridFusion, + GroupBy, + MetadataQuery, + NearVectorInputType, +) +from weaviate.collections.classes.internal import Object +from weaviate.exceptions import ( + WeaviateInvalidInputError, + WeaviateUnsupportedFeatureError, +) + + +@pytest.mark.parametrize("fusion_type", [HybridFusion.RANKED, HybridFusion.RELATIVE_SCORE]) +def test_search_hybrid(collection_factory: CollectionFactory, fusion_type: HybridFusion) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4()) + collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4()) + objs = collection.query.hybrid( + alpha=0, query="name", fusion_type=fusion_type, include_vector=True + ).objects + assert len(objs) == 1 + + objs = collection.query.hybrid( + alpha=1, query="name", fusion_type=fusion_type, vector=objs[0].vector["default"] + ).objects + assert len(objs) == 2 + + +def test_search_hybrid_group_by(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4()) + collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4()) + if collection._connection.supports_groupby_in_bm25_and_hybrid(): + objs = collection.query.hybrid( + alpha=0, + query="name", + include_vector=True, + group_by=GroupBy(prop="name", objects_per_group=1, number_of_groups=2), + ).objects + assert len(objs) == 1 + assert objs[0].belongs_to_group == "some name" + else: + with pytest.raises(WeaviateUnsupportedFeatureError): + collection.query.hybrid( + alpha=0, + query="name", + include_vector=True, + group_by=GroupBy(prop="name", objects_per_group=1, number_of_groups=2), + ) + + +@pytest.mark.parametrize("query", [None, ""]) +def test_search_hybrid_only_vector( + collection_factory: CollectionFactory, query: Optional[str] +) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + uuid_ = collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4()) + vec = collection.query.fetch_object_by_id(uuid_, include_vector=True).vector + assert vec is not None + + collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4()) + + objs = collection.query.hybrid(alpha=1, query=query, vector=vec["default"]).objects + assert len(objs) == 2 + + +@pytest.mark.parametrize("limit", [1, 2]) +def test_hybrid_limit(collection_factory: CollectionFactory, limit: int) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.none(), + ) + + res = collection.data.insert_many( + [ + {"Name": "test"}, + {"Name": "another"}, + {"Name": "test"}, + ] + ) + assert res.has_errors is False + assert len(collection.query.hybrid(query="test", alpha=0, limit=limit).objects) == limit + + +@pytest.mark.parametrize("offset,expected", [(0, 2), (1, 1), (2, 0)]) +def test_hybrid_offset(collection_factory: CollectionFactory, offset: int, expected: int) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.none(), + ) + + res = collection.data.insert_many( + [ + {"Name": "test"}, + {"Name": "another"}, + {"Name": "test"}, + ] + ) + assert res.has_errors is False + + assert len(collection.query.hybrid(query="test", alpha=0, offset=offset).objects) == expected + + +def test_hybrid_alpha(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[Property(name="name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + + res = collection.data.insert_many( + [ + {"name": "banana"}, + {"name": "fruit"}, + {"name": "car"}, + ] + ) + assert res.has_errors is False + + hybrid_res = collection.query.hybrid(query="fruit", alpha=0) + bm25_res = collection.query.bm25(query="fruit") + assert all( + bm25_res.objects[i].uuid == hybrid_res.objects[i].uuid + for i in range(len(hybrid_res.objects)) + ) + + hybrid_res = collection.query.hybrid(query="fruit", alpha=1) + text_res = collection.query.near_text(query="fruit") + assert all( + text_res.objects[i].uuid == hybrid_res.objects[i].uuid + for i in range(len(hybrid_res.objects)) + ) + + +def test_hybrid_near_vector_search(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[ + Property(name="text", data_type=DataType.TEXT), + ], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + uuid_banana = collection.data.insert({"text": "banana"}) + obj = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) + + if collection._connection._weaviate_version.is_lower_than(1, 25, 0): + with pytest.raises(WeaviateUnsupportedFeatureError): + collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_vector(vector=obj.vector["default"]), + ).objects + return + + collection.data.insert({"text": "dog"}) + collection.data.insert({"text": "different concept"}) + + hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_vector(vector=obj.vector["default"]), + ).objects + + assert hybrid_objs[0].uuid == uuid_banana + assert len(hybrid_objs) == 3 + + # make a near vector search to get the distance + near_vec = collection.query.near_vector( + near_vector=obj.vector["default"], return_metadata=["distance"] + ).objects + assert near_vec[0].metadata.distance is not None + + hybrid_objs2 = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_vector( + vector=obj.vector["default"], distance=near_vec[0].metadata.distance + 0.001 + ), + return_metadata=MetadataQuery.full(), + ).objects + + assert hybrid_objs2[0].uuid == uuid_banana + assert len(hybrid_objs2) == 1 + + +def test_hybrid_near_vector_search_named_vectors(collection_factory: CollectionFactory) -> None: + dummy = collection_factory("dummy") + collection_maker = lambda: collection_factory( + properties=[ + Property(name="text", data_type=DataType.TEXT), + Property(name="int", data_type=DataType.INT), + ], + vectorizer_config=[ + Configure.NamedVectors.text2vec_contextionary( + name="text", vectorize_collection_name=False + ), + Configure.NamedVectors.text2vec_contextionary( + name="int", vectorize_collection_name=False + ), + ], + ) + + if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): + with pytest.raises(WeaviateInvalidInputError): + collection_maker() + return + + collection = collection_maker() + uuid_banana = collection.data.insert({"text": "banana"}) + collection.data.insert({"text": "dog"}) + collection.data.insert({"text": "different concept"}) + + obj = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) + + if collection._connection._weaviate_version.is_lower_than(1, 25, 0): + with pytest.raises(WeaviateUnsupportedFeatureError): + hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_vector(vector=obj.vector["text"]), + target_vector="text", + ).objects + return + + hybrid_objs = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_vector(vector=obj.vector["text"]), + target_vector="text", + ).objects + + assert hybrid_objs[0].uuid == uuid_banana + assert len(hybrid_objs) == 3 + + # make a near vector search to get the distance + near_vec = collection.query.near_vector( + near_vector=obj.vector["text"], return_metadata=["distance"], target_vector="text" + ).objects + assert near_vec[0].metadata.distance is not None + + hybrid_objs2 = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_vector( + vector=obj.vector["text"], + distance=near_vec[0].metadata.distance + 0.001, + ), + target_vector="text", + return_metadata=MetadataQuery.full(), + ).objects + + assert hybrid_objs2[0].uuid == uuid_banana + assert len(hybrid_objs2) == 1 + + +def test_hybrid_near_text_search(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[ + Property(name="text", data_type=DataType.TEXT), + ], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + + if collection._connection._weaviate_version.is_lower_than(1, 25, 0): + with pytest.raises(WeaviateUnsupportedFeatureError): + collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_text(query="banana pudding"), + ).objects + return + + uuid_banana_pudding = collection.data.insert({"text": "banana pudding"}) + collection.data.insert({"text": "banana smoothie"}) + collection.data.insert({"text": "different concept"}) + + hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_text(query="banana pudding"), + ).objects + + assert hybrid_objs[0].uuid == uuid_banana_pudding + assert len(hybrid_objs) == 3 + + hybrid_objs2 = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_text( + query="banana", + move_to=wvc.query.Move(concepts="pudding", force=0.1), + move_away=wvc.query.Move(concepts="smoothie", force=0.1), + ), + return_metadata=MetadataQuery.full(), + ).objects + + assert hybrid_objs2[0].uuid == uuid_banana_pudding + + +def test_hybrid_near_text_search_named_vectors(collection_factory: CollectionFactory) -> None: + dummy = collection_factory("dummy") + collection_maker = lambda: collection_factory( + properties=[ + Property(name="text", data_type=DataType.TEXT), + Property(name="int", data_type=DataType.INT), + ], + vectorizer_config=[ + Configure.NamedVectors.text2vec_contextionary( + name="text", vectorize_collection_name=False + ), + Configure.NamedVectors.text2vec_contextionary( + name="int", vectorize_collection_name=False + ), + ], + ) + if dummy._connection._weaviate_version.is_lower_than(1, 24, 0): + with pytest.raises(WeaviateInvalidInputError): + collection_maker() + return + + collection = collection_maker() + uuid_banana_pudding = collection.data.insert({"text": "banana pudding"}) + collection.data.insert({"text": "banana smoothie"}) + collection.data.insert({"text": "different concept"}) + + if collection._connection._weaviate_version.is_lower_than(1, 25, 0): + with pytest.raises(WeaviateUnsupportedFeatureError): + hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_text(query="banana pudding"), + target_vector="text", + ).objects + return + + hybrid_objs = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_text(query="banana pudding"), + target_vector="text", + ).objects + + assert hybrid_objs[0].uuid == uuid_banana_pudding + assert len(hybrid_objs) == 3 + + hybrid_objs2 = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_text( + query="banana", + move_to=wvc.query.Move(concepts="pudding", force=0.1), + move_away=wvc.query.Move(concepts="smoothie", force=0.1), + ), + target_vector="text", + return_metadata=MetadataQuery.full(), + ).objects + + assert hybrid_objs2[0].uuid == uuid_banana_pudding + + +@pytest.mark.parametrize( + "vector", + [ + {"first": [1, 0], "second": [1, 0, 0]}, + {"first": [1, 0], "second": np.array([1, 0, 0])}, + {"first": [1, 0], "second": pl.Series([1, 0, 0])}, + {"first": [1, 0], "second": pd.Series([1, 0, 0])}, + [[1, 0], [1, 0, 0]], + [[1, 0], np.array([1, 0, 0])], + [[1, 0], pl.Series([1, 0, 0])], + [[1, 0], pd.Series([1, 0, 0])], + ], +) +def test_vector_per_target( + collection_factory: CollectionFactory, vector: NearVectorInputType +) -> None: + dummy = collection_factory("dummy") + if dummy._connection._weaviate_version.is_lower_than(1, 26, 0): + pytest.skip("No multi target search below 1.26") + + collection = collection_factory( + properties=[ + Property(name="text", data_type=DataType.TEXT), + ], + vectorizer_config=[ + Configure.NamedVectors.none("first"), + Configure.NamedVectors.none("second"), + ], + ) + + uuid1 = collection.data.insert( + {"text": "banana"}, vector={"first": [1, 0], "second": [1, 0, 0]} + ) + collection.data.insert({"text": "apple"}, vector={"first": [0, 1], "second": [0, 0, 1]}) + + objs = collection.query.hybrid( + query=None, + vector=wvc.query.HybridVector.near_vector(vector, distance=0.1), + target_vector=["first", "second"], + ).objects + assert len(objs) == 1 + assert objs[0].uuid == uuid1 diff --git a/weaviate/collections/classes/grpc.py b/weaviate/collections/classes/grpc.py index 9c1f7c7c9..1e562feed 100644 --- a/weaviate/collections/classes/grpc.py +++ b/weaviate/collections/classes/grpc.py @@ -227,6 +227,9 @@ class Rerank(_WeaviateInput): query: Optional[str] = Field(default=None) +NearVectorInputType = Union[List[float], Dict[str, List[float]], List[List[float]]] + + class _HybridNearBase(_WeaviateInput): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") @@ -241,10 +244,10 @@ class _HybridNearText(_HybridNearBase): class _HybridNearVector(_HybridNearBase): - vector: List[float] + vector: NearVectorInputType -HybridVectorType = Union[List[float], _HybridNearText, _HybridNearVector] +HybridVectorType = Union[NearVectorInputType, _HybridNearText, _HybridNearVector] class _MultiTargetVectorJoinEnum(BaseEnum): @@ -367,7 +370,7 @@ def near_text( @staticmethod def near_vector( - vector: List[float], + vector: NearVectorInputType, *, certainty: Optional[float] = None, distance: Optional[float] = None, @@ -386,9 +389,6 @@ def near_vector( return _HybridNearVector(vector=vector, distance=distance, certainty=certainty) -NearVectorInputType = Union[List[float], Dict[str, List[float]], List[List[float]]] - - class _QueryReference(_WeaviateInput): link_on: str include_vector: bool = Field(default=False) diff --git a/weaviate/collections/grpc/query.py b/weaviate/collections/grpc/query.py index 08314170d..e831c1881 100644 --- a/weaviate/collections/grpc/query.py +++ b/weaviate/collections/grpc/query.py @@ -189,7 +189,7 @@ def hybrid( _ValidateArgument([None, str], "query", query), _ValidateArgument([float, int, None], "alpha", alpha), _ValidateArgument( - [list, _HybridNearText, _HybridNearVector, None], "vector", vector + [List, _HybridNearText, _HybridNearVector, None], "vector", vector ), _ValidateArgument([List, None], "properties", properties), _ValidateArgument([HybridFusion, None], "fusion_type", fusion_type), @@ -205,6 +205,32 @@ def hybrid( targets, target_vector = self.__target_vector_to_grpc(target_vector) + near_text, near_vector, vector_bytes = None, None, None + + if vector is None: + pass + elif isinstance(vector, _HybridNearText): + near_text = search_get_pb2.NearTextSearch( + query=[vector.text] if isinstance(vector.text, str) else vector.text, + certainty=vector.certainty, + distance=vector.distance, + move_away=self.__parse_move(vector.move_away), + move_to=self.__parse_move(vector.move_to), + ) + elif isinstance(vector, _HybridNearVector): + vector_per_target, vector_bytes = self.__vector_per_target( + vector.vector, targets, "vector" + ) + near_vector = search_get_pb2.NearVector( + vector_bytes=vector_bytes, + certainty=vector.certainty, + distance=vector.distance, + vector_per_target=vector_per_target, + ) + else: + vector = _get_vector_v4(vector) + vector_bytes = struct.pack("{}f".format(len(vector)), *vector) + hybrid_search = ( search_get_pb2.Hybrid( properties=properties, @@ -220,31 +246,9 @@ def hybrid( ), target_vectors=target_vector, targets=targets, - vector_bytes=( - struct.pack("{}f".format(len(vector)), *vector) - if vector is not None and isinstance(vector, list) - else None - ), - near_text=( - search_get_pb2.NearTextSearch( - query=[vector.text] if isinstance(vector.text, str) else vector.text, - certainty=vector.certainty, - distance=vector.distance, - move_away=self.__parse_move(vector.move_away), - move_to=self.__parse_move(vector.move_to), - ) - if vector is not None and isinstance(vector, _HybridNearText) - else None - ), - near_vector=( - search_get_pb2.NearVector( - vector_bytes=struct.pack("{}f".format(len(vector.vector)), *vector.vector), - certainty=vector.certainty, - distance=vector.distance, - ) - if vector is not None and isinstance(vector, _HybridNearVector) - else None - ), + near_text=near_text, + near_vector=near_vector, + vector_bytes=vector_bytes, ) if query is not None or vector is not None else None @@ -351,78 +355,10 @@ def near_vector( certainty, distance = self.__parse_near_options(certainty, distance) targets, target_vectors = self.__target_vector_to_grpc(target_vector) - invalid_nv_exception = WeaviateInvalidInputError( - f"""near vector argument can be: - - a list of numbers - - a list of lists of numbers for multi target search - - a dictionary with target names as keys and lists of numbers as values - received: {near_vector}""" - ) - if isinstance(near_vector, dict): - if targets is None or len(targets.target_vectors) != len(near_vector): - raise WeaviateInvalidInputError( - "The number of target vectors must be equal to the number of vectors." - ) - - vector_per_target: Dict[str, bytes] = {} - for key, value in near_vector.items(): - nv = _get_vector_v4(value) - - if ( - not isinstance(nv, list) - or len(nv) == 0 - or not isinstance(nv[0], get_args(NUMBER)) - ): - raise invalid_nv_exception - - vector_per_target[key] = struct.pack("{}f".format(len(nv)), *nv) - near_vector_grpc = search_get_pb2.NearVector( - certainty=certainty, - distance=distance, - targets=targets, - target_vectors=target_vectors, - vector_per_target=vector_per_target, - ) - else: - if len(near_vector) == 0: - raise invalid_nv_exception - - if _is_1d_vector(near_vector): - near_vector = _get_vector_v4(near_vector) - if not isinstance(near_vector, list): - raise invalid_nv_exception - near_vector_grpc = search_get_pb2.NearVector( - certainty=certainty, - distance=distance, - vector_bytes=struct.pack("{}f".format(len(near_vector)), *near_vector), - targets=targets, - target_vectors=target_vectors, - ) - else: - vector_per_target_tmp: Dict[str, bytes] = {} - if targets is None or len(targets.target_vectors) != len(near_vector): - raise WeaviateInvalidInputError( - "The number of target vectors must be equal to the number of vectors." - ) - for i, vector in enumerate(near_vector): - nv = _get_vector_v4(vector) - if ( - not isinstance(nv, list) - or len(nv) == 0 - or not isinstance(nv[0], get_args(NUMBER)) - ): - raise invalid_nv_exception - vector_per_target_tmp[targets.target_vectors[i]] = struct.pack( - "{}f".format(len(nv)), *nv - ) - near_vector_grpc = search_get_pb2.NearVector( - certainty=certainty, - distance=distance, - targets=targets, - target_vectors=target_vectors, - vector_per_target=vector_per_target_tmp, - ) + vector_per_target_tmp, near_vector_grpc = self.__vector_per_target( + near_vector, targets, "near_vector" + ) request = self.__create_request( limit=limit, offset=offset, @@ -434,7 +370,14 @@ def near_vector( rerank=rerank, autocut=autocut, group_by=group_by, - near_vector=near_vector_grpc, + near_vector=search_get_pb2.NearVector( + certainty=certainty, + distance=distance, + targets=targets, + target_vectors=target_vectors, + vector_per_target=vector_per_target_tmp, + vector=near_vector_grpc, + ), ) return self.__call(request) @@ -902,3 +845,62 @@ def __target_vector_to_grpc( return search_get_pb2.Targets(target_vectors=target_vector), None else: return target_vector.to_grpc_target_vector(), None + + @staticmethod + def __vector_per_target( + vector: NearVectorInputType, targets: Optional[search_get_pb2.Targets], argument_name: str + ) -> Tuple[Optional[Dict[str, bytes]], Optional[bytes]]: + invalid_nv_exception = WeaviateInvalidInputError( + f"""{argument_name} argument can be: + - a list of numbers + - a list of lists of numbers for multi target search + - a dictionary with target names as keys and lists of numbers as values + received: {vector}""" + ) + if isinstance(vector, dict): + if targets is None or len(targets.target_vectors) != len(vector): + raise WeaviateInvalidInputError( + "The number of target vectors must be equal to the number of vectors." + ) + + vector_per_target: Dict[str, bytes] = {} + for key, value in vector.items(): + nv = _get_vector_v4(value) + + if ( + not isinstance(nv, list) + or len(nv) == 0 + or not isinstance(nv[0], get_args(NUMBER)) + ): + raise invalid_nv_exception + + vector_per_target[key] = struct.pack("{}f".format(len(nv)), *nv) + + return vector_per_target, None + else: + if len(vector) == 0: + raise invalid_nv_exception + + if _is_1d_vector(vector): + near_vector = _get_vector_v4(vector) + if not isinstance(near_vector, list): + raise invalid_nv_exception + return None, struct.pack("{}f".format(len(near_vector)), *near_vector) + else: + vector_per_target = {} + if targets is None or len(targets.target_vectors) != len(vector): + raise WeaviateInvalidInputError( + "The number of target vectors must be equal to the number of vectors." + ) + for i, inner_vector in enumerate(vector): + nv = _get_vector_v4(inner_vector) + if ( + not isinstance(nv, list) + or len(nv) == 0 + or not isinstance(nv[0], get_args(NUMBER)) + ): + raise invalid_nv_exception + vector_per_target[targets.target_vectors[i]] = struct.pack( + "{}f".format(len(nv)), *nv + ) + return vector_per_target, None From 7f226c3af6b0db3500bfc57e155feeaa02469949 Mon Sep 17 00:00:00 2001 From: Dirk Kulawiak Date: Thu, 4 Jul 2024 13:08:43 +0200 Subject: [PATCH 2/6] Fix argument for vector --- integration/conftest.py | 7 ++++--- weaviate/collections/grpc/query.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/integration/conftest.py b/integration/conftest.py index f41c7ebaf..54116928b 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -19,11 +19,10 @@ _VectorIndexConfigCreate, _RerankerConfigCreate, ) +from weaviate.collections.classes.config_named_vectors import _NamedVectorConfigCreate from weaviate.collections.classes.types import Properties from weaviate.config import AdditionalConfig -from weaviate.collections.classes.config_named_vectors import _NamedVectorConfigCreate - class CollectionFactory(Protocol): """Typing for fixture.""" @@ -77,7 +76,9 @@ def _factory( reranker_config: Optional[_RerankerConfigCreate] = None, ) -> Collection[Any, Any]: nonlocal client_fixture, name_fixture - name_fixture = _sanitize_collection_name(request.node.name) + name + name_fixture = ( + _sanitize_collection_name(request.node.fspath.basename + "_" + request.node.name) + name + ) client_fixture = weaviate.connect_to_local( headers=headers, grpc_port=ports[1], diff --git a/weaviate/collections/grpc/query.py b/weaviate/collections/grpc/query.py index e831c1881..2f8762b67 100644 --- a/weaviate/collections/grpc/query.py +++ b/weaviate/collections/grpc/query.py @@ -376,7 +376,7 @@ def near_vector( targets=targets, target_vectors=target_vectors, vector_per_target=vector_per_target_tmp, - vector=near_vector_grpc, + vector_bytes=near_vector_grpc, ), ) From cae600f95821bdbff8d03e87c7cbfaf3d50d9393 Mon Sep 17 00:00:00 2001 From: Dirk Kulawiak Date: Thu, 4 Jul 2024 13:21:46 +0200 Subject: [PATCH 3/6] Fix hybrid near nector search --- weaviate/collections/grpc/query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/weaviate/collections/grpc/query.py b/weaviate/collections/grpc/query.py index 2f8762b67..f2f91f521 100644 --- a/weaviate/collections/grpc/query.py +++ b/weaviate/collections/grpc/query.py @@ -218,11 +218,11 @@ def hybrid( move_to=self.__parse_move(vector.move_to), ) elif isinstance(vector, _HybridNearVector): - vector_per_target, vector_bytes = self.__vector_per_target( + vector_per_target, vector_bytes_tmp = self.__vector_per_target( vector.vector, targets, "vector" ) near_vector = search_get_pb2.NearVector( - vector_bytes=vector_bytes, + vector_bytes=vector_bytes_tmp, certainty=vector.certainty, distance=vector.distance, vector_per_target=vector_per_target, From ac8429ceb20345690be8b06ce4df92457a215899 Mon Sep 17 00:00:00 2001 From: Dirk Kulawiak Date: Thu, 4 Jul 2024 13:53:56 +0200 Subject: [PATCH 4/6] Add support for vector per target for vector argument in hybrid --- integration/test_collection_hybrid.py | 19 +++++++++++------- weaviate/collections/grpc/query.py | 28 ++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/integration/test_collection_hybrid.py b/integration/test_collection_hybrid.py index 8f88936a4..f30dce1f1 100644 --- a/integration/test_collection_hybrid.py +++ b/integration/test_collection_hybrid.py @@ -402,19 +402,24 @@ def test_vector_per_target( pytest.skip("No multi target search below 1.26") collection = collection_factory( - properties=[ - Property(name="text", data_type=DataType.TEXT), - ], + properties=[], vectorizer_config=[ Configure.NamedVectors.none("first"), Configure.NamedVectors.none("second"), ], ) - uuid1 = collection.data.insert( - {"text": "banana"}, vector={"first": [1, 0], "second": [1, 0, 0]} - ) - collection.data.insert({"text": "apple"}, vector={"first": [0, 1], "second": [0, 0, 1]}) + uuid1 = collection.data.insert({}, vector={"first": [1, 0], "second": [1, 0, 0]}) + uuid2 = collection.data.insert({}, vector={"first": [0, 1], "second": [0, 0, 1]}) + + objs = collection.query.hybrid( + query=None, + vector=vector, + target_vector=["first", "second"], + ).objects + assert len(objs) == 2 + assert objs[0].uuid == uuid1 + assert objs[1].uuid == uuid2 objs = collection.query.hybrid( query=None, diff --git a/weaviate/collections/grpc/query.py b/weaviate/collections/grpc/query.py index f2f91f521..121395946 100644 --- a/weaviate/collections/grpc/query.py +++ b/weaviate/collections/grpc/query.py @@ -189,7 +189,18 @@ def hybrid( _ValidateArgument([None, str], "query", query), _ValidateArgument([float, int, None], "alpha", alpha), _ValidateArgument( - [List, _HybridNearText, _HybridNearVector, None], "vector", vector + [ + List, + Dict, + _ExtraTypes.PANDAS, + _ExtraTypes.POLARS, + _ExtraTypes.NUMPY, + _ExtraTypes.TF, + _HybridNearText, + _HybridNearVector, + ], + "vector", + vector, ), _ValidateArgument([List, None], "properties", properties), _ValidateArgument([HybridFusion, None], "fusion_type", fusion_type), @@ -209,6 +220,9 @@ def hybrid( if vector is None: pass + elif isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], float): + # fast path for simple vector + vector_bytes = struct.pack("{}f".format(len(vector)), *vector) elif isinstance(vector, _HybridNearText): near_text = search_get_pb2.NearTextSearch( query=[vector.text] if isinstance(vector.text, str) else vector.text, @@ -228,8 +242,16 @@ def hybrid( vector_per_target=vector_per_target, ) else: - vector = _get_vector_v4(vector) - vector_bytes = struct.pack("{}f".format(len(vector)), *vector) + vector_per_target, vector_bytes_tmp = self.__vector_per_target( + vector, targets, "vector" + ) + if vector_per_target is not None: + near_vector = search_get_pb2.NearVector( + vector_bytes=vector_bytes_tmp, + vector_per_target=vector_per_target, + ) + else: + vector_bytes = vector_bytes_tmp hybrid_search = ( search_get_pb2.Hybrid( From 8d95d60117954e43cb9b4b018543b89a6ae28a8d Mon Sep 17 00:00:00 2001 From: Dirk Kulawiak Date: Thu, 4 Jul 2024 14:02:10 +0200 Subject: [PATCH 5/6] Add fast path for near vector search with list --- weaviate/collections/grpc/query.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/weaviate/collections/grpc/query.py b/weaviate/collections/grpc/query.py index 121395946..797c7b119 100644 --- a/weaviate/collections/grpc/query.py +++ b/weaviate/collections/grpc/query.py @@ -378,9 +378,20 @@ def near_vector( targets, target_vectors = self.__target_vector_to_grpc(target_vector) - vector_per_target_tmp, near_vector_grpc = self.__vector_per_target( - near_vector, targets, "near_vector" - ) + if ( + isinstance(near_vector, list) + and len(near_vector) > 0 + and isinstance(near_vector[0], float) + ): + # fast path for simple vector + near_vector_grpc: Optional[bytes] = struct.pack( + "{}f".format(len(near_vector)), *near_vector + ) + vector_per_target_tmp = None + else: + vector_per_target_tmp, near_vector_grpc = self.__vector_per_target( + near_vector, targets, "near_vector" + ) request = self.__create_request( limit=limit, offset=offset, From 0cc1b9a8b48d8a5e4ee03ed088947904ae25d9c7 Mon Sep 17 00:00:00 2001 From: Dirk Kulawiak Date: Thu, 4 Jul 2024 14:04:39 +0200 Subject: [PATCH 6/6] Allow none as vector input again --- weaviate/collections/grpc/query.py | 1 + 1 file changed, 1 insertion(+) diff --git a/weaviate/collections/grpc/query.py b/weaviate/collections/grpc/query.py index 797c7b119..1e80417ee 100644 --- a/weaviate/collections/grpc/query.py +++ b/weaviate/collections/grpc/query.py @@ -198,6 +198,7 @@ def hybrid( _ExtraTypes.TF, _HybridNearText, _HybridNearVector, + None, ], "vector", vector,