diff --git a/aiohttp/client.py b/aiohttp/client.py index e261322f642..07f69de8237 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -228,7 +228,8 @@ def ws_connect(self, url, *, timeout=10.0, autoclose=True, autoping=True, - auth=None): + auth=None, + origin=None): """Initiate websocket connection.""" return _WSRequestContextManager( self._ws_connect(url, @@ -236,7 +237,8 @@ def ws_connect(self, url, *, timeout=timeout, autoclose=autoclose, autoping=autoping, - auth=auth)) + auth=auth, + origin=origin)) @asyncio.coroutine def _ws_connect(self, url, *, @@ -244,7 +246,8 @@ def _ws_connect(self, url, *, timeout=10.0, autoclose=True, autoping=True, - auth=None): + auth=None, + origin=None): sec_key = base64.b64encode(os.urandom(16)) @@ -256,6 +259,8 @@ def _ws_connect(self, url, *, } if protocols: headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) + if origin is not None: + headers[hdrs.ORIGIN] = origin # send request resp = yield from self.request('get', url, headers=headers, @@ -659,7 +664,7 @@ def delete(url, **kwargs): def ws_connect(url, *, protocols=(), timeout=10.0, connector=None, auth=None, ws_response_class=ClientWebSocketResponse, autoclose=True, - autoping=True, loop=None): + autoping=True, loop=None, origin=None): if loop is None: loop = asyncio.get_event_loop() @@ -675,5 +680,6 @@ def ws_connect(url, *, protocols=(), timeout=10.0, connector=None, auth=None, protocols=protocols, timeout=timeout, autoclose=autoclose, - autoping=autoping), + autoping=autoping, + origin=origin), session=session) diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 343a024312d..7feaf4ebd97 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -305,7 +305,9 @@ The client session supports context manager protocol for self closing. .. coroutinemethod:: ws_connect(url, *, protocols=(), timeout=10.0,\ auth=None,\ - autoclose=True, autoping=True) + autoclose=True,\ + autoping=True,\ + origin=None) Create a websocket connection. Returns a :class:`ClientWebSocketResponse` object. @@ -327,6 +329,8 @@ The client session supports context manager protocol for self closing. :param bool autoping: automatically send `pong` on `ping` message from server + :param str origin: Origin header to send to server + .. versionadded:: 0.16 Add :meth:`ws_connect`. @@ -335,6 +339,10 @@ The client session supports context manager protocol for self closing. Add *auth* parameter. + .. versionadded:: 0.19 + + Add *origin* parameter. + .. method:: close() Close underlying connector. diff --git a/docs/client_websockets.rst b/docs/client_websockets.rst index c39942e095e..1383f3d5bec 100644 --- a/docs/client_websockets.rst +++ b/docs/client_websockets.rst @@ -62,7 +62,8 @@ coroutines, do not create an instance of class .. coroutinefunction:: ws_connect(url, *, protocols=(), \ timeout=10.0, connector=None, auth=None,\ ws_response_class=ClientWebSocketResponse,\ - autoclose=True, autoping=True, loop=None) + autoclose=True, autoping=True, loop=None,\ + origin=None) This function creates a websocket connection, checks the response and returns a :class:`ClientWebSocketResponse` object. In case of failure @@ -98,10 +99,16 @@ coroutines, do not create an instance of class used for getting default event loop, but we strongly recommend to use explicit loops everywhere. + :param str origin: Origin header to send to server + .. versionadded:: 0.18 Add *auth* parameter. + .. versionadded:: 0.19 + + Add *origin* parameter. + ClientWebSocketResponse ----------------------- diff --git a/tests/test_websocket_client.py b/tests/test_websocket_client.py index 1b598f31d27..99a2685048a 100644 --- a/tests/test_websocket_client.py +++ b/tests/test_websocket_client.py @@ -46,6 +46,27 @@ def test_ws_connect(self, m_req, m_os): self.assertIsInstance(res, websocket_client.ClientWebSocketResponse) self.assertEqual(res.protocol, 'chat') + self.assertNotIn(hdrs.ORIGIN, m_req.call_args[1]["headers"]) + + @mock.patch('aiohttp.client.os') + @mock.patch('aiohttp.client.ClientSession.request') + def test_ws_connect_with_origin(self, m_req, m_os): + resp = mock.Mock() + resp.status = 403 + m_os.urandom.return_value = self.key_data + m_req.return_value = asyncio.Future(loop=self.loop) + m_req.return_value.set_result(resp) + + origin = 'https://example.org/page.html' + with self.assertRaises(errors.WSServerHandshakeError): + self.loop.run_until_complete( + aiohttp.ws_connect( + 'http://test.org', + loop=self.loop, + origin=origin)) + + self.assertIn(hdrs.ORIGIN, m_req.call_args[1]["headers"]) + self.assertEqual(m_req.call_args[1]["headers"][hdrs.ORIGIN], origin) @mock.patch('aiohttp.client.os') @mock.patch('aiohttp.client.ClientSession.request')