From 328606e3716cf8d2457b1a5c4a436c757620c9fd Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Wed, 6 Jul 2022 18:49:56 +0300 Subject: [PATCH] Try to fix errors of using WebSockets closed from client side --- platform_monitoring/api.py | 193 ++++++++++++++++------------------ tests/integration/test_api.py | 96 +++++------------ 2 files changed, 121 insertions(+), 168 deletions(-) diff --git a/platform_monitoring/api.py b/platform_monitoring/api.py index 982c3b73..c0a0626e 100644 --- a/platform_monitoring/api.py +++ b/platform_monitoring/api.py @@ -2,7 +2,7 @@ import json import logging import random -from collections.abc import AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine from contextlib import AsyncExitStack, asynccontextmanager, suppress from importlib.metadata import version from pathlib import Path @@ -65,7 +65,7 @@ ContainerRuntimeClientRegistry, ) from .jobs_service import JobException, JobNotRunningException, JobsService -from .kube_client import JobError, KubeClient, KubeTelemetry +from .kube_client import KubeClient, KubeTelemetry from .log_cleanup_poller import LogCleanupPoller from .logs import ( DEFAULT_ARCHIVE_DELAY, @@ -81,6 +81,7 @@ ) WS_ATTACH_PROTOCOL = "v2.channels.neu.ro" +HEARTBEAT = 30 logger = logging.getLogger(__name__) @@ -223,12 +224,6 @@ async def ws_log(self, request: Request) -> StreamResponse: if separator is None: separator = "=== Live logs ===" + _getrandbytes(30).hex() - response = WebSocketResponse(protocols=[WS_ATTACH_PROTOCOL], heartbeat=30) - - async def _ws_reader() -> None: - while True: - await response.receive() - async def stop_func() -> bool: return self._jobs_helper.is_job_finished(await self._get_job(job.id)) @@ -241,17 +236,13 @@ async def stop_func() -> bool: archive_delay_s=archive_delay_s, stop_func=stop_func, ) as it: + response = WebSocketResponse(heartbeat=HEARTBEAT) await response.prepare(request) - ws_reader_task = asyncio.create_task(_ws_reader()) - try: - async for chunk in it: - await response.send_bytes(chunk) - finally: - ws_reader_task.cancel() - with suppress(asyncio.CancelledError): - await ws_reader_task - - return response + await _run_concurrently( + _listen(response), + _forward_bytes_iterating(response, it), + ) + return response async def drop_log(self, request: Request) -> Response: job = await self._resolve_job(request, "write") @@ -263,55 +254,44 @@ async def drop_log(self, request: Request) -> Response: async def stream_top(self, request: Request) -> WebSocketResponse: job = await self._resolve_job(request, "read") - logger.info("Websocket connection starting") - ws = WebSocketResponse() - await ws.prepare(request) - logger.info("Websocket connection ready") - - # TODO (truskovskiyk 09/12/18) remove CancelledError - # https://github.com/aio-libs/aiohttp/issues/3443 - - # TODO expose configuration - sleep_timeout = 1 - telemetry = await self._get_job_telemetry(job) - async with telemetry: + response = WebSocketResponse() + await response.prepare(request) + await _run_concurrently( + _listen(response), + self._send_telemetry(response, job.id, telemetry), + ) + return response - try: - while True: - # client closed connection - assert request.transport is not None - if request.transport.is_closing(): - break - - # TODO (A Yushkovskiy 06-Jun-2019) don't make slow HTTP requests to - # platform-api to check job's status every iteration: we better - # retrieve this information directly form kubernetes - job = await self._get_job(job.id) - - if self._jobs_helper.is_job_running(job): - job_stats = await telemetry.get_latest_stats() - if job_stats: - message = self._convert_job_stats_to_ws_message(job_stats) - await ws.send_json(message) + async def _send_telemetry( + self, ws: WebSocketResponse, job_id: str, telemetry: Telemetry + ) -> None: + # TODO expose configuration + sleep_timeout = 1 - if self._jobs_helper.is_job_finished(job): + while not ws.closed: + # TODO (A Yushkovskiy 06-Jun-2019) don't make slow HTTP requests to + # platform-api to check job's status every iteration: we better + # retrieve this information directly from kubernetes + job = await self._get_job(job_id) + if ws.closed: + break + + if self._jobs_helper.is_job_running(job): + job_stats = await telemetry.get_latest_stats() + if ws.closed: + break + if job_stats: + message = self._convert_job_stats_to_ws_message(job_stats) + await ws.send_json(message) + if ws.closed: break - await asyncio.sleep(sleep_timeout) - - except JobError as e: - raise JobError(f"Failed to get telemetry for job {job.id}: {e}") from e + if self._jobs_helper.is_job_finished(job): + break - except asyncio.CancelledError as ex: - logger.info(f"got cancelled error {ex}") - - finally: - if not ws.closed: - await ws.close() - - return ws + await asyncio.sleep(sleep_timeout) async def _get_job(self, job_id: str) -> Job: return await self._jobs_service.get(job_id) @@ -418,7 +398,9 @@ async def ws_attach(self, request: Request) -> StreamResponse: job = await self._resolve_job(request, "write") - response = WebSocketResponse(protocols=[WS_ATTACH_PROTOCOL], heartbeat=30) + response = WebSocketResponse( + protocols=[WS_ATTACH_PROTOCOL], heartbeat=HEARTBEAT + ) async with self._jobs_service.attach( job, tty=tty, stdin=stdin, stdout=stdout, stderr=stderr @@ -447,7 +429,9 @@ async def ws_exec(self, request: Request) -> StreamResponse: job = await self._resolve_job(request, "write") - response = WebSocketResponse(protocols=[WS_ATTACH_PROTOCOL], heartbeat=30) + response = WebSocketResponse( + protocols=[WS_ATTACH_PROTOCOL], heartbeat=HEARTBEAT + ) async with self._jobs_service.exec( job, cmd=cmd, tty=tty, stdin=stdin, stdout=stdout, stderr=stderr @@ -486,34 +470,41 @@ async def port_forward(self, request: Request) -> StreamResponse: ) try: - response = WebSocketResponse(heartbeat=30) + response = WebSocketResponse(heartbeat=HEARTBEAT) await response.prepare(request) + await _run_concurrently( + _forward_reading(response, reader), + _forward_writing(response, writer), + ) + return response + finally: + writer.close() + await writer.wait_closed() - tasks = [] - try: - tasks.append(asyncio.create_task(_forward_reading(response, reader))) - tasks.append(asyncio.create_task(_forward_writing(response, writer))) - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) +async def _listen(ws: WebSocketResponse) -> None: + # Maintain the WebSocket connetion. + # Process ping-pong game and perform closing handshake. + async for msg in ws: + logger.info(f"Received unexpected WebSocket message: {msg!r}") - return response - finally: - for task in tasks: - if not task.done(): - task.cancel() - with suppress(asyncio.CancelledError): - await task - finally: - writer.close() - await writer.wait_closed() +async def _forward_bytes_iterating( + ws: WebSocketResponse, it: AsyncIterator[bytes] +) -> None: + async for chunk in it: + if ws.closed: + break + await ws.send_bytes(chunk) + if ws.closed: + break async def _forward_reading(ws: WebSocketResponse, reader: asyncio.StreamReader) -> None: - while True: - # 4-6 MB is the typical default socket receive buffer size of Lunix + while not ws.closed: + # 4-6 MB is the typical default socket receive buffer size on Linux data = await reader.read(4 * 1024 * 1024) - if not data: + if not data or ws.closed: break await ws.send_bytes(data) @@ -540,23 +531,15 @@ def __init__( self._closing = False async def transfer(self) -> None: - tasks = [ - asyncio.create_task(self._proxy(self._resp, self._client_resp)), - asyncio.create_task(self._proxy(self._client_resp, self._resp)), - ] - try: - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + await _run_concurrently( + self._proxy(self._resp, self._client_resp), + self._proxy(self._client_resp, self._resp), + ) finally: await self._resp.close() await self._client_resp.close() - for task in tasks: - if not task.done(): - task.cancel() - with suppress(asyncio.CancelledError): - await task - async def _proxy( self, src: Union[WebSocketResponse, ClientWebSocketResponse], @@ -564,7 +547,7 @@ async def _proxy( ) -> None: try: async for msg in src: - if self._closing: + if self._closing or dst.closed: break if msg.type == aiohttp.WSMsgType.BINARY: @@ -576,8 +559,6 @@ async def _proxy( ) else: raise ValueError(f"Unsupported WS message type {msg.type}") - except StopAsyncIteration: - pass finally: self._closing = True @@ -632,10 +613,8 @@ async def handle_exceptions( e.headers["X-Error"] = e.text raise e except Exception as e: - msg_str = ( - f"Unexpected exception: {str(e)}. " f"Path with query: {request.path_qs}." - ) - logging.exception(msg_str) + msg_str = f"Unexpected exception: {str(e)}. Path with query: {request.path_qs}." + logger.exception(msg_str) payload = {"error": msg_str} return json_response( payload, @@ -956,3 +935,17 @@ def _get_bool_param(request: Request, name: str, default: bool = False) -> bool: def _getrandbytes(size: int) -> bytes: return random.getrandbits(size * 8).to_bytes(size, "big") + + +async def _run_concurrently(*coros: Coroutine[Any, Any, None]) -> None: + tasks: list[asyncio.Task[None]] = [] + try: + for coro in coros: + tasks.append(asyncio.create_task(coro)) + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + finally: + for task in tasks: + if not task.done(): + task.cancel() + with suppress(asyncio.CancelledError): + await task diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 989b630f..4bb93592 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -36,21 +36,17 @@ async def expect_prompt(ws: aiohttp.ClientWebSocketResponse) -> bytes: _ansi_re = re.compile(rb"\033\[[;?0-9]*[a-zA-Z]") - _exit_re = re.compile(rb"exit \d+\Z") + _exit_re = re.compile(rb"exit \d+") try: ret: bytes = b"" async with timeout(3): - while not ret.strip().endswith(b"#") and not _exit_re.match(ret.strip()): - msg = await ws.receive() - if msg.type in ( - aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.CLOSING, - aiohttp.WSMsgType.CLOSED, - ): - break - print(msg.data) + async for msg in ws: + assert msg.type == aiohttp.WSMsgType.BINARY assert msg.data[0] == 1 ret += _ansi_re.sub(b"", msg.data[1:]) + ret_strip = ret.strip() + if ret_strip.endswith(b"#") or _exit_re.fullmatch(ret_strip): + break return ret except asyncio.TimeoutError: raise AssertionError(f"[Timeout] {ret!r}") @@ -535,19 +531,10 @@ async def test_top_ok( url = monitoring_api.generate_top_url(job_id=infinite_job) async with client.ws_connect(url, headers=jobs_client.headers) as ws: # TODO move this ws communication to JobClient - while True: - msg = await ws.receive() - if msg.type == aiohttp.WSMsgType.CLOSE: - break - else: - records.append(json.loads(msg.data)) - + async for msg in ws: + assert msg.type == aiohttp.WSMsgType.TEXT + records.append(json.loads(msg.data)) if len(records) == num_request: - # TODO (truskovskiyk 09/12/18) do not use protected prop - # https://github.com/aio-libs/aiohttp/issues/3443 - proto = ws._writer.protocol - assert proto.transport is not None - proto.transport.close() break assert len(records) == num_request @@ -572,10 +559,8 @@ async def test_top_shared_by_name( await share_job(jobs_client.user, regular_user2, job_name) url = monitoring_api.generate_top_url(named_infinite_job) - async with client.ws_connect(url, headers=regular_user2.headers) as ws: - proto = ws._writer.protocol - assert proto.transport is not None - proto.transport.close() + async with client.ws_connect(url, headers=regular_user2.headers): + pass async def test_top_no_permissions_unauthorized( self, @@ -617,28 +602,10 @@ async def test_top_non_running_job( job_id=infinite_job, status="cancelled" ) - num_request = 2 - records = [] - url = monitoring_api.generate_top_url(job_id=infinite_job) async with client.ws_connect(url, headers=jobs_client.headers) as ws: - # TODO move this ws communication to JobClient - while True: - msg = await ws.receive() - if msg.type == aiohttp.WSMsgType.CLOSE: - break - else: - records.append(json.loads(msg.data)) - - if len(records) == num_request: - # TODO (truskovskiyk 09/12/18) do not use protected prop - # https://github.com/aio-libs/aiohttp/issues/3443 - proto = ws._writer.protocol - assert proto.transport is not None - proto.transport.close() - break - - assert not records + msg = await ws.receive() + assert msg.type == aiohttp.WSMsgType.CLOSE async def test_top_non_existing_job( self, @@ -675,7 +642,7 @@ async def test_top_silently_wait_when_job_pending( job_id = await jobs_client.run_job(job_submit) num_request = 2 - records = [] + records: list[Any] = [] headers = jobs_client.headers job_top_url = monitoring_api.generate_top_url(job_id) @@ -683,25 +650,13 @@ async def test_top_silently_wait_when_job_pending( job = await jobs_client.get_job_by_id(job_id=job_id) assert job["status"] == "pending" - # silently waiting for a job becomes running - msg = await ws.receive() - job = await jobs_client.get_job_by_id(job_id=job_id) - assert job["status"] == "running" - assert msg.type == aiohttp.WSMsgType.TEXT - - while True: - msg = await ws.receive() - if msg.type == aiohttp.WSMsgType.CLOSE: - break - else: - records.append(json.loads(msg.data)) - + async for msg in ws: + assert msg.type == aiohttp.WSMsgType.TEXT + if not records: # First message. + job = await jobs_client.get_job_by_id(job_id=job_id) + assert job["status"] == "running" + records.append(json.loads(msg.data)) if len(records) == num_request: - # TODO (truskovskiyk 09/12/18) do not use protected prop - # https://github.com/aio-libs/aiohttp/issues/3443 - proto = ws._writer.protocol - assert proto.transport is not None - proto.transport.close() break assert len(records) == num_request @@ -731,9 +686,8 @@ async def test_top_close_when_job_succeeded( job_top_url = monitoring_api.generate_top_url(job_id) async with client.ws_connect(job_top_url, headers=headers) as ws: msg = await ws.receive() - job = await jobs_client.get_job_by_id(job_id=job_id) - assert msg.type == aiohttp.WSMsgType.CLOSE + job = await jobs_client.get_job_by_id(job_id=job_id) assert job["status"] == "succeeded" await jobs_client.delete_job(job_id=job_id) @@ -763,7 +717,7 @@ async def test_log_no_auth_token_provided_unauthorized( jobs_client: JobsClient, infinite_job: str, ) -> None: - url = monitoring_api.generate_top_url(job_id=infinite_job) + url = monitoring_api.generate_log_url(job_id=infinite_job) async with client.get(url) as resp: assert resp.status == HTTPUnauthorized.status_code @@ -809,6 +763,7 @@ async def test_job_log_ws( async with client.ws_connect(url, headers=headers) as ws: ws_data = [] async for msg in ws: + assert msg.type == aiohttp.WSMsgType.BINARY ws_data.append(msg.data) actual_payload = b"".join(ws_data) @@ -1108,6 +1063,7 @@ async def test_attach_nontty_stdout( content = [] async for msg in ws: + assert msg.type == aiohttp.WSMsgType.BINARY content.append(msg.data) expected = ( @@ -1146,6 +1102,7 @@ async def test_attach_nontty_stdout_shared_by_name( content = [] async for msg in ws: + assert msg.type == aiohttp.WSMsgType.BINARY content.append(msg.data) expected = ( @@ -1178,6 +1135,7 @@ async def test_attach_nontty_stderr( content = [] async for msg in ws: + assert msg.type == aiohttp.WSMsgType.BINARY content.append(msg.data) expected = ( @@ -1252,6 +1210,7 @@ async def test_attach_tty_exit_code( ): break + assert msg.type == aiohttp.WSMsgType.BINARY if msg.data[0] == 3: payload = json.loads(msg.data[1:]) # Zero code is returned even if we exited with non zero code. @@ -1427,6 +1386,7 @@ async def test_exec_tty_exit_code( ): break + assert msg.type == aiohttp.WSMsgType.BINARY if msg.data[0] == 3: payload = json.loads(msg.data[1:]) assert payload["exit_code"] == 42