From d866a1bd09e9e79a02e0ed42a9d59f30c85bb7bc Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Thu, 16 Nov 2023 10:33:50 +0900 Subject: [PATCH] Skip interceptor for methods that are not registered in the server (#48) * Skip interceptor for methods that are not registered in the server Co-authored-by: Xander Johnson * Bump version --------- Co-authored-by: Xander Johnson --- pyproject.toml | 2 +- src/grpc_interceptor/server.py | 4 ++++ tests/test_server.py | 10 +++++++--- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8da8126..9d996a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "grpc-interceptor" -version = "0.15.3" +version = "0.15.4" description = "Simplifies gRPC interceptors" license = "MIT" readme = "README.md" diff --git a/src/grpc_interceptor/server.py b/src/grpc_interceptor/server.py index 6cd6ebc..6733d0f 100644 --- a/src/grpc_interceptor/server.py +++ b/src/grpc_interceptor/server.py @@ -121,6 +121,10 @@ async def intercept_service(self, continuation, handler_call_details): have a public name. Do not override it, unless you know what you're doing. """ next_handler = await continuation(handler_call_details) + # Returns None if the method isn't implemented. + if not next_handler: + return + handler_factory, next_handler_method = _get_factory_and_method(next_handler) if next_handler.response_streaming: diff --git a/tests/test_server.py b/tests/test_server.py index f7ca614..d0d9553 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -177,15 +177,19 @@ def test_aborting_interceptor(aio): assert e.value.details() == "oh no" -def test_method_not_found(): +@pytest.mark.parametrize("aio", [False, True]) +def test_method_not_found(aio): """Calling undefined endpoints should return Unimplemented. Interceptors are not invoked when the RPC call is not handled. """ - intr = CountingInterceptor() + intr_type = AsyncCountingInterceptor if aio else CountingInterceptor + intr = intr_type() interceptors = [intr] - with dummy_channel(special_cases={}, interceptors=interceptors) as channel: + with dummy_channel( + special_cases={}, interceptors=interceptors, aio_server=aio + ) as channel: with pytest.raises(grpc.RpcError) as e: channel.unary_unary( "/DummyService/Unimplemented",