From 46c3970fd6f2904ab9d37080b64cc25bc573a03c Mon Sep 17 00:00:00 2001 From: Kafonek Date: Sat, 1 Oct 2022 10:06:42 -0400 Subject: [PATCH] comments, docstrings, reordering code to be more intuitive to new readers --- sending/backends/jupyter.py | 184 ++++++++++++++++++++++-------------- 1 file changed, 114 insertions(+), 70 deletions(-) diff --git a/sending/backends/jupyter.py b/sending/backends/jupyter.py index b7da301..0f10839 100644 --- a/sending/backends/jupyter.py +++ b/sending/backends/jupyter.py @@ -4,6 +4,7 @@ from jupyter_client import AsyncKernelClient from jupyter_client.channels import ZMQSocketChannel +import jupyter_client.session import zmq from zmq.asyncio import Context @@ -22,7 +23,12 @@ def __init__( ): super().__init__() self.connection_info = connection_info + # If max_message_size is set, we'll disconnect (and reconnect immediately) to zmq + # channels that try to send a message greater than that size. It prevents applications + # from OOM crashing reading in large outputs or other messages self.max_message_size = max_message_size + # Tasks that ultiamtely watch the zmq channels for messages. Keep track of these + # for cleanup (unsubscribe_from_topic, shutdown) self.channel_tasks: Dict[str, List[asyncio.Task]] = collections.defaultdict(list) async def initialize( @@ -46,6 +52,8 @@ async def initialize( ) async def shutdown(self, now=False): + # Cancelling channel watching tasks here is equivalent to shutting down poll worker + # in other Sending backend implementations. for topic_name, task_list in self.channel_tasks.items(): for task in task_list: task.cancel() @@ -56,32 +64,84 @@ async def shutdown(self, now=False): def set_context_option(self, option: int, val: Union[int, bytes]): self._context.setsockopt(option, val) - async def watch_for_channel_messages(self, topic_name: str, channel_obj: ZMQSocketChannel): + def send( + self, + topic_name: str, + msg_type: str, + content: Optional[dict], + parent: Optional[dict] = None, + header: Optional[dict] = None, + metadata: Optional[dict] = None, + ): """ - Read in any messages on a specific jupyter_client channel and drop them into the inbound - worker queue which will trigger registered callback functions by predicate / topic + Put a message onto the outbound queue which will be picked up by the outbound + worker and sent over zmq to the Kernel. Most messages will get sent over the shell + channel, although some may go over control as wel. + + Example: + mgr.send("shell", "execute_request", {"code": "print('hello')", "silent": False}) """ - while True: - msg: dict = await channel_obj.get_msg() - self.schedule_for_delivery(topic_name, msg) + # format the message into a Jupyter specced dictionary then drop into outbound queue + # to get sent over the wire when outbound worker calls ._publish + jupyter_session: jupyter_client.session.Session = self._client.session + jupyter_msg: dict = jupyter_session.msg(msg_type, content, parent, header, metadata) + self.outbound_queue.put_nowait(QueuedMessage(topic_name, jupyter_msg, None)) - async def watch_for_disconnect(self, monitor_socket: zmq.Socket): + async def _publish(self, message: QueuedMessage): """ - An awaitable task that ends when a particular socket has a disconnect event. Used in - conjunction with watch_for_channel_messages to cycle a socket when it's disconnected. + When the outbound worker observes a message on the outbound queue, it will call this + method to actually send the message over the wire. """ - while True: - msg: dict = await recv_monitor_message(monitor_socket) - event: zmq.Event = msg["event"] - if event == zmq.EVENT_DISCONNECTED: - return + topic_name = message.topic + if topic_name not in self.subscribed_topics: + await self._create_topic_subscription(topic_name) + if hasattr(self._client, f"{topic_name}_channel"): + channel_obj: ZMQSocketChannel = getattr(self._client, f"{topic_name}_channel") + channel_obj.send(message.contents) + + # Normally in Sending backends there is the concept of a poll_worker which calls into _poll + # as part of a custom _poll_loop implementation. The poll_worker is what reads data over the + # wire (redis, socket, websocket, etc. zmq in the case of Jupyter/ipykernel). However the way + # this backend is written, reading data from zmq after subscribe_to_topic is called is handled + # by _watch_channel task (and its child tasks). poll_worker and these _poll methods do nothing. + async def _poll(self): + pass + + async def _poll_loop(self): + pass - async def watch_channel(self, topic_name: str): + async def _create_topic_subscription(self, topic_name: str): + """ + Start observing messages on a zmq channel after a call to mgr.subscribe_to_topic('iopub') + """ + task = asyncio.create_task(self._watch_channel(topic_name)) + self.channel_tasks[topic_name].append(task) + + async def _cleanup_topic_subscription(self, topic_name: str): + """ + Clean up channel observing tasks after a call to mgr.unsubscribe_from_topic('iopub') + """ + if topic_name in self.channel_tasks: + for task in self.channel_tasks[topic_name]: + task.cancel() + self.channel_tasks[topic_name].remove(task) + await asyncio.sleep(0) + # Reset the channel object on our jupyter_client + setattr(self._client, f"_{topic_name}_channel", None) + else: + logger.warning( + f"Got a call to cleanup topic {topic_name} but it wasn't in the channel_tasks dict" + ) + + async def _watch_channel(self, topic_name: str): """ When a user subscribes to a topic (mgr.subscribe_to_topic('iopub')), this function starts - two tasks: + two child tasks: 1. Pull messages off the zmq channel and trigger any registered callbacks 2. Watch the monitor socket for disconnect events and reconnect / restart tasks + + If a disconnect is observed, the two tasks are both cancelled and restarted. + Unsubscribing from a topic cancels this task and the child tasks. """ channel_name = f"{topic_name}_channel" @@ -93,21 +153,25 @@ async def watch_channel(self, topic_name: str): await self.context_hook() while True: # The channel properties (e.g. self._client.iopub_channel) will connect the socket - # if self._client._iopub_channel is None. + # if self._client._iopub_channel is None. Channel objects have a monitor object + # to observe lifecycle of the socket such as handshake / disconnect channel_obj: ZMQSocketChannel = getattr(self._client, channel_name) - monitor_socket = channel_obj.socket.get_monitor_socket() - monitor_task = asyncio.create_task(self.watch_for_disconnect(monitor_socket)) message_task = asyncio.create_task( - self.watch_for_channel_messages(topic_name, channel_obj) + self._watch_for_channel_messages(topic_name, channel_obj) ) - # add tasks to self.channel_tasks so we can cleanup during topic unsubscribe / shutdown + monitor_socket = channel_obj.socket.get_monitor_socket() + monitor_task = asyncio.create_task(self._watch_for_disconnect(monitor_socket)) + + # If the _watch_channel task gets cancelled from a .unsubscribe_from_topic call, + # the two child tasks won't automatically be cancelled. Store these up at the class + # level so that _cleanup_topic_subscription can cancel them. self.channel_tasks[topic_name].append(monitor_task) self.channel_tasks[topic_name].append(message_task) - # Run the monitor and message tasks. Message task should run forever. - # If the monitor task returns then it means the socket was disconnected - # (max message size) and we need to cycle it. + # Await the monitor and message tasks. Message task should run forever. + # If the monitor task returns then it means the socket was disconnected, + # presumably from receiving a message larger than max message size. done, pending = await asyncio.wait( [monitor_task, message_task], return_when=asyncio.FIRST_COMPLETED, @@ -118,56 +182,36 @@ async def watch_channel(self, topic_name: str): if task.exception(): raise task.exception() + logger.info(f"Cycling topic {topic_name} after disconnect") self.channel_tasks[topic_name].remove(monitor_task) self.channel_tasks[topic_name].remove(message_task) - logger.info(f"Cycling topic {topic_name} after disconnect") + + # Emit an event so that callbacks registered to pickup the disconnect can do things like + # send user-facing messages that an output stream was too big and won't be displayed self._emit_system_event(topic_name, SystemEvents.FORCED_DISCONNECT) channel_obj.close() - setattr(self._client, f"_{channel_name}", None) - async def _create_topic_subscription(self, topic_name: str): - task = asyncio.create_task(self.watch_channel(topic_name)) - self.channel_tasks[topic_name].append(task) - - async def _cleanup_topic_subscription(self, topic_name: str): - if topic_name in self.channel_tasks: - for task in self.channel_tasks[topic_name]: - task.cancel() - self.channel_tasks[topic_name].remove(task) - await asyncio.sleep(0) - # Reset the channel object on our jupyter_client - setattr(self._client, f"_{topic_name}_channel", None) - else: - logger.warning( - f"Got a call to cleanup topic {topic_name} but it wasn't in the channel_tasks dict" - ) - - def send( - self, - topic_name: str, - msg_type: str, - content: Optional[dict], - parent: Optional[dict] = None, - header: Optional[dict] = None, - metadata: Optional[dict] = None, - ): - msg = self._client.session.msg(msg_type, content, parent, header, metadata) - self.outbound_queue.put_nowait(QueuedMessage(topic_name, msg, None)) - - async def _publish(self, message: QueuedMessage): - topic_name = message.topic - if topic_name not in self.subscribed_topics: - await self._create_topic_subscription(topic_name) - if hasattr(self._client, f"{topic_name}_channel"): - channel_obj = getattr(self._client, f"{topic_name}_channel") - channel_obj.send(message.contents) + # Setting jupyter_client._iopub_channel to None will cause the next reference to + # the jupyter_client.iopub_channel @property to reconnect the socket. + # (see top of this while loop!) + setattr(self._client, f"_{channel_name}", None) - # _poll and _poll_loop are designed to be used to define how a Sending backend - # will read incoming data over the wire (socket, websocket, etc). In this implementation - # when we subscribe to a topic, it starts a watch_channel task which handles reading - # data over the right jupyter_client / zmq channel. So _poll and _poll_loop aren't used. - async def _poll(self): - pass + async def _watch_for_channel_messages(self, topic_name: str, channel_obj: ZMQSocketChannel): + """ + Read in any messages on a specific jupyter_client channel and drop them into the inbound + worker queue which will trigger registered callback functions by predicate / topic + """ + while True: + msg: dict = await channel_obj.get_msg() + self.schedule_for_delivery(topic_name, msg) - async def _poll_loop(self): - pass + async def _watch_for_disconnect(self, monitor_socket: zmq.Socket): + """ + An awaitable task that ends when a particular socket has a disconnect event. Used in + conjunction with watch_for_channel_messages to cycle a socket when it's disconnected. + """ + while True: + msg: dict = await recv_monitor_message(monitor_socket) + event: zmq.Event = msg["event"] + if event == zmq.EVENT_DISCONNECTED: + return