diff --git a/docs/source/api.rst b/docs/source/api.rst index 1a5b317..4b6798e 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -128,6 +128,21 @@ have a newline inside a header value, and ``Content-Length: hello`` is an error because `Content-Length` should always be an integer. We may add additional checks in the future. +While we make sure to expose header names as lowercased bytes, we also +preserve the original header casing that is used. Compliant HTTP +agents should always treat headers in a case insensitive manner, but +this may not always be the case. When sending bytes over the wire we +send headers preserving whatever original header casing was used. + +It is possible to access the headers in their raw original casing, +which may be useful for some user output or debugging purposes. + +.. ipython:: python + + original_headers = [("Host", "example.com")] + req = h11.Request(method="GET", target="/", headers=original_headers) + req.headers.raw_items() + .. _http_version-format: It's not just headers we normalize to being byte-strings: the same diff --git a/h11/_connection.py b/h11/_connection.py index fc6289a..410c4e9 100644 --- a/h11/_connection.py +++ b/h11/_connection.py @@ -534,7 +534,7 @@ def send_failed(self): def _clean_up_response_headers_for_sending(self, response): assert type(response) is Response - headers = list(response.headers) + headers = response.headers need_close = False # HEAD requests need some special handling: they always act like they @@ -560,13 +560,13 @@ def _clean_up_response_headers_for_sending(self, response): # but the HTTP spec says that if our peer does this then we have # to fix it instead of erroring out, so we'll accord the user the # same respect). - set_comma_header(headers, b"content-length", []) + headers = set_comma_header(headers, b"content-length", []) if self.their_http_version is None or self.their_http_version < b"1.1": # Either we never got a valid request and are sending back an # error (their_http_version is None), so we assume the worst; # or else we did get a valid HTTP/1.0 request, so we know that # they don't understand chunked encoding. - set_comma_header(headers, b"transfer-encoding", []) + headers = set_comma_header(headers, b"transfer-encoding", []) # This is actually redundant ATM, since currently we # unconditionally disable keep-alive when talking to HTTP/1.0 # peers. But let's be defensive just in case we add @@ -574,13 +574,13 @@ def _clean_up_response_headers_for_sending(self, response): if self._request_method != b"HEAD": need_close = True else: - set_comma_header(headers, b"transfer-encoding", ["chunked"]) + headers = set_comma_header(headers, b"transfer-encoding", ["chunked"]) if not self._cstate.keep_alive or need_close: # Make sure Connection: close is set connection = set(get_comma_header(headers, b"connection")) connection.discard(b"keep-alive") connection.add(b"close") - set_comma_header(headers, b"connection", sorted(connection)) + headers = set_comma_header(headers, b"connection", sorted(connection)) response.headers = headers diff --git a/h11/_headers.py b/h11/_headers.py index 878f63c..39612a0 100644 --- a/h11/_headers.py +++ b/h11/_headers.py @@ -62,6 +62,59 @@ _field_value_re = re.compile(field_value.encode("ascii")) +class Headers: + """ + A list-like interface that allows iterating over headers as byte-pairs + of (lowercased-name, value). + + Internally we actually store the representation as three-tuples, + including both the raw original casing, in order to preserve casing + over-the-wire, and the lowercased name, for case-insensitive comparisions. + + r = Request( + method="GET", + target="/", + headers=[("Host", "example.org"), ("Connection", "keep-alive")], + http_version="1.1", + ) + assert r.headers == [ + (b"host", b"example.org"), + (b"connection", b"keep-alive") + ] + assert r.headers.raw_items() == [ + (b"Host", b"example.org"), + (b"Connection", b"keep-alive") + ] + """ + __slots__ = '_full_items' + + def __init__(self, full_items): + self._full_items = full_items + + def __iter__(self): + for _, name, value in self._full_items: + yield name, value + + def __bool__(self): + return bool(self._full_items) + + def __eq__(self, other): + return list(self) == list(other) + + def __len__(self): + return len(self._full_items) + + def __repr__(self): + return "" % repr(list(self)) + + def __getitem__(self, idx): + _, name, value = self._full_items[idx] + return (name, value) + + def raw_items(self): + return [(raw_name, value) for raw_name, _, value in self._full_items] + + def normalize_and_validate(headers, _parsed=False): new_headers = [] saw_content_length = False @@ -75,6 +128,7 @@ def normalize_and_validate(headers, _parsed=False): value = bytesify(value) validate(_field_name_re, name, "Illegal header name {!r}", name) validate(_field_value_re, value, "Illegal header value {!r}", value) + raw_name = name name = name.lower() if name == b"content-length": if saw_content_length: @@ -99,8 +153,8 @@ def normalize_and_validate(headers, _parsed=False): error_status_hint=501, ) saw_transfer_encoding = True - new_headers.append((name, value)) - return new_headers + new_headers.append((raw_name, name, value)) + return Headers(new_headers) def get_comma_header(headers, name): @@ -140,7 +194,7 @@ def get_comma_header(headers, name): # "100-continue". Splitting on commas is harmless. Case insensitive. # out = [] - for found_name, found_raw_value in headers: + for _, found_name, found_raw_value in headers._full_items: if found_name == name: found_raw_value = found_raw_value.lower() for found_split_value in found_raw_value.split(b","): @@ -152,13 +206,21 @@ def get_comma_header(headers, name): def set_comma_header(headers, name, new_values): # The header name `name` is expected to be lower-case bytes. + # + # Note that when we store the header we use title casing for the header + # names, in order to match the conventional HTTP header style. + # + # Simply calling `.title()` is a blunt approach, but it's correct + # here given the cases where we're using `set_comma_header`... + # + # Connection, Content-Length, Transfer-Encoding. new_headers = [] - for found_name, found_raw_value in headers: + for found_raw_name, found_name, found_raw_value in headers._full_items: if found_name != name: - new_headers.append((found_name, found_raw_value)) + new_headers.append((found_raw_name, found_raw_value)) for new_value in new_values: - new_headers.append((name, new_value)) - headers[:] = normalize_and_validate(new_headers) + new_headers.append((name.title(), new_value)) + return normalize_and_validate(new_headers) def has_expect_100_continue(request): diff --git a/h11/_readers.py b/h11/_readers.py index 56d8915..cc86bff 100644 --- a/h11/_readers.py +++ b/h11/_readers.py @@ -58,7 +58,9 @@ def _decode_header_lines(lines): # Python 3, validate() takes either and returns matches as bytes. But # on Python 2, validate can return matches as bytearrays, so we have # to explicitly cast back. - matches = validate(header_field_re, bytes(line), "illegal header line: {!r}", bytes(line)) + matches = validate( + header_field_re, bytes(line), "illegal header line: {!r}", bytes(line) + ) yield (matches["field_name"], matches["field_value"]) @@ -71,7 +73,9 @@ def maybe_read_from_IDLE_client(buf): return None if not lines: raise LocalProtocolError("no request line received") - matches = validate(request_line_re, lines[0], "illegal request line: {!r}", lines[0]) + matches = validate( + request_line_re, lines[0], "illegal request line: {!r}", lines[0] + ) return Request( headers=list(_decode_header_lines(lines[1:])), _parsed=True, **matches ) @@ -152,7 +156,12 @@ def __call__(self, buf): chunk_header = buf.maybe_extract_until_next(b"\r\n") if chunk_header is None: return None - matches = validate(chunk_header_re, chunk_header, "illegal chunk header: {!r}", chunk_header) + matches = validate( + chunk_header_re, + chunk_header, + "illegal chunk header: {!r}", + chunk_header, + ) # XX FIXME: we discard chunk extensions. Does anyone care? # We convert to bytes because Python 2's `int()` function doesn't # work properly on bytearray objects. diff --git a/h11/_writers.py b/h11/_writers.py index 6a41100..7531579 100644 --- a/h11/_writers.py +++ b/h11/_writers.py @@ -38,12 +38,13 @@ def write_headers(headers, write): # "Since the Host field-value is critical information for handling a # request, a user agent SHOULD generate Host as the first header field # following the request-line." - RFC 7230 - for name, value in headers: + raw_items = headers._full_items + for raw_name, name, value in raw_items: if name == b"host": - write(bytesmod(b"%s: %s\r\n", (name, value))) - for name, value in headers: + write(bytesmod(b"%s: %s\r\n", (raw_name, value))) + for raw_name, name, value in raw_items: if name != b"host": - write(bytesmod(b"%s: %s\r\n", (name, value))) + write(bytesmod(b"%s: %s\r\n", (raw_name, value))) write(b"\r\n") diff --git a/h11/tests/test_connection.py b/h11/tests/test_connection.py index 13e6e2d..a43113e 100644 --- a/h11/tests/test_connection.py +++ b/h11/tests/test_connection.py @@ -96,7 +96,7 @@ def test_Connection_basics_and_content_length(): ), ) assert data == ( - b"GET / HTTP/1.1\r\n" b"host: example.com\r\n" b"content-length: 10\r\n\r\n" + b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 10\r\n\r\n" ) for conn in p.conns: @@ -113,7 +113,7 @@ def test_Connection_basics_and_content_length(): assert data == b"HTTP/1.1 100 \r\n\r\n" data = p.send(SERVER, Response(status_code=200, headers=[("Content-Length", "11")])) - assert data == b"HTTP/1.1 200 \r\ncontent-length: 11\r\n\r\n" + assert data == b"HTTP/1.1 200 \r\nContent-Length: 11\r\n\r\n" for conn in p.conns: assert conn.states == {CLIENT: SEND_BODY, SERVER: SEND_BODY} @@ -243,7 +243,7 @@ def test_server_talking_to_http10_client(): # We automatically Connection: close back at them assert ( c.send(Response(status_code=200, headers=[])) - == b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n" + == b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n" ) assert c.send(Data(data=b"12345")) == b"12345" @@ -303,7 +303,7 @@ def test_automatic_transfer_encoding_in_response(): receive_and_get(c, b"GET / HTTP/1.0\r\n\r\n") assert ( c.send(Response(status_code=200, headers=user_headers)) - == b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n" + == b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n" ) assert c.send(Data(data=b"12345")) == b"12345" @@ -876,7 +876,7 @@ def test_errors(): if role is SERVER: assert ( c.send(Response(status_code=400, headers=[])) - == b"HTTP/1.1 400 \r\nconnection: close\r\n\r\n" + == b"HTTP/1.1 400 \r\nConnection: close\r\n\r\n" ) # After an error sending, you can no longer send @@ -988,14 +988,14 @@ def setup(method, http_version): c = setup(method, b"1.1") assert ( c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" - b"transfer-encoding: chunked\r\n\r\n" + b"Transfer-Encoding: chunked\r\n\r\n" ) # No Content-Length, HTTP/1.0 peer, frame with connection: close c = setup(method, b"1.0") assert ( c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" - b"connection: close\r\n\r\n" + b"Connection: close\r\n\r\n" ) # Content-Length + Transfer-Encoding, TE wins @@ -1011,7 +1011,7 @@ def setup(method, http_version): ) ) == b"HTTP/1.1 200 \r\n" - b"transfer-encoding: chunked\r\n\r\n" + b"Transfer-Encoding: chunked\r\n\r\n" ) diff --git a/h11/tests/test_events.py b/h11/tests/test_events.py index d6bb931..07ffc13 100644 --- a/h11/tests/test_events.py +++ b/h11/tests/test_events.py @@ -163,3 +163,19 @@ def test_intenum_status_code(): assert r.status_code == HTTPStatus.OK assert type(r.status_code) is not type(HTTPStatus.OK) assert type(r.status_code) is int + + +def test_header_casing(): + r = Request( + method="GET", + target="/", + headers=[("Host", "example.org"), ("Connection", "keep-alive")], + http_version="1.1", + ) + assert len(r.headers) == 2 + assert r.headers[0] == (b"host", b"example.org") + assert r.headers == [(b"host", b"example.org"), (b"connection", b"keep-alive")] + assert r.headers.raw_items() == [ + (b"Host", b"example.org"), + (b"Connection", b"keep-alive"), + ] diff --git a/h11/tests/test_headers.py b/h11/tests/test_headers.py index 67bcd7b..5c16102 100644 --- a/h11/tests/test_headers.py +++ b/h11/tests/test_headers.py @@ -83,7 +83,7 @@ def test_get_set_comma_header(): assert get_comma_header(headers, b"connection") == [b"close", b"foo", b"bar"] - set_comma_header(headers, b"newthing", ["a", "b"]) + headers = set_comma_header(headers, b"newthing", ["a", "b"]) with pytest.raises(LocalProtocolError): set_comma_header(headers, b"newthing", [" a", "b"]) @@ -96,7 +96,7 @@ def test_get_set_comma_header(): (b"newthing", b"b"), ] - set_comma_header(headers, b"whatever", ["different thing"]) + headers = set_comma_header(headers, b"whatever", ["different thing"]) assert headers == [ (b"connection", b"close"), diff --git a/h11/tests/test_io.py b/h11/tests/test_io.py index ef5e31b..5ade99b 100644 --- a/h11/tests/test_io.py +++ b/h11/tests/test_io.py @@ -1,7 +1,7 @@ import pytest from .._events import * -from .._headers import normalize_and_validate +from .._headers import normalize_and_validate, Headers from .._readers import ( _obsolete_line_fold, ChunkedReader, @@ -31,12 +31,12 @@ target="/a", headers=[("Host", "foo"), ("Connection", "close")], ), - b"GET /a HTTP/1.1\r\nhost: foo\r\nconnection: close\r\n\r\n", + b"GET /a HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", ), ( (SERVER, SEND_RESPONSE), Response(status_code=200, headers=[("Connection", "close")], reason=b"OK"), - b"HTTP/1.1 200 OK\r\nconnection: close\r\n\r\n", + b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n", ), ( (SERVER, SEND_RESPONSE), @@ -48,7 +48,7 @@ InformationalResponse( status_code=101, headers=[("Upgrade", "websocket")], reason=b"Upgrade" ), - b"HTTP/1.1 101 Upgrade\r\nupgrade: websocket\r\n\r\n", + b"HTTP/1.1 101 Upgrade\r\nUpgrade: websocket\r\n\r\n", ), ( (SERVER, SEND_RESPONSE), @@ -121,7 +121,7 @@ def test_writers_unusual(): normalize_and_validate([("foo", "bar"), ("baz", "quux")]), b"foo: bar\r\nbaz: quux\r\n\r\n", ) - tw(write_headers, [], b"\r\n") + tw(write_headers, Headers([]), b"\r\n") # We understand HTTP/1.0, but we don't speak it with pytest.raises(LocalProtocolError): @@ -435,7 +435,7 @@ def test_ChunkedWriter(): assert ( dowrite(w, EndOfMessage(headers=[("Etag", "asdf"), ("a", "b")])) - == b"0\r\netag: asdf\r\na: b\r\n\r\n" + == b"0\r\nEtag: asdf\r\na: b\r\n\r\n" ) @@ -503,5 +503,5 @@ def test_host_comes_first(): tw( write_headers, normalize_and_validate([("foo", "bar"), ("Host", "example.com")]), - b"host: example.com\r\nfoo: bar\r\n\r\n", + b"Host: example.com\r\nfoo: bar\r\n\r\n", )