Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support type selectors in Resolve #2997

Merged
merged 20 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions main/resolve/src/mill/resolve/ParseArgs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,23 @@ object ParseArgs {
}

private def selector[_p: P]: P[(Option[Segments], Segments)] = {
def ident2 = P(CharsWhileIn("a-zA-Z0-9_\\-.")).!
def segment = P(mill.define.Reflect.ident).map(Segment.Label)
def crossSegment = P("[" ~ ident2.rep(1, sep = ",") ~ "]").map(Segment.Cross)
def wildcard = P("__" | "_")
def label = mill.define.Reflect.ident

def typeQualifier(simple: Boolean) = {
val maxSegments = if (simple) 0 else Int.MaxValue
P(("^" | "!").? ~~ label ~~ ("." ~~ label).rep(max = maxSegments)).!
}

def typePattern(simple: Boolean) = P(wildcard ~~ (":" ~~ typeQualifier(simple)).rep(1)).!

def segment0(simple: Boolean) = P(typePattern(simple) | label).map(Segment.Label)
def segment = P("(" ~ segment0(false) ~ ")" | segment0(true))

def identCross = P(CharsWhileIn("a-zA-Z0-9_\\-.")).!
def crossSegment = P("[" ~ identCross.rep(1, sep = ",") ~ "]").map(Segment.Cross)
def defaultCrossSegment = P("[]").map(_ => Segment.Cross(Seq()))

def simpleQuery = P(segment ~ ("." ~ segment | crossSegment | defaultCrossSegment).rep).map {
case (h, rest) => Segments(h +: rest)
}
Expand Down
27 changes: 21 additions & 6 deletions main/resolve/src/mill/resolve/Resolve.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object Resolve {
Right(resolved.map(_.segments))
}

private[mill] override def deduplicate(items: List[Segments]) = items.distinct
private[mill] override def deduplicate(items: List[Segments]): List[Segments] = items.distinct
}

object Tasks extends Resolve[NamedTask[Any]] {
Expand Down Expand Up @@ -83,11 +83,11 @@ object Resolve {
)
}

private[mill] override def deduplicate(items: List[NamedTask[Any]]) =
private[mill] override def deduplicate(items: List[NamedTask[Any]]): List[NamedTask[Any]] =
items.distinctBy(_.ctx.segments)
}

private def instantiateTarget(r: Resolved.Target, p: Module) = {
private def instantiateTarget(r: Resolved.Target, p: Module): Either[String, Target[_]] = {
val definition = Reflect
.reflect(p.getClass, classOf[Target[_]], _ == r.segments.parts.last, true)
.head
Expand Down Expand Up @@ -230,11 +230,23 @@ trait Resolve[T] {
): Either[String, Seq[T]] = {
val rootResolved = ResolveCore.Resolved.Module(Segments(), rootModule.getClass)
val resolved =
ResolveCore.resolve(rootModule, sel.value.toList, rootResolved, Segments()) match {
ResolveCore.resolve(
rootModule = rootModule,
remainingQuery = sel.value.toList,
current = rootResolved,
querySoFar = Segments()
) match {
case ResolveCore.Success(value) => Right(value)
case ResolveCore.NotFound(segments, found, next, possibleNexts) =>
val allPossibleNames = rootModule.millDiscover.value.values.flatMap(_._1).toSet
Left(ResolveNotFoundHandler(sel, segments, found, next, possibleNexts, allPossibleNames))
Left(ResolveNotFoundHandler(
selector = sel,
segments = segments,
found = found,
next = next,
possibleNexts = possibleNexts,
allPossibleNames = allPossibleNames
))
case ResolveCore.Error(value) => Left(value)
}

Expand All @@ -245,7 +257,10 @@ trait Resolve[T] {

private[mill] def deduplicate(items: List[T]): List[T] = items

private[mill] def resolveRootModule(rootModule: BaseModule, scopedSel: Option[Segments]) = {
private[mill] def resolveRootModule(
rootModule: BaseModule,
scopedSel: Option[Segments]
): Either[String, BaseModule] = {
scopedSel match {
case None => Right(rootModule)
case Some(scoping) =>
Expand Down
124 changes: 112 additions & 12 deletions main/resolve/src/mill/resolve/ResolveCore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,37 @@ private object ResolveCore {
case "__" =>
val self = Seq(Resolved.Module(m.segments, m.cls))
val transitiveOrErr =
resolveTransitiveChildren(rootModule, m.cls, None, current.segments)
resolveTransitiveChildren(rootModule, m.cls, None, current.segments, Nil)

transitiveOrErr.map(transitive => self ++ transitive)

case "_" =>
resolveDirectChildren(rootModule, m.cls, None, current.segments)

case pattern if pattern.startsWith("__:") =>
val typePattern = pattern.split(":").drop(1)
val self = Seq(Resolved.Module(m.segments, m.cls))

val transitiveOrErr = resolveTransitiveChildren(
rootModule,
m.cls,
None,
current.segments,
typePattern
)

transitiveOrErr.map(transitive => self ++ transitive)

case pattern if pattern.startsWith("_:") =>
val typePattern = pattern.split(":").drop(1)
resolveDirectChildren(
rootModule,
m.cls,
None,
current.segments,
typePattern
)

case _ =>
resolveDirectChildren(rootModule, m.cls, Some(singleLabel), current.segments)
}
Expand Down Expand Up @@ -186,22 +210,75 @@ private object ResolveCore {
cls: Class[_],
nameOpt: Option[String],
segments: Segments
): Either[String, Set[Resolved]] =
resolveTransitiveChildren(rootModule, cls, nameOpt, segments, Nil)

def resolveTransitiveChildren(
rootModule: Module,
cls: Class[_],
nameOpt: Option[String],
segments: Segments,
typePattern: Seq[String]
): Either[String, Set[Resolved]] = {
for {
direct <- resolveDirectChildren(rootModule, cls, nameOpt, segments)
indirect0 = direct
.collect { case m: Resolved.Module =>
resolveTransitiveChildren(rootModule, m.cls, nameOpt, m.segments)
}
indirect <- EitherOps.sequence(indirect0).map(_.flatten)
} yield direct ++ indirect
val direct = resolveDirectChildren(rootModule, cls, nameOpt, segments, typePattern)
direct.flatMap { direct =>
for {
directTraverse <- resolveDirectChildren(rootModule, cls, nameOpt, segments, Nil)
indirect0 = directTraverse
.collect { case m: Resolved.Module =>
resolveTransitiveChildren(rootModule, m.cls, nameOpt, m.segments, typePattern)
}
indirect <- EitherOps.sequence(indirect0).map(_.flatten)
} yield direct ++ indirect
}
}

private def resolveParents(c: Class[_]): Seq[Class[_]] =
Seq(c) ++
Option(c.getSuperclass).toSeq.flatMap(resolveParents) ++
c.getInterfaces.flatMap(resolveParents)

/**
* Check if the given class matches a given type selector as string
* @param cls
* @param typePattern
* @return
*/
private def classMatchesTypePred(typePattern: Seq[String])(cls: Class[_]): Boolean =
typePattern
.forall { pat =>
val negate = pat.startsWith("^") || pat.startsWith("!")
val clsPat = pat.drop(if (negate) 1 else 0)

// We split full class names by `.` and `$`
// a class matches a type patter, if the type pattern segments match from the right
// to express a full match, use `_root_` as first segment

val typeNames = clsPat.split("[.$]").toSeq.reverse

val parents = resolveParents(cls)
val classNames = parents.flatMap(c =>
("_root_$" + c.getName).split("[.$]").toSeq.reverse.inits.toSeq.filter(_.nonEmpty)
)

val isOfType = classNames.contains(typeNames)
if (negate) !isOfType else isOfType
}

def resolveDirectChildren(
rootModule: Module,
cls: Class[_],
nameOpt: Option[String],
segments: Segments
): Either[String, Set[Resolved]] =
resolveDirectChildren(rootModule, cls, nameOpt, segments, typePattern = Nil)

def resolveDirectChildren(
rootModule: Module,
cls: Class[_],
nameOpt: Option[String],
segments: Segments,
typePattern: Seq[String]
): Either[String, Set[Resolved]] = {

val crossesOrErr = if (classOf[Cross[_]].isAssignableFrom(cls) && nameOpt.isEmpty) {
Expand All @@ -216,15 +293,19 @@ private object ResolveCore {
} else Right(Nil)

crossesOrErr.flatMap { crosses =>
resolveDirectChildren0(rootModule, segments, cls, nameOpt)
val filteredCrosses = crosses.filter { c =>
classMatchesTypePred(typePattern)(c.cls)
}

resolveDirectChildren0(rootModule, segments, cls, nameOpt, typePattern)
.map(
_.map {
case (Resolved.Module(s, cls), _) => Resolved.Module(segments ++ s, cls)
case (Resolved.Target(s), _) => Resolved.Target(segments ++ s)
case (Resolved.Command(s), _) => Resolved.Command(segments ++ s)
}
.toSet
.++(crosses)
.++(filteredCrosses)
)
}
}
Expand All @@ -234,15 +315,25 @@ private object ResolveCore {
segments: Segments,
cls: Class[_],
nameOpt: Option[String]
): Either[String, Seq[(Resolved, Option[Module => Either[String, Module]])]] =
resolveDirectChildren0(rootModule, segments, cls, nameOpt, Nil)

def resolveDirectChildren0(
rootModule: Module,
segments: Segments,
cls: Class[_],
nameOpt: Option[String],
typePattern: Seq[String]
): Either[String, Seq[(Resolved, Option[Module => Either[String, Module]])]] = {
def namePred(n: String) = nameOpt.isEmpty || nameOpt.contains(n)

val modulesOrErr: Either[String, Seq[(Resolved, Option[Module => Either[String, Module]])]] =
val modulesOrErr: Either[String, Seq[(Resolved, Option[Module => Either[String, Module]])]] = {
if (classOf[DynamicModule].isAssignableFrom(cls)) {
instantiateModule(rootModule, segments).map {
case m: DynamicModule =>
m.millModuleDirectChildren
.filter(c => namePred(c.millModuleSegments.parts.last))
.filter(c => classMatchesTypePred(typePattern)(c.getClass))
.map(c =>
(
Resolved.Module(
Expand All @@ -256,6 +347,14 @@ private object ResolveCore {
} else Right {
Reflect
.reflectNestedObjects0[Module](cls, namePred)
.filter {
case (_, member) =>
val memberCls = member match {
case f: java.lang.reflect.Field => f.getType
case f: java.lang.reflect.Method => f.getReturnType
}
classMatchesTypePred(typePattern)(memberCls)
}
.map { case (name, member) =>
Resolved.Module(
Segments.labels(decode(name)),
Expand All @@ -274,6 +373,7 @@ private object ResolveCore {
)
}
}
}

val targets = Reflect
.reflect(cls, classOf[Target[_]], namePred, noParams = true)
Expand Down
Loading