From 9a7e801b9d571567b2975fae44b34ecb84ae1656 Mon Sep 17 00:00:00 2001 From: Viktor Lovgren Date: Tue, 12 May 2020 08:29:25 +0200 Subject: [PATCH] Change union decoding to accept schemas in union --- .../core/src/main/scala/vulcan/Codec.scala | 100 +++++----- .../src/test/scala/vulcan/CodecSpec.scala | 26 ++- .../main/scala/vulcan/generic/package.scala | 176 ++++++++---------- .../examples/SealedTraitCaseClass.scala | 5 + .../test/scala/vulcan/generic/CodecSpec.scala | 37 +++- 5 files changed, 180 insertions(+), 164 deletions(-) diff --git a/modules/core/src/main/scala/vulcan/Codec.scala b/modules/core/src/main/scala/vulcan/Codec.scala index 61d2e583..4f92f148 100644 --- a/modules/core/src/main/scala/vulcan/Codec.scala +++ b/modules/core/src/main/scala/vulcan/Codec.scala @@ -1445,64 +1445,54 @@ final object Codec { Left(AvroError.encodeExhaustedAlternatives(a, None)) }, (value, schema) => { - schema.getType() match { - case Schema.Type.UNION => - value match { - case container: GenericContainer => - val altName = - container.getSchema.getName - - val altUnionSchema = - schema.getTypes.asScala - .find(_.getName == altName) - .toRight(AvroError.decodeMissingUnionSchema(altName, None)) - - def altMatching = - alts - .find(_.codec.schema.exists(_.getName == altName)) - .toRight(AvroError.decodeMissingUnionAlternative(altName, None)) - - altUnionSchema.flatMap { altSchema => - altMatching.flatMap { alt => - alt.codec - .decode(container, altSchema) - .map(alt.prism.reverseGet) - } - } + val schemaTypes = + schema.getType() match { + case Schema.Type.UNION => schema.getTypes.asScala + case _ => Seq(schema) + } - case other => - val schemaTypes = - schema.getTypes.asScala - - alts - .collectFirstSome { alt => - alt.codec.schema - .traverse { altSchema => - val altName = altSchema.getName - schemaTypes - .find(_.getName == altName) - .flatMap { schema => - alt.codec - .decode(other, schema) - .map(alt.prism.reverseGet) - .toOption - } - } - } - .getOrElse { - Left(AvroError.decodeExhaustedAlternatives(other, None)) - } + value match { + case container: GenericContainer => + val altName = + container.getSchema.getName + + val altUnionSchema = + schemaTypes + .find(_.getName == altName) + .toRight(AvroError.decodeMissingUnionSchema(altName, None)) + + def altMatching = + alts + .find(_.codec.schema.exists(_.getName == altName)) + .toRight(AvroError.decodeMissingUnionAlternative(altName, None)) + + altUnionSchema.flatMap { altSchema => + altMatching.flatMap { alt => + alt.codec + .decode(container, altSchema) + .map(alt.prism.reverseGet) + } } - case schemaType => - Left { - AvroError - .decodeUnexpectedSchemaType( - "union", - schemaType, - Schema.Type.UNION - ) - } + case other => + alts + .collectFirstSome { alt => + alt.codec.schema + .traverse { altSchema => + val altName = altSchema.getName + schemaTypes + .find(_.getName == altName) + .flatMap { schema => + alt.codec + .decode(other, schema) + .map(alt.prism.reverseGet) + .toOption + } + } + } + .getOrElse { + Left(AvroError.decodeExhaustedAlternatives(other, None)) + } } } ) diff --git a/modules/core/src/test/scala/vulcan/CodecSpec.scala b/modules/core/src/test/scala/vulcan/CodecSpec.scala index 7461b845..16370163 100644 --- a/modules/core/src/test/scala/vulcan/CodecSpec.scala +++ b/modules/core/src/test/scala/vulcan/CodecSpec.scala @@ -1473,11 +1473,19 @@ final class CodecSpec extends BaseSpec { } describe("decode") { - it("should error if schema is not union") { + it("should error if schema is not in union") { assertDecodeError[Option[Int]]( unsafeEncode(Option(1)), - unsafeSchema[Int], - "Got unexpected schema type INT while decoding union, expected schema type UNION" + unsafeSchema[String], + "Exhausted alternatives for type java.lang.Integer" + ) + } + + it("should decode if schema is part of union") { + assertDecodeIs[Option[Int]]( + unsafeEncode(Option(1)), + Right(Some(1)), + Some(unsafeSchema[Int]) ) } @@ -2585,11 +2593,19 @@ final class CodecSpec extends BaseSpec { } describe("decode") { - it("should error if schema is not union") { + it("should error if schema is not in union") { assertDecodeError[SealedTraitCaseClass]( unsafeEncode[SealedTraitCaseClass](FirstInSealedTraitCaseClass(0)), unsafeSchema[String], - "Got unexpected schema type STRING while decoding union, expected schema type UNION" + "Missing schema FirstInSealedTraitCaseClass in union" + ) + } + + it("should decode if schema is part of union") { + assertDecodeIs[SealedTraitCaseClass]( + unsafeEncode[SealedTraitCaseClass](FirstInSealedTraitCaseClass(0)), + Right(FirstInSealedTraitCaseClass(0)), + Some(unsafeSchema[FirstInSealedTraitCaseClass]) ) } diff --git a/modules/generic/src/main/scala/vulcan/generic/package.scala b/modules/generic/src/main/scala/vulcan/generic/package.scala index 82bfa6d4..1151dfac 100644 --- a/modules/generic/src/main/scala/vulcan/generic/package.scala +++ b/modules/generic/src/main/scala/vulcan/generic/package.scala @@ -46,61 +46,51 @@ package object generic { tailCodec.value.encode ), (value, schema) => { - schema.getType() match { - case Schema.Type.UNION => - value match { - case container: GenericContainer => - headCodec.schema.flatMap { - headSchema => - val name = container.getSchema.getName - if (headSchema.getName == name) { - val subschema = - schema.getTypes.asScala - .find(_.getName == name) - .toRight(AvroError.decodeMissingUnionSchema(name, Some("Coproduct"))) + val schemaTypes = + schema.getType() match { + case Schema.Type.UNION => schema.getTypes.asScala + case _ => Seq(schema) + } - subschema - .flatMap(headCodec.decode(container, _)) - .map(Inl(_)) - } else { - tailCodec.value - .decode(container, schema) - .map(Inr(_)) - } - } + value match { + case container: GenericContainer => + headCodec.schema.flatMap { + headSchema => + val name = container.getSchema.getName + if (headSchema.getName == name) { + val subschema = + schemaTypes + .find(_.getName == name) + .toRight(AvroError.decodeMissingUnionSchema(name, Some("Coproduct"))) - case other => - val schemaTypes = - schema.getTypes.asScala + subschema + .flatMap(headCodec.decode(container, _)) + .map(Inl(_)) + } else { + tailCodec.value + .decode(container, schema) + .map(Inr(_)) + } + } - headCodec.schema - .traverse { headSchema => - val headName = headSchema.getName - schemaTypes - .find(_.getName == headName) - .flatMap { schema => - headCodec - .decode(other, schema) - .map(Inl(_)) - .toOption - } - } - .getOrElse { - tailCodec.value + case other => + headCodec.schema + .traverse { headSchema => + val headName = headSchema.getName + schemaTypes + .find(_.getName == headName) + .flatMap { schema => + headCodec .decode(other, schema) - .map(Inr(_)) + .map(Inl(_)) + .toOption } - } - - case schemaType => - Left { - AvroError - .decodeUnexpectedSchemaType( - "Coproduct", - schemaType, - Schema.Type.UNION - ) - } + } + .getOrElse { + tailCodec.value + .decode(other, schema) + .map(Inr(_)) + } } } ) @@ -241,61 +231,51 @@ package object generic { subtype.typeclass.encode(subtype.cast(a)) }, (value, schema) => { - schema.getType() match { - case Schema.Type.UNION => - value match { - case container: GenericContainer => - val subtypeName = - container.getSchema.getName + val schemaTypes = + schema.getType() match { + case Schema.Type.UNION => schema.getTypes.asScala + case _ => Seq(schema) + } - val subtypeUnionSchema = - schema.getTypes.asScala - .find(_.getName == subtypeName) - .toRight(AvroError.decodeMissingUnionSchema(subtypeName, Some(typeName))) + value match { + case container: GenericContainer => + val subtypeName = + container.getSchema.getName - def subtypeMatching = - sealedTrait.subtypes - .find(_.typeclass.schema.exists(_.getName == subtypeName)) - .toRight(AvroError.decodeMissingUnionAlternative(subtypeName, Some(typeName))) + val subtypeUnionSchema = + schemaTypes + .find(_.getName == subtypeName) + .toRight(AvroError.decodeMissingUnionSchema(subtypeName, Some(typeName))) - subtypeUnionSchema.flatMap { subtypeSchema => - subtypeMatching.flatMap { subtype => - subtype.typeclass.decode(container, subtypeSchema) - } - } + def subtypeMatching = + sealedTrait.subtypes + .find(_.typeclass.schema.exists(_.getName == subtypeName)) + .toRight(AvroError.decodeMissingUnionAlternative(subtypeName, Some(typeName))) - case other => - val schemaTypes = - schema.getTypes.asScala + subtypeUnionSchema.flatMap { subtypeSchema => + subtypeMatching.flatMap { subtype => + subtype.typeclass.decode(container, subtypeSchema) + } + } - sealedTrait.subtypes.toList - .collectFirstSome { subtype => - subtype.typeclass.schema - .traverse { subtypeSchema => - val subtypeName = subtypeSchema.getName - schemaTypes - .find(_.getName == subtypeName) - .flatMap { schema => - subtype.typeclass - .decode(other, schema) - .toOption - } + case other => + sealedTrait.subtypes.toList + .collectFirstSome { subtype => + subtype.typeclass.schema + .traverse { subtypeSchema => + val subtypeName = subtypeSchema.getName + schemaTypes + .find(_.getName == subtypeName) + .flatMap { schema => + subtype.typeclass + .decode(other, schema) + .toOption } } - .getOrElse { - Left(AvroError.decodeExhaustedAlternatives(other, Some(typeName))) - } - } - - case schemaType => - Left { - AvroError - .decodeUnexpectedSchemaType( - typeName, - schemaType, - Schema.Type.UNION - ) - } + } + .getOrElse { + Left(AvroError.decodeExhaustedAlternatives(other, Some(typeName))) + } } } ) diff --git a/modules/generic/src/test/scala/vulcan/examples/SealedTraitCaseClass.scala b/modules/generic/src/test/scala/vulcan/examples/SealedTraitCaseClass.scala index a47fd851..b6e33ad4 100644 --- a/modules/generic/src/test/scala/vulcan/examples/SealedTraitCaseClass.scala +++ b/modules/generic/src/test/scala/vulcan/examples/SealedTraitCaseClass.scala @@ -7,6 +7,11 @@ sealed trait SealedTraitCaseClass final case class CaseClassInSealedTrait(value: Int) extends SealedTraitCaseClass +object CaseClassInSealedTrait { + implicit val codec: Codec[CaseClassInSealedTrait] = + Codec.derive +} + object SealedTraitCaseClass { implicit val codec: Codec[SealedTraitCaseClass] = Codec.derive diff --git a/modules/generic/src/test/scala/vulcan/generic/CodecSpec.scala b/modules/generic/src/test/scala/vulcan/generic/CodecSpec.scala index 41d769fc..d6b63ad2 100644 --- a/modules/generic/src/test/scala/vulcan/generic/CodecSpec.scala +++ b/modules/generic/src/test/scala/vulcan/generic/CodecSpec.scala @@ -94,12 +94,21 @@ final class CodecSpec extends AnyFunSpec with ScalaCheckPropertyChecks with Eith } describe("decode") { - it("should error if schema is not union") { + it("should error if schema is not in union") { type A = Int :+: String :+: CNil assertDecodeError[A]( unsafeEncode(Coproduct[A](123)), unsafeSchema[String], - "Got unexpected schema type STRING while decoding Coproduct, expected schema type UNION" + "Exhausted alternatives for type java.lang.Integer while decoding Coproduct" + ) + } + + it("should decode if schema is part of union") { + type A = Int :+: String :+: CNil + assertDecodeIs[A]( + unsafeEncode(Coproduct[A](123)), + Right(Coproduct[A](123)), + Some(unsafeSchema[Int]) ) } @@ -338,11 +347,19 @@ final class CodecSpec extends AnyFunSpec with ScalaCheckPropertyChecks with Eith } describe("decode") { - it("should error if schema is not union") { + it("should error if schema is not in union") { assertDecodeError[SealedTraitCaseClass]( unsafeEncode[SealedTraitCaseClass](CaseClassInSealedTrait(0)), unsafeSchema[String], - "Got unexpected schema type STRING while decoding vulcan.examples.SealedTraitCaseClass, expected schema type UNION" + "Missing schema CaseClassInSealedTrait in union for type vulcan.examples.SealedTraitCaseClass" + ) + } + + it("should decode if schema is part of union") { + assertDecodeIs[SealedTraitCaseClass]( + unsafeEncode[SealedTraitCaseClass](CaseClassInSealedTrait(0)), + Right(CaseClassInSealedTrait(0)), + Some(unsafeSchema[CaseClassInSealedTrait]) ) } @@ -401,9 +418,17 @@ final class CodecSpec extends AnyFunSpec with ScalaCheckPropertyChecks with Eith def assertDecodeIs[A]( value: Any, - decoded: Either[AvroError, A] + decoded: Either[AvroError, A], + schema: Option[Schema] = None )(implicit codec: Codec[A]): Assertion = - assert(unsafeDecode[A](value) === decoded.value) + assert { + val decode = + schema + .map(codec.decode(value, _).value) + .getOrElse(unsafeDecode[A](value)) + + decode === decoded.value + } def assertSchemaError[A]( expectedErrorMessage: String