Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #892 on branch 1.x (Make ChannelQueue.get_msg true async) #893

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions jupyter_server/gateway/managers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import datetime
import json
import os
from logging import Logger
from queue import Queue
from queue import Empty, Queue
from threading import Thread
from time import monotonic
from typing import Any, Dict, Optional

import websocket
Expand Down Expand Up @@ -503,9 +505,24 @@ def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log:
self.channel_socket = channel_socket
self.log = log

async def _async_get(self, timeout=None):
if timeout is None:
timeout = float("inf")
elif timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
end_time = monotonic() + timeout

while True:
try:
return self.get(block=False)
except Empty:
if monotonic() > end_time:
raise
await asyncio.sleep(0)

async def get_msg(self, *args: Any, **kwargs: Any) -> dict:
timeout = kwargs.get("timeout", 1)
msg = self.get(timeout=timeout)
msg = await self._async_get(timeout=timeout)
self.log.debug(
"Received message on channel: {}, msg_id: {}, msg_type: {}".format(
self.channel_name, msg["msg_id"], msg["msg_type"] if msg else "null"
Expand Down
38 changes: 36 additions & 2 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
"""Test GatewayClient"""
import asyncio
import json
import logging
import os
import uuid
from datetime import datetime
from io import BytesIO
from unittest.mock import patch
from queue import Empty
from unittest.mock import MagicMock, patch

import pytest
import tornado
from tornado.httpclient import HTTPRequest, HTTPResponse
from tornado.web import HTTPError

from jupyter_server.gateway.managers import GatewayClient
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient
from jupyter_server.utils import ensure_async

from .utils import expected_http_error
Expand Down Expand Up @@ -318,6 +321,37 @@ async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_ke
assert await is_kernel_running(jp_fetch, k2) is False


async def test_channel_queue_get_msg_with_invalid_timeout():
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())

with pytest.raises(ValueError):
await queue.get_msg(timeout=-1)


async def test_channel_queue_get_msg_raises_empty_after_timeout():
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())

with pytest.raises(Empty):
await asyncio.wait_for(queue.get_msg(timeout=0.1), 2)


async def test_channel_queue_get_msg_without_timeout():
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())

with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(queue.get_msg(timeout=None), 1)


async def test_channel_queue_get_msg_with_existing_item():
sent_message = {"msg_id": 1, "msg_type": 2}
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
queue.put_nowait(sent_message)

received_message = await asyncio.wait_for(queue.get_msg(timeout=None), 1)

assert received_message == sent_message


#
# Test methods below...
#
Expand Down