Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change union decoding to accept schemas in union #191

Merged
merged 1 commit into from
May 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 45 additions & 55 deletions modules/core/src/main/scala/vulcan/Codec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1445,64 +1445,54 @@ final object Codec {
Left(AvroError.encodeExhaustedAlternatives(a, None))
},
(value, schema) => {
schema.getType() match {
case Schema.Type.UNION =>
value match {
case container: GenericContainer =>
val altName =
container.getSchema.getName

val altUnionSchema =
schema.getTypes.asScala
.find(_.getName == altName)
.toRight(AvroError.decodeMissingUnionSchema(altName, None))

def altMatching =
alts
.find(_.codec.schema.exists(_.getName == altName))
.toRight(AvroError.decodeMissingUnionAlternative(altName, None))

altUnionSchema.flatMap { altSchema =>
altMatching.flatMap { alt =>
alt.codec
.decode(container, altSchema)
.map(alt.prism.reverseGet)
}
}
val schemaTypes =
schema.getType() match {
case Schema.Type.UNION => schema.getTypes.asScala
case _ => Seq(schema)
}

case other =>
val schemaTypes =
schema.getTypes.asScala

alts
.collectFirstSome { alt =>
alt.codec.schema
.traverse { altSchema =>
val altName = altSchema.getName
schemaTypes
.find(_.getName == altName)
.flatMap { schema =>
alt.codec
.decode(other, schema)
.map(alt.prism.reverseGet)
.toOption
}
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other, None))
}
value match {
case container: GenericContainer =>
val altName =
container.getSchema.getName

val altUnionSchema =
schemaTypes
.find(_.getName == altName)
.toRight(AvroError.decodeMissingUnionSchema(altName, None))

def altMatching =
alts
.find(_.codec.schema.exists(_.getName == altName))
.toRight(AvroError.decodeMissingUnionAlternative(altName, None))

altUnionSchema.flatMap { altSchema =>
altMatching.flatMap { alt =>
alt.codec
.decode(container, altSchema)
.map(alt.prism.reverseGet)
}
}

case schemaType =>
Left {
AvroError
.decodeUnexpectedSchemaType(
"union",
schemaType,
Schema.Type.UNION
)
}
case other =>
alts
.collectFirstSome { alt =>
alt.codec.schema
.traverse { altSchema =>
val altName = altSchema.getName
schemaTypes
.find(_.getName == altName)
.flatMap { schema =>
alt.codec
.decode(other, schema)
.map(alt.prism.reverseGet)
.toOption
}
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other, None))
}
}
}
)
Expand Down
26 changes: 21 additions & 5 deletions modules/core/src/test/scala/vulcan/CodecSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1473,11 +1473,19 @@ final class CodecSpec extends BaseSpec {
}

describe("decode") {
it("should error if schema is not union") {
it("should error if schema is not in union") {
assertDecodeError[Option[Int]](
unsafeEncode(Option(1)),
unsafeSchema[Int],
"Got unexpected schema type INT while decoding union, expected schema type UNION"
unsafeSchema[String],
"Exhausted alternatives for type java.lang.Integer"
)
}

it("should decode if schema is part of union") {
assertDecodeIs[Option[Int]](
unsafeEncode(Option(1)),
Right(Some(1)),
Some(unsafeSchema[Int])
)
}

Expand Down Expand Up @@ -2585,11 +2593,19 @@ final class CodecSpec extends BaseSpec {
}

describe("decode") {
it("should error if schema is not union") {
it("should error if schema is not in union") {
assertDecodeError[SealedTraitCaseClass](
unsafeEncode[SealedTraitCaseClass](FirstInSealedTraitCaseClass(0)),
unsafeSchema[String],
"Got unexpected schema type STRING while decoding union, expected schema type UNION"
"Missing schema FirstInSealedTraitCaseClass in union"
)
}

it("should decode if schema is part of union") {
assertDecodeIs[SealedTraitCaseClass](
unsafeEncode[SealedTraitCaseClass](FirstInSealedTraitCaseClass(0)),
Right(FirstInSealedTraitCaseClass(0)),
Some(unsafeSchema[FirstInSealedTraitCaseClass])
)
}

Expand Down
176 changes: 78 additions & 98 deletions modules/generic/src/main/scala/vulcan/generic/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,61 +46,51 @@ package object generic {
tailCodec.value.encode
),
(value, schema) => {
schema.getType() match {
case Schema.Type.UNION =>
value match {
case container: GenericContainer =>
headCodec.schema.flatMap {
headSchema =>
val name = container.getSchema.getName
if (headSchema.getName == name) {
val subschema =
schema.getTypes.asScala
.find(_.getName == name)
.toRight(AvroError.decodeMissingUnionSchema(name, Some("Coproduct")))
val schemaTypes =
schema.getType() match {
case Schema.Type.UNION => schema.getTypes.asScala
case _ => Seq(schema)
}

subschema
.flatMap(headCodec.decode(container, _))
.map(Inl(_))
} else {
tailCodec.value
.decode(container, schema)
.map(Inr(_))
}
}
value match {
case container: GenericContainer =>
headCodec.schema.flatMap {
headSchema =>
val name = container.getSchema.getName
if (headSchema.getName == name) {
val subschema =
schemaTypes
.find(_.getName == name)
.toRight(AvroError.decodeMissingUnionSchema(name, Some("Coproduct")))

case other =>
val schemaTypes =
schema.getTypes.asScala
subschema
.flatMap(headCodec.decode(container, _))
.map(Inl(_))
} else {
tailCodec.value
.decode(container, schema)
.map(Inr(_))
}
}

headCodec.schema
.traverse { headSchema =>
val headName = headSchema.getName
schemaTypes
.find(_.getName == headName)
.flatMap { schema =>
headCodec
.decode(other, schema)
.map(Inl(_))
.toOption
}
}
.getOrElse {
tailCodec.value
case other =>
headCodec.schema
.traverse { headSchema =>
val headName = headSchema.getName
schemaTypes
.find(_.getName == headName)
.flatMap { schema =>
headCodec
.decode(other, schema)
.map(Inr(_))
.map(Inl(_))
.toOption
}
}

case schemaType =>
Left {
AvroError
.decodeUnexpectedSchemaType(
"Coproduct",
schemaType,
Schema.Type.UNION
)
}
}
.getOrElse {
tailCodec.value
.decode(other, schema)
.map(Inr(_))
}
}
}
)
Expand Down Expand Up @@ -241,61 +231,51 @@ package object generic {
subtype.typeclass.encode(subtype.cast(a))
},
(value, schema) => {
schema.getType() match {
case Schema.Type.UNION =>
value match {
case container: GenericContainer =>
val subtypeName =
container.getSchema.getName
val schemaTypes =
schema.getType() match {
case Schema.Type.UNION => schema.getTypes.asScala
case _ => Seq(schema)
}

val subtypeUnionSchema =
schema.getTypes.asScala
.find(_.getName == subtypeName)
.toRight(AvroError.decodeMissingUnionSchema(subtypeName, Some(typeName)))
value match {
case container: GenericContainer =>
val subtypeName =
container.getSchema.getName

def subtypeMatching =
sealedTrait.subtypes
.find(_.typeclass.schema.exists(_.getName == subtypeName))
.toRight(AvroError.decodeMissingUnionAlternative(subtypeName, Some(typeName)))
val subtypeUnionSchema =
schemaTypes
.find(_.getName == subtypeName)
.toRight(AvroError.decodeMissingUnionSchema(subtypeName, Some(typeName)))

subtypeUnionSchema.flatMap { subtypeSchema =>
subtypeMatching.flatMap { subtype =>
subtype.typeclass.decode(container, subtypeSchema)
}
}
def subtypeMatching =
sealedTrait.subtypes
.find(_.typeclass.schema.exists(_.getName == subtypeName))
.toRight(AvroError.decodeMissingUnionAlternative(subtypeName, Some(typeName)))

case other =>
val schemaTypes =
schema.getTypes.asScala
subtypeUnionSchema.flatMap { subtypeSchema =>
subtypeMatching.flatMap { subtype =>
subtype.typeclass.decode(container, subtypeSchema)
}
}

sealedTrait.subtypes.toList
.collectFirstSome { subtype =>
subtype.typeclass.schema
.traverse { subtypeSchema =>
val subtypeName = subtypeSchema.getName
schemaTypes
.find(_.getName == subtypeName)
.flatMap { schema =>
subtype.typeclass
.decode(other, schema)
.toOption
}
case other =>
sealedTrait.subtypes.toList
.collectFirstSome { subtype =>
subtype.typeclass.schema
.traverse { subtypeSchema =>
val subtypeName = subtypeSchema.getName
schemaTypes
.find(_.getName == subtypeName)
.flatMap { schema =>
subtype.typeclass
.decode(other, schema)
.toOption
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other, Some(typeName)))
}
}

case schemaType =>
Left {
AvroError
.decodeUnexpectedSchemaType(
typeName,
schemaType,
Schema.Type.UNION
)
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other, Some(typeName)))
}
}
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ sealed trait SealedTraitCaseClass

final case class CaseClassInSealedTrait(value: Int) extends SealedTraitCaseClass

object CaseClassInSealedTrait {
implicit val codec: Codec[CaseClassInSealedTrait] =
Codec.derive
}

object SealedTraitCaseClass {
implicit val codec: Codec[SealedTraitCaseClass] =
Codec.derive
Expand Down
Loading