Skip to content

Commit

Permalink
Tweak macro to allow more than 22 inputs (#1623)
Browse files Browse the repository at this point in the history
Fixes #910

Basically we avoid using the statically typed `zipMap` methods which cap out at 22, and instead use an untyped `zipMapLong` method that works with `IndexedSeq[T[Any]]` and `IndexedSeq[Any]`s. This could conceivably result in runtime errors if the macro has bugs, but assuming the macro is correct then it's as safe for the user as the existing macro is

By removing the old implementation, we are able to cut out a considerable amount of complexity from the codebase while still maintaining the existing semantics. Performance is probably better after, since we're just removing a layer of indirection: previously we ended up going through this dynamic `Seq[Any]` anyway, we just had a statically-typed facade wrapping it

Tested with a unit test, also tested manually in the `scratch` folder on the following build:

```scala
def t1 = T{ 1 }
def t2 = T{ 2 }
def t3 = T{ 3 }
def t4 = T{ 4 }
def t5 = T{ 5 }
def t6 = T{ 6 }
def t7 = T{ 7 }
def t8 = T{ 8 }
def t9 = T{ 9 }
def t10 = T{ 10 }
def t11 = T{ 11 }
def t12 = T{ 12 }
def t13 = T{ 13 }
def t14 = T{ 14 }
def t15 = T{ 15 }
def t16 = T{ 16 }
def t17 = T{ 17 }
def t18 = T{ 18 }
def t19 = T{ 19 }
def t20 = T{ 20 }
def t21 = T{ 21 }
def t22 = T{ 22 }
def t23 = T{ 23 }

def sum = T{
  t1() + t2() + t3() + t4() + t5() + t6() + t7() + t8() + t9() + t10() + t11() + t12() +
  t13() + t14() + t15() + t16() + t17() + t18() + t19() + t20() + t21() + t22() + t23()
}
```

This passes on this PR, fails on master.
  • Loading branch information
lihaoyi authored Dec 16, 2021
1 parent 578166e commit bb37e70
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 155 deletions.
7 changes: 0 additions & 7 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,11 @@ object main extends MillModule {
override def compileIvyDeps = Agg(
Deps.scalaReflect(scalaVersion())
)
override def generatedSources = T {
Seq(PathRef(shared.generateCoreSources(T.ctx.dest)))
}
override def testArgs = Seq(
"-DMILL_VERSION=" + publishVersion()
)
override val test = new Tests(implicitly)
class Tests(ctx0: mill.define.Ctx) extends super.Tests(ctx0) {
override def generatedSources = T {
Seq(PathRef(shared.generateCoreTestSources(T.ctx.dest)))
}
}
object api extends MillApiModule {
override def ivyDeps = Agg(
Expand Down Expand Up @@ -301,7 +295,6 @@ object main extends MillModule {
millBinPlatform = millBinPlatform(),
artifacts = T.traverse(dev.moduleDeps)(_.publishSelfDependency)()
)
shared.generateCoreSources(dest)
Seq(PathRef(dest))
}

Expand Down
107 changes: 0 additions & 107 deletions ci/shared.sc
Original file line number Diff line number Diff line change
Expand Up @@ -6,100 +6,6 @@

import $ivy.`org.scalaj::scalaj-http:2.4.2`

def argNames(n: Int) = {
val uppercases = (0 until n).map("T" + _)
val lowercases = uppercases.map(_.toLowerCase)
val typeArgs = uppercases.mkString(", ")
val zipArgs = lowercases.mkString(", ")
(lowercases, uppercases, typeArgs, zipArgs)
}
def generateApplyer(dir: os.Path) = {
def generate(n: Int) = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: TT[$upper]" }.mkString(", ")

val body = s"mapCtx(zip($zipArgs)) { case (($zipArgs), z) => cb($zipArgs, z) }"
val zipmap = s"def zipMap[$typeArgs, Res]($parameters)(cb: ($typeArgs, Ctx) => Z[Res]) = $body"
val zip = s"def zip[$typeArgs]($parameters): TT[($typeArgs)]"

if (n < 22) List(zipmap, zip).mkString("\n") else zip
}
os.write(
dir / "ApplicativeGenerated.scala",
s"""package mill.define
|import scala.language.higherKinds
|trait ApplyerGenerated[TT[_], Z[_], Ctx] {
| def mapCtx[A, B](a: TT[A])(f: (A, Ctx) => Z[B]): TT[B]
| ${(2 to 22).map(generate).mkString("\n")}
|}""".stripMargin
)
}

def generateTarget(dir: os.Path) = {
def generate(n: Int) = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: TT[$upper]" }.mkString(", ")
val body = uppercases.zipWithIndex.map { case (t, i) => s"args[$t]($i)" }.mkString(", ")

s"def zip[$typeArgs]($parameters) = makeT[($typeArgs)](Seq($zipArgs), (args: mill.api.Ctx) => ($body))"
}

os.write(
dir / "TaskGenerated.scala",
s"""package mill.define
|import scala.language.higherKinds
|trait TargetGenerated {
| type TT[+X]
| def makeT[X](inputs: Seq[TT[_]], evaluate: mill.api.Ctx => mill.api.Result[X]): TT[X]
| ${(3 to 22).map(generate).mkString("\n")}
|}""".stripMargin
)
}

def generateEval(dir: os.Path) = {
def generate(n: Int) = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: TT[$upper]" }.mkString(", ")
val extract = uppercases.zipWithIndex.map { case (t, i) => s"result($i).asInstanceOf[$t]" }.mkString(", ")

s"""def eval[$typeArgs]($parameters):($typeArgs) = {
| val result = evaluator.evaluate(Agg($zipArgs)).values
| (${extract})
|}
""".stripMargin
}

os.write(
dir / "EvalGenerated.scala",
s"""package mill.main
|import mill.eval.Evaluator
|import mill.define.Task
|import mill.api.Strict.Agg
|class EvalGenerated(evaluator: Evaluator) {
| type TT[+X] = Task[X]
| ${(1 to 22).map(generate).mkString("\n")}
|}""".stripMargin
)
}

