Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak macro to allow more than 22 inputs #1623

Merged
merged 6 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,11 @@ object main extends MillModule {
override def compileIvyDeps = Agg(
Deps.scalaReflect(scalaVersion())
)
override def generatedSources = T {
Seq(PathRef(shared.generateCoreSources(T.ctx.dest)))
}
override def testArgs = Seq(
"-DMILL_VERSION=" + publishVersion()
)
override val test = new Tests(implicitly)
class Tests(ctx0: mill.define.Ctx) extends super.Tests(ctx0) {
override def generatedSources = T {
Seq(PathRef(shared.generateCoreTestSources(T.ctx.dest)))
}
}
object api extends MillApiModule {
override def ivyDeps = Agg(
Expand Down Expand Up @@ -301,7 +295,6 @@ object main extends MillModule {
millBinPlatform = millBinPlatform(),
artifacts = T.traverse(dev.moduleDeps)(_.publishSelfDependency)()
)
shared.generateCoreSources(dest)
Seq(PathRef(dest))
}

Expand Down
107 changes: 0 additions & 107 deletions ci/shared.sc
Original file line number Diff line number Diff line change
Expand Up @@ -6,100 +6,6 @@

import $ivy.`org.scalaj::scalaj-http:2.4.2`

def argNames(n: Int) = {
val uppercases = (0 until n).map("T" + _)
val lowercases = uppercases.map(_.toLowerCase)
val typeArgs = uppercases.mkString(", ")
val zipArgs = lowercases.mkString(", ")
(lowercases, uppercases, typeArgs, zipArgs)
}
def generateApplyer(dir: os.Path) = {
def generate(n: Int) = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: TT[$upper]" }.mkString(", ")

val body = s"mapCtx(zip($zipArgs)) { case (($zipArgs), z) => cb($zipArgs, z) }"
val zipmap = s"def zipMap[$typeArgs, Res]($parameters)(cb: ($typeArgs, Ctx) => Z[Res]) = $body"
val zip = s"def zip[$typeArgs]($parameters): TT[($typeArgs)]"

if (n < 22) List(zipmap, zip).mkString("\n") else zip
}
os.write(
dir / "ApplicativeGenerated.scala",
s"""package mill.define
|import scala.language.higherKinds
|trait ApplyerGenerated[TT[_], Z[_], Ctx] {
| def mapCtx[A, B](a: TT[A])(f: (A, Ctx) => Z[B]): TT[B]
| ${(2 to 22).map(generate).mkString("\n")}
|}""".stripMargin
)
}

def generateTarget(dir: os.Path) = {
def generate(n: Int) = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: TT[$upper]" }.mkString(", ")
val body = uppercases.zipWithIndex.map { case (t, i) => s"args[$t]($i)" }.mkString(", ")

s"def zip[$typeArgs]($parameters) = makeT[($typeArgs)](Seq($zipArgs), (args: mill.api.Ctx) => ($body))"
}

os.write(
dir / "TaskGenerated.scala",
s"""package mill.define
|import scala.language.higherKinds
|trait TargetGenerated {
| type TT[+X]
| def makeT[X](inputs: Seq[TT[_]], evaluate: mill.api.Ctx => mill.api.Result[X]): TT[X]
| ${(3 to 22).map(generate).mkString("\n")}
|}""".stripMargin
)
}

def generateEval(dir: os.Path) = {
def generate(n: Int) = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: TT[$upper]" }.mkString(", ")
val extract = uppercases.zipWithIndex.map { case (t, i) => s"result($i).asInstanceOf[$t]" }.mkString(", ")

s"""def eval[$typeArgs]($parameters):($typeArgs) = {
| val result = evaluator.evaluate(Agg($zipArgs)).values
| (${extract})
|}
""".stripMargin
}

os.write(
dir / "EvalGenerated.scala",
s"""package mill.main
|import mill.eval.Evaluator
|import mill.define.Task
|import mill.api.Strict.Agg
|class EvalGenerated(evaluator: Evaluator) {
| type TT[+X] = Task[X]
| ${(1 to 22).map(generate).mkString("\n")}
|}""".stripMargin
)
}

def generateApplicativeTest(dir: os.Path) = {
def generate(n: Int): String = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: Option[$upper]" }.mkString(", ")
val forArgs = lowercases.map(i => s"$i <- $i").mkString("; ")
s"def zip[$typeArgs]($parameters) = { for ($forArgs) yield ($zipArgs) }"
}

os.write(
dir / "ApplicativeTestsGenerated.scala",
s"""package mill.define
|trait OptGenerated {
| ${(2 to 22).map(generate).mkString("\n")}
|}
""".stripMargin
)
}

def unpackZip(zipDest: os.Path, url: String) = {
println(s"Unpacking zip $url into $zipDest")
os.makeDir.all(zipDest)
Expand Down Expand Up @@ -132,19 +38,6 @@ def unpackZip(zipDest: os.Path, url: String) = {
})()
}

@main
def generateCoreSources(p: os.Path) = {
generateApplyer(p)
generateTarget(p)
generateEval(p)
p
}

@main
def generateCoreTestSources(p: os.Path) = {
generateApplicativeTest(p)
p
}


@main
Expand Down
27 changes: 13 additions & 14 deletions main/core/src/mill/define/Applicative.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,9 @@ object Applicative {

type Id[+T] = T

trait Applyer[W[_], T[_], Z[_], Ctx] extends ApplyerGenerated[T, Z, Ctx] {
trait Applyer[W[_], T[_], Z[_], Ctx] {
def ctx()(implicit c: Ctx) = c
def underlying[A](v: W[A]): T[_]

def zipMap[R]()(cb: Ctx => Z[R]) = mapCtx(zip()) { (_, ctx) => cb(ctx) }
def zipMap[A, R](a: T[A])(f: (A, Ctx) => Z[R]) = mapCtx(a)(f)
def zip(): T[Unit]
def zip[A](a: T[A]): T[Tuple1[A]]
def traverseCtx[I, R](xs: Seq[W[I]])(f: (IndexedSeq[I], Ctx) => Z[R]): T[R]
}

def impl[M[_], T: c.WeakTypeTag, Ctx: c.WeakTypeTag](c: Context)(t: c.Expr[T]): c.Expr[M[T]] = {
Expand All @@ -45,9 +40,13 @@ object Applicative {
import c.universe._
def rec(t: Tree): Iterator[c.Tree] = Iterator(t) ++ t.children.flatMap(rec(_))

val bound = collection.mutable.Buffer.empty[(c.Tree, ValDef)]
val exprs = collection.mutable.Buffer.empty[c.Tree]
val targetApplySym = typeOf[Applyable[Nothing, _]].member(TermName("apply"))

val itemsName = c.freshName(TermName("items"))
val itemsSym = c.internal.newTermSymbol(c.internal.enclosingOwner, itemsName)
c.internal.setFlag(itemsSym, (1L << 44).asInstanceOf[c.universe.FlagSet])
c.internal.setInfo(itemsSym, typeOf[Seq[Any]])
// Derived from @olafurpg's
// https://gist.github.com/olafurpg/596d62f87bf3360a29488b725fbc7608
val defs = rec(t).filter(_.isDef).map(_.symbol).toSet
Expand All @@ -74,8 +73,9 @@ object Applicative {
val tempIdent = Ident(tempSym)
c.internal.setType(tempIdent, t.tpe)
c.internal.setFlag(tempSym, (1L << 44).asInstanceOf[c.universe.FlagSet])
bound.append((q"${c.prefix}.underlying($fun)", c.internal.valDef(tempSym)))
tempIdent
val itemsIdent = Ident(itemsSym)
exprs.append(q"$fun")
c.typecheck(q"$itemsIdent(${exprs.size-1}).asInstanceOf[${t.tpe}]")
case (t, api)
if t.symbol != null
&& t.symbol.annotations.exists(_.tree.tpe =:= typeOf[mill.api.Ctx.ImplicitStub]) =>
Expand All @@ -87,13 +87,12 @@ object Applicative {
case (t, api) => api.default(t)
}

val (exprs, bindings) = bound.unzip

val ctxBinding = c.internal.valDef(ctxSym)

val callback = c.typecheck(q"(..$bindings, $ctxBinding) => $transformed ")
val itemsBinding = c.internal.valDef(itemsSym)
val callback = c.typecheck(q"{(${itemsBinding}, ${ctxBinding}) => $transformed}")

val res = q"${c.prefix}.zipMap(..$exprs){ $callback }"
val res = q"${c.prefix}.traverseCtx[_root_.scala.Any, ${weakTypeOf[T]}](${exprs.toList}){ $callback }"

c.internal.changeOwner(transformed, c.internal.enclosingOwner, callback.symbol)

Expand Down
31 changes: 16 additions & 15 deletions main/core/src/mill/define/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ trait Target[+T] extends NamedTask[T] {
def readWrite: RW[_]
}

object Target extends TargetGenerated with Applicative.Applyer[Task, Task, Result, mill.api.Ctx] {
object Target extends Applicative.Applyer[Task, Task, Result, mill.api.Ctx] {
// convenience
def dest(implicit ctx: mill.api.Ctx.Dest): os.Path = ctx.dest
def log(implicit ctx: mill.api.Ctx.Log): Logger = ctx.log
Expand Down Expand Up @@ -283,22 +283,13 @@ object Target extends TargetGenerated with Applicative.Applyer[Task, Task, Resul
)
}

type TT[+X] = Task[X]
def makeT[X](inputs0: Seq[TT[_]], evaluate0: mill.api.Ctx => Result[X]) = new Task[X] {
val inputs = inputs0
def evaluate(x: mill.api.Ctx) = evaluate0(x)
}

def underlying[A](v: Task[A]) = v
def mapCtx[A, B](t: Task[A])(f: (A, mill.api.Ctx) => Result[B]) = t.mapDest(f)
def zip() = new Task.Task0(())
def zip[A](a: Task[A]) = a.map(Tuple1(_))
def zip[A, B](a: Task[A], b: Task[B]) = a.zip(b)

def traverse[T, V](source: Seq[T])(f: T => Task[V]) = {
new Task.Sequence[V](source.map(f))
}
def sequence[T](source: Seq[Task[T]]) = new Task.Sequence[T](source)
def traverseCtx[I, R](xs: Seq[Task[I]])(f: (IndexedSeq[I], mill.api.Ctx) => Result[R]): Task[R] = {
new Task.TraverseCtx[I, R](xs, f)
}
}

abstract class NamedTaskImpl[+T](ctx0: mill.define.Ctx, t: Task[T]) extends NamedTask[T] {
Expand Down Expand Up @@ -372,9 +363,19 @@ object Task {
val inputs = inputs0
def evaluate(args: mill.api.Ctx) = {
for (i <- 0 until args.length)
yield args(i).asInstanceOf[T]
yield args(i).asInstanceOf[T]
}
}
class TraverseCtx[+T, V](inputs0: Seq[Task[T]],
f: (IndexedSeq[T], mill.api.Ctx) => Result[V]) extends Task[V] {
val inputs = inputs0
def evaluate(args: mill.api.Ctx) = {
f(
for (i <- 0 until args.length)
yield args(i).asInstanceOf[T],
args
)
}

}
class Mapped[+T, +V](source: Task[T], f: T => V) extends Task[V] {
def evaluate(args: mill.api.Ctx) = f(args(0))
Expand Down
2 changes: 0 additions & 2 deletions main/src/mill/main/ReplApplyHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ class ReplApplyHandler(
}
}

val generatedEval = new EvalGenerated(evaluator)

val millHandlers: PartialFunction[Any, pprint.Tree] = {
case c: Cross[_] =>
ReplApplyHandler.pprintCross(c, evaluator)
Expand Down
39 changes: 33 additions & 6 deletions main/test/src/define/ApplicativeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ import scala.language.experimental.macros

object ApplicativeTests extends TestSuite {
implicit def optionToOpt[T](o: Option[T]): Opt[T] = new Opt(o)
class Opt[T](val self: Option[T]) extends Applicative.Applyable[Option, T]
object Opt extends OptGenerated with Applicative.Applyer[Opt, Option, Applicative.Id, String] {
class Opt[+T](val self: Option[T]) extends Applicative.Applyable[Option, T]
object Opt extends Applicative.Applyer[Opt, Option, Applicative.Id, String] {

val injectedCtx = "helloooo"
def underlying[A](v: Opt[A]) = v.self
def apply[T](t: T): Option[T] = macro Applicative.impl[Option, T, String]

def mapCtx[A, B](a: Option[A])(f: (A, String) => B): Option[B] = a.map(f(_, injectedCtx))
def zip() = Some(())
def zip[A](a: Option[A]) = a.map(Tuple1(_))
def traverseCtx[I, R](xs: Seq[Opt[I]])(f: (IndexedSeq[I], String) => Applicative.Id[R]): Option[R] = {
if (xs.exists(_.self.isEmpty)) None
else Some(f(xs.map(_.self.get).toVector, injectedCtx))
}
}
class Counter {
var value = 0
Expand All @@ -39,6 +39,33 @@ object ApplicativeTests extends TestSuite {
"twoSomes" - assert(Opt(Some("lol ")() + Some("hello")()) == Some("lol hello"))
"singleNone" - assert(Opt("lol " + None()) == None)
"twoNones" - assert(Opt("lol " + None() + None()) == None)
"moreThan22" - {
assert(
Opt(
"lol " +
None() + None() + None() + None() + None() +
None() + None() + None() + None() + Some(" world")() +
None() + None() + None() + None() + None() +
None() + None() + None() + None() + None() +
None() + None() + None() + None() + Some(" moo")()
) == None
)
assert(
Opt(
"lol " +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")()
) == Some("lol abcdeabcdeabcdeabcdeabcdeabcdeabcdeabcdeabcdeabcde")
)
}
}
"context" - {
assert(Opt(Opt.ctx() + Some("World")()) == Some("hellooooWorld"))
Expand Down
8 changes: 4 additions & 4 deletions main/test/src/eval/EvaluationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class EvaluationTests(threadCount: Option[Int]) extends TestSuite {
// cached target
val check = new Checker(build)
assert(leftCount == 0, rightCount == 0)
check(down, expValue = 10101, expEvaled = Agg(up, right, down), extraEvaled = 8)
check(down, expValue = 10101, expEvaled = Agg(up, right, down), extraEvaled = 5)
assert(leftCount == 1, middleCount == 1, rightCount == 1)

// If the upstream `up` doesn't change, the entire block of tasks
Expand All @@ -340,7 +340,7 @@ class EvaluationTests(threadCount: Option[Int]) extends TestSuite {
// because tasks have no cached value that can be used. `right`, which
// is a cached Target, does not recompute
up.inputs(0).asInstanceOf[Test].counter += 1
check(down, expValue = 10102, expEvaled = Agg(up, down), extraEvaled = 6)
check(down, expValue = 10102, expEvaled = Agg(up, down), extraEvaled = 4)
assert(leftCount == 2, middleCount == 2, rightCount == 1)

// Running the tasks themselves results in them being recomputed every
Expand All @@ -350,9 +350,9 @@ class EvaluationTests(threadCount: Option[Int]) extends TestSuite {
check(left, expValue = 2, expEvaled = Agg(), extraEvaled = 1, secondRunNoOp = false)
assert(leftCount == 4, middleCount == 2, rightCount == 1)

check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 2, secondRunNoOp = false)
check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 1, secondRunNoOp = false)
assert(leftCount == 4, middleCount == 3, rightCount == 1)
check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 2, secondRunNoOp = false)
check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 1, secondRunNoOp = false)
assert(leftCount == 4, middleCount == 4, rightCount == 1)
}
}
Expand Down