diff --git a/contrib/playlib/src/mill/playlib/Twirl.scala b/contrib/playlib/src/mill/playlib/Twirl.scala
index 64da53b3114..adf123fdd75 100644
--- a/contrib/playlib/src/mill/playlib/Twirl.scala
+++ b/contrib/playlib/src/mill/playlib/Twirl.scala
@@ -7,21 +7,17 @@ trait Twirl extends TwirlModule with Layout {
override def twirlSources=T.sources{ app() }
- override def twirlAdditionalImports = Seq(
- "_root_.play.twirl.api.TwirlFeatureImports._",
- "_root_.play.twirl.api.TwirlHelperImports._",
- "_root_.play.twirl.api.Html",
- "_root_.play.twirl.api.JavaScript",
- "_root_.play.twirl.api.Txt",
- "_root_.play.twirl.api.Xml",
- "models._",
- "controllers._",
- "play.api.i18n._",
- "views.html._",
- "play.api.templates.PlayMagic._",
- "play.api.mvc._",
- "play.api.data._"
- )
+ override def twirlImports = T {
+ super.twirlImports() ++ Seq(
+ "models._",
+ "controllers._",
+ "play.api.i18n._",
+ "views.html._",
+ "play.api.templates.PlayMagic._",
+ "play.api.mvc._",
+ "play.api.data._"
+ )
+ }
def twirlOutput = T{Seq(compileTwirl().classes)}
diff --git a/contrib/twirllib/src/TwirlModule.scala b/contrib/twirllib/src/TwirlModule.scala
index 56b1f52b7b5..987b2c5ef38 100644
--- a/contrib/twirllib/src/TwirlModule.scala
+++ b/contrib/twirllib/src/TwirlModule.scala
@@ -33,7 +33,11 @@ trait TwirlModule extends mill.Module {
)
}
- def twirlAdditionalImports: Seq[String] = Nil
+ def twirlImports: T[Seq[String]] = T {
+ TwirlWorkerApi.twirlWorker.defaultImports(twirlClasspath().map(_.path))
+ }
+
+ def twirlFormats: T[Map[String, String]] = TwirlWorkerApi.twirlWorker.defaultFormats
def twirlConstructorAnnotations: Seq[String] = Nil
@@ -47,7 +51,8 @@ trait TwirlModule extends mill.Module {
twirlClasspath().map(_.path),
twirlSources().map(_.path),
T.dest,
- twirlAdditionalImports,
+ twirlImports(),
+ twirlFormats(),
twirlConstructorAnnotations,
twirlCodec,
twirlInclusiveDot)
diff --git a/contrib/twirllib/src/TwirlWorker.scala b/contrib/twirllib/src/TwirlWorker.scala
index 19eb47251e7..a516745f6f7 100644
--- a/contrib/twirllib/src/TwirlWorker.scala
+++ b/contrib/twirllib/src/TwirlWorker.scala
@@ -9,12 +9,15 @@ import java.nio.charset.Charset
import mill.api.PathRef
import mill.scalalib.api.CompilationResult
+import scala.jdk.CollectionConverters._
import scala.io.Codec
+import scala.util.matching.Regex
+
class TwirlWorker {
- private var twirlInstanceCache = Option.empty[(Long, TwirlWorkerApi)]
+ private var twirlInstanceCache = Option.empty[(Long, (TwirlWorkerApi, Class[_]))]
- private def twirl(twirlClasspath: Agg[os.Path]) = {
+ private def twirlCompilerAndClass(twirlClasspath: Agg[os.Path]): (TwirlWorkerApi, Class[_]) = {
val classloaderSig = twirlClasspath.map(p => p.toString().hashCode + os.mtime(p)).sum
twirlInstanceCache match {
case Some((sig, instance)) if sig == classloaderSig => instance
@@ -23,7 +26,7 @@ class TwirlWorker {
// Switched to using the java api because of the hack-ish thing going on later.
//
- // * we'll need to construct a collection of additional imports (will need to also include the defaults and add the user-provided additional imports)
+ // * we'll need to construct a collection of imports
// * we'll need to construct a collection of constructor annotations// *
// * the default collection in scala api is a Seq[String]
// * but it is defined in a different classloader (namely in cl)
@@ -57,26 +60,20 @@ class TwirlWorker {
cl.loadClass("scala.io.Codec"),
classOf[Boolean])
- val defaultImportsMethod = twirlCompilerClass.getField("DEFAULT_IMPORTS")
-
- val hashSetConstructor = hashSetClass.getConstructor(cl.loadClass("java.util.Collection"))
-
val instance = new TwirlWorkerApi {
override def compileTwirl(source: File,
sourceDirectory: File,
generatedDirectory: File,
formatterType: String,
- additionalImports: Seq[String],
+ imports: Seq[String],
constructorAnnotations: Seq[String],
codec: Codec,
inclusiveDot: Boolean) {
- // val defaultImports = play.japi.twirl.compiler.TwirlCompiler.DEFAULT_IMPORTS()
- // val twirlAdditionalImports = new HashSet(defaultImports)
- // additionalImports.foreach(twirlAdditionalImports.add)
- val defaultImports = defaultImportsMethod.get(null) // unmodifiable collection
- val twirlAdditionalImports = hashSetConstructor.newInstance(defaultImports).asInstanceOf[Object]
- val hashSetAddMethod = twirlAdditionalImports.getClass.getMethod("add", classOf[Object])
- additionalImports.foreach(hashSetAddMethod.invoke(twirlAdditionalImports, _))
+ // val twirlImports = new HashSet()
+ // imports.foreach(twirlImports.add)
+ val twirlImports = hashSetClass.newInstance().asInstanceOf[Object]
+ val hashSetAddMethod = twirlImports.getClass.getMethod("add", classOf[Object])
+ imports.foreach(hashSetAddMethod.invoke(twirlImports, _))
// Codec.apply(Charset.forName(codec.charSet.name()))
val twirlCodec = codecApplyMethod.invoke(null, charsetForNameMethod.invoke(null, codec.charSet.name()))
@@ -102,37 +99,56 @@ class TwirlWorker {
sourceDirectory,
generatedDirectory,
formatterType,
- twirlAdditionalImports,
+ twirlImports,
twirlConstructorAnnotations,
twirlCodec,
Boolean.box(inclusiveDot)
)
}
}
- twirlInstanceCache = Some((classloaderSig, instance))
- instance
+ twirlInstanceCache = Some(classloaderSig -> (instance -> twirlCompilerClass))
+ (instance, twirlCompilerClass)
}
}
+ private def twirl(twirlClasspath: Agg[os.Path]): TwirlWorkerApi =
+ twirlCompilerAndClass(twirlClasspath)._1
+
+ private def twirlClass(twirlClasspath: Agg[os.Path]): Class[_] =
+ twirlCompilerAndClass(twirlClasspath)._2
+
+ def defaultImports(twirlClasspath: Agg[os.Path]): Seq[String] =
+ twirlClass(twirlClasspath).getField("DEFAULT_IMPORTS")
+ .get(null).asInstanceOf[java.util.Set[String]].asScala.toSeq
+
+ def defaultFormats: Map[String, String] =
+ Map(
+ "html" -> "play.twirl.api.HtmlFormat",
+ "xml" -> "play.twirl.api.XmlFormat",
+ "js" -> "play.twirl.api.JavaScriptFormat",
+ "txt" -> "play.twirl.api.TxtFormat")
+
def compile(twirlClasspath: Agg[os.Path],
sourceDirectories: Seq[os.Path],
dest: os.Path,
- additionalImports: Seq[String],
+ imports: Seq[String],
+ formats: Map[String, String],
constructorAnnotations: Seq[String],
codec: Codec,
inclusiveDot: Boolean)
(implicit ctx: mill.api.Ctx): mill.api.Result[CompilationResult] = {
val compiler = twirl(twirlClasspath)
+ val formatExtsRegex = formats.keys.map(Regex.quote).mkString("|")
def compileTwirlDir(inputDir: os.Path) {
- os.walk(inputDir).filter(_.last.matches(".*.scala.(html|xml|js|txt)"))
+ os.walk(inputDir).filter(_.last.matches(s".*.scala.($formatExtsRegex)"))
.foreach { template =>
- val extFormat = twirlExtensionFormat(template.last)
+ val extClass = twirlExtensionClass(template.last, formats)
compiler.compileTwirl(template.toIO,
inputDir.toIO,
dest.toIO,
- s"play.twirl.api.$extFormat",
- additionalImports,
+ extClass,
+ imports,
constructorAnnotations,
codec,
inclusiveDot
@@ -148,11 +164,10 @@ class TwirlWorker {
mill.api.Result.Success(CompilationResult(zincFile, PathRef(classesDir)))
}
- private def twirlExtensionFormat(name: String) =
- if (name.endsWith("html")) "HtmlFormat"
- else if (name.endsWith("xml")) "XmlFormat"
- else if (name.endsWith("js")) "JavaScriptFormat"
- else "TxtFormat"
+ private def twirlExtensionClass(name: String, formats: Map[String, String]) =
+ formats.collectFirst { case (ext, klass) if name.endsWith(ext) => klass }.getOrElse {
+ throw new IllegalStateException(s"Unknown twirl extension for file: $name. Known extensions: ${formats.keys.mkString(", ")}")
+ }
}
trait TwirlWorkerApi {
@@ -160,7 +175,7 @@ trait TwirlWorkerApi {
sourceDirectory: File,
generatedDirectory: File,
formatterType: String,
- additionalImports: Seq[String],
+ imports: Seq[String],
constructorAnnotations: Seq[String],
codec: Codec,
inclusiveDot: Boolean)
diff --git a/contrib/twirllib/test/resources/hello-world-inclusive-dot/core/views/test.scala.svg b/contrib/twirllib/test/resources/hello-world-inclusive-dot/core/views/test.scala.svg
new file mode 100644
index 00000000000..714efc7ef0b
--- /dev/null
+++ b/contrib/twirllib/test/resources/hello-world-inclusive-dot/core/views/test.scala.svg
@@ -0,0 +1 @@
+
diff --git a/contrib/twirllib/test/resources/hello-world/core/views/test.scala.svg b/contrib/twirllib/test/resources/hello-world/core/views/test.scala.svg
new file mode 100644
index 00000000000..714efc7ef0b
--- /dev/null
+++ b/contrib/twirllib/test/resources/hello-world/core/views/test.scala.svg
@@ -0,0 +1 @@
+
diff --git a/contrib/twirllib/test/src/HelloWorldTests.scala b/contrib/twirllib/test/src/HelloWorldTests.scala
index 22c05886199..7f62916a1af 100644
--- a/contrib/twirllib/test/src/HelloWorldTests.scala
+++ b/contrib/twirllib/test/src/HelloWorldTests.scala
@@ -21,7 +21,8 @@ object HelloWorldTests extends TestSuite {
object HelloWorld extends HelloBase {
object core extends HelloWorldModule {
- override def twirlAdditionalImports: Seq[String] = testAdditionalImports
+ override def twirlImports = super.twirlImports() ++ testAdditionalImports
+ override def twirlFormats = super.twirlFormats() ++ Map("svg" -> "play.twirl.api.HtmlFormat")
override def twirlConstructorAnnotations: Seq[String] = testConstructorAnnotations
}
@@ -31,6 +32,7 @@ object HelloWorldTests extends TestSuite {
object core extends HelloWorldModule {
override def twirlInclusiveDot: Boolean = true
+ override def twirlFormats = super.twirlFormats() ++ Map("svg" -> "play.twirl.api.HtmlFormat")
}
}
@@ -51,8 +53,9 @@ object HelloWorldTests extends TestSuite {
}
def compileClassfiles: Seq[os.RelPath] = Seq[os.RelPath](
- os.rel / "hello.template.scala",
- os.rel / "wrapper.template.scala"
+ os.rel / 'html / "hello.template.scala",
+ os.rel / 'html / "wrapper.template.scala",
+ os.rel / 'svg / "test.template.scala"
)
def expectedDefaultImports: Seq[String] = Seq(
@@ -92,14 +95,14 @@ object HelloWorldTests extends TestSuite {
val outputFiles = os.walk(result.classes.path).filter(_.last.endsWith(".scala"))
val expectedClassfiles = compileClassfiles.map(
- eval.outPath / 'core / 'compileTwirl / 'dest / 'html / _
+ eval.outPath / 'core / 'compileTwirl / 'dest / _
)
assert(
result.classes.path == eval.outPath / 'core / 'compileTwirl / 'dest,
outputFiles.nonEmpty,
outputFiles.forall(expectedClassfiles.contains),
- outputFiles.size == 2,
+ outputFiles.size == 3,
evalCount > 0,
outputFiles.forall { p =>
val lines = os.read.lines(p).map(_.trim)
@@ -124,7 +127,7 @@ object HelloWorldTests extends TestSuite {
val outputFiles = os.walk(result.classes.path).filter(_.last.endsWith(".scala"))
val expectedClassfiles = compileClassfiles.map( name =>
- eval.outPath / 'core / 'compileTwirl / 'dest / 'html / name.toString().replace(".template.scala", "$$TwirlInclusiveDot.template.scala")
+ eval.outPath / 'core / 'compileTwirl / 'dest / name / os.RelPath.up / name.last.replace(".template.scala", "$$TwirlInclusiveDot.template.scala")
)
println(s"outputFiles: $outputFiles")
@@ -133,7 +136,7 @@ object HelloWorldTests extends TestSuite {
result.classes.path == eval.outPath / 'core / 'compileTwirl / 'dest,
outputFiles.nonEmpty,
outputFiles.forall(expectedClassfiles.contains),
- outputFiles.size == 2,
+ outputFiles.size == 3,
evalCount > 0,
outputFiles.filter(_.toString().contains("hello.template.scala")).forall { p =>
val lines = os.read.lines(p).map(_.trim)
diff --git a/docs/pages/9 - Contrib Modules.md b/docs/pages/9 - Contrib Modules.md
index 16ab85ba379..80fea16e8c9 100644
--- a/docs/pages/9 - Contrib Modules.md
+++ b/docs/pages/9 - Contrib Modules.md
@@ -177,7 +177,7 @@ object project extends BuildInfo {
The contrib.bsp module was created in order to integrate the Mill build tool
with IntelliJ IDEA via the Build Server Protocol (BSP). It implements most of
-the server side functionality described in BSP, and can therefore connect to a
+the server side functionality described in BSP, and can therefore connect to a
BSP client, including the one behind IntelliJ IDEA. This allows a lot of mill
tasks to be executed from the IDE.
@@ -818,14 +818,6 @@ object app extends ScalaModule with TwirlModule {
}
```
-### Twirl configuration options
-
-* `def twirlVersion: T[String]` (mandatory) - the version of the twirl compiler to use, like "1.3.15"
-* `def twirlAdditionalImports: Seq[String] = Nil` - the additional imports that will be added by twirl compiler to the top of all templates
-* `def twirlConstructorAnnotations: Seq[String] = Nil` - annotations added to the generated classes' constructors (note it only applies to templates with `@this(...)` constructors)
-* `def twirlCodec = Codec(Properties.sourceEncoding)` - the codec used to generate the files (the default is the same sbt plugin uses)
-* `def twirlInclusiveDot: Boolean = false`
-
### Details
The following filesystem layout is expected:
@@ -833,9 +825,9 @@ The following filesystem layout is expected:
```text
build.sc
app/
- views/
- view1.scala.html
- view2.scala.html
+ views/
+ view1.scala.html
+ view2.scala.html
```
`TwirlModule` adds the `compileTwirl` task to the module:
@@ -861,7 +853,28 @@ object app extends ScalaModule with TwirlModule {
}
```
-To add additional imports to all of the twirl templates:
+### Twirl configuration options
+
+#### `def twirlVersion: T[String]`
+
+Mandatory - the version of the twirl compiler to use, like "1.3.15".
+
+#### `def twirlImports: T[Seq[String]]`
+
+The imports that will be added by the twirl compiler to the top of all templates, defaults to [twirl's default imports](https://github.com/playframework/twirl/blob/1.5.0/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala#L166-L173):
+
+```scala
+Seq(
+ "_root_.play.twirl.api.TwirlFeatureImports._",
+ "_root_.play.twirl.api.TwirlHelperImports._",
+ "_root_.play.twirl.api.Html",
+ "_root_.play.twirl.api.JavaScript",
+ "_root_.play.twirl.api.Txt",
+ "_root_.play.twirl.api.Xml"
+)
+```
+
+To add additional imports to all of the twirl templates, override `twirlImports` in your build:
```scala
// build.sc
@@ -871,32 +884,62 @@ import $ivy.`com.lihaoyi::mill-contrib-twirllib:$MILL_VERSION`, mill.twirllib._
object app extends ScalaModule with TwirlModule {
def twirlVersion = "1.3.15"
- override def twirlAdditionalImports = Seq("my.additional.stuff._", "my.other.stuff._")
+ override def twirlImports = super.twirlImports() ++ Seq("my.additional.stuff._", "my.other.stuff._")
def generatedSources = T{ Seq(compileTwirl().classes) }
}
+
+// out.template.scala
+@import _root_.play.twirl.api.TwirlFeatureImports._
+// ...
+@import _root_.play.twirl.api.Xml
+@import my.additional.stuff._
+@import my.other.stuff._
```
-as the result all templates will get this line at the top:
+To exclude the default imports, simply override `twirlImports` without calling `super`:
```scala
-@import "my.additional.stuff._"
-@import "my.other.stuff._"
+// build.sc
+object app extends ScalaModule with TwirlModule {
+ // ...
+ override def twirlImports = Seq("my.stuff._")
+}
+
+// out.template.scala
+@import my.stuff._
```
-Besides that, twirl compiler has default imports, at the moment these:
+#### `def twirlFormats: Map[String, String]`
+
+A mapping of file extensions to class names that will be compiled by twirl, e.g. `Map("html" -> "play.twirl.api.HtmlFormat")`.
+By default `html`, `xml`, `js`, and `txt` files will be compiled using the corresponding [twirl format](https://github.com/playframework/twirl/blob/1.5.0/api/shared/src/main/scala/play/twirl/api/Formats.scala).
+
+To add additional formats, override `twirlFormats` in your build:
```scala
-Seq(
- "_root_.play.twirl.api.TwirlFeatureImports._",
- "_root_.play.twirl.api.TwirlHelperImports._",
- "_root_.play.twirl.api.Html",
- "_root_.play.twirl.api.JavaScript",
- "_root_.play.twirl.api.Txt",
- "_root_.play.twirl.api.Xml"
-)
+// build.sc
+import mill.scalalib._
+
+import $ivy.`com.lihaoyi::mill-contrib-twirllib:$MILL_VERSION`, mill.twirllib._
+
+object app extends ScalaModule with TwirlModule {
+ def twirlVersion = "1.3.15"
+ override def twirlFormats = super.twirlFormats() + Map("svg" -> "play.twirl.api.HtmlFormat")
+ def generatedSources = T{ Seq(compileTwirl().classes) }
+}
```
-These imports will always be added to every template. You don't need to list them if you override `twirlAdditionalImports`.
+#### `def twirlConstructorAnnotations: Seq[String] = Nil`
+
+Annotations added to the generated classes' constructors (note it only applies to templates with `@this(...)` constructors).
+
+#### `def twirlCodec = Codec(Properties.sourceEncoding)`
+
+The codec used to generate the files (the default is the same sbt plugin uses).
+
+#### `def twirlInclusiveDot: Boolean = false`
+
+Whether the twirl parser should parse with an inclusive dot.
### Example
There's an [example project](https://github.com/lihaoyi/cask/tree/master/example/twirl)