From b5b4f50ef8aef38ab19b92adf7cfd23b38c1b757 Mon Sep 17 00:00:00 2001 From: Albert Meltzer <7529386+kitbellew@users.noreply.github.com> Date: Wed, 27 Mar 2024 09:00:46 -0700 Subject: [PATCH] RemoveScala3OptionalBraces: handle fewer braces In some cases, invoke this rule from redundant-braces, since that rule is guaranteed to run before remove-optional-braces. --- .../scalafmt/rewrite/RedundantBraces.scala | 31 ++++- .../rewrite/RemoveScala3OptionalBraces.scala | 114 +++++++++++++++++- .../test/resources/scala3/FewerBraces.stat | 90 ++++++-------- .../resources/scala3/FewerBraces_fold.stat | 74 ++++++------ .../resources/scala3/FewerBraces_keep.stat | 108 +++++++---------- .../resources/scala3/FewerBraces_unfold.stat | 76 +++++------- 6 files changed, 293 insertions(+), 200 deletions(-) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala index 437b58ae0e..655e3628d1 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala @@ -88,7 +88,7 @@ class RedundantBraces(implicit val ftoks: FormatTokens) ): Option[(Replacement, Replacement)] = Option { ft.right match { case _: Token.RightBrace => onRightBrace(left) - case _: Token.RightParen => onRightParen(left) + case _: Token.RightParen => onRightParen(left, hasFormatOff) case _ => null } } @@ -142,14 +142,28 @@ class RedundantBraces(implicit val ftoks: FormatTokens) lpFunction.orElse(lpPartialFunction).orNull } - private def onRightParen(left: Replacement)(implicit + private def onRightParen(left: Replacement, hasFormatOff: Boolean)(implicit ft: FormatToken, session: Session, style: ScalafmtConfig ): (Replacement, Replacement) = left.how match { case ReplacementType.Remove => val resOpt = getRightBraceBeforeRightParen(false).map { rb => - // we'll use right brace later, when applying fewer-braces rewrite + ft.meta.rightOwner match { + case ac: Term.ArgClause => + ftoks.matchingOpt(rb.left).map(ftoks.justBefore).foreach { lb => + session.rule[RemoveScala3OptionalBraces].foreach { r => + session.getClaimed(lb.meta.idx).foreach { case (leftIdx, _) => + val repl = r.onLeftForArgClause(ac)(lb, left.style) + if (null ne repl) { + implicit val ft: FormatToken = ftoks.prev(rb) + repl.onRightAndClaim(hasFormatOff, leftIdx) + } + } + } + } + case _ => + } (left, removeToken) } resOpt.orNull @@ -170,7 +184,16 @@ class RedundantBraces(implicit val ftoks: FormatTokens) new Token.RightBrace(rb.input, rb.dialect, rb.start + 1) } } - replaceIfAfterRightBrace.orNull // don't know how to Replace + (ft.meta.rightOwner match { + case ac: Term.ArgClause => + session.rule[RemoveScala3OptionalBraces].flatMap { r => + val repl = r.onLeftForArgClause(ac)(left.ft, left.style) + if (repl eq null) None else repl.onRight(hasFormatOff) + } + case _ => None + }).getOrElse { + replaceIfAfterRightBrace.orNull // don't know how to Replace + } case _ => null } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala index 7400e5eb10..56c09f846d 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala @@ -1,5 +1,7 @@ package org.scalafmt.rewrite +import scala.reflect.ClassTag + import scala.meta._ import scala.meta.tokens.Token @@ -47,6 +49,18 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens) if (t.parent.exists(_.is[Defn.Given])) removeToken else replaceToken(":")(new Token.Colon(x.input, x.dialect, x.start)) + case t: Term.ArgClause => onLeftForArgClause(t) + case t: Term.PartialFunction => + t.parent match { + case Some(p: Term.ArgClause) if (p.tokens.head match { + case px: Token.LeftBrace => px eq x + case px: Token.LeftParen => + shouldRewriteArgClauseWithLeftParen[RedundantBraces](px) + case _ => false + }) => + onLeftForArgClause(p) + case _ => null + } case _: Term.For if allowOldSyntax || { val rbFt = ftoks(ftoks.matching(ft.right)) ftoks.nextNonComment(rbFt).right.is[Token.KwDo] @@ -82,9 +96,14 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens) case _ => false } case _ => false + }) || (left.ft.right match { + case _: Token.Colon => !shouldRewriteColonOnRight(left) + case _ => false }) ft.right match { case _ if notOkToRewrite => None + case _: Token.RightParen if RewriteTrailingCommas.checkIfPrevious => + Some((left, removeToken)) case x: Token.RightBrace => val replacement = ft.meta.rightOwner match { case _: Term.For if allowOldSyntax && !nextFt.right.is[Token.KwDo] => @@ -96,9 +115,11 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens) } } - private def onLeftForBlock( - tree: Term.Block - )(implicit ft: FormatToken, style: ScalafmtConfig): Replacement = + private def onLeftForBlock(tree: Term.Block)(implicit + ft: FormatToken, + session: Session, + style: ScalafmtConfig + ): Replacement = tree.parent.fold(null: Replacement) { case t: Term.If => val ok = ftoks.prevNonComment(ft).left match { @@ -136,7 +157,94 @@ private class RemoveScala3OptionalBraces(implicit val ftoks: FormatTokens) else if (ftoks.prevNonComment(ft).left.is[Token.Equals]) removeToken else null case p: Tree.WithBody => if (p.body eq tree) removeToken else null + case p: Term.ArgClause => + p.tokens.head match { + case _: Token.LeftBrace => + onLeftForArgClause(p) + case px: Token.LeftParen + if shouldRewriteArgClauseWithLeftParen[RedundantParens](px) => + onLeftForArgClause(p) + case _ => null + } case _ => null } + private def shouldRewriteArgClauseWithLeftParen[A <: Rule]( + lp: Token + )(implicit ft: FormatToken, session: Session, tag: ClassTag[A]) = { + val prevFt = ftoks.prevNonComment(ft) + prevFt.left.eq(lp) && session + .claimedRule(prevFt.meta.idx - 1) + .exists(x => tag.runtimeClass.isInstance(x.rule)) + } + + private[rewrite] def onLeftForArgClause( + tree: Term.ArgClause + )(implicit ft: FormatToken, style: ScalafmtConfig): Replacement = { + val ok = style.dialect.allowFewerBraces && + style.rewrite.scala3.removeOptionalBraces.fewerBracesMaxSpan > 0 && + isSeqSingle(tree.values) + if (!ok) return null + + tree.parent match { + case Some(p: Term.Apply) if (p.parent match { + case Some(pp: Term.Apply) => pp.fun ne p + case _ => true + }) => + val x = ft.right // `{` or `(` + replaceToken(":")(new Token.Colon(x.input, x.dialect, x.start)) + case _ => null + } + } + + private def shouldRewriteColonOnRight(left: Replacement)(implicit + ft: FormatToken, + session: Session, + style: ScalafmtConfig + ): Boolean = { + val lft = left.ft + lft.meta.rightOwner match { + case t: Term.ArgClause => shouldRewriteArgClauseColonOnRight(t, lft) + case t @ (_: Term.Block | _: Term.PartialFunction) => + t.parent match { + case Some(p: Term.ArgClause) => + shouldRewriteArgClauseColonOnRight(p, lft) + case _ => false + } + case _ => true // template etc + } + } + + private def shouldRewriteArgClauseColonOnRight( + ac: Term.ArgClause, + lft: FormatToken + )(implicit + ft: FormatToken, + session: Session, + style: ScalafmtConfig + ): Boolean = ac.values match { + case arg :: Nil => + val begIdx = math.max(ftoks.getHead(arg).meta.idx - 1, lft.meta.idx + 1) + val endIdx = math.min(ftoks.getLast(arg).meta.idx, ft.meta.idx) + var span = 0 + val rob = style.rewrite.scala3.removeOptionalBraces + val maxStats = rob.fewerBracesMaxSpan + (begIdx until endIdx).foreach { idx => + val tokOpt = session.claimedRule(idx) match { + case Some(x) if x.ft.meta.idx == idx => + if (x.how == ReplacementType.Remove) None + else Some(x.ft.right) + case _ => + val tok = ftoks(idx).right + if (tok.is[Token.Whitespace]) None else Some(tok) + } + tokOpt.foreach { tok => + span += tok.end - tok.start + if (span > maxStats) return false // RETURNING!!! + } + } + span >= rob.fewerBracesMinSpan + case _ => false + } + } diff --git a/scalafmt-tests/src/test/resources/scala3/FewerBraces.stat b/scalafmt-tests/src/test/resources/scala3/FewerBraces.stat index bb953e7608..ba53b17ef7 100644 --- a/scalafmt-tests/src/test/resources/scala3/FewerBraces.stat +++ b/scalafmt-tests/src/test/resources/scala3/FewerBraces.stat @@ -1827,10 +1827,9 @@ foo .mtd1 { x + 1 } - .mtd2 { - x + 1 - x + 2 - } + .mtd2: + x + 1 + x + 2 .mtd3 { x + 1 x + 2 @@ -1861,10 +1860,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1902,10 +1900,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1945,10 +1942,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1988,10 +1984,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -2028,15 +2023,14 @@ object a: mtd1 { x => x + 1 } - + mtd2 { x => - x + 1 - x + 2 - } - + mtd3 { x => - x + 1 - x + 2 - x + 3 - } + + mtd2: x => + x + 1 + x + 2 + + mtd3 { x => + x + 1 + x + 2 + x + 3 + } <<< rewrite to fewer braces: func in parens and braces rewrite.rules = [RedundantBraces, RedundantParens] rewrite.scala3.removeOptionalBraces = { @@ -2069,10 +2063,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -2109,15 +2102,14 @@ object a: mtd1 { x => x + 1 } - + mtd2 { x => - x + 1 - x + 2 - } - + mtd3 { x => - x + 1 - x + 2 - x + 3 - } + + mtd2: x => + x + 1 + x + 2 + + mtd3 { x => + x + 1 + x + 2 + x + 3 + } <<< rewrite to fewer braces: partial func rewrite.scala3.removeOptionalBraces = { enabled = yes @@ -2143,10 +2135,9 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => x + 1 - case y => y + 1 - } + .mtd2: + case x => x + 1 + case y => y + 1 .mtd3 { case x => x + 1 case y => y + 1 @@ -2184,10 +2175,9 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => x + 1 - case y => y + 1 - } + .mtd2: + case x => x + 1 + case y => y + 1 .mtd3 { case x => x + 1 case y => y + 1 @@ -2225,11 +2215,10 @@ foo bar match case x => x + 1 } - .mtd2 { - bar match - case x => x + 1 - case y => y + 1 - } + .mtd2: + bar match + case x => x + 1 + case y => y + 1 .mtd3 { bar match case x => x + 1 @@ -2272,12 +2261,11 @@ foo def x = x + 3 } - .mtd2 { - x + 1 - def x = - x + 3 - x + 4 - } + .mtd2: + x + 1 + def x = + x + 3 + x + 4 .mtd3 { x + 1 def x = diff --git a/scalafmt-tests/src/test/resources/scala3/FewerBraces_fold.stat b/scalafmt-tests/src/test/resources/scala3/FewerBraces_fold.stat index 6d0a0d991e..12222f1701 100644 --- a/scalafmt-tests/src/test/resources/scala3/FewerBraces_fold.stat +++ b/scalafmt-tests/src/test/resources/scala3/FewerBraces_fold.stat @@ -1598,10 +1598,10 @@ foo x + 3 } >>> -foo.mtd1 { x + 1 }.mtd2 { - x + 1 - x + 2 -}.mtd3 { +foo.mtd1 { x + 1 }.mtd2: + x + 1 + x + 2 +.mtd3 { x + 1 x + 2 x + 3 @@ -1627,10 +1627,10 @@ foo x + 3 } >>> -foo.mtd1 { x => x + 1 }.mtd2 { x => +foo.mtd1 { x => x + 1 }.mtd2: x => x + 1 x + 2 -}.mtd3 { x => +.mtd3 { x => x + 1 x + 2 x + 3 @@ -1663,10 +1663,10 @@ foo } ) >>> -foo.mtd1(x => x + 1).mtd2 { x => +foo.mtd1(x => x + 1).mtd2: x => x + 1 x + 2 -}.mtd3 { x => +.mtd3 { x => x + 1 x + 2 x + 3 @@ -1701,10 +1701,10 @@ foo }, ) >>> -foo.mtd1(x => x + 1).mtd2 { x => +foo.mtd1(x => x + 1).mtd2: x => x + 1 x + 2 -}.mtd3 { x => +.mtd3 { x => x + 1 x + 2 x + 3 @@ -1739,10 +1739,10 @@ foo }, ) >>> -foo.mtd1 { x => x + 1 }.mtd2 { x => +foo.mtd1 { x => x + 1 }.mtd2: x => x + 1 x + 2 -}.mtd3 { x => +.mtd3 { x => x + 1 x + 2 x + 3 @@ -1775,10 +1775,10 @@ object a { } >>> object a: - mtd1(x => x + 1) + mtd2 { x => + mtd1(x => x + 1) + mtd2: x => x + 1 x + 2 - } + mtd3 { x => + + mtd3 { x => x + 1 x + 2 x + 3 @@ -1811,10 +1811,10 @@ foo } ) >>> -foo.mtd1(x => x + 1).mtd2 { x => +foo.mtd1(x => x + 1).mtd2: x => x + 1 x + 2 -}.mtd3 { x => +.mtd3 { x => x + 1 x + 2 x + 3 @@ -1847,10 +1847,10 @@ object a { } >>> object a: - mtd1(x => x + 1) + mtd2 { x => + mtd1(x => x + 1) + mtd2: x => x + 1 x + 2 - } + mtd3 { x => + + mtd3 { x => x + 1 x + 2 x + 3 @@ -1876,10 +1876,10 @@ foo case z => z + 1 } >>> -foo.mtd1 { case x => x + 1 }.mtd2 { - case x => x + 1 - case y => y + 1 -}.mtd3 { +foo.mtd1 { case x => x + 1 }.mtd2: + case x => x + 1 + case y => y + 1 +.mtd3 { case x => x + 1 case y => y + 1 case z => z + 1 @@ -1912,10 +1912,10 @@ foo } ) >>> -foo.mtd1 { case x => x + 1 }.mtd2 { - case x => x + 1 - case y => y + 1 -}.mtd3 { +foo.mtd1 { case x => x + 1 }.mtd2: + case x => x + 1 + case y => y + 1 +.mtd3 { case x => x + 1 case y => y + 1 case z => z + 1 @@ -1950,11 +1950,11 @@ foo foo.mtd1 { bar match case x => x + 1 -}.mtd2 { - bar match - case x => x + 1 - case y => y + 1 -}.mtd3 { +}.mtd2: + bar match + case x => x + 1 + case y => y + 1 +.mtd3 { bar match case x => x + 1 case y => y + 1 @@ -1993,12 +1993,12 @@ foo foo.mtd1 { x + 1 def x = x + 3 -}.mtd2 { - x + 1 - def x = - x + 3 - x + 4 -}.mtd3 { +}.mtd2: + x + 1 + def x = + x + 3 + x + 4 +.mtd3 { x + 1 def x = x + 3 diff --git a/scalafmt-tests/src/test/resources/scala3/FewerBraces_keep.stat b/scalafmt-tests/src/test/resources/scala3/FewerBraces_keep.stat index fa17a16dc1..50747c1a34 100644 --- a/scalafmt-tests/src/test/resources/scala3/FewerBraces_keep.stat +++ b/scalafmt-tests/src/test/resources/scala3/FewerBraces_keep.stat @@ -1795,10 +1795,9 @@ foo .mtd1 { x + 1 } - .mtd2 { - x + 1 - x + 2 - } + .mtd2: + x + 1 + x + 2 .mtd3 { x + 1 x + 2 @@ -1829,10 +1828,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1871,11 +1869,9 @@ foo x => x + 1 } - .mtd2 { - x => - x + 1 - x + 2 - } + .mtd2: x => + x + 1 + x + 2 .mtd3 { x => x + 1 @@ -1917,11 +1913,9 @@ foo x => x + 1 } - .mtd2 { - x => - x + 1 - x + 2 - } + .mtd2: x => + x + 1 + x + 2 .mtd3 { x => x + 1 @@ -1963,11 +1957,9 @@ foo x => x + 1 } - .mtd2 { - x => - x + 1 - x + 2 - } + .mtd2: x => + x + 1 + x + 2 .mtd3 { x => x + 1 @@ -2006,17 +1998,15 @@ object a: x => x + 1 } - + mtd2 { - x => - x + 1 - x + 2 - } - + mtd3 { - x => - x + 1 - x + 2 - x + 3 - } + + mtd2: x => + x + 1 + x + 2 + + mtd3 { + x => + x + 1 + x + 2 + x + 3 + } <<< rewrite to fewer braces: func in parens and braces rewrite.rules = [RedundantBraces, RedundantParens] rewrite.scala3.removeOptionalBraces = { @@ -2049,10 +2039,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -2089,15 +2078,14 @@ object a: mtd1 { x => x + 1 } - + mtd2 { x => - x + 1 - x + 2 - } - + mtd3 { x => - x + 1 - x + 2 - x + 3 - } + + mtd2: x => + x + 1 + x + 2 + + mtd3 { x => + x + 1 + x + 2 + x + 3 + } <<< rewrite to fewer braces: partial func rewrite.scala3.removeOptionalBraces = { enabled = yes @@ -2123,10 +2111,9 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => x + 1 - case y => y + 1 - } + .mtd2: + case x => x + 1 + case y => y + 1 .mtd3 { case x => x + 1 case y => y + 1 @@ -2164,10 +2151,9 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => x + 1 - case y => y + 1 - } + .mtd2: + case x => x + 1 + case y => y + 1 .mtd3 { case x => x + 1 case y => y + 1 @@ -2205,11 +2191,10 @@ foo bar match case x => x + 1 } - .mtd2 { - bar match - case x => x + 1 - case y => y + 1 - } + .mtd2: + bar match + case x => x + 1 + case y => y + 1 .mtd3 { bar match case x => x + 1 @@ -2252,12 +2237,11 @@ foo def x = x + 3 } - .mtd2 { - x + 1 - def x = - x + 3 - x + 4 - } + .mtd2: + x + 1 + def x = + x + 3 + x + 4 .mtd3 { x + 1 def x = diff --git a/scalafmt-tests/src/test/resources/scala3/FewerBraces_unfold.stat b/scalafmt-tests/src/test/resources/scala3/FewerBraces_unfold.stat index 2ee28d9d31..7aa5fcc254 100644 --- a/scalafmt-tests/src/test/resources/scala3/FewerBraces_unfold.stat +++ b/scalafmt-tests/src/test/resources/scala3/FewerBraces_unfold.stat @@ -1818,10 +1818,9 @@ foo .mtd1 { x + 1 } - .mtd2 { - x + 1 - x + 2 - } + .mtd2: + x + 1 + x + 2 .mtd3 { x + 1 x + 2 @@ -1852,10 +1851,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1893,10 +1891,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1936,10 +1933,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -1979,10 +1975,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -2019,10 +2014,10 @@ object a: mtd1 { x => x + 1 } + - mtd2 { x => + mtd2: x => x + 1 x + 2 - } + + + mtd3 { x => x + 1 x + 2 @@ -2060,10 +2055,9 @@ foo .mtd1 { x => x + 1 } - .mtd2 { x => + .mtd2: x => x + 1 x + 2 - } .mtd3 { x => x + 1 x + 2 @@ -2100,10 +2094,10 @@ object a: mtd1 { x => x + 1 } + - mtd2 { x => + mtd2: x => x + 1 x + 2 - } + + + mtd3 { x => x + 1 x + 2 @@ -2134,12 +2128,11 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => - x + 1 - case y => - y + 1 - } + .mtd2: + case x => + x + 1 + case y => + y + 1 .mtd3 { case x => x + 1 @@ -2180,12 +2173,11 @@ foo .mtd1 { case x => x + 1 } - .mtd2 { - case x => - x + 1 - case y => - y + 1 - } + .mtd2: + case x => + x + 1 + case y => + y + 1 .mtd3 { case x => x + 1 @@ -2227,13 +2219,12 @@ foo case x => x + 1 } - .mtd2 { - bar match - case x => - x + 1 - case y => - y + 1 - } + .mtd2: + bar match + case x => + x + 1 + case y => + y + 1 .mtd3 { bar match case x => @@ -2278,12 +2269,11 @@ foo x + 1 def x = x + 3 } - .mtd2 { - x + 1 - def x = - x + 3 - x + 4 - } + .mtd2: + x + 1 + def x = + x + 3 + x + 4 .mtd3 { x + 1 def x =