diff --git a/google/api_core/grpc_helpers_async.py b/google/api_core/grpc_helpers_async.py index 9423d2b6..718b5f05 100644 --- a/google/api_core/grpc_helpers_async.py +++ b/google/api_core/grpc_helpers_async.py @@ -159,7 +159,6 @@ class _WrappedStreamStreamCall( def _wrap_unary_errors(callable_): """Map errors for Unary-Unary async callables.""" - grpc_helpers._patch_callable_name(callable_) @functools.wraps(callable_) def error_remapped_callable(*args, **kwargs): @@ -169,23 +168,13 @@ def error_remapped_callable(*args, **kwargs): return error_remapped_callable -def _wrap_stream_errors(callable_): +def _wrap_stream_errors(callable_, wrapper_type): """Map errors for streaming RPC async callables.""" - grpc_helpers._patch_callable_name(callable_) @functools.wraps(callable_) async def error_remapped_callable(*args, **kwargs): call = callable_(*args, **kwargs) - - if isinstance(call, aio.UnaryStreamCall): - call = _WrappedUnaryStreamCall().with_call(call) - elif isinstance(call, aio.StreamUnaryCall): - call = _WrappedStreamUnaryCall().with_call(call) - elif isinstance(call, aio.StreamStreamCall): - call = _WrappedStreamStreamCall().with_call(call) - else: - raise TypeError("Unexpected type of call %s" % type(call)) - + call = wrapper_type().with_call(call) await call.wait_for_connection() return call @@ -207,10 +196,17 @@ def wrap_errors(callable_): Returns: Callable: The wrapped gRPC callable. """ + grpc_helpers._patch_callable_name(callable_) if isinstance(callable_, aio.UnaryUnaryMultiCallable): return _wrap_unary_errors(callable_) + elif isinstance(callable_, aio.UnaryStreamMultiCallable): + return _wrap_stream_errors(callable_, _WrappedUnaryStreamCall) + elif isinstance(callable_, aio.StreamUnaryMultiCallable): + return _wrap_stream_errors(callable_, _WrappedStreamUnaryCall) + elif isinstance(callable_, aio.StreamStreamMultiCallable): + return _wrap_stream_errors(callable_, _WrappedStreamStreamCall) else: - return _wrap_stream_errors(callable_) + raise TypeError("Unexpected type of callable: {}".format(type(callable_))) def create_channel( diff --git a/tests/asyncio/test_grpc_helpers_async.py b/tests/asyncio/test_grpc_helpers_async.py index 6bde59ca..6e08f10a 100644 --- a/tests/asyncio/test_grpc_helpers_async.py +++ b/tests/asyncio/test_grpc_helpers_async.py @@ -97,12 +97,40 @@ async def test_common_methods_in_wrapped_call(): assert mock_call.wait_for_connection.call_count == 1 +@pytest.mark.asyncio +@pytest.mark.parametrize( + "callable_type,expected_wrapper_type", + [ + (grpc.aio.UnaryStreamMultiCallable, grpc_helpers_async._WrappedUnaryStreamCall), + (grpc.aio.StreamUnaryMultiCallable, grpc_helpers_async._WrappedStreamUnaryCall), + ( + grpc.aio.StreamStreamMultiCallable, + grpc_helpers_async._WrappedStreamStreamCall, + ), + ], +) +async def test_wrap_errors_w_stream_type(callable_type, expected_wrapper_type): + class ConcreteMulticallable(callable_type): + def __call__(self, *args, **kwargs): + raise NotImplementedError("Should not be called") + + with mock.patch.object( + grpc_helpers_async, "_wrap_stream_errors" + ) as wrap_stream_errors: + callable_ = ConcreteMulticallable() + grpc_helpers_async.wrap_errors(callable_) + assert wrap_stream_errors.call_count == 1 + wrap_stream_errors.assert_called_once_with(callable_, expected_wrapper_type) + + @pytest.mark.asyncio async def test_wrap_stream_errors_unary_stream(): mock_call = mock.Mock(aio.UnaryStreamCall, autospec=True) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedUnaryStreamCall + ) await wrapped_callable(1, 2, three="four") multicallable.assert_called_once_with(1, 2, three="four") @@ -114,7 +142,9 @@ async def test_wrap_stream_errors_stream_unary(): mock_call = mock.Mock(aio.StreamUnaryCall, autospec=True) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamUnaryCall + ) await wrapped_callable(1, 2, three="four") multicallable.assert_called_once_with(1, 2, three="four") @@ -126,7 +156,9 @@ async def test_wrap_stream_errors_stream_stream(): mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) await wrapped_callable(1, 2, three="four") multicallable.assert_called_once_with(1, 2, three="four") @@ -134,14 +166,16 @@ async def test_wrap_stream_errors_stream_stream(): @pytest.mark.asyncio -async def test_wrap_stream_errors_type_error(): +async def test_wrap_errors_type_error(): + """ + If wrap_errors is called with an unexpected type, it should raise a TypeError. + """ mock_call = mock.Mock() multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) - - with pytest.raises(TypeError): - await wrapped_callable() + with pytest.raises(TypeError) as exc: + grpc_helpers_async.wrap_errors(multicallable) + assert "Unexpected type" in str(exc.value) @pytest.mark.asyncio @@ -151,7 +185,9 @@ async def test_wrap_stream_errors_raised(): mock_call.wait_for_connection = mock.AsyncMock(side_effect=[grpc_error]) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) with pytest.raises(exceptions.InvalidArgument): await wrapped_callable() @@ -166,7 +202,9 @@ async def test_wrap_stream_errors_read(): mock_call.read = mock.AsyncMock(side_effect=grpc_error) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) wrapped_call = await wrapped_callable(1, 2, three="four") multicallable.assert_called_once_with(1, 2, three="four") @@ -189,7 +227,9 @@ async def test_wrap_stream_errors_aiter(): mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) wrapped_call = await wrapped_callable() with pytest.raises(exceptions.InvalidArgument) as exc_info: @@ -210,7 +250,9 @@ async def test_wrap_stream_errors_aiter_non_rpc_error(): mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) wrapped_call = await wrapped_callable() with pytest.raises(TypeError) as exc_info: @@ -224,7 +266,9 @@ async def test_wrap_stream_errors_aiter_called_multiple_times(): mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) wrapped_call = await wrapped_callable() assert wrapped_call.__aiter__() == wrapped_call.__aiter__() @@ -239,7 +283,9 @@ async def test_wrap_stream_errors_write(): mock_call.done_writing = mock.AsyncMock(side_effect=[None, grpc_error]) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) wrapped_call = await wrapped_callable() @@ -295,7 +341,9 @@ def test_wrap_errors_streaming(wrap_stream_errors): result = grpc_helpers_async.wrap_errors(callable_) assert result == wrap_stream_errors.return_value - wrap_stream_errors.assert_called_once_with(callable_) + wrap_stream_errors.assert_called_once_with( + callable_, grpc_helpers_async._WrappedUnaryStreamCall + ) @pytest.mark.parametrize(