Skip to content

Commit

Permalink
Merge pull request #1160 from weaviate/hybrid_vector_per_target
Browse files Browse the repository at this point in the history
Hybrid search vector per target
  • Loading branch information
dirkkul committed Jul 5, 2024
2 parents c60e2f6 + 0cc1b9a commit 6d6e80c
Show file tree
Hide file tree
Showing 5 changed files with 572 additions and 461 deletions.
7 changes: 4 additions & 3 deletions integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
_VectorIndexConfigCreate,
_RerankerConfigCreate,
)
from weaviate.collections.classes.config_named_vectors import _NamedVectorConfigCreate
from weaviate.collections.classes.types import Properties
from weaviate.config import AdditionalConfig

from weaviate.collections.classes.config_named_vectors import _NamedVectorConfigCreate


class CollectionFactory(Protocol):
"""Typing for fixture."""
Expand Down Expand Up @@ -77,7 +76,9 @@ def _factory(
reranker_config: Optional[_RerankerConfigCreate] = None,
) -> Collection[Any, Any]:
nonlocal client_fixture, name_fixture
name_fixture = _sanitize_collection_name(request.node.name) + name
name_fixture = (
_sanitize_collection_name(request.node.fspath.basename + "_" + request.node.name) + name
)
client_fixture = weaviate.connect_to_local(
headers=headers,
grpc_port=ports[1],
Expand Down
356 changes: 0 additions & 356 deletions integration/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pytest

import weaviate.classes as wvc
from integration.conftest import CollectionFactory, CollectionFactoryGet, _sanitize_collection_name
from integration.constants import WEAVIATE_LOGO_OLD_ENCODED, WEAVIATE_LOGO_NEW_ENCODED
from weaviate.collections.classes.batch import ErrorObject
Expand Down Expand Up @@ -541,144 +540,6 @@ def test_types(collection_factory: CollectionFactory, data_type: DataType, value
assert object_get_from_batch is not None and object_get_from_batch.properties[name] == value


@pytest.mark.parametrize("fusion_type", [HybridFusion.RANKED, HybridFusion.RELATIVE_SCORE])
def test_search_hybrid(collection_factory: CollectionFactory, fusion_type: HybridFusion) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4())
collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4())
objs = collection.query.hybrid(
alpha=0, query="name", fusion_type=fusion_type, include_vector=True
).objects
assert len(objs) == 1

objs = collection.query.hybrid(
alpha=1, query="name", fusion_type=fusion_type, vector=objs[0].vector["default"]
).objects
assert len(objs) == 2


def test_search_hybrid_group_by(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4())
collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4())
if collection._connection.supports_groupby_in_bm25_and_hybrid():
objs = collection.query.hybrid(
alpha=0,
query="name",
include_vector=True,
group_by=GroupBy(prop="name", objects_per_group=1, number_of_groups=2),
).objects
assert len(objs) == 1
assert objs[0].belongs_to_group == "some name"
else:
with pytest.raises(WeaviateUnsupportedFeatureError):
collection.query.hybrid(
alpha=0,
query="name",
include_vector=True,
group_by=GroupBy(prop="name", objects_per_group=1, number_of_groups=2),
)


@pytest.mark.parametrize("query", [None, ""])
def test_search_hybrid_only_vector(
collection_factory: CollectionFactory, query: Optional[str]
) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_ = collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4())
vec = collection.query.fetch_object_by_id(uuid_, include_vector=True).vector
assert vec is not None

collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4())

objs = collection.query.hybrid(alpha=1, query=query, vector=vec["default"]).objects
assert len(objs) == 2


@pytest.mark.parametrize("limit", [1, 2])
def test_hybrid_limit(collection_factory: CollectionFactory, limit: int) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)

res = collection.data.insert_many(
[
{"Name": "test"},
{"Name": "another"},
{"Name": "test"},
]
)
assert res.has_errors is False
assert len(collection.query.hybrid(query="test", alpha=0, limit=limit).objects) == limit


@pytest.mark.parametrize("offset,expected", [(0, 2), (1, 1), (2, 0)])
def test_hybrid_offset(collection_factory: CollectionFactory, offset: int, expected: int) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)

res = collection.data.insert_many(
[
{"Name": "test"},
{"Name": "another"},
{"Name": "test"},
]
)
assert res.has_errors is False

assert len(collection.query.hybrid(query="test", alpha=0, offset=offset).objects) == expected


def test_hybrid_alpha(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)

res = collection.data.insert_many(
[
{"name": "banana"},
{"name": "fruit"},
{"name": "car"},
]
)
assert res.has_errors is False

hybrid_res = collection.query.hybrid(query="fruit", alpha=0)
bm25_res = collection.query.bm25(query="fruit")
assert all(
bm25_res.objects[i].uuid == hybrid_res.objects[i].uuid
for i in range(len(hybrid_res.objects))
)

hybrid_res = collection.query.hybrid(query="fruit", alpha=1)
text_res = collection.query.near_text(query="fruit")
assert all(
text_res.objects[i].uuid == hybrid_res.objects[i].uuid
for i in range(len(hybrid_res.objects))
)


