Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(v1): resolve issue handling protobuf responses in rest streaming #609

Merged
merged 2 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions google/api_core/rest_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,31 @@

from collections import deque
import string
from typing import Deque
from typing import Deque, Union

import proto
import requests
import google.protobuf.message
from google.protobuf.json_format import Parse


class ResponseIterator:
"""Iterator over REST API responses.

Args:
response (requests.Response): An API response object.
response_message_cls (Callable[proto.Message]): A proto
response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
class expected to be returned from an API.

Raises:
ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
"""

def __init__(self, response: requests.Response, response_message_cls):
def __init__(
self,
response: requests.Response,
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
):
self._response = response
self._response_message_cls = response_message_cls
# Inner iterator over HTTP response's content.
Expand Down Expand Up @@ -107,7 +117,14 @@ def __next__(self):

def _grab(self):
# Add extra quotes to make json.loads happy.
return self._response_message_cls.from_json(self._ready_objs.popleft())
if issubclass(self._response_message_cls, proto.Message):
return self._response_message_cls.from_json(self._ready_objs.popleft())
elif issubclass(self._response_message_cls, google.protobuf.message.Message):
return Parse(self._ready_objs.popleft(), self._response_message_cls())
else:
raise ValueError(
"Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
)

def __iter__(self):
return self
227 changes: 175 additions & 52 deletions tests/unit/test_rest_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
import requests

from google.api_core import rest_streaming
from google.api import http_pb2
from google.api import httpbody_pb2
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
from google.protobuf.json_format import MessageToJson


__protobuf__ = proto.module(package=__name__)
Expand Down Expand Up @@ -98,7 +101,10 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes:
# json.dumps returns a string surrounded with quotes that need to be stripped
# in order to be an actual JSON.
json_responses = [
self._response_message_cls.to_json(r).strip('"') for r in responses
self._response_message_cls.to_json(r).strip('"')
if issubclass(self._response_message_cls, proto.Message)
else MessageToJson(r).strip('"')
for r in responses
]
logging.info(f"Sending JSON stream: {json_responses}")
ret_val = "[{}]".format(",".join(json_responses))
Expand All @@ -114,103 +120,220 @@ def iter_content(self, *args, **kwargs):
)


@pytest.mark.parametrize("random_split", [False])
def test_next_simple(random_split):
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus",
[(False, True), (False, False)],
)
def test_next_simple(random_split, resp_message_is_proto_plus):
if resp_message_is_proto_plus:
response_type = EchoResponse
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
else:
response_type = httpbody_pb2.HttpBody
responses = [
httpbody_pb2.HttpBody(content_type="hello world"),
httpbody_pb2.HttpBody(content_type="yes"),
]

resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=EchoResponse
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_nested(random_split):
responses = [
Song(title="some song", composer=Composer(given_name="some name")),
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
]
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus",
[
(True, True),
(False, True),
(True, False),
(False, False),
],
)
def test_next_nested(random_split, resp_message_is_proto_plus):
if resp_message_is_proto_plus:
response_type = Song
responses = [
Song(title="some song", composer=Composer(given_name="some name")),
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
]
else:
# Although `http_pb2.HttpRule`` is used in the response, any response message
# can be used which meets this criteria for the test of having a nested field.
response_type = http_pb2.HttpRule
responses = [
http_pb2.HttpRule(
selector="some selector",
custom=http_pb2.CustomHttpPattern(kind="some kind"),
),
http_pb2.HttpRule(
selector="another selector",
custom=http_pb2.CustomHttpPattern(path="some path"),
),
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, Song)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_stress(random_split):
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus",
[
(True, True),
(False, True),
(True, False),
(False, False),
],
)
def test_next_stress(random_split, resp_message_is_proto_plus):
n = 50
responses = [
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
for i in range(n)
]
if resp_message_is_proto_plus:
response_type = Song
responses = [
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
for i in range(n)
]
else:
response_type = http_pb2.HttpRule
responses = [
http_pb2.HttpRule(
selector="selector_%d" % i,
custom=http_pb2.CustomHttpPattern(path="path_%d" % i),
)
for i in range(n)
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, Song)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


@pytest.mark.parametrize("random_split", [True, False])
def test_next_escaped_characters_in_string(random_split):
composer_with_relateds = Composer()
relateds = ["Artist A", "Artist B"]
composer_with_relateds.relateds = relateds

responses = [
Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")),
Song(
title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\")
),
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
]
@pytest.mark.parametrize(
"random_split,resp_message_is_proto_plus",
[
(True, True),
(False, True),
(True, False),
(False, False),
],
)
def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_plus):
if resp_message_is_proto_plus:
response_type = Song
composer_with_relateds = Composer()
relateds = ["Artist A", "Artist B"]
composer_with_relateds.relateds = relateds

responses = [
Song(
title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")
),
Song(
title='{"this is weird": "totally"}',
composer=Composer(given_name="\\{}\\"),
),
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
]
else:
response_type = http_pb2.Http
responses = [
http_pb2.Http(
rules=[
http_pb2.HttpRule(
selector='ti"tle\nfoo\tbar{}',
custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"),
)
]
),
http_pb2.Http(
rules=[
http_pb2.HttpRule(
selector='{"this is weird": "totally"}',
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
)
]
),
http_pb2.Http(
rules=[
http_pb2.HttpRule(
selector='\\{"key": ["value",]}\\',
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
)
]
),
]
resp = ResponseMock(
responses=responses, random_split=random_split, response_cls=Song
responses=responses, random_split=random_split, response_cls=response_type
)
itr = rest_streaming.ResponseIterator(resp, Song)
itr = rest_streaming.ResponseIterator(resp, response_type)
assert list(itr) == responses


def test_next_not_array():
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
def test_next_not_array(response_type):
with patch.object(
ResponseMock, "iter_content", return_value=iter('{"hello": 0}')
) as mock_method:

resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
mock_method.assert_called_once()


def test_cancel():
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
def test_cancel(response_type):
with patch.object(ResponseMock, "close", return_value=None) as mock_method:
resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
itr.cancel()
mock_method.assert_called_once()


def test_check_buffer():
@pytest.mark.parametrize(
"response_type,return_value",
[
(EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")),
(httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")),
],
)
def test_check_buffer(response_type, return_value):
with patch.object(
ResponseMock,
"_parse_responses",
return_value=bytes('[{"content": "hello"}, {', "utf-8"),
return_value=return_value,
):
resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
next(itr)


def test_next_html():
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
def test_next_html(response_type):
with patch.object(
ResponseMock, "iter_content", return_value=iter("<!DOCTYPE html><html></html>")
) as mock_method:

resp = ResponseMock(responses=[], response_cls=EchoResponse)
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
resp = ResponseMock(responses=[], response_cls=response_type)
itr = rest_streaming.ResponseIterator(resp, response_type)
with pytest.raises(ValueError):
next(itr)
mock_method.assert_called_once()


def test_invalid_response_class():
class SomeClass:
pass

resp = ResponseMock(responses=[], response_cls=SomeClass)
response_iterator = rest_streaming.ResponseIterator(resp, SomeClass)
with pytest.raises(
ValueError,
match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message",
):
response_iterator._grab()
Loading