Skip to content

Commit

Permalink
Try to fix errors of using WebSockets closed from client side
Browse files Browse the repository at this point in the history
  • Loading branch information
serhiy-storchaka committed Jul 15, 2022
1 parent 1a963cf commit 51877fc
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 171 deletions.
199 changes: 96 additions & 103 deletions platform_monitoring/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
import random
from collections.abc import AsyncIterator, Awaitable, Callable
from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from importlib.metadata import version
from pathlib import Path
Expand Down Expand Up @@ -65,7 +65,7 @@
ContainerRuntimeClientRegistry,
)
from .jobs_service import JobException, JobNotRunningException, JobsService
from .kube_client import JobError, KubeClient, KubeTelemetry
from .kube_client import KubeClient, KubeTelemetry
from .log_cleanup_poller import LogCleanupPoller
from .logs import (
DEFAULT_ARCHIVE_DELAY,
Expand All @@ -81,6 +81,7 @@
)

WS_ATTACH_PROTOCOL = "v2.channels.neu.ro"
HEARTBEAT = 30


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -223,12 +224,6 @@ async def ws_log(self, request: Request) -> StreamResponse:
if separator is None:
separator = "=== Live logs ===" + _getrandbytes(30).hex()

response = WebSocketResponse(protocols=[WS_ATTACH_PROTOCOL], heartbeat=30)

async def _ws_reader() -> None:
while True:
await response.receive()

async def stop_func() -> bool:
return self._jobs_helper.is_job_finished(await self._get_job(job.id))

Expand All @@ -241,17 +236,13 @@ async def stop_func() -> bool:
archive_delay_s=archive_delay_s,
stop_func=stop_func,
) as it:
response = WebSocketResponse(heartbeat=HEARTBEAT)
await response.prepare(request)
ws_reader_task = asyncio.create_task(_ws_reader())
try:
async for chunk in it:
await response.send_bytes(chunk)
finally:
ws_reader_task.cancel()
with suppress(asyncio.CancelledError):
await ws_reader_task

return response
await _run_concurrently(
_listen(response),
_forward_bytes_iterating(response, it),
)
return response

async def drop_log(self, request: Request) -> Response:
job = await self._resolve_job(request, "write")
Expand All @@ -263,55 +254,44 @@ async def drop_log(self, request: Request) -> Response:
async def stream_top(self, request: Request) -> WebSocketResponse:
job = await self._resolve_job(request, "read")

logger.info("Websocket connection starting")
ws = WebSocketResponse()
await ws.prepare(request)
logger.info("Websocket connection ready")

# TODO (truskovskiyk 09/12/18) remove CancelledError
# https://github.com/aio-libs/aiohttp/issues/3443

# TODO expose configuration
sleep_timeout = 1

telemetry = await self._get_job_telemetry(job)

async with telemetry:
response = WebSocketResponse()
await response.prepare(request)
await _run_concurrently(
_listen(response),
self._send_telemetry(response, job.id, telemetry),
)
return response

try:
while True:
# client closed connection
assert request.transport is not None
if request.transport.is_closing():
break

# TODO (A Yushkovskiy 06-Jun-2019) don't make slow HTTP requests to
# platform-api to check job's status every iteration: we better
# retrieve this information directly form kubernetes
job = await self._get_job(job.id)

if self._jobs_helper.is_job_running(job):
job_stats = await telemetry.get_latest_stats()
if job_stats:
message = self._convert_job_stats_to_ws_message(job_stats)
await ws.send_json(message)

if self._jobs_helper.is_job_finished(job):
break

await asyncio.sleep(sleep_timeout)

except JobError as e:
raise JobError(f"Failed to get telemetry for job {job.id}: {e}") from e

except asyncio.CancelledError as ex:
logger.info(f"got cancelled error {ex}")

finally:
if not ws.closed:
await ws.close()
async def _send_telemetry(
self, ws: WebSocketResponse, job_id: str, telemetry: Telemetry
) -> None:
# TODO expose configuration
sleep_timeout = 1

return ws
while 1: # not ws.closed:
# TODO (A Yushkovskiy 06-Jun-2019) don't make slow HTTP requests to
# platform-api to check job's status every iteration: we better
# retrieve this information directly from kubernetes
job = await self._get_job(job_id)
# if ws.closed:
# break

if self._jobs_helper.is_job_running(job):
job_stats = await telemetry.get_latest_stats()
# if ws.closed:
# break
if job_stats:
message = self._convert_job_stats_to_ws_message(job_stats)
await ws.send_json(message)
# if ws.closed:
# break

if self._jobs_helper.is_job_finished(job):
break

await asyncio.sleep(sleep_timeout)

async def _get_job(self, job_id: str) -> Job:
return await self._jobs_service.get(job_id)
Expand Down Expand Up @@ -418,7 +398,9 @@ async def ws_attach(self, request: Request) -> StreamResponse:

job = await self._resolve_job(request, "write")

response = WebSocketResponse(protocols=[WS_ATTACH_PROTOCOL], heartbeat=30)
response = WebSocketResponse(
protocols=[WS_ATTACH_PROTOCOL], heartbeat=HEARTBEAT
)

async with self._jobs_service.attach(
job, tty=tty, stdin=stdin, stdout=stdout, stderr=stderr
Expand Down Expand Up @@ -447,7 +429,9 @@ async def ws_exec(self, request: Request) -> StreamResponse:

job = await self._resolve_job(request, "write")

response = WebSocketResponse(protocols=[WS_ATTACH_PROTOCOL], heartbeat=30)
response = WebSocketResponse(
protocols=[WS_ATTACH_PROTOCOL], heartbeat=HEARTBEAT
)

async with self._jobs_service.exec(
job, cmd=cmd, tty=tty, stdin=stdin, stdout=stdout, stderr=stderr
Expand Down Expand Up @@ -486,34 +470,41 @@ async def port_forward(self, request: Request) -> StreamResponse:
)

try:
response = WebSocketResponse(heartbeat=30)
response = WebSocketResponse(heartbeat=HEARTBEAT)
await response.prepare(request)
await _run_concurrently(
_forward_reading(response, reader),
_forward_writing(response, writer),
)
return response
finally:
writer.close()
await writer.wait_closed()

tasks = []
try:
tasks.append(asyncio.create_task(_forward_reading(response, reader)))
tasks.append(asyncio.create_task(_forward_writing(response, writer)))

await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
async def _listen(ws: WebSocketResponse) -> None:
# Maintain the WebSocket connetion.
# Process ping-pong game and perform closing handshake.
async for msg in ws:
logger.info(f"Received unexpected WebSocket message: {msg!r}")

return response
finally:
for task in tasks:
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task

finally:
writer.close()
await writer.wait_closed()
async def _forward_bytes_iterating(
ws: WebSocketResponse, it: AsyncIterator[bytes]
) -> None:
async for chunk in it:
if ws.closed:
break
await ws.send_bytes(chunk)
if ws.closed:
break


async def _forward_reading(ws: WebSocketResponse, reader: asyncio.StreamReader) -> None:
while True:
# 4-6 MB is the typical default socket receive buffer size of Lunix
while not ws.closed:
# 4-6 MB is the typical default socket receive buffer size on Linux
data = await reader.read(4 * 1024 * 1024)
if not data:
if not data or ws.closed:
break
await ws.send_bytes(data)

Expand All @@ -540,31 +531,23 @@ def __init__(
self._closing = False

async def transfer(self) -> None:
tasks = [
asyncio.create_task(self._proxy(self._resp, self._client_resp)),
asyncio.create_task(self._proxy(self._client_resp, self._resp)),
]

try:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
await _run_concurrently(
self._proxy(self._resp, self._client_resp),
self._proxy(self._client_resp, self._resp),
)
finally:
await self._resp.close()
await self._client_resp.close()

for task in tasks:
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task

async def _proxy(
self,
src: Union[WebSocketResponse, ClientWebSocketResponse],
dst: Union[WebSocketResponse, ClientWebSocketResponse],
) -> None:
try:
async for msg in src:
if self._closing:
if self._closing or dst.closed:
break

if msg.type == aiohttp.WSMsgType.BINARY:
Expand All @@ -576,8 +559,6 @@ async def _proxy(
)
else:
raise ValueError(f"Unsupported WS message type {msg.type}")
except StopAsyncIteration:
pass
finally:
self._closing = True

Expand Down Expand Up @@ -632,10 +613,8 @@ async def handle_exceptions(
e.headers["X-Error"] = e.text
raise e
except Exception as e:
msg_str = (
f"Unexpected exception: {str(e)}. " f"Path with query: {request.path_qs}."
)
logging.exception(msg_str)
msg_str = f"Unexpected exception: {str(e)}. Path with query: {request.path_qs}."
logger.exception(msg_str)
payload = {"error": msg_str}
return json_response(
payload,
Expand Down Expand Up @@ -956,3 +935,17 @@ def _get_bool_param(request: Request, name: str, default: bool = False) -> bool:

def _getrandbytes(size: int) -> bytes:
return random.getrandbits(size * 8).to_bytes(size, "big")


async def _run_concurrently(*coros: Coroutine[Any, Any, None]) -> None:
tasks: list[asyncio.Task[None]] = []
try:
for coro in coros:
tasks.append(asyncio.create_task(coro))
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
finally:
for task in tasks:
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task
Loading

0 comments on commit 51877fc

Please sign in to comment.