diff --git a/scalafmt-cli/src/main/scala/org/scalafmt/cli/CliOptions.scala b/scalafmt-cli/src/main/scala/org/scalafmt/cli/CliOptions.scala index 8317e87840..eda6573f4f 100644 --- a/scalafmt-cli/src/main/scala/org/scalafmt/cli/CliOptions.scala +++ b/scalafmt-cli/src/main/scala/org/scalafmt/cli/CliOptions.scala @@ -9,6 +9,7 @@ import org.scalafmt.sysops.OsSpecific import java.io.InputStream import java.io.PrintStream +import java.io.PrintWriter import java.nio.file.Files import java.nio.file.NoSuchFileException import java.nio.file.Path @@ -35,23 +36,26 @@ object CliOptions { * directly from main. */ def auto(parsed: CliOptions): CliOptions = { - val usesOut = parsed.stdIn || parsed.writeMode.usesOut - val auxOut = - if (parsed.noStdErr || !usesOut) parsed.common.out else parsed.common.err - - parsed.copy(common = - parsed.common.copy( - out = guardPrintStream(parsed.quiet && !usesOut)(parsed.common.out), - info = guardPrintStream(parsed.quiet || usesOut)(auxOut), - debug = guardPrintStream(parsed.quiet)( - if (parsed.debug) auxOut else parsed.common.debug, - ), - err = guardPrintStream(parsed.quiet)(parsed.common.err), - ), + val info: Output.StreamOrWriter = + if (parsed.quiet) Output.NoopStream + else { + val usesOut = parsed.stdIn || parsed.writeMode.usesOut + new Output.StreamOrWriter.Stream( + if (parsed.noStdErr || !usesOut) parsed.common.out + else parsed.common.err, + ) + } + val common = parsed.common.copy( + out = guardPrintStream(parsed.quiet && !parsed.stdIn)(parsed.common.out), + info = info, + debug = (if (parsed.debug) info else Output.NoopStream).printWriter, + err = guardPrintStream(parsed.quiet)(parsed.common.err), ) + + parsed.copy(common = common) } - private def guardPrintStream(p: => Boolean)( + private def guardPrintStream(p: Boolean)( candidate: PrintStream, ): PrintStream = if (p) Output.NoopStream.printStream else candidate @@ -62,8 +66,8 @@ case class CommonOptions( out: PrintStream = System.out, in: InputStream = System.in, err: PrintStream = System.err, - debug: PrintStream = Output.NoopStream.printStream, - info: PrintStream = Output.NoopStream.printStream, + debug: PrintWriter = Output.NoopStream.printWriter, + info: Output.StreamOrWriter = Output.NoopStream, ) { private[cli] lazy val workingDirectory: AbsoluteFile = cwd .getOrElse(AbsoluteFile.userDir) diff --git a/scalafmt-cli/src/main/scala/org/scalafmt/cli/Output.scala b/scalafmt-cli/src/main/scala/org/scalafmt/cli/Output.scala index 68f403fbda..1834377651 100644 --- a/scalafmt-cli/src/main/scala/org/scalafmt/cli/Output.scala +++ b/scalafmt-cli/src/main/scala/org/scalafmt/cli/Output.scala @@ -4,7 +4,7 @@ import java.io._ object Output { - object NoopStream extends OutputStream { + object NoopStream extends OutputStream with StreamOrWriter { self => override def write(b: Int): Unit = () @@ -15,7 +15,20 @@ object Output { def outputStream: OutputStream = self val printStream = new PrintStream(self) val printWriter = new PrintWriter(self) - val streamWriter = new OutputStreamWriter(self) + } + + trait StreamOrWriter { + def outputStream: OutputStream + def printStream: PrintStream + def printWriter: PrintWriter + } + + object StreamOrWriter { + class Stream(val obj: PrintStream) extends StreamOrWriter { + override def outputStream: OutputStream = obj + override def printStream: PrintStream = obj + override def printWriter: PrintWriter = new PrintWriter(obj) + } } } diff --git a/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtCliReporter.scala b/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtCliReporter.scala index dfb0898d6c..a0abc94b62 100644 --- a/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtCliReporter.scala +++ b/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtCliReporter.scala @@ -41,11 +41,10 @@ class ScalafmtCliReporter(options: CliOptions) extends ScalafmtReporter { override def parsedConfig(config: Path, scalafmtVersion: String): Unit = options.common.debug.println(s"parsed config (v$scalafmtVersion): $config") - override def downloadWriter(): PrintWriter = - new PrintWriter(options.common.info) + override def downloadWriter(): PrintWriter = options.common.info.printWriter override def downloadOutputStreamWriter(): OutputStreamWriter = - new OutputStreamWriter(options.common.info) + new OutputStreamWriter(options.common.info.outputStream) } private class FailedToFormat(filename: String, cause: Throwable) diff --git a/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtRunner.scala b/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtRunner.scala index b5ec5845ab..f75677d59d 100644 --- a/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtRunner.scala +++ b/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtRunner.scala @@ -4,7 +4,6 @@ import org.scalafmt.Error import org.scalafmt.sysops.AbsoluteFile import org.scalafmt.sysops.BatchPathFinder -import java.io.OutputStreamWriter import java.nio.file.Path trait ScalafmtRunner { @@ -19,7 +18,7 @@ trait ScalafmtRunner { msg: String, ): TermDisplay = { val termDisplay = new TermDisplay( - new OutputStreamWriter(options.common.info), + options.common.info.printWriter, fallbackMode = options.nonInteractive || TermDisplay.defaultFallbackMode, ) if ( diff --git a/scalafmt-tests/src/test/scala/org/scalafmt/cli/CliOptionsTest.scala b/scalafmt-tests/src/test/scala/org/scalafmt/cli/CliOptionsTest.scala index b7caeda755..0d5a1f41bc 100644 --- a/scalafmt-tests/src/test/scala/org/scalafmt/cli/CliOptionsTest.scala +++ b/scalafmt-tests/src/test/scala/org/scalafmt/cli/CliOptionsTest.scala @@ -108,13 +108,17 @@ class CliOptionsTest extends FunSuite { test("write info to out if not writing to stdout") { val options = Cli.getConfig(Array.empty[String], baseCliOptionsWithOut).get - assertEquals(options.common.info, System.out) + assertEquals(options.common.info.printStream, System.out) } Seq("--stdin", "--stdout").foreach { arg => test(s"don't write info when using $arg") { val options = Cli.getConfig(Array(arg), baseCliOptionsWithOut).get - assertEquals(options.common.info, Output.NoopStream.printStream) + options.common.info match { + case x: Output.StreamOrWriter.Stream + if x.obj eq Output.NoopStream.printStream => + case x => fail(s"info should be writing to NoopStream: $x") + } } }