From 20072fad2bae6ea89c29bbf3fa5b5a89ec861775 Mon Sep 17 00:00:00 2001 From: Albert Meltzer <7529386+kitbellew@users.noreply.github.com> Date: Mon, 16 Dec 2024 20:47:50 -0800 Subject: [PATCH] BestFirstSearch: skip states which have NL and SLB --- .../scalafmt/internal/BestFirstSearch.scala | 46 +++++++++++++++---- .../org/scalafmt/internal/PolicySummary.scala | 2 + .../scala/org/scalafmt/internal/State.scala | 4 ++ .../src/test/resources/default/Advanced.stat | 2 +- .../resources/newlines/source_classic.stat | 4 +- .../test/resources/newlines/source_fold.stat | 2 +- .../resources/scala3/OptionalBraces_fold.stat | 2 +- .../test/scala/org/scalafmt/FormatTests.scala | 2 +- 8 files changed, 49 insertions(+), 15 deletions(-) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala index da1abc5c0..8f7df158d 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/BestFirstSearch.scala @@ -7,6 +7,8 @@ import org.scalafmt.util._ import scala.meta._ import scala.meta.tokens.{Token => T} +import java.util + import scala.annotation.tailrec import scala.collection.mutable @@ -25,11 +27,16 @@ private class BestFirstSearch private (range: Set[Range])(implicit /** Precomputed table of splits for each token. */ - val routes: Array[Seq[Split]] = { + val (routes, hasNLOnly): (Array[Seq[Split]], Array[Int]) = { val router = new Router(formatOps) val result = Array.newBuilder[Seq[Split]] - tokens.foreach(t => result += router.getSplits(t)) - result.result() + val nlOnly = Array.newBuilder[Int] + tokens.foreach { t => + val splits = router.getSplits(t) + result += splits + if (splits.forall(_.isNL)) nlOnly += t.idx + } + (result.result(), nlOnly.result()) } private val noOptZones = if (useNoOptZones(initStyle)) getNoOptZones(tokens) else null @@ -133,8 +140,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit if (noBlockClose) None else getBlockCloseToRecurse(splitToken) .flatMap(shortestPathMemo(curr, _, depth + 1, isOpt)) - if (blockCloseState.nonEmpty) blockCloseState - .foreach(_.foreach(Q.enqueue)) + if (blockCloseState.nonEmpty) blockCloseState.foreach(_.foreach(enqueue)) else { if (optimizer.escapeInPathologicalCases && isSeqMulti(routes(idx))) stats.explode(splitToken, optimizer.maxVisitsPerToken)( @@ -161,10 +167,8 @@ private class BestFirstSearch private (range: Set[Range])(implicit } case _ => nextState } - if (null ne stateToQueue) { - stats.updateBest(nextState, stateToQueue) - Q.enqueue(stateToQueue) - } + if (enqueue(stateToQueue)) stats + .updateBest(nextState, stateToQueue) } else preFork = false } @@ -194,6 +198,29 @@ private class BestFirstSearch private (range: Set[Range])(implicit state.next(split) } + private def getNLOnlyIndexAfter(ftIdx: Int) = { + val idx = util.Arrays.binarySearch(hasNLOnly, ftIdx + 1) + if (idx >= 0) idx else -(idx + 1) + } + + private def hasImpossibleSlb(state: State, optFtIdx: Int = -1) = { + val maxFtIdx = state.policy.maxEndPos.fold(optFtIdx)(_.endIdx.max(optFtIdx)) + @tailrec + def iter(nlOnlyIdx: Int): Boolean = nlOnlyIdx < hasNLOnly.length && { + val nlOnlyFtIdx = hasNLOnly(nlOnlyIdx) + state.hasSlbOn(tokens(nlOnlyFtIdx)) || + nlOnlyFtIdx < maxFtIdx && iter(nlOnlyIdx + 1) + } + iter(getNLOnlyIndexAfter(state.depth)) + } + + private def enqueue(state: State)(implicit q: StateQueue): Boolean = + (state ne null) && { + val ok = !hasImpossibleSlb(state) + if (ok) q.enqueue(state) + ok + } + private def willKillOnFail(kill: Boolean, end: => FT)(implicit nextState: State, ): Boolean = kill || nextState.hasSlbUntil(end) @@ -209,6 +236,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit val optIdx = opt.token.idx val nextNextState = if (optIdx <= nextState.depth) nextState + else if (hasImpossibleSlb(nextState, optIdx)) return Left(null) else if (tokens.width(nextState.depth, optIdx) > 3 * style.maxColumn) return Left(killOnFail(opt)) else { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/PolicySummary.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/PolicySummary.scala index d5e3787d6..a31bded8e 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/PolicySummary.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/PolicySummary.scala @@ -28,6 +28,8 @@ class PolicySummary(val policies: Seq[Policy]) extends AnyVal { @inline def exists(f: Policy => Boolean): Boolean = policies.exists(f) + + def maxEndPos = policies.iterator.map(_.maxEndPos).maxOption } object PolicySummary { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala index 923ab1b32..5130d5402 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala @@ -55,6 +55,10 @@ final class State( def hasSlbUntil(ft: FT): Boolean = policy .exists(_.appliesUntil(ft)(_.isInstanceOf[PolicyOps.SingleLineBlock])) + def hasSlbOn(ft: FT): Boolean = policy.exists( + _.appliesOn(ft)(_.isInstanceOf[PolicyOps.SingleLineBlock]).contains(true), + ) + def hasSlb(): Boolean = policy .exists(_.exists(_.isInstanceOf[PolicyOps.SingleLineBlock])) diff --git a/scalafmt-tests/shared/src/test/resources/default/Advanced.stat b/scalafmt-tests/shared/src/test/resources/default/Advanced.stat index 1c74667f8..3efc9a57f 100644 --- a/scalafmt-tests/shared/src/test/resources/default/Advanced.stat +++ b/scalafmt-tests/shared/src/test/resources/default/Advanced.stat @@ -377,7 +377,7 @@ private def withNewLocalDefs = { })) })) } ->>> { stateVisits = 5621, stateVisits2 = 5621 } +>>> { stateVisits = 5267, stateVisits2 = 5267 } val createIsArrayOfStat = { envFieldDef( "isArrayOf", diff --git a/scalafmt-tests/shared/src/test/resources/newlines/source_classic.stat b/scalafmt-tests/shared/src/test/resources/newlines/source_classic.stat index 389c2f047..c2ce19597 100644 --- a/scalafmt-tests/shared/src/test/resources/newlines/source_classic.stat +++ b/scalafmt-tests/shared/src/test/resources/newlines/source_classic.stat @@ -8897,7 +8897,7 @@ object a { ) ) } ->>> { stateVisits = 4176, stateVisits2 = 3915 } +>>> { stateVisits = 3915, stateVisits2 = 3915 } object a { div(cls := "cover")( div(cls := "doc")(bodyContents), @@ -9352,7 +9352,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Try(ExpressionEncoder[A22]()).toOption :: Nil } } ->>> { stateVisits = 2626, stateVisits2 = 1038 } +>>> { stateVisits = 2618, stateVisits2 = 1038 } class UDFRegistration private[sql] ( functionRegistry: FunctionRegistry ) extends Logging { diff --git a/scalafmt-tests/shared/src/test/resources/newlines/source_fold.stat b/scalafmt-tests/shared/src/test/resources/newlines/source_fold.stat index 29df224c3..c7db5b02b 100644 --- a/scalafmt-tests/shared/src/test/resources/newlines/source_fold.stat +++ b/scalafmt-tests/shared/src/test/resources/newlines/source_fold.stat @@ -9436,7 +9436,7 @@ object a { } } } ->>> { stateVisits = 2151, stateVisits2 = 2151 } +>>> { stateVisits = 2001, stateVisits2 = 2001 } object a { private object MemoMap { def make(implicit trace: Trace): UIO[MemoMap] = Ref.Synchronized diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat index 331565af7..bdcba06c5 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat @@ -7195,7 +7195,7 @@ final case class UserCheck( case b => } ) ->>> { stateVisits = 3818, stateVisits2 = 3818 } +>>> { stateVisits = 3889, stateVisits2 = 3889 } final case class UserCheck( options: PublishSetupOptions, configDb: () => ConfigDb, diff --git a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala index bc53b08d0..abed78dfe 100644 --- a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala +++ b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala @@ -144,7 +144,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions { val explored = Debug.explored.get() logger.debug(s"Total explored: $explored") if (!onlyUnit && !onlyManual) - assertEquals(explored, 1208516, "total explored") + assertEquals(explored, 1163630, "total explored") val results = debugResults.result() // TODO(olafur) don't block printing out test results. // I don't want to deal with scalaz's Tasks :'(