Skip to content

Commit

Permalink
Merge pull request #334 from fd4s/alias-support
Browse files Browse the repository at this point in the history
Support decoding aliased schemas and fields
  • Loading branch information
bplommer authored May 3, 2021
2 parents a3e51f7 + e29e94a commit 59375e8
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 68 deletions.
140 changes: 74 additions & 66 deletions modules/core/src/main/scala/vulcan/Codec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -936,14 +936,14 @@ object Codec extends CodecCompanionCompat {
free.foldMap {
new (Field[A, *] ~> Either[AvroError, *]) {
def apply[B](field: Field[A, B]): Either[AvroError, B] =
record.getSchema.getField(field.name) match {
case null =>
field.default.toRight {
AvroError.decodeMissingRecordField(field.name)
}
case schemaField =>
field.codec.decode(record.get(schemaField.pos), schemaField.schema)
}
(field.name +: field.aliases.toList)
.collectFirstSome { name =>
Option(record.getSchema.getField(name))
}
.fold(field.default.toRight(AvroError.decodeMissingRecordField(field.name))) {
schemaField =>
field.codec.decode(record.get(schemaField.pos), schemaField.schema)
}
}
}
}
Expand Down Expand Up @@ -1079,68 +1079,76 @@ object Codec extends CodecCompanionCompat {
.map(schemas => Schema.createUnion(schemas.asJava))
}

Codec.instance[Any, A](
schema,
a =>
alts
.foldMapK { alt =>
alt.prism.getOption(a).map(alt.codec.encode(_))
}
.getOrElse {
Left(AvroError.encodeExhaustedAlternatives(a))
},
(value, schema) => {
val schemaTypes =
schema.getType() match {
case UNION => schema.getTypes.asScala
case _ => Seq(schema)
}

value match {
case container: GenericContainer =>
val altName =
container.getSchema.getName

val altUnionSchema =
schemaTypes
.find(_.getName == altName)
.toRight(AvroError.decodeMissingUnionSchema(altName))

def altMatching =
alts
.find(_.codec.schema.exists(_.getName == altName))
.toRight(AvroError.decodeMissingUnionAlternative(altName))

altUnionSchema.flatMap { altSchema =>
altMatching.flatMap { alt =>
alt.codec
.decode(container, altSchema)
.map(alt.prism.reverseGet)
}
Codec
.instance[Any, A](
schema,
a =>
alts
.foldMapK { alt =>
alt.prism.getOption(a).map(alt.codec.encode(_))
}
.getOrElse {
Left(AvroError.encodeExhaustedAlternatives(a))
},
(value, schema) => {
val schemaTypes =
schema.getType() match {
case UNION => schema.getTypes.asScala
case _ => Seq(schema)
}

case other =>
alts
.collectFirstSome { alt =>
alt.codec.schema
.traverse { altSchema =>
val altName = altSchema.getName
schemaTypes
.find(_.getName == altName)
.flatMap { schema =>
alt.codec
.decode(other, schema)
.map(alt.prism.reverseGet)
.toOption
}
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other))
value match {
case container: GenericContainer =>
val altName =
container.getSchema.getName

val altWriterSchema =
schemaTypes
.find(_.getName == altName)
.toRight(AvroError.decodeMissingUnionSchema(altName))

def altMatching =
alts
.find(_.codec.schema.exists { schema =>
schema.getType match {
case RECORD | FIXED | ENUM =>
schema.getName == altName || schema.getAliases.asScala
.exists(alias => alias == altName || alias.endsWith(s".$altName"))
case _ => false
}
})
.toRight(AvroError.decodeMissingUnionAlternative(altName))

altWriterSchema.flatMap { altSchema =>
altMatching.flatMap { alt =>
alt.codec
.decode(container, altSchema)
.map(alt.prism.reverseGet)
}
}

case other =>
alts
.collectFirstSome { alt =>
alt.codec.schema
.traverse { altSchema =>
val altName = altSchema.getName
schemaTypes
.find(_.getName == altName)
.flatMap { schema =>
alt.codec
.decode(other, schema)
.map(alt.prism.reverseGet)
.toOption
}
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other))
}
}
}
}
)
)
}.withTypeName("union")

/**
Expand Down
44 changes: 42 additions & 2 deletions modules/core/src/test/scala/vulcan/CodecSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ package vulcan

import cats.data._
import cats.implicits._

import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.time.{Instant, LocalDate, LocalTime}
import java.util.concurrent.TimeUnit
import java.time.temporal.ChronoUnit
import java.util.UUID

import org.apache.avro.{Conversions, LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.generic.GenericData
import org.apache.avro.util.Utf8
import org.scalacheck.Gen
import org.scalatest.Assertion
import vulcan.examples._
import vulcan.examples.{SecondInSealedTraitCaseClass, _}
import vulcan.internal.converters.collection._

import scala.util.{Failure, Success, Try}
Expand Down Expand Up @@ -2508,6 +2508,19 @@ final class CodecSpec extends BaseSpec with CodecSpecHelpers {
Right(Test(None))
)
}

it("should decode field with aliased name") {
case class Aliased(aliasedField: Int)
implicit val codec: Codec[Aliased] =
Codec.record("CaseClassField", "") { field =>
field("aliasedField", _.aliasedField, aliases = Seq("value")).map(Aliased(_))
}

assertDecodeIs[Aliased](
unsafeEncode(CaseClassField(3)),
Right(Aliased(3))
)
}
}
}

Expand Down Expand Up @@ -2854,6 +2867,33 @@ final class CodecSpec extends BaseSpec with CodecSpecHelpers {
Right(FirstInSealedTraitCaseClass(0))
)
}

it("should decode using schema with aliased name") {

implicit val secondCodec: Codec[SecondInSealedTraitCaseClass] =
Codec.record(
name = "AliasedInSealedTraitCaseClass",
namespace = "com.example",
aliases = Seq("SecondInSealedTraitCaseClass")
) { field =>
field("value", _.value).map(SecondInSealedTraitCaseClass(_))
}

implicit val codec: Codec[SealedTraitCaseClass] = Codec.union(
alt =>
alt[FirstInSealedTraitCaseClass]
|+| alt[SecondInSealedTraitCaseClass]
|+| alt[ThirdInSealedTraitCaseClass]
)

assertDecodeIs[SealedTraitCaseClass](
unsafeEncode[SealedTraitCaseClass](SecondInSealedTraitCaseClass("foo"))(
SealedTraitCaseClass.sealedTraitCaseClassCodec
),
Right(SecondInSealedTraitCaseClass("foo")),
Some(SealedTraitCaseClass.sealedTraitCaseClassCodec.schema.value)
)
}
}
}

Expand Down

0 comments on commit 59375e8

Please sign in to comment.