Skip to content

Commit

Permalink
Merge pull request #1155 from weaviate/dev/fix-grpc-mock-test-fixtures
Browse files Browse the repository at this point in the history
Make `start_grpc_server` function scoped so that mocks can be added at the test level
  • Loading branch information
tsmith023 authored Jul 3, 2024
2 parents 083ab28 + 0fbae51 commit 9f1f3af
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 79 deletions.
110 changes: 69 additions & 41 deletions mock_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from concurrent import futures
from typing import Generator
from typing import Generator, Mapping

import grpc
import pytest
Expand All @@ -12,7 +12,7 @@

import weaviate
from weaviate.connect.base import ConnectionParams, ProtocolParams
from weaviate.proto.v1 import tenants_pb2, weaviate_pb2_grpc
from weaviate.proto.v1 import properties_pb2, tenants_pb2, search_get_pb2, weaviate_pb2_grpc

MOCK_IP = "127.0.0.1"
MOCK_PORT = 23536
Expand All @@ -25,7 +25,6 @@
http=ProtocolParams(host=MOCK_IP, port=MOCK_PORT, secure=False),
grpc=ProtocolParams(host=MOCK_IP, port=MOCK_PORT + 1, secure=False),
)
TENANTS_GET_COLLECTION_NAME = "TenantsGetCollectionName"

# pytest_httpserver 'Authorization' HeaderValueMatcher does not work with Bearer tokens.
# Hence, overwrite it with the default header value matcher that just compares for equality.
Expand Down Expand Up @@ -77,48 +76,20 @@ def weaviate_auth_mock(weaviate_mock: HTTPServer):
yield weaviate_mock


# Implement the health check service
class MockHealthServicer(HealthServicer):
def Check(self, request: HealthCheckRequest, context: ServicerContext) -> HealthCheckResponse:
return HealthCheckResponse(status=HealthCheckResponse.SERVING)


class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
def TenantsGet(
self, request: tenants_pb2.TenantsGetRequest, context: ServicerContext
) -> tenants_pb2.TenantsGetReply:
return tenants_pb2.TenantsGetReply(
tenants=[
tenants_pb2.Tenant(
name="tenant1", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_HOT
),
tenants_pb2.Tenant(
name="tenant2", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_COLD
),
tenants_pb2.Tenant(
name="tenant3", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_FROZEN
),
tenants_pb2.Tenant(
name="tenant4", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_FREEZING
),
tenants_pb2.Tenant(
name="tenant5", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_UNFREEZING
),
tenants_pb2.Tenant(
name="tenant6", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_UNFROZEN
),
]
)


@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def start_grpc_server() -> Generator[grpc.Server, None, None]:
# Create a gRPC server
server: grpc.Server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))

# Implement the health check service
class MockHealthServicer(HealthServicer):
def Check(
self, request: HealthCheckRequest, context: ServicerContext
) -> HealthCheckResponse:
return HealthCheckResponse(status=HealthCheckResponse.SERVING)

# Add the health check service to the server
add_HealthServicer_to_server(MockHealthServicer(), server)
weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), server)

# Listen on a specific port
server.add_insecure_port(f"[::]:{MOCK_PORT_GRPC}")
Expand All @@ -140,5 +111,62 @@ def weaviate_client(


@pytest.fixture(scope="function")
def tenants_collection(weaviate_client: weaviate.WeaviateClient) -> weaviate.collections.Collection:
return weaviate_client.collections.get(TENANTS_GET_COLLECTION_NAME)
def tenants_collection(
weaviate_client: weaviate.WeaviateClient, start_grpc_server: grpc.Server
) -> weaviate.collections.Collection:
class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
def TenantsGet(
self, request: tenants_pb2.TenantsGetRequest, context: ServicerContext
) -> tenants_pb2.TenantsGetReply:
return tenants_pb2.TenantsGetReply(
tenants=[
tenants_pb2.Tenant(
name="tenant1", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_HOT
),
tenants_pb2.Tenant(
name="tenant2", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_COLD
),
tenants_pb2.Tenant(
name="tenant3", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_FROZEN
),
tenants_pb2.Tenant(
name="tenant4", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_FREEZING
),
tenants_pb2.Tenant(
name="tenant5",
activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_UNFREEZING,
),
tenants_pb2.Tenant(
name="tenant6", activity_status=tenants_pb2.TENANT_ACTIVITY_STATUS_UNFROZEN
),
]
)

weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
return weaviate_client.collections.get("TenantsGetCollectionName")


@pytest.fixture(scope="function")
def year_zero_collection(
weaviate_client: weaviate.WeaviateClient, start_grpc_server: grpc.Server
) -> weaviate.collections.Collection:
class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
def Search(
self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext
) -> search_get_pb2.SearchReply:
zero_date: properties_pb2.Value.date_value = properties_pb2.Value(
date_value="0000-01-30T00:00:00Z"
)
date_prop: Mapping[str, properties_pb2.Value.date_value] = {"date": zero_date}
return search_get_pb2.SearchReply(
results=[
search_get_pb2.SearchResult(
properties=search_get_pb2.PropertiesResult(
non_ref_props=properties_pb2.Properties(fields=date_prop)
)
),
]
)

weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
return weaviate_client.collections.get("YearZeroCollection")
41 changes: 3 additions & 38 deletions mock_tests/test_collection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import json
import time
from typing import Any, Dict, Mapping
from typing import Any, Dict

import grpc
import pytest
Expand All @@ -28,7 +28,6 @@
from weaviate.connect.base import ConnectionParams, ProtocolParams
from weaviate.connect.integrations import _IntegrationConfig
from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateStartUpError
from weaviate.proto.v1 import weaviate_pb2_grpc, search_get_pb2, properties_pb2

ACCESS_TOKEN = "HELLO!IamAnAccessToken"
REFRESH_TOKEN = "UseMeToRefreshYourAccessToken"
Expand Down Expand Up @@ -343,43 +342,9 @@ def test_integration_config(
weaviate_no_auth_mock.check_assertions()


def test_year_zero(weaviate_no_auth_mock: HTTPServer, start_grpc_server: grpc.Server) -> None:
zero_date: properties_pb2.Value.date_value = properties_pb2.Value(
date_value="0000-01-30T00:00:00Z"
)
date_prop: Mapping[str, properties_pb2.Value.date_value] = {"date": zero_date}

class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
def Search(
self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext
) -> search_get_pb2.SearchReply:
return search_get_pb2.SearchReply(
results=[
search_get_pb2.SearchResult(
properties=search_get_pb2.PropertiesResult(
non_ref_props=properties_pb2.Properties(fields=date_prop)
)
),
]
)

weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
schema = {
"class": "Test",
"properties": [],
"vectorizer": "none",
}
weaviate_no_auth_mock.expect_request("/v1/schema/Test").respond_with_json(
response_json=schema, status=200
)

client = weaviate.connect_to_local(
port=MOCK_PORT,
host=MOCK_IP,
grpc_port=MOCK_PORT_GRPC,
)
def test_year_zero(year_zero_collection: weaviate.collections.Collection) -> None:
with pytest.warns(UserWarning) as recwarn:
objs = client.collections.get("Test").query.fetch_objects().objects
objs = year_zero_collection.query.fetch_objects().objects
assert objs[0].properties["date"] == datetime.datetime.min

assert str(recwarn[0].message).startswith("Con004")

0 comments on commit 9f1f3af

Please sign in to comment.