Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OptimizationEntities: move TreeOps.statementStarts #4408

Merged
merged 1 commit into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
if (
emptyQueueSpots.contains(idx) ||
optimizer.dequeueOnNewStatements && !(depth == 0 &&
noOptZone) && statementStarts.contains(idx)
noOptZone) && optimizationEntities.statementStarts.contains(idx)
) Q.addGeneration()

val noBlockClose = start == curr && 0 != maxCost || !noOptZone ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class FormatOps(
FormatTokens(topSourceTree.tokens, owners)(initStyle)
import tokens._

private[internal] val soft = new SoftKeywordClasses(dialect)
private[internal] val statementStarts = getStatementStarts(topSourceTree, soft)
private[internal] implicit val soft: SoftKeywordClasses =
new SoftKeywordClasses(dialect)

val (forceConfigStyle, emptyQueueSpots) = getForceConfigStyle

Expand Down Expand Up @@ -149,8 +149,9 @@ class FormatOps(
case _ => false
})

val StartsStatementRight =
new ExtractFromMeta[Tree](meta => statementStarts.get(meta.idx + 1))
val StartsStatementRight = new ExtractFromMeta[Tree](meta =>
optimizationEntities.statementStarts.get(meta.idx + 1),
)

def parensTuple(token: T): TokenRanges = matchingOpt(token)
.fold(TokenRanges.empty)(other => TokenRanges(TokenRange(token, other.left)))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package org.scalafmt.internal

import org.scalafmt.Error
import org.scalafmt.util._

import scala.meta._
import scala.meta.tokens.{Token => T}

import scala.collection.mutable
import scala.reflect.ClassTag

class OptimizationEntities(
argumentStarts: Map[Int, Tree],
optionalNewlines: Set[Int],
val statementStarts: Map[Int, Tree],
) {
def argumentAt(idx: Int): Option[Tree] = argumentStarts.get(idx)
def argument(implicit ft: FormatToken): Option[Tree] = argumentAt(ft.meta.idx)
Expand All @@ -15,25 +21,36 @@ class OptimizationEntities(
}

object OptimizationEntities {
def apply(tree: Tree)(implicit ftoks: FormatTokens): OptimizationEntities =
new Builder(tree).build()
def apply(tree: Tree)(implicit
ftoks: FormatTokens,
soft: SoftKeywordClasses,
): OptimizationEntities = new Builder(tree).build()

private class Builder(topSourceTree: Tree)(implicit tokens: FormatTokens) {
private class Builder(topSourceTree: Tree)(implicit
ftoks: FormatTokens,
soft: SoftKeywordClasses,
) {

private val arguments = mutable.Map.empty[Int, Tree]
private val optional = Set.newBuilder[Int]
private val statements = Map.newBuilder[Int, Tree]

def build(): OptimizationEntities = {
val queue = new mutable.ListBuffer[Seq[Tree]]
queue += topSourceTree :: Nil
while (queue.nonEmpty) queue.remove(0).foreach { tree =>
processForArguments(tree)
processForStatements(tree)
queue += tree.children
}
new OptimizationEntities(arguments.toMap, optional.result())
new OptimizationEntities(
arguments.toMap,
optional.result(),
statements.result(),
)
}

private def getHeadIndex(tree: Tree): Option[Int] = tokens.getHeadOpt(tree)
private def getHeadIndex(tree: Tree): Option[Int] = ftoks.getHeadOpt(tree)
.map(_.meta.idx - 1)
private def addArgWith(key: Tree)(value: Tree): Unit = getHeadIndex(key)
.foreach(arguments.getOrElseUpdate(_, value))
Expand Down Expand Up @@ -68,6 +85,101 @@ object OptimizationEntities {
case t: Term => addArg(t)
case _ =>
}

private def addStmtFT(stmt: Tree)(ft: FormatToken): Unit = {
val isComment = ft.left.is[Token.Comment]
val nft = if (isComment) ftoks.nextAfterNonComment(ft) else ft
statements += nft.meta.idx -> stmt
}
private def addStmtTok(stmt: Tree)(token: Token) =
addStmtFT(stmt)(ftoks.after(token))
private def addStmtTree(t: Tree, stmt: Tree) = ftoks.getHeadOpt(t)
.foreach(addStmtFT(stmt))
private def addOneStmt(t: Tree) = addStmtTree(t, t)
private def addAllStmts(trees: Seq[Tree]) = trees.foreach(addOneStmt)

private def addDefnTokens(
mods: Seq[Mod],
tree: Tree,
what: String,
isMatch: Token => Boolean,
): Unit = {
// Each @annotation gets a separate line
val annotations = mods.filter(_.is[Mod.Annot])
addAllStmts(annotations)
mods.find(!_.is[Mod.Annot]) match {
// Non-annotation modifier, for example `sealed`/`abstract`
case Some(x) => addStmtTree(x, tree)
case _ =>
// No non-annotation modifier exists, fallback to keyword like `object`
tree.tokens.find(isMatch)
.fold(throw Error.CantFindDefnToken(what, tree))(addStmtTok(tree))
}
}

private def addDefn[T](mods: Seq[Mod], tree: Tree)(implicit
tag: ClassTag[T],
): Unit = {
val runtimeClass = tag.runtimeClass
addDefnTokens(
mods,
tree,
runtimeClass.getSimpleName,
runtimeClass.isInstance,
)
}

private def processForStatements(tree: Tree): Unit = tree match {
case t: Defn.Class => addDefn[T.KwClass](t.mods, t)
case t: Decl.Def => addDefn[T.KwDef](t.mods, t)
case t: Defn.Def => addDefn[T.KwDef](t.mods, t)
case t: Defn.Macro => addDefn[T.KwDef](t.mods, t)
case t: Decl.Given => addDefn[T.KwGiven](t.mods, t)
case t: Defn.Given => addDefn[T.KwGiven](t.mods, t)
case t: Defn.GivenAlias => addDefn[T.KwGiven](t.mods, t)
case t: Defn.Enum => addDefn[T.KwEnum](t.mods, t)
case t: Defn.ExtensionGroup =>
addDefnTokens(Nil, t, "extension", soft.KwExtension.unapply)
case t: Defn.Object => addDefn[T.KwObject](t.mods, t)
case t: Defn.Trait => addDefn[T.KwTrait](t.mods, t)
case t: Defn.Type => addDefn[T.KwType](t.mods, t)
case t: Decl.Type => addDefn[T.KwType](t.mods, t)
case t: Defn.Val => addDefn[T.KwVal](t.mods, t)
case t: Decl.Val => addDefn[T.KwVal](t.mods, t)
case t: Defn.Var => addDefn[T.KwVar](t.mods, t)
case t: Decl.Var => addDefn[T.KwVar](t.mods, t)
case t: Ctor.Secondary =>
addDefn[T.KwDef](t.mods, t)
addAllStmts(t.body.stats)
// special handling for rewritten blocks
case t @ Term.Block(_ :: Nil) if t.tokens.headOption.exists { x =>
// ignore single-stat block if opening brace was removed
x.is[Token.LeftBrace] && ftoks(x).left.ne(x)
} =>
case t: Term.EnumeratorsBlock =>
var wasGuard = false
t.enums.tail.foreach { x =>
val isGuard = x.is[Enumerator.Guard]
// Only guard that follows another guard starts a statement.
if (wasGuard || !isGuard) addOneStmt(x)
wasGuard = isGuard
}
case t: Term.PartialFunction => t.cases match {
case _ :: Nil =>
case x => addAllStmts(x)
}
case t @ Term.Block(s) =>
if (t.parent.is[CaseTree]) addAllStmts(
if (TreeOps.getSingleStatExceptEndMarker(s).isEmpty) s else s.drop(1),
)
else s match {
case (_: Term.FunctionTerm) :: Nil =>
case _ => addAllStmts(s)
}
case Tree.Block(s) => addAllStmts(s)
case _ => // Nothing
}

}

}
112 changes: 3 additions & 109 deletions scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
package org.scalafmt.util

import org.scalafmt.Error
import org.scalafmt.config.ScalafmtConfig
import org.scalafmt.internal.FormatToken
import org.scalafmt.internal.FormatTokens
import org.scalafmt.internal.Modification
import org.scalafmt.internal.Space
import org.scalafmt.config._
import org.scalafmt.internal._
import org.scalafmt.util.InfixApp._
import org.scalafmt.util.LoggerOps._

import scala.meta._
import scala.meta.classifiers.Classifier
import scala.meta.tokens.Token
import scala.meta.tokens.Token._
import scala.meta.tokens.Token.{Space => _, _}
import scala.meta.tokens.Tokens

import scala.annotation.tailrec
import scala.collection.immutable.HashMap
import scala.collection.mutable
import scala.reflect.ClassTag

/** Stateless helper functions on `scala.meta.Tree`.
*/
Expand Down Expand Up @@ -92,108 +88,6 @@ object TreeOps {
ftoks: FormatTokens,
): Boolean = SingleArgInBraces.orBlock(parent).exists(_._2 eq expr)

def getStatementStarts(tree: Tree, soft: SoftKeywordClasses)(implicit
ftoks: FormatTokens,
): Map[Int, Tree] = {
val ret = Map.newBuilder[Int, Tree]
ret.sizeHint(tree.tokens.length)

def addFT(stmt: Tree)(ft: FormatToken): Unit = {
val isComment = ft.left.is[Token.Comment]
val nft = if (isComment) ftoks.nextAfterNonComment(ft) else ft
ret += nft.meta.idx -> stmt
}
def addTok(stmt: Tree)(token: Token) = addFT(stmt)(ftoks.after(token))
def addTree(t: Tree, stmt: Tree) = ftoks.getHeadOpt(t).foreach(addFT(stmt))
def addOne(t: Tree) = addTree(t, t)
def addAll(trees: Seq[Tree]) = trees.foreach(addOne)

def addDefnTokens(
mods: Seq[Mod],
tree: Tree,
what: String,
isMatch: Token => Boolean,
): Unit = {
// Each @annotation gets a separate line
val annotations = mods.filter(_.is[Mod.Annot])
addAll(annotations)
mods.find(!_.is[Mod.Annot]) match {
// Non-annotation modifier, for example `sealed`/`abstract`
case Some(x) => addTree(x, tree)
case _ =>
// No non-annotation modifier exists, fallback to keyword like `object`
tree.tokens.find(isMatch)
.fold(throw Error.CantFindDefnToken(what, tree))(addTok(tree))
}
}
def addDefn[T](mods: Seq[Mod], tree: Tree)(implicit
tag: ClassTag[T],
): Unit = {
val runtimeClass = tag.runtimeClass
addDefnTokens(
mods,
tree,
runtimeClass.getSimpleName,
runtimeClass.isInstance,
)
}

def loop(subtree: Tree): Unit = {
subtree match {
case t: Defn.Class => addDefn[KwClass](t.mods, t)
case t: Decl.Def => addDefn[KwDef](t.mods, t)
case t: Defn.Def => addDefn[KwDef](t.mods, t)
case t: Defn.Macro => addDefn[KwDef](t.mods, t)
case t: Decl.Given => addDefn[KwGiven](t.mods, t)
case t: Defn.Given => addDefn[KwGiven](t.mods, t)
case t: Defn.GivenAlias => addDefn[KwGiven](t.mods, t)
case t: Defn.Enum => addDefn[KwEnum](t.mods, t)
case t: Defn.ExtensionGroup =>
addDefnTokens(Nil, t, "extension", soft.KwExtension.unapply)
case t: Defn.Object => addDefn[KwObject](t.mods, t)
case t: Defn.Trait => addDefn[KwTrait](t.mods, t)
case t: Defn.Type => addDefn[KwType](t.mods, t)
case t: Decl.Type => addDefn[KwType](t.mods, t)
case t: Defn.Val => addDefn[KwVal](t.mods, t)
case t: Decl.Val => addDefn[KwVal](t.mods, t)
case t: Defn.Var => addDefn[KwVar](t.mods, t)
case t: Decl.Var => addDefn[KwVar](t.mods, t)
case t: Ctor.Secondary =>
addDefn[KwDef](t.mods, t)
addAll(t.stats)
// special handling for rewritten blocks
case t @ Term.Block(_ :: Nil) if t.tokens.headOption.exists { x =>
// ignore single-stat block if opening brace was removed
x.is[Token.LeftBrace] && ftoks(x).left.ne(x)
} =>
case t: Term.EnumeratorsBlock =>
var wasGuard = false
t.enums.tail.foreach { x =>
val isGuard = x.is[Enumerator.Guard]
// Only guard that follows another guard starts a statement.
if (wasGuard || !isGuard) addOne(x)
wasGuard = isGuard
}
case t: Term.PartialFunction => t.cases match {
case _ :: Nil =>
case x => addAll(x)
}
case t @ Term.Block(s) =>
if (t.parent.is[CaseTree])
addAll(if (getSingleStatExceptEndMarker(s).isEmpty) s else s.drop(1))
else s match {
case (_: Term.FunctionTerm) :: Nil =>
case _ => addAll(s)
}
case Tree.Block(s) => addAll(s)
case _ => // Nothing
}
subtree.children.foreach(loop)
}
loop(tree)
ret.result()
}

/** Finds matching parens [({})].
*
* Contains lookup keys in both directions, opening [({ and closing })].
Expand Down
Loading