Skip to content

Commit

Permalink
Improve and make scripted parallel
Browse files Browse the repository at this point in the history
These are the improvements that I've added to scripted:

1. Scripted is now parallel and does batch execution.
2. Scripted logs to both a file and the console (if `bufferLog == true`).
   All the logs can be inspected locally by going to a directory in
   `/tmp`. This directory is shown to the user at the end of the
   execution.
3. Scripted UI has been improved.
   3.1. Colors are used for `+` and `x`.
   3.1. It shows the command that actually failed, not `Command failed {line 1}`.
   3.2. It trims the stack traces of the wrapping exceptions
        (corresponding to the scripted infrastructure). Only the stack traces
        of the causing exceptions are shown (which are the ones we're
        interested in and are usually assertion errors).

I think these improvements enhance the current dev workflow
considerably. I invite you to give them a try.

This change combined with sbt#429, gives a really fast execution of
scripted. Testing just one test is under 7 seconds in my machine (note
that in those 7 seconds we have to fetch the bridge, etc).
  • Loading branch information
jvican committed Oct 22, 2017
1 parent aa1a628 commit dd20e7c
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package sbt.internal.inc

import org.scalatest.exceptions.TestFailedException
import sbt.internal.inc
import sbt.internal.scripted._
import sbt.internal.inc.BatchScriptRunner.States

/** Defines an alternative script runner that allows batch execution. */
private[sbt] class BatchScriptRunner extends ScriptRunner {

/** Defines a method to run batched execution.
*
* @param statements The list of handlers and statements.
* @param states The states of the runner. In case it's empty, inherited apply is called.
*/
def apply(statements: List[(StatementHandler, Statement)], states: States): Unit = {
if (states.isEmpty) super.apply(statements)
else statements.foreach(st => processStatement(st._1, st._2, states))
}

def initStates(states: States, handlers: Seq[StatementHandler]): Unit =
handlers.foreach(handler => states(handler) = handler.initialState)

def cleanUpHandlers(handlers: Seq[StatementHandler], states: States): Unit = {
for (handler <- handlers; state <- states.get(handler)) {
try handler.finish(state.asInstanceOf[handler.State])
catch { case _: Exception => () }
}
}

import BatchScriptRunner.PreciseScriptedError
def processStatement(handler: StatementHandler, statement: Statement, states: States): Unit = {
val state = states(handler).asInstanceOf[handler.State]
val nextState =
try { Right(handler(statement.command, statement.arguments, state)) } catch {
case e: Exception => Left(e)
}
nextState match {
case Left(err) =>
if (statement.successExpected) {
err match {
case t: TestFailed =>
val errorMessage = s"${t.getMessage} produced by"
throw new PreciseScriptedError(statement, errorMessage, null)
case _ => throw new PreciseScriptedError(statement, "Command failed", err)
}
} else ()
case Right(s) =>
if (statement.successExpected) states(handler) = s
else throw new PreciseScriptedError(statement, "Expecting error at", null)
}
}
}

private[sbt] object BatchScriptRunner {
import scala.collection.mutable
type States = mutable.HashMap[StatementHandler, Any]

// Should be used instead of sbt.internal.scripted.TestException that doesn't show failed command
final class PreciseScriptedError(st: Statement, msg: String, e: Throwable)
extends RuntimeException(s"$msg: '${st.command} ${st.arguments.mkString(" ")}'", e) {
override def fillInStackTrace = e
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,18 @@ final class IncHandler(directory: File, cacheDir: File, scriptedLog: ManagedLogg
type IncCommand = (ProjectStructure, List[String], IncInstance) => Unit

val compiler = new IncrementalCompilerImpl
def initialState: Option[IncInstance] = None
def finish(state: Option[IncInstance]): Unit = ()

def initialState: Option[IncInstance] = {
initBuildStructure()
None
}

def finish(state: Option[IncInstance]): Unit = {
// Required so that next projects re-read the project structure
buildStructure.clear()
()
}

val buildStructure: mutable.Map[String, ProjectStructure] = mutable.Map.empty
def initBuildStructure(): Unit = {
val build = initBuild
Expand All @@ -70,8 +80,6 @@ final class IncHandler(directory: File, cacheDir: File, scriptedLog: ManagedLogg
}
}

initBuildStructure()

private final val RootIdentifier = "root"
def initBuild: Build = {
if ((directory / "build.json").exists) {
Expand Down Expand Up @@ -108,14 +116,14 @@ final class IncHandler(directory: File, cacheDir: File, scriptedLog: ManagedLogg
private final val noLogger = Logger.Null
private[this] def onNewIncInstance(p: ProjectStructure): IncInstance = {
val scalaVersion = p.scalaVersion
val (compilerBridge, si) = IncHandler.scriptedCompilerCache.get(scalaVersion) match {
val (compilerBridge, si) = IncHandler.getCompilerCacheFor(scalaVersion) match {
case Some(alreadyInstantiated) =>
alreadyInstantiated
case None =>
val compilerBridge = getCompilerBridge(cacheDir, noLogger, scalaVersion)
val si = scalaInstance(scalaVersion, cacheDir, noLogger)
val toCache = (compilerBridge, si)
IncHandler.scriptedCompilerCache.put(scalaVersion, toCache)
IncHandler.putCompilerCache(scalaVersion, toCache)
toCache
}
val analyzingCompiler = scalaCompiler(si, compilerBridge)
Expand Down Expand Up @@ -505,7 +513,12 @@ case class ProjectStructure(

object IncHandler {
type Cached = (File, xsbti.compile.ScalaInstance)
private[internal] final val scriptedCompilerCache = new mutable.WeakHashMap[String, Cached]()
private[this] final val scriptedCompilerCache = new mutable.WeakHashMap[String, Cached]()
def getCompilerCacheFor(scalaVersion: String): Option[Cached] =
synchronized(scriptedCompilerCache.get(scalaVersion))
def putCompilerCache(scalaVersion: String, cached: Cached): Option[Cached] =
synchronized(scriptedCompilerCache.put(scalaVersion, cached))

private[internal] final val classLoaderCache = Some(
new ClassLoaderCache(new URLClassLoader(Array())))
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,79 @@ package sbt.internal.inc

import java.io.File

import sbt.internal.scripted.ScriptedRunnerImpl
import sbt.internal.scripted.{ HandlersProvider, ListTests, ScriptedTest }
import sbt.io.IO
import sbt.util.Logger

import scala.collection.parallel.ParSeq

class IncScriptedRunner {
def run(resourceBaseDirectory: File, bufferLog: Boolean, tests: Array[String]): Unit = {
IO.withTemporaryDirectory { tempDir =>
// Create a global temporary directory to store the bridge et al
val handlers = new IncScriptedHandlers(tempDir)
ScriptedRunnerImpl.run(resourceBaseDirectory, bufferLog, tests, handlers);
ScriptedRunnerImpl.run(resourceBaseDirectory, bufferLog, tests, handlers, 4)
}
}
}

object ScriptedRunnerImpl {
type TestRunner = () => Seq[Option[String]]

def run(
resourceBaseDirectory: File,
bufferLog: Boolean,
tests: Array[String],
handlersProvider: HandlersProvider,
instances: Int
): Unit = {
val globalLogger = newLogger
val logsDir = newScriptedLogsDir
val runner = new ScriptedTests(resourceBaseDirectory, bufferLog, handlersProvider, logsDir)
val scriptedTests = get(tests, resourceBaseDirectory, globalLogger)
val scriptedRunners = runner.batchScriptedRunner(scriptedTests, instances)
val parallelRunners = scriptedRunners.toParArray
val pool = new java.util.concurrent.ForkJoinPool(instances)
parallelRunners.tasksupport = new scala.collection.parallel.ForkJoinTaskSupport(pool)
runAllInParallel(parallelRunners)
globalLogger.info(s"Log files can be found at ${logsDir.getAbsolutePath}")
}

private val nl = IO.Newline
private val nlt = nl + "\t"
class ScriptedFailure(tests: Seq[String]) extends RuntimeException(tests.mkString(nlt, nlt, nl)) {
// We are not interested in the stack trace here, only the failing tests
override def fillInStackTrace = this
}

private def reportErrors(errors: Seq[String]): Unit =
if (errors.nonEmpty) throw new ScriptedFailure(errors) else ()

def runAllInParallel(tests: ParSeq[TestRunner]): Unit = {
reportErrors(tests.flatMap(test => test.apply().flatten.toSeq).toList)
}

def get(tests: Seq[String], baseDirectory: File, log: Logger): Seq[ScriptedTest] =
if (tests.isEmpty) listTests(baseDirectory, log) else parseTests(tests)

def listTests(baseDirectory: File, log: Logger): Seq[ScriptedTest] =
(new ListTests(baseDirectory, _ => true, log)).listTests

def parseTests(in: Seq[String]): Seq[ScriptedTest] = for (testString <- in) yield {
testString.split("/").map(_.trim) match {
case Array(group, name) => ScriptedTest(group, name)
case elems =>
sys.error(s"Expected two arguments 'group/name', obtained ${elems.mkString("/")}")
}
}

private[sbt] def newLogger: Logger = sbt.internal.util.ConsoleLogger()

private[this] val random = new java.util.Random()
private[sbt] def newScriptedLogsDir: File = {
val randomName = "scripted-logs-" + java.lang.Integer.toHexString(random.nextInt)
val logsDir = new File(IO.temporaryDirectory, randomName)
IO.createDirectory(logsDir)
logsDir
}
}
Loading

0 comments on commit dd20e7c

Please sign in to comment.