Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(storage): fix timeout error and related unit tests #9992

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions storage/google/cloud/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self, client):
self._requests = []
self._target_objects = []

def _do_request(self, method, url, headers, data, target_object):
def _do_request(self, method, url, headers, data, target_object, timeout=None):
"""Override Connection: defer actual HTTP request.

Only allow up to ``_MAX_BATCH_SIZE`` requests to be deferred.
Expand All @@ -173,6 +173,13 @@ def _do_request(self, method, url, headers, data, target_object):
connection. Here we defer an HTTP request and complete
initialization of the object at a later time.

:type timeout: float or tuple
:param timeout: (optional) The amount of time, in seconds, to wait
for the server response. By default, the method waits indefinitely.

Can also be passed as a tuple (connect_timeout, read_timeout).
See :meth:`requests.Session.request` documentation for details.

:rtype: tuple of ``response`` (a dictionary of sorts)
and ``content`` (a string).
:returns: The HTTP response object and the content of the response.
Expand All @@ -181,7 +188,7 @@ def _do_request(self, method, url, headers, data, target_object):
raise ValueError(
"Too many deferred requests (max %d)" % self._MAX_BATCH_SIZE
)
self._requests.append((method, url, headers, data))
self._requests.append((method, url, headers, data, timeout))
result = _FutureDict()
self._target_objects.append(target_object)
if target_object is not None:
Expand All @@ -200,7 +207,7 @@ def _prepare_batch_request(self):

multi = MIMEMultipart()

for method, uri, headers, body in self._requests:
for method, uri, headers, body, timeout in self._requests:
subrequest = MIMEApplicationHTTP(method, uri, headers, body)
multi.attach(subrequest)

Expand All @@ -215,7 +222,7 @@ def _prepare_batch_request(self):

# Strip off redundant header text
_, body = payload.split("\n\n", 1)
return dict(multi._headers), body
return dict(multi._headers), body, timeout

def _finish_futures(self, responses):
"""Apply all the batch responses to the futures created.
Expand Down Expand Up @@ -251,15 +258,16 @@ def finish(self):
:rtype: list of tuples
:returns: one ``(headers, payload)`` tuple per deferred request.
"""
headers, body = self._prepare_batch_request()
headers, body, timeout = self._prepare_batch_request()

url = "%s/batch/storage/v1" % self.API_BASE_URL

# Use the private ``_base_connection`` rather than the property
# ``_connection``, since the property may be this
# current batch.

response = self._client._base_connection._make_request(
"POST", url, data=body, headers=headers
"POST", url, data=body, headers=headers, timeout=timeout
)
responses = list(_unpack_batch_response(response))
self._finish_futures(responses)
Expand Down
6 changes: 5 additions & 1 deletion storage/tests/unit/test__http.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def test_extra_headers(self):
}
expected_uri = conn.build_api_url("/rainbow")
http.request.assert_called_once_with(
data=req_data, headers=expected_headers, method="GET", url=expected_uri
data=req_data,
headers=expected_headers,
method="GET",
timeout=None,
url=expected_uri,
)

def test_build_api_url_no_extra_query_params(self):
Expand Down
60 changes: 38 additions & 22 deletions storage/tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test__make_request_GET_normal(self):
batch = self._make_one(connection)
target = _MockObject()

response = batch._make_request("GET", url, target_object=target)
response = batch._make_request("GET", url, target_object=target, timeout=None)

# Check the respone
self.assertEqual(response.status_code, 204)
Expand All @@ -147,10 +147,11 @@ def test__make_request_GET_normal(self):
# Check the queued request
self.assertEqual(len(batch._requests), 1)
request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, request_timeout = request
self.assertEqual(request_method, "GET")
self.assertEqual(request_url, url)
self.assertIsNone(request_data)
self.assertIsNone(request_timeout)

def test__make_request_POST_normal(self):
from google.cloud.storage.batch import _FutureDict
Expand All @@ -163,7 +164,7 @@ def test__make_request_POST_normal(self):
target = _MockObject()

response = batch._make_request(
"POST", url, data={"foo": 1}, target_object=target
"POST", url, data={"foo": 1}, target_object=target, timeout=None
)

self.assertEqual(response.status_code, 204)
Expand All @@ -174,10 +175,11 @@ def test__make_request_POST_normal(self):
http.request.assert_not_called()

request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, request_timeout = request
self.assertEqual(request_method, "POST")
self.assertEqual(request_url, url)
self.assertEqual(request_data, data)
self.assertIsNone(request_timeout)

def test__make_request_PATCH_normal(self):
from google.cloud.storage.batch import _FutureDict
Expand All @@ -190,7 +192,7 @@ def test__make_request_PATCH_normal(self):
target = _MockObject()

response = batch._make_request(
"PATCH", url, data={"foo": 1}, target_object=target
"PATCH", url, data={"foo": 1}, target_object=target, timeout=None
)

self.assertEqual(response.status_code, 204)
Expand All @@ -201,10 +203,11 @@ def test__make_request_PATCH_normal(self):
http.request.assert_not_called()

request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, request_timeout = request
self.assertEqual(request_method, "PATCH")
self.assertEqual(request_url, url)
self.assertEqual(request_data, data)
self.assertIsNone(request_timeout)

def test__make_request_DELETE_normal(self):
from google.cloud.storage.batch import _FutureDict
Expand All @@ -214,8 +217,11 @@ def test__make_request_DELETE_normal(self):
connection = _Connection(http=http)
batch = self._make_one(connection)
target = _MockObject()
timeout = 1

response = batch._make_request("DELETE", url, target_object=target)
response = batch._make_request(
"DELETE", url, target_object=target, timeout=timeout
)

# Check the respone
self.assertEqual(response.status_code, 204)
Expand All @@ -228,22 +234,22 @@ def test__make_request_DELETE_normal(self):
# Check the queued request
self.assertEqual(len(batch._requests), 1)
request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, request_timeout = request
self.assertEqual(request_method, "DELETE")
self.assertEqual(request_url, url)
self.assertIsNone(request_data)
self.assertEqual(request_timeout, timeout)

def test__make_request_POST_too_many_requests(self):
url = "http://example.com/api"
http = _make_requests_session([])
connection = _Connection(http=http)
batch = self._make_one(connection)

batch._MAX_BATCH_SIZE = 1
batch._requests.append(("POST", url, {}, {"bar": 2}))
batch._requests.append(("POST", url, {}, {"bar": 2}, 1))

with self.assertRaises(ValueError):
batch._make_request("POST", url, data={"foo": 1})
batch._make_request("POST", url, data={"foo": 1}, timeout=1)

def test_finish_empty(self):
http = _make_requests_session([])
Expand Down Expand Up @@ -314,9 +320,9 @@ def test_finish_nonempty(self):
batch = self._make_one(client)
batch.API_BASE_URL = "http://api.example.com"

batch._do_request("POST", url, {}, {"foo": 1, "bar": 2}, None)
batch._do_request("PATCH", url, {}, {"bar": 3}, None)
batch._do_request("DELETE", url, {}, None, None)
batch._do_request("POST", url, {}, {"foo": 1, "bar": 2}, None, None)
batch._do_request("PATCH", url, {}, {"bar": 3}, None, None)
batch._do_request("DELETE", url, {}, None, None, None)
result = batch.finish()

self.assertEqual(len(result), len(batch._requests))
Expand All @@ -340,7 +346,11 @@ def test_finish_nonempty(self):

expected_url = "{}/batch/storage/v1".format(batch.API_BASE_URL)
http.request.assert_called_once_with(
method="POST", url=expected_url, headers=mock.ANY, data=mock.ANY
method="POST",
url=expected_url,
headers=mock.ANY,
data=mock.ANY,
timeout=mock.ANY,
)

request_info = self._get_mutlipart_request(http)
Expand Down Expand Up @@ -369,7 +379,7 @@ def test_finish_responses_mismatch(self):
batch = self._make_one(client)
batch.API_BASE_URL = "http://api.example.com"

batch._requests.append(("GET", url, {}, None))
batch._requests.append(("GET", url, {}, None, 1))
with self.assertRaises(ValueError):
batch.finish()

Expand All @@ -389,8 +399,8 @@ def test_finish_nonempty_with_status_failure(self):
target1 = _MockObject()
target2 = _MockObject()

batch._do_request("GET", url, {}, None, target1)
batch._do_request("GET", url, {}, None, target2)
batch._do_request("GET", url, {}, None, target1, None)
batch._do_request("GET", url, {}, None, target2, None)

# Make sure futures are not populated.
self.assertEqual(
Expand All @@ -406,7 +416,11 @@ def test_finish_nonempty_with_status_failure(self):

expected_url = "{}/batch/storage/v1".format(batch.API_BASE_URL)
http.request.assert_called_once_with(
method="POST", url=expected_url, headers=mock.ANY, data=mock.ANY
method="POST",
url=expected_url,
headers=mock.ANY,
data=mock.ANY,
timeout=mock.ANY,
)

_, request_body, _, boundary = self._get_mutlipart_request(http)
Expand All @@ -422,7 +436,7 @@ def test_finish_nonempty_non_multipart_response(self):
connection = _Connection(http=http)
client = _Client(connection)
batch = self._make_one(client)
batch._requests.append(("POST", url, {}, {"foo": 1, "bar": 2}))
batch._requests.append(("POST", url, {}, {"foo": 1, "bar": 2}, 1))

with self.assertRaises(ValueError):
batch.finish()
Expand Down Expand Up @@ -620,8 +634,10 @@ class _Connection(object):
def __init__(self, **kw):
self.__dict__.update(kw)

def _make_request(self, method, url, data=None, headers=None):
return self.http.request(url=url, method=method, headers=headers, data=data)
def _make_request(self, method, url, data=None, headers=None, timeout=None):
return self.http.request(
url=url, method=method, headers=headers, data=data, timeout=timeout
)


class _MockObject(object):
Expand Down
Loading