Skip to content

Commit

Permalink
Pushing a few experiments for feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
ShaneDelmore committed Dec 21, 2016
1 parent 445b2cd commit b5aaec5
Show file tree
Hide file tree
Showing 10 changed files with 303 additions and 23 deletions.
59 changes: 42 additions & 17 deletions core/src/main/scala/scalafix/rewrite/Xor2Either.scala
Original file line number Diff line number Diff line change
@@ -1,28 +1,53 @@
package scalafix.rewrite

import scala.collection.immutable.Seq
import scala.meta.parsers.Parse
import scala.{meta => m}
import scalafix.util.Patch
import scalafix.util._
import scala.meta._

class Desugared[T <: Tree: Parse](implicit semantic: SemanticApi) {
def unapply(original: T): Option[T] = semantic.desugared(original)
}

case object Xor2Either extends Rewrite {
override def rewrite(ast: m.Tree, ctx: RewriteCtx): Seq[Patch] = {
import scala.meta._
val semantic = getSemanticApi(ctx)
class Desugared[T <: Tree: Parse] {
def unapply(original: T): Option[T] = semantic.desugared(original)
}
implicit val semanticApi: SemanticApi = getSemanticApi(ctx)
object DType extends Desugared[Type]
object DTerm extends Desugared[Term]
// NOTE. This approach is super inefficient, since we run semantic.desugar on
// every case for every node in the tree. Ideally, we first match on the
// syntax structure we want and then run semantic.desugar.
ast.collect {
case t @ DType(t"cats.data.Xor") =>
Patch(t.tokens.head, t.tokens.last, s"Either")
case t @ DTerm(q"cats.data.Xor.Right.apply[..$_]") =>
Patch(t.tokens.head, t.tokens.last, s"Right")
case t @ DTerm(q"cats.data.Xor.Left.apply[..$_]") =>
Patch(t.tokens.head, t.tokens.last, s"Left")
}

val typeChanger = new ChangeType(ast, ctx)
val methodChanger = new ChangeMethod(ast, ctx)
val importAdder = new AddImport(ast, ctx)

//Create a sequence of type replacements
val replacementTypes = List(
ReplaceType(t"cats.data.XorT", t"cats.data.EitherT", "EitherT"),
ReplaceType(t"cats.data.Xor", t"scala.util.Either", "Either"),
ReplaceType(t"cats.data.Xor.Left", t"scala.util.Left", "Left"),
ReplaceType(t"cats.data.Xor.Right", t"scala.util.Either.Right", "Right")
)

//Add in some method replacements
val replacementTerms = List(
ReplaceTerm(t"cats.data.Xor.Right.apply", "Right"),
ReplaceTerm(t"cats.data.Xor.Left.apply", "Left")
)

//Then add needed imports.
//todo - derive this from patches created, types
//and terms replaced
//Only add if they are not already imported
val additionalImports = List(
"cats.data.EitherT",
"cats.implicits._",
"scala.util.Either"
)

val typeReplacements = typeChanger.gatherPatches(replacementTypes)
val termReplacements = methodChanger.gatherPatches(replacementTerms)
val addedImports = importAdder.gatherPatches(additionalImports)

addedImports ++ typeReplacements ++ termReplacements
}
}
37 changes: 37 additions & 0 deletions core/src/main/scala/scalafix/util/AddImport.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package scalafix.util

import scala.collection.immutable.Seq
import scala.{meta => m}
import scala.meta._
import scalafix.rewrite._

class AddImport(ast: m.Tree, ctx: RewriteCtx)(implicit sApi: SemanticApi) {
val allImports = ast.collect {
case t @ q"import ..$importersnel" => t -> importersnel
}

val firstImport = allImports.headOption
val firstImportFirstToken = firstImport.flatMap {
case (importStatement, _) => importStatement.tokens.headOption
}
val tokenBeforeFirstImport = firstImportFirstToken.flatMap { stopAt =>
ast.tokens.takeWhile(_ != stopAt).lastOption
}

//This is currently a very dumb implementation.
//It does no checking for existing imports and makes
//no attempt to consolidate imports
def addedImports(importString: String): Seq[Patch] =
tokenBeforeFirstImport
.map(
beginImportsLocation =>
Patch
.insertAfter(beginImportsLocation, importString)
)
.toList

def gatherPatches(imports: Seq[String]): Seq[Patch] = {
val importStrings = imports.map("import " + _).mkString("\n", "\n", "\n")
addedImports(importStrings)
}
}
45 changes: 45 additions & 0 deletions core/src/main/scala/scalafix/util/AnyDiff.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package scalafix.util

import scala.collection.immutable.Seq
import scala.meta.Tree

/** Helper class to create textual diff between two objects */
case class AnyDiff(a: Any, b: Any) extends Exception {
override def toString: String = s"""$a != $b $mismatchClass"""
def detailed: String = compare(a, b)

/** Best effort attempt to find a line number for scala.meta.Tree */
def lineNumber: Int =
1 + (a match {
case e: Tree => e.pos.start.line
case Some(t: Tree) => t.pos.start.line
case lst: Seq[_] =>
lst match {
case (head: Tree) :: tail => head.pos.start.line
case _ => -2
}
case _ => -2
})
def mismatchClass: String =
if (clsName(a) != clsName(b)) s"(${clsName(a)} != ${clsName(b)})"
else s"same class ${clsName(a)}"

private def clsName(a: Any) = a.getClass.getName

private def compare(a: Any, b: Any): String =
(a, b) match {
case (t1: Tree, t2: Tree) =>
s"""$toString
|Syntax diff:
|${t1.syntax}
|${t2.syntax}
|
|Structure diff:
|${t1.structure}
|${t2.structure}
""".stripMargin
case (t1: Seq[_], t2: Seq[_]) =>
t1.zip(t2).map { case (a, b) => compare(a, b) }.mkString
case _ => toString
}
}
41 changes: 41 additions & 0 deletions core/src/main/scala/scalafix/util/ChangeMethod.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package scalafix.util

import scala.collection.immutable.Seq
import scala.{meta => m}
import scalafix.rewrite.{Desugared, RewriteCtx, SemanticApi}

case class ReplaceTerm(original: m.Type, newString: String)

class ChangeMethod(ast: m.Tree, ctx: RewriteCtx)(implicit sApi: SemanticApi) {
object DType extends Desugared[m.Type]
object DTerm extends Desugared[m.Term]

//This isn't very satisfying. I am trying to match on a method, without including
//it's arguments
def partialTermMatch(rt: ReplaceTerm)
: PartialFunction[m.Tree, (scala.meta.Term, ReplaceTerm)] = {
case t @ DTerm(desugared)
if desugared.syntax.startsWith(rt.original.syntax) &&
desugared.syntax.endsWith("]") =>
t -> rt
}

def partialTermMatches(replacementTerms: Seq[ReplaceTerm])
: PartialFunction[m.Tree, (m.Term, ReplaceTerm)] =
replacementTerms.map(partialTermMatch).reduce(_ orElse _)

def terms(ptm: PartialFunction[m.Tree, (m.Term, ReplaceTerm)]) =
ast.collect {
ptm
}

def termReplacements(trms: Seq[(m.Term, ReplaceTerm)]): Seq[Patch] =
trms.map {
case (t, rt) =>
Patch.replace(t, rt.newString)
}

//Non-EmptyList?
def gatherPatches(tr: Seq[ReplaceTerm]): Seq[Patch] =
termReplacements(terms(partialTermMatches(tr)))
}
43 changes: 43 additions & 0 deletions core/src/main/scala/scalafix/util/ChangeType.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package scalafix.util

import scala.collection.immutable.Seq
import scala.{meta => m}
import scalafix.rewrite._

//Provide a little structure to the replacements we will be performing
case class ReplaceType(original: m.Type,
replacement: m.Type,
newString: String) {
def toPatch(t: m.Type): Patch = Patch.replace(t, newString)
}

class ChangeType(ast: m.Tree, ctx: RewriteCtx)(implicit sApi: SemanticApi) {
object DType extends Desugared[m.Type]
object DTerm extends Desugared[m.Term]

def partialTypeMatch(
rt: ReplaceType): PartialFunction[m.Tree, (m.Type, ReplaceType)] = {
case t @ DType(desugared)
if StructurallyEqual(desugared, rt.original).isRight =>
t -> rt
}

def partialTypeMatches(replacementTypes: Seq[ReplaceType])
: PartialFunction[m.Tree, (m.Type, ReplaceType)] =
replacementTypes.map(partialTypeMatch).reduce(_ orElse _)

// NOTE. This approach is super inefficient, since we run semantic.desugar on
// every case for every node in the tree. Ideally, we first match on the
// syntax structure we want and then run semantic.desugar.
def tpes(ptm: PartialFunction[m.Tree, (m.Type, ReplaceType)])
: Seq[(m.Type, ReplaceType)] = ast.collect { ptm }

//This is unsafe, come up with something better
def typeReplacements(tpes: Seq[(m.Type, ReplaceType)]): Seq[Patch] =
tpes.map {
case (tree, rt) => rt.toPatch(tree)
}

def gatherPatches(tr: Seq[ReplaceType]): Seq[Patch] =
typeReplacements(tpes(partialTypeMatches(tr)))
}
20 changes: 20 additions & 0 deletions core/src/main/scala/scalafix/util/Patch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,24 @@ object Patch {
.map(_.syntax)
.mkString("")
}

def replace(token: Token, replacement: String): Patch =
Patch(token, token, replacement)

def replace(tree: Tree, replacement: String): Patch =
Patch(tree.tokens.head, tree.tokens.last, replacement)

def insertBefore(token: Token, toPrepend: String) =
replace(token, s"$toPrepend${token.syntax}")

def insertBefore(tree: Tree, toPrepend: String): Patch =
replace(tree, s"$toPrepend${tree.syntax}")

def insertAfter(token: Token, toAppend: String) =
replace(token, s"$toAppend${token.syntax}")

def insertAfter(tree: Tree, toAppend: String): Patch =
replace(tree, s"${tree.syntax}$toAppend")

def delete(tree: Tree): Patch = replace(tree, "")
}
42 changes: 42 additions & 0 deletions core/src/main/scala/scalafix/util/StructurallyEqual.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package scalafix.util

import scala.collection.immutable.Seq

object StructurallyEqual {
import scala.meta.Tree

/** Test if two trees are structurally equal.
* @return Left(errorMessage with minimal diff) if trees are not structurally
* different, otherwise Right(Unit). To convert into exception with
* meaningful error message,
* val Right(_) = StructurallyEqual(a, b)
**/
def apply(a: Tree, b: Tree): Either[AnyDiff, Unit] = {
def loop(x: Any, y: Any): Boolean = {
val ok: Boolean = (x, y) match {
case (x, y) if x == null || y == null => x == null && y == null
case (x: Some[_], y: Some[_]) => loop(x.get, y.get)
case (x: None.type, y: None.type) => true
case (xs: Seq[_], ys: Seq[_]) =>
xs.length == ys.length &&
xs.zip(ys).forall {
case (x, y) => loop(x, y)
}
case (x: Tree, y: Tree) =>
def sameStructure =
x.productPrefix == y.productPrefix &&
loop(x.productIterator.toList, y.productIterator.toList)
sameStructure
case _ => x == y
}
if (!ok) throw AnyDiff(x, y)
else true
}
try {
loop(a, b)
Right(Unit)
} catch {
case t: AnyDiff => Left(t)
}
}
}
21 changes: 16 additions & 5 deletions core/src/test/resources/Xor/basic.source
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
rewrites = [Xor2Either]
<<< xor 1
import cats.data.Xor
import scala.concurrent.Future
import cats.data.{ Xor, XorT }
trait A {
val r: Xor[Int, String] = Xor.Right("")
val s: Xor[Int, String] = Xor.Left(1 /* comment */)
type MyDisjunction = Xor[Int, String]
val r: MyDisjunction = Xor.Right.apply("")
val s: Xor[Int, String] = cats.data.Xor.Left(1 /* comment */)
val t: Xor[Int, String] = r.map(_ + "!")
val nest: Seq[Xor[Int, cats.data.Xor[String, Int]]]
val u: XorT[Future, Int, String] = ???
}
>>>
import cats.data.Xor
import cats.data.EitherT
import cats.implicits._
import scala.util.Either
import scala.concurrent.Future
import cats.data.{ Xor, XorT }
trait A {
val r: Either[Int, String] = Right("")
type MyDisjunction = Either[Int, String]
val r: MyDisjunction = Right("")
val s: Either[Int, String] = Left(1 /* comment */)
val t: Either[Int, String] = r.map(_ + "!")
val nest: Seq[Either[Int, Either[String, Int]]]
val u: EitherT[Future, Int, String] = ???
}
9 changes: 8 additions & 1 deletion scalafix-nsc/src/test/scala/cats/data/Xor.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
package cats.data
import scala.language.higherKinds

sealed abstract class Xor[+A, +B] extends Product with Serializable
sealed abstract class Xor[+A, +B] extends Product with Serializable {
def map[C](f: B => C) = ???
}

object Xor {
def left[A, B](a: A): A Xor B = Xor.Left(a)
def right[A, B](b: B): A Xor B = Xor.Right(b)
final case class Left[+A](a: A) extends (A Xor Nothing)
final case class Right[+B](b: B) extends (Nothing Xor B)
}

sealed abstract class XorT[F[_], A, B](value: F[A Xor B])

sealed abstract class EitherT[F[_], A, B](value: F[Either[A, B]])
9 changes: 9 additions & 0 deletions scalafix-nsc/src/test/scala/cats/implicits/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package cats

import scala.language.implicitConversions

package object implicits {
implicit class EitherOps[A, B](from: Either[A, B]) {
def map[C](f: B => C): Either[A, C] = ???
}
}

0 comments on commit b5aaec5

Please sign in to comment.