Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send tenants.update requests in batches of 100 #1192

Merged
merged 9 commits into from
Jul 23, 2024
6 changes: 3 additions & 3 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ on:

env:
WEAVIATE_123: 1.23.16
WEAVIATE_124: 1.24.20
WEAVIATE_125: 1.25.7
WEAVIATE_126: preview--0ccf121
WEAVIATE_124: 1.24.21
WEAVIATE_125: 1.25.8
WEAVIATE_126: preview--caba86b


jobs:
Expand Down
25 changes: 25 additions & 0 deletions integration/test_tenants.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,28 @@ def test_tenants_update_with_read_only_activity_status(
)
with pytest.raises(WeaviateInvalidInputError):
collection.tenants.update(tenants)


def test_tenants_create_and_update_1001_tenants(
collection_factory: CollectionFactory,
) -> None:
collection = collection_factory(
vectorizer_config=Configure.Vectorizer.none(),
multi_tenancy_config=Configure.multi_tenancy(),
)

tenants = [TenantCreate(name=f"tenant{i}") for i in range(1001)]

collection.tenants.create(tenants)
t = collection.tenants.get()
assert len(t) == 1001
assert all(tenant.activity_status == TenantActivityStatus.ACTIVE for tenant in t.values())

tenants = [
Tenant(name=f"tenant{i}", activity_status=TenantActivityStatus.INACTIVE)
for i in range(1001)
]
collection.tenants.update(tenants)
t = collection.tenants.get()
assert len(t) == 1001
assert all(tenant.activity_status == TenantActivityStatus.INACTIVE for tenant in t.values())
2 changes: 1 addition & 1 deletion weaviate/collections/collection/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def iterator(
) -> _ObjectAIterator[TProperties, TReferences]:
...

def iterator( # type: ignore
def iterator(
self,
include_vector: bool = False,
return_metadata: Optional[METADATA] = None,
Expand Down
4 changes: 1 addition & 3 deletions weaviate/collections/collection/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,7 @@ def iterator(
) -> _ObjectIterator[TProperties, TReferences]:
...

# weaviate/collections/collection.py:263: error: Overloaded function implementation does not accept all possible arguments of signature 3 [misc]
# weaviate/collections/collection.py:263: error: Overloaded function implementation cannot produce return type of signature 3 [misc]
def iterator( # type: ignore
def iterator(
self,
include_vector: bool = False,
return_metadata: Optional[METADATA] = None,
Expand Down
67 changes: 45 additions & 22 deletions weaviate/collections/tenants/tenants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from math import ceil
from typing import Any, Dict, List, Optional, Sequence, Union

from weaviate.collections.classes.config import ConsistencyLevel
Expand All @@ -20,6 +21,8 @@
TenantUpdateInputType = Union[Tenant, TenantUpdate]
TenantOutputType = Tenant

MUTATE_TENANT_BATCH_SIZE = 100


class _TenantsBase:
def __init__(
Expand Down Expand Up @@ -85,14 +88,15 @@ async def create(
)

path = "/schema/" + self._name + "/tenants"
await self._connection.post(
path=path,
weaviate_object=self.__map_create_tenants(tenants),
error_msg=f"Collection tenants may not have been added properly for {self._name}",
status_codes=_ExpectedStatusCodes(
ok_in=200, error=f"Add collection tenants for {self._name}"
),
)
for mapped_tenants in self.__map_create_tenants(tenants):
await self._connection.post(
path=path,
weaviate_object=mapped_tenants,
error_msg=f"Collection tenants may not have been added properly for {self._name}",
status_codes=_ExpectedStatusCodes(
ok_in=200, error=f"Add collection tenants for {self._name}"
),
)

async def remove(self, tenants: Union[str, Tenant, Sequence[Union[str, Tenant]]]) -> None:
"""Remove the specified tenants from a collection in Weaviate.
Expand Down Expand Up @@ -210,27 +214,45 @@ def __map_update_tenant(self, tenant: TenantUpdateInputType) -> TenantUpdate:

def __map_create_tenants(
self, tenant: Union[str, Tenant, TenantCreate, Sequence[Union[str, Tenant, TenantCreate]]]
) -> List[dict]:
) -> List[List[dict]]:
if (
isinstance(tenant, str)
or isinstance(tenant, Tenant)
or isinstance(tenant, TenantCreate)
):
return [self.__map_create_tenant(tenant).model_dump()]
return [[self.__map_create_tenant(tenant).model_dump()]]
else:
return [self.__map_create_tenant(t).model_dump() for t in tenant]
batches = ceil(len(tenant) / MUTATE_TENANT_BATCH_SIZE)
tsmith023 marked this conversation as resolved.
Show resolved Hide resolved
return [
[
self.__map_create_tenant(tenant[i + b * MUTATE_TENANT_BATCH_SIZE]).model_dump()
for i in range(
min(len(tenant) - b * MUTATE_TENANT_BATCH_SIZE, MUTATE_TENANT_BATCH_SIZE)
)
]
for b in range(batches)
]

def __map_update_tenants(
self, tenant: Union[TenantUpdateInputType, Sequence[TenantUpdateInputType]]
) -> List[dict]:
) -> List[List[dict]]:
if (
isinstance(tenant, str)
or isinstance(tenant, Tenant)
or isinstance(tenant, TenantUpdate)
):
return [self.__map_update_tenant(tenant).model_dump()]
return [[self.__map_update_tenant(tenant).model_dump()]]
else:
return [self.__map_update_tenant(t).model_dump() for t in tenant]
batches = ceil(len(tenant) / MUTATE_TENANT_BATCH_SIZE)
return [
[
self.__map_update_tenant(tenant[i + b * MUTATE_TENANT_BATCH_SIZE]).model_dump()
for i in range(
min(len(tenant) - b * MUTATE_TENANT_BATCH_SIZE, MUTATE_TENANT_BATCH_SIZE)
)
]
for b in range(batches)
]

async def get(self) -> Dict[str, TenantOutputType]:
"""Return all tenants currently associated with a collection in Weaviate.
Expand Down Expand Up @@ -340,14 +362,15 @@ async def update(
)

path = "/schema/" + self._name + "/tenants"
await self._connection.put(
path=path,
weaviate_object=self.__map_update_tenants(tenants),
error_msg=f"Collection tenants may not have been updated properly for {self._name}",
status_codes=_ExpectedStatusCodes(
ok_in=200, error=f"Update collection tenants for {self._name}"
),
)
for mapped_tenants in self.__map_update_tenants(tenants):
await self._connection.put(
path=path,
weaviate_object=mapped_tenants,
error_msg=f"Collection tenants may not have been updated properly for {self._name}",
status_codes=_ExpectedStatusCodes(
ok_in=200, error=f"Update collection tenants for {self._name}"
),
)

async def exists(self, tenant: Union[str, Tenant]) -> bool:
"""Check if a tenant exists for a collection in Weaviate.
Expand Down