From b7cd2273630dbf263cb252f4e250ef85a279f8e2 Mon Sep 17 00:00:00 2001 From: Anatoly Nikitin Date: Mon, 24 Oct 2022 03:47:54 +0300 Subject: [PATCH] Support display id in `display(value)` API. (#385) * Support display id in `DISPLAY(value)` API. * Support async display update Co-authored-by: Ilya Muradyan (cherry picked from commit 00699df2ed4daa4db7eda4b499961db4495f8a68) --- .../kotlinx/jupyter/api/KotlinKernelHost.kt | 2 +- .../jetbrains/kotlinx/jupyter/api/results.kt | 20 ++++---- .../ScriptTemplateWithDisplayHelpers.kt | 2 +- .../org/jetbrains/kotlinx/jupyter/apiImpl.kt | 38 ++++++++++---- .../jetbrains/kotlinx/jupyter/connection.kt | 4 +- .../messaging/JupyterConnectionInternal.kt | 2 +- .../kotlinx/jupyter/messaging/protocol.kt | 51 +++++++++---------- .../jupyter/repl/impl/CellExecutorImpl.kt | 4 +- .../kotlinx/jupyter/test/testUtil.kt | 2 +- 9 files changed, 68 insertions(+), 57 deletions(-) diff --git a/jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/KotlinKernelHost.kt b/jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/KotlinKernelHost.kt index 5c40552b5..8ba3a467d 100644 --- a/jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/KotlinKernelHost.kt +++ b/jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/KotlinKernelHost.kt @@ -11,7 +11,7 @@ interface KotlinKernelHost { * Try to display the given value. It is only displayed if it's an instance of [Renderable] * or may be converted to it */ - fun display(value: Any) + fun display(value: Any, id: String? = null) /** * Updates display data with given [id] with the new [value] diff --git a/jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/results.kt b/jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/results.kt index ac9468f87..0afd12bd1 100644 --- a/jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/results.kt +++ b/jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/results.kt @@ -47,7 +47,7 @@ interface DisplayResult : Renderable { * @param additionalMetadata Additional reply metadata * @return Display JSON */ - fun toJson(additionalMetadata: JsonObject = JsonObject(mapOf())): JsonObject + fun toJson(additionalMetadata: JsonObject = JsonObject(mapOf()), overrideId: String? = null): JsonObject /** * Renders display result, generally should return `this` @@ -83,6 +83,12 @@ fun DisplayResult?.toJson(): JsonObject { return Json.encodeToJsonElement(mapOf("data" to null, "metadata" to JsonObject(mapOf()))) as JsonObject } +@Suppress("unused") +fun DisplayResult.withId(id: String) = if (id == this.id) this else object : DisplayResult { + override fun toJson(additionalMetadata: JsonObject, overrideId: String?) = this@withId.toJson(additionalMetadata, overrideId ?: id) + override val id = id +} + /** * Sets display ID to JSON. * If ID was not set, sets it to [id] and returns it back @@ -112,7 +118,7 @@ class MimeTypedResult( var isolatedHtml: Boolean = false, override val id: String? = null ) : Map by mimeData, DisplayResult { - override fun toJson(additionalMetadata: JsonObject): JsonObject { + override fun toJson(additionalMetadata: JsonObject, overrideId: String?): JsonObject { val metadata = HashMap().apply { if (isolatedHtml) put("text/html", Json.encodeToJsonElement(mapOf("isolated" to true))) additionalMetadata.forEach { key, value -> @@ -124,17 +130,9 @@ class MimeTypedResult( "data" to Json.encodeToJsonElement(mimeData), "metadata" to Json.encodeToJsonElement(metadata) ) - result.setDisplayId(id) + result.setDisplayId(overrideId ?: id) return Json.encodeToJsonElement(result) as JsonObject } - - /** - * Adds an [id] to this [MimeTypedResult] object - */ - @Suppress("unused") - fun withId(id: String?): MimeTypedResult { - return MimeTypedResult(mimeData, isolatedHtml, id) - } } // Convenience methods for displaying results diff --git a/jupyter-lib/lib/src/main/kotlin/jupyter/kotlin/ScriptTemplateWithDisplayHelpers.kt b/jupyter-lib/lib/src/main/kotlin/jupyter/kotlin/ScriptTemplateWithDisplayHelpers.kt index 14ce15fd3..a7d5f29ab 100644 --- a/jupyter-lib/lib/src/main/kotlin/jupyter/kotlin/ScriptTemplateWithDisplayHelpers.kt +++ b/jupyter-lib/lib/src/main/kotlin/jupyter/kotlin/ScriptTemplateWithDisplayHelpers.kt @@ -14,7 +14,7 @@ abstract class ScriptTemplateWithDisplayHelpers( val notebook get() = userHandlesProvider.notebook - fun DISPLAY(value: Any) = host.display(value) + fun DISPLAY(value: Any, id: String? = null) = host.display(value, id) fun UPDATE_DISPLAY(value: Any, id: String?) = host.updateDisplay(value, id) diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/apiImpl.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/apiImpl.kt index 74c51e915..54c373885 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/apiImpl.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/apiImpl.kt @@ -18,12 +18,20 @@ import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterConnection import org.jetbrains.kotlinx.jupyter.api.libraries.LibraryResolutionRequest import org.jetbrains.kotlinx.jupyter.repl.impl.SharedReplContext +interface MutableDisplayResultWithCell : DisplayResultWithCell { + override val cell: MutableCodeCell +} + interface MutableDisplayContainer : DisplayContainer { fun add(display: DisplayResultWrapper) fun add(display: DisplayResult, cell: MutableCodeCell) fun update(id: String?, display: DisplayResult) + + override fun getAll(): List + + override fun getById(id: String?): List } interface MutableCodeCell : CodeCell { @@ -53,7 +61,7 @@ interface MutableNotebook : Notebook { class DisplayResultWrapper private constructor( val display: DisplayResult, override val cell: MutableCodeCell, -) : DisplayResult by display, DisplayResultWithCell { +) : DisplayResult by display, MutableDisplayResultWithCell { companion object { fun create(display: DisplayResult, cell: MutableCodeCell): DisplayResultWrapper { return if (display is DisplayResultWrapper) DisplayResultWrapper(display.display, cell) @@ -66,27 +74,35 @@ class DisplayContainerImpl : MutableDisplayContainer { private val displays: MutableMap> = mutableMapOf() override fun add(display: DisplayResultWrapper) { - val list = displays.getOrPut(display.id) { mutableListOf() } - list.add(display) + synchronized(displays) { + val list = displays.getOrPut(display.id) { mutableListOf() } + list.add(display) + } } override fun add(display: DisplayResult, cell: MutableCodeCell) { add(DisplayResultWrapper.create(display, cell)) } - override fun getAll(): List { - return displays.flatMap { it.value } + override fun getAll(): List { + synchronized(displays) { + return displays.flatMap { it.value } + } } - override fun getById(id: String?): List { - return displays[id].orEmpty() + override fun getById(id: String?): List { + synchronized(displays) { + return displays[id].orEmpty() + } } override fun update(id: String?, display: DisplayResult) { - val initialDisplays = displays[id] ?: return - val updated = initialDisplays.map { DisplayResultWrapper.create(display, it.cell) } - initialDisplays.clear() - initialDisplays.addAll(updated) + synchronized(displays) { + val initialDisplays = displays[id] ?: return + val updated = initialDisplays.map { DisplayResultWrapper.create(display, it.cell) } + initialDisplays.clear() + initialDisplays.addAll(updated) + } } } diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt index de2d08bb9..cdf0bba83 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt @@ -48,7 +48,7 @@ class JupyterConnectionImpl( private fun getInput(): String { stdin.sendMessage( makeReplyMessage( - contextMessage, + contextMessage!!, MessageType.INPUT_REQUEST, content = InputRequest("stdin:") ) @@ -156,7 +156,7 @@ class JupyterConnectionImpl( override fun setContextMessage(message: RawMessage?) { _contextMessage = message } - override val contextMessage: RawMessage get() = _contextMessage!! + override val contextMessage: RawMessage? get() = _contextMessage override val executor: JupyterExecutor = JupyterExecutorImpl() diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/JupyterConnectionInternal.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/JupyterConnectionInternal.kt index fa3e2706a..9d5f9356b 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/JupyterConnectionInternal.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/JupyterConnectionInternal.kt @@ -9,7 +9,7 @@ import java.io.InputStream interface JupyterConnectionInternal : JupyterConnection { val config: KernelConfig - val contextMessage: RawMessage + val contextMessage: RawMessage? val heartbeat: JupyterSocket val shell: JupyterSocket diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/protocol.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/protocol.kt index 413c868da..f110d49b1 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/protocol.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/protocol.kt @@ -23,6 +23,7 @@ import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessage import org.jetbrains.kotlinx.jupyter.api.libraries.portField import org.jetbrains.kotlinx.jupyter.api.setDisplayId import org.jetbrains.kotlinx.jupyter.api.textResult +import org.jetbrains.kotlinx.jupyter.api.withId import org.jetbrains.kotlinx.jupyter.common.looksLikeReplCommand import org.jetbrains.kotlinx.jupyter.compiler.util.EvaluatedSnippetMetadata import org.jetbrains.kotlinx.jupyter.config.KernelStreams @@ -62,10 +63,10 @@ abstract class Response( abstract val state: ResponseState fun send(connection: JupyterConnectionInternal, requestCount: Long, requestMsg: RawMessage, startedTime: String) { - if (stdOut != null && stdOut.isNotEmpty()) { + if (!stdOut.isNullOrEmpty()) { connection.sendOut(requestMsg, JupyterOutType.STDOUT, stdOut) } - if (stdErr != null && stdErr.isNotEmpty()) { + if (!stdErr.isNullOrEmpty()) { connection.sendOut(requestMsg, JupyterOutType.STDERR, stdErr) } sendBody(connection, requestCount, requestMsg, startedTime) @@ -118,12 +119,12 @@ class OkResponseWithMessage( } interface DisplayHandler { - fun handleDisplay(value: Any, host: ExecutionHost) + fun handleDisplay(value: Any, host: ExecutionHost, id: String? = null) fun handleUpdate(value: Any, host: ExecutionHost, id: String? = null) } object NoOpDisplayHandler : DisplayHandler { - override fun handleDisplay(value: Any, host: ExecutionHost) { + override fun handleDisplay(value: Any, host: ExecutionHost, id: String?) { } override fun handleUpdate(value: Any, host: ExecutionHost, id: String?) { @@ -141,44 +142,40 @@ class SocketDisplayHandler( return renderedValue.toDisplayResult(notebook) } - override fun handleDisplay(value: Any, host: ExecutionHost) { - val display = render(host, value) ?: return + override fun handleDisplay(value: Any, host: ExecutionHost, id: String?) { + val display = render(host, value)?.let { if (id != null) it.withId(id) else it } ?: return val json = display.toJson() notebook.currentCell?.addDisplay(display) - socket.sendMessage( - makeReplyMessage( - connection.contextMessage, - MessageType.DISPLAY_DATA, - content = DisplayDataResponse( - json["data"], - json["metadata"], - json["transient"] - ) - ) + val content = DisplayDataResponse( + json["data"], + json["metadata"], + json["transient"] ) + val message = makeReplyMessage(connection.contextMessage!!, MessageType.DISPLAY_DATA, content = content) + socket.sendMessage(message) } override fun handleUpdate(value: Any, host: ExecutionHost, id: String?) { val display = render(host, value) ?: return val json = display.toJson().toMutableMap() - notebook.currentCell?.displays?.update(id, display) + val container = notebook.displays + container.update(id, display) + container.getById(id).distinctBy { it.cell.id }.forEach { + it.cell.displays.update(id, display) + } json.setDisplayId(id) ?: throw RuntimeException("`update_display_data` response should provide an id of data being updated") - socket.sendMessage( - makeReplyMessage( - connection.contextMessage, - MessageType.UPDATE_DISPLAY_DATA, - content = DisplayDataResponse( - json["data"], - json["metadata"], - json["transient"] - ) - ) + val content = DisplayDataResponse( + json["data"], + json["metadata"], + json["transient"] ) + val message = connection.makeSimpleMessage(MessageType.UPDATE_DISPLAY_DATA, content) + socket.sendMessage(message) } } diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/impl/CellExecutorImpl.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/impl/CellExecutorImpl.kt index 54975bc63..a781b0f15 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/impl/CellExecutorImpl.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/impl/CellExecutorImpl.kt @@ -210,8 +210,8 @@ internal class CellExecutorImpl(private val replContext: SharedReplContext) : Ce override fun execute(code: Code) = executor.execute(code, displayHandler, processVariables = false, invokeAfterCallbacks = false, stackFrame = stackFrame).result - override fun display(value: Any) { - displayHandler.handleDisplay(value, this) + override fun display(value: Any, id: String?) { + displayHandler.handleDisplay(value, this, id) } override fun updateDisplay(value: Any, id: String?) { diff --git a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/testUtil.kt b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/testUtil.kt index 31eb07eaa..e83930612 100644 --- a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/testUtil.kt +++ b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/testUtil.kt @@ -153,7 +153,7 @@ class InMemoryLibraryResolver( } class TestDisplayHandler(val list: MutableList = mutableListOf()) : DisplayHandler { - override fun handleDisplay(value: Any, host: ExecutionHost) { + override fun handleDisplay(value: Any, host: ExecutionHost, id: String?) { list.add(value) }