diff --git a/.gitignore b/.gitignore index 0b4d3394..35140cfc 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ target/ .metals/ .vscode/ .bloop/ -metals.sbt \ No newline at end of file +metals.sbt +.idea/ \ No newline at end of file diff --git a/build.sbt b/build.sbt index eaa66e08..be52a454 100644 --- a/build.sbt +++ b/build.sbt @@ -94,6 +94,10 @@ lazy val generic = project scalaSettings ++ Seq( crossScalaVersions += scala3 ), + // magnolia requires compilation with the -Yretain-trees flag to support case class field default values on Scala 3 + Test / scalacOptions ++= (if (CrossVersion.partialVersion(scalaVersion.value).exists(_._1 == 3)) + Seq("-Yretain-trees") + else Nil), testSettings ) .dependsOn(core % "compile->compile;test->test") diff --git a/modules/generic/src/main/scala-2/vulcan/generic/package.scala b/modules/generic/src/main/scala-2/vulcan/generic/package.scala index 27b5fe0e..09274bc2 100644 --- a/modules/generic/src/main/scala-2/vulcan/generic/package.scala +++ b/modules/generic/src/main/scala-2/vulcan/generic/package.scala @@ -97,8 +97,11 @@ package object generic { doc = param.annotations.collectFirst { case AvroDoc(doc) => doc }, - default = (if (codec.schema.exists(_.isNullable) && nullDefaultField) Some(None) - else None).asInstanceOf[Option[param.PType]] // TODO: remove cast + default = param.default.orElse( + if (codec.schema.exists(_.isNullable) && nullDefaultField) + Some(None.asInstanceOf[param.PType]) // TODO: remove cast + else None + ) ).widen } .map(caseClass.rawConstruct(_)) diff --git a/modules/generic/src/main/scala-3/vulcan/generic/package.scala b/modules/generic/src/main/scala-3/vulcan/generic/package.scala index ed59ef38..2649d3aa 100644 --- a/modules/generic/src/main/scala-3/vulcan/generic/package.scala +++ b/modules/generic/src/main/scala-3/vulcan/generic/package.scala @@ -68,8 +68,11 @@ package object generic { doc = param.annotations.collectFirst { case AvroDoc(doc) => doc }, - default = (if (codec.schema.exists(_.isNullable) && nullDefaultField) Some(None) - else None).asInstanceOf[Option[param.PType]] // TODO: remove cast + default = param.default.orElse( + Option.when(codec.schema.exists(_.isNullable) && nullDefaultField)( + None.asInstanceOf[param.PType] // TODO: remove cast + ) + ) ).widen } .map(caseClass.rawConstruct(_)) diff --git a/modules/generic/src/test/scala/vulcan/generic/AvroFieldDefaultSpec.scala b/modules/generic/src/test/scala/vulcan/generic/AvroFieldDefaultSpec.scala new file mode 100644 index 00000000..04a2b996 --- /dev/null +++ b/modules/generic/src/test/scala/vulcan/generic/AvroFieldDefaultSpec.scala @@ -0,0 +1,44 @@ +/* + * Copyright 2019-2023 OVO Energy Limited + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package vulcan.generic + +import examples.AvroRecordDefault._ +import org.apache.avro.JsonProperties + +final class AvroFieldDefaultSpec extends CodecBase { + describe("AvroFieldDefault") { + it("should create a schema with a default for a field") { + assert(Foo.codec.schema.exists(_.getField("a").defaultVal() == 1)) + assert(Foo.codec.schema.exists(_.getField("b").defaultVal() == "foo")) + assert(Foo.codec.schema.exists(_.getField("c").defaultVal() == JsonProperties.NULL_VALUE)) + } + + it("should fail when annotating an Option") { + assertSchemaError[InvalidDefault2] + } + + it("should succeed when annotating an enum first element") { + assert(HasSFirst.codec.schema.exists(_.getField("s").defaultVal() == "A")) + } + + it("should succeed when annotating an enum second element") { + assert(HasSSecond.codec.schema.exists(_.getField("s").defaultVal() == "B")) + } + + it("should succeed with the first member of a union") { + assertSchemaIs[HasUnion]( + """{"type":"record","name":"HasUnion","namespace":"vulcan.generic.examples.AvroRecordDefault","fields":[{"name":"u","type":[{"type":"record","name":"A","namespace":"vulcan.generic.examples.AvroRecordDefault.Union","fields":[{"name":"a","type":"int"}]},{"type":"record","name":"B","namespace":"vulcan.generic.examples.AvroRecordDefault.Union","fields":[{"name":"b","type":"string"}]}],"default":{"a":1}}]}""" + ) + val result = unsafeDecode[HasUnion](unsafeEncode[Empty](Empty())) + assert(result == HasUnion(Union.A(1))) + } + + it("should fail with the second member of a union") { + assertSchemaError[HasUnionSecond] + } + } +} diff --git a/modules/generic/src/test/scala/vulcan/generic/CodecBase.scala b/modules/generic/src/test/scala/vulcan/generic/CodecBase.scala index 3b66e529..c96ad6db 100644 --- a/modules/generic/src/test/scala/vulcan/generic/CodecBase.scala +++ b/modules/generic/src/test/scala/vulcan/generic/CodecBase.scala @@ -51,6 +51,9 @@ class CodecBase extends AnyFunSpec with ScalaCheckPropertyChecks with EitherValu )(implicit codec: Codec[A]): Assertion = assert(codec.schema.swap.value.message == expectedErrorMessage) + def assertSchemaError[A](implicit codec: Codec[A]): Assertion = + assert(codec.schema.isLeft, codec.schema) + def assertDecodeError[A]( value: Any, schema: Schema, diff --git a/modules/generic/src/test/scala/vulcan/generic/examples/AvroRecordDefault.scala b/modules/generic/src/test/scala/vulcan/generic/examples/AvroRecordDefault.scala new file mode 100644 index 00000000..5bb29e3d --- /dev/null +++ b/modules/generic/src/test/scala/vulcan/generic/examples/AvroRecordDefault.scala @@ -0,0 +1,89 @@ +package vulcan.generic.examples + +import vulcan.{AvroError, Codec} +import vulcan.generic._ + +object AvroRecordDefault { + sealed trait Enum extends Product { + self => + def value: String = self.productPrefix + } + + object Enum { + case object A extends Enum + + case object B extends Enum + + implicit val codec: Codec[Enum] = deriveEnum( + symbols = List(A.value, B.value), + encode = _.value, + decode = { + case "A" => Right(A) + case "B" => Right(B) + case other => Left(AvroError(s"Invalid S: $other")) + } + ) + } + + sealed trait Union + + object Union { + case class A(a: Int) extends Union + + case class B(b: String) extends Union + + implicit val codec: Codec[Union] = Codec.derive + } + + case class Foo( + a: Int = 1, + b: String = "foo", + c: Option[String] = None + ) + + object Foo { + implicit val codec: Codec[Foo] = Codec.derive + } + + case class InvalidDefault2( + a: Option[String] = Some("foo") + ) + object InvalidDefault2 { + implicit val codec: Codec[InvalidDefault2] = Codec.derive + } + + case class HasSFirst( + s: Enum = Enum.A + ) + object HasSFirst { + implicit val codec: Codec[HasSFirst] = Codec.derive + } + + case class HasSSecond( + s: Enum = Enum.B + ) + object HasSSecond { + implicit val codec: Codec[HasSSecond] = Codec.derive + } + + case class HasUnion( + u: Union = Union.A(1) + ) + object HasUnion { + implicit val codec: Codec[HasUnion] = Codec.derive + } + + case class Empty() + object Empty { + implicit val codec: Codec[Empty] = Codec.derive + } + + case class HasUnionSecond( + u: Union = Union.B("foo") + ) + object HasUnionSecond { + implicit val codec: Codec[HasUnionSecond] = Codec.derive + } +} + + diff --git a/project/build.properties b/project/build.properties index ef3d2662..0aa5c39b 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version = 1.8.3 +sbt.version = 1.9.8