Skip to content

Commit

Permalink
Merge pull request #607 from rutsky/ws_origin_header
Browse files Browse the repository at this point in the history
add option to pass Origin header in ws_connect
  • Loading branch information
asvetlov committed Oct 29, 2015
2 parents 2abbc25 + 8d68d39 commit 944e0f6
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 7 deletions.
16 changes: 11 additions & 5 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,23 +228,26 @@ def ws_connect(self, url, *,
timeout=10.0,
autoclose=True,
autoping=True,
auth=None):
auth=None,
origin=None):
"""Initiate websocket connection."""
return _WSRequestContextManager(
self._ws_connect(url,
protocols=protocols,
timeout=timeout,
autoclose=autoclose,
autoping=autoping,
auth=auth))
auth=auth,
origin=origin))

@asyncio.coroutine
def _ws_connect(self, url, *,
protocols=(),
timeout=10.0,
autoclose=True,
autoping=True,
auth=None):
auth=None,
origin=None):

sec_key = base64.b64encode(os.urandom(16))

Expand All @@ -256,6 +259,8 @@ def _ws_connect(self, url, *,
}
if protocols:
headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
if origin is not None:
headers[hdrs.ORIGIN] = origin

# send request
resp = yield from self.request('get', url, headers=headers,
Expand Down Expand Up @@ -659,7 +664,7 @@ def delete(url, **kwargs):

def ws_connect(url, *, protocols=(), timeout=10.0, connector=None, auth=None,
ws_response_class=ClientWebSocketResponse, autoclose=True,
autoping=True, loop=None):
autoping=True, loop=None, origin=None):

if loop is None:
loop = asyncio.get_event_loop()
Expand All @@ -675,5 +680,6 @@ def ws_connect(url, *, protocols=(), timeout=10.0, connector=None, auth=None,
protocols=protocols,
timeout=timeout,
autoclose=autoclose,
autoping=autoping),
autoping=autoping,
origin=origin),
session=session)
10 changes: 9 additions & 1 deletion docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ The client session supports context manager protocol for self closing.

.. coroutinemethod:: ws_connect(url, *, protocols=(), timeout=10.0,\
auth=None,\
autoclose=True, autoping=True)
autoclose=True,\
autoping=True,\
origin=None)

Create a websocket connection. Returns a
:class:`ClientWebSocketResponse` object.
Expand All @@ -327,6 +329,8 @@ The client session supports context manager protocol for self closing.
:param bool autoping: automatically send `pong` on `ping`
message from server

:param str origin: Origin header to send to server

.. versionadded:: 0.16

Add :meth:`ws_connect`.
Expand All @@ -335,6 +339,10 @@ The client session supports context manager protocol for self closing.

Add *auth* parameter.

.. versionadded:: 0.19

Add *origin* parameter.

.. method:: close()

Close underlying connector.
Expand Down
9 changes: 8 additions & 1 deletion docs/client_websockets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ coroutines, do not create an instance of class
.. coroutinefunction:: ws_connect(url, *, protocols=(), \
timeout=10.0, connector=None, auth=None,\
ws_response_class=ClientWebSocketResponse,\
autoclose=True, autoping=True, loop=None)
autoclose=True, autoping=True, loop=None,\
origin=None)

This function creates a websocket connection, checks the response and
returns a :class:`ClientWebSocketResponse` object. In case of failure
Expand Down Expand Up @@ -98,10 +99,16 @@ coroutines, do not create an instance of class
used for getting default event loop, but we strongly
recommend to use explicit loops everywhere.

:param str origin: Origin header to send to server

.. versionadded:: 0.18

Add *auth* parameter.

.. versionadded:: 0.19

Add *origin* parameter.


ClientWebSocketResponse
-----------------------
Expand Down
21 changes: 21 additions & 0 deletions tests/test_websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ def test_ws_connect(self, m_req, m_os):

self.assertIsInstance(res, websocket_client.ClientWebSocketResponse)
self.assertEqual(res.protocol, 'chat')
self.assertNotIn(hdrs.ORIGIN, m_req.call_args[1]["headers"])

@mock.patch('aiohttp.client.os')
@mock.patch('aiohttp.client.ClientSession.request')
def test_ws_connect_with_origin(self, m_req, m_os):
resp = mock.Mock()
resp.status = 403
m_os.urandom.return_value = self.key_data
m_req.return_value = asyncio.Future(loop=self.loop)
m_req.return_value.set_result(resp)

origin = 'https://example.org/page.html'
with self.assertRaises(errors.WSServerHandshakeError):
self.loop.run_until_complete(
aiohttp.ws_connect(
'http://test.org',
loop=self.loop,
origin=origin))

self.assertIn(hdrs.ORIGIN, m_req.call_args[1]["headers"])
self.assertEqual(m_req.call_args[1]["headers"][hdrs.ORIGIN], origin)

@mock.patch('aiohttp.client.os')
@mock.patch('aiohttp.client.ClientSession.request')
Expand Down

0 comments on commit 944e0f6

Please sign in to comment.