diff --git a/core/src/main/scala/scalafix/rewrite/Xor2Either.scala b/core/src/main/scala/scalafix/rewrite/Xor2Either.scala index e8e697c757..543a051ea6 100644 --- a/core/src/main/scala/scalafix/rewrite/Xor2Either.scala +++ b/core/src/main/scala/scalafix/rewrite/Xor2Either.scala @@ -1,28 +1,53 @@ package scalafix.rewrite +import scala.collection.immutable.Seq import scala.meta.parsers.Parse import scala.{meta => m} -import scalafix.util.Patch +import scalafix.util._ +import scala.meta._ + +class Desugared[T <: Tree: Parse](implicit semantic: SemanticApi) { + def unapply(original: T): Option[T] = semantic.desugared(original) +} case object Xor2Either extends Rewrite { override def rewrite(ast: m.Tree, ctx: RewriteCtx): Seq[Patch] = { - import scala.meta._ - val semantic = getSemanticApi(ctx) - class Desugared[T <: Tree: Parse] { - def unapply(original: T): Option[T] = semantic.desugared(original) - } + implicit val semanticApi: SemanticApi = getSemanticApi(ctx) object DType extends Desugared[Type] object DTerm extends Desugared[Term] - // NOTE. This approach is super inefficient, since we run semantic.desugar on - // every case for every node in the tree. Ideally, we first match on the - // syntax structure we want and then run semantic.desugar. - ast.collect { - case t @ DType(t"cats.data.Xor") => - Patch(t.tokens.head, t.tokens.last, s"Either") - case t @ DTerm(q"cats.data.Xor.Right.apply[..$_]") => - Patch(t.tokens.head, t.tokens.last, s"Right") - case t @ DTerm(q"cats.data.Xor.Left.apply[..$_]") => - Patch(t.tokens.head, t.tokens.last, s"Left") - } + + val typeChanger = new ChangeType(ast, ctx) + val methodChanger = new ChangeMethod(ast, ctx) + val importAdder = new AddImport(ast, ctx) + + //Create a sequence of type replacements + val replacementTypes = List( + ReplaceType(t"cats.data.XorT", t"cats.data.EitherT", "EitherT"), + ReplaceType(t"cats.data.Xor", t"scala.util.Either", "Either"), + ReplaceType(t"cats.data.Xor.Left", t"scala.util.Left", "Left"), + ReplaceType(t"cats.data.Xor.Right", t"scala.util.Either.Right", "Right") + ) + + //Add in some method replacements + val replacementTerms = List( + ReplaceTerm(t"cats.data.Xor.Right.apply", "Right"), + ReplaceTerm(t"cats.data.Xor.Left.apply", "Left") + ) + + //Then add needed imports. + //todo - derive this from patches created, types + //and terms replaced + //Only add if they are not already imported + val additionalImports = List( + "cats.data.EitherT", + "cats.implicits._", + "scala.util.Either" + ) + + val typeReplacements = typeChanger.gatherPatches(replacementTypes) + val termReplacements = methodChanger.gatherPatches(replacementTerms) + val addedImports = importAdder.gatherPatches(additionalImports) + + addedImports ++ typeReplacements ++ termReplacements } } diff --git a/core/src/main/scala/scalafix/util/AddImport.scala b/core/src/main/scala/scalafix/util/AddImport.scala new file mode 100644 index 0000000000..be526085d2 --- /dev/null +++ b/core/src/main/scala/scalafix/util/AddImport.scala @@ -0,0 +1,37 @@ +package scalafix.util + +import scala.collection.immutable.Seq +import scala.{meta => m} +import scala.meta._ +import scalafix.rewrite._ + +class AddImport(ast: m.Tree, ctx: RewriteCtx)(implicit sApi: SemanticApi) { + val allImports = ast.collect { + case t @ q"import ..$importersnel" => t -> importersnel + } + + val firstImport = allImports.headOption + val firstImportFirstToken = firstImport.flatMap { + case (importStatement, _) => importStatement.tokens.headOption + } + val tokenBeforeFirstImport = firstImportFirstToken.flatMap { stopAt => + ast.tokens.takeWhile(_ != stopAt).lastOption + } + + //This is currently a very dumb implementation. + //It does no checking for existing imports and makes + //no attempt to consolidate imports + def addedImports(importString: String): Seq[Patch] = + tokenBeforeFirstImport + .map( + beginImportsLocation => + Patch + .insertAfter(beginImportsLocation, importString) + ) + .toList + + def gatherPatches(imports: Seq[String]): Seq[Patch] = { + val importStrings = imports.map("import " + _).mkString("\n", "\n", "\n") + addedImports(importStrings) + } +} diff --git a/core/src/main/scala/scalafix/util/AnyDiff.scala b/core/src/main/scala/scalafix/util/AnyDiff.scala new file mode 100644 index 0000000000..8ce724c1bc --- /dev/null +++ b/core/src/main/scala/scalafix/util/AnyDiff.scala @@ -0,0 +1,45 @@ +package scalafix.util + +import scala.collection.immutable.Seq +import scala.meta.Tree + +/** Helper class to create textual diff between two objects */ +case class AnyDiff(a: Any, b: Any) extends Exception { + override def toString: String = s"""$a != $b $mismatchClass""" + def detailed: String = compare(a, b) + + /** Best effort attempt to find a line number for scala.meta.Tree */ + def lineNumber: Int = + 1 + (a match { + case e: Tree => e.pos.start.line + case Some(t: Tree) => t.pos.start.line + case lst: Seq[_] => + lst match { + case (head: Tree) :: tail => head.pos.start.line + case _ => -2 + } + case _ => -2 + }) + def mismatchClass: String = + if (clsName(a) != clsName(b)) s"(${clsName(a)} != ${clsName(b)})" + else s"same class ${clsName(a)}" + + private def clsName(a: Any) = a.getClass.getName + + private def compare(a: Any, b: Any): String = + (a, b) match { + case (t1: Tree, t2: Tree) => + s"""$toString + |Syntax diff: + |${t1.syntax} + |${t2.syntax} + | + |Structure diff: + |${t1.structure} + |${t2.structure} + """.stripMargin + case (t1: Seq[_], t2: Seq[_]) => + t1.zip(t2).map { case (a, b) => compare(a, b) }.mkString + case _ => toString + } +} diff --git a/core/src/main/scala/scalafix/util/ChangeMethod.scala b/core/src/main/scala/scalafix/util/ChangeMethod.scala new file mode 100644 index 0000000000..d68d01267b --- /dev/null +++ b/core/src/main/scala/scalafix/util/ChangeMethod.scala @@ -0,0 +1,41 @@ +package scalafix.util + +import scala.collection.immutable.Seq +import scala.{meta => m} +import scalafix.rewrite.{Desugared, RewriteCtx, SemanticApi} + +case class ReplaceTerm(original: m.Type, newString: String) + +class ChangeMethod(ast: m.Tree, ctx: RewriteCtx)(implicit sApi: SemanticApi) { + object DType extends Desugared[m.Type] + object DTerm extends Desugared[m.Term] + + //This isn't very satisfying. I am trying to match on a method, without including + //it's arguments + def partialTermMatch(rt: ReplaceTerm) + : PartialFunction[m.Tree, (scala.meta.Term, ReplaceTerm)] = { + case t @ DTerm(desugared) + if desugared.syntax.startsWith(rt.original.syntax) && + desugared.syntax.endsWith("]") => + t -> rt + } + + def partialTermMatches(replacementTerms: Seq[ReplaceTerm]) + : PartialFunction[m.Tree, (m.Term, ReplaceTerm)] = + replacementTerms.map(partialTermMatch).reduce(_ orElse _) + + def terms(ptm: PartialFunction[m.Tree, (m.Term, ReplaceTerm)]) = + ast.collect { + ptm + } + + def termReplacements(trms: Seq[(m.Term, ReplaceTerm)]): Seq[Patch] = + trms.map { + case (t, rt) => + Patch.replace(t, rt.newString) + } + + //Non-EmptyList? + def gatherPatches(tr: Seq[ReplaceTerm]): Seq[Patch] = + termReplacements(terms(partialTermMatches(tr))) +} diff --git a/core/src/main/scala/scalafix/util/ChangeType.scala b/core/src/main/scala/scalafix/util/ChangeType.scala new file mode 100644 index 0000000000..2919ba434a --- /dev/null +++ b/core/src/main/scala/scalafix/util/ChangeType.scala @@ -0,0 +1,43 @@ +package scalafix.util + +import scala.collection.immutable.Seq +import scala.{meta => m} +import scalafix.rewrite._ + +//Provide a little structure to the replacements we will be performing +case class ReplaceType(original: m.Type, + replacement: m.Type, + newString: String) { + def toPatch(t: m.Type): Patch = Patch.replace(t, newString) +} + +class ChangeType(ast: m.Tree, ctx: RewriteCtx)(implicit sApi: SemanticApi) { + object DType extends Desugared[m.Type] + object DTerm extends Desugared[m.Term] + + def partialTypeMatch( + rt: ReplaceType): PartialFunction[m.Tree, (m.Type, ReplaceType)] = { + case t @ DType(desugared) + if StructurallyEqual(desugared, rt.original).isRight => + t -> rt + } + + def partialTypeMatches(replacementTypes: Seq[ReplaceType]) + : PartialFunction[m.Tree, (m.Type, ReplaceType)] = + replacementTypes.map(partialTypeMatch).reduce(_ orElse _) + + // NOTE. This approach is super inefficient, since we run semantic.desugar on + // every case for every node in the tree. Ideally, we first match on the + // syntax structure we want and then run semantic.desugar. + def tpes(ptm: PartialFunction[m.Tree, (m.Type, ReplaceType)]) + : Seq[(m.Type, ReplaceType)] = ast.collect { ptm } + + //This is unsafe, come up with something better + def typeReplacements(tpes: Seq[(m.Type, ReplaceType)]): Seq[Patch] = + tpes.map { + case (tree, rt) => rt.toPatch(tree) + } + + def gatherPatches(tr: Seq[ReplaceType]): Seq[Patch] = + typeReplacements(tpes(partialTypeMatches(tr))) +} diff --git a/core/src/main/scala/scalafix/util/Patch.scala b/core/src/main/scala/scalafix/util/Patch.scala index 6be15ae2a9..392c30b3d6 100644 --- a/core/src/main/scala/scalafix/util/Patch.scala +++ b/core/src/main/scala/scalafix/util/Patch.scala @@ -35,4 +35,24 @@ object Patch { .map(_.syntax) .mkString("") } + + def replace(token: Token, replacement: String): Patch = + Patch(token, token, replacement) + + def replace(tree: Tree, replacement: String): Patch = + Patch(tree.tokens.head, tree.tokens.last, replacement) + + def insertBefore(token: Token, toPrepend: String) = + replace(token, s"$toPrepend${token.syntax}") + + def insertBefore(tree: Tree, toPrepend: String): Patch = + replace(tree, s"$toPrepend${tree.syntax}") + + def insertAfter(token: Token, toAppend: String) = + replace(token, s"$toAppend${token.syntax}") + + def insertAfter(tree: Tree, toAppend: String): Patch = + replace(tree, s"${tree.syntax}$toAppend") + + def delete(tree: Tree): Patch = replace(tree, "") } diff --git a/core/src/main/scala/scalafix/util/StructurallyEqual.scala b/core/src/main/scala/scalafix/util/StructurallyEqual.scala new file mode 100644 index 0000000000..4fb1147d17 --- /dev/null +++ b/core/src/main/scala/scalafix/util/StructurallyEqual.scala @@ -0,0 +1,42 @@ +package scalafix.util + +import scala.collection.immutable.Seq + +object StructurallyEqual { + import scala.meta.Tree + + /** Test if two trees are structurally equal. + * @return Left(errorMessage with minimal diff) if trees are not structurally + * different, otherwise Right(Unit). To convert into exception with + * meaningful error message, + * val Right(_) = StructurallyEqual(a, b) + **/ + def apply(a: Tree, b: Tree): Either[AnyDiff, Unit] = { + def loop(x: Any, y: Any): Boolean = { + val ok: Boolean = (x, y) match { + case (x, y) if x == null || y == null => x == null && y == null + case (x: Some[_], y: Some[_]) => loop(x.get, y.get) + case (x: None.type, y: None.type) => true + case (xs: Seq[_], ys: Seq[_]) => + xs.length == ys.length && + xs.zip(ys).forall { + case (x, y) => loop(x, y) + } + case (x: Tree, y: Tree) => + def sameStructure = + x.productPrefix == y.productPrefix && + loop(x.productIterator.toList, y.productIterator.toList) + sameStructure + case _ => x == y + } + if (!ok) throw AnyDiff(x, y) + else true + } + try { + loop(a, b) + Right(Unit) + } catch { + case t: AnyDiff => Left(t) + } + } +} diff --git a/core/src/test/resources/Xor/basic.source b/core/src/test/resources/Xor/basic.source index 9cbd7683f1..c2d49b304c 100644 --- a/core/src/test/resources/Xor/basic.source +++ b/core/src/test/resources/Xor/basic.source @@ -1,15 +1,26 @@ rewrites = [Xor2Either] <<< xor 1 -import cats.data.Xor +import scala.concurrent.Future +import cats.data.{ Xor, XorT } trait A { - val r: Xor[Int, String] = Xor.Right("") - val s: Xor[Int, String] = Xor.Left(1 /* comment */) +type MyDisjunction = Xor[Int, String] + val r: MyDisjunction = Xor.Right.apply("") + val s: Xor[Int, String] = cats.data.Xor.Left(1 /* comment */) + val t: Xor[Int, String] = r.map(_ + "!") val nest: Seq[Xor[Int, cats.data.Xor[String, Int]]] + val u: XorT[Future, Int, String] = ??? } >>> -import cats.data.Xor +import cats.data.EitherT +import cats.implicits._ +import scala.util.Either +import scala.concurrent.Future +import cats.data.{ Xor, XorT } trait A { - val r: Either[Int, String] = Right("") +type MyDisjunction = Either[Int, String] + val r: MyDisjunction = Right("") val s: Either[Int, String] = Left(1 /* comment */) + val t: Either[Int, String] = r.map(_ + "!") val nest: Seq[Either[Int, Either[String, Int]]] + val u: EitherT[Future, Int, String] = ??? } diff --git a/scalafix-nsc/src/test/scala/cats/data/Xor.scala b/scalafix-nsc/src/test/scala/cats/data/Xor.scala index 9395fbe29a..8e1e16cb39 100644 --- a/scalafix-nsc/src/test/scala/cats/data/Xor.scala +++ b/scalafix-nsc/src/test/scala/cats/data/Xor.scala @@ -1,6 +1,9 @@ package cats.data +import scala.language.higherKinds -sealed abstract class Xor[+A, +B] extends Product with Serializable +sealed abstract class Xor[+A, +B] extends Product with Serializable { + def map[C](f: B => C) = ??? +} object Xor { def left[A, B](a: A): A Xor B = Xor.Left(a) @@ -8,3 +11,7 @@ object Xor { final case class Left[+A](a: A) extends (A Xor Nothing) final case class Right[+B](b: B) extends (Nothing Xor B) } + +sealed abstract class XorT[F[_], A, B](value: F[A Xor B]) + +sealed abstract class EitherT[F[_], A, B](value: F[Either[A, B]]) diff --git a/scalafix-nsc/src/test/scala/cats/implicits/package.scala b/scalafix-nsc/src/test/scala/cats/implicits/package.scala new file mode 100644 index 0000000000..2cb979e2f9 --- /dev/null +++ b/scalafix-nsc/src/test/scala/cats/implicits/package.scala @@ -0,0 +1,9 @@ +package cats + +import scala.language.implicitConversions + +package object implicits { + implicit class EitherOps[A, B](from: Either[A, B]) { + def map[C](f: B => C): Either[A, C] = ??? + } +}