diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala index 9e05afa16..0b1f0413f 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala @@ -100,7 +100,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit if ( emptyQueueSpots.contains(idx) || optimizer.dequeueOnNewStatements && !(depth == 0 && - noOptZone) && statementStarts.contains(idx) + noOptZone) && optimizationEntities.statementStarts.contains(idx) ) Q.addGeneration() val noBlockClose = start == curr && 0 != maxCost || !noOptZone || diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala index 66abdf5f6..db263908f 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala @@ -41,8 +41,8 @@ class FormatOps( FormatTokens(topSourceTree.tokens, owners)(initStyle) import tokens._ - private[internal] val soft = new SoftKeywordClasses(dialect) - private[internal] val statementStarts = getStatementStarts(topSourceTree, soft) + private[internal] implicit val soft: SoftKeywordClasses = + new SoftKeywordClasses(dialect) val (forceConfigStyle, emptyQueueSpots) = getForceConfigStyle @@ -149,8 +149,9 @@ class FormatOps( case _ => false }) - val StartsStatementRight = - new ExtractFromMeta[Tree](meta => statementStarts.get(meta.idx + 1)) + val StartsStatementRight = new ExtractFromMeta[Tree](meta => + optimizationEntities.statementStarts.get(meta.idx + 1), + ) def parensTuple(token: T): TokenRanges = matchingOpt(token) .fold(TokenRanges.empty)(other => TokenRanges(TokenRange(token, other.left))) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/OptimizationEntities.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/OptimizationEntities.scala index 87ceff637..2b2c866d0 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/OptimizationEntities.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/OptimizationEntities.scala @@ -1,12 +1,18 @@ package org.scalafmt.internal +import org.scalafmt.Error +import org.scalafmt.util._ + import scala.meta._ +import scala.meta.tokens.{Token => T} import scala.collection.mutable +import scala.reflect.ClassTag class OptimizationEntities( argumentStarts: Map[Int, Tree], optionalNewlines: Set[Int], + val statementStarts: Map[Int, Tree], ) { def argumentAt(idx: Int): Option[Tree] = argumentStarts.get(idx) def argument(implicit ft: FormatToken): Option[Tree] = argumentAt(ft.meta.idx) @@ -15,25 +21,36 @@ class OptimizationEntities( } object OptimizationEntities { - def apply(tree: Tree)(implicit ftoks: FormatTokens): OptimizationEntities = - new Builder(tree).build() + def apply(tree: Tree)(implicit + ftoks: FormatTokens, + soft: SoftKeywordClasses, + ): OptimizationEntities = new Builder(tree).build() - private class Builder(topSourceTree: Tree)(implicit tokens: FormatTokens) { + private class Builder(topSourceTree: Tree)(implicit + ftoks: FormatTokens, + soft: SoftKeywordClasses, + ) { private val arguments = mutable.Map.empty[Int, Tree] private val optional = Set.newBuilder[Int] + private val statements = Map.newBuilder[Int, Tree] def build(): OptimizationEntities = { val queue = new mutable.ListBuffer[Seq[Tree]] queue += topSourceTree :: Nil while (queue.nonEmpty) queue.remove(0).foreach { tree => processForArguments(tree) + processForStatements(tree) queue += tree.children } - new OptimizationEntities(arguments.toMap, optional.result()) + new OptimizationEntities( + arguments.toMap, + optional.result(), + statements.result(), + ) } - private def getHeadIndex(tree: Tree): Option[Int] = tokens.getHeadOpt(tree) + private def getHeadIndex(tree: Tree): Option[Int] = ftoks.getHeadOpt(tree) .map(_.meta.idx - 1) private def addArgWith(key: Tree)(value: Tree): Unit = getHeadIndex(key) .foreach(arguments.getOrElseUpdate(_, value)) @@ -68,6 +85,101 @@ object OptimizationEntities { case t: Term => addArg(t) case _ => } + + private def addStmtFT(stmt: Tree)(ft: FormatToken): Unit = { + val isComment = ft.left.is[Token.Comment] + val nft = if (isComment) ftoks.nextAfterNonComment(ft) else ft + statements += nft.meta.idx -> stmt + } + private def addStmtTok(stmt: Tree)(token: Token) = + addStmtFT(stmt)(ftoks.after(token)) + private def addStmtTree(t: Tree, stmt: Tree) = ftoks.getHeadOpt(t) + .foreach(addStmtFT(stmt)) + private def addOneStmt(t: Tree) = addStmtTree(t, t) + private def addAllStmts(trees: Seq[Tree]) = trees.foreach(addOneStmt) + + private def addDefnTokens( + mods: Seq[Mod], + tree: Tree, + what: String, + isMatch: Token => Boolean, + ): Unit = { + // Each @annotation gets a separate line + val annotations = mods.filter(_.is[Mod.Annot]) + addAllStmts(annotations) + mods.find(!_.is[Mod.Annot]) match { + // Non-annotation modifier, for example `sealed`/`abstract` + case Some(x) => addStmtTree(x, tree) + case _ => + // No non-annotation modifier exists, fallback to keyword like `object` + tree.tokens.find(isMatch) + .fold(throw Error.CantFindDefnToken(what, tree))(addStmtTok(tree)) + } + } + + private def addDefn[T](mods: Seq[Mod], tree: Tree)(implicit + tag: ClassTag[T], + ): Unit = { + val runtimeClass = tag.runtimeClass + addDefnTokens( + mods, + tree, + runtimeClass.getSimpleName, + runtimeClass.isInstance, + ) + } + + private def processForStatements(tree: Tree): Unit = tree match { + case t: Defn.Class => addDefn[T.KwClass](t.mods, t) + case t: Decl.Def => addDefn[T.KwDef](t.mods, t) + case t: Defn.Def => addDefn[T.KwDef](t.mods, t) + case t: Defn.Macro => addDefn[T.KwDef](t.mods, t) + case t: Decl.Given => addDefn[T.KwGiven](t.mods, t) + case t: Defn.Given => addDefn[T.KwGiven](t.mods, t) + case t: Defn.GivenAlias => addDefn[T.KwGiven](t.mods, t) + case t: Defn.Enum => addDefn[T.KwEnum](t.mods, t) + case t: Defn.ExtensionGroup => + addDefnTokens(Nil, t, "extension", soft.KwExtension.unapply) + case t: Defn.Object => addDefn[T.KwObject](t.mods, t) + case t: Defn.Trait => addDefn[T.KwTrait](t.mods, t) + case t: Defn.Type => addDefn[T.KwType](t.mods, t) + case t: Decl.Type => addDefn[T.KwType](t.mods, t) + case t: Defn.Val => addDefn[T.KwVal](t.mods, t) + case t: Decl.Val => addDefn[T.KwVal](t.mods, t) + case t: Defn.Var => addDefn[T.KwVar](t.mods, t) + case t: Decl.Var => addDefn[T.KwVar](t.mods, t) + case t: Ctor.Secondary => + addDefn[T.KwDef](t.mods, t) + addAllStmts(t.body.stats) + // special handling for rewritten blocks + case t @ Term.Block(_ :: Nil) if t.tokens.headOption.exists { x => + // ignore single-stat block if opening brace was removed + x.is[Token.LeftBrace] && ftoks(x).left.ne(x) + } => + case t: Term.EnumeratorsBlock => + var wasGuard = false + t.enums.tail.foreach { x => + val isGuard = x.is[Enumerator.Guard] + // Only guard that follows another guard starts a statement. + if (wasGuard || !isGuard) addOneStmt(x) + wasGuard = isGuard + } + case t: Term.PartialFunction => t.cases match { + case _ :: Nil => + case x => addAllStmts(x) + } + case t @ Term.Block(s) => + if (t.parent.is[CaseTree]) addAllStmts( + if (TreeOps.getSingleStatExceptEndMarker(s).isEmpty) s else s.drop(1), + ) + else s match { + case (_: Term.FunctionTerm) :: Nil => + case _ => addAllStmts(s) + } + case Tree.Block(s) => addAllStmts(s) + case _ => // Nothing + } + } } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala index 22076a1db..a40b22c3d 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala @@ -1,24 +1,20 @@ package org.scalafmt.util import org.scalafmt.Error -import org.scalafmt.config.ScalafmtConfig -import org.scalafmt.internal.FormatToken -import org.scalafmt.internal.FormatTokens -import org.scalafmt.internal.Modification -import org.scalafmt.internal.Space +import org.scalafmt.config._ +import org.scalafmt.internal._ import org.scalafmt.util.InfixApp._ import org.scalafmt.util.LoggerOps._ import scala.meta._ import scala.meta.classifiers.Classifier import scala.meta.tokens.Token -import scala.meta.tokens.Token._ +import scala.meta.tokens.Token.{Space => _, _} import scala.meta.tokens.Tokens import scala.annotation.tailrec import scala.collection.immutable.HashMap import scala.collection.mutable -import scala.reflect.ClassTag /** Stateless helper functions on `scala.meta.Tree`. */ @@ -92,108 +88,6 @@ object TreeOps { ftoks: FormatTokens, ): Boolean = SingleArgInBraces.orBlock(parent).exists(_._2 eq expr) - def getStatementStarts(tree: Tree, soft: SoftKeywordClasses)(implicit - ftoks: FormatTokens, - ): Map[Int, Tree] = { - val ret = Map.newBuilder[Int, Tree] - ret.sizeHint(tree.tokens.length) - - def addFT(stmt: Tree)(ft: FormatToken): Unit = { - val isComment = ft.left.is[Token.Comment] - val nft = if (isComment) ftoks.nextAfterNonComment(ft) else ft - ret += nft.meta.idx -> stmt - } - def addTok(stmt: Tree)(token: Token) = addFT(stmt)(ftoks.after(token)) - def addTree(t: Tree, stmt: Tree) = ftoks.getHeadOpt(t).foreach(addFT(stmt)) - def addOne(t: Tree) = addTree(t, t) - def addAll(trees: Seq[Tree]) = trees.foreach(addOne) - - def addDefnTokens( - mods: Seq[Mod], - tree: Tree, - what: String, - isMatch: Token => Boolean, - ): Unit = { - // Each @annotation gets a separate line - val annotations = mods.filter(_.is[Mod.Annot]) - addAll(annotations) - mods.find(!_.is[Mod.Annot]) match { - // Non-annotation modifier, for example `sealed`/`abstract` - case Some(x) => addTree(x, tree) - case _ => - // No non-annotation modifier exists, fallback to keyword like `object` - tree.tokens.find(isMatch) - .fold(throw Error.CantFindDefnToken(what, tree))(addTok(tree)) - } - } - def addDefn[T](mods: Seq[Mod], tree: Tree)(implicit - tag: ClassTag[T], - ): Unit = { - val runtimeClass = tag.runtimeClass - addDefnTokens( - mods, - tree, - runtimeClass.getSimpleName, - runtimeClass.isInstance, - ) - } - - def loop(subtree: Tree): Unit = { - subtree match { - case t: Defn.Class => addDefn[KwClass](t.mods, t) - case t: Decl.Def => addDefn[KwDef](t.mods, t) - case t: Defn.Def => addDefn[KwDef](t.mods, t) - case t: Defn.Macro => addDefn[KwDef](t.mods, t) - case t: Decl.Given => addDefn[KwGiven](t.mods, t) - case t: Defn.Given => addDefn[KwGiven](t.mods, t) - case t: Defn.GivenAlias => addDefn[KwGiven](t.mods, t) - case t: Defn.Enum => addDefn[KwEnum](t.mods, t) - case t: Defn.ExtensionGroup => - addDefnTokens(Nil, t, "extension", soft.KwExtension.unapply) - case t: Defn.Object => addDefn[KwObject](t.mods, t) - case t: Defn.Trait => addDefn[KwTrait](t.mods, t) - case t: Defn.Type => addDefn[KwType](t.mods, t) - case t: Decl.Type => addDefn[KwType](t.mods, t) - case t: Defn.Val => addDefn[KwVal](t.mods, t) - case t: Decl.Val => addDefn[KwVal](t.mods, t) - case t: Defn.Var => addDefn[KwVar](t.mods, t) - case t: Decl.Var => addDefn[KwVar](t.mods, t) - case t: Ctor.Secondary => - addDefn[KwDef](t.mods, t) - addAll(t.stats) - // special handling for rewritten blocks - case t @ Term.Block(_ :: Nil) if t.tokens.headOption.exists { x => - // ignore single-stat block if opening brace was removed - x.is[Token.LeftBrace] && ftoks(x).left.ne(x) - } => - case t: Term.EnumeratorsBlock => - var wasGuard = false - t.enums.tail.foreach { x => - val isGuard = x.is[Enumerator.Guard] - // Only guard that follows another guard starts a statement. - if (wasGuard || !isGuard) addOne(x) - wasGuard = isGuard - } - case t: Term.PartialFunction => t.cases match { - case _ :: Nil => - case x => addAll(x) - } - case t @ Term.Block(s) => - if (t.parent.is[CaseTree]) - addAll(if (getSingleStatExceptEndMarker(s).isEmpty) s else s.drop(1)) - else s match { - case (_: Term.FunctionTerm) :: Nil => - case _ => addAll(s) - } - case Tree.Block(s) => addAll(s) - case _ => // Nothing - } - subtree.children.foreach(loop) - } - loop(tree) - ret.result() - } - /** Finds matching parens [({})]. * * Contains lookup keys in both directions, opening [({ and closing })].