diff --git a/contrib/playlib/src/mill/playlib/Twirl.scala b/contrib/playlib/src/mill/playlib/Twirl.scala index 64da53b3114..adf123fdd75 100644 --- a/contrib/playlib/src/mill/playlib/Twirl.scala +++ b/contrib/playlib/src/mill/playlib/Twirl.scala @@ -7,21 +7,17 @@ trait Twirl extends TwirlModule with Layout { override def twirlSources=T.sources{ app() } - override def twirlAdditionalImports = Seq( - "_root_.play.twirl.api.TwirlFeatureImports._", - "_root_.play.twirl.api.TwirlHelperImports._", - "_root_.play.twirl.api.Html", - "_root_.play.twirl.api.JavaScript", - "_root_.play.twirl.api.Txt", - "_root_.play.twirl.api.Xml", - "models._", - "controllers._", - "play.api.i18n._", - "views.html._", - "play.api.templates.PlayMagic._", - "play.api.mvc._", - "play.api.data._" - ) + override def twirlImports = T { + super.twirlImports() ++ Seq( + "models._", + "controllers._", + "play.api.i18n._", + "views.html._", + "play.api.templates.PlayMagic._", + "play.api.mvc._", + "play.api.data._" + ) + } def twirlOutput = T{Seq(compileTwirl().classes)} diff --git a/contrib/twirllib/src/TwirlModule.scala b/contrib/twirllib/src/TwirlModule.scala index 56b1f52b7b5..987b2c5ef38 100644 --- a/contrib/twirllib/src/TwirlModule.scala +++ b/contrib/twirllib/src/TwirlModule.scala @@ -33,7 +33,11 @@ trait TwirlModule extends mill.Module { ) } - def twirlAdditionalImports: Seq[String] = Nil + def twirlImports: T[Seq[String]] = T { + TwirlWorkerApi.twirlWorker.defaultImports(twirlClasspath().map(_.path)) + } + + def twirlFormats: T[Map[String, String]] = TwirlWorkerApi.twirlWorker.defaultFormats def twirlConstructorAnnotations: Seq[String] = Nil @@ -47,7 +51,8 @@ trait TwirlModule extends mill.Module { twirlClasspath().map(_.path), twirlSources().map(_.path), T.dest, - twirlAdditionalImports, + twirlImports(), + twirlFormats(), twirlConstructorAnnotations, twirlCodec, twirlInclusiveDot) diff --git a/contrib/twirllib/src/TwirlWorker.scala b/contrib/twirllib/src/TwirlWorker.scala index 19eb47251e7..a516745f6f7 100644 --- a/contrib/twirllib/src/TwirlWorker.scala +++ b/contrib/twirllib/src/TwirlWorker.scala @@ -9,12 +9,15 @@ import java.nio.charset.Charset import mill.api.PathRef import mill.scalalib.api.CompilationResult +import scala.jdk.CollectionConverters._ import scala.io.Codec +import scala.util.matching.Regex + class TwirlWorker { - private var twirlInstanceCache = Option.empty[(Long, TwirlWorkerApi)] + private var twirlInstanceCache = Option.empty[(Long, (TwirlWorkerApi, Class[_]))] - private def twirl(twirlClasspath: Agg[os.Path]) = { + private def twirlCompilerAndClass(twirlClasspath: Agg[os.Path]): (TwirlWorkerApi, Class[_]) = { val classloaderSig = twirlClasspath.map(p => p.toString().hashCode + os.mtime(p)).sum twirlInstanceCache match { case Some((sig, instance)) if sig == classloaderSig => instance @@ -23,7 +26,7 @@ class TwirlWorker { // Switched to using the java api because of the hack-ish thing going on later. // - // * we'll need to construct a collection of additional imports (will need to also include the defaults and add the user-provided additional imports) + // * we'll need to construct a collection of imports // * we'll need to construct a collection of constructor annotations// * // * the default collection in scala api is a Seq[String] // * but it is defined in a different classloader (namely in cl) @@ -57,26 +60,20 @@ class TwirlWorker { cl.loadClass("scala.io.Codec"), classOf[Boolean]) - val defaultImportsMethod = twirlCompilerClass.getField("DEFAULT_IMPORTS") - - val hashSetConstructor = hashSetClass.getConstructor(cl.loadClass("java.util.Collection")) - val instance = new TwirlWorkerApi { override def compileTwirl(source: File, sourceDirectory: File, generatedDirectory: File, formatterType: String, - additionalImports: Seq[String], + imports: Seq[String], constructorAnnotations: Seq[String], codec: Codec, inclusiveDot: Boolean) { - // val defaultImports = play.japi.twirl.compiler.TwirlCompiler.DEFAULT_IMPORTS() - // val twirlAdditionalImports = new HashSet(defaultImports) - // additionalImports.foreach(twirlAdditionalImports.add) - val defaultImports = defaultImportsMethod.get(null) // unmodifiable collection - val twirlAdditionalImports = hashSetConstructor.newInstance(defaultImports).asInstanceOf[Object] - val hashSetAddMethod = twirlAdditionalImports.getClass.getMethod("add", classOf[Object]) - additionalImports.foreach(hashSetAddMethod.invoke(twirlAdditionalImports, _)) + // val twirlImports = new HashSet() + // imports.foreach(twirlImports.add) + val twirlImports = hashSetClass.newInstance().asInstanceOf[Object] + val hashSetAddMethod = twirlImports.getClass.getMethod("add", classOf[Object]) + imports.foreach(hashSetAddMethod.invoke(twirlImports, _)) // Codec.apply(Charset.forName(codec.charSet.name())) val twirlCodec = codecApplyMethod.invoke(null, charsetForNameMethod.invoke(null, codec.charSet.name())) @@ -102,37 +99,56 @@ class TwirlWorker { sourceDirectory, generatedDirectory, formatterType, - twirlAdditionalImports, + twirlImports, twirlConstructorAnnotations, twirlCodec, Boolean.box(inclusiveDot) ) } } - twirlInstanceCache = Some((classloaderSig, instance)) - instance + twirlInstanceCache = Some(classloaderSig -> (instance -> twirlCompilerClass)) + (instance, twirlCompilerClass) } } + private def twirl(twirlClasspath: Agg[os.Path]): TwirlWorkerApi = + twirlCompilerAndClass(twirlClasspath)._1 + + private def twirlClass(twirlClasspath: Agg[os.Path]): Class[_] = + twirlCompilerAndClass(twirlClasspath)._2 + + def defaultImports(twirlClasspath: Agg[os.Path]): Seq[String] = + twirlClass(twirlClasspath).getField("DEFAULT_IMPORTS") + .get(null).asInstanceOf[java.util.Set[String]].asScala.toSeq + + def defaultFormats: Map[String, String] = + Map( + "html" -> "play.twirl.api.HtmlFormat", + "xml" -> "play.twirl.api.XmlFormat", + "js" -> "play.twirl.api.JavaScriptFormat", + "txt" -> "play.twirl.api.TxtFormat") + def compile(twirlClasspath: Agg[os.Path], sourceDirectories: Seq[os.Path], dest: os.Path, - additionalImports: Seq[String], + imports: Seq[String], + formats: Map[String, String], constructorAnnotations: Seq[String], codec: Codec, inclusiveDot: Boolean) (implicit ctx: mill.api.Ctx): mill.api.Result[CompilationResult] = { val compiler = twirl(twirlClasspath) + val formatExtsRegex = formats.keys.map(Regex.quote).mkString("|") def compileTwirlDir(inputDir: os.Path) { - os.walk(inputDir).filter(_.last.matches(".*.scala.(html|xml|js|txt)")) + os.walk(inputDir).filter(_.last.matches(s".*.scala.($formatExtsRegex)")) .foreach { template => - val extFormat = twirlExtensionFormat(template.last) + val extClass = twirlExtensionClass(template.last, formats) compiler.compileTwirl(template.toIO, inputDir.toIO, dest.toIO, - s"play.twirl.api.$extFormat", - additionalImports, + extClass, + imports, constructorAnnotations, codec, inclusiveDot @@ -148,11 +164,10 @@ class TwirlWorker { mill.api.Result.Success(CompilationResult(zincFile, PathRef(classesDir))) } - private def twirlExtensionFormat(name: String) = - if (name.endsWith("html")) "HtmlFormat" - else if (name.endsWith("xml")) "XmlFormat" - else if (name.endsWith("js")) "JavaScriptFormat" - else "TxtFormat" + private def twirlExtensionClass(name: String, formats: Map[String, String]) = + formats.collectFirst { case (ext, klass) if name.endsWith(ext) => klass }.getOrElse { + throw new IllegalStateException(s"Unknown twirl extension for file: $name. Known extensions: ${formats.keys.mkString(", ")}") + } } trait TwirlWorkerApi { @@ -160,7 +175,7 @@ trait TwirlWorkerApi { sourceDirectory: File, generatedDirectory: File, formatterType: String, - additionalImports: Seq[String], + imports: Seq[String], constructorAnnotations: Seq[String], codec: Codec, inclusiveDot: Boolean) diff --git a/contrib/twirllib/test/resources/hello-world-inclusive-dot/core/views/test.scala.svg b/contrib/twirllib/test/resources/hello-world-inclusive-dot/core/views/test.scala.svg new file mode 100644 index 00000000000..714efc7ef0b --- /dev/null +++ b/contrib/twirllib/test/resources/hello-world-inclusive-dot/core/views/test.scala.svg @@ -0,0 +1 @@ + diff --git a/contrib/twirllib/test/resources/hello-world/core/views/test.scala.svg b/contrib/twirllib/test/resources/hello-world/core/views/test.scala.svg new file mode 100644 index 00000000000..714efc7ef0b --- /dev/null +++ b/contrib/twirllib/test/resources/hello-world/core/views/test.scala.svg @@ -0,0 +1 @@ + diff --git a/contrib/twirllib/test/src/HelloWorldTests.scala b/contrib/twirllib/test/src/HelloWorldTests.scala index 22c05886199..7f62916a1af 100644 --- a/contrib/twirllib/test/src/HelloWorldTests.scala +++ b/contrib/twirllib/test/src/HelloWorldTests.scala @@ -21,7 +21,8 @@ object HelloWorldTests extends TestSuite { object HelloWorld extends HelloBase { object core extends HelloWorldModule { - override def twirlAdditionalImports: Seq[String] = testAdditionalImports + override def twirlImports = super.twirlImports() ++ testAdditionalImports + override def twirlFormats = super.twirlFormats() ++ Map("svg" -> "play.twirl.api.HtmlFormat") override def twirlConstructorAnnotations: Seq[String] = testConstructorAnnotations } @@ -31,6 +32,7 @@ object HelloWorldTests extends TestSuite { object core extends HelloWorldModule { override def twirlInclusiveDot: Boolean = true + override def twirlFormats = super.twirlFormats() ++ Map("svg" -> "play.twirl.api.HtmlFormat") } } @@ -51,8 +53,9 @@ object HelloWorldTests extends TestSuite { } def compileClassfiles: Seq[os.RelPath] = Seq[os.RelPath]( - os.rel / "hello.template.scala", - os.rel / "wrapper.template.scala" + os.rel / 'html / "hello.template.scala", + os.rel / 'html / "wrapper.template.scala", + os.rel / 'svg / "test.template.scala" ) def expectedDefaultImports: Seq[String] = Seq( @@ -92,14 +95,14 @@ object HelloWorldTests extends TestSuite { val outputFiles = os.walk(result.classes.path).filter(_.last.endsWith(".scala")) val expectedClassfiles = compileClassfiles.map( - eval.outPath / 'core / 'compileTwirl / 'dest / 'html / _ + eval.outPath / 'core / 'compileTwirl / 'dest / _ ) assert( result.classes.path == eval.outPath / 'core / 'compileTwirl / 'dest, outputFiles.nonEmpty, outputFiles.forall(expectedClassfiles.contains), - outputFiles.size == 2, + outputFiles.size == 3, evalCount > 0, outputFiles.forall { p => val lines = os.read.lines(p).map(_.trim) @@ -124,7 +127,7 @@ object HelloWorldTests extends TestSuite { val outputFiles = os.walk(result.classes.path).filter(_.last.endsWith(".scala")) val expectedClassfiles = compileClassfiles.map( name => - eval.outPath / 'core / 'compileTwirl / 'dest / 'html / name.toString().replace(".template.scala", "$$TwirlInclusiveDot.template.scala") + eval.outPath / 'core / 'compileTwirl / 'dest / name / os.RelPath.up / name.last.replace(".template.scala", "$$TwirlInclusiveDot.template.scala") ) println(s"outputFiles: $outputFiles") @@ -133,7 +136,7 @@ object HelloWorldTests extends TestSuite { result.classes.path == eval.outPath / 'core / 'compileTwirl / 'dest, outputFiles.nonEmpty, outputFiles.forall(expectedClassfiles.contains), - outputFiles.size == 2, + outputFiles.size == 3, evalCount > 0, outputFiles.filter(_.toString().contains("hello.template.scala")).forall { p => val lines = os.read.lines(p).map(_.trim) diff --git a/docs/pages/9 - Contrib Modules.md b/docs/pages/9 - Contrib Modules.md index 16ab85ba379..80fea16e8c9 100644 --- a/docs/pages/9 - Contrib Modules.md +++ b/docs/pages/9 - Contrib Modules.md @@ -177,7 +177,7 @@ object project extends BuildInfo { The contrib.bsp module was created in order to integrate the Mill build tool with IntelliJ IDEA via the Build Server Protocol (BSP). It implements most of -the server side functionality described in BSP, and can therefore connect to a +the server side functionality described in BSP, and can therefore connect to a BSP client, including the one behind IntelliJ IDEA. This allows a lot of mill tasks to be executed from the IDE. @@ -818,14 +818,6 @@ object app extends ScalaModule with TwirlModule { } ``` -### Twirl configuration options - -* `def twirlVersion: T[String]` (mandatory) - the version of the twirl compiler to use, like "1.3.15" -* `def twirlAdditionalImports: Seq[String] = Nil` - the additional imports that will be added by twirl compiler to the top of all templates -* `def twirlConstructorAnnotations: Seq[String] = Nil` - annotations added to the generated classes' constructors (note it only applies to templates with `@this(...)` constructors) -* `def twirlCodec = Codec(Properties.sourceEncoding)` - the codec used to generate the files (the default is the same sbt plugin uses) -* `def twirlInclusiveDot: Boolean = false` - ### Details The following filesystem layout is expected: @@ -833,9 +825,9 @@ The following filesystem layout is expected: ```text build.sc app/ - views/ - view1.scala.html - view2.scala.html + views/ + view1.scala.html + view2.scala.html ``` `TwirlModule` adds the `compileTwirl` task to the module: @@ -861,7 +853,28 @@ object app extends ScalaModule with TwirlModule { } ``` -To add additional imports to all of the twirl templates: +### Twirl configuration options + +#### `def twirlVersion: T[String]` + +Mandatory - the version of the twirl compiler to use, like "1.3.15". + +#### `def twirlImports: T[Seq[String]]` + +The imports that will be added by the twirl compiler to the top of all templates, defaults to [twirl's default imports](https://github.com/playframework/twirl/blob/1.5.0/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala#L166-L173): + +```scala +Seq( + "_root_.play.twirl.api.TwirlFeatureImports._", + "_root_.play.twirl.api.TwirlHelperImports._", + "_root_.play.twirl.api.Html", + "_root_.play.twirl.api.JavaScript", + "_root_.play.twirl.api.Txt", + "_root_.play.twirl.api.Xml" +) +``` + +To add additional imports to all of the twirl templates, override `twirlImports` in your build: ```scala // build.sc @@ -871,32 +884,62 @@ import $ivy.`com.lihaoyi::mill-contrib-twirllib:$MILL_VERSION`, mill.twirllib._ object app extends ScalaModule with TwirlModule { def twirlVersion = "1.3.15" - override def twirlAdditionalImports = Seq("my.additional.stuff._", "my.other.stuff._") + override def twirlImports = super.twirlImports() ++ Seq("my.additional.stuff._", "my.other.stuff._") def generatedSources = T{ Seq(compileTwirl().classes) } } + +// out.template.scala +@import _root_.play.twirl.api.TwirlFeatureImports._ +// ... +@import _root_.play.twirl.api.Xml +@import my.additional.stuff._ +@import my.other.stuff._ ``` -as the result all templates will get this line at the top: +To exclude the default imports, simply override `twirlImports` without calling `super`: ```scala -@import "my.additional.stuff._" -@import "my.other.stuff._" +// build.sc +object app extends ScalaModule with TwirlModule { + // ... + override def twirlImports = Seq("my.stuff._") +} + +// out.template.scala +@import my.stuff._ ``` -Besides that, twirl compiler has default imports, at the moment these: +#### `def twirlFormats: Map[String, String]` + +A mapping of file extensions to class names that will be compiled by twirl, e.g. `Map("html" -> "play.twirl.api.HtmlFormat")`. +By default `html`, `xml`, `js`, and `txt` files will be compiled using the corresponding [twirl format](https://github.com/playframework/twirl/blob/1.5.0/api/shared/src/main/scala/play/twirl/api/Formats.scala). + +To add additional formats, override `twirlFormats` in your build: ```scala -Seq( - "_root_.play.twirl.api.TwirlFeatureImports._", - "_root_.play.twirl.api.TwirlHelperImports._", - "_root_.play.twirl.api.Html", - "_root_.play.twirl.api.JavaScript", - "_root_.play.twirl.api.Txt", - "_root_.play.twirl.api.Xml" -) +// build.sc +import mill.scalalib._ + +import $ivy.`com.lihaoyi::mill-contrib-twirllib:$MILL_VERSION`, mill.twirllib._ + +object app extends ScalaModule with TwirlModule { + def twirlVersion = "1.3.15" + override def twirlFormats = super.twirlFormats() + Map("svg" -> "play.twirl.api.HtmlFormat") + def generatedSources = T{ Seq(compileTwirl().classes) } +} ``` -These imports will always be added to every template. You don't need to list them if you override `twirlAdditionalImports`. +#### `def twirlConstructorAnnotations: Seq[String] = Nil` + +Annotations added to the generated classes' constructors (note it only applies to templates with `@this(...)` constructors). + +#### `def twirlCodec = Codec(Properties.sourceEncoding)` + +The codec used to generate the files (the default is the same sbt plugin uses). + +#### `def twirlInclusiveDot: Boolean = false` + +Whether the twirl parser should parse with an inclusive dot. ### Example There's an [example project](https://github.com/lihaoyi/cask/tree/master/example/twirl)