def test_bm25(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
Expand Down Expand Up @@ -1844,220 +1705,3 @@ def test_none_query_hybrid_bm25(collection_factory: CollectionFactory) -> None:
bm25_objs = collection.query.bm25(query=None, return_metadata=MetadataQuery.full()).objects
assert len(bm25_objs) == 3
assert all(obj.metadata.score is not None and obj.metadata.score == 0.0 for obj in bm25_objs)


def test_hybrid_near_vector_search(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[
Property(name="text", data_type=DataType.TEXT),
],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_banana = collection.data.insert({"text": "banana"})
obj = collection.query.fetch_object_by_id(uuid_banana, include_vector=True)

if collection._connection._weaviate_version.is_lower_than(1, 25, 0):
with pytest.raises(WeaviateUnsupportedFeatureError):
collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(vector=obj.vector["default"]),
).objects
return

collection.data.insert({"text": "dog"})
collection.data.insert({"text": "different concept"})

hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(vector=obj.vector["default"]),
).objects

assert hybrid_objs[0].uuid == uuid_banana
assert len(hybrid_objs) == 3

# make a near vector search to get the distance
near_vec = collection.query.near_vector(
near_vector=obj.vector["default"], return_metadata=["distance"]
).objects
assert near_vec[0].metadata.distance is not None

hybrid_objs2 = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(
vector=obj.vector["default"], distance=near_vec[0].metadata.distance + 0.001
),
return_metadata=MetadataQuery.full(),
).objects

assert hybrid_objs2[0].uuid == uuid_banana
assert len(hybrid_objs2) == 1


def test_hybrid_near_vector_search_named_vectors(collection_factory: CollectionFactory) -> None:
dummy = collection_factory("dummy")
collection_maker = lambda: collection_factory(
properties=[
Property(name="text", data_type=DataType.TEXT),
Property(name="int", data_type=DataType.INT),
],
vectorizer_config=[
Configure.NamedVectors.text2vec_contextionary(
name="text", vectorize_collection_name=False
),
Configure.NamedVectors.text2vec_contextionary(
name="int", vectorize_collection_name=False
),
],
)

if dummy._connection._weaviate_version.is_lower_than(1, 24, 0):
with pytest.raises(WeaviateInvalidInputError):
collection_maker()
return

collection = collection_maker()
uuid_banana = collection.data.insert({"text": "banana"})
collection.data.insert({"text": "dog"})
collection.data.insert({"text": "different concept"})

obj = collection.query.fetch_object_by_id(uuid_banana, include_vector=True)

if collection._connection._weaviate_version.is_lower_than(1, 25, 0):
with pytest.raises(WeaviateUnsupportedFeatureError):
hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(vector=obj.vector["text"]),
target_vector="text",
).objects
return

hybrid_objs = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(vector=obj.vector["text"]),
target_vector="text",
).objects

assert hybrid_objs[0].uuid == uuid_banana
assert len(hybrid_objs) == 3

# make a near vector search to get the distance
near_vec = collection.query.near_vector(
near_vector=obj.vector["text"], return_metadata=["distance"], target_vector="text"
).objects
assert near_vec[0].metadata.distance is not None

hybrid_objs2 = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(
vector=obj.vector["text"],
distance=near_vec[0].metadata.distance + 0.001,
),
target_vector="text",
return_metadata=MetadataQuery.full(),
).objects

assert hybrid_objs2[0].uuid == uuid_banana
assert len(hybrid_objs2) == 1


def test_hybrid_near_text_search(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[
Property(name="text", data_type=DataType.TEXT),
],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)

if collection._connection._weaviate_version.is_lower_than(1, 25, 0):
with pytest.raises(WeaviateUnsupportedFeatureError):
collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(query="banana pudding"),
).objects
return

uuid_banana_pudding = collection.data.insert({"text": "banana pudding"})
collection.data.insert({"text": "banana smoothie"})
collection.data.insert({"text": "different concept"})

hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(query="banana pudding"),
).objects

assert hybrid_objs[0].uuid == uuid_banana_pudding
assert len(hybrid_objs) == 3

hybrid_objs2 = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(
query="banana",
move_to=wvc.query.Move(concepts="pudding", force=0.1),
move_away=wvc.query.Move(concepts="smoothie", force=0.1),
),
return_metadata=MetadataQuery.full(),
).objects

assert hybrid_objs2[0].uuid == uuid_banana_pudding


def test_hybrid_near_text_search_named_vectors(collection_factory: CollectionFactory) -> None:
dummy = collection_factory("dummy")
collection_maker = lambda: collection_factory(
properties=[
Property(name="text", data_type=DataType.TEXT),
Property(name="int", data_type=DataType.INT),
],
vectorizer_config=[
Configure.NamedVectors.text2vec_contextionary(
name="text", vectorize_collection_name=False
),
Configure.NamedVectors.text2vec_contextionary(
name="int", vectorize_collection_name=False
),
],
)
if dummy._connection._weaviate_version.is_lower_than(1, 24, 0):
with pytest.raises(WeaviateInvalidInputError):
collection_maker()
return

collection = collection_maker()
uuid_banana_pudding = collection.data.insert({"text": "banana pudding"})
collection.data.insert({"text": "banana smoothie"})
collection.data.insert({"text": "different concept"})

if collection._connection._weaviate_version.is_lower_than(1, 25, 0):
with pytest.raises(WeaviateUnsupportedFeatureError):
hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(query="banana pudding"),
target_vector="text",
).objects
return

hybrid_objs = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(query="banana pudding"),
target_vector="text",
).objects

assert hybrid_objs[0].uuid == uuid_banana_pudding
assert len(hybrid_objs) == 3

hybrid_objs2 = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(
query="banana",
move_to=wvc.query.Move(concepts="pudding", force=0.1),
move_away=wvc.query.Move(concepts="smoothie", force=0.1),
),
target_vector="text",
return_metadata=MetadataQuery.full(),
).objects

assert hybrid_objs2[0].uuid == uuid_banana_pudding
Loading

0 comments on commit 6d6e80c

Please sign in to comment.