Skip to content

Commit

Permalink
[PR #8682/490fca61 backport][3.10] Reduce WebSocket frame parser comp…
Browse files Browse the repository at this point in the history
…lexity (#8727)

Co-authored-by: J. Nick Koston <nick@koston.org>
  • Loading branch information
patchback[bot] and bdraco authored Aug 17, 2024
1 parent ebec945 commit 635ae62
Showing 1 changed file with 93 additions and 107 deletions.
200 changes: 93 additions & 107 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def __init__(
self._frame_opcode: Optional[int] = None
self._frame_payload = bytearray()

self._tail = b""
self._tail: bytes = b""
self._has_mask = False
self._frame_mask: Optional[bytes] = None
self._payload_length = 0
Expand Down Expand Up @@ -447,126 +447,113 @@ def parse_frame(
self, buf: bytes
) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
"""Return the next frame from the socket."""
frames = []
frames: List[Tuple[bool, Optional[int], bytearray, Optional[bool]]] = []
if self._tail:
buf, self._tail = self._tail + buf, b""

start_pos = 0
start_pos: int = 0
buf_length = len(buf)

while True:
# read header
if self._state == WSParserState.READ_HEADER:
if buf_length - start_pos >= 2:
data = buf[start_pos : start_pos + 2]
start_pos += 2
first_byte, second_byte = data

fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF

# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
if self._state is WSParserState.READ_HEADER:
if buf_length - start_pos < 2:
break
data = buf[start_pos : start_pos + 2]
start_pos += 2
first_byte, second_byte = data

fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF

# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)

if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)

has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F

# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be " "larger than 125 bytes",
)
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be " "larger than 125 bytes",
)

# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed is None:
self._compressed = True if rsv1 else False
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed is None:
self._compressed = True if rsv1 else False
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)

self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_length_flag = length
self._state = WSParserState.READ_PAYLOAD_LENGTH
else:
break
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_length_flag = length
self._state = WSParserState.READ_PAYLOAD_LENGTH

# read payload length
if self._state == WSParserState.READ_PAYLOAD_LENGTH:
if self._state is WSParserState.READ_PAYLOAD_LENGTH:
length = self._payload_length_flag
if length == 126:
if buf_length - start_pos >= 2:
data = buf[start_pos : start_pos + 2]
start_pos += 2
length = UNPACK_LEN2(data)[0]
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD
)
else:
if buf_length - start_pos < 2:
break
data = buf[start_pos : start_pos + 2]
start_pos += 2
length = UNPACK_LEN2(data)[0]
self._payload_length = length
elif length > 126:
if buf_length - start_pos >= 8:
data = buf[start_pos : start_pos + 8]
start_pos += 8
length = UNPACK_LEN3(data)[0]
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD
)
else:
if buf_length - start_pos < 8:
break
data = buf[start_pos : start_pos + 8]
start_pos += 8
length = UNPACK_LEN3(data)[0]
self._payload_length = length
else:
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD
)

self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD
)

# read payload mask
if self._state == WSParserState.READ_PAYLOAD_MASK:
if buf_length - start_pos >= 4:
self._frame_mask = buf[start_pos : start_pos + 4]
start_pos += 4
self._state = WSParserState.READ_PAYLOAD
else:
if self._state is WSParserState.READ_PAYLOAD_MASK:
if buf_length - start_pos < 4:
break
self._frame_mask = buf[start_pos : start_pos + 4]
start_pos += 4
self._state = WSParserState.READ_PAYLOAD

if self._state == WSParserState.READ_PAYLOAD:
if self._state is WSParserState.READ_PAYLOAD:
length = self._payload_length
payload = self._frame_payload

Expand All @@ -580,19 +567,18 @@ def parse_frame(
payload.extend(buf[start_pos : start_pos + length])
start_pos = start_pos + length

if self._payload_length == 0:
if self._has_mask:
assert self._frame_mask is not None
_websocket_mask(self._frame_mask, payload)
if self._payload_length != 0:
break

frames.append(
(self._frame_fin, self._frame_opcode, payload, self._compressed)
)
if self._has_mask:
assert self._frame_mask is not None
_websocket_mask(self._frame_mask, payload)

self._frame_payload = bytearray()
self._state = WSParserState.READ_HEADER
else:
break
frames.append(
(self._frame_fin, self._frame_opcode, payload, self._compressed)
)
self._frame_payload = bytearray()
self._state = WSParserState.READ_HEADER

self._tail = buf[start_pos:]

Expand Down

0 comments on commit 635ae62

Please sign in to comment.