Skip to content

Commit

Permalink
FormatTokens: use FormatToken in matching
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Nov 14, 2024
1 parent 7f88d86 commit b7885f2
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -312,22 +312,22 @@ object BestFirstSearch {
val result = mutable.Map.empty[Int, Boolean]
var expire: FT = null
@inline
def addRange(t: T): Unit = expire = tokens.matching(t)
def addRange(ft: FT): Unit = expire = tokens.matchingLeft(ft)
@inline
def addBlock(idx: Int): Unit = result.getOrElseUpdate(idx, false)
tokens.foreach {
case ft if expire ne null =>
if (ft eq expire) expire = null else result.update(ft.idx, true)
case FT(t: T.LeftParen, _, m) if (m.leftOwner match {
case ft @ FT(t: T.LeftParen, _, m) if (m.leftOwner match {
case lo: Term.ArgClause => !lo.parent.is[Term.ApplyInfix] &&
!styleMap.at(t).newlines.keep
case _: Term.Apply => true // legacy: when enclosed in parens
case _ => false
}) => addRange(t)
case FT(t: T.LeftBrace, _, m) => m.leftOwner match {
}) => addRange(ft)
case ft @ FT(t: T.LeftBrace, _, m) => m.leftOwner match {
// Type compounds can be inside defn.defs
case lo: meta.Stat.Block if lo.parent.is[Type.Refine] => addRange(t)
case _: Type.Refine => addRange(t)
case lo: meta.Stat.Block if lo.parent.is[Type.Refine] => addRange(ft)
case _: Type.Refine => addRange(ft)
case lo: Term.PartialFunction
if lo.cases.lengthCompare(1) == 0 &&
styleMap.at(t).newlines.fold => addBlock(m.idx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class FormatOps(

implicit val dialect: Dialect = initStyle.dialect
implicit val (tokens: FormatTokens, styleMap: StyleMap) =
FormatTokens(topSourceTree.tokens, owners)(initStyle)
FormatTokens(topSourceTree.tokens, ownersMap)(initStyle)
import tokens._

private[internal] implicit val soft: SoftKeywordClasses =
Expand All @@ -49,9 +49,6 @@ class FormatOps(

val optimizationEntities = OptimizationEntities(topSourceTree)

@inline
def owners(token: T): Tree = ownersMap(hash(token))

@tailrec
final def findFirst(start: FT, end: FT)(f: FT => Boolean): Option[FT] =
if (start.idx >= end.idx) None
Expand Down Expand Up @@ -83,14 +80,14 @@ class FormatOps(
!style.newlines.isBeforeOpenParenCallSite
case t => !isJustBeforeTree(start)(t)
}) => null
case t: T.RightParen =>
case _: T.RightParen =>
if (start.left.is[T.LeftParen]) null
else {
val owner = start.rightOwner
val isDefnSite = isParamClauseSite(owner)
implicit val clauseSiteFlags: ClauseSiteFlags =
ClauseSiteFlags(owner, isDefnSite)
val bpFlags = getBinpackSiteFlags(matching(t), start, false)
val bpFlags = getBinpackSiteFlags(matchingRight(start), start, false)
if (bpFlags.scalaJsStyle) scalaJsOptCloseOnRight(start, bpFlags)
else if (
!start.left.is[T.RightParen] ||
Expand Down Expand Up @@ -156,7 +153,7 @@ class FormatOps(
optimizationEntities.statementStarts.get(meta.idx + 1),
)

def parensTuple(ft: FT): TokenRanges = matchingOpt(ft.left)
def parensTuple(ft: FT): TokenRanges = matchingOptLeft(ft)
.fold(TokenRanges.empty)(other => TokenRanges(TokenRange(ft, other)))
def parensTuple(tree: Tree): TokenRanges = parensTuple(getLast(tree))

Expand All @@ -165,7 +162,7 @@ class FormatOps(
): TokenRanges = insideBlock(start, end, x => classifier(x.left))

def insideBlock(start: FT, end: FT, matches: FT => Boolean): TokenRanges =
insideBlock(x => if (matches(x)) matchingOpt(x.left) else None)(start, end)
insideBlock(x => if (matches(x)) matchingOptLeft(x) else None)(start, end)

def insideBracesBlock(start: FT, end: FT, parensToo: Boolean = false)(implicit
style: ScalafmtConfig,
Expand Down Expand Up @@ -317,7 +314,7 @@ class FormatOps(
val ok = initStyle.newlines.alwaysBeforeElseAfterCurlyIf || ! {
ft.leftOwner.is[Term.Block] && ft.left.is[T.RightBrace] &&
!prevNonCommentSameLineBefore(ft).left.is[T.LeftBrace] &&
matchingOpt(ft.left).exists { lb =>
matchingOptLeft(ft).exists { lb =>
prev(lb).left.start < term.thenp.pos.start
}
}
Expand All @@ -335,7 +332,7 @@ class FormatOps(
term.elsep match {
case t: Term.If => getElseChain(t, newRes)
case b @ Term.Block((t: Term.If) :: Nil)
if !areMatching(ftElsep.right)(getLast(b).left) =>
if !matchingOptRight(ftElsep).exists(_ eq getLast(b)) =>
getElseChain(t, newRes)
case _ => newRes
}
Expand Down Expand Up @@ -644,7 +641,7 @@ class FormatOps(
if (isAfterOp) infixSequenceMaxPrecedence(fullInfix) else 0 // 0 unused
val breakPenalty = if (isAfterOp) maxPrecedence - app.precedence else 1

val closeOpt = matchingOpt(ft.right)
val closeOpt = matchingOptRight(ft)
val finalExpireCost = fullExpire -> 0
val expires =
if (closeOpt.isDefined) finalExpireCost :: Nil
Expand Down Expand Up @@ -879,7 +876,7 @@ class FormatOps(
val toMatch =
if (tok.right.is[T.RightParen])
// Hack to allow any annotations with arguments like @foo(1)
tokens(matching(tok.right), -2)
tokens(matchingRight(tok), -2)
else tok
toMatch match {
case FT(T.At(), _: T.Ident, _) => true
Expand Down Expand Up @@ -1006,7 +1003,7 @@ class FormatOps(
val forces = Set.newBuilder[Int]
def process(clause: Member.SyntaxValuesClause, ftOpen: FT)(
cfg: ScalafmtOptimizer.ClauseElement,
): Unit = if (cfg.isEnabled) matchingOpt(ftOpen.left).foreach { close =>
): Unit = if (cfg.isEnabled) matchingOptLeft(ftOpen).foreach { close =>
val values = clause.values
if (
values.lengthCompare(cfg.minCount) >= 0 &&
Expand Down Expand Up @@ -1041,7 +1038,7 @@ class FormatOps(

val FT(open, r, _) = ft
val nft = next(ft)
val close = matching(open)
val close = matchingLeft(ft)
val beforeClose = prev(close)
val indentParam = Num(style.indent.getDefnSite(lpOwner))
val indentSep = Num((indentParam.n - 2).max(0))
Expand Down Expand Up @@ -1387,24 +1384,24 @@ class FormatOps(
case Some(x) => x
case None => findXmlLastLineIndent(prev(ft))
}
case t: T.Xml.SpliceEnd => findXmlLastLineIndent(prev(matching(t)))
case _: T.Xml.SpliceEnd => findXmlLastLineIndent(prev(matchingLeft(ft)))
case _ => findXmlLastLineIndent(prev(ft))
}

def withIndentOnXmlStart(tok: T.Xml.Start, splits: Seq[Split])(implicit
def withIndentOnXmlStart(xmlEnd: => FT, splits: Seq[Split])(implicit
style: ScalafmtConfig,
): Seq[Split] =
if (style.xmlLiterals.assumeFormatted) {
val end = matching(tok)
val end = xmlEnd
val indent = Num(findXmlLastLineIndent(prev(end)), true)
splits.map(_.withIndent(indent, end, ExpiresOn.After))
} else splits

def withIndentOnXmlSpliceStart(ft: FT, splits: Seq[Split])(implicit
style: ScalafmtConfig,
): Seq[Split] = ft.left match {
case t: T.Xml.SpliceStart if style.xmlLiterals.assumeFormatted =>
val end = matching(t)
case _: T.Xml.SpliceStart if style.xmlLiterals.assumeFormatted =>
val end = matchingLeft(ft)
val indent = Num(findXmlLastLineIndent(prev(ft)), true)
splits.map(_.withIndent(indent, end, ExpiresOn.After))
case _ => splits
Expand Down Expand Up @@ -1685,8 +1682,8 @@ class FormatOps(
ft: FT,
style: ScalafmtConfig,
): Split = withNLPolicy(endFt) {
val right = nextNonComment(ft).right
val rpOpt = if (right.is[T.LeftParen]) matchingOpt(right) else None
val nft = nextNonComment(ft)
val rpOpt = if (nft.right.is[T.LeftParen]) matchingOptRight(nft) else None
val expire = nextNonCommentSameLine(rpOpt.fold(endFt) { rp =>
if (rp.left.end >= endFt.left.end) rp else endFt
})
Expand Down Expand Up @@ -2117,7 +2114,7 @@ class FormatOps(
case t: Term.While => t.expr match {
case b: Term.Block
if isMultiStatBlock(b) &&
!matchingOpt(nft.right).exists(_.left.end >= b.pos.end) =>
!matchingOptRight(nft).exists(_.left.end >= b.pos.end) =>
Some(new OptionalBracesRegion {
def owner = Some(t)
def splits = Some {
Expand Down Expand Up @@ -2291,18 +2288,16 @@ class FormatOps(
nft: FT,
)(implicit style: ScalafmtConfig, ft: FT): Option[OptionalBracesRegion] =
ft.meta.leftOwner match {
case t: Term.If =>
val nr = nft.right
t.cond match {
case b: Term.Block if (matchingOpt(nr) match {
case t: Term.If => t.cond match {
case b: Term.Block if (matchingOptRight(nft) match {
case Some(t) => t.left.end < b.pos.end
case None => isMultiStatBlock(b)
}) =>
Some(new OptionalBracesRegion {
def owner = Some(t)
def splits = Some {
val dangle = style.danglingParentheses.ctrlSite
val forceNL = !nr.is[T.LeftParen]
val forceNL = !nft.right.is[T.LeftParen]
getSplits(b, forceNL, dangle)
}
def rightBrace = blockLast(b)
Expand Down Expand Up @@ -2411,7 +2406,7 @@ class FormatOps(
case _: T.KwThen => true
case _: T.LeftBrace => false
case _ => !isTreeSingleExpr(thenp) &&
(!before.right.is[T.LeftBrace] || matchingOpt(before.right)
(!before.right.is[T.LeftBrace] || matchingOptRight(before)
.exists(_.left.end < thenp.pos.end))
}
}
Expand Down Expand Up @@ -2651,8 +2646,8 @@ class FormatOps(
def getEndOfBlock(ft: FT, parensToo: => Boolean)(implicit
style: ScalafmtConfig,
): Option[FT] = ft.left match {
case x: T.LeftBrace => matchingOpt(x)
case x: T.LeftParen => if (parensToo) matchingOpt(x) else None
case _: T.LeftBrace => matchingOptLeft(ft)
case _: T.LeftParen => if (parensToo) matchingOptLeft(ft) else None
case _ => OptionalBraces.get(ft)
.flatMap(_.rightBrace.map(x => nextNonCommentSameLine(x)))
}
Expand Down Expand Up @@ -2803,7 +2798,7 @@ class FormatOps(
ftAfterClose.right.is[T.RightParen] && ftAfterClose.noBreak &&
isArgClauseSite(ftAfterClose.meta.rightOwner)
if (continue) {
val open = matching(ftAfterClose.right)
val open = matchingRight(ftAfterClose)
implicit val style: ScalafmtConfig = styleMap.at(open)
implicit val clauseSiteFlags: ClauseSiteFlags = ClauseSiteFlags
.atCallSite(ftAfterClose.meta.rightOwner)
Expand Down Expand Up @@ -2839,7 +2834,7 @@ class FormatOps(
@tailrec
def iter(currft: FT): Option[Policy] = {
val prevft = prevNonComment(currft)
val breakBeforeClose = matchingOpt(prevft.left) match {
val breakBeforeClose = matchingOptLeft(prevft) match {
case Some(open) =>
val cfg = styleMap.at(open)
def cfgStyle = cfg.configStyleCallSite.prefer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FT])
result.result()
}(arr)

private lazy val matchingParentheses: Map[TokenHash, FT] = TreeOps
.getMatchingParentheses(arr.view)(_.left)
private lazy val matchingParentheses: Map[Int, FT] = TreeOps
.getMatchingParentheses(arr.view)(_.idx)(_.left)

override def length: Int = arr.length
override def apply(idx: Int): FT = arr(idx)
Expand All @@ -39,11 +39,10 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FT])
else at(idx + 1)
}

private def get(tok: T, isBefore: Boolean): FT =
getAt(tok, isBefore)(leftTok2tok.getOrElse(
FormatTokens.thash(tok),
FormatTokens.throwNoToken(tok, "Missing token index"),
))
private def get(tok: T, isBefore: Boolean): FT = getAt(tok, isBefore)(
leftTok2tok
.getOrElse(hash(tok), FormatTokens.throwNoToken(tok, "Missing token index")),
)

def at(off: Int): FT =
if (off < 0) arr.head else if (off < arr.length) arr(off) else arr.last
Expand Down Expand Up @@ -77,27 +76,26 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FT])
def next(ft: FT): FT = apply(ft, 1)

@inline
def matching(token: T): FT = matchingParentheses.getOrElse(
FormatTokens.thash(token),
private def matching(idx: Int, token: T): FT = matchingParentheses.getOrElse(
idx,
FormatTokens.throwNoToken(token, "Missing matching token index"),
)
@inline
def matchingOpt(token: T): Option[FT] = matchingParentheses
.get(FormatTokens.thash(token))
def matchingLeft(ft: FT): FT = matching(ft.idx, ft.left)
@inline
def hasMatching(token: T): Boolean = matchingParentheses
.contains(FormatTokens.thash(token))
def matchingRight(ft: FT): FT = matching(ft.idx + 1, ft.right)
@inline
def areMatching(t1: T)(t2: => T): Boolean = matchingOpt(t1) match {
case Some(x) => x.left eq t2
case _ => false
}
private def matchingOpt(idx: Int): Option[FT] = matchingParentheses.get(idx)
@inline
def matchingOptLeft(ft: FT): Option[FT] = matchingOpt(ft.idx)
@inline
def matchingOptRight(ft: FT): Option[FT] = matchingOpt(ft.idx + 1)

def getHeadAndLastIfEnclosed(
tokens: Tokens,
tree: Tree,
): Option[(FT, Option[FT])] = getHeadOpt(tokens, tree).map { head =>
head -> matchingOpt(head.left).flatMap { other =>
head -> matchingOptLeft(head).flatMap { other =>
val last = getLastNonTrivial(tokens, tree)
if (last eq other) Some(last) else None
}
Expand Down Expand Up @@ -142,11 +140,11 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FT])
getClosingIfWithinParens(tree) != Left(false)

def getClosingIfWithinParens(last: FT)(head: FT): Either[Boolean, FT] = {
val innerMatched = areMatching(last.left)(head.left)
val innerMatched = matchingOptLeft(last).contains(head)
if (innerMatched && last.left.is[T.RightParen]) Right(prev(last))
else {
val afterLast = nextNonComment(last)
if (areMatching(afterLast.right)(prevNonCommentBefore(head).left))
if (matchingOptRight(afterLast).exists(_ eq prevNonCommentBefore(head)))
if (afterLast.right.is[T.RightParen]) Right(afterLast) else Left(true)
else Left(innerMatched)
}
Expand Down Expand Up @@ -409,7 +407,7 @@ object FormatTokens {
* Since tokens might be very large, we try to allocate as little memory as
* possible.
*/
def apply(tokens: Tokens, owner: T => Tree)(implicit
def apply(tokens: Tokens, owners: collection.Map[TokenHash, Tree])(implicit
style: ScalafmtConfig,
): (FormatTokens, StyleMap) = {
var left: T = null
Expand All @@ -421,7 +419,7 @@ object FormatTokens {
var fmtWasOff = false
val arr = tokens.toArray
def process(right: T): Unit = {
val rmeta = FT.TokenMeta(owner(right), right.text)
val rmeta = FT.TokenMeta(owners(hash(right)), right.text)
if (left eq null) fmtWasOff = isFormatOff(right)
else {
val between = arr.slice(wsIdx, tokIdx)
Expand Down Expand Up @@ -454,13 +452,10 @@ object FormatTokens {
s"$msg ${t.structure} @${t.pos.startLine}:${t.pos.startColumn}: `$t`",
)

@inline
def thash(token: T): TokenHash = hash(token)

class TokenToIndexMapBuilder {
private val builder = Map.newBuilder[TokenHash, Int]
def sizeHint(size: Int): Unit = builder.sizeHint(size)
def add(idx: Int)(token: T): Unit = builder += thash(token) -> idx
def add(idx: Int)(token: T): Unit = builder += hash(token) -> idx
def result(): Map[TokenHash, Int] = builder.result()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ class FormatWriter(formatOps: FormatOps) {
val loc = locations(idx)
val tok = loc.formatToken
tok.left match {
case rb: T.RightBrace // look for "foo { bar }"
case _: T.RightBrace // look for "foo { bar }"
if RedundantBraces.canRewriteWithParensOnRightBrace(tok) =>
val beg = matching(rb).idx
val beg = matchingLeft(tok).idx
val bloc = locations(beg)
val style = bloc.style
if (
Expand Down
Loading

0 comments on commit b7885f2

Please sign in to comment.