Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add timeout parameter to Batch interface to match google-cloud-core #10010

Merged
merged 5 commits into from
Dec 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions storage/google/cloud/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self, client):
self._requests = []
self._target_objects = []

def _do_request(self, method, url, headers, data, target_object):
def _do_request(self, method, url, headers, data, target_object, timeout=None):
"""Override Connection: defer actual HTTP request.

Only allow up to ``_MAX_BATCH_SIZE`` requests to be deferred.
Expand All @@ -173,6 +173,12 @@ def _do_request(self, method, url, headers, data, target_object):
connection. Here we defer an HTTP request and complete
initialization of the object at a later time.

:type timeout: float or tuple
:param timeout: (optional) The amount of time, in seconds, to wait
for the server response. By default, the method waits indefinitely.
Can also be passed as a tuple (connect_timeout, read_timeout).
See :meth:`requests.Session.request` documentation for details.

:rtype: tuple of ``response`` (a dictionary of sorts)
and ``content`` (a string).
:returns: The HTTP response object and the content of the response.
Expand All @@ -181,7 +187,7 @@ def _do_request(self, method, url, headers, data, target_object):
raise ValueError(
"Too many deferred requests (max %d)" % self._MAX_BATCH_SIZE
)
self._requests.append((method, url, headers, data))
self._requests.append((method, url, headers, data, timeout))
result = _FutureDict()
self._target_objects.append(target_object)
if target_object is not None:
Expand All @@ -200,9 +206,12 @@ def _prepare_batch_request(self):

multi = MIMEMultipart()

for method, uri, headers, body in self._requests:
# Use timeout of last request, default to None (indefinite)
timeout = None
for method, uri, headers, body, _timeout in self._requests:
subrequest = MIMEApplicationHTTP(method, uri, headers, body)
multi.attach(subrequest)
timeout = _timeout

# The `email` package expects to deal with "native" strings
if six.PY3: # pragma: NO COVER Python3
Expand All @@ -215,7 +224,7 @@ def _prepare_batch_request(self):

# Strip off redundant header text
_, body = payload.split("\n\n", 1)
return dict(multi._headers), body
return dict(multi._headers), body, timeout

def _finish_futures(self, responses):
"""Apply all the batch responses to the futures created.
Expand All @@ -230,7 +239,7 @@ def _finish_futures(self, responses):
# until all futures have been populated.
exception_args = None

if len(self._target_objects) != len(responses):
if len(self._target_objects) != len(responses): # pragma: NO COVER
raise ValueError("Expected a response for every request.")

for target_object, subresponse in zip(self._target_objects, responses):
Expand All @@ -251,15 +260,15 @@ def finish(self):
:rtype: list of tuples
:returns: one ``(headers, payload)`` tuple per deferred request.
"""
headers, body = self._prepare_batch_request()
headers, body, timeout = self._prepare_batch_request()

url = "%s/batch/storage/v1" % self.API_BASE_URL

# Use the private ``_base_connection`` rather than the property
# ``_connection``, since the property may be this
# current batch.
response = self._client._base_connection._make_request(
"POST", url, data=body, headers=headers
"POST", url, data=body, headers=headers, timeout=timeout
)
responses = list(_unpack_batch_response(response))
self._finish_futures(responses)
Expand Down Expand Up @@ -313,7 +322,7 @@ def _unpack_batch_response(response):
parser = Parser()
message = _generate_faux_mime_message(parser, response)

if not isinstance(message._payload, list):
if not isinstance(message._payload, list): # pragma: NO COVER
raise ValueError("Bad response: not multi-part")

for subrequest in message._payload:
Expand Down
4 changes: 2 additions & 2 deletions storage/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
# 'Development Status :: 5 - Production/Stable'
release_status = "Development Status :: 5 - Production/Stable"
dependencies = [
"google-auth >= 1.2.0",
"google-cloud-core >= 1.0.3, < 2.0dev",
"google-auth >= 1.9.0, < 2.0dev",
"google-cloud-core >= 1.1.0, < 2.0dev",
"google-resumable-media >= 0.5.0, < 0.6dev",
]
extras = {}
Expand Down
6 changes: 5 additions & 1 deletion storage/tests/unit/test__http.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def test_extra_headers(self):
}
expected_uri = conn.build_api_url("/rainbow")
http.request.assert_called_once_with(
data=req_data, headers=expected_headers, method="GET", url=expected_uri
data=req_data,
headers=expected_headers,
method="GET",
url=expected_uri,
timeout=None,
)

def test_build_api_url_no_extra_query_params(self):
Expand Down
26 changes: 18 additions & 8 deletions storage/tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test__make_request_GET_normal(self):
# Check the queued request
self.assertEqual(len(batch._requests), 1)
request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, _ = request
self.assertEqual(request_method, "GET")
self.assertEqual(request_url, url)
self.assertIsNone(request_data)
Expand All @@ -174,7 +174,7 @@ def test__make_request_POST_normal(self):
http.request.assert_not_called()

request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, _ = request
self.assertEqual(request_method, "POST")
self.assertEqual(request_url, url)
self.assertEqual(request_data, data)
Expand All @@ -201,7 +201,7 @@ def test__make_request_PATCH_normal(self):
http.request.assert_not_called()

request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, _ = request
self.assertEqual(request_method, "PATCH")
self.assertEqual(request_url, url)
self.assertEqual(request_data, data)
Expand All @@ -228,7 +228,7 @@ def test__make_request_DELETE_normal(self):
# Check the queued request
self.assertEqual(len(batch._requests), 1)
request = batch._requests[0]
request_method, request_url, _, request_data = request
request_method, request_url, _, request_data, _ = request
self.assertEqual(request_method, "DELETE")
self.assertEqual(request_url, url)
self.assertIsNone(request_data)
Expand Down Expand Up @@ -340,7 +340,11 @@ def test_finish_nonempty(self):

expected_url = "{}/batch/storage/v1".format(batch.API_BASE_URL)
http.request.assert_called_once_with(
method="POST", url=expected_url, headers=mock.ANY, data=mock.ANY
method="POST",
url=expected_url,
headers=mock.ANY,
data=mock.ANY,
timeout=mock.ANY,
)

request_info = self._get_mutlipart_request(http)
Expand Down Expand Up @@ -406,7 +410,11 @@ def test_finish_nonempty_with_status_failure(self):

expected_url = "{}/batch/storage/v1".format(batch.API_BASE_URL)
http.request.assert_called_once_with(
method="POST", url=expected_url, headers=mock.ANY, data=mock.ANY
method="POST",
url=expected_url,
headers=mock.ANY,
data=mock.ANY,
timeout=mock.ANY,
)

_, request_body, _, boundary = self._get_mutlipart_request(http)
Expand Down Expand Up @@ -620,8 +628,10 @@ class _Connection(object):
def __init__(self, **kw):
self.__dict__.update(kw)

def _make_request(self, method, url, data=None, headers=None):
return self.http.request(url=url, method=method, headers=headers, data=data)
def _make_request(self, method, url, data=None, headers=None, timeout=None):
return self.http.request(
url=url, method=method, headers=headers, data=data, timeout=timeout
)


class _MockObject(object):
Expand Down
54 changes: 35 additions & 19 deletions storage/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def test_get_service_account_email_wo_project(self):
]
)
http.request.assert_called_once_with(
method="GET", url=URI, data=None, headers=mock.ANY
method="GET", url=URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

def test_get_service_account_email_w_project(self):
Expand All @@ -297,7 +297,7 @@ def test_get_service_account_email_w_project(self):
]
)
http.request.assert_called_once_with(
method="GET", url=URI, data=None, headers=mock.ANY
method="GET", url=URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

def test_bucket(self):
Expand Down Expand Up @@ -366,7 +366,7 @@ def test_get_bucket_with_string_miss(self):
client.get_bucket(NONESUCH)

http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_get_bucket_with_string_hit(self):
Expand Down Expand Up @@ -396,7 +396,7 @@ def test_get_bucket_with_string_hit(self):
self.assertIsInstance(bucket, Bucket)
self.assertEqual(bucket.name, BUCKET_NAME)
http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_get_bucket_with_object_miss(self):
Expand Down Expand Up @@ -427,7 +427,7 @@ def test_get_bucket_with_object_miss(self):
client.get_bucket(bucket_obj)

http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_get_bucket_with_object_hit(self):
Expand Down Expand Up @@ -458,7 +458,7 @@ def test_get_bucket_with_object_hit(self):
self.assertIsInstance(bucket, Bucket)
self.assertEqual(bucket.name, bucket_name)
http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_lookup_bucket_miss(self):
Expand All @@ -485,7 +485,7 @@ def test_lookup_bucket_miss(self):

self.assertIsNone(bucket)
http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_lookup_bucket_hit(self):
Expand Down Expand Up @@ -514,7 +514,7 @@ def test_lookup_bucket_hit(self):
self.assertIsInstance(bucket, Bucket)
self.assertEqual(bucket.name, BUCKET_NAME)
http.request.assert_called_once_with(
method="GET", url=URI, data=mock.ANY, headers=mock.ANY
method="GET", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)

def test_create_bucket_w_missing_client_project(self):
Expand Down Expand Up @@ -666,7 +666,7 @@ def test_create_bucket_w_string_success(self):
self.assertEqual(bucket.name, bucket_name)
self.assertTrue(bucket.requester_pays)
http.request.assert_called_once_with(
method="POST", url=URI, data=mock.ANY, headers=mock.ANY
method="POST", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)
json_sent = http.request.call_args_list[0][1]["data"]
self.assertEqual(json_expected, json.loads(json_sent))
Expand Down Expand Up @@ -706,7 +706,7 @@ def test_create_bucket_w_object_success(self):
self.assertEqual(bucket.name, bucket_name)
self.assertTrue(bucket.requester_pays)
http.request.assert_called_once_with(
method="POST", url=URI, data=mock.ANY, headers=mock.ANY
method="POST", url=URI, data=mock.ANY, headers=mock.ANY, timeout=mock.ANY
)
json_sent = http.request.call_args_list[0][1]["data"]
self.assertEqual(json_expected, json.loads(json_sent))
Expand Down Expand Up @@ -848,7 +848,11 @@ def test_list_buckets_empty(self):
self.assertEqual(len(buckets), 0)

http.request.assert_called_once_with(
method="GET", url=mock.ANY, data=mock.ANY, headers=mock.ANY
method="GET",
url=mock.ANY,
data=mock.ANY,
headers=mock.ANY,
timeout=mock.ANY,
)

requested_url = http.request.mock_calls[0][2]["url"]
Expand Down Expand Up @@ -883,7 +887,11 @@ def test_list_buckets_explicit_project(self):
self.assertEqual(len(buckets), 0)

http.request.assert_called_once_with(
method="GET", url=mock.ANY, data=mock.ANY, headers=mock.ANY
method="GET",
url=mock.ANY,
data=mock.ANY,
headers=mock.ANY,
timeout=mock.ANY,
)

requested_url = http.request.mock_calls[0][2]["url"]
Expand Down Expand Up @@ -918,7 +926,11 @@ def test_list_buckets_non_empty(self):
self.assertEqual(buckets[0].name, BUCKET_NAME)

http.request.assert_called_once_with(
method="GET", url=mock.ANY, data=mock.ANY, headers=mock.ANY
method="GET",
url=mock.ANY,
data=mock.ANY,
headers=mock.ANY,
timeout=mock.ANY,
)

def test_list_buckets_all_arguments(self):
Expand Down Expand Up @@ -948,7 +960,11 @@ def test_list_buckets_all_arguments(self):
buckets = list(iterator)
self.assertEqual(buckets, [])
http.request.assert_called_once_with(
method="GET", url=mock.ANY, data=mock.ANY, headers=mock.ANY
method="GET",
url=mock.ANY,
data=mock.ANY,
headers=mock.ANY,
timeout=mock.ANY,
)

requested_url = http.request.mock_calls[0][2]["url"]
Expand Down Expand Up @@ -1077,7 +1093,7 @@ def _create_hmac_key_helper(self, explicit_project=None, user_project=None):

FULL_URI = "{}?{}".format(URI, urlencode(qs_params))
http.request.assert_called_once_with(
method="POST", url=FULL_URI, data=None, headers=mock.ANY
method="POST", url=FULL_URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

def test_create_hmac_key_defaults(self):
Expand Down Expand Up @@ -1112,7 +1128,7 @@ def test_list_hmac_keys_defaults_empty(self):
]
)
http.request.assert_called_once_with(
method="GET", url=URI, data=None, headers=mock.ANY
method="GET", url=URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

def test_list_hmac_keys_explicit_non_empty(self):
Expand Down Expand Up @@ -1176,7 +1192,7 @@ def test_list_hmac_keys_explicit_non_empty(self):
"userProject": USER_PROJECT,
}
http.request.assert_called_once_with(
method="GET", url=mock.ANY, data=None, headers=mock.ANY
method="GET", url=mock.ANY, data=None, headers=mock.ANY, timeout=mock.ANY
)
kwargs = http.request.mock_calls[0].kwargs
uri = kwargs["url"]
Expand Down Expand Up @@ -1223,7 +1239,7 @@ def test_get_hmac_key_metadata_wo_project(self):
]
)
http.request.assert_called_once_with(
method="GET", url=URI, data=None, headers=mock.ANY
method="GET", url=URI, data=None, headers=mock.ANY, timeout=mock.ANY
)

def test_get_hmac_key_metadata_w_project(self):
Expand Down Expand Up @@ -1273,5 +1289,5 @@ def test_get_hmac_key_metadata_w_project(self):
FULL_URI = "{}?{}".format(URI, urlencode(qs_params))

http.request.assert_called_once_with(
method="GET", url=FULL_URI, data=None, headers=mock.ANY
method="GET", url=FULL_URI, data=None, headers=mock.ANY, timeout=mock.ANY
)