def generateApplicativeTest(dir: os.Path) = {
def generate(n: Int): String = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: Option[$upper]" }.mkString(", ")
val forArgs = lowercases.map(i => s"$i <- $i").mkString("; ")
s"def zip[$typeArgs]($parameters) = { for ($forArgs) yield ($zipArgs) }"
}

os.write(
dir / "ApplicativeTestsGenerated.scala",
s"""package mill.define
|trait OptGenerated {
| ${(2 to 22).map(generate).mkString("\n")}
|}
""".stripMargin
)
}

def unpackZip(zipDest: os.Path, url: String) = {
println(s"Unpacking zip $url into $zipDest")
os.makeDir.all(zipDest)
Expand Down Expand Up @@ -132,19 +38,6 @@ def unpackZip(zipDest: os.Path, url: String) = {
})()
}

@main
def generateCoreSources(p: os.Path) = {
generateApplyer(p)
generateTarget(p)
generateEval(p)
p
}

@main
def generateCoreTestSources(p: os.Path) = {
generateApplicativeTest(p)
p
}


@main
Expand Down
27 changes: 13 additions & 14 deletions main/core/src/mill/define/Applicative.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,9 @@ object Applicative {

type Id[+T] = T

trait Applyer[W[_], T[_], Z[_], Ctx] extends ApplyerGenerated[T, Z, Ctx] {
trait Applyer[W[_], T[_], Z[_], Ctx] {
def ctx()(implicit c: Ctx) = c
def underlying[A](v: W[A]): T[_]

def zipMap[R]()(cb: Ctx => Z[R]) = mapCtx(zip()) { (_, ctx) => cb(ctx) }
def zipMap[A, R](a: T[A])(f: (A, Ctx) => Z[R]) = mapCtx(a)(f)
def zip(): T[Unit]
def zip[A](a: T[A]): T[Tuple1[A]]
def traverseCtx[I, R](xs: Seq[W[I]])(f: (IndexedSeq[I], Ctx) => Z[R]): T[R]
}

def impl[M[_], T: c.WeakTypeTag, Ctx: c.WeakTypeTag](c: Context)(t: c.Expr[T]): c.Expr[M[T]] = {
Expand All @@ -45,9 +40,13 @@ object Applicative {
import c.universe._
def rec(t: Tree): Iterator[c.Tree] = Iterator(t) ++ t.children.flatMap(rec(_))

val bound = collection.mutable.Buffer.empty[(c.Tree, ValDef)]
val exprs = collection.mutable.Buffer.empty[c.Tree]
val targetApplySym = typeOf[Applyable[Nothing, _]].member(TermName("apply"))

val itemsName = c.freshName(TermName("items"))
val itemsSym = c.internal.newTermSymbol(c.internal.enclosingOwner, itemsName)
c.internal.setFlag(itemsSym, (1L << 44).asInstanceOf[c.universe.FlagSet])
c.internal.setInfo(itemsSym, typeOf[Seq[Any]])
// Derived from @olafurpg's
// https://gist.github.com/olafurpg/596d62f87bf3360a29488b725fbc7608
val defs = rec(t).filter(_.isDef).map(_.symbol).toSet
Expand All @@ -74,8 +73,9 @@ object Applicative {
val tempIdent = Ident(tempSym)
c.internal.setType(tempIdent, t.tpe)
c.internal.setFlag(tempSym, (1L << 44).asInstanceOf[c.universe.FlagSet])
bound.append((q"${c.prefix}.underlying($fun)", c.internal.valDef(tempSym)))
tempIdent
val itemsIdent = Ident(itemsSym)
exprs.append(q"$fun")
c.typecheck(q"$itemsIdent(${exprs.size-1}).asInstanceOf[${t.tpe}]")
case (t, api)
if t.symbol != null
&& t.symbol.annotations.exists(_.tree.tpe =:= typeOf[mill.api.Ctx.ImplicitStub]) =>
Expand All @@ -87,13 +87,12 @@ object Applicative {
case (t, api) => api.default(t)
}

val (exprs, bindings) = bound.unzip

val ctxBinding = c.internal.valDef(ctxSym)

val callback = c.typecheck(q"(..$bindings, $ctxBinding) => $transformed ")
val itemsBinding = c.internal.valDef(itemsSym)
val callback = c.typecheck(q"{(${itemsBinding}, ${ctxBinding}) => $transformed}")

val res = q"${c.prefix}.zipMap(..$exprs){ $callback }"
val res = q"${c.prefix}.traverseCtx[_root_.scala.Any, ${weakTypeOf[T]}](${exprs.toList}){ $callback }"

c.internal.changeOwner(transformed, c.internal.enclosingOwner, callback.symbol)

Expand Down
31 changes: 16 additions & 15 deletions main/core/src/mill/define/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ trait Target[+T] extends NamedTask[T] {
def readWrite: RW[_]
}

object Target extends TargetGenerated with Applicative.Applyer[Task, Task, Result, mill.api.Ctx] {
object Target extends Applicative.Applyer[Task, Task, Result, mill.api.Ctx] {
// convenience
def dest(implicit ctx: mill.api.Ctx.Dest): os.Path = ctx.dest
def log(implicit ctx: mill.api.Ctx.Log): Logger = ctx.log
Expand Down Expand Up @@ -283,22 +283,13 @@ object Target extends TargetGenerated with Applicative.Applyer[Task, Task, Resul
)
}

type TT[+X] = Task[X]
def makeT[X](inputs0: Seq[TT[_]], evaluate0: mill.api.Ctx => Result[X]) = new Task[X] {
val inputs = inputs0
def evaluate(x: mill.api.Ctx) = evaluate0(x)
}

def underlying[A](v: Task[A]) = v
def mapCtx[A, B](t: Task[A])(f: (A, mill.api.Ctx) => Result[B]) = t.mapDest(f)
def zip() = new Task.Task0(())
def zip[A](a: Task[A]) = a.map(Tuple1(_))
def zip[A, B](a: Task[A], b: Task[B]) = a.zip(b)

def traverse[T, V](source: Seq[T])(f: T => Task[V]) = {
new Task.Sequence[V](source.map(f))
}
def sequence[T](source: Seq[Task[T]]) = new Task.Sequence[T](source)
def traverseCtx[I, R](xs: Seq[Task[I]])(f: (IndexedSeq[I], mill.api.Ctx) => Result[R]): Task[R] = {
new Task.TraverseCtx[I, R](xs, f)
}
}

abstract class NamedTaskImpl[+T](ctx0: mill.define.Ctx, t: Task[T]) extends NamedTask[T] {
Expand Down Expand Up @@ -372,9 +363,19 @@ object Task {
val inputs = inputs0
def evaluate(args: mill.api.Ctx) = {
for (i <- 0 until args.length)
yield args(i).asInstanceOf[T]
yield args(i).asInstanceOf[T]
}
}
class TraverseCtx[+T, V](inputs0: Seq[Task[T]],
f: (IndexedSeq[T], mill.api.Ctx) => Result[V]) extends Task[V] {
val inputs = inputs0
def evaluate(args: mill.api.Ctx) = {
f(
for (i <- 0 until args.length)
yield args(i).asInstanceOf[T],
args
)
}

}
class Mapped[+T, +V](source: Task[T], f: T => V) extends Task[V] {
def evaluate(args: mill.api.Ctx) = f(args(0))
Expand Down
2 changes: 0 additions & 2 deletions main/src/mill/main/ReplApplyHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ class ReplApplyHandler(
}
}

val generatedEval = new EvalGenerated(evaluator)

val millHandlers: PartialFunction[Any, pprint.Tree] = {
case c: Cross[_] =>
ReplApplyHandler.pprintCross(c, evaluator)
Expand Down
39 changes: 33 additions & 6 deletions main/test/src/define/ApplicativeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ import scala.language.experimental.macros

object ApplicativeTests extends TestSuite {
implicit def optionToOpt[T](o: Option[T]): Opt[T] = new Opt(o)
class Opt[T](val self: Option[T]) extends Applicative.Applyable[Option, T]
object Opt extends OptGenerated with Applicative.Applyer[Opt, Option, Applicative.Id, String] {
class Opt[+T](val self: Option[T]) extends Applicative.Applyable[Option, T]
object Opt extends Applicative.Applyer[Opt, Option, Applicative.Id, String] {

val injectedCtx = "helloooo"
def underlying[A](v: Opt[A]) = v.self
def apply[T](t: T): Option[T] = macro Applicative.impl[Option, T, String]

def mapCtx[A, B](a: Option[A])(f: (A, String) => B): Option[B] = a.map(f(_, injectedCtx))
def zip() = Some(())
def zip[A](a: Option[A]) = a.map(Tuple1(_))
def traverseCtx[I, R](xs: Seq[Opt[I]])(f: (IndexedSeq[I], String) => Applicative.Id[R]): Option[R] = {
if (xs.exists(_.self.isEmpty)) None
else Some(f(xs.map(_.self.get).toVector, injectedCtx))
}
}
class Counter {
var value = 0
Expand All @@ -39,6 +39,33 @@ object ApplicativeTests extends TestSuite {
"twoSomes" - assert(Opt(Some("lol ")() + Some("hello")()) == Some("lol hello"))
"singleNone" - assert(Opt("lol " + None()) == None)
"twoNones" - assert(Opt("lol " + None() + None()) == None)
"moreThan22" - {
assert(
Opt(
"lol " +
None() + None() + None() + None() + None() +
None() + None() + None() + None() + Some(" world")() +
None() + None() + None() + None() + None() +
None() + None() + None() + None() + None() +
None() + None() + None() + None() + Some(" moo")()
) == None
)
assert(
Opt(
"lol " +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")() +
Some("a")() + Some("b")() + Some("c")() + Some("d")() + Some("e")()
) == Some("lol abcdeabcdeabcdeabcdeabcdeabcdeabcdeabcdeabcdeabcde")
)
}
}
"context" - {
assert(Opt(Opt.ctx() + Some("World")()) == Some("hellooooWorld"))
Expand Down
8 changes: 4 additions & 4 deletions main/test/src/eval/EvaluationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class EvaluationTests(threadCount: Option[Int]) extends TestSuite {
// cached target
val check = new Checker(build)
assert(leftCount == 0, rightCount == 0)
check(down, expValue = 10101, expEvaled = Agg(up, right, down), extraEvaled = 8)
check(down, expValue = 10101, expEvaled = Agg(up, right, down), extraEvaled = 5)
assert(leftCount == 1, middleCount == 1, rightCount == 1)

// If the upstream `up` doesn't change, the entire block of tasks
Expand All @@ -340,7 +340,7 @@ class EvaluationTests(threadCount: Option[Int]) extends TestSuite {
// because tasks have no cached value that can be used. `right`, which
// is a cached Target, does not recompute
up.inputs(0).asInstanceOf[Test].counter += 1
check(down, expValue = 10102, expEvaled = Agg(up, down), extraEvaled = 6)
check(down, expValue = 10102, expEvaled = Agg(up, down), extraEvaled = 4)
assert(leftCount == 2, middleCount == 2, rightCount == 1)

// Running the tasks themselves results in them being recomputed every
Expand All @@ -350,9 +350,9 @@ class EvaluationTests(threadCount: Option[Int]) extends TestSuite {
check(left, expValue = 2, expEvaled = Agg(), extraEvaled = 1, secondRunNoOp = false)
assert(leftCount == 4, middleCount == 2, rightCount == 1)

check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 2, secondRunNoOp = false)
check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 1, secondRunNoOp = false)
assert(leftCount == 4, middleCount == 3, rightCount == 1)
check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 2, secondRunNoOp = false)
check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 1, secondRunNoOp = false)
assert(leftCount == 4, middleCount == 4, rightCount == 1)
}
}
Expand Down

0 comments on commit bb37e70

Please sign in to comment.