diff --git a/munit/js/src/main/scala/munit/internal/PlatformCompat.scala b/munit/js/src/main/scala/munit/internal/PlatformCompat.scala index f5fb39dd..b5ed7e24 100644 --- a/munit/js/src/main/scala/munit/internal/PlatformCompat.scala +++ b/munit/js/src/main/scala/munit/internal/PlatformCompat.scala @@ -9,11 +9,17 @@ import sbt.testing.EventHandler import sbt.testing.Logger import scala.concurrent.Promise import scala.concurrent.duration.Duration +import scala.concurrent.Await +import scala.concurrent.Awaitable import scala.concurrent.ExecutionContext import scala.scalajs.js.timers import java.util.concurrent.TimeoutException object PlatformCompat { + def awaitResult[T](awaitable: Awaitable[T]): T = { + Await.result(awaitable, Duration.Inf) + } + def executeAsync( task: Task, eventHandler: EventHandler, diff --git a/munit/js-native/src/main/scala/munit/internal/junitinterface/JUnitTask.scala b/munit/js/src/main/scala/munit/internal/junitinterface/JUnitTask.scala similarity index 100% rename from munit/js-native/src/main/scala/munit/internal/junitinterface/JUnitTask.scala rename to munit/js/src/main/scala/munit/internal/junitinterface/JUnitTask.scala diff --git a/munit/jvm/src/main/scala/munit/internal/PlatformCompat.scala b/munit/jvm/src/main/scala/munit/internal/PlatformCompat.scala index f56ef6dc..ae6b19dc 100644 --- a/munit/jvm/src/main/scala/munit/internal/PlatformCompat.scala +++ b/munit/jvm/src/main/scala/munit/internal/PlatformCompat.scala @@ -12,6 +12,8 @@ import java.util.concurrent.{ TimeUnit, TimeoutException } +import scala.concurrent.Await +import scala.concurrent.Awaitable import scala.concurrent.Promise import scala.concurrent.ExecutionContext import java.util.concurrent.atomic.AtomicInteger @@ -28,6 +30,11 @@ object PlatformCompat { } } ) + + def awaitResult[T](awaitable: Awaitable[T]): T = { + Await.result(awaitable, Duration.Inf) + } + def executeAsync( task: Task, eventHandler: EventHandler, diff --git a/munit/native/src/main/scala/munit/internal/PlatformCompat.scala b/munit/native/src/main/scala/munit/internal/PlatformCompat.scala index 79112d82..9ddbd8e3 100644 --- a/munit/native/src/main/scala/munit/internal/PlatformCompat.scala +++ b/munit/native/src/main/scala/munit/internal/PlatformCompat.scala @@ -7,10 +7,17 @@ import scala.scalanative.reflect.Reflect import sbt.testing.Task import sbt.testing.EventHandler import sbt.testing.Logger +import scala.concurrent.Await +import scala.concurrent.Awaitable import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext object PlatformCompat { + def awaitResult[T](awaitable: Awaitable[T]): T = { + scalanative.runtime.loop() + Await.result(awaitable, Duration.Inf) + } + def executeAsync( task: Task, eventHandler: EventHandler, @@ -24,9 +31,7 @@ object PlatformCompat { duration: Duration, ec: ExecutionContext ): Future[T] = { - val f = startFuture() - scala.scalanative.runtime.loop() - f + startFuture() } def setTimeout(ms: Int)(body: => Unit): () => Unit = { Thread.sleep(ms) diff --git a/munit/native/src/main/scala/munit/internal/junitinterface/JUnitTask.scala b/munit/native/src/main/scala/munit/internal/junitinterface/JUnitTask.scala new file mode 100644 index 00000000..791ae277 --- /dev/null +++ b/munit/native/src/main/scala/munit/internal/junitinterface/JUnitTask.scala @@ -0,0 +1,43 @@ +/* + * Adapted from https://github.com/scala-js/scala-js, see NOTICE.md. + */ + +package munit.internal.junitinterface + +import munit.internal.PlatformCompat +import org.junit.runner.notification.RunNotifier +import sbt.testing._ +import scala.concurrent.ExecutionContext.Implicits.global + +/* Implementation note: In JUnitTask we use Future[Try[Unit]] instead of simply + * Future[Unit]. This is to prevent Scala's Future implementation to box/wrap + * fatal errors (most importantly AssertionError) in ExecutionExceptions. We + * need to prevent the wrapping in order to hide the fact that we use async + * under the hood and stay consistent with JVM JUnit. + */ +final class JUnitTask( + _taskDef: TaskDef, + runSettings: RunSettings, + classLoader: ClassLoader +) extends Task { + + override def taskDef(): TaskDef = _taskDef + override def tags(): Array[String] = Array.empty + + def execute( + eventHandler: EventHandler, + loggers: Array[Logger] + ): Array[Task] = { + PlatformCompat.newRunner(taskDef(), classLoader) match { + case None => + case Some(runner) => + runner.filter(runSettings.tags) + val reporter = + new JUnitReporter(eventHandler, loggers, runSettings, taskDef()) + val notifier: RunNotifier = new MUnitRunNotifier(reporter) + runner.run(notifier) + } + Array() + } + +} diff --git a/munit/shared/src/main/scala/munit/MUnitRunner.scala b/munit/shared/src/main/scala/munit/MUnitRunner.scala index 275aa98c..16081c7f 100644 --- a/munit/shared/src/main/scala/munit/MUnitRunner.scala +++ b/munit/shared/src/main/scala/munit/MUnitRunner.scala @@ -16,7 +16,6 @@ import org.junit.runner.notification.RunNotifier import java.lang.reflect.Modifier import scala.collection.mutable -import scala.concurrent.Await import scala.concurrent.ExecutionContext import scala.concurrent.Future import scala.concurrent.duration.Duration @@ -113,7 +112,7 @@ class MUnitRunner(val cls: Class[_ <: Suite], newInstance: () => Suite) } override def run(notifier: RunNotifier): Unit = { - Await.result(runAsync(notifier), Duration.Inf) + PlatformCompat.awaitResult(runAsync(notifier)) } def runAsync(notifier: RunNotifier): Future[Unit] = { val description = getDescription() diff --git a/tests/native/src/test/scala/munit/Issue695Suite.scala b/tests/native/src/test/scala/munit/Issue695Suite.scala new file mode 100644 index 00000000..ee1324e0 --- /dev/null +++ b/tests/native/src/test/scala/munit/Issue695Suite.scala @@ -0,0 +1,17 @@ +package munit + +import scala.concurrent._ + +class Issue695Suite extends FunSuite { + override def munitExecutionContext = ExecutionContext.global + + test("await task on global EC") { + val p = Promise[Unit]() + ExecutionContext.global.execute { () => + Thread.sleep(1000) + p.success(()) + } + p.future + } + +}