Skip to content

Commit

Permalink
Support display id in display(value) API. (#385)
Browse files Browse the repository at this point in the history
* Support display id in `DISPLAY(value)` API.
* Support async display update

Co-authored-by: Ilya Muradyan <Ilya.Muradyan@jetbrains.com>
(cherry picked from commit 00699df)
  • Loading branch information
nikitinas authored and ileasile committed Nov 16, 2022
1 parent 4799218 commit b7cd227
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ interface KotlinKernelHost {
* Try to display the given value. It is only displayed if it's an instance of [Renderable]
* or may be converted to it
*/
fun display(value: Any)
fun display(value: Any, id: String? = null)

/**
* Updates display data with given [id] with the new [value]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ interface DisplayResult : Renderable {
* @param additionalMetadata Additional reply metadata
* @return Display JSON
*/
fun toJson(additionalMetadata: JsonObject = JsonObject(mapOf())): JsonObject
fun toJson(additionalMetadata: JsonObject = JsonObject(mapOf()), overrideId: String? = null): JsonObject

/**
* Renders display result, generally should return `this`
Expand Down Expand Up @@ -83,6 +83,12 @@ fun DisplayResult?.toJson(): JsonObject {
return Json.encodeToJsonElement(mapOf("data" to null, "metadata" to JsonObject(mapOf()))) as JsonObject
}

@Suppress("unused")
fun DisplayResult.withId(id: String) = if (id == this.id) this else object : DisplayResult {
override fun toJson(additionalMetadata: JsonObject, overrideId: String?) = this@withId.toJson(additionalMetadata, overrideId ?: id)
override val id = id
}

/**
* Sets display ID to JSON.
* If ID was not set, sets it to [id] and returns it back
Expand Down Expand Up @@ -112,7 +118,7 @@ class MimeTypedResult(
var isolatedHtml: Boolean = false,
override val id: String? = null
) : Map<String, String> by mimeData, DisplayResult {
override fun toJson(additionalMetadata: JsonObject): JsonObject {
override fun toJson(additionalMetadata: JsonObject, overrideId: String?): JsonObject {
val metadata = HashMap<String, JsonElement>().apply {
if (isolatedHtml) put("text/html", Json.encodeToJsonElement(mapOf("isolated" to true)))
additionalMetadata.forEach { key, value ->
Expand All @@ -124,17 +130,9 @@ class MimeTypedResult(
"data" to Json.encodeToJsonElement(mimeData),
"metadata" to Json.encodeToJsonElement(metadata)
)
result.setDisplayId(id)
result.setDisplayId(overrideId ?: id)
return Json.encodeToJsonElement(result) as JsonObject
}

/**
* Adds an [id] to this [MimeTypedResult] object
*/
@Suppress("unused")
fun withId(id: String?): MimeTypedResult {
return MimeTypedResult(mimeData, isolatedHtml, id)
}
}

// Convenience methods for displaying results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ abstract class ScriptTemplateWithDisplayHelpers(

val notebook get() = userHandlesProvider.notebook

fun DISPLAY(value: Any) = host.display(value)
fun DISPLAY(value: Any, id: String? = null) = host.display(value, id)

fun UPDATE_DISPLAY(value: Any, id: String?) = host.updateDisplay(value, id)

Expand Down
38 changes: 27 additions & 11 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/apiImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,20 @@ import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterConnection
import org.jetbrains.kotlinx.jupyter.api.libraries.LibraryResolutionRequest
import org.jetbrains.kotlinx.jupyter.repl.impl.SharedReplContext

interface MutableDisplayResultWithCell : DisplayResultWithCell {
override val cell: MutableCodeCell
}

interface MutableDisplayContainer : DisplayContainer {
fun add(display: DisplayResultWrapper)

fun add(display: DisplayResult, cell: MutableCodeCell)

fun update(id: String?, display: DisplayResult)

override fun getAll(): List<MutableDisplayResultWithCell>

override fun getById(id: String?): List<MutableDisplayResultWithCell>
}

interface MutableCodeCell : CodeCell {
Expand Down Expand Up @@ -53,7 +61,7 @@ interface MutableNotebook : Notebook {
class DisplayResultWrapper private constructor(
val display: DisplayResult,
override val cell: MutableCodeCell,
) : DisplayResult by display, DisplayResultWithCell {
) : DisplayResult by display, MutableDisplayResultWithCell {
companion object {
fun create(display: DisplayResult, cell: MutableCodeCell): DisplayResultWrapper {
return if (display is DisplayResultWrapper) DisplayResultWrapper(display.display, cell)
Expand All @@ -66,27 +74,35 @@ class DisplayContainerImpl : MutableDisplayContainer {
private val displays: MutableMap<String?, MutableList<DisplayResultWrapper>> = mutableMapOf()

override fun add(display: DisplayResultWrapper) {
val list = displays.getOrPut(display.id) { mutableListOf() }
list.add(display)
synchronized(displays) {
val list = displays.getOrPut(display.id) { mutableListOf() }
list.add(display)
}
}

override fun add(display: DisplayResult, cell: MutableCodeCell) {
add(DisplayResultWrapper.create(display, cell))
}

override fun getAll(): List<DisplayResultWithCell> {
return displays.flatMap { it.value }
override fun getAll(): List<MutableDisplayResultWithCell> {
synchronized(displays) {
return displays.flatMap { it.value }
}
}

override fun getById(id: String?): List<DisplayResultWithCell> {
return displays[id].orEmpty()
override fun getById(id: String?): List<MutableDisplayResultWithCell> {
synchronized(displays) {
return displays[id].orEmpty()
}
}

override fun update(id: String?, display: DisplayResult) {
val initialDisplays = displays[id] ?: return
val updated = initialDisplays.map { DisplayResultWrapper.create(display, it.cell) }
initialDisplays.clear()
initialDisplays.addAll(updated)
synchronized(displays) {
val initialDisplays = displays[id] ?: return
val updated = initialDisplays.map { DisplayResultWrapper.create(display, it.cell) }
initialDisplays.clear()
initialDisplays.addAll(updated)
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class JupyterConnectionImpl(
private fun getInput(): String {
stdin.sendMessage(
makeReplyMessage(
contextMessage,
contextMessage!!,
MessageType.INPUT_REQUEST,
content = InputRequest("stdin:")
)
Expand Down Expand Up @@ -156,7 +156,7 @@ class JupyterConnectionImpl(
override fun setContextMessage(message: RawMessage?) {
_contextMessage = message
}
override val contextMessage: RawMessage get() = _contextMessage!!
override val contextMessage: RawMessage? get() = _contextMessage

override val executor: JupyterExecutor = JupyterExecutorImpl()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import java.io.InputStream

interface JupyterConnectionInternal : JupyterConnection {
val config: KernelConfig
val contextMessage: RawMessage
val contextMessage: RawMessage?

val heartbeat: JupyterSocket
val shell: JupyterSocket
Expand Down
51 changes: 24 additions & 27 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/protocol.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessage
import org.jetbrains.kotlinx.jupyter.api.libraries.portField
import org.jetbrains.kotlinx.jupyter.api.setDisplayId
import org.jetbrains.kotlinx.jupyter.api.textResult
import org.jetbrains.kotlinx.jupyter.api.withId
import org.jetbrains.kotlinx.jupyter.common.looksLikeReplCommand
import org.jetbrains.kotlinx.jupyter.compiler.util.EvaluatedSnippetMetadata
import org.jetbrains.kotlinx.jupyter.config.KernelStreams
Expand Down Expand Up @@ -62,10 +63,10 @@ abstract class Response(
abstract val state: ResponseState

fun send(connection: JupyterConnectionInternal, requestCount: Long, requestMsg: RawMessage, startedTime: String) {
if (stdOut != null && stdOut.isNotEmpty()) {
if (!stdOut.isNullOrEmpty()) {
connection.sendOut(requestMsg, JupyterOutType.STDOUT, stdOut)
}
if (stdErr != null && stdErr.isNotEmpty()) {
if (!stdErr.isNullOrEmpty()) {
connection.sendOut(requestMsg, JupyterOutType.STDERR, stdErr)
}
sendBody(connection, requestCount, requestMsg, startedTime)
Expand Down Expand Up @@ -118,12 +119,12 @@ class OkResponseWithMessage(
}

interface DisplayHandler {
fun handleDisplay(value: Any, host: ExecutionHost)
fun handleDisplay(value: Any, host: ExecutionHost, id: String? = null)
fun handleUpdate(value: Any, host: ExecutionHost, id: String? = null)
}

object NoOpDisplayHandler : DisplayHandler {
override fun handleDisplay(value: Any, host: ExecutionHost) {
override fun handleDisplay(value: Any, host: ExecutionHost, id: String?) {
}

override fun handleUpdate(value: Any, host: ExecutionHost, id: String?) {
Expand All @@ -141,44 +142,40 @@ class SocketDisplayHandler(
return renderedValue.toDisplayResult(notebook)
}

override fun handleDisplay(value: Any, host: ExecutionHost) {
val display = render(host, value) ?: return
override fun handleDisplay(value: Any, host: ExecutionHost, id: String?) {
val display = render(host, value)?.let { if (id != null) it.withId(id) else it } ?: return
val json = display.toJson()

notebook.currentCell?.addDisplay(display)

socket.sendMessage(
makeReplyMessage(
connection.contextMessage,
MessageType.DISPLAY_DATA,
content = DisplayDataResponse(
json["data"],
json["metadata"],
json["transient"]
)
)
val content = DisplayDataResponse(
json["data"],
json["metadata"],
json["transient"]
)
val message = makeReplyMessage(connection.contextMessage!!, MessageType.DISPLAY_DATA, content = content)
socket.sendMessage(message)
}

override fun handleUpdate(value: Any, host: ExecutionHost, id: String?) {
val display = render(host, value) ?: return
val json = display.toJson().toMutableMap()

notebook.currentCell?.displays?.update(id, display)
val container = notebook.displays
container.update(id, display)
container.getById(id).distinctBy { it.cell.id }.forEach {
it.cell.displays.update(id, display)
}

json.setDisplayId(id) ?: throw RuntimeException("`update_display_data` response should provide an id of data being updated")

socket.sendMessage(
makeReplyMessage(
connection.contextMessage,
MessageType.UPDATE_DISPLAY_DATA,
content = DisplayDataResponse(
json["data"],
json["metadata"],
json["transient"]
)
)
val content = DisplayDataResponse(
json["data"],
json["metadata"],
json["transient"]
)
val message = connection.makeSimpleMessage(MessageType.UPDATE_DISPLAY_DATA, content)
socket.sendMessage(message)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ internal class CellExecutorImpl(private val replContext: SharedReplContext) : Ce

override fun execute(code: Code) = executor.execute(code, displayHandler, processVariables = false, invokeAfterCallbacks = false, stackFrame = stackFrame).result

override fun display(value: Any) {
displayHandler.handleDisplay(value, this)
override fun display(value: Any, id: String?) {
displayHandler.handleDisplay(value, this, id)
}

override fun updateDisplay(value: Any, id: String?) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class InMemoryLibraryResolver(
}

class TestDisplayHandler(val list: MutableList<Any> = mutableListOf()) : DisplayHandler {
override fun handleDisplay(value: Any, host: ExecutionHost) {
override fun handleDisplay(value: Any, host: ExecutionHost, id: String?) {
list.add(value)
}

Expand Down

0 comments on commit b7cd227

Please sign in to comment.