Skip to content

Commit

Permalink
Merge pull request #593 from soujiro32167/support-defaults
Browse files Browse the repository at this point in the history
Support default values for case class fields
  • Loading branch information
ayoub-benali authored Jun 11, 2024
2 parents be2d7c0 + f21a876 commit 13f7fc1
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ target/
.metals/
.vscode/
.bloop/
metals.sbt
metals.sbt
.idea/
4 changes: 4 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ lazy val generic = project
scalaSettings ++ Seq(
crossScalaVersions += scala3
),
// magnolia requires compilation with the -Yretain-trees flag to support case class field default values on Scala 3
Test / scalacOptions ++= (if (CrossVersion.partialVersion(scalaVersion.value).exists(_._1 == 3))
Seq("-Yretain-trees")
else Nil),
testSettings
)
.dependsOn(core % "compile->compile;test->test")
Expand Down
7 changes: 5 additions & 2 deletions modules/generic/src/main/scala-2/vulcan/generic/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ package object generic {
doc = param.annotations.collectFirst {
case AvroDoc(doc) => doc
},
default = (if (codec.schema.exists(_.isNullable) && nullDefaultField) Some(None)
else None).asInstanceOf[Option[param.PType]] // TODO: remove cast
default = param.default.orElse(
if (codec.schema.exists(_.isNullable) && nullDefaultField)
Some(None.asInstanceOf[param.PType]) // TODO: remove cast
else None
)
).widen
}
.map(caseClass.rawConstruct(_))
Expand Down
7 changes: 5 additions & 2 deletions modules/generic/src/main/scala-3/vulcan/generic/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ package object generic {
doc = param.annotations.collectFirst {
case AvroDoc(doc) => doc
},
default = (if (codec.schema.exists(_.isNullable) && nullDefaultField) Some(None)
else None).asInstanceOf[Option[param.PType]] // TODO: remove cast
default = param.default.orElse(
Option.when(codec.schema.exists(_.isNullable) && nullDefaultField)(
None.asInstanceOf[param.PType] // TODO: remove cast
)
)
).widen
}
.map(caseClass.rawConstruct(_))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2019-2023 OVO Energy Limited
*
* SPDX-License-Identifier: Apache-2.0
*/

package vulcan.generic

import examples.AvroRecordDefault._
import org.apache.avro.JsonProperties

final class AvroFieldDefaultSpec extends CodecBase {
describe("AvroFieldDefault") {
it("should create a schema with a default for a field") {
assert(Foo.codec.schema.exists(_.getField("a").defaultVal() == 1))
assert(Foo.codec.schema.exists(_.getField("b").defaultVal() == "foo"))
assert(Foo.codec.schema.exists(_.getField("c").defaultVal() == JsonProperties.NULL_VALUE))
}

it("should fail when annotating an Option") {
assertSchemaError[InvalidDefault2]
}

it("should succeed when annotating an enum first element") {
assert(HasSFirst.codec.schema.exists(_.getField("s").defaultVal() == "A"))
}

it("should succeed when annotating an enum second element") {
assert(HasSSecond.codec.schema.exists(_.getField("s").defaultVal() == "B"))
}

it("should succeed with the first member of a union") {
assertSchemaIs[HasUnion](
"""{"type":"record","name":"HasUnion","namespace":"vulcan.generic.examples.AvroRecordDefault","fields":[{"name":"u","type":[{"type":"record","name":"A","namespace":"vulcan.generic.examples.AvroRecordDefault.Union","fields":[{"name":"a","type":"int"}]},{"type":"record","name":"B","namespace":"vulcan.generic.examples.AvroRecordDefault.Union","fields":[{"name":"b","type":"string"}]}],"default":{"a":1}}]}"""
)
val result = unsafeDecode[HasUnion](unsafeEncode[Empty](Empty()))
assert(result == HasUnion(Union.A(1)))
}

it("should fail with the second member of a union") {
assertSchemaError[HasUnionSecond]
}
}
}
3 changes: 3 additions & 0 deletions modules/generic/src/test/scala/vulcan/generic/CodecBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class CodecBase extends AnyFunSpec with ScalaCheckPropertyChecks with EitherValu
)(implicit codec: Codec[A]): Assertion =
assert(codec.schema.swap.value.message == expectedErrorMessage)

def assertSchemaError[A](implicit codec: Codec[A]): Assertion =
assert(codec.schema.isLeft, codec.schema)

def assertDecodeError[A](
value: Any,
schema: Schema,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package vulcan.generic.examples

import vulcan.{AvroError, Codec}
import vulcan.generic._

object AvroRecordDefault {
sealed trait Enum extends Product {
self =>
def value: String = self.productPrefix
}

object Enum {
case object A extends Enum

case object B extends Enum

implicit val codec: Codec[Enum] = deriveEnum(
symbols = List(A.value, B.value),
encode = _.value,
decode = {
case "A" => Right(A)
case "B" => Right(B)
case other => Left(AvroError(s"Invalid S: $other"))
}
)
}

sealed trait Union

object Union {
case class A(a: Int) extends Union

case class B(b: String) extends Union

implicit val codec: Codec[Union] = Codec.derive
}

case class Foo(
a: Int = 1,
b: String = "foo",
c: Option[String] = None
)

object Foo {
implicit val codec: Codec[Foo] = Codec.derive
}

case class InvalidDefault2(
a: Option[String] = Some("foo")
)
object InvalidDefault2 {
implicit val codec: Codec[InvalidDefault2] = Codec.derive
}

case class HasSFirst(
s: Enum = Enum.A
)
object HasSFirst {
implicit val codec: Codec[HasSFirst] = Codec.derive
}

case class HasSSecond(
s: Enum = Enum.B
)
object HasSSecond {
implicit val codec: Codec[HasSSecond] = Codec.derive
}

case class HasUnion(
u: Union = Union.A(1)
)
object HasUnion {
implicit val codec: Codec[HasUnion] = Codec.derive
}

case class Empty()
object Empty {
implicit val codec: Codec[Empty] = Codec.derive
}

case class HasUnionSecond(
u: Union = Union.B("foo")
)
object HasUnionSecond {
implicit val codec: Codec[HasUnionSecond] = Codec.derive
}
}


2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version = 1.8.3
sbt.version = 1.9.8

0 comments on commit 13f7fc1

Please sign in to comment.