From 7d517d9295c8d45b23b1942dec4009a2d63b3df5 Mon Sep 17 00:00:00 2001 From: Travis Hobrla Date: Mon, 11 May 2015 15:42:45 -0700 Subject: [PATCH] Rewind original stream body when refreshing. When refreshing credentials, the original request is re-sent after the credentials are refreshed. If the body of that request is a stream, the stream contents are read in the initial request, and the stream must be rewound before the request is re-sent. Otherwise, the original message body will be different (because stream data was skipped). --- oauth2client/client.py | 8 ++++++++ tests/http_mock.py | 9 ++++----- tests/test_file.py | 43 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/oauth2client/client.py b/oauth2client/client.py index d808ed4e0..abe342e84 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -554,6 +554,11 @@ def new_request(uri, method='GET', body=None, headers=None, else: headers['user-agent'] = self.user_agent + body_stream_position = None + if all(getattr(body, stream_prop, None) for stream_prop in + ('read', 'seek', 'tell')): + body_stream_position = body.tell() + resp, content = request_orig(uri, method, body, clean_headers(headers), redirections, connection_type) @@ -567,6 +572,9 @@ def new_request(uri, method='GET', body=None, headers=None, refresh_attempt + 1, max_refresh_attempts) self._refresh(request_orig) self.apply(headers) + if body_stream_position is not None: + body.seek(body_stream_position) + resp, content = request_orig(uri, method, body, clean_headers(headers), redirections, connection_type) diff --git a/tests/http_mock.py b/tests/http_mock.py index b059b5fa8..4040d48b2 100644 --- a/tests/http_mock.py +++ b/tests/http_mock.py @@ -100,17 +100,16 @@ def request(self, uri, connection_type=None): resp, content = self._iterable.pop(0) self.requests.append({'uri': uri, 'body': body, 'headers': headers}) + # Read any underlying stream before sending the request. + body_stream_content = body.read() if getattr(body, 'read', None) else None if content == 'echo_request_headers': content = headers elif content == 'echo_request_headers_as_json': content = json.dumps(headers) elif content == 'echo_request_body': - if hasattr(body, 'read'): - content = body.read() - else: - content = body + content = body if body_stream_content is None else body_stream_content elif content == 'echo_request_uri': content = uri elif not isinstance(content, bytes): - raise TypeError("http content should be bytes: %r" % (content,)) + raise TypeError('http content should be bytes: %r' % (content,)) return httplib2.Response(resp), content diff --git a/tests/test_file.py b/tests/test_file.py index 8efe6cfed..89c3e02c7 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -32,12 +32,15 @@ import unittest from .http_mock import HttpMockSequence +import six + from oauth2client import file from oauth2client import locked_file from oauth2client import multistore_file from oauth2client import util from oauth2client.client import AccessTokenCredentials from oauth2client.client import OAuth2Credentials +from six.moves import http_client try: # Python2 from future_builtins import oct @@ -154,15 +157,17 @@ def test_token_refresh_store_expires_soon(self): access_token = '1/3w' token_response = {'access_token': access_token, 'expires_in': 3600} http = HttpMockSequence([ - ({'status': '401'}, b'Initial token expired'), - ({'status': '401'}, b'Store token expired'), - ({'status': '200'}, json.dumps(token_response).encode('utf-8')), - ({'status': '200'}, b'Valid response to original request') + ({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'), + ({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'), + ({'status': str(http_client.OK)}, + json.dumps(token_response).encode('utf-8')), + ({'status': str(http_client.OK)}, + b'Valid response to original request') ]) credentials.authorize(http) http.request('https://example.com') - self.assertEquals(credentials.access_token, access_token) + self.assertEqual(credentials.access_token, access_token) def test_token_refresh_good_store(self): expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15) @@ -178,6 +183,34 @@ def test_token_refresh_good_store(self): credentials._refresh(lambda x: x) self.assertEquals(credentials.access_token, 'bar') + def test_token_refresh_stream_body(self): + expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15) + credentials = self.create_test_credentials(expiration=expiration) + + s = file.Storage(FILENAME) + s.put(credentials) + credentials = s.get() + new_cred = copy.copy(credentials) + new_cred.access_token = 'bar' + s.put(new_cred) + + valid_access_token = '1/3w' + token_response = {'access_token': valid_access_token, 'expires_in': 3600} + http = HttpMockSequence([ + ({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'), + ({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'), + ({'status': str(http_client.OK)}, + json.dumps(token_response).encode('utf-8')), + ({'status': str(http_client.OK)}, 'echo_request_body') + ]) + + body = six.StringIO('streaming body') + + credentials.authorize(http) + _, content = http.request('https://example.com', body=body) + self.assertEqual(content, 'streaming body') + self.assertEqual(credentials.access_token, valid_access_token) + def test_credentials_delete(self): credentials = self.create_test_credentials()