Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid search vector per target #1160

Merged
merged 6 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
_VectorIndexConfigCreate,
_RerankerConfigCreate,
)
from weaviate.collections.classes.config_named_vectors import _NamedVectorConfigCreate
from weaviate.collections.classes.types import Properties
from weaviate.config import AdditionalConfig

from weaviate.collections.classes.config_named_vectors import _NamedVectorConfigCreate


class CollectionFactory(Protocol):
"""Typing for fixture."""
Expand Down Expand Up @@ -77,7 +76,9 @@ def _factory(
reranker_config: Optional[_RerankerConfigCreate] = None,
) -> Collection[Any, Any]:
nonlocal client_fixture, name_fixture
name_fixture = _sanitize_collection_name(request.node.name) + name
name_fixture = (
_sanitize_collection_name(request.node.fspath.basename + "_" + request.node.name) + name
)
client_fixture = weaviate.connect_to_local(
headers=headers,
grpc_port=ports[1],
Expand Down
356 changes: 0 additions & 356 deletions integration/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pytest

import weaviate.classes as wvc
from integration.conftest import CollectionFactory, CollectionFactoryGet, _sanitize_collection_name
from integration.constants import WEAVIATE_LOGO_OLD_ENCODED, WEAVIATE_LOGO_NEW_ENCODED
from weaviate.collections.classes.batch import ErrorObject
Expand Down Expand Up @@ -541,144 +540,6 @@ def test_types(collection_factory: CollectionFactory, data_type: DataType, value
assert object_get_from_batch is not None and object_get_from_batch.properties[name] == value


@pytest.mark.parametrize("fusion_type", [HybridFusion.RANKED, HybridFusion.RELATIVE_SCORE])
def test_search_hybrid(collection_factory: CollectionFactory, fusion_type: HybridFusion) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4())
collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4())
objs = collection.query.hybrid(
alpha=0, query="name", fusion_type=fusion_type, include_vector=True
).objects
assert len(objs) == 1

objs = collection.query.hybrid(
alpha=1, query="name", fusion_type=fusion_type, vector=objs[0].vector["default"]
).objects
assert len(objs) == 2


def test_search_hybrid_group_by(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4())
collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4())
if collection._connection.supports_groupby_in_bm25_and_hybrid():
objs = collection.query.hybrid(
alpha=0,
query="name",
include_vector=True,
group_by=GroupBy(prop="name", objects_per_group=1, number_of_groups=2),
).objects
assert len(objs) == 1
assert objs[0].belongs_to_group == "some name"
else:
with pytest.raises(WeaviateUnsupportedFeatureError):
collection.query.hybrid(
alpha=0,
query="name",
include_vector=True,
group_by=GroupBy(prop="name", objects_per_group=1, number_of_groups=2),
)


@pytest.mark.parametrize("query", [None, ""])
def test_search_hybrid_only_vector(
collection_factory: CollectionFactory, query: Optional[str]
) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_ = collection.data.insert({"Name": "some name"}, uuid=uuid.uuid4())
vec = collection.query.fetch_object_by_id(uuid_, include_vector=True).vector
assert vec is not None

collection.data.insert({"Name": "other word"}, uuid=uuid.uuid4())

objs = collection.query.hybrid(alpha=1, query=query, vector=vec["default"]).objects
assert len(objs) == 2


@pytest.mark.parametrize("limit", [1, 2])
def test_hybrid_limit(collection_factory: CollectionFactory, limit: int) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)

res = collection.data.insert_many(
[
{"Name": "test"},
{"Name": "another"},
{"Name": "test"},
]
)
assert res.has_errors is False
assert len(collection.query.hybrid(query="test", alpha=0, limit=limit).objects) == limit


@pytest.mark.parametrize("offset,expected", [(0, 2), (1, 1), (2, 0)])
def test_hybrid_offset(collection_factory: CollectionFactory, offset: int, expected: int) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)

res = collection.data.insert_many(
[
{"Name": "test"},
{"Name": "another"},
{"Name": "test"},
]
)
assert res.has_errors is False

assert len(collection.query.hybrid(query="test", alpha=0, offset=offset).objects) == expected


def test_hybrid_alpha(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)

res = collection.data.insert_many(
[
{"name": "banana"},
{"name": "fruit"},
{"name": "car"},
]
)
assert res.has_errors is False

hybrid_res = collection.query.hybrid(query="fruit", alpha=0)
bm25_res = collection.query.bm25(query="fruit")
assert all(
bm25_res.objects[i].uuid == hybrid_res.objects[i].uuid
for i in range(len(hybrid_res.objects))
)

hybrid_res = collection.query.hybrid(query="fruit", alpha=1)
text_res = collection.query.near_text(query="fruit")
assert all(
text_res.objects[i].uuid == hybrid_res.objects[i].uuid
for i in range(len(hybrid_res.objects))
)


def test_bm25(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
Expand Down Expand Up @@ -1844,220 +1705,3 @@ def test_none_query_hybrid_bm25(collection_factory: CollectionFactory) -> None:
bm25_objs = collection.query.bm25(query=None, return_metadata=MetadataQuery.full()).objects
assert len(bm25_objs) == 3
assert all(obj.metadata.score is not None and obj.metadata.score == 0.0 for obj in bm25_objs)


def test_hybrid_near_vector_search(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[
Property(name="text", data_type=DataType.TEXT),
],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_banana = collection.data.insert({"text": "banana"})
obj = collection.query.fetch_object_by_id(uuid_banana, include_vector=True)

if collection._connection._weaviate_version.is_lower_than(1, 25, 0):
with pytest.raises(WeaviateUnsupportedFeatureError):
collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(vector=obj.vector["default"]),
).objects
return

collection.data.insert({"text": "dog"})
collection.data.insert({"text": "different concept"})

hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(vector=obj.vector["default"]),
).objects

assert hybrid_objs[0].uuid == uuid_banana
assert len(hybrid_objs) == 3

# make a near vector search to get the distance
near_vec = collection.query.near_vector(
near_vector=obj.vector["default"], return_metadata=["distance"]
).objects
assert near_vec[0].metadata.distance is not None

hybrid_objs2 = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(
vector=obj.vector["default"], distance=near_vec[0].metadata.distance + 0.001
),
return_metadata=MetadataQuery.full(),
).objects

assert hybrid_objs2[0].uuid == uuid_banana
assert len(hybrid_objs2) == 1


def test_hybrid_near_vector_search_named_vectors(collection_factory: CollectionFactory) -> None:
dummy = collection_factory("dummy")
collection_maker = lambda: collection_factory(
properties=[
Property(name="text", data_type=DataType.TEXT),
Property(name="int", data_type=DataType.INT),
],
vectorizer_config=[
Configure.NamedVectors.text2vec_contextionary(
name="text", vectorize_collection_name=False
),
Configure.NamedVectors.text2vec_contextionary(
name="int", vectorize_collection_name=False
),
],
)

if dummy._connection._weaviate_version.is_lower_than(1, 24, 0):
with pytest.raises(WeaviateInvalidInputError):
collection_maker()
return

collection = collection_maker()
uuid_banana = collection.data.insert({"text": "banana"})
collection.data.insert({"text": "dog"})
collection.data.insert({"text": "different concept"})

obj = collection.query.fetch_object_by_id(uuid_banana, include_vector=True)

if collection._connection._weaviate_version.is_lower_than(1, 25, 0):
with pytest.raises(WeaviateUnsupportedFeatureError):
hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(vector=obj.vector["text"]),
target_vector="text",
).objects
return

hybrid_objs = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(vector=obj.vector["text"]),
target_vector="text",
).objects

assert hybrid_objs[0].uuid == uuid_banana
assert len(hybrid_objs) == 3

# make a near vector search to get the distance
near_vec = collection.query.near_vector(
near_vector=obj.vector["text"], return_metadata=["distance"], target_vector="text"
).objects
assert near_vec[0].metadata.distance is not None

hybrid_objs2 = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_vector(
vector=obj.vector["text"],
distance=near_vec[0].metadata.distance + 0.001,
),
target_vector="text",
return_metadata=MetadataQuery.full(),
).objects

assert hybrid_objs2[0].uuid == uuid_banana
assert len(hybrid_objs2) == 1


def test_hybrid_near_text_search(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[
Property(name="text", data_type=DataType.TEXT),
],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)

if collection._connection._weaviate_version.is_lower_than(1, 25, 0):
with pytest.raises(WeaviateUnsupportedFeatureError):
collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(query="banana pudding"),
).objects
return

uuid_banana_pudding = collection.data.insert({"text": "banana pudding"})
collection.data.insert({"text": "banana smoothie"})
collection.data.insert({"text": "different concept"})

hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(query="banana pudding"),
).objects

assert hybrid_objs[0].uuid == uuid_banana_pudding
assert len(hybrid_objs) == 3

hybrid_objs2 = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(
query="banana",
move_to=wvc.query.Move(concepts="pudding", force=0.1),
move_away=wvc.query.Move(concepts="smoothie", force=0.1),
),
return_metadata=MetadataQuery.full(),
).objects

assert hybrid_objs2[0].uuid == uuid_banana_pudding


def test_hybrid_near_text_search_named_vectors(collection_factory: CollectionFactory) -> None:
dummy = collection_factory("dummy")
collection_maker = lambda: collection_factory(
properties=[
Property(name="text", data_type=DataType.TEXT),
Property(name="int", data_type=DataType.INT),
],
vectorizer_config=[
Configure.NamedVectors.text2vec_contextionary(
name="text", vectorize_collection_name=False
),
Configure.NamedVectors.text2vec_contextionary(
name="int", vectorize_collection_name=False
),
],
)
if dummy._connection._weaviate_version.is_lower_than(1, 24, 0):
with pytest.raises(WeaviateInvalidInputError):
collection_maker()
return

collection = collection_maker()
uuid_banana_pudding = collection.data.insert({"text": "banana pudding"})
collection.data.insert({"text": "banana smoothie"})
collection.data.insert({"text": "different concept"})

if collection._connection._weaviate_version.is_lower_than(1, 25, 0):
with pytest.raises(WeaviateUnsupportedFeatureError):
hybrid_objs: List[Object[Any, Any]] = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(query="banana pudding"),
target_vector="text",
).objects
return

hybrid_objs = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(query="banana pudding"),
target_vector="text",
).objects

assert hybrid_objs[0].uuid == uuid_banana_pudding
assert len(hybrid_objs) == 3

hybrid_objs2 = collection.query.hybrid(
query=None,
vector=wvc.query.HybridVector.near_text(
query="banana",
move_to=wvc.query.Move(concepts="pudding", force=0.1),
move_away=wvc.query.Move(concepts="smoothie", force=0.1),
),
target_vector="text",
return_metadata=MetadataQuery.full(),
).objects

assert hybrid_objs2[0].uuid == uuid_banana_pudding
Loading