Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed Dec 16, 2021
1 parent 4db803d commit 9a0e7b4
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 162 deletions.
7 changes: 0 additions & 7 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,11 @@ object main extends MillModule {
override def compileIvyDeps = Agg(
Deps.scalaReflect(scalaVersion())
)
override def generatedSources = T {
Seq(PathRef(shared.generateCoreSources(T.ctx.dest)))
}
override def testArgs = Seq(
"-DMILL_VERSION=" + publishVersion()
)
override val test = new Tests(implicitly)
class Tests(ctx0: mill.define.Ctx) extends super.Tests(ctx0) {
override def generatedSources = T {
Seq(PathRef(shared.generateCoreTestSources(T.ctx.dest)))
}
}
object api extends MillApiModule {
override def ivyDeps = Agg(
Expand Down Expand Up @@ -301,7 +295,6 @@ object main extends MillModule {
millBinPlatform = millBinPlatform(),
artifacts = T.traverse(dev.moduleDeps)(_.publishSelfDependency)()
)
shared.generateCoreSources(dest)
Seq(PathRef(dest))
}

Expand Down
107 changes: 0 additions & 107 deletions ci/shared.sc
Original file line number Diff line number Diff line change
Expand Up @@ -6,100 +6,6 @@

import $ivy.`org.scalaj::scalaj-http:2.4.2`

def argNames(n: Int) = {
val uppercases = (0 until n).map("T" + _)
val lowercases = uppercases.map(_.toLowerCase)
val typeArgs = uppercases.mkString(", ")
val zipArgs = lowercases.mkString(", ")
(lowercases, uppercases, typeArgs, zipArgs)
}
def generateApplyer(dir: os.Path) = {
def generate(n: Int) = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: TT[$upper]" }.mkString(", ")

val body = s"mapCtx(zip($zipArgs)) { case (($zipArgs), z) => cb($zipArgs, z) }"
val zipmap = s"def zipMap[$typeArgs, Res]($parameters)(cb: ($typeArgs, Ctx) => Z[Res]) = $body"
val zip = s"def zip[$typeArgs]($parameters): TT[($typeArgs)]"

if (n < 22) List(zipmap, zip).mkString("\n") else zip
}
os.write(
dir / "ApplicativeGenerated.scala",
s"""package mill.define
|import scala.language.higherKinds
|trait ApplyerGenerated[TT[_], Z[_], Ctx] {
| def mapCtx[A, B](a: TT[A])(f: (A, Ctx) => Z[B]): TT[B]
| ${(2 to 22).map(generate).mkString("\n")}
|}""".stripMargin
)
}

def generateTarget(dir: os.Path) = {
def generate(n: Int) = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: TT[$upper]" }.mkString(", ")
val body = uppercases.zipWithIndex.map { case (t, i) => s"args[$t]($i)" }.mkString(", ")

s"def zip[$typeArgs]($parameters) = makeT[($typeArgs)](Seq($zipArgs), (args: mill.api.Ctx) => ($body))"
}

os.write(
dir / "TaskGenerated.scala",
s"""package mill.define
|import scala.language.higherKinds
|trait TargetGenerated {
| type TT[+X]
| def makeT[X](inputs: Seq[TT[_]], evaluate: mill.api.Ctx => mill.api.Result[X]): TT[X]
| ${(3 to 22).map(generate).mkString("\n")}
|}""".stripMargin
)
}

def generateEval(dir: os.Path) = {
def generate(n: Int) = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: TT[$upper]" }.mkString(", ")
val extract = uppercases.zipWithIndex.map { case (t, i) => s"result($i).asInstanceOf[$t]" }.mkString(", ")

s"""def eval[$typeArgs]($parameters):($typeArgs) = {
| val result = evaluator.evaluate(Agg($zipArgs)).values
| (${extract})
|}
""".stripMargin
}

os.write(
dir / "EvalGenerated.scala",
s"""package mill.main
|import mill.eval.Evaluator
|import mill.define.Task
|import mill.api.Strict.Agg
|class EvalGenerated(evaluator: Evaluator) {
| type TT[+X] = Task[X]
| ${(1 to 22).map(generate).mkString("\n")}
|}""".stripMargin
)
}

def generateApplicativeTest(dir: os.Path) = {
def generate(n: Int): String = {
val (lowercases, uppercases, typeArgs, zipArgs) = argNames(n)
val parameters = lowercases.zip(uppercases).map { case (lower, upper) => s"$lower: Option[$upper]" }.mkString(", ")
val forArgs = lowercases.map(i => s"$i <- $i").mkString("; ")
s"def zip[$typeArgs]($parameters) = { for ($forArgs) yield ($zipArgs) }"
}

os.write(
dir / "ApplicativeTestsGenerated.scala",
s"""package mill.define
|trait OptGenerated {
| ${(2 to 22).map(generate).mkString("\n")}
|}
""".stripMargin
)
}

def unpackZip(zipDest: os.Path, url: String) = {
println(s"Unpacking zip $url into $zipDest")
os.makeDir.all(zipDest)
Expand Down Expand Up @@ -132,19 +38,6 @@ def unpackZip(zipDest: os.Path, url: String) = {
})()
}

@main
def generateCoreSources(p: os.Path) = {
generateApplyer(p)
generateTarget(p)
generateEval(p)
p
}

@main
def generateCoreTestSources(p: os.Path) = {
generateApplicativeTest(p)
p
}


@main
Expand Down
29 changes: 4 additions & 25 deletions main/core/src/mill/define/Applicative.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,10 @@ object Applicative {

type Id[+T] = T

trait Applyer[W[_], T[_], Z[_], Ctx] extends ApplyerGenerated[T, Z, Ctx] {
trait Applyer[W[_], T[_], Z[_], Ctx] {
def ctx()(implicit c: Ctx) = c
def underlying[A](v: W[A]): T[_]

def zipMap[R]()(cb: Ctx => Z[R]) = mapCtx(zip()) { (_, ctx) => cb(ctx) }
def zipMap[A, R](a: T[A])(f: (A, Ctx) => Z[R]) = mapCtx(a)(f)
def zipMapLong[R](xs: IndexedSeq[T[Any]])(f: (IndexedSeq[Any], Ctx) => Z[R]) = {
var recursiveZipped: T[_] = zip()
for(x <- xs) {
recursiveZipped = zip(recursiveZipped, x)
}

mapCtx(recursiveZipped) { (nested, ctx) =>
var items = List.empty[Any]
var current: Any = nested
while(current != ()){
val (rest, v) = current
current = rest
items = v :: items
}
f(items.toArray[Any], ctx)
}
}
def zip(): T[Unit]
def zip[A](a: T[A]): T[Tuple1[A]]
def traverseCtx[I, R](xs: Seq[W[I]])(f: (IndexedSeq[I], Ctx) => Z[R]): T[R]
}

def impl[M[_], T: c.WeakTypeTag, Ctx: c.WeakTypeTag](c: Context)(t: c.Expr[T]): c.Expr[M[T]] = {
Expand Down Expand Up @@ -96,7 +75,7 @@ object Applicative {
c.internal.setType(tempIdent, t.tpe)
c.internal.setFlag(tempSym, (1L << 44).asInstanceOf[c.universe.FlagSet])
val itemsIdent = Ident(itemsSym)
exprs.append(q"${c.prefix}.underlying($fun)")
exprs.append(q"$fun")
c.typecheck(q"$itemsIdent(${exprs.size-1}).asInstanceOf[${t.tpe}]")
case (t, api)
if t.symbol != null
Expand All @@ -114,7 +93,7 @@ object Applicative {
val itemsBinding = c.internal.valDef(itemsSym)
val callback = c.typecheck(q"{(${itemsBinding}, ${ctxBinding}) => $transformed}")

val res = q"${c.prefix}.zipMapLong(${exprs.toList}){ $callback }"
val res = q"${c.prefix}.traverseCtx[_root_.scala.Any, ${weakTypeOf[T]}](${exprs.toList}){ $callback }"

c.internal.changeOwner(transformed, c.internal.enclosingOwner, callback.symbol)

Expand Down
29 changes: 17 additions & 12 deletions main/core/src/mill/define/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ trait Target[+T] extends NamedTask[T] {
def readWrite: RW[_]
}

object Target extends TargetGenerated with Applicative.Applyer[Task, Task, Result, mill.api.Ctx] {
object Target extends Applicative.Applyer[Task, Task, Result, mill.api.Ctx] {
// convenience
def dest(implicit ctx: mill.api.Ctx.Dest): os.Path = ctx.dest
def log(implicit ctx: mill.api.Ctx.Log): Logger = ctx.log
Expand Down Expand Up @@ -284,21 +284,16 @@ object Target extends TargetGenerated with Applicative.Applyer[Task, Task, Resul
}

type TT[+X] = Task[X]
def makeT[X](inputs0: Seq[TT[_]], evaluate0: mill.api.Ctx => Result[X]) = new Task[X] {
val inputs = inputs0
def evaluate(x: mill.api.Ctx) = evaluate0(x)
}

def underlying[A](v: Task[A]) = v
def mapCtx[A, B](t: Task[A])(f: (A, mill.api.Ctx) => Result[B]) = t.mapDest(f)
def zip() = new Task.Task0(())
def zip[A](a: Task[A]) = a.map(Tuple1(_))
def zip[A, B](a: Task[A], b: Task[B]) = a.zip(b)

def traverse[T, V](source: Seq[T])(f: T => Task[V]) = {
new Task.Sequence[V](source.map(f))
}
def sequence[T](source: Seq[Task[T]]) = new Task.Sequence[T](source)
def traverseCtx[I, R](xs: Seq[Task[I]])(f: (IndexedSeq[I], mill.api.Ctx) => Result[R]): Task[R] = {
new Task.TraverseCtx[I, R](xs, f)
}
}

abstract class NamedTaskImpl[+T](ctx0: mill.define.Ctx, t: Task[T]) extends NamedTask[T] {
Expand Down Expand Up @@ -368,13 +363,23 @@ object Task {

}

class Sequence[+T](inputs0: Seq[Task[T]]) extends Task[Seq[T]] {
class Sequence[+T](inputs0: Seq[Task[T]]) extends Task[IndexedSeq[T]] {
val inputs = inputs0
def evaluate(args: mill.api.Ctx) = {
for (i <- 0 until args.length)
yield args(i).asInstanceOf[T]
yield args(i).asInstanceOf[T]
}
}
class TraverseCtx[+T, V](inputs0: Seq[Task[T]],
f: (IndexedSeq[T], mill.api.Ctx) => Result[V]) extends Task[V] {
val inputs = inputs0
def evaluate(args: mill.api.Ctx) = {
f(
for (i <- 0 until args.length)
yield args(i).asInstanceOf[T],
args
)
}

}
class Mapped[+T, +V](source: Task[T], f: T => V) extends Task[V] {
def evaluate(args: mill.api.Ctx) = f(args(0))
Expand Down
2 changes: 0 additions & 2 deletions main/src/mill/main/ReplApplyHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ class ReplApplyHandler(
}
}

val generatedEval = new EvalGenerated(evaluator)

val millHandlers: PartialFunction[Any, pprint.Tree] = {
case c: Cross[_] =>
ReplApplyHandler.pprintCross(c, evaluator)
Expand Down
11 changes: 6 additions & 5 deletions main/test/src/define/ApplicativeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@ import scala.language.experimental.macros

object ApplicativeTests extends TestSuite {
implicit def optionToOpt[T](o: Option[T]): Opt[T] = new Opt(o)
class Opt[T](val self: Option[T]) extends Applicative.Applyable[Option, T]
object Opt extends OptGenerated with Applicative.Applyer[Opt, Option, Applicative.Id, String] {
class Opt[+T](val self: Option[T]) extends Applicative.Applyable[Option, T]
object Opt extends Applicative.Applyer[Opt, Option, Applicative.Id, String] {

val injectedCtx = "helloooo"
def underlying[A](v: Opt[A]) = v.self
def apply[T](t: T): Option[T] = macro Applicative.impl[Option, T, String]

def mapCtx[A, B](a: Option[A])(f: (A, String) => B): Option[B] = a.map(f(_, injectedCtx))
def zip() = Some(())
def zip[A](a: Option[A]) = a.map(Tuple1(_))
def traverseCtx[I, R](xs: Seq[Opt[I]])(f: (IndexedSeq[I], String) => Applicative.Id[R]): Option[R] = {
if (xs.exists(_.self.isEmpty)) None
else Some(f(xs.map(_.self.get).toVector, injectedCtx))
}
}
class Counter {
var value = 0
Expand Down
8 changes: 4 additions & 4 deletions main/test/src/eval/EvaluationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class EvaluationTests(threadCount: Option[Int]) extends TestSuite {
// cached target
val check = new Checker(build)
assert(leftCount == 0, rightCount == 0)
check(down, expValue = 10101, expEvaled = Agg(up, right, down), extraEvaled = 8)
check(down, expValue = 10101, expEvaled = Agg(up, right, down), extraEvaled = 5)
assert(leftCount == 1, middleCount == 1, rightCount == 1)

// If the upstream `up` doesn't change, the entire block of tasks
Expand All @@ -340,7 +340,7 @@ class EvaluationTests(threadCount: Option[Int]) extends TestSuite {
// because tasks have no cached value that can be used. `right`, which
// is a cached Target, does not recompute
up.inputs(0).asInstanceOf[Test].counter += 1
check(down, expValue = 10102, expEvaled = Agg(up, down), extraEvaled = 6)
check(down, expValue = 10102, expEvaled = Agg(up, down), extraEvaled = 4)
assert(leftCount == 2, middleCount == 2, rightCount == 1)

// Running the tasks themselves results in them being recomputed every
Expand All @@ -350,9 +350,9 @@ class EvaluationTests(threadCount: Option[Int]) extends TestSuite {
check(left, expValue = 2, expEvaled = Agg(), extraEvaled = 1, secondRunNoOp = false)
assert(leftCount == 4, middleCount == 2, rightCount == 1)

check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 2, secondRunNoOp = false)
check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 1, secondRunNoOp = false)
assert(leftCount == 4, middleCount == 3, rightCount == 1)
check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 2, secondRunNoOp = false)
check(middle, expValue = 100, expEvaled = Agg(), extraEvaled = 1, secondRunNoOp = false)
assert(leftCount == 4, middleCount == 4, rightCount == 1)
}
}
Expand Down

0 comments on commit 9a0e7b4

Please sign in to comment.