Skip to content

Commit

Permalink
Scala 3 support (#11)
Browse files Browse the repository at this point in the history
Closes #2 

## Macros

Just like @KuceraMartin in #3, I ran into scala/scala3#19493 and
scala/scala3#19436 when trying to resolve a `TypeMapper` by importing
from a `DialectTypeMappers`. As a workaround, I introduced [additional
`implicit def`s in the `TableMapper` companion
object](https://github.com/com-lihaoyi/scalasql/blob/a7d6c531bf7b9cc2f5e2c175906d2a1e961de206/scalasql/core/src/TypeMapper.scala#L58-L121)
that instead rely on an implicit instance of `DialectTypeMappers`, i.e.
in a macro:

```scala
// bad, causes a compiler crash
// TableMacro.scala
(dialect: DialectTypeMappers) => {
  import dialect.*
  summonInline[TypeMapper[t]]
}

// good
// TypeMapper.scala
implicit def stringFromDialectTypeMappers(implicit d: DialectTypeMappers): TypeMapper[String] = d.StringType

// TableMacro.scala
(dialect: DialectTypeMappers) => {
  given d: DialectTypeMappers = dialect
  summonInline[TypeMapper[t]]
}
```

## Supporting changes

In addition to building out the macros in Scala 3, the following changes
were necessary:

1. Update the generated code to ensure `def`s aren't too far to the left
-- this is to silence Scala 3 warnings
2. Convert `CharSequence`s to `String`s explicitly -- see the [error the
Scala 3 compiler reported
here](9ffeb06)
3. Remove `try` block without a corresponding `catch` -- see the
[warning the Scala 3 compiler reported
here](011c3f6)
4. Add types to implicit definitions
5. Mark `renderSql` as `private[scalasql]` instead of `protected` -- see
the [error the Scala 3 compiler reported
here](8e767e3)
6. Use Scala 3.4 -- this is a little unfortunate since it's not the LTS
but it's necessary for the Scala 3 macros to [match on higher kinded
types like
this](https://github.com/com-lihaoyi/scalasql/blob/a7d6c531bf7b9cc2f5e2c175906d2a1e961de206/scalasql/query/src-3/TableMacro.scala#L48-L52).
This type of match doesn't work in Scala 3.3
7. Replace `_` wildcards with `?` -- this is to silence Scala 3 warnings
8. Replace `Foo with Bar` in types with `Foo & Bar` -- this is to
silence Scala 3 warnings
9. Add the `-Xsource:3` compiler option for Scala 2 -- this is necessary
to use the language features mentioned in points 7 and 8
10. Add a number of type annotations to method overrides -- this is to
silence warnings reported by the Scala 2 compiler as a result of
enabling `-Xsource:3`. All of the warnings relate to the inferred type
of the method changing between Scala 2 and 3
  • Loading branch information
mrdziuban authored May 25, 2024
1 parent 7318f00 commit 7cbc5cc
Show file tree
Hide file tree
Showing 50 changed files with 572 additions and 251 deletions.
9 changes: 7 additions & 2 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version = "3.7.15"
version = "3.8.1"

align.preset = none
align.openParenCallSite = false
Expand All @@ -15,5 +15,10 @@ docstrings.wrap = no

maxColumn = 100
newlines.implicitParamListModifierPrefer = before
runner.dialect = scala213
runner.dialect = scala3

fileOverride {
"glob:**/src-2/**" {
runner.dialect = scala213source3
}
}
33 changes: 19 additions & 14 deletions build.sc
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import $file.docs.generateDocs
import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version::0.4.0`
import $ivy.`com.github.lolgab::mill-mima::0.1.0`
import $ivy.`com.goyeau::mill-scalafix::0.3.1`
import $ivy.`com.goyeau::mill-scalafix::0.4.0`
import de.tobiasroeser.mill.vcs.version.VcsVersion
import com.goyeau.mill.scalafix.ScalafixModule
import mill._, scalalib._, publish._

val scalaVersions = Seq("2.13.12"/*, "3.3.1"*/)
val scalaVersions = Seq("2.13.12", "3.4.2")

trait Common extends CrossScalaModule with PublishModule with ScalafixModule{
def scalaVersion = crossScalaVersion
Expand All @@ -27,12 +27,15 @@ trait Common extends CrossScalaModule with PublishModule with ScalafixModule{
)
)

def scalacOptions = Seq("-Xlint:unused")
def scalacOptions = T {
Seq("-Wunused:privates,locals,explicits,implicits,params") ++
Option.when(scalaVersion().startsWith("2."))("-Xsource:3")
}
}


object scalasql extends Cross[ScalaSql](scalaVersions)
trait ScalaSql extends Common{
trait ScalaSql extends Common{ common =>
def moduleDeps = Seq(query, operations)
def ivyDeps = Agg(
ivy"org.apache.logging.log4j:log4j-api:2.20.0",
Expand All @@ -45,7 +48,7 @@ trait ScalaSql extends Common{


object test extends ScalaTests with ScalafixModule{
def scalacOptions = Seq("-Xlint:unused")
def scalacOptions = common.scalacOptions
def ivyDeps = Agg(
ivy"com.github.vertical-blank:sql-formatter:2.0.4",
ivy"com.lihaoyi::mainargs:0.4.0",
Expand All @@ -66,6 +69,9 @@ trait ScalaSql extends Common{
def forkArgs = Seq("-Duser.timezone=Asia/Singapore")
}

private def indent(code: Iterable[String]): String =
code.map(_.split("\n").map(" " + _).mkString("\n")).mkString("\n")

object core extends Common with CrossValue {
def ivyDeps = Agg(
ivy"com.lihaoyi::geny:1.0.0",
Expand Down Expand Up @@ -100,7 +106,7 @@ trait ScalaSql extends Common{
s"""package scalasql.core.generated
|import scalasql.core.Queryable
|trait QueryableRow{
| ${queryableRowDefs.mkString("\n")}
|${indent(queryableRowDefs)}
|}
|""".stripMargin
)
Expand Down Expand Up @@ -147,7 +153,6 @@ trait ScalaSql extends Common{
| implicit
| ${commaSep(j => s"q$j: Queryable.Row[Q$j, R$j]")}
|): Queryable.Row[(${commaSep(j => s"Q$j")}), (${commaSep(j => s"R$j")})] = {
| import scalasql.core.SqlStr.SqlStringSyntax
| new Queryable.Row.TupleNQueryable(
| Seq(${commaSep(j => s"q$j.walkLabels()")}),
| t => Seq(${commaSep(j => s"q$j.walkExprs(t._$j)")}),
Expand All @@ -166,7 +171,7 @@ trait ScalaSql extends Common{
s"""
|implicit def append$i[$commaSepQ, QA, $commaSepR, RA](
| implicit qr0: Queryable.Row[($commaSepQ, QA), ($commaSepR, RA)],
| qr20: Queryable.Row[QA, RA]): $joinAppendType = new $joinAppendType {
| @annotation.nowarn("msg=never used") qr20: Queryable.Row[QA, RA]): $joinAppendType = new $joinAppendType {
| override def appendTuple(t: ($commaSepQ), v: QA): ($commaSepQ, QA) = (${commaSep(j => s"t._$j")}, v)
|
| def qr: Queryable.Row[($commaSepQ, QA), ($commaSepR, RA)] = qr0
Expand All @@ -179,23 +184,23 @@ trait ScalaSql extends Common{
|import scalasql.core.{Queryable, Expr}
|import scalasql.query.Column
|trait Insert[V[_[_]], R]{
| ${defs(false).mkString("\n")}
|${indent(defs(false))}
|}
|trait InsertImpl[V[_[_]], R] extends Insert[V, R]{ this: scalasql.query.Insert[V, R] =>
| def newInsertValues[R](
| insert: scalasql.query.Insert[V, R],
| columns: Seq[Column[_]],
| valuesLists: Seq[Seq[Expr[_]]]
| columns: Seq[Column[?]],
| valuesLists: Seq[Seq[Expr[?]]]
| )(implicit qr: Queryable[V[Column], R]): scalasql.query.InsertColumns[V, R]
| ${defs(true).mkString("\n")}
|${indent(defs(true))}
|}
|
|trait QueryableRow{
| ${queryableRowDefs.mkString("\n")}
|${indent(queryableRowDefs)}
|}
|
|trait JoinAppend extends scalasql.query.JoinAppendLowPriority{
| ${joinAppendDefs.mkString("\n")}
|${indent(joinAppendDefs)}
|}
|""".stripMargin
)
Expand Down
8 changes: 4 additions & 4 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ dbClient.transaction { implicit db =>
db.run(Purchase.delete(_.id <= 3)) ==> 3
db.run(Purchase.select.size) ==> 4

db.savepoint { sp =>
db.savepoint { _ =>
db.run(Purchase.delete(_ => true)) ==> 4
db.run(Purchase.select.size) ==> 0
}
Expand Down Expand Up @@ -499,7 +499,7 @@ dbClient.transaction { implicit db =>
db.run(Purchase.select.size) ==> 4

try {
db.savepoint { sp =>
db.savepoint { _ =>
db.run(Purchase.delete(_ => true)) ==> 4
db.run(Purchase.select.size) ==> 0
throw new FooException
Expand Down Expand Up @@ -533,11 +533,11 @@ dbClient.transaction { implicit db =>
db.run(Purchase.delete(_.id <= 2)) ==> 2
db.run(Purchase.select.size) ==> 5

db.savepoint { sp1 =>
db.savepoint { _ =>
db.run(Purchase.delete(_.id <= 4)) ==> 2
db.run(Purchase.select.size) ==> 3

db.savepoint { sp2 =>
db.savepoint { _ =>
db.run(Purchase.delete(_.id <= 6)) ==> 2
db.run(Purchase.select.size) ==> 1
}
Expand Down
54 changes: 26 additions & 28 deletions scalasql/core/src/DbApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ trait DbApi extends AutoCloseable {
* Runs the given [[SqlStr]] of the form `sql"..."` and returns a value of type [[R]]
*/
def runSql[R](query: SqlStr, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): IndexedSeq[R]
Expand All @@ -48,7 +48,7 @@ trait DbApi extends AutoCloseable {
* arbitrary [[SqlStr]] of the form `sql"..."` and streams the results back to you
*/
def streamSql[R](sql: SqlStr, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): Generator[R]
Expand All @@ -63,7 +63,7 @@ trait DbApi extends AutoCloseable {
fetchSize: Int = -1,
queryTimeoutSeconds: Int = -1
)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): IndexedSeq[R]
Expand All @@ -78,7 +78,7 @@ trait DbApi extends AutoCloseable {
fetchSize: Int = -1,
queryTimeoutSeconds: Int = -1
)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): Generator[R]
Expand All @@ -104,7 +104,7 @@ trait DbApi extends AutoCloseable {
)(implicit fileName: sourcecode.FileName, lineNum: sourcecode.Line): Int

def updateGetGeneratedKeysSql[R](sql: SqlStr, fetchSize: Int = -1, queryTimeoutSeconds: Int = -1)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): IndexedSeq[R]
Expand All @@ -115,7 +115,7 @@ trait DbApi extends AutoCloseable {
fetchSize: Int = -1,
queryTimeoutSeconds: Int = -1
)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): IndexedSeq[R]
Expand Down Expand Up @@ -202,22 +202,20 @@ object DbApi {
.asInstanceOf[R]
else if (qr.isExecuteUpdate(query)) updateSql(flattened).asInstanceOf[R]
else {
try {
val res = stream(query, fetchSize, queryTimeoutSeconds)(
qr.asInstanceOf[Queryable[Q, Seq[_]]],
fileName,
lineNum
val res = stream(query, fetchSize, queryTimeoutSeconds)(
qr.asInstanceOf[Queryable[Q, Seq[?]]],
fileName,
lineNum
)
if (qr.isSingleRow(query)) {
val results = res.take(2).toVector
assert(
results.size == 1,
s"Single row query must return 1 result, not ${results.size}"
)
if (qr.isSingleRow(query)) {
val results = res.take(2).toVector
assert(
results.size == 1,
s"Single row query must return 1 result, not ${results.size}"
)
results.head.asInstanceOf[R]
} else {
res.toVector.asInstanceOf[R]
}
results.head.asInstanceOf[R]
} else {
res.toVector.asInstanceOf[R]
}
}
}
Expand Down Expand Up @@ -248,7 +246,7 @@ object DbApi {
fetchSize: Int = -1,
queryTimeoutSeconds: Int = -1
)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): IndexedSeq[R] = streamSql(sql, fetchSize, queryTimeoutSeconds).toVector
Expand All @@ -258,7 +256,7 @@ object DbApi {
fetchSize: Int = -1,
queryTimeoutSeconds: Int = -1
)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): Generator[R] = {
Expand Down Expand Up @@ -292,7 +290,7 @@ object DbApi {
fetchSize: Int = -1,
queryTimeoutSeconds: Int = -1
)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): IndexedSeq[R] = {
Expand All @@ -314,7 +312,7 @@ object DbApi {
fetchSize: Int = -1,
queryTimeoutSeconds: Int = -1
)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): IndexedSeq[R] = {
Expand All @@ -327,7 +325,7 @@ object DbApi {
fetchSize: Int = -1,
queryTimeoutSeconds: Int = -1
)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): Generator[R] = {
Expand Down Expand Up @@ -362,7 +360,7 @@ object DbApi {
fetchSize: Int = -1,
queryTimeoutSeconds: Int = -1
)(
implicit qr: Queryable.Row[_, R],
implicit qr: Queryable.Row[?, R],
fileName: sourcecode.FileName,
lineNum: sourcecode.Line
): IndexedSeq[R] = runRawUpdateGetGeneratedKeys0(
Expand Down Expand Up @@ -466,7 +464,7 @@ object DbApi {
queryTimeoutSeconds: Int,
fileName: sourcecode.FileName,
lineNum: sourcecode.Line,
qr: Queryable.Row[_, R]
qr: Queryable.Row[?, R]
): IndexedSeq[R] = {
val statement = connection.prepareStatement(sql, java.sql.Statement.RETURN_GENERATED_KEYS)
for ((v, i) <- variables.iterator.zipWithIndex) v(statement, i + 1)
Expand Down
10 changes: 5 additions & 5 deletions scalasql/core/src/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import scalasql.core.SqlStr.SqlStringSyntax
* a Scala value of a particular type [[T]]
*/
trait Expr[T] extends SqlStr.Renderable {
protected final def renderSql(ctx: Context): SqlStr = {
private[scalasql] final def renderSql(ctx: Context): SqlStr = {
ctx.exprNaming.get(this.exprIdentity).getOrElse(renderToSql0(ctx))
}

Expand Down Expand Up @@ -37,15 +37,15 @@ object Expr {
def identity[T](e: Expr[T]): Identity = e.exprIdentity
class Identity()

implicit def ExprQueryable[E[_] <: Expr[_], T](
implicit def ExprQueryable[E[_] <: Expr[?], T](
implicit mt: TypeMapper[T]
): Queryable.Row[E[T], T] = new ExprQueryable[E, T]()

class ExprQueryable[E[_] <: Expr[_], T](
class ExprQueryable[E[_] <: Expr[?], T](
implicit tm: TypeMapper[T]
) extends Queryable.Row[E[T], T] {
def walkLabels() = Seq(Nil)
def walkExprs(q: E[T]) = Seq(q)
def walkLabels(): Seq[List[String]] = Seq(Nil)
def walkExprs(q: E[T]): Seq[Expr[?]] = Seq(q)

override def construct(args: Queryable.ResultSetIterator): T = args.get(tm)

Expand Down
2 changes: 1 addition & 1 deletion scalasql/core/src/ExprsToSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object ExprsToSql {
}
}

def booleanExprs(prefix: SqlStr, exprs: Seq[Expr[_]])(implicit ctx: Context) = {
def booleanExprs(prefix: SqlStr, exprs: Seq[Expr[?]])(implicit ctx: Context) = {
SqlStr.optSeq(exprs.filter(!Expr.isLiteralTrue(_))) { having =>
prefix + SqlStr.join(having.map(Renderable.renderSql(_)), sql" AND ")
}
Expand Down
8 changes: 4 additions & 4 deletions scalasql/core/src/JoinNullable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import scalasql.core.SqlStr.SqlStringSyntax
*/
trait JoinNullable[Q] {
def get: Q
def isEmpty[T](f: Q => Expr[T])(implicit qr: Queryable[Q, _]): Expr[Boolean]
def nonEmpty[T](f: Q => Expr[T])(implicit qr: Queryable[Q, _]): Expr[Boolean]
def isEmpty[T](f: Q => Expr[T])(implicit qr: Queryable[Q, ?]): Expr[Boolean]
def nonEmpty[T](f: Q => Expr[T])(implicit qr: Queryable[Q, ?]): Expr[Boolean]
def map[V](f: Q => V): JoinNullable[V]

}
Expand All @@ -19,11 +19,11 @@ object JoinNullable {

def apply[Q](t: Q): JoinNullable[Q] = new JoinNullable[Q] {
def get: Q = t
def isEmpty[T](f: Q => Expr[T])(implicit qr: Queryable[Q, _]): Expr[Boolean] = Expr {
def isEmpty[T](f: Q => Expr[T])(implicit qr: Queryable[Q, ?]): Expr[Boolean] = Expr {
implicit ctx =>
sql"(${f(t)} IS NULL)"
}
def nonEmpty[T](f: Q => Expr[T])(implicit qr: Queryable[Q, _]): Expr[Boolean] = Expr {
def nonEmpty[T](f: Q => Expr[T])(implicit qr: Queryable[Q, ?]): Expr[Boolean] = Expr {
implicit ctx =>
sql"(${f(t)} IS NOT NULL)"
}
Expand Down
Loading

0 comments on commit 7cbc5cc

Please sign in to comment.