Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to fix errors of using WebSockets closed from client side #737

Merged
merged 1 commit into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 93 additions & 100 deletions platform_monitoring/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
import random
from collections.abc import AsyncIterator, Awaitable, Callable
from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from importlib.metadata import version
from pathlib import Path
Expand Down Expand Up @@ -65,7 +65,7 @@
ContainerRuntimeClientRegistry,
)
from .jobs_service import JobException, JobNotRunningException, JobsService
from .kube_client import JobError, KubeClient, KubeTelemetry
from .kube_client import KubeClient, KubeTelemetry
from .log_cleanup_poller import LogCleanupPoller
from .logs import (
DEFAULT_ARCHIVE_DELAY,
Expand All @@ -81,6 +81,7 @@
)

WS_ATTACH_PROTOCOL = "v2.channels.neu.ro"
HEARTBEAT = 30


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -223,12 +224,6 @@ async def ws_log(self, request: Request) -> StreamResponse:
if separator is None:
separator = "=== Live logs ===" + _getrandbytes(30).hex()

response = WebSocketResponse(protocols=[WS_ATTACH_PROTOCOL], heartbeat=30)

async def _ws_reader() -> None:
while True:
await response.receive()

async def stop_func() -> bool:
return self._jobs_helper.is_job_finished(await self._get_job(job.id))

Expand All @@ -241,17 +236,13 @@ async def stop_func() -> bool:
archive_delay_s=archive_delay_s,
stop_func=stop_func,
) as it:
response = WebSocketResponse(heartbeat=HEARTBEAT)
await response.prepare(request)
ws_reader_task = asyncio.create_task(_ws_reader())
try:
async for chunk in it:
await response.send_bytes(chunk)
finally:
ws_reader_task.cancel()
with suppress(asyncio.CancelledError):
await ws_reader_task

return response
await _run_concurrently(
_listen(response),
_forward_bytes_iterating(response, it),
)
return response

async def drop_log(self, request: Request) -> Response:
job = await self._resolve_job(request, "write")
Expand All @@ -263,55 +254,44 @@ async def drop_log(self, request: Request) -> Response:
async def stream_top(self, request: Request) -> WebSocketResponse:
job = await self._resolve_job(request, "read")

logger.info("Websocket connection starting")
ws = WebSocketResponse()
await ws.prepare(request)
logger.info("Websocket connection ready")

# TODO (truskovskiyk 09/12/18) remove CancelledError
# https://github.com/aio-libs/aiohttp/issues/3443

# TODO expose configuration
sleep_timeout = 1

telemetry = await self._get_job_telemetry(job)

async with telemetry:
response = WebSocketResponse()
await response.prepare(request)
await _run_concurrently(
_listen(response),
self._send_telemetry(response, job.id, telemetry),
)
return response

try:
while True:
# client closed connection
assert request.transport is not None
if request.transport.is_closing():
break

# TODO (A Yushkovskiy 06-Jun-2019) don't make slow HTTP requests to
# platform-api to check job's status every iteration: we better
# retrieve this information directly form kubernetes
job = await self._get_job(job.id)

if self._jobs_helper.is_job_running(job):
job_stats = await telemetry.get_latest_stats()
if job_stats:
message = self._convert_job_stats_to_ws_message(job_stats)
await ws.send_json(message)
async def _send_telemetry(
self, ws: WebSocketResponse, job_id: str, telemetry: Telemetry
) -> None:
# TODO expose configuration
sleep_timeout = 1

if self._jobs_helper.is_job_finished(job):
while not ws.closed:
# TODO (A Yushkovskiy 06-Jun-2019) don't make slow HTTP requests to
# platform-api to check job's status every iteration: we better
# retrieve this information directly from kubernetes
job = await self._get_job(job_id)
if ws.closed:
break

if self._jobs_helper.is_job_running(job):
job_stats = await telemetry.get_latest_stats()
if ws.closed:
break
if job_stats:
message = self._convert_job_stats_to_ws_message(job_stats)
await ws.send_json(message)
if ws.closed:
break

await asyncio.sleep(sleep_timeout)

except JobError as e:
raise JobError(f"Failed to get telemetry for job {job.id}: {e}") from e
if self._jobs_helper.is_job_finished(job):
break

except asyncio.CancelledError as ex:
logger.info(f"got cancelled error {ex}")

finally:
if not ws.closed:
await ws.close()

return ws
await asyncio.sleep(sleep_timeout)

async def _get_job(self, job_id: str) -> Job:
return await self._jobs_service.get(job_id)
Expand Down Expand Up @@ -418,7 +398,9 @@ async def ws_attach(self, request: Request) -> StreamResponse:

job = await self._resolve_job(request, "write")

response = WebSocketResponse(protocols=[WS_ATTACH_PROTOCOL], heartbeat=30)
response = WebSocketResponse(
protocols=[WS_ATTACH_PROTOCOL], heartbeat=HEARTBEAT
)

async with self._jobs_service.attach(
job, tty=tty, stdin=stdin, stdout=stdout, stderr=stderr
Expand Down Expand Up @@ -447,7 +429,9 @@ async def ws_exec(self, request: Request) -> StreamResponse:

job = await self._resolve_job(request, "write")

response = WebSocketResponse(protocols=[WS_ATTACH_PROTOCOL], heartbeat=30)
response = WebSocketResponse(
protocols=[WS_ATTACH_PROTOCOL], heartbeat=HEARTBEAT
)

async with self._jobs_service.exec(
job, cmd=cmd, tty=tty, stdin=stdin, stdout=stdout, stderr=stderr
Expand Down Expand Up @@ -486,34 +470,41 @@ async def port_forward(self, request: Request) -> StreamResponse:
)

try:
response = WebSocketResponse(heartbeat=30)
response = WebSocketResponse(heartbeat=HEARTBEAT)
await response.prepare(request)
await _run_concurrently(
_forward_reading(response, reader),
_forward_writing(response, writer),
)
return response
finally:
writer.close()
await writer.wait_closed()

tasks = []
try:
tasks.append(asyncio.create_task(_forward_reading(response, reader)))
tasks.append(asyncio.create_task(_forward_writing(response, writer)))

await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
async def _listen(ws: WebSocketResponse) -> None:
# Maintain the WebSocket connetion.
# Process ping-pong game and perform closing handshake.
async for msg in ws:
logger.info(f"Received unexpected WebSocket message: {msg!r}")

return response
finally:
for task in tasks:
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task

finally:
writer.close()
await writer.wait_closed()
async def _forward_bytes_iterating(
ws: WebSocketResponse, it: AsyncIterator[bytes]
) -> None:
async for chunk in it:
if ws.closed:
break
await ws.send_bytes(chunk)
if ws.closed:
break


async def _forward_reading(ws: WebSocketResponse, reader: asyncio.StreamReader) -> None:
while True:
# 4-6 MB is the typical default socket receive buffer size of Lunix
while not ws.closed:
# 4-6 MB is the typical default socket receive buffer size on Linux
data = await reader.read(4 * 1024 * 1024)
if not data:
if not data or ws.closed:
break
await ws.send_bytes(data)

Expand All @@ -540,31 +531,23 @@ def __init__(
self._closing = False

async def transfer(self) -> None:
tasks = [
asyncio.create_task(self._proxy(self._resp, self._client_resp)),
asyncio.create_task(self._proxy(self._client_resp, self._resp)),
]

try:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
await _run_concurrently(
self._proxy(self._resp, self._client_resp),
self._proxy(self._client_resp, self._resp),
)
finally:
await self._resp.close()
await self._client_resp.close()

for task in tasks:
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task

async def _proxy(
self,
src: Union[WebSocketResponse, ClientWebSocketResponse],
dst: Union[WebSocketResponse, ClientWebSocketResponse],
) -> None:
try:
async for msg in src:
if self._closing:
if self._closing or dst.closed:
break

if msg.type == aiohttp.WSMsgType.BINARY:
Expand All @@ -576,8 +559,6 @@ async def _proxy(
)
else:
raise ValueError(f"Unsupported WS message type {msg.type}")
except StopAsyncIteration:
pass
finally:
self._closing = True

Expand Down Expand Up @@ -632,10 +613,8 @@ async def handle_exceptions(
e.headers["X-Error"] = e.text
raise e
except Exception as e:
msg_str = (
f"Unexpected exception: {str(e)}. " f"Path with query: {request.path_qs}."
)
logging.exception(msg_str)
msg_str = f"Unexpected exception: {str(e)}. Path with query: {request.path_qs}."
logger.exception(msg_str)
payload = {"error": msg_str}
return json_response(
payload,
Expand Down Expand Up @@ -956,3 +935,17 @@ def _get_bool_param(request: Request, name: str, default: bool = False) -> bool:

def _getrandbytes(size: int) -> bytes:
return random.getrandbits(size * 8).to_bytes(size, "big")


async def _run_concurrently(*coros: Coroutine[Any, Any, None]) -> None:
tasks: list[asyncio.Task[None]] = []
try:
for coro in coros:
tasks.append(asyncio.create_task(coro))
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
finally:
for task in tasks:
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task
Loading