diff --git a/scalalib/src/mill/scalalib/TestModule.scala b/scalalib/src/mill/scalalib/TestModule.scala index 4a323ced7cf..3eea4d1bf3d 100644 --- a/scalalib/src/mill/scalalib/TestModule.scala +++ b/scalalib/src/mill/scalalib/TestModule.scala @@ -3,7 +3,7 @@ package mill.scalalib import mill.api.{Ctx, PathRef, Result} import mill.define.{Command, Task, TaskModule} import mill.scalalib.bsp.{BspBuildTarget, BspModule} -import mill.testrunner.{Framework, TestArgs, TestResult, TestRunner} +import mill.testrunner.{Framework, TestArgs, TestResult, TestRunner, TestRunnerUtils} import mill.util.Jvm import mill.{Agg, T} import sbt.testing.Status @@ -46,6 +46,25 @@ trait TestModule */ def testFramework: T[String] + def discoveredTestClasses: T[Seq[String]] = T { + val classes = Jvm.inprocess( + runClasspath().map(_.path), + classLoaderOverrideSbtTesting = true, + isolated = true, + closeContextClassLoaderWhenDone = true, + cl => { + val framework = Framework.framework(testFramework())(cl) + val classes = TestRunnerUtils.discoverTests(cl, framework, testClasspath().map(_.path)) + classes.toSeq.map(_._1.getName()) + .map { + case s if s.endsWith("$") => s.dropRight(1) + case s => s + } + } + ) + classes.sorted + } + /** * Discovers and runs the module's tests in a subprocess, reporting the * results to the console. diff --git a/scalalib/test/src/mill/scalalib/TestRunnerTests.scala b/scalalib/test/src/mill/scalalib/TestRunnerTests.scala index fde5271ec78..cfbbd1b232e 100644 --- a/scalalib/test/src/mill/scalalib/TestRunnerTests.scala +++ b/scalalib/test/src/mill/scalalib/TestRunnerTests.scala @@ -78,9 +78,9 @@ object TestRunnerTests extends TestSuite { } override def tests: Tests = Tests { - "TestRunner" - { - "utest" - { - "test case lookup" - workspaceTest(testrunner) { eval => + test("TestRunner") - { + test("utest") - { + test("test case lookup") - workspaceTest(testrunner) { eval => val Right((result, _)) = eval.apply(testrunner.utest.test()) val test = result.asInstanceOf[(String, Seq[mill.testrunner.TestResult])] assert( @@ -88,7 +88,17 @@ object TestRunnerTests extends TestSuite { ) junitReportIn(eval.outPath, "utest").shouldHave(3 -> Status.Success) } - "testOnly" - { + test("discoveredTestClasses") - workspaceTest(testrunner) { eval => + val Right((res, _)) = eval.apply(testrunner.utest.discoveredTestClasses) + val expected = Seq( + "mill.scalalib.BarTests", + "mill.scalalib.FooTests", + "mill.scalalib.FoobarTests" + ) + assert(res == expected) + expected + } + test("testOnly") - { def testOnly(eval: TestEvaluator, args: Seq[String], size: Int) = { val Right((result1, _)) = eval.apply(testrunner.utest.testOnly(args: _*)) val testOnly = result1.asInstanceOf[(String, Seq[mill.testrunner.TestResult])] @@ -144,6 +154,12 @@ object TestRunnerTests extends TestSuite { junitReportIn(eval.outPath, "scalatest").shouldHave(2 -> Status.Success) } } + test("discoveredTestClasses") - workspaceTest(testrunner) { eval => + val Right((res, _)) = eval.apply(testrunner.scalatest.discoveredTestClasses) + val expected = Seq("mill.scalalib.ScalaTestSpec") + assert(res == expected) + expected + } } "ZioTest" - { @@ -154,6 +170,12 @@ object TestRunnerTests extends TestSuite { junitReportIn(eval.outPath, "ziotest").shouldHave(1 -> Status.Success) } } + test("discoveredTestClasses") - workspaceTest(testrunner) { eval => + val Right((res, _)) = eval.apply(testrunner.ziotest.discoveredTestClasses) + val expected = Seq("mill.scalalib.ZioTestSpec") + assert(res == expected) + expected + } } } }