Skip to content

Commit

Permalink
Merge branch 'main' into fix_tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkkul authored Aug 20, 2024
2 parents 85a602a + 8c3f915 commit c170177
Show file tree
Hide file tree
Showing 12 changed files with 813 additions and 30 deletions.
28 changes: 14 additions & 14 deletions integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def test_authentication_client_credentials(
@pytest.mark.parametrize(
"name,user,env_variable_name,port,scope,warning",
[
(
"WCS",
"ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net",
"WCS_DUMMY_CI_PW",
WCS_PORT,
None,
False,
),
# ( # WCS keycloak times out too often
# "WCS",
# "ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net",
# "WCS_DUMMY_CI_PW",
# WCS_PORT,
# None,
# False,
# ),
(
"okta",
"test@test.de",
Expand Down Expand Up @@ -168,12 +168,12 @@ def _get_access_token(url: str, user: str, pw: str) -> Dict[str, str]:
@pytest.mark.parametrize(
"name,user,env_variable_name,port",
[
(
"WCS",
"ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net",
"WCS_DUMMY_CI_PW",
WCS_PORT,
),
# ( # WCS keycloak times out too often
# "WCS",
# "ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net",
# "WCS_DUMMY_CI_PW",
# WCS_PORT,
# ),
(
"okta",
"test@test.de",
Expand Down
99 changes: 99 additions & 0 deletions integration/test_collection_async.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import datetime
import uuid
from typing import Iterable

import pytest

import weaviate.classes as wvc
from weaviate.collections.classes.config import DataType, Property
from weaviate.collections.classes.data import DataObject
from weaviate.types import UUID

from .conftest import AsyncCollectionFactory, AsyncOpenAICollectionFactory

UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9")
UUID2 = uuid.uuid4()
UUID3 = uuid.uuid4()

DATE1 = datetime.datetime.strptime("2012-02-09", "%Y-%m-%d").replace(tzinfo=datetime.timezone.utc)

Expand All @@ -32,6 +37,51 @@ async def test_fetch_objects(async_collection_factory: AsyncCollectionFactory) -
assert res.objects[0].properties["name"] == "John Doe"


@pytest.mark.asyncio
@pytest.mark.parametrize(
"ids, expected_len, expected",
[
([], 0, set()),
((), 0, set()),
(
[
UUID3,
],
1,
{
UUID3,
},
),
([UUID1, UUID2], 2, {UUID1, UUID2}),
((UUID1, UUID3), 2, {UUID1, UUID3}),
((UUID1, UUID3, UUID3), 2, {UUID1, UUID3}),
],
)
async def test_fetch_objects_by_ids(
async_collection_factory: AsyncCollectionFactory,
ids: Iterable[UUID],
expected_len: int,
expected: set,
) -> None:
collection = await async_collection_factory(
properties=[
Property(name="name", data_type=DataType.TEXT),
],
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
)
await collection.data.insert_many(
[
DataObject(properties={"name": "first"}, uuid=UUID1),
DataObject(properties={"name": "second"}, uuid=UUID2),
DataObject(properties={"name": "third"}, uuid=UUID3),
]
)

res = await collection.query.fetch_objects_by_ids(ids)
assert len(res.objects) == expected_len
assert {o.uuid for o in res.objects} == expected


@pytest.mark.asyncio
async def test_config_update(async_collection_factory: AsyncCollectionFactory) -> None:
collection = await async_collection_factory(
Expand Down Expand Up @@ -200,3 +250,52 @@ async def test_generate(async_openai_collection: AsyncOpenAICollectionFactory) -
assert len(res.objects) == 2
for obj in res.objects:
assert obj.generated is not None


@pytest.mark.asyncio
@pytest.mark.parametrize(
"ids, expected_len, expected",
[
([], 0, set()),
((), 0, set()),
(
[
UUID3,
],
1,
{
UUID3,
},
),
([UUID1, UUID2], 2, {UUID1, UUID2}),
((UUID1, UUID3), 2, {UUID1, UUID3}),
((UUID1, UUID3, UUID3), 2, {UUID1, UUID3}),
],
)
async def test_generate_by_ids(
async_openai_collection: AsyncOpenAICollectionFactory,
ids: Iterable[UUID],
expected_len: int,
expected: set,
) -> None:
collection = await async_openai_collection(
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
)
await collection.data.insert_many(
[
DataObject(properties={"text": "John Doe"}, uuid=UUID1),
DataObject(properties={"text": "Jane Doe"}, uuid=UUID2),
DataObject(properties={"text": "J. Doe"}, uuid=UUID3),
]
)
res = await collection.generate.fetch_objects_by_ids(
ids,
single_prompt="Who is this? {text}",
grouped_task="Who are these people?",
)
assert res is not None
assert res.generated is not None
assert len(res.objects) == expected_len
assert {o.uuid for o in res.objects} == expected
for obj in res.objects:
assert obj.generated is not None
49 changes: 48 additions & 1 deletion integration/test_collection_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import time
import uuid
from typing import Callable, List, Optional
from typing import Callable, Iterable, List, Optional

import pytest as pytest

Expand All @@ -21,6 +21,7 @@
)
from weaviate.collections.classes.grpc import MetadataQuery, QueryReference, Sort
from weaviate.collections.classes.internal import ReferenceToMulti
from weaviate.types import UUID

NOW = datetime.datetime.now(datetime.timezone.utc)
LATER = NOW + datetime.timedelta(hours=1)
Expand Down Expand Up @@ -548,6 +549,52 @@ def test_filter_id(collection_factory: CollectionFactory, weav_filter: _FilterVa
assert objects[0].uuid == UUID1


@pytest.mark.parametrize(
"ids, expected_len, expected",
[
([], 0, set()),
((), 0, set()),
(
[
UUID3,
],
1,
{
UUID3,
},
),
([UUID1, UUID2], 2, {UUID1, UUID2}),
((UUID1, UUID3), 2, {UUID1, UUID3}),
((UUID1, UUID3, UUID3), 2, {UUID1, UUID3}),
],
)
def test_filter_ids(
collection_factory: CollectionFactory,
ids: Iterable[UUID],
expected_len: int,
expected: set,
) -> None:
collection = collection_factory(
properties=[
Property(name="Name", data_type=DataType.TEXT),
],
vectorizer_config=Configure.Vectorizer.none(),
)

collection.data.insert_many(
[
DataObject(properties={"name": "first"}, uuid=UUID1),
DataObject(properties={"name": "second"}, uuid=UUID2),
DataObject(properties={"name": "third"}, uuid=UUID3),
]
)

objects = collection.query.fetch_objects_by_ids(ids).objects

assert len(objects) == expected_len
assert {o.uuid for o in objects} == expected


@pytest.mark.parametrize("path", ["_creationTimeUnix", "_lastUpdateTimeUnix"])
def test_filter_timestamp_direct_path(collection_factory: CollectionFactory, path: str) -> None:
collection = collection_factory(
Expand Down
30 changes: 16 additions & 14 deletions integration_v3/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ def test_authentication_client_credentials(
@pytest.mark.parametrize(
"name,user,env_variable_name,port,scope,warning",
[
(
"WCS",
"ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net",
"WCS_DUMMY_CI_PW",
WCS_PORT,
None,
False,
),
# ( # WCS keycloak times out too often
# "WCS",
# "ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net",
# "WCS_DUMMY_CI_PW",
# WCS_PORT,
# None,
# False,
# ),
(
"okta",
"test@test.de",
Expand Down Expand Up @@ -168,12 +168,12 @@ def _get_access_token(url: str, user: str, pw: str) -> Dict[str, str]:
@pytest.mark.parametrize(
"name,user,env_variable_name,port",
[
(
"WCS",
"ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net",
"WCS_DUMMY_CI_PW",
WCS_PORT,
),
# (
# "WCS",
# "ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net",
# "WCS_DUMMY_CI_PW",
# WCS_PORT,
# ),
(
"okta",
"test@test.de",
Expand Down Expand Up @@ -227,6 +227,8 @@ def test_client_with_authentication_with_anon_weaviate(recwarn):
def test_bearer_token_without_refresh(recwarn):
"""Test that the client warns users when only supplying an access token without refresh."""

pytest.skip("WCS keycloak times out too often")

# testing for warnings can be flaky without this as there are open SSL conections
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
warnings.filterwarnings(action="ignore", message="Dep005", category=DeprecationWarning)
Expand Down
2 changes: 1 addition & 1 deletion requirements-devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ polars>=0.20.26,<1.3.0

fastapi==0.111.1
flask[async]==3.0.3
litestar==2.9.1
litestar==2.10.0

mypy==1.11.0
mypy-extensions==1.0.0
Expand Down
6 changes: 6 additions & 0 deletions weaviate/collections/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
_FetchObjectsGenerateAsync,
_FetchObjectsGenerate,
)
from weaviate.collections.queries.fetch_objects_by_ids import (
_FetchObjectsByIDsGenerateAsync,
_FetchObjectsByIDsGenerate,
)
from weaviate.collections.queries.hybrid import _HybridGenerateAsync, _HybridGenerate
from weaviate.collections.queries.near_image import _NearImageGenerateAsync, _NearImageGenerate
from weaviate.collections.queries.near_media import _NearMediaGenerateAsync, _NearMediaGenerate
Expand All @@ -19,6 +23,7 @@ class _GenerateCollectionAsync(
Generic[TProperties, References],
_BM25GenerateAsync[TProperties, References],
_FetchObjectsGenerateAsync[TProperties, References],
_FetchObjectsByIDsGenerateAsync[TProperties, References],
_HybridGenerateAsync[TProperties, References],
_NearImageGenerateAsync[TProperties, References],
_NearMediaGenerateAsync[TProperties, References],
Expand All @@ -33,6 +38,7 @@ class _GenerateCollection(
Generic[TProperties, References],
_BM25Generate[TProperties, References],
_FetchObjectsGenerate[TProperties, References],
_FetchObjectsByIDsGenerate[TProperties, References],
_HybridGenerate[TProperties, References],
_NearImageGenerate[TProperties, References],
_NearMediaGenerate[TProperties, References],
Expand Down
9 changes: 9 additions & 0 deletions weaviate/collections/queries/fetch_objects_by_ids/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .generate import _FetchObjectsByIDsGenerateAsync, _FetchObjectsByIDsGenerate
from .query import _FetchObjectsByIDsQueryAsync, _FetchObjectsByIDsQuery

__all__ = [
"_FetchObjectsByIDsGenerate",
"_FetchObjectsByIDsGenerateAsync",
"_FetchObjectsByIDsQuery",
"_FetchObjectsByIDsQueryAsync",
]
Loading

0 comments on commit c170177

Please sign in to comment.