Skip to content

Commit

Permalink
BestFirstSearch: skip states which have NL and SLB
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Dec 20, 2024
1 parent 8fb149e commit 20072fa
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import org.scalafmt.util._
import scala.meta._
import scala.meta.tokens.{Token => T}

import java.util

import scala.annotation.tailrec
import scala.collection.mutable

Expand All @@ -25,11 +27,16 @@ private class BestFirstSearch private (range: Set[Range])(implicit

/** Precomputed table of splits for each token.
*/
val routes: Array[Seq[Split]] = {
val (routes, hasNLOnly): (Array[Seq[Split]], Array[Int]) = {
val router = new Router(formatOps)
val result = Array.newBuilder[Seq[Split]]
tokens.foreach(t => result += router.getSplits(t))
result.result()
val nlOnly = Array.newBuilder[Int]
tokens.foreach { t =>
val splits = router.getSplits(t)
result += splits
if (splits.forall(_.isNL)) nlOnly += t.idx
}
(result.result(), nlOnly.result())
}
private val noOptZones =
if (useNoOptZones(initStyle)) getNoOptZones(tokens) else null
Expand Down Expand Up @@ -133,8 +140,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
if (noBlockClose) None
else getBlockCloseToRecurse(splitToken)
.flatMap(shortestPathMemo(curr, _, depth + 1, isOpt))
if (blockCloseState.nonEmpty) blockCloseState
.foreach(_.foreach(Q.enqueue))
if (blockCloseState.nonEmpty) blockCloseState.foreach(_.foreach(enqueue))
else {
if (optimizer.escapeInPathologicalCases && isSeqMulti(routes(idx)))
stats.explode(splitToken, optimizer.maxVisitsPerToken)(
Expand All @@ -161,10 +167,8 @@ private class BestFirstSearch private (range: Set[Range])(implicit
}
case _ => nextState
}
if (null ne stateToQueue) {
stats.updateBest(nextState, stateToQueue)
Q.enqueue(stateToQueue)
}
if (enqueue(stateToQueue)) stats
.updateBest(nextState, stateToQueue)
} else preFork = false
}

Expand Down Expand Up @@ -194,6 +198,29 @@ private class BestFirstSearch private (range: Set[Range])(implicit
state.next(split)
}

private def getNLOnlyIndexAfter(ftIdx: Int) = {
val idx = util.Arrays.binarySearch(hasNLOnly, ftIdx + 1)
if (idx >= 0) idx else -(idx + 1)
}

private def hasImpossibleSlb(state: State, optFtIdx: Int = -1) = {
val maxFtIdx = state.policy.maxEndPos.fold(optFtIdx)(_.endIdx.max(optFtIdx))
@tailrec
def iter(nlOnlyIdx: Int): Boolean = nlOnlyIdx < hasNLOnly.length && {
val nlOnlyFtIdx = hasNLOnly(nlOnlyIdx)
state.hasSlbOn(tokens(nlOnlyFtIdx)) ||
nlOnlyFtIdx < maxFtIdx && iter(nlOnlyIdx + 1)
}
iter(getNLOnlyIndexAfter(state.depth))
}

private def enqueue(state: State)(implicit q: StateQueue): Boolean =
(state ne null) && {
val ok = !hasImpossibleSlb(state)
if (ok) q.enqueue(state)
ok
}

private def willKillOnFail(kill: Boolean, end: => FT)(implicit
nextState: State,
): Boolean = kill || nextState.hasSlbUntil(end)
Expand All @@ -209,6 +236,7 @@ private class BestFirstSearch private (range: Set[Range])(implicit
val optIdx = opt.token.idx
val nextNextState =
if (optIdx <= nextState.depth) nextState
else if (hasImpossibleSlb(nextState, optIdx)) return Left(null)
else if (tokens.width(nextState.depth, optIdx) > 3 * style.maxColumn)
return Left(killOnFail(opt))
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class PolicySummary(val policies: Seq[Policy]) extends AnyVal {

@inline
def exists(f: Policy => Boolean): Boolean = policies.exists(f)

def maxEndPos = policies.iterator.map(_.maxEndPos).maxOption
}

object PolicySummary {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ final class State(
def hasSlbUntil(ft: FT): Boolean = policy
.exists(_.appliesUntil(ft)(_.isInstanceOf[PolicyOps.SingleLineBlock]))

def hasSlbOn(ft: FT): Boolean = policy.exists(
_.appliesOn(ft)(_.isInstanceOf[PolicyOps.SingleLineBlock]).contains(true),
)

def hasSlb(): Boolean = policy
.exists(_.exists(_.isInstanceOf[PolicyOps.SingleLineBlock]))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ private def withNewLocalDefs = {
}))
}))
}
>>> { stateVisits = 5621, stateVisits2 = 5621 }
>>> { stateVisits = 5267, stateVisits2 = 5267 }
val createIsArrayOfStat = {
envFieldDef(
"isArrayOf",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8897,7 +8897,7 @@ object a {
)
)
}
>>> { stateVisits = 4176, stateVisits2 = 3915 }
>>> { stateVisits = 3915, stateVisits2 = 3915 }
object a {
div(cls := "cover")(
div(cls := "doc")(bodyContents),
Expand Down Expand Up @@ -9352,7 +9352,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Try(ExpressionEncoder[A22]()).toOption :: Nil
}
}
>>> { stateVisits = 2626, stateVisits2 = 1038 }
>>> { stateVisits = 2618, stateVisits2 = 1038 }
class UDFRegistration private[sql] (
functionRegistry: FunctionRegistry
) extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9436,7 +9436,7 @@ object a {
}
}
}
>>> { stateVisits = 2151, stateVisits2 = 2151 }
>>> { stateVisits = 2001, stateVisits2 = 2001 }
object a {
private object MemoMap {
def make(implicit trace: Trace): UIO[MemoMap] = Ref.Synchronized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7195,7 +7195,7 @@ final case class UserCheck(
case b =>
}
)
>>> { stateVisits = 3818, stateVisits2 = 3818 }
>>> { stateVisits = 3889, stateVisits2 = 3889 }
final case class UserCheck(
options: PublishSetupOptions,
configDb: () => ConfigDb,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions {
val explored = Debug.explored.get()
logger.debug(s"Total explored: $explored")
if (!onlyUnit && !onlyManual)
assertEquals(explored, 1208516, "total explored")
assertEquals(explored, 1163630, "total explored")
val results = debugResults.result()
// TODO(olafur) don't block printing out test results.
// I don't want to deal with scalaz's Tasks :'(
Expand Down

0 comments on commit 20072fa

Please sign in to comment.