Skip to content

Commit

Permalink
Complete Client websocket deflate support
Browse files Browse the repository at this point in the history
Add tests
  • Loading branch information
fanthos committed Sep 17, 2017
1 parent 52de28c commit 048a8f2
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 48 deletions.
51 changes: 32 additions & 19 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
deprecated_noop, sentinel)
from .http import WS_KEY, WebSocketReader, WebSocketWriter
from .streams import FlowControlDataQueue
from .http_websocket import (extensions_parse as ws_ext_parse,
extensions_gen as ws_ext_gen)


__all__ = (client_exceptions.__all__ + # noqa
Expand Down Expand Up @@ -380,7 +382,7 @@ def _ws_connect(self, url, *,
headers=None,
proxy=None,
proxy_auth=None,
compress=False):
compress=0):

if headers is None:
headers = CIMultiDict()
Expand All @@ -403,7 +405,11 @@ def _ws_connect(self, url, *,
if origin is not None:
headers[hdrs.ORIGIN] = origin
if compress:
headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = 'permessage-deflate'
if compress is True:
compress = 15
extstr = ws_ext_gen(compress=compress)
if extstr:
headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr

# send request
resp = yield from self.get(url, headers=headers,
Expand Down Expand Up @@ -463,22 +469,27 @@ def _ws_connect(self, url, *,
break

# websocket compress
compress_notakeover = False
notakeover = False
if compress:
if hdrs.SEC_WEBSOCKET_EXTENSIONS not in resp.headers:
compress = False
else:
exts = resp.headers[
hdrs.SEC_WEBSOCKET_EXTENSIONS].split(',')
for ext in exts:
params = [x.strip() for x in ext.split(';')]
if params[0] == 'permessage-deflate':
for param in params:
if param == 'client_no_context_takeover':
compress_notakeover = True
break
if compress_notakeover:
break
compress, notakeover = ws_ext_parse(
resp.headers[hdrs.SEC_WEBSOCKET_EXTENSIONS]
)
if compress == 0:
pass
elif compress < 0:
raise WSServerHandshakeError(
resp.request_info,
resp.history,
message='Invalid deflate extension',
code=resp.status,
headers=resp.headers)
elif compress < 8 or compress > 15:
raise WSServerHandshakeError(
resp.request_info,
resp.history,
message='Invalid window size',
code=resp.status,
headers=resp.headers)

proto = resp.connection.protocol
reader = FlowControlDataQueue(
Expand All @@ -487,7 +498,7 @@ def _ws_connect(self, url, *,
resp.connection.writer.set_tcp_nodelay(True)
writer = WebSocketWriter(
resp.connection.writer, use_mask=True,
compress=compress, notakeover=compress_notakeover)
compress=compress, notakeover=notakeover)
except Exception:
resp.close()
raise
Expand All @@ -501,7 +512,9 @@ def _ws_connect(self, url, *,
autoping,
self._loop,
receive_timeout=receive_timeout,
heartbeat=heartbeat)
heartbeat=heartbeat,
compress=compress,
compress_notakeover=notakeover)

def _prepare_headers(self, headers):
""" Add default headers and transform it to CIMultiDict
Expand Down
13 changes: 12 additions & 1 deletion aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class ClientWebSocketResponse:

def __init__(self, reader, writer, protocol,
response, timeout, autoclose, autoping, loop, *,
receive_timeout=None, heartbeat=None):
receive_timeout=None, heartbeat=None,
compress=0, compress_notakeover=False):
self._response = response
self._conn = response.connection

Expand All @@ -35,6 +36,8 @@ def __init__(self, reader, writer, protocol,
self._loop = loop
self._waiting = None
self._exception = None
self._compress = compress
self._compress_notakeover = compress_notakeover

self._reset_heartbeat()

Expand Down Expand Up @@ -82,6 +85,14 @@ def close_code(self):
def protocol(self):
return self._protocol

@property
def compress(self):
return self._compress

@property
def compress_notakeover(self):
return self._compress_notakeover

def get_extra_info(self, name, default=None):
"""extra info from connection transport"""
try:
Expand Down
86 changes: 60 additions & 26 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,48 @@ def _websocket_mask_python(mask, data):
_WS_DEFLATE_TRAILING = bytes([0x00, 0x00, 0xff, 0xff])


def extensions_parse(extstr):
if not extstr:
return 0, False

extensions = [[s.strip() for s in s1.split(';')]
for s1 in extstr.split(',')]
compress = 0
compress_notakeover = False
for ext in extensions:
if ext[0] == 'permessage-deflate':
compress = 15
for param in ext[1:]:
if param.startswith('server_max_window_bits'):
compress = int(param.split('=')[1])
elif param == 'server_no_context_takeover':
compress_notakeover = True
# Ignore Client Takeover
elif param not in ('client_no_context_takeover',
'client_max_window_bits'):
return -1, False
if compress > 15:
raise HttpBadRequest(
message='Handshake error: PMCE window > 15') from None
break

return compress, compress_notakeover


def extensions_gen(compress=0, server_notakeover=False,
client_notakeover=False):
if compress < 8 or compress > 15:
return False
enabledext = 'permessage-deflate'
if compress < 15:
enabledext += '; server_max_window_bits=' + str(compress)
if server_notakeover:
enabledext += '; server_no_context_takeover'
if client_notakeover:
enabledext += '; client_no_context_takeover'
return enabledext


class WSParserState(IntEnum):
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
Expand Down Expand Up @@ -260,6 +302,8 @@ def _feed_data(self, data):

payload_merged = b''.join(self._partial)

# Decompress process must to be done after all packets
# received.
if compressed:
payload_merged = self._decompressobj.decompress(
payload_merged + _WS_DEFLATE_TRAILING)
Expand Down Expand Up @@ -435,7 +479,7 @@ class WebSocketWriter:

def __init__(self, stream, *,
use_mask=False, limit=DEFAULT_LIMIT, random=random.Random(),
compress=False, notakeover=False):
compress=0, notakeover=False):
self.stream = stream
self.writer = stream.transport
self.use_mask = use_mask
Expand All @@ -458,11 +502,12 @@ def _send_frame(self, message, opcode):
# Does small packet needs to be compressed?
# if self.compress and opcode < 8 and len(message) > 124:
if self.compress and opcode < 8:
if not self._compressobj or self.notakeover:
if not self._compressobj:
self._compressobj = zlib.compressobj(wbits=-self.compress)

message = self._compressobj.compress(message)
message = message + self._compressobj.flush(zlib.Z_FULL_FLUSH)
message = message + self._compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH)
if message.endswith(_WS_DEFLATE_TRAILING):
message = message[:-4]
rsv = rsv | 0x40
Expand Down Expand Up @@ -607,29 +652,18 @@ def do_handshake(method, headers, stream,
compress_notakeover = False

extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
if extensions:
extensions = [[s.strip() for s in s1.split(';')]
for s1 in extensions.split(',')]

for ext in extensions:
if ext[0] == 'permessage-deflate':
enabledext = ['permessage-deflate']
compress = 15
for param in ext[1:]:
if param.startswith('server_max_window_bits'):
compress = int(param.split('=')[1])
enabledext.append((
'server_max_window_bits=' + str(compress)))
elif param == 'server_no_context_takeover':
compress_notakeover = True
enabledext.append(('server_no_context_takeover'))
# Ignore Client Takeover
# elif param == 'client_no_context_takeover':
# compress_notakeover |= WSCompressNoTakeover.NT_CLIENT

response_headers.append((
hdrs.SEC_WEBSOCKET_EXTENSIONS, '; '.join(enabledext)))
break
compress, compress_notakeover = extensions_parse(extensions)
if compress:
if compress < 0:
raise HttpBadRequest(
message='Handshake error: PMCE bad extensions') from None
if compress > 15:
raise HttpBadRequest(
message='Handshake error: PMCE window > 15') from None

enabledext = extensions_gen(compress=compress,
server_notakeover=compress_notakeover)
response_headers.append((hdrs.SEC_WEBSOCKET_EXTENSIONS, enabledext))

if protocol:
response_headers.append((hdrs.SEC_WEBSOCKET_PROTOCOL, protocol))
Expand Down
28 changes: 28 additions & 0 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,3 +616,31 @@ def handler(request):
yield from resp.receive()

assert ping_received


@asyncio.coroutine
def test_send_recv_compress(loop, test_client):

@asyncio.coroutine
def handler(request):
ws = web.WebSocketResponse()
yield from ws.prepare(request)

msg = yield from ws.receive_str()
yield from ws.send_str(msg+'/answer')
yield from ws.close()
return ws

app = web.Application()
app.router.add_route('GET', '/', handler)
client = yield from test_client(app)
resp = yield from client.ws_connect('/', compress=True)
yield from resp.send_str('ask')

assert resp.compress == 15

data = yield from resp.receive_str()
assert data == 'ask/answer'

yield from resp.close()
assert resp.get_extra_info('socket') is None
39 changes: 38 additions & 1 deletion tests/test_websocket_handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,21 @@ def message():
True, None, True, False, URL('/path'))


def gen_ws_headers(protocols=''):
def gen_ws_headers(protocols='', compress=0, compress_notakeover=False):
key = base64.b64encode(os.urandom(16)).decode()
hdrs = [('Upgrade', 'websocket'),
('Connection', 'upgrade'),
('Sec-Websocket-Version', '13'),
('Sec-Websocket-Key', key)]
if protocols:
hdrs += [('Sec-Websocket-Protocol', protocols)]
if compress:
params = 'permessage-deflate'
if compress < 15:
params += '; server_max_window_bits=' + str(compress)
if compress_notakeover:
params += '; server_no_context_takeover'
hdrs += [('Sec-Websocket-Extensions', params)]
return hdrs, key


Expand Down Expand Up @@ -149,3 +156,33 @@ def test_handshake_protocol_unsupported(log, message, transport):
assert protocol is None
assert (ctx.records[-1].msg ==
'Client protocols %r don’t overlap server-known ones %r')


def test_handshake_compress(message, transport):
hdrs, sec_key = gen_ws_headers(compress=15)

message.headers.extend(hdrs)
status, headers, parser, writer, protocol = do_handshake(
message.method, message.headers, transport)

headers = dict(headers)
assert 'Sec-Websocket-Extensions' in headers
assert headers['Sec-Websocket-Extensions'] == 'permessage-deflate'

assert writer.compress == 15


def test_handshake_compress_notakeover(message, transport):
hdrs, sec_key = gen_ws_headers(compress=15, compress_notakeover=True)

message.headers.extend(hdrs)
status, headers, parser, writer, protocol = do_handshake(
message.method, message.headers, transport)

headers = dict(headers)
assert 'Sec-Websocket-Extensions' in headers
assert headers['Sec-Websocket-Extensions'] == (
'permessage-deflate; server_no_context_takeover')

assert writer.compress == 15
assert writer.notakeover is True
1 change: 0 additions & 1 deletion tests/test_websocket_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def test_parse_compress_frame_single(parser):
assert (1, 1, b'1', True) == (fin, opcode, payload, not not compress)



def test_parse_compress_frame_multi(parser):
parser.parse_frame(struct.pack('!BB', 0b01000001, 126))
parser.parse_frame(struct.pack('!H', 4))
Expand Down
16 changes: 16 additions & 0 deletions tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,19 @@ def test_send_text_masked(stream, writer):
random=random.Random(123))
writer.send(b'text')
stream.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12')


def test_send_compress_text(stream, writer):
writer = WebSocketWriter(stream, compress=15)
writer.send(b'text')
stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
writer.send(b'text')
stream.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00')


def test_send_compress_text_notakeover(stream, writer):
writer = WebSocketWriter(stream, compress=15, notakeover=True)
writer.send(b'text')
stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
writer.send(b'text')
stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')

0 comments on commit 048a8f2

Please sign in to comment.