Skip to content

Commit

Permalink
Change union decoding to accept schemas in union
Browse files Browse the repository at this point in the history
  • Loading branch information
vlovgr committed May 12, 2020
1 parent af88fc0 commit 9a7e801
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 164 deletions.
100 changes: 45 additions & 55 deletions modules/core/src/main/scala/vulcan/Codec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1445,64 +1445,54 @@ final object Codec {
Left(AvroError.encodeExhaustedAlternatives(a, None))
},
(value, schema) => {
schema.getType() match {
case Schema.Type.UNION =>
value match {
case container: GenericContainer =>
val altName =
container.getSchema.getName

val altUnionSchema =
schema.getTypes.asScala
.find(_.getName == altName)
.toRight(AvroError.decodeMissingUnionSchema(altName, None))

def altMatching =
alts
.find(_.codec.schema.exists(_.getName == altName))
.toRight(AvroError.decodeMissingUnionAlternative(altName, None))

altUnionSchema.flatMap { altSchema =>
altMatching.flatMap { alt =>
alt.codec
.decode(container, altSchema)
.map(alt.prism.reverseGet)
}
}
val schemaTypes =
schema.getType() match {
case Schema.Type.UNION => schema.getTypes.asScala
case _ => Seq(schema)
}

case other =>
val schemaTypes =
schema.getTypes.asScala

alts
.collectFirstSome { alt =>
alt.codec.schema
.traverse { altSchema =>
val altName = altSchema.getName
schemaTypes
.find(_.getName == altName)
.flatMap { schema =>
alt.codec
.decode(other, schema)
.map(alt.prism.reverseGet)
.toOption
}
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other, None))
}
value match {
case container: GenericContainer =>
val altName =
container.getSchema.getName

val altUnionSchema =
schemaTypes
.find(_.getName == altName)
.toRight(AvroError.decodeMissingUnionSchema(altName, None))

def altMatching =
alts
.find(_.codec.schema.exists(_.getName == altName))
.toRight(AvroError.decodeMissingUnionAlternative(altName, None))

altUnionSchema.flatMap { altSchema =>
altMatching.flatMap { alt =>
alt.codec
.decode(container, altSchema)
.map(alt.prism.reverseGet)
}
}

case schemaType =>
Left {
AvroError
.decodeUnexpectedSchemaType(
"union",
schemaType,
Schema.Type.UNION
)
}
case other =>
alts
.collectFirstSome { alt =>
alt.codec.schema
.traverse { altSchema =>
val altName = altSchema.getName
schemaTypes
.find(_.getName == altName)
.flatMap { schema =>
alt.codec
.decode(other, schema)
.map(alt.prism.reverseGet)
.toOption
}
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other, None))
}
}
}
)
Expand Down
26 changes: 21 additions & 5 deletions modules/core/src/test/scala/vulcan/CodecSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1473,11 +1473,19 @@ final class CodecSpec extends BaseSpec {
}

describe("decode") {
it("should error if schema is not union") {
it("should error if schema is not in union") {
assertDecodeError[Option[Int]](
unsafeEncode(Option(1)),
unsafeSchema[Int],
"Got unexpected schema type INT while decoding union, expected schema type UNION"
unsafeSchema[String],
"Exhausted alternatives for type java.lang.Integer"
)
}

it("should decode if schema is part of union") {
assertDecodeIs[Option[Int]](
unsafeEncode(Option(1)),
Right(Some(1)),
Some(unsafeSchema[Int])
)
}

Expand Down Expand Up @@ -2585,11 +2593,19 @@ final class CodecSpec extends BaseSpec {
}

describe("decode") {
it("should error if schema is not union") {
it("should error if schema is not in union") {
assertDecodeError[SealedTraitCaseClass](
unsafeEncode[SealedTraitCaseClass](FirstInSealedTraitCaseClass(0)),
unsafeSchema[String],
"Got unexpected schema type STRING while decoding union, expected schema type UNION"
"Missing schema FirstInSealedTraitCaseClass in union"
)
}

it("should decode if schema is part of union") {
assertDecodeIs[SealedTraitCaseClass](
unsafeEncode[SealedTraitCaseClass](FirstInSealedTraitCaseClass(0)),
Right(FirstInSealedTraitCaseClass(0)),
Some(unsafeSchema[FirstInSealedTraitCaseClass])
)
}

Expand Down
176 changes: 78 additions & 98 deletions modules/generic/src/main/scala/vulcan/generic/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,61 +46,51 @@ package object generic {
tailCodec.value.encode
),
(value, schema) => {
schema.getType() match {
case Schema.Type.UNION =>
value match {
case container: GenericContainer =>
headCodec.schema.flatMap {
headSchema =>
val name = container.getSchema.getName
if (headSchema.getName == name) {
val subschema =
schema.getTypes.asScala
.find(_.getName == name)
.toRight(AvroError.decodeMissingUnionSchema(name, Some("Coproduct")))
val schemaTypes =
schema.getType() match {
case Schema.Type.UNION => schema.getTypes.asScala
case _ => Seq(schema)
}

subschema
.flatMap(headCodec.decode(container, _))
.map(Inl(_))
} else {
tailCodec.value
.decode(container, schema)
.map(Inr(_))
}
}
value match {
case container: GenericContainer =>
headCodec.schema.flatMap {
headSchema =>
val name = container.getSchema.getName
if (headSchema.getName == name) {
val subschema =
schemaTypes
.find(_.getName == name)
.toRight(AvroError.decodeMissingUnionSchema(name, Some("Coproduct")))

case other =>
val schemaTypes =
schema.getTypes.asScala
subschema
.flatMap(headCodec.decode(container, _))
.map(Inl(_))
} else {
tailCodec.value
.decode(container, schema)
.map(Inr(_))
}
}

headCodec.schema
.traverse { headSchema =>
val headName = headSchema.getName
schemaTypes
.find(_.getName == headName)
.flatMap { schema =>
headCodec
.decode(other, schema)
.map(Inl(_))
.toOption
}
}
.getOrElse {
tailCodec.value
case other =>
headCodec.schema
.traverse { headSchema =>
val headName = headSchema.getName
schemaTypes
.find(_.getName == headName)
.flatMap { schema =>
headCodec
.decode(other, schema)
.map(Inr(_))
.map(Inl(_))
.toOption
}
}

case schemaType =>
Left {
AvroError
.decodeUnexpectedSchemaType(
"Coproduct",
schemaType,
Schema.Type.UNION
)
}
}
.getOrElse {
tailCodec.value
.decode(other, schema)
.map(Inr(_))
}
}
}
)
Expand Down Expand Up @@ -241,61 +231,51 @@ package object generic {
subtype.typeclass.encode(subtype.cast(a))
},
(value, schema) => {
schema.getType() match {
case Schema.Type.UNION =>
value match {
case container: GenericContainer =>
val subtypeName =
container.getSchema.getName
val schemaTypes =
schema.getType() match {
case Schema.Type.UNION => schema.getTypes.asScala
case _ => Seq(schema)
}

val subtypeUnionSchema =
schema.getTypes.asScala
.find(_.getName == subtypeName)
.toRight(AvroError.decodeMissingUnionSchema(subtypeName, Some(typeName)))
value match {
case container: GenericContainer =>
val subtypeName =
container.getSchema.getName

def subtypeMatching =
sealedTrait.subtypes
.find(_.typeclass.schema.exists(_.getName == subtypeName))
.toRight(AvroError.decodeMissingUnionAlternative(subtypeName, Some(typeName)))
val subtypeUnionSchema =
schemaTypes
.find(_.getName == subtypeName)
.toRight(AvroError.decodeMissingUnionSchema(subtypeName, Some(typeName)))

subtypeUnionSchema.flatMap { subtypeSchema =>
subtypeMatching.flatMap { subtype =>
subtype.typeclass.decode(container, subtypeSchema)
}
}
def subtypeMatching =
sealedTrait.subtypes
.find(_.typeclass.schema.exists(_.getName == subtypeName))
.toRight(AvroError.decodeMissingUnionAlternative(subtypeName, Some(typeName)))

case other =>
val schemaTypes =
schema.getTypes.asScala
subtypeUnionSchema.flatMap { subtypeSchema =>
subtypeMatching.flatMap { subtype =>
subtype.typeclass.decode(container, subtypeSchema)
}
}

sealedTrait.subtypes.toList
.collectFirstSome { subtype =>
subtype.typeclass.schema
.traverse { subtypeSchema =>
val subtypeName = subtypeSchema.getName
schemaTypes
.find(_.getName == subtypeName)
.flatMap { schema =>
subtype.typeclass
.decode(other, schema)
.toOption
}
case other =>
sealedTrait.subtypes.toList
.collectFirstSome { subtype =>
subtype.typeclass.schema
.traverse { subtypeSchema =>
val subtypeName = subtypeSchema.getName
schemaTypes
.find(_.getName == subtypeName)
.flatMap { schema =>
subtype.typeclass
.decode(other, schema)
.toOption
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other, Some(typeName)))
}
}

case schemaType =>
Left {
AvroError
.decodeUnexpectedSchemaType(
typeName,
schemaType,
Schema.Type.UNION
)
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other, Some(typeName)))
}
}
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ sealed trait SealedTraitCaseClass

final case class CaseClassInSealedTrait(value: Int) extends SealedTraitCaseClass

object CaseClassInSealedTrait {
implicit val codec: Codec[CaseClassInSealedTrait] =
Codec.derive
}

object SealedTraitCaseClass {
implicit val codec: Codec[SealedTraitCaseClass] =
Codec.derive
Expand Down
Loading

0 comments on commit 9a7e801

Please sign in to comment.