Skip to content
This repository has been archived by the owner on Nov 5, 2019. It is now read-only.

Commit

Permalink
Merge pull request #174 from thobrla/body-stream-idem
Browse files Browse the repository at this point in the history
Rewind original stream body when refreshing.
  • Loading branch information
nathanielmanistaatgoogle committed May 13, 2015
2 parents d93ed1e + 7d517d9 commit 219cf26
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 10 deletions.
8 changes: 8 additions & 0 deletions oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,11 @@ def new_request(uri, method='GET', body=None, headers=None,
else:
headers['user-agent'] = self.user_agent

body_stream_position = None
if all(getattr(body, stream_prop, None) for stream_prop in
('read', 'seek', 'tell')):
body_stream_position = body.tell()

resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)

Expand All @@ -567,6 +572,9 @@ def new_request(uri, method='GET', body=None, headers=None,
refresh_attempt + 1, max_refresh_attempts)
self._refresh(request_orig)
self.apply(headers)
if body_stream_position is not None:
body.seek(body_stream_position)

resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)

Expand Down
9 changes: 4 additions & 5 deletions tests/http_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,16 @@ def request(self, uri,
connection_type=None):
resp, content = self._iterable.pop(0)
self.requests.append({'uri': uri, 'body': body, 'headers': headers})
# Read any underlying stream before sending the request.
body_stream_content = body.read() if getattr(body, 'read', None) else None
if content == 'echo_request_headers':
content = headers
elif content == 'echo_request_headers_as_json':
content = json.dumps(headers)
elif content == 'echo_request_body':
if hasattr(body, 'read'):
content = body.read()
else:
content = body
content = body if body_stream_content is None else body_stream_content
elif content == 'echo_request_uri':
content = uri
elif not isinstance(content, bytes):
raise TypeError("http content should be bytes: %r" % (content,))
raise TypeError('http content should be bytes: %r' % (content,))
return httplib2.Response(resp), content
43 changes: 38 additions & 5 deletions tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@
import unittest

from .http_mock import HttpMockSequence
import six

from oauth2client import file
from oauth2client import locked_file
from oauth2client import multistore_file
from oauth2client import util
from oauth2client.client import AccessTokenCredentials
from oauth2client.client import OAuth2Credentials
from six.moves import http_client
try:
# Python2
from future_builtins import oct
Expand Down Expand Up @@ -154,15 +157,17 @@ def test_token_refresh_store_expires_soon(self):
access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': '401'}, b'Initial token expired'),
({'status': '401'}, b'Store token expired'),
({'status': '200'}, json.dumps(token_response).encode('utf-8')),
({'status': '200'}, b'Valid response to original request')
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)},
b'Valid response to original request')
])

credentials.authorize(http)
http.request('https://example.com')
self.assertEquals(credentials.access_token, access_token)
self.assertEqual(credentials.access_token, access_token)

def test_token_refresh_good_store(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
Expand All @@ -178,6 +183,34 @@ def test_token_refresh_good_store(self):
credentials._refresh(lambda x: x)
self.assertEquals(credentials.access_token, 'bar')

def test_token_refresh_stream_body(self):
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=15)
credentials = self.create_test_credentials(expiration=expiration)

s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)

valid_access_token = '1/3w'
token_response = {'access_token': valid_access_token, 'expires_in': 3600}
http = HttpMockSequence([
({'status': str(http_client.UNAUTHORIZED)}, b'Initial token expired'),
({'status': str(http_client.UNAUTHORIZED)}, b'Store token expired'),
({'status': str(http_client.OK)},
json.dumps(token_response).encode('utf-8')),
({'status': str(http_client.OK)}, 'echo_request_body')
])

body = six.StringIO('streaming body')

credentials.authorize(http)
_, content = http.request('https://example.com', body=body)
self.assertEqual(content, 'streaming body')
self.assertEqual(credentials.access_token, valid_access_token)

def test_credentials_delete(self):
credentials = self.create_test_credentials()

Expand Down

0 comments on commit 219cf26

Please sign in to comment.