diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala index 225d104a9..5af6c3a5b 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala @@ -153,45 +153,43 @@ class FormatWriter(formatOps: FormatOps) { while (0 <= idx) { val loc = locations(idx) val tok = loc.formatToken - tok.left match { - case _: T.RightBrace // look for "foo { bar }" - if RedundantBraces.canRewriteWithParensOnRightBrace(tok) => - val beg = matchingLeft(tok).idx - val bloc = locations(beg) - val style = bloc.style - if ( - style.rewrite.trailingCommas.isOptional && - loc.leftLineId == bloc.leftLineId - ) { - val state = bloc.state - val inParentheses = style.spaces.inParentheses - // remove space before "{" - if (0 != beg && state.prev.mod.length != 0) { - val prevState = state.prev - prevState.split = prevState.split.withMod(NoSplit) - locations(beg - 1).shift -= 1 - } - - // update "{" - bloc.replace = "(" - if (!inParentheses && state.mod.length != 0) { - // remove space after "{" - state.split = state.split.withMod(NoSplit) - bloc.shift -= 1 - } + if (tok.left.is[T.RightBrace]) { // look for "foo { bar }" + val beg = matchingLeft(tok).idx + val bloc = locations(beg) + implicit val style = bloc.style + if ( + RedundantBraces.canRewriteWithParensOnRightBrace(tok) && + style.rewrite.trailingCommas.isOptional && + loc.leftLineId == bloc.leftLineId + ) { + val state = bloc.state + val inParentheses = style.spaces.inParentheses + // remove space before "{" + if (0 != beg && state.prev.mod.length != 0) { + val prevState = state.prev + prevState.split = prevState.split.withMod(NoSplit) + locations(beg - 1).shift -= 1 + } - val prevEndLoc = locations(idx - 1) - val prevEndState = prevEndLoc.state - if (!inParentheses && prevEndState.mod.length != 0) { - // remove space before "}" - prevEndState.split = prevEndState.split.withMod(NoSplit) - prevEndLoc.shift -= 1 - } + // update "{" + bloc.replace = "(" + if (!inParentheses && state.mod.length != 0) { + // remove space after "{" + state.split = state.split.withMod(NoSplit) + bloc.shift -= 1 + } - // update "}" - loc.replace = ")" + val prevEndLoc = locations(idx - 1) + val prevEndState = prevEndLoc.state + if (!inParentheses && prevEndState.mod.length != 0) { + // remove space before "}" + prevEndState.split = prevEndState.split.withMod(NoSplit) + prevEndLoc.shift -= 1 } - case _ => + + // update "}" + loc.replace = ")" + } } idx -= 1 } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala index e1b77e9a1..480612765 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala @@ -28,35 +28,46 @@ object RedundantBraces extends Rewrite with FormatTokensRewrite.RuleFactory { case _ => false } - private def canRewriteBlockWithParens(b: Term.Block)(implicit - ftoks: FormatTokens, - ): Boolean = getBlockSingleStat(b).exists(canRewriteStatWithParens) - - private def canRewriteStatWithParens( - t: Stat, - )(implicit ftoks: FormatTokens): Boolean = t match { - case f: Term.FunctionTerm => canRewriteFuncWithParens(f) - case _: Term.Assign => false // disallowed in 2.13 - case _: Defn => false - case _: Term.PartialFunction => false - case Term.Block(s :: Nil) if !ftoks.isEnclosedInMatching(t) => - canRewriteStatWithParens(s) - case _ => true + def canRewriteStatWithParens( + stat: Tree, + )(implicit ftoks: FormatTokens): Boolean = { + @tailrec + def stripTopBlock(tree: Tree, singleStatOnly: Boolean): Option[Tree] = + tree match { + case Term.Block(s :: Nil) => + val ko = (tree eq stat) && ftoks.tokenAfter(s).right.is[T.Semicolon] + if (ko) None else stripTopBlock(s, singleStatOnly) + case b: Term.Block + if b.stats.isEmpty || singleStatOnly || + !ftoks.isEnclosedInBraces(b) => None + /* guard for statements requiring a wrapper block + * "foo { x => y; z }" can't become "foo(x => y; z)" + * "foo { x1 => x2 => y; z }" can't become "foo(x1 => x2 => y; z)" + */ + case t: Term.FunctionTerm => + if (needParensAroundParams(t)) None + else stripTopBlock(t.body, singleStatOnly = t eq stat) + case _ => Some(tree) + } + @tailrec + def iter(trees: List[Tree]): Boolean = trees match { + case head :: rest => head match { + case _: Term.Repeated => iter(rest) + case _: Term.PartialFunction | _: Defn | _: Term.Assign => false + case b @ Term.Block(s :: ss) => + if (ss.isEmpty) iter(s :: rest) + else ftoks.isEnclosedInBraces(b) && iter(rest) + case t: Term.If => iter(t.thenp :: t.elsep :: rest) + case t: Term.FunctionTerm if needParensAroundParams(t) => false + case t: Tree.WithBody => iter(t.body :: rest) + case t: Term.AnonymousFunction => iter(t.body :: rest) + case _ => iter(rest) + } + case _ => true + } + stripTopBlock(stat, singleStatOnly = true).exists(t => iter(t :: Nil)) } - /* guard for statements requiring a wrapper block - * "foo { x => y; z }" can't become "foo(x => y; z)" - * "foo { x1 => x2 => y; z }" can't become "foo(x1 => x2 => y; z)" - */ - @tailrec - private def canRewriteFuncWithParens(f: Term.FunctionTerm): Boolean = - !needParensAroundParams(f) && - (getTreeSingleStat(f.body) match { - case Some(t: Term.FunctionTerm) => canRewriteFuncWithParens(t) - case Some(_: Defn) => false - case x => x.isDefined - }) - private def checkApply(t: Tree): Boolean = t.parent match { case Some(p @ Term.ArgClause(`t` :: Nil, _)) => isParentAnApply(p) case _ => false @@ -66,9 +77,9 @@ object RedundantBraces extends Rewrite with FormatTokensRewrite.RuleFactory { ftoks: FormatTokens, ): Boolean = !ftoks.prevNonCommentBefore(rb).left.is[T.Semicolon] && (rb.meta.leftOwner match { // look for "foo { bar }" - case b: Term.Block => checkApply(b) && canRewriteBlockWithParens(b) && + case b: Term.Block => checkApply(b) && canRewriteStatWithParens(b) && b.parent.exists(ftoks.getLast(_) eq rb) - case f: Term.FunctionTerm => checkApply(f) && canRewriteFuncWithParens(f) + case f: Term.FunctionTerm => checkApply(f) && canRewriteStatWithParens(f) case t @ Term.ArgClause(arg :: Nil, _) => isParentAnApply(t) && ftoks.getDelimsIfEnclosed(t).exists(_._2 eq rb) && canRewriteStatWithParens(arg) diff --git a/scalafmt-tests-community/intellij/src/test/scala/org/scalafmt/community/intellij/CommunityIntellijScalaSuite.scala b/scalafmt-tests-community/intellij/src/test/scala/org/scalafmt/community/intellij/CommunityIntellijScalaSuite.scala index 9e2b11bc5..f7b7b0322 100644 --- a/scalafmt-tests-community/intellij/src/test/scala/org/scalafmt/community/intellij/CommunityIntellijScalaSuite.scala +++ b/scalafmt-tests-community/intellij/src/test/scala/org/scalafmt/community/intellij/CommunityIntellijScalaSuite.scala @@ -13,7 +13,7 @@ abstract class CommunityIntellijScalaSuite(name: String) class CommunityIntellijScala_2024_2_Suite extends CommunityIntellijScalaSuite("intellij-scala-2024.2") { - override protected def totalStatesVisited: Option[Int] = Some(59358052) + override protected def totalStatesVisited: Option[Int] = Some(59357916) override protected def builds = Seq(getBuild( "2024.2.28", @@ -52,7 +52,7 @@ class CommunityIntellijScala_2024_2_Suite class CommunityIntellijScala_2024_3_Suite extends CommunityIntellijScalaSuite("intellij-scala-2024.3") { - override protected def totalStatesVisited: Option[Int] = Some(59575207) + override protected def totalStatesVisited: Option[Int] = Some(59575073) override protected def builds = Seq(getBuild( "2024.3.4", diff --git a/scalafmt-tests-community/scala2/src/test/scala/org/scalafmt/community/scala2/CommunityScala2Suite.scala b/scalafmt-tests-community/scala2/src/test/scala/org/scalafmt/community/scala2/CommunityScala2Suite.scala index e8a1c7efd..20f27ce0b 100644 --- a/scalafmt-tests-community/scala2/src/test/scala/org/scalafmt/community/scala2/CommunityScala2Suite.scala +++ b/scalafmt-tests-community/scala2/src/test/scala/org/scalafmt/community/scala2/CommunityScala2Suite.scala @@ -9,7 +9,7 @@ abstract class CommunityScala2Suite(name: String) class CommunityScala2_12Suite extends CommunityScala2Suite("scala-2.12") { - override protected def totalStatesVisited: Option[Int] = Some(42585159) + override protected def totalStatesVisited: Option[Int] = Some(42584993) override protected def builds = Seq(getBuild("v2.12.20", dialects.Scala212, 1277)) @@ -18,7 +18,7 @@ class CommunityScala2_12Suite extends CommunityScala2Suite("scala-2.12") { class CommunityScala2_13Suite extends CommunityScala2Suite("scala-2.13") { - override protected def totalStatesVisited: Option[Int] = Some(53251428) + override protected def totalStatesVisited: Option[Int] = Some(53251209) override protected def builds = Seq(getBuild("v2.13.14", dialects.Scala213, 1287)) diff --git a/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala b/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala index 6b533cd1b..2cc2f4709 100644 --- a/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala +++ b/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala @@ -9,7 +9,7 @@ abstract class CommunityScala3Suite(name: String) class CommunityScala3_2Suite extends CommunityScala3Suite("scala-3.2") { - override protected def totalStatesVisited: Option[Int] = Some(39246112) + override protected def totalStatesVisited: Option[Int] = Some(39245552) override protected def builds = Seq(getBuild("3.2.2", dialects.Scala32, 791)) @@ -17,7 +17,7 @@ class CommunityScala3_2Suite extends CommunityScala3Suite("scala-3.2") { class CommunityScala3_3Suite extends CommunityScala3Suite("scala-3.3") { - override protected def totalStatesVisited: Option[Int] = Some(42433011) + override protected def totalStatesVisited: Option[Int] = Some(42432284) override protected def builds = Seq(getBuild("3.3.3", dialects.Scala33, 861)) diff --git a/scalafmt-tests-community/spark/src/test/scala/org/scalafmt/community/spark/CommunitySparkSuite.scala b/scalafmt-tests-community/spark/src/test/scala/org/scalafmt/community/spark/CommunitySparkSuite.scala index 67753e0f6..944368bd7 100644 --- a/scalafmt-tests-community/spark/src/test/scala/org/scalafmt/community/spark/CommunitySparkSuite.scala +++ b/scalafmt-tests-community/spark/src/test/scala/org/scalafmt/community/spark/CommunitySparkSuite.scala @@ -9,7 +9,7 @@ abstract class CommunitySparkSuite(name: String) class CommunitySpark3_4Suite extends CommunitySparkSuite("spark-3.4") { - override protected def totalStatesVisited: Option[Int] = Some(86406517) + override protected def totalStatesVisited: Option[Int] = Some(86405215) override protected def builds = Seq(getBuild("v3.4.1", dialects.Scala213, 2585)) @@ -17,7 +17,7 @@ class CommunitySpark3_4Suite extends CommunitySparkSuite("spark-3.4") { class CommunitySpark3_5Suite extends CommunitySparkSuite("spark-3.5") { - override protected def totalStatesVisited: Option[Int] = Some(91407413) + override protected def totalStatesVisited: Option[Int] = Some(91405983) override protected def builds = Seq(getBuild("v3.5.3", dialects.Scala213, 2756)) diff --git a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala index c92af9be1..5b55fc8c7 100644 --- a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala +++ b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala @@ -144,7 +144,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions { val explored = Debug.explored.get() logger.debug(s"Total explored: $explored") if (!onlyUnit && !onlyManual) - assertEquals(explored, 1209508, "total explored") + assertEquals(explored, 1209520, "total explored") val results = debugResults.result() // TODO(olafur) don't block printing out test results. // I don't want to deal with scalaz's Tasks :'(