Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR #8736/1b88af2 backport][3.10] Improve performance of WebSocketReader #8743

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/8736.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved performance of the WebSocket reader -- by :user:`bdraco`.
200 changes: 104 additions & 96 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ class WSMsgType(IntEnum):
error = ERROR


MESSAGE_TYPES_WITH_CONTENT: Final = (
WSMsgType.BINARY,
WSMsgType.TEXT,
WSMsgType.CONTINUATION,
)

WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"


Expand Down Expand Up @@ -313,17 +319,101 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
return True, data

try:
return self._feed_data(data)
self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return True, b""

def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
return False, b""

def _feed_data(self, data: bytes) -> None:
for fin, opcode, payload, compressed in self.parse_frame(data):
if compressed and not self._decompressobj:
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
if opcode == WSMsgType.CLOSE:
if opcode in MESSAGE_TYPES_WITH_CONTENT:
# load text/binary
is_continuation = opcode == WSMsgType.CONTINUATION
if not fin:
# got partial frame payload
if not is_continuation:
self._opcode = opcode
self._partial += payload
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size
),
)
continue

has_partial = bool(self._partial)
if is_continuation:
if self._opcode is None:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
opcode = self._opcode
self._opcode = None
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
"to be zero, got {!r}".format(opcode),
)

if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload

if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(assembled_payload), self._max_msg_size
),
)

# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(
suppress_deflate_header=True
)
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + _WS_DEFLATE_TRAILING, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Decompressed message size {} exceeds limit {}".format(
self._max_msg_size + left, self._max_msg_size
),
)
else:
payload_merged = bytes(assembled_payload)

if opcode == WSMsgType.TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc

self.queue.feed_data(WSMessage(WSMsgType.TEXT, text, ""), len(text))
continue

self.queue.feed_data(
WSMessage(WSMsgType.BINARY, payload_merged, ""), len(payload_merged)
)
elif opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
Expand Down Expand Up @@ -358,90 +448,10 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
WSMessage(WSMsgType.PONG, payload, ""), len(payload)
)

elif (
opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
and self._opcode is None
):
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
else:
# load text/binary
if not fin:
# got partial frame payload
if opcode != WSMsgType.CONTINUATION:
self._opcode = opcode
self._partial.extend(payload)
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size
),
)
else:
# previous frame was non finished
# we should get continuation opcode
if self._partial:
if opcode != WSMsgType.CONTINUATION:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
"to be zero, got {!r}".format(opcode),
)

if opcode == WSMsgType.CONTINUATION:
assert self._opcode is not None
opcode = self._opcode
self._opcode = None

self._partial.extend(payload)
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size
),
)

# Decompress process must to be done after all packets
# received.
if compressed:
assert self._decompressobj is not None
self._partial.extend(_WS_DEFLATE_TRAILING)
payload_merged = self._decompressobj.decompress_sync(
self._partial, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Decompressed message size {} exceeds limit {}".format(
self._max_msg_size + left, self._max_msg_size
),
)
else:
payload_merged = bytes(self._partial)

self._partial.clear()

if opcode == WSMsgType.TEXT:
try:
text = payload_merged.decode("utf-8")
self.queue.feed_data(
WSMessage(WSMsgType.TEXT, text, ""), len(text)
)
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
else:
self.queue.feed_data(
WSMessage(WSMsgType.BINARY, payload_merged, ""),
len(payload_merged),
)

return False, b""

def parse_frame(
self, buf: bytes
Expand Down Expand Up @@ -521,23 +531,21 @@ def parse_frame(

# read payload length
if self._state is WSParserState.READ_PAYLOAD_LENGTH:
length = self._payload_length_flag
if length == 126:
length_flag = self._payload_length_flag
if length_flag == 126:
if buf_length - start_pos < 2:
break
data = buf[start_pos : start_pos + 2]
start_pos += 2
length = UNPACK_LEN2(data)[0]
self._payload_length = length
elif length > 126:
self._payload_length = UNPACK_LEN2(data)[0]
elif length_flag > 126:
if buf_length - start_pos < 8:
break
data = buf[start_pos : start_pos + 8]
start_pos += 8
length = UNPACK_LEN3(data)[0]
self._payload_length = length
self._payload_length = UNPACK_LEN3(data)[0]
else:
self._payload_length = length
self._payload_length = length_flag

self._state = (
WSParserState.READ_PAYLOAD_MASK
Expand All @@ -560,11 +568,11 @@ def parse_frame(
chunk_len = buf_length - start_pos
if length >= chunk_len:
self._payload_length = length - chunk_len
payload.extend(buf[start_pos:])
payload += buf[start_pos:]
start_pos = buf_length
else:
self._payload_length = 0
payload.extend(buf[start_pos : start_pos + length])
payload += buf[start_pos : start_pos + length]
start_pos = start_pos + length

if self._payload_length != 0:
Expand Down
Loading