diff --git a/README.md b/README.md index ae940973d..48ec7d2f9 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,9 @@ The following line magics are supported: - `%use , ...` - injects code for supported libraries: artifact resolution, default imports, initialization code, type renderers - `%trackClasspath` - logs any changes of current classpath. Useful for debugging artifact resolution failures - `%trackExecution` - logs pieces of code that are going to be executed. Useful for debugging of libraries support + - `%output [options]` - output capturing settings. + + See detailed info about line magics [here](doc/magics.md). ### Supported Libraries diff --git a/build.gradle b/build.gradle index 82792f8aa..0433efd04 100644 --- a/build.gradle +++ b/build.gradle @@ -37,10 +37,10 @@ allprojects { } dependencies { - compile "org.jetbrains.kotlin:kotlin-stdlib:$kotlinVersion" + implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlinVersion" - testCompile 'junit:junit:4.12' - testCompile "org.jetbrains.kotlin:kotlin-test:$kotlinVersion" + testImplementation 'junit:junit:4.12' + testImplementation "org.jetbrains.kotlin:kotlin-test:$kotlinVersion" } ext { @@ -166,6 +166,7 @@ dependencies { compile 'khttp:khttp:1.0.0' compile 'org.zeromq:jeromq:0.3.5' compile 'com.beust:klaxon:5.2' + compile 'com.github.ajalt:clikt:2.3.0' runtime 'org.slf4j:slf4j-simple:1.7.25' runtime "org.jetbrains.kotlin:jcabi-aether:1.0-dev-3" runtime "org.sonatype.aether:aether-api:1.13.1" diff --git a/doc/magics.md b/doc/magics.md new file mode 100644 index 000000000..5bf49281b --- /dev/null +++ b/doc/magics.md @@ -0,0 +1,14 @@ +# Line magics +The following line magics are supported: + - `%use , ...` - injects code for supported libraries: artifact resolution, default imports, initialization code, type renderers + - `%trackClasspath` - logs any changes of current classpath. Useful for debugging artifact resolution failures + - `%trackExecution` - logs pieces of code that are going to be executed. Useful for debugging of libraries support + - `%output [--max-cell-size=N] [--max-buffer=N] [--max-buffer-newline=N] [--max-time=N] [--no-stdout] [--reset-to-defaults]` - + output capturing settings. + - `max-cell-size` specifies the characters count which may be printed to stdout. Default is 100000. + - `max-buffer` - max characters count stored in internal buffer before being sent to client. Default is 10000. + - `max-buffer-newline` - same as above, but trigger happens only if newline character was encountered. Default is 100. + - `max-time` - max time in milliseconds before the buffer is sent to client. Default is 100. + - `no-stdout` - don't capture output. Default is false. + - `reset-to-defaults` - reset all output settings that were set with magics to defaults + \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/kotlin/jupyter/config.kt b/src/main/kotlin/org/jetbrains/kotlin/jupyter/config.kt index 68476c894..a9be0f786 100644 --- a/src/main/kotlin/org/jetbrains/kotlin/jupyter/config.kt +++ b/src/main/kotlin/org/jetbrains/kotlin/jupyter/config.kt @@ -40,6 +40,22 @@ enum class JupyterSockets { iopub } +data class OutputConfig( + var captureOutput: Boolean = true, + var captureBufferTimeLimitMs: Long = 100, + var captureBufferMaxSize: Int = 1000, + var cellOutputMaxSize: Int = 100000, + var captureNewlineBufferSize: Int = 100 +) { + fun assign(other: OutputConfig) { + captureOutput = other.captureOutput + captureBufferTimeLimitMs = other.captureBufferTimeLimitMs + captureBufferMaxSize = other.captureBufferMaxSize + cellOutputMaxSize = other.cellOutputMaxSize + captureNewlineBufferSize = other.captureNewlineBufferSize + } +} + data class RuntimeKernelProperties(val map: Map) { val version: String get() = map["version"] ?: "unspecified" diff --git a/src/main/kotlin/org/jetbrains/kotlin/jupyter/connection.kt b/src/main/kotlin/org/jetbrains/kotlin/jupyter/connection.kt index 78bf2dba2..f227b8380 100644 --- a/src/main/kotlin/org/jetbrains/kotlin/jupyter/connection.kt +++ b/src/main/kotlin/org/jetbrains/kotlin/jupyter/connection.kt @@ -5,9 +5,7 @@ import com.beust.klaxon.Parser import org.jetbrains.kotlin.com.intellij.openapi.Disposable import org.jetbrains.kotlin.com.intellij.openapi.util.Disposer import org.zeromq.ZMQ -import java.io.ByteArrayOutputStream import java.io.Closeable -import java.io.PrintStream import java.security.SignatureException import java.util.* import javax.crypto.Mac @@ -20,6 +18,14 @@ class JupyterConnection(val config: KernelConfig): Closeable { init { val port = config.ports[socket.ordinal] bind("${config.transport}://*:$port") + if (type == ZMQ.PUB) { + // Workaround to prevent losing few first messages on kernel startup + // For more information on losing messages see this scheme: + // http://zguide.zeromq.org/page:all#Missing-Message-Problem-Solver + // It seems we cannot do correct sync because messaging protocol + // doesn't support this. Value of 500 ms was chosen experimentally. + Thread.sleep(500) + } log.debug("[$name] listen: ${config.transport}://*:$port") } @@ -150,13 +156,9 @@ class HMAC(algo: String, key: String?) { operator fun invoke(vararg data: ByteArray): String? = invoke(data.asIterable()) } -fun JupyterConnection.Socket.logWireMessage(msg: ByteArray) { - log.debug("[$name] >in: ${String(msg)}") -} - fun ByteArray.toHexString(): String = joinToString("", transform = { "%02x".format(it) }) -fun ZMQ.Socket.sendMessage(msg: Message, hmac: HMAC): Unit { +fun ZMQ.Socket.sendMessage(msg: Message, hmac: HMAC) { msg.id.forEach { sendMore(it) } sendMore(DELIM) val signableMsg = listOf(msg.header, msg.parentHeader, msg.metadata, msg.content) diff --git a/src/main/kotlin/org/jetbrains/kotlin/jupyter/magics.kt b/src/main/kotlin/org/jetbrains/kotlin/jupyter/magics.kt index 63cb02c67..6de9554ec 100644 --- a/src/main/kotlin/org/jetbrains/kotlin/jupyter/magics.kt +++ b/src/main/kotlin/org/jetbrains/kotlin/jupyter/magics.kt @@ -1,12 +1,19 @@ package org.jetbrains.kotlin.jupyter +import com.github.ajalt.clikt.core.CliktCommand +import com.github.ajalt.clikt.parameters.options.default +import com.github.ajalt.clikt.parameters.options.flag +import com.github.ajalt.clikt.parameters.options.option +import com.github.ajalt.clikt.parameters.types.int +import com.github.ajalt.clikt.parameters.types.long import org.jetbrains.kotlin.jupyter.repl.spark.ClassWriter enum class ReplLineMagics(val desc: String, val argumentsUsage: String? = null, val visibleInHelp: Boolean = true) { use("include supported libraries", "klaxon(5.0.1), lets-plot"), trackClasspath("log current classpath changes"), trackExecution("log code that is going to be executed in repl", visibleInHelp = false), - dumpClassesForSpark("stores compiled repl classes in special folder for Spark integration", visibleInHelp = false) + dumpClassesForSpark("stores compiled repl classes in special folder for Spark integration", visibleInHelp = false), + output("setup output settings", "--max-cell-size=1000 --no-stdout --max-time=100 --max-buffer=400") } fun processMagics(repl: ReplForJupyter, code: String): String { @@ -15,6 +22,35 @@ fun processMagics(repl: ReplForJupyter, code: String): String { var nextSearchIndex = 0 var nextCopyIndex = 0 + val outputParser = repl.outputConfig.let { conf -> + object : CliktCommand() { + val defaultConfig = OutputConfig() + + val max: Int by option("--max-cell-size", help = "Maximum cell output").int().default(conf.cellOutputMaxSize) + val maxBuffer: Int by option("--max-buffer", help = "Maximum buffer size").int().default(conf.captureBufferMaxSize) + val maxBufferNewline: Int by option("--max-buffer-newline", help = "Maximum buffer size when newline got").int().default(conf.captureNewlineBufferSize) + val maxTimeInterval: Long by option("--max-time", help = "Maximum time wait for output to accumulate").long().default(conf.captureBufferTimeLimitMs) + val dontCaptureStdout: Boolean by option("--no-stdout", help = "Don't capture output").flag(default = !conf.captureOutput) + val reset: Boolean by option("--reset-to-defaults", help = "Reset to defaults").flag() + + override fun run() { + if (reset) { + conf.assign(defaultConfig) + return + } + conf.assign( + OutputConfig( + !dontCaptureStdout, + maxTimeInterval, + maxBuffer, + max, + maxBufferNewline + ) + ) + } + } + } + while (true) { var magicStart: Int @@ -55,6 +91,9 @@ fun processMagics(repl: ReplForJupyter, code: String): String { if (arg == null) throw ReplCompilerException("Need some arguments for 'use' command") repl.librariesCodeGenerator.processNewLibraries(repl, arg) } + ReplLineMagics.output -> { + outputParser.parse((arg ?: "").split(" ")) + } } nextCopyIndex = magicEnd nextSearchIndex = magicEnd diff --git a/src/main/kotlin/org/jetbrains/kotlin/jupyter/protocol.kt b/src/main/kotlin/org/jetbrains/kotlin/jupyter/protocol.kt index 7345a7b00..cee0ae008 100644 --- a/src/main/kotlin/org/jetbrains/kotlin/jupyter/protocol.kt +++ b/src/main/kotlin/org/jetbrains/kotlin/jupyter/protocol.kt @@ -4,25 +4,38 @@ import com.beust.klaxon.JsonObject import jupyter.kotlin.DisplayResult import jupyter.kotlin.MimeTypedResult import jupyter.kotlin.textResult +import org.jetbrains.annotations.TestOnly import org.jetbrains.kotlin.config.KotlinCompilerVersion import java.io.ByteArrayOutputStream import java.io.OutputStream import java.io.PrintStream import java.lang.reflect.InvocationTargetException import java.util.concurrent.atomic.AtomicLong +import kotlin.concurrent.timer enum class ResponseState { Ok, Error } +enum class JupyterOutType { + STDOUT, STDERR; + fun optionName() = name.toLowerCase() +} + data class ResponseWithMessage(val state: ResponseState, val result: MimeTypedResult?, val displays: List = emptyList(), val stdOut: String? = null, val stdErr: String? = null) { val hasStdOut: Boolean = stdOut != null && stdOut.isNotEmpty() val hasStdErr: Boolean = stdErr != null && stdErr.isNotEmpty() } +fun JupyterConnection.Socket.sendOut(msg:Message, stream: JupyterOutType, text: String) { + connection.iopub.send(makeReplyMessage(msg, header = makeHeader("stream", msg), + content = jsonObject( + "name" to stream.optionName(), + "text" to text))) +} + fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJupyter?, executionCount: AtomicLong) { - val msgType = msg.header!!["msg_type"] - when (msgType) { + when (msg.header!!["msg_type"]) { "kernel_info_request" -> sendWrapped(msg, makeReplyMessage(msg, "kernel_info_reply", content = jsonObject( @@ -70,23 +83,16 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup val res: ResponseWithMessage = if (isCommand(code.toString())) { runCommand(code.toString(), repl) } else { - connection.evalWithIO { + connection.evalWithIO (repl?.outputConfig) { repl?.eval(code.toString(), count.toInt()) } } - fun sendOut(stream: String, text: String) { - connection.iopub.send(makeReplyMessage(msg, header = makeHeader("stream", msg), - content = jsonObject( - "name" to stream, - "text" to text))) - } - if (res.hasStdOut) { - sendOut("stdout", res.stdOut!!) + sendOut(msg, JupyterOutType.STDOUT, res.stdOut!!) } if (res.hasStdErr) { - sendOut("stderr", res.stdErr!!) + sendOut(msg, JupyterOutType.STDERR, res.stdErr!!) } when (res.state) { @@ -169,12 +175,73 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup } } -class CapturingOutputStream(val stdout: PrintStream, val captureOutput: Boolean) : OutputStream() { - val capturedOutput = ByteArrayOutputStream() +class CapturingOutputStream(private val stdout: PrintStream, + private val conf: OutputConfig, + private val captureOutput: Boolean, + val onCaptured: (String) -> Unit) : OutputStream() { + private val capturedLines = ByteArrayOutputStream() + private val capturedNewLine = ByteArrayOutputStream() + private var overallOutputSize = 0 + private var newlineFound = false + + private val timer = timer( + initialDelay = conf.captureBufferTimeLimitMs, + period = conf.captureBufferTimeLimitMs, + action = { + flush() + }) + val contents: ByteArray + @TestOnly + get() = capturedLines.toByteArray() + capturedNewLine.toByteArray() + + private fun flushIfNeeded(b: Int) { + val c = b.toChar() + if (c == '\n' || c == '\r') { + newlineFound = true + capturedNewLine.writeTo(capturedLines) + capturedNewLine.reset() + } + + val size = capturedLines.size() + capturedNewLine.size() + + if (newlineFound && size >= conf.captureNewlineBufferSize) + return flushBuffers(capturedLines) + if (size >= conf.captureBufferMaxSize) + return flush() + } + + @Synchronized override fun write(b: Int) { + ++overallOutputSize stdout.write(b) - if (captureOutput) capturedOutput.write(b) + + if (captureOutput && overallOutputSize <= conf.cellOutputMaxSize) { + capturedNewLine.write(b) + flushIfNeeded(b) + } + } + + @Synchronized + private fun flushBuffers(vararg buffers: ByteArrayOutputStream) { + newlineFound = false + val str = buffers.map { stream -> + val str = stream.toString("UTF-8") + stream.reset() + str + }.reduce { acc, s -> acc + s } + if (str.isNotEmpty()) { + onCaptured(str) + } + } + + override fun flush() { + flushBuffers(capturedLines, capturedNewLine) + } + + override fun close() { + super.close() + timer.cancel() } } @@ -185,17 +252,25 @@ fun Any.toMimeTypedResult(): MimeTypedResult? = when (this) { else -> textResult(this.toString()) } -fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage { +fun JupyterConnection.evalWithIO(maybeConfig: OutputConfig?, body: () -> EvalResult?): ResponseWithMessage { val out = System.out val err = System.err + val config = maybeConfig ?: OutputConfig() + + fun getCapturingStream(stream: PrintStream, outType: JupyterOutType, captureOutput: Boolean): CapturingOutputStream { + return CapturingOutputStream( + stream, + config, + captureOutput) { text -> + this.iopub.sendOut(contextMessage!!, outType, text) + } + } - // TODO: make configuration option of whether to pipe back stdout and stderr - // TODO: make a configuration option to limit the total stdout / stderr possibly returned (in case it goes wild...) - val forkedOut = CapturingOutputStream(out, true) - val forkedError = CapturingOutputStream(err, false) + val forkedOut = getCapturingStream(out, JupyterOutType.STDOUT, config.captureOutput) + val forkedError = getCapturingStream(err, JupyterOutType.STDERR, false) - System.setOut(PrintStream(forkedOut, true, "UTF-8")) - System.setErr(PrintStream(forkedError, true, "UTF-8")) + System.setOut(PrintStream(forkedOut, false, "UTF-8")) + System.setErr(PrintStream(forkedError, false, "UTF-8")) val `in` = System.`in` System.setIn(stdinIn) @@ -205,26 +280,26 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage { if (exec == null) { ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null, "NO REPL!") } else { - val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull() - val stdErr = forkedError.capturedOutput.toString("UTF-8").emptyWhenNull() + forkedOut.flush() + forkedError.flush() try { var result: MimeTypedResult? = null - var displays = exec.displayValues.mapNotNull { it.toMimeTypedResult() } + val displays = exec.displayValues.mapNotNull { it.toMimeTypedResult() }.toMutableList() if (exec.resultValue is DisplayResult) { val resultDisplay = exec.resultValue.value.toMimeTypedResult() if (resultDisplay != null) displays += resultDisplay } else result = exec.resultValue?.toMimeTypedResult() - ResponseWithMessage(ResponseState.Ok, result, displays, stdOut, stdErr) + ResponseWithMessage(ResponseState.Ok, result, displays, null, null) } catch (e: Exception) { - ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut, - joinLines(stdErr, "error: Unable to convert result to a string: ${e}")) + ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null, + "error: Unable to convert result to a string: $e") } } } catch (ex: ReplCompilerException) { - val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull() - val stdErr = forkedError.capturedOutput.toString("UTF-8").emptyWhenNull() + forkedOut.flush() + forkedError.flush() // handle runtime vs. compile time and send back correct format of response, now we just send text /* @@ -235,10 +310,10 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage { 'traceback' : list(str), # traceback frames as strings } */ - ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut, - joinLines(stdErr, ex.errorResult.message)) + ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null, + ex.errorResult.message) } catch (ex: ReplEvalRuntimeException) { - val stdOut = forkedOut.capturedOutput.toString("UTF-8").emptyWhenNull() + forkedOut.flush() // handle runtime vs. compile time and send back correct format of response, now we just send text /* @@ -251,7 +326,6 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage { */ val stdErr = StringBuilder() with(stdErr) { - forkedError.capturedOutput.toString("UTF-8")?.nullWhenEmpty()?.also { appendln(it) } val cause = ex.errorResult.cause if (cause == null) appendln(ex.errorResult.message) else { @@ -265,16 +339,13 @@ fun JupyterConnection.evalWithIO(body: () -> EvalResult?): ResponseWithMessage { } } } - ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), stdOut, stdErr.toString()) + ResponseWithMessage(ResponseState.Error, textResult("Error!"), emptyList(), null, stdErr.toString()) } } finally { + forkedOut.close() + forkedError.close() System.setIn(`in`) System.setErr(err) System.setOut(out) } } - -fun joinLines(vararg parts: String): String = parts.filter(String::isNotBlank).joinToString("\n") -fun String.nullWhenEmpty(): String? = if (this.isBlank()) null else this -fun String?.emptyWhenNull(): String = if (this == null || this.isBlank()) "" else this - diff --git a/src/main/kotlin/org/jetbrains/kotlin/jupyter/repl.kt b/src/main/kotlin/org/jetbrains/kotlin/jupyter/repl.kt index 1e26b868f..536fa7531 100644 --- a/src/main/kotlin/org/jetbrains/kotlin/jupyter/repl.kt +++ b/src/main/kotlin/org/jetbrains/kotlin/jupyter/repl.kt @@ -40,6 +40,8 @@ class ReplCompilerException(val errorResult: ReplCompileResult.Error) : ReplExce class ReplForJupyter(val scriptClasspath: List = emptyList(), val config: ResolverConfig? = null) { + val outputConfig = OutputConfig() + private val resolver = JupyterScriptDependenciesResolver(config) private val renderers = config?.let { diff --git a/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/capturingStreamTests.kt b/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/capturingStreamTests.kt new file mode 100644 index 000000000..fc30cb71d --- /dev/null +++ b/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/capturingStreamTests.kt @@ -0,0 +1,108 @@ +package org.jetbrains.kotlin.jupyter.test + +import org.jetbrains.kotlin.jupyter.CapturingOutputStream +import org.jetbrains.kotlin.jupyter.OutputConfig +import org.junit.Assert.assertArrayEquals +import org.junit.Assert.assertEquals +import org.junit.Test +import java.io.OutputStream +import java.io.PrintStream +import java.util.concurrent.atomic.AtomicInteger + +class CapturingStreamTests { + private val nullOStream = object: OutputStream() { + override fun write(b: Int) { + } + } + + private fun getStream(stdout: OutputStream = nullOStream, + captureOutput: Boolean = true, + maxBufferLifeTimeMs: Long = 1000, + maxBufferSize: Int = 1000, + maxOutputSize: Int = 1000, + maxBufferNewlineSize: Int = 1, + onCaptured: (String) -> Unit = {}): CapturingOutputStream { + + val printStream = PrintStream(stdout, false, "UTF-8") + val config = OutputConfig(captureOutput, maxBufferLifeTimeMs, maxBufferSize, maxOutputSize, maxBufferNewlineSize) + return CapturingOutputStream(printStream, config, captureOutput, onCaptured) + } + + @Test + fun testMaxOutputSizeOk() { + val s = getStream(maxOutputSize = 6) + s.write("kotlin".toByteArray()) + } + + @Test + fun testMaxOutputSizeError() { + val s = getStream(maxOutputSize = 3) + s.write("java".toByteArray()) + assertArrayEquals("jav".toByteArray(), s.contents) + } + + @Test + fun testOutputCapturingFlag() { + val contents = "abc".toByteArray() + + val s1 = getStream(captureOutput = false) + s1.write(contents) + assertEquals(0, s1.contents.size) + + val s2 = getStream(captureOutput = true) + s2.write(contents) + assertArrayEquals(contents, s2.contents) + } + + @Test + fun testMaxBufferSize() { + val contents = "0123456789\nfortran".toByteArray() + val expected = arrayOf("012", "345", "678", "9\n", "for", "tra", "n") + + val i = AtomicInteger() + val s = getStream(maxBufferSize = 3) { + assertEquals(expected[i.getAndIncrement()], it) + } + + s.write(contents) + s.flush() + + assertEquals(expected.size, i.get()) + } + + @Test + fun testNewlineBufferSize() { + val contents = "12345\n12\n3451234567890".toByteArray() + val expected = arrayOf("12345\n", "12\n", "345123456", "7890") + + val i = AtomicInteger() + val s = getStream(maxBufferSize = 9, maxBufferNewlineSize = 6) { + assertEquals(expected[i.getAndIncrement()], it) + } + + s.write(contents) + s.flush() + + assertEquals(expected.size, i.get()) + } + + @Test + fun testMaxBufferLifeTime() { + val strings = arrayOf("11", "22", "33", "44", "55", "66") + val expected = arrayOf("1122", "3344", "5566") + + val i = AtomicInteger(0) + val s = getStream(maxBufferLifeTimeMs = 1000) { + assertEquals(expected[i.getAndIncrement()], it) + } + + strings.forEach { + Thread.sleep(450) + s.write(it.toByteArray()) + } + + s.flush() + + assertEquals(expected.size, i.get()) + } +} \ No newline at end of file diff --git a/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/executeTests.kt b/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/executeTests.kt index eaf74b6b8..46e0b4491 100644 --- a/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/executeTests.kt +++ b/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/executeTests.kt @@ -6,29 +6,40 @@ import org.junit.Assert import org.junit.Test import org.zeromq.ZMQ +fun Message.type(): String { + return header!!["msg_type"] as String +} + class ExecuteTests : KernelServerTestsBase() { - fun doExecute(code : String) : Any? { + private fun doExecute(code : String, hasResult: Boolean = true, ioPubChecker : (ZMQ.Socket) -> Unit = {}) : Any? { val context = ZMQ.context(1) - var shell = context.socket(ZMQ.REQ) - var ioPub = context.socket(ZMQ.SUB) + val shell = context.socket(ZMQ.REQ) + val ioPub = context.socket(ZMQ.SUB) ioPub.subscribe(byteArrayOf()) try { shell.connect("${config.transport}://*:${config.ports[JupyterSockets.shell.ordinal]}") ioPub.connect("${config.transport}://*:${config.ports[JupyterSockets.iopub.ordinal]}") shell.sendMessage("execute_request", content = jsonObject("code" to code)) var msg = shell.receiveMessage() - Assert.assertEquals("execute_reply", msg.header!!["msg_type"]) + Assert.assertEquals("execute_reply", msg.type()) msg = ioPub.receiveMessage() - Assert.assertEquals("status", msg.header!!["msg_type"]) + Assert.assertEquals("status", msg.type()) Assert.assertEquals("busy", msg.content["execution_state"]) msg = ioPub.receiveMessage() - Assert.assertEquals("execute_input", msg.header!!["msg_type"]) - msg = ioPub.receiveMessage() - Assert.assertEquals("execute_result", msg.header!!["msg_type"]) - var response = msg.content["data"] + Assert.assertEquals("execute_input", msg.type()) + + ioPubChecker(ioPub) + + var response: Any? = null + if (hasResult) { + msg = ioPub.receiveMessage() + Assert.assertEquals("execute_result", msg.type()) + response = msg.content["data"] + } + msg = ioPub.receiveMessage() - Assert.assertEquals("status", msg.header!!["msg_type"]) + Assert.assertEquals("status", msg.type()) Assert.assertEquals("idle", msg.content["execution_state"]) return response } finally { @@ -43,4 +54,48 @@ class ExecuteTests : KernelServerTestsBase() { val res = doExecute("2+2") as JsonObject Assert.assertEquals("4", res["text/plain"]) } + + @Test + fun testOutput(){ + val code = """ + for (i in 1..5) { + Thread.sleep(200) + print(i) + } + """.trimIndent() + + fun checker(ioPub: ZMQ.Socket) { + for (i in 1..5) { + val msg = ioPub.receiveMessage() + Assert.assertEquals("stream", msg.type()) + Assert.assertEquals(i.toString(), msg.content!!["text"]) + } + } + + val res = doExecute(code, false, ::checker) + Assert.assertNull(res) + } + + @Test + fun testOutputMagic(){ + val code = """ + %output --max-buffer=2 --max-time=10000 + for (i in 1..5) { + print(i) + } + """.trimIndent() + + val expected = arrayOf("12","34","5") + + fun checker(ioPub: ZMQ.Socket) { + for (el in expected) { + val msg = ioPub.receiveMessage() + Assert.assertEquals("stream", msg.type()) + Assert.assertEquals(el, msg.content!!["text"]) + } + } + + val res = doExecute(code, false, ::checker) + Assert.assertNull(res) + } } \ No newline at end of file diff --git a/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/kernelServerTestsBase.kt b/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/kernelServerTestsBase.kt index bd7cd3d9f..f6760145f 100644 --- a/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/kernelServerTestsBase.kt +++ b/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/kernelServerTestsBase.kt @@ -22,7 +22,7 @@ open class KernelServerTestsBase { protected val hmac = HMAC(config.signatureScheme, config.signatureKey) - protected var server: Thread? = null + private var server: Thread? = null protected val messageId = listOf(byteArrayOf(1)) @@ -33,11 +33,10 @@ open class KernelServerTestsBase { @After fun teardownServer() { - Thread.sleep(100) server?.interrupt() } - fun ZMQ.Socket.sendMessage(msgType: String, content : JsonObject): Unit { + fun ZMQ.Socket.sendMessage(msgType: String, content : JsonObject) { sendMessage(Message(id = messageId, header = makeHeader(msgType), content = content), hmac) } @@ -45,18 +44,22 @@ open class KernelServerTestsBase { companion object { private val rng = Random() - private val portRangeStart = 32768 - private val portRangeEnd = 65536 - - fun randomPort(): Int = - generateSequence { portRangeStart + rng.nextInt(portRangeEnd - portRangeStart) }.find { - try { - ServerSocket(it).close() - true - } - catch (e: IOException) { - false - } - }!! + private val usedPorts = mutableSetOf() + private const val portRangeStart = 32768 + private const val portRangeEnd = 65536 + + @Synchronized + fun randomPort(): Int { + val res = generateSequence { portRangeStart + rng.nextInt(portRangeEnd - portRangeStart) }.find { + try { + ServerSocket(it).close() + !usedPorts.contains(it) + } catch (e: IOException) { + false + } + }!! + usedPorts.add(res) + return res + } } } diff --git a/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/replTests.kt b/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/replTests.kt index 3c4cd8795..6191339b2 100644 --- a/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/replTests.kt +++ b/src/test/kotlin/org/jetbrains/kotlin/jupyter/test/replTests.kt @@ -3,8 +3,6 @@ package org.jetbrains.kotlin.jupyter.test import com.beust.klaxon.JsonObject import com.beust.klaxon.Parser import jupyter.kotlin.MimeTypedResult -import kotlinx.coroutines.GlobalScope -import kotlinx.coroutines.async import org.jetbrains.kotlin.jupyter.* import org.jetbrains.kotlin.jupyter.repl.completion.CompletionResultSuccess import org.junit.Assert @@ -74,6 +72,28 @@ class ReplTest { assertFails { repl.eval("Out[3]") } } + @Test + fun TestOutputMagic() { + val repl = ReplForJupyter(classpath) + repl.preprocessCode("%output --max-cell-size=100500 --no-stdout") + assertEquals(OutputConfig( + cellOutputMaxSize = 100500, + captureOutput = false + ), repl.outputConfig) + + repl.preprocessCode("%output --max-buffer=42 --max-buffer-newline=33 --max-time=2000") + assertEquals(OutputConfig( + cellOutputMaxSize = 100500, + captureOutput = false, + captureBufferMaxSize = 42, + captureNewlineBufferSize = 33, + captureBufferTimeLimitMs = 2000 + ), repl.outputConfig) + + repl.preprocessCode("%output --reset-to-defaults") + assertEquals(OutputConfig(), repl.outputConfig) + } + @Test fun TestUseMagic() { val lib1 = "mylib" to """