Skip to content

Commit

Permalink
RedundantBraces: improve check stat OK in parens
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Dec 20, 2024
1 parent b9d7b3b commit 932189d
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,45 +153,43 @@ class FormatWriter(formatOps: FormatOps) {
while (0 <= idx) {
val loc = locations(idx)
val tok = loc.formatToken
tok.left match {
case _: T.RightBrace // look for "foo { bar }"
if RedundantBraces.canRewriteWithParensOnRightBrace(tok) =>
val beg = matchingLeft(tok).idx
val bloc = locations(beg)
val style = bloc.style
if (
style.rewrite.trailingCommas.isOptional &&
loc.leftLineId == bloc.leftLineId
) {
val state = bloc.state
val inParentheses = style.spaces.inParentheses
// remove space before "{"
if (0 != beg && state.prev.mod.length != 0) {
val prevState = state.prev
prevState.split = prevState.split.withMod(NoSplit)
locations(beg - 1).shift -= 1
}

// update "{"
bloc.replace = "("
if (!inParentheses && state.mod.length != 0) {
// remove space after "{"
state.split = state.split.withMod(NoSplit)
bloc.shift -= 1
}
if (tok.left.is[T.RightBrace]) { // look for "foo { bar }"
val beg = matchingLeft(tok).idx
val bloc = locations(beg)
implicit val style = bloc.style
if (
RedundantBraces.canRewriteWithParensOnRightBrace(tok) &&
style.rewrite.trailingCommas.isOptional &&
loc.leftLineId == bloc.leftLineId
) {
val state = bloc.state
val inParentheses = style.spaces.inParentheses
// remove space before "{"
if (0 != beg && state.prev.mod.length != 0) {
val prevState = state.prev
prevState.split = prevState.split.withMod(NoSplit)
locations(beg - 1).shift -= 1
}

val prevEndLoc = locations(idx - 1)
val prevEndState = prevEndLoc.state
if (!inParentheses && prevEndState.mod.length != 0) {
// remove space before "}"
prevEndState.split = prevEndState.split.withMod(NoSplit)
prevEndLoc.shift -= 1
}
// update "{"
bloc.replace = "("
if (!inParentheses && state.mod.length != 0) {
// remove space after "{"
state.split = state.split.withMod(NoSplit)
bloc.shift -= 1
}

// update "}"
loc.replace = ")"
val prevEndLoc = locations(idx - 1)
val prevEndState = prevEndLoc.state
if (!inParentheses && prevEndState.mod.length != 0) {
// remove space before "}"
prevEndState.split = prevEndState.split.withMod(NoSplit)
prevEndLoc.shift -= 1
}
case _ =>

// update "}"
loc.replace = ")"
}
}
idx -= 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,46 @@ object RedundantBraces extends Rewrite with FormatTokensRewrite.RuleFactory {
case _ => false
}

private def canRewriteBlockWithParens(b: Term.Block)(implicit
ftoks: FormatTokens,
): Boolean = getBlockSingleStat(b).exists(canRewriteStatWithParens)

private def canRewriteStatWithParens(
t: Stat,
)(implicit ftoks: FormatTokens): Boolean = t match {
case f: Term.FunctionTerm => canRewriteFuncWithParens(f)
case _: Term.Assign => false // disallowed in 2.13
case _: Defn => false
case _: Term.PartialFunction => false
case Term.Block(s :: Nil) if !ftoks.isEnclosedInMatching(t) =>
canRewriteStatWithParens(s)
case _ => true
def canRewriteStatWithParens(
stat: Tree,
)(implicit ftoks: FormatTokens): Boolean = {
@tailrec
def stripTopBlock(tree: Tree, singleStatOnly: Boolean): Option[Tree] =
tree match {
case Term.Block(s :: Nil) =>
val ko = (tree eq stat) && ftoks.tokenAfter(s).right.is[T.Semicolon]
if (ko) None else stripTopBlock(s, singleStatOnly)
case b: Term.Block
if b.stats.isEmpty || singleStatOnly ||
!ftoks.isEnclosedInBraces(b) => None
/* guard for statements requiring a wrapper block
* "foo { x => y; z }" can't become "foo(x => y; z)"
* "foo { x1 => x2 => y; z }" can't become "foo(x1 => x2 => y; z)"
*/
case t: Term.FunctionTerm =>
if (needParensAroundParams(t)) None
else stripTopBlock(t.body, singleStatOnly = t eq stat)
case _ => Some(tree)
}
@tailrec
def iter(trees: List[Tree]): Boolean = trees match {
case head :: rest => head match {
case _: Term.Repeated => iter(rest)
case _: Term.PartialFunction | _: Defn | _: Term.Assign => false
case b @ Term.Block(s :: ss) =>
if (ss.isEmpty) iter(s :: rest)
else ftoks.isEnclosedInBraces(b) && iter(rest)
case t: Term.If => iter(t.thenp :: t.elsep :: rest)
case t: Term.FunctionTerm if needParensAroundParams(t) => false
case t: Tree.WithBody => iter(t.body :: rest)
case t: Term.AnonymousFunction => iter(t.body :: rest)
case _ => iter(rest)
}
case _ => true
}
stripTopBlock(stat, singleStatOnly = true).exists(t => iter(t :: Nil))
}

/* guard for statements requiring a wrapper block
* "foo { x => y; z }" can't become "foo(x => y; z)"
* "foo { x1 => x2 => y; z }" can't become "foo(x1 => x2 => y; z)"
*/
@tailrec
private def canRewriteFuncWithParens(f: Term.FunctionTerm): Boolean =
!needParensAroundParams(f) &&
(getTreeSingleStat(f.body) match {
case Some(t: Term.FunctionTerm) => canRewriteFuncWithParens(t)
case Some(_: Defn) => false
case x => x.isDefined
})

private def checkApply(t: Tree): Boolean = t.parent match {
case Some(p @ Term.ArgClause(`t` :: Nil, _)) => isParentAnApply(p)
case _ => false
Expand All @@ -66,9 +77,9 @@ object RedundantBraces extends Rewrite with FormatTokensRewrite.RuleFactory {
ftoks: FormatTokens,
): Boolean = !ftoks.prevNonCommentBefore(rb).left.is[T.Semicolon] &&
(rb.meta.leftOwner match { // look for "foo { bar }"
case b: Term.Block => checkApply(b) && canRewriteBlockWithParens(b) &&
case b: Term.Block => checkApply(b) && canRewriteStatWithParens(b) &&
b.parent.exists(ftoks.getLast(_) eq rb)
case f: Term.FunctionTerm => checkApply(f) && canRewriteFuncWithParens(f)
case f: Term.FunctionTerm => checkApply(f) && canRewriteStatWithParens(f)
case t @ Term.ArgClause(arg :: Nil, _) => isParentAnApply(t) &&
ftoks.getDelimsIfEnclosed(t).exists(_._2 eq rb) &&
canRewriteStatWithParens(arg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ abstract class CommunityIntellijScalaSuite(name: String)
class CommunityIntellijScala_2024_2_Suite
extends CommunityIntellijScalaSuite("intellij-scala-2024.2") {

override protected def totalStatesVisited: Option[Int] = Some(59358052)
override protected def totalStatesVisited: Option[Int] = Some(59357916)

override protected def builds = Seq(getBuild(
"2024.2.28",
Expand Down Expand Up @@ -52,7 +52,7 @@ class CommunityIntellijScala_2024_2_Suite
class CommunityIntellijScala_2024_3_Suite
extends CommunityIntellijScalaSuite("intellij-scala-2024.3") {

override protected def totalStatesVisited: Option[Int] = Some(59575207)
override protected def totalStatesVisited: Option[Int] = Some(59575073)

override protected def builds = Seq(getBuild(
"2024.3.4",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ abstract class CommunityScala2Suite(name: String)

class CommunityScala2_12Suite extends CommunityScala2Suite("scala-2.12") {

override protected def totalStatesVisited: Option[Int] = Some(42585159)
override protected def totalStatesVisited: Option[Int] = Some(42584993)

override protected def builds =
Seq(getBuild("v2.12.20", dialects.Scala212, 1277))
Expand All @@ -18,7 +18,7 @@ class CommunityScala2_12Suite extends CommunityScala2Suite("scala-2.12") {

class CommunityScala2_13Suite extends CommunityScala2Suite("scala-2.13") {

override protected def totalStatesVisited: Option[Int] = Some(53251428)
override protected def totalStatesVisited: Option[Int] = Some(53251209)

override protected def builds =
Seq(getBuild("v2.13.14", dialects.Scala213, 1287))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ abstract class CommunityScala3Suite(name: String)

class CommunityScala3_2Suite extends CommunityScala3Suite("scala-3.2") {

override protected def totalStatesVisited: Option[Int] = Some(39246112)
override protected def totalStatesVisited: Option[Int] = Some(39245552)

override protected def builds = Seq(getBuild("3.2.2", dialects.Scala32, 791))

}

class CommunityScala3_3Suite extends CommunityScala3Suite("scala-3.3") {

override protected def totalStatesVisited: Option[Int] = Some(42433011)
override protected def totalStatesVisited: Option[Int] = Some(42432284)

override protected def builds = Seq(getBuild("3.3.3", dialects.Scala33, 861))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ abstract class CommunitySparkSuite(name: String)

class CommunitySpark3_4Suite extends CommunitySparkSuite("spark-3.4") {

override protected def totalStatesVisited: Option[Int] = Some(86406517)
override protected def totalStatesVisited: Option[Int] = Some(86405215)

override protected def builds = Seq(getBuild("v3.4.1", dialects.Scala213, 2585))

}

class CommunitySpark3_5Suite extends CommunitySparkSuite("spark-3.5") {

override protected def totalStatesVisited: Option[Int] = Some(91407413)
override protected def totalStatesVisited: Option[Int] = Some(91405983)

override protected def builds = Seq(getBuild("v3.5.3", dialects.Scala213, 2756))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions {
val explored = Debug.explored.get()
logger.debug(s"Total explored: $explored")
if (!onlyUnit && !onlyManual)
assertEquals(explored, 1209508, "total explored")
assertEquals(explored, 1209520, "total explored")
val results = debugResults.result()
// TODO(olafur) don't block printing out test results.
// I don't want to deal with scalaz's Tasks :'(
Expand Down

0 comments on commit 932189d

Please sign in to comment.