Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Twirl enhancements #952

Merged
merged 3 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions contrib/playlib/src/mill/playlib/Twirl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@ trait Twirl extends TwirlModule with Layout {

override def twirlSources=T.sources{ app() }

override def twirlAdditionalImports = Seq(
"_root_.play.twirl.api.TwirlFeatureImports._",
"_root_.play.twirl.api.TwirlHelperImports._",
"_root_.play.twirl.api.Html",
"_root_.play.twirl.api.JavaScript",
"_root_.play.twirl.api.Txt",
"_root_.play.twirl.api.Xml",
"models._",
"controllers._",
"play.api.i18n._",
"views.html._",
"play.api.templates.PlayMagic._",
"play.api.mvc._",
"play.api.data._"
)
override def twirlImports = T {
super.twirlImports() ++ Seq(
"models._",
"controllers._",
"play.api.i18n._",
"views.html._",
"play.api.templates.PlayMagic._",
"play.api.mvc._",
"play.api.data._"
)
}

def twirlOutput = T{Seq(compileTwirl().classes)}

Expand Down
9 changes: 7 additions & 2 deletions contrib/twirllib/src/TwirlModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ trait TwirlModule extends mill.Module {
)
}

def twirlAdditionalImports: Seq[String] = Nil
def twirlImports: T[Seq[String]] = T {
TwirlWorkerApi.twirlWorker.defaultImports(twirlClasspath().map(_.path))
}

def twirlFormats: T[Map[String, String]] = TwirlWorkerApi.twirlWorker.defaultFormats

def twirlConstructorAnnotations: Seq[String] = Nil

Expand All @@ -47,7 +51,8 @@ trait TwirlModule extends mill.Module {
twirlClasspath().map(_.path),
twirlSources().map(_.path),
T.dest,
twirlAdditionalImports,
twirlImports(),
twirlFormats(),
twirlConstructorAnnotations,
twirlCodec,
twirlInclusiveDot)
Expand Down
73 changes: 44 additions & 29 deletions contrib/twirllib/src/TwirlWorker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ import java.nio.charset.Charset
import mill.api.PathRef
import mill.scalalib.api.CompilationResult

import scala.jdk.CollectionConverters._
import scala.io.Codec
import scala.util.matching.Regex

class TwirlWorker {

private var twirlInstanceCache = Option.empty[(Long, TwirlWorkerApi)]
private var twirlInstanceCache = Option.empty[(Long, (TwirlWorkerApi, Class[_]))]

private def twirl(twirlClasspath: Agg[os.Path]) = {
private def twirlCompilerAndClass(twirlClasspath: Agg[os.Path]): (TwirlWorkerApi, Class[_]) = {
val classloaderSig = twirlClasspath.map(p => p.toString().hashCode + os.mtime(p)).sum
twirlInstanceCache match {
case Some((sig, instance)) if sig == classloaderSig => instance
Expand All @@ -23,7 +26,7 @@ class TwirlWorker {

// Switched to using the java api because of the hack-ish thing going on later.
//
// * we'll need to construct a collection of additional imports (will need to also include the defaults and add the user-provided additional imports)
// * we'll need to construct a collection of imports
// * we'll need to construct a collection of constructor annotations// *
// * the default collection in scala api is a Seq[String]
// * but it is defined in a different classloader (namely in cl)
Expand Down Expand Up @@ -57,26 +60,20 @@ class TwirlWorker {
cl.loadClass("scala.io.Codec"),
classOf[Boolean])

val defaultImportsMethod = twirlCompilerClass.getField("DEFAULT_IMPORTS")

val hashSetConstructor = hashSetClass.getConstructor(cl.loadClass("java.util.Collection"))

val instance = new TwirlWorkerApi {
override def compileTwirl(source: File,
sourceDirectory: File,
generatedDirectory: File,
formatterType: String,
additionalImports: Seq[String],
imports: Seq[String],
constructorAnnotations: Seq[String],
codec: Codec,
inclusiveDot: Boolean) {
// val defaultImports = play.japi.twirl.compiler.TwirlCompiler.DEFAULT_IMPORTS()
// val twirlAdditionalImports = new HashSet(defaultImports)
// additionalImports.foreach(twirlAdditionalImports.add)
val defaultImports = defaultImportsMethod.get(null) // unmodifiable collection
val twirlAdditionalImports = hashSetConstructor.newInstance(defaultImports).asInstanceOf[Object]
val hashSetAddMethod = twirlAdditionalImports.getClass.getMethod("add", classOf[Object])
additionalImports.foreach(hashSetAddMethod.invoke(twirlAdditionalImports, _))
// val twirlImports = new HashSet()
// imports.foreach(twirlImports.add)
val twirlImports = hashSetClass.newInstance().asInstanceOf[Object]
val hashSetAddMethod = twirlImports.getClass.getMethod("add", classOf[Object])
imports.foreach(hashSetAddMethod.invoke(twirlImports, _))

// Codec.apply(Charset.forName(codec.charSet.name()))
val twirlCodec = codecApplyMethod.invoke(null, charsetForNameMethod.invoke(null, codec.charSet.name()))
Expand All @@ -102,37 +99,56 @@ class TwirlWorker {
sourceDirectory,
generatedDirectory,
formatterType,
twirlAdditionalImports,
twirlImports,
twirlConstructorAnnotations,
twirlCodec,
Boolean.box(inclusiveDot)
)
}
}
twirlInstanceCache = Some((classloaderSig, instance))
instance
twirlInstanceCache = Some(classloaderSig -> (instance -> twirlCompilerClass))
(instance, twirlCompilerClass)
}
}

private def twirl(twirlClasspath: Agg[os.Path]): TwirlWorkerApi =
twirlCompilerAndClass(twirlClasspath)._1

private def twirlClass(twirlClasspath: Agg[os.Path]): Class[_] =
twirlCompilerAndClass(twirlClasspath)._2

def defaultImports(twirlClasspath: Agg[os.Path]): Seq[String] =
twirlClass(twirlClasspath).getField("DEFAULT_IMPORTS")
.get(null).asInstanceOf[java.util.Set[String]].asScala.toSeq

def defaultFormats: Map[String, String] =
Map(
"html" -> "play.twirl.api.HtmlFormat",
"xml" -> "play.twirl.api.XmlFormat",
"js" -> "play.twirl.api.JavaScriptFormat",
"txt" -> "play.twirl.api.TxtFormat")

def compile(twirlClasspath: Agg[os.Path],
sourceDirectories: Seq[os.Path],
dest: os.Path,
additionalImports: Seq[String],
imports: Seq[String],
formats: Map[String, String],
constructorAnnotations: Seq[String],
codec: Codec,
inclusiveDot: Boolean)
(implicit ctx: mill.api.Ctx): mill.api.Result[CompilationResult] = {
val compiler = twirl(twirlClasspath)
val formatExtsRegex = formats.keys.map(Regex.quote).mkString("|")

def compileTwirlDir(inputDir: os.Path) {
os.walk(inputDir).filter(_.last.matches(".*.scala.(html|xml|js|txt)"))
os.walk(inputDir).filter(_.last.matches(s".*.scala.($formatExtsRegex)"))
.foreach { template =>
val extFormat = twirlExtensionFormat(template.last)
val extClass = twirlExtensionClass(template.last, formats)
compiler.compileTwirl(template.toIO,
inputDir.toIO,
dest.toIO,
s"play.twirl.api.$extFormat",
additionalImports,
extClass,
imports,
constructorAnnotations,
codec,
inclusiveDot
Expand All @@ -148,19 +164,18 @@ class TwirlWorker {
mill.api.Result.Success(CompilationResult(zincFile, PathRef(classesDir)))
}

private def twirlExtensionFormat(name: String) =
if (name.endsWith("html")) "HtmlFormat"
else if (name.endsWith("xml")) "XmlFormat"
else if (name.endsWith("js")) "JavaScriptFormat"
else "TxtFormat"
private def twirlExtensionClass(name: String, formats: Map[String, String]) =
formats.collectFirst { case (ext, klass) if name.endsWith(ext) => klass }.getOrElse {
throw new IllegalStateException(s"Unknown twirl extension for file: $name. Known extensions: ${formats.keys.mkString(", ")}")
}
}

trait TwirlWorkerApi {
def compileTwirl(source: File,
sourceDirectory: File,
generatedDirectory: File,
formatterType: String,
additionalImports: Seq[String],
imports: Seq[String],
constructorAnnotations: Seq[String],
codec: Codec,
inclusiveDot: Boolean)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 10 additions & 7 deletions contrib/twirllib/test/src/HelloWorldTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ object HelloWorldTests extends TestSuite {
object HelloWorld extends HelloBase {

object core extends HelloWorldModule {
override def twirlAdditionalImports: Seq[String] = testAdditionalImports
override def twirlImports = super.twirlImports() ++ testAdditionalImports
override def twirlFormats = super.twirlFormats() ++ Map("svg" -> "play.twirl.api.HtmlFormat")
override def twirlConstructorAnnotations: Seq[String] = testConstructorAnnotations
}

Expand All @@ -31,6 +32,7 @@ object HelloWorldTests extends TestSuite {

object core extends HelloWorldModule {
override def twirlInclusiveDot: Boolean = true
override def twirlFormats = super.twirlFormats() ++ Map("svg" -> "play.twirl.api.HtmlFormat")
}

}
Expand All @@ -51,8 +53,9 @@ object HelloWorldTests extends TestSuite {
}

def compileClassfiles: Seq[os.RelPath] = Seq[os.RelPath](
os.rel / "hello.template.scala",
os.rel / "wrapper.template.scala"
os.rel / 'html / "hello.template.scala",
os.rel / 'html / "wrapper.template.scala",
os.rel / 'svg / "test.template.scala"
)

def expectedDefaultImports: Seq[String] = Seq(
Expand Down Expand Up @@ -92,14 +95,14 @@ object HelloWorldTests extends TestSuite {

val outputFiles = os.walk(result.classes.path).filter(_.last.endsWith(".scala"))
val expectedClassfiles = compileClassfiles.map(
eval.outPath / 'core / 'compileTwirl / 'dest / 'html / _
eval.outPath / 'core / 'compileTwirl / 'dest / _
)

assert(
result.classes.path == eval.outPath / 'core / 'compileTwirl / 'dest,
outputFiles.nonEmpty,
outputFiles.forall(expectedClassfiles.contains),
outputFiles.size == 2,
outputFiles.size == 3,
evalCount > 0,
outputFiles.forall { p =>
val lines = os.read.lines(p).map(_.trim)
Expand All @@ -124,7 +127,7 @@ object HelloWorldTests extends TestSuite {

val outputFiles = os.walk(result.classes.path).filter(_.last.endsWith(".scala"))
val expectedClassfiles = compileClassfiles.map( name =>
eval.outPath / 'core / 'compileTwirl / 'dest / 'html / name.toString().replace(".template.scala", "$$TwirlInclusiveDot.template.scala")
eval.outPath / 'core / 'compileTwirl / 'dest / name / os.RelPath.up / name.last.replace(".template.scala", "$$TwirlInclusiveDot.template.scala")
)

println(s"outputFiles: $outputFiles")
Expand All @@ -133,7 +136,7 @@ object HelloWorldTests extends TestSuite {
result.classes.path == eval.outPath / 'core / 'compileTwirl / 'dest,
outputFiles.nonEmpty,
outputFiles.forall(expectedClassfiles.contains),
outputFiles.size == 2,
outputFiles.size == 3,
evalCount > 0,
outputFiles.filter(_.toString().contains("hello.template.scala")).forall { p =>
val lines = os.read.lines(p).map(_.trim)
Expand Down
Loading