diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py b/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py index e098491fe..4d9097ff9 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py @@ -20,7 +20,7 @@ import logging import threading import typing -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Callable, Iterable, List, Optional, Tuple, Union import uuid import grpc # type: ignore @@ -49,7 +49,6 @@ if typing.TYPE_CHECKING: # pragma: NO COVER from google.cloud.pubsub_v1 import subscriber - from google.protobuf.internal import containers _LOGGER = logging.getLogger(__name__) @@ -141,9 +140,7 @@ def _get_status(exc: exceptions.GoogleAPICallError,) -> Optional["status_pb2.Sta return None -def _get_ack_errors( - exc: exceptions.GoogleAPICallError, -) -> Optional["containers.ScalarMap"]: +def _get_ack_errors(exc: exceptions.GoogleAPICallError,) -> Optional[Dict[str, str]]: status = _get_status(exc) if not status: _LOGGER.debug("Unable to get status of errored RPC.") @@ -159,8 +156,8 @@ def _get_ack_errors( def _process_requests( error_status: Optional["status_pb2.Status"], - ack_reqs_dict: "containers.ScalarMap", - errors_dict: Optional["containers.ScalarMap"], + ack_reqs_dict: Dict[str, requests.AckRequest], + errors_dict: Optional[Dict[str, str]], ): """Process requests by referring to error_status and errors_dict. @@ -182,9 +179,9 @@ def _process_requests( exc = AcknowledgeError(AcknowledgeStatus.INVALID_ACK_ID, info=None) else: exc = AcknowledgeError(AcknowledgeStatus.OTHER, exactly_once_error) - future = ack_reqs_dict[ack_id].future - future.set_exception(exc) + if future is not None: + future.set_exception(exc) requests_completed.append(ack_reqs_dict[ack_id]) # Temporary GRPC errors are retried elif ( @@ -201,12 +198,14 @@ def _process_requests( else: exc = AcknowledgeError(AcknowledgeStatus.OTHER, str(error_status)) future = ack_reqs_dict[ack_id].future - future.set_exception(exc) + if future is not None: + future.set_exception(exc) requests_completed.append(ack_reqs_dict[ack_id]) # Since no error occurred, requests with futures are completed successfully. elif ack_reqs_dict[ack_id].future: future = ack_reqs_dict[ack_id].future # success + assert future is not None future.set_result(AcknowledgeStatus.SUCCESS) requests_completed.append(ack_reqs_dict[ack_id]) # All other requests are considered completed. diff --git a/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py b/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py index 36f82b621..e9554deda 100644 --- a/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py +++ b/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py @@ -1713,6 +1713,21 @@ def test_process_requests_no_errors(): assert not requests_to_retry +def test_process_requests_no_errors_no_future(): + # no errors, request should be completed, even when future is None. + ack_reqs_dict = { + "ackid1": requests.AckRequest( + ack_id="ackid1", byte_size=0, time_to_ack=20, ordering_key="", future=None + ) + } + errors_dict = {} + requests_completed, requests_to_retry = streaming_pull_manager._process_requests( + None, ack_reqs_dict, errors_dict + ) + assert requests_completed[0].ack_id == "ackid1" + assert not requests_to_retry + + def test_process_requests_permanent_error_raises_exception(): # a permanent error raises an exception future = futures.Future() @@ -1735,6 +1750,40 @@ def test_process_requests_permanent_error_raises_exception(): assert not requests_to_retry +def test_process_requests_permanent_error_other_raises_exception(): + # a permanent error of other raises an exception + future = futures.Future() + ack_reqs_dict = { + "ackid1": requests.AckRequest( + ack_id="ackid1", byte_size=0, time_to_ack=20, ordering_key="", future=future + ) + } + errors_dict = {"ackid1": "PERMANENT_FAILURE_OTHER"} + requests_completed, requests_to_retry = streaming_pull_manager._process_requests( + None, ack_reqs_dict, errors_dict + ) + assert requests_completed[0].ack_id == "ackid1" + with pytest.raises(subscriber_exceptions.AcknowledgeError) as exc_info: + future.result() + assert exc_info.value.error_code == subscriber_exceptions.AcknowledgeStatus.OTHER + assert not requests_to_retry + + +def test_process_requests_permanent_error_other_raises_exception_no_future(): + # with a permanent error, request is completed even when future is None. + ack_reqs_dict = { + "ackid1": requests.AckRequest( + ack_id="ackid1", byte_size=0, time_to_ack=20, ordering_key="", future=None + ) + } + errors_dict = {"ackid1": "PERMANENT_FAILURE_OTHER"} + requests_completed, requests_to_retry = streaming_pull_manager._process_requests( + None, ack_reqs_dict, errors_dict + ) + assert requests_completed[0].ack_id == "ackid1" + assert not requests_to_retry + + def test_process_requests_transient_error_returns_request_for_retrying(): # a transient error returns the request in `requests_to_retry` future = futures.Future() @@ -1872,6 +1921,23 @@ def test_process_requests_other_error_status_raises_exception(): assert not requests_to_retry +def test_process_requests_other_error_status_raises_exception_no_future(): + # with an unrecognized error status, requests are completed, even when + # future is None. + ack_reqs_dict = { + "ackid1": requests.AckRequest( + ack_id="ackid1", byte_size=0, time_to_ack=20, ordering_key="", future=None + ) + } + st = status_pb2.Status() + st.code = code_pb2.Code.OUT_OF_RANGE + requests_completed, requests_to_retry = streaming_pull_manager._process_requests( + st, ack_reqs_dict, None + ) + assert requests_completed[0].ack_id == "ackid1" + assert not requests_to_retry + + def test_process_requests_mixed_success_and_failure_acks(): # mixed success and failure (acks) future1 = futures.Future()