Skip to content

Commit

Permalink
Merge pull request #486 from KeepSafe/skip_default_headers
Browse files Browse the repository at this point in the history
Skip default headers
  • Loading branch information
asvetlov committed Sep 3, 2015
2 parents 755a232 + 5c5743e commit c533d97
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ CHANGES
Using `force` parameter for the method is deprecated: use `.release()` instead.

* Properly requote URL's path #480

* add `skip_auto_headers` parameter for client API #486
17 changes: 15 additions & 2 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class ClientSession:
_connector = None

def __init__(self, *, connector=None, loop=None, cookies=None,
headers=None, auth=None, request_class=ClientRequest,
headers=None, skip_auto_headers=None,
auth=None, request_class=ClientRequest,
response_class=ClientResponse,
ws_response_class=ClientWebSocketResponse):

Expand Down Expand Up @@ -65,6 +66,11 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
else:
headers = CIMultiDict()
self._default_headers = headers
if skip_auto_headers is not None:
self._skip_auto_headers = frozenset([upstr(i)
for i in skip_auto_headers])
else:
self._skip_auto_headers = frozenset()

self._request_class = request_class
self._response_class = response_class
Expand All @@ -88,6 +94,7 @@ def request(self, method, url, *,
params=None,
data=None,
headers=None,
skip_auto_headers=None,
files=None,
auth=None,
allow_redirects=True,
Expand Down Expand Up @@ -119,9 +126,15 @@ def request(self, method, url, *,
raise ValueError("Can't combine `Authorization` header with "
"`auth` argument")

skip_headers = set(self._skip_auto_headers)
if skip_auto_headers is not None:
for i in skip_auto_headers:
skip_headers.add(upstr(i))

while True:
req = self._request_class(
method, url, params=params, headers=headers, data=data,
method, url, params=params, headers=headers,
skip_auto_headers=skip_headers, data=data,
cookies=self.cookies, files=files, encoding=encoding,
auth=auth, version=version, compress=compress, chunked=chunked,
expect100=expect100,
Expand Down
24 changes: 19 additions & 5 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from . import hdrs, helpers, streams
from .log import client_logger
from .streams import EOF_MARKER, FlowControlStreamReader
from .multidict import CIMultiDictProxy, MultiDictProxy, MultiDict, CIMultiDict
from .multidict import (CIMultiDictProxy, MultiDictProxy, MultiDict,
CIMultiDict)
from .multipart import MultipartWriter
from .protocol import HttpMessage

PY_341 = sys.version_info >= (3, 4, 1)

Expand All @@ -39,6 +41,8 @@ class ClientRequest:
hdrs.ACCEPT_ENCODING: 'gzip, deflate',
}

SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

body = b''
auth = None
response = None
Expand All @@ -53,7 +57,8 @@ class ClientRequest:
# Until writer has finished finalizer will not be called.

def __init__(self, method, url, *,
params=None, headers=None, data=None, cookies=None,
params=None, headers=None, skip_auto_headers=frozenset(),
data=None, cookies=None,
files=None, auth=None, encoding='utf-8',
version=aiohttp.HttpVersion11, compress=None,
chunked=None, expect100=False,
Expand All @@ -77,6 +82,7 @@ def __init__(self, method, url, *,
self.update_host(url)
self.update_path(params)
self.update_headers(headers)
self.update_auto_headers(skip_auto_headers)
self.update_cookies(cookies)
self.update_content_encoding()
self.update_auth(auth)
Expand Down Expand Up @@ -191,14 +197,21 @@ def update_headers(self, headers):
for key, value in headers:
self.headers.add(key, value)

def update_auto_headers(self, skip_auto_headers):
self.skip_auto_headers = skip_auto_headers
used_headers = set(self.headers) | skip_auto_headers

for hdr, val in self.DEFAULT_HEADERS.items():
if hdr not in self.headers:
self.headers[hdr] = val
if hdr not in used_headers:
self.headers.add(hdr, val)

# add host
if hdrs.HOST not in self.headers:
if hdrs.HOST not in used_headers:
self.headers[hdrs.HOST] = self.netloc

if hdrs.USER_AGENT not in used_headers:
self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

def update_cookies(self, cookies):
"""Update request cookies header."""
if not cookies:
Expand Down Expand Up @@ -445,6 +458,7 @@ def send(self, writer, reader):

# set default content-type
if (self.method in self.POST_METHODS and
hdrs.CONTENT_TYPE not in self.skip_auto_headers and
hdrs.CONTENT_TYPE not in self.headers):
self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

Expand Down
5 changes: 0 additions & 5 deletions aiohttp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,3 @@ def __init__(self, transport, method, path,
self.path = path
self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format(
method, path, http_version)

def _add_default_headers(self):
super()._add_default_headers()

self.headers.setdefault(hdrs.USER_AGENT, self.SERVER_SOFTWARE)
36 changes: 32 additions & 4 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ The client session supports context manager protocol for self closing::


.. class:: ClientSession(*, connector=None, loop=None, cookies=None,\
headers=None, auth=None, request_class=ClientRequest,\
headers=None, skip_auto_headers=None, \
auth=None, request_class=ClientRequest,\
response_class=ClientResponse, \
ws_response_class=ClientWebSocketResponse)

Expand All @@ -61,8 +62,23 @@ The client session supports context manager protocol for self closing::

:param dict cookies: Cookies to send with the request (optional)

:param dict headers: HTTP Headers to send with
the request (optional)
:param headers: HTTP Headers to send with
the request (optional).

May be either *iterable of key-value pairs* or
:class:`~collections.abc.Mapping`
(e.g. :class:`dict`,
:class:`~aiohttp.multidict.CIMultiDict`).

:param skip_auto_headers: set of headers for which autogeneration
should be skipped.

*aiohttp* autogenerates headers like ``User-Agent`` or
``Content-Type`` if these headers are not explicitly
passed. Using ``skip_auto_headers`` parameter allows to skip
that generation.

Iterable of :class:`str` or :class:`~aiohttp.multidict.upstr` (optional)

:param aiohttp.helpers.BasicAuth auth: BasicAuth named tuple that represents
HTTP Basic Authorization (optional)
Expand Down Expand Up @@ -106,7 +122,8 @@ The client session supports context manager protocol for self closing::


.. coroutinemethod:: request(method, url, *, params=None, data=None,\
headers=None, auth=None, allow_redirects=True,\
headers=None, skip_auto_headers=None, \
auth=None, allow_redirects=True,\
max_redirects=10, encoding='utf-8',\
version=HttpVersion(major=1, minor=1),\
compress=None, chunked=None, expect100=False,\
Expand All @@ -128,6 +145,17 @@ The client session supports context manager protocol for self closing::
:param dict headers: HTTP Headers to send with
the request (optional)

:param skip_auto_headers: set of headers for which autogeneration
should be skipped.

*aiohttp* autogenerates headers like ``User-Agent`` or
``Content-Type`` if these headers are not explicitly
passed. Using ``skip_auto_headers`` parameter allows to skip
that generation.

Iterable of :class:`str` or :class:`~aiohttp.multidict.upstr`
(optional)

:param aiohttp.helpers.BasicAuth auth: BasicAuth named tuple that
represents HTTP Basic Authorization
(optional)
Expand Down
116 changes: 116 additions & 0 deletions tests/test_client_functional2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import asyncio
import socket
import unittest

import aiohttp
from aiohttp import hdrs, log, web


class TestClientFunctional2(unittest.TestCase):

def setUp(self):
self.handler = None
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.client = aiohttp.ClientSession(loop=self.loop)

def tearDown(self):
if self.handler:
self.loop.run_until_complete(self.handler.finish_connections())
self.client.close()
self.loop.stop()
self.loop.run_forever()
self.loop.close()

def find_unused_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]
s.close()
return port

@asyncio.coroutine
def create_server(self):
app = web.Application(loop=self.loop)

port = self.find_unused_port()
self.handler = app.make_handler(
debug=True, keep_alive_on=False,
access_log=log.access_logger)
srv = yield from self.loop.create_server(
self.handler, '127.0.0.1', port)
url = "http://127.0.0.1:{}".format(port)
self.addCleanup(srv.close)
return app, srv, url

def test_auto_header_user_agent(self):
@asyncio.coroutine
def handler(request):
self.assertIn('aiohttp', request.headers['user-agent'])
return web.Response()

@asyncio.coroutine
def go():
app, srv, url = yield from self.create_server()
app.router.add_route('get', '/', handler)
resp = yield from self.client.get(url+'/')
self.assertEqual(200, resp.status)
yield from resp.release()

self.loop.run_until_complete(go())

def test_skip_auto_headers_user_agent(self):
@asyncio.coroutine
def handler(request):
self.assertNotIn(hdrs.USER_AGENT, request.headers)
return web.Response()

@asyncio.coroutine
def go():
app, srv, url = yield from self.create_server()
app.router.add_route('get', '/', handler)
resp = yield from self.client.get(url+'/',
skip_auto_headers=['user-agent'])
self.assertEqual(200, resp.status)
yield from resp.release()

self.loop.run_until_complete(go())

def test_skip_default_auto_headers_user_agent(self):
@asyncio.coroutine
def handler(request):
self.assertNotIn(hdrs.USER_AGENT, request.headers)
return web.Response()

@asyncio.coroutine
def go():
app, srv, url = yield from self.create_server()
app.router.add_route('get', '/', handler)

client = aiohttp.ClientSession(loop=self.loop,
skip_auto_headers=['user-agent'])
resp = yield from client.get(url+'/')
self.assertEqual(200, resp.status)
yield from resp.release()

client.close()

self.loop.run_until_complete(go())

def test_skip_auto_headers_content_type(self):
@asyncio.coroutine
def handler(request):
self.assertNotIn(hdrs.CONTENT_TYPE, request.headers)
return web.Response()

@asyncio.coroutine
def go():
app, srv, url = yield from self.create_server()
app.router.add_route('get', '/', handler)
resp = yield from self.client.get(
url+'/',
skip_auto_headers=['content-type'])
self.assertEqual(200, resp.status)
yield from resp.release()

self.loop.run_until_complete(go())
20 changes: 20 additions & 0 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import aiohttp
from aiohttp.client_reqrep import ClientRequest, ClientResponse
from aiohttp.multidict import upstr

PY_341 = sys.version_info >= (3, 4, 1)

Expand Down Expand Up @@ -116,6 +117,25 @@ def test_host_header(self):
self.assertEqual(req.headers['HOST'], 'example.com:99')
self.loop.run_until_complete(req.close())

def test_default_headers_useragent(self):
req = ClientRequest('get', 'http://python.org/', loop=self.loop)

self.assertNotIn('SERVER', req.headers)
self.assertIn('USER-AGENT', req.headers)

def test_default_headers_useragent_custom(self):
req = ClientRequest('get', 'http://python.org/', loop=self.loop,
headers={'user-agent': 'my custom agent'})

self.assertIn('USER-Agent', req.headers)
self.assertEqual('my custom agent', req.headers['User-Agent'])

def test_skip_default_useragent_header(self):
req = ClientRequest('get', 'http://python.org/', loop=self.loop,
skip_auto_headers=set([upstr('user-agent')]))

self.assertNotIn('User-Agent', req.headers)

def test_headers(self):
req = ClientRequest('get', 'http://python.org/',
headers={'Content-Type': 'text/plain'},
Expand Down
33 changes: 31 additions & 2 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ def test_init_headers_MultiDict(self):
("H3", "header3")]))
session.close()

def test_init_headers_list_of_tuples_with_duplicates(self):
session = ClientSession(
headers=[("h1", "header11"),
("h2", "header21"),
("h1", "header12")],
loop=self.loop)
self.assertEqual(
session._default_headers,
CIMultiDict([("H1", "header11"),
("H2", "header21"),
("H1", "header12")]))
session.close()

def test_init_cookies_with_simple_dict(self):
session = ClientSession(
cookies={
Expand Down Expand Up @@ -142,8 +155,24 @@ def test_merge_headers_with_list_of_tuples(self):
]))
session.close()

def _make_one(self):
session = ClientSession(loop=self.loop)
def test_merge_headers_with_list_of_tuples_duplicated_names(self):
session = ClientSession(
headers={
"h1": "header1",
"h2": "header2"
}, loop=self.loop)
headers = session._prepare_headers([("h1", "v1"),
("h1", "v2")])
self.assertIsInstance(headers, CIMultiDict)
self.assertEqual(headers, CIMultiDict([
("H2", "header2"),
("H1", "v1"),
("H1", "v2"),
]))
session.close()

def _make_one(self, **kwargs):
session = ClientSession(loop=self.loop, **kwargs)
params = dict(
headers={"Authorization": "Basic ..."},
max_redirects=2,
Expand Down
Loading

0 comments on commit c533d97

Please sign in to comment.