Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: avoid checking instance on each stream call #529

Merged
merged 13 commits into from
May 3, 2024
24 changes: 10 additions & 14 deletions google/api_core/grpc_helpers_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ class _WrappedStreamStreamCall(

def _wrap_unary_errors(callable_):
"""Map errors for Unary-Unary async callables."""
grpc_helpers._patch_callable_name(callable_)

@functools.wraps(callable_)
def error_remapped_callable(*args, **kwargs):
Expand All @@ -169,23 +168,13 @@ def error_remapped_callable(*args, **kwargs):
return error_remapped_callable


def _wrap_stream_errors(callable_):
def _wrap_stream_errors(callable_, wrapper_type):
"""Map errors for streaming RPC async callables."""
grpc_helpers._patch_callable_name(callable_)

@functools.wraps(callable_)
async def error_remapped_callable(*args, **kwargs):
call = callable_(*args, **kwargs)

if isinstance(call, aio.UnaryStreamCall):
call = _WrappedUnaryStreamCall().with_call(call)
elif isinstance(call, aio.StreamUnaryCall):
call = _WrappedStreamUnaryCall().with_call(call)
elif isinstance(call, aio.StreamStreamCall):
call = _WrappedStreamStreamCall().with_call(call)
else:
raise TypeError("Unexpected type of call %s" % type(call))

call = wrapper_type().with_call(call)
await call.wait_for_connection()
return call

Expand All @@ -207,10 +196,17 @@ def wrap_errors(callable_):

Returns: Callable: The wrapped gRPC callable.
"""
grpc_helpers._patch_callable_name(callable_)
if isinstance(callable_, aio.UnaryUnaryMultiCallable):
return _wrap_unary_errors(callable_)
elif isinstance(callable_, aio.UnaryStreamMultiCallable):
return _wrap_stream_errors(callable_, _WrappedUnaryStreamCall)
elif isinstance(callable_, aio.StreamUnaryMultiCallable):
return _wrap_stream_errors(callable_, _WrappedStreamUnaryCall)
elif isinstance(callable_, aio.StreamStreamMultiCallable):
return _wrap_stream_errors(callable_, _WrappedStreamStreamCall)
else:
return _wrap_stream_errors(callable_)
raise TypeError("Unexpected type of callable: {}".format(type(callable_)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (not a blocker, consider for the future): consider in cases like this raising a custom subclass of TypeError, so that we can check the exception type in tests



def create_channel(
Expand Down
78 changes: 63 additions & 15 deletions tests/asyncio/test_grpc_helpers_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,40 @@ async def test_common_methods_in_wrapped_call():
assert mock_call.wait_for_connection.call_count == 1


@pytest.mark.asyncio
@pytest.mark.parametrize(
"callable_type,expected_wrapper_type",
[
(grpc.aio.UnaryStreamMultiCallable, grpc_helpers_async._WrappedUnaryStreamCall),
(grpc.aio.StreamUnaryMultiCallable, grpc_helpers_async._WrappedStreamUnaryCall),
(
grpc.aio.StreamStreamMultiCallable,
grpc_helpers_async._WrappedStreamStreamCall,
),
],
)
async def test_wrap_errors_w_stream_type(callable_type, expected_wrapper_type):
class ConcreteMulticallable(callable_type):
def __call__(self, *args, **kwargs):
raise NotImplementedError("Should not be called")

with mock.patch.object(
grpc_helpers_async, "_wrap_stream_errors"
) as wrap_stream_errors:
callable_ = ConcreteMulticallable()
grpc_helpers_async.wrap_errors(callable_)
assert wrap_stream_errors.call_count == 1
wrap_stream_errors.assert_called_once_with(callable_, expected_wrapper_type)


@pytest.mark.asyncio
async def test_wrap_stream_errors_unary_stream():
mock_call = mock.Mock(aio.UnaryStreamCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedUnaryStreamCall
)

await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
Expand All @@ -114,7 +142,9 @@ async def test_wrap_stream_errors_stream_unary():
mock_call = mock.Mock(aio.StreamUnaryCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamUnaryCall
)

await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
Expand All @@ -126,22 +156,26 @@ async def test_wrap_stream_errors_stream_stream():
mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)

await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
assert mock_call.wait_for_connection.call_count == 1


@pytest.mark.asyncio
async def test_wrap_stream_errors_type_error():
async def test_wrap_errors_type_error():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to comment that the explicit type of error you're checking for is the "Unexpected type of callable" (which now happens at wrapping time). As per my other comment, if that were a specialized subclass of TypeError, we could check for that and the code would be self-documenting, so we wouldn't need an extra comment.

"""
If wrap_errors is called with an unexpected type, it should raise a TypeError.
"""
mock_call = mock.Mock()
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)

with pytest.raises(TypeError):
await wrapped_callable()
with pytest.raises(TypeError) as exc:
grpc_helpers_async.wrap_errors(multicallable)
assert "Unexpected type" in str(exc.value)


@pytest.mark.asyncio
Expand All @@ -151,7 +185,9 @@ async def test_wrap_stream_errors_raised():
mock_call.wait_for_connection = mock.AsyncMock(side_effect=[grpc_error])
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)

with pytest.raises(exceptions.InvalidArgument):
await wrapped_callable()
Expand All @@ -166,7 +202,9 @@ async def test_wrap_stream_errors_read():
mock_call.read = mock.AsyncMock(side_effect=grpc_error)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)

wrapped_call = await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
Expand All @@ -189,7 +227,9 @@ async def test_wrap_stream_errors_aiter():
mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)
wrapped_call = await wrapped_callable()

with pytest.raises(exceptions.InvalidArgument) as exc_info:
Expand All @@ -210,7 +250,9 @@ async def test_wrap_stream_errors_aiter_non_rpc_error():
mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)
wrapped_call = await wrapped_callable()

with pytest.raises(TypeError) as exc_info:
Expand All @@ -224,7 +266,9 @@ async def test_wrap_stream_errors_aiter_called_multiple_times():
mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)
wrapped_call = await wrapped_callable()

assert wrapped_call.__aiter__() == wrapped_call.__aiter__()
Expand All @@ -239,7 +283,9 @@ async def test_wrap_stream_errors_write():
mock_call.done_writing = mock.AsyncMock(side_effect=[None, grpc_error])
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)

wrapped_call = await wrapped_callable()

Expand Down Expand Up @@ -295,7 +341,9 @@ def test_wrap_errors_streaming(wrap_stream_errors):
result = grpc_helpers_async.wrap_errors(callable_)

assert result == wrap_stream_errors.return_value
wrap_stream_errors.assert_called_once_with(callable_)
wrap_stream_errors.assert_called_once_with(
callable_, grpc_helpers_async._WrappedUnaryStreamCall
)


@pytest.mark.parametrize(
Expand Down