From ef9d446264bf1c8bd13e25bec6e6e02d95fb5830 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Sun, 15 Sep 2024 10:06:31 +0200 Subject: [PATCH] refactor: rework direct decoding for more clear & compact resolving unions --- .../github/avrokotlin/avro4k/AvroDecoder.kt | 11 + .../direct/AbstractAvroDirectDecoder.kt | 414 +++++------------- .../encoder/direct/RecordDirectEncoder.kt | 43 +- .../generic/AbstractAvroGenericEncoder.kt | 2 +- .../encoder/generic/RecordGenericEncoder.kt | 102 +++-- 5 files changed, 219 insertions(+), 353 deletions(-) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt index 569bf6a..7305c4e 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt @@ -1,6 +1,7 @@ package com.github.avrokotlin.avro4k import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerializationException import kotlinx.serialization.encoding.Decoder import org.apache.avro.Schema import org.apache.avro.generic.GenericFixed @@ -317,4 +318,14 @@ internal inline fun AvroDecoder.findValueDecoder( resolver(schema) } return foundResolver ?: throw error() +} + +internal fun AvroDecoder.unsupportedWriterTypeError( + mainType: Schema.Type, + vararg fallbackTypes: Schema.Type, +): Throwable { + val fallbacksStr = if (fallbackTypes.isNotEmpty()) ", and also not matching to any compatible type (one of ${fallbackTypes.joinToString()})." else "" + return SerializationException( + "Unsupported schema '${currentWriterSchema.fullName}' for decoded type of ${mainType.getName()}$fallbacksStr. Actual schema: $currentWriterSchema" + ) } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt index c8fdde6..7bc3cc5 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt @@ -1,21 +1,7 @@ package com.github.avrokotlin.avro4k.internal.decoder.direct -import com.github.avrokotlin.avro4k.AnyValueDecoder import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.BooleanValueDecoder -import com.github.avrokotlin.avro4k.CharValueDecoder -import com.github.avrokotlin.avro4k.DoubleValueDecoder -import com.github.avrokotlin.avro4k.FloatValueDecoder -import com.github.avrokotlin.avro4k.IntValueDecoder -import com.github.avrokotlin.avro4k.LongValueDecoder import com.github.avrokotlin.avro4k.UnionDecoder -import com.github.avrokotlin.avro4k.decodeResolvingAny -import com.github.avrokotlin.avro4k.decodeResolvingBoolean -import com.github.avrokotlin.avro4k.decodeResolvingChar -import com.github.avrokotlin.avro4k.decodeResolvingDouble -import com.github.avrokotlin.avro4k.decodeResolvingFloat -import com.github.avrokotlin.avro4k.decodeResolvingInt -import com.github.avrokotlin.avro4k.decodeResolvingLong import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError import com.github.avrokotlin.avro4k.internal.decoder.AbstractPolymorphicDecoder @@ -23,9 +9,9 @@ import com.github.avrokotlin.avro4k.internal.getElementIndexNullable import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch import com.github.avrokotlin.avro4k.internal.nonNullSerialName import com.github.avrokotlin.avro4k.internal.toByteExact -import com.github.avrokotlin.avro4k.internal.toFloatExact import com.github.avrokotlin.avro4k.internal.toIntExact import com.github.avrokotlin.avro4k.internal.toShortExact +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.PolymorphicKind @@ -59,38 +45,39 @@ internal abstract class AbstractAvroDirectDecoder( } override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { + decodeAndResolveUnion() + return when (descriptor.kind) { StructureKind.LIST -> - decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.ARRAY) }) { - when (it.type) { - Schema.Type.ARRAY -> { - AnyValueDecoder { ArrayBlockDirectDecoder(it, decodeFirstBlock = decodedCollectionSize == -1, { decodedCollectionSize = it }, avro, binaryDecoder) } - } - - else -> null - } + when (currentWriterSchema.type) { + Schema.Type.ARRAY -> + ArrayBlockDirectDecoder( + currentWriterSchema, + decodeFirstBlock = decodedCollectionSize == -1, + { decodedCollectionSize = it }, + avro, + binaryDecoder + ) + else -> throw unsupportedWriterTypeError(Schema.Type.ARRAY) } StructureKind.MAP -> - decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.MAP) }) { - when (it.type) { - Schema.Type.MAP -> { - AnyValueDecoder { MapBlockDirectDecoder(it, decodeFirstBlock = decodedCollectionSize == -1, { decodedCollectionSize = it }, avro, binaryDecoder) } - } - - else -> null - } + when (currentWriterSchema.type) { + Schema.Type.MAP -> + MapBlockDirectDecoder( + currentWriterSchema, + decodeFirstBlock = decodedCollectionSize == -1, + { decodedCollectionSize = it }, + avro, + binaryDecoder + ) + else -> throw unsupportedWriterTypeError(Schema.Type.MAP) } StructureKind.CLASS, StructureKind.OBJECT -> - decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.RECORD) }) { - when (it.type) { - Schema.Type.RECORD -> { - AnyValueDecoder { RecordDirectDecoder(it, descriptor, avro, binaryDecoder) } - } - - else -> null - } + when (currentWriterSchema.type) { + Schema.Type.RECORD -> RecordDirectDecoder(currentWriterSchema, descriptor, avro, binaryDecoder) + else -> throw unsupportedWriterTypeError(Schema.Type.RECORD) } is PolymorphicKind -> PolymorphicDecoder(avro, descriptor, currentWriterSchema, binaryDecoder) @@ -106,46 +93,27 @@ internal abstract class AbstractAvroDirectDecoder( override fun decodeNotNullMark(): Boolean { decodeAndResolveUnion() + return currentWriterSchema.type != Schema.Type.NULL } override fun decodeNull(): Nothing? { - decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "null", - Schema.Type.NULL - ) - }) { - when (it.type) { - Schema.Type.NULL -> { - AnyValueDecoder { binaryDecoder.readNull() } - } + decodeAndResolveUnion() - else -> null - } + if (currentWriterSchema.type != Schema.Type.NULL) { + throw unsupportedWriterTypeError(Schema.Type.NULL) } + binaryDecoder.readNull() return null } override fun decodeBoolean(): Boolean { - return decodeResolvingBoolean({ - UnexpectedDecodeSchemaError( - "boolean", - Schema.Type.BOOLEAN, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.BOOLEAN -> { - BooleanValueDecoder { binaryDecoder.readBoolean() } - } - - Schema.Type.STRING -> { - BooleanValueDecoder { binaryDecoder.readString().toBooleanStrict() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.BOOLEAN -> binaryDecoder.readBoolean() + Schema.Type.STRING -> binaryDecoder.readString().toBooleanStrict() + else -> throw unsupportedWriterTypeError(Schema.Type.BOOLEAN, Schema.Type.STRING) } } @@ -158,284 +126,122 @@ internal abstract class AbstractAvroDirectDecoder( } override fun decodeInt(): Int { - return decodeResolvingInt({ - UnexpectedDecodeSchemaError( - "int", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - IntValueDecoder { binaryDecoder.readInt() } - } - - Schema.Type.LONG -> { - IntValueDecoder { binaryDecoder.readLong().toIntExact() } - } - - Schema.Type.FLOAT -> { - IntValueDecoder { binaryDecoder.readDouble().toInt() } - } - - Schema.Type.DOUBLE -> { - IntValueDecoder { binaryDecoder.readDouble().toInt() } - } - - Schema.Type.STRING -> { - IntValueDecoder { binaryDecoder.readString().toInt() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt() + Schema.Type.LONG -> binaryDecoder.readLong().toIntExact() + Schema.Type.STRING -> binaryDecoder.readString().toInt() + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } } override fun decodeLong(): Long { - return decodeResolvingLong({ - UnexpectedDecodeSchemaError( - "long", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - LongValueDecoder { binaryDecoder.readInt().toLong() } - } - - Schema.Type.LONG -> { - LongValueDecoder { binaryDecoder.readLong() } - } - - Schema.Type.FLOAT -> { - LongValueDecoder { binaryDecoder.readFloat().toLong() } - } - - Schema.Type.DOUBLE -> { - LongValueDecoder { binaryDecoder.readDouble().toLong() } - } - - Schema.Type.STRING -> { - LongValueDecoder { binaryDecoder.readString().toLong() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toLong() + Schema.Type.LONG -> binaryDecoder.readLong() + Schema.Type.STRING -> binaryDecoder.readString().toLong() + else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.INT, Schema.Type.STRING) } } override fun decodeFloat(): Float { - return decodeResolvingFloat({ - UnexpectedDecodeSchemaError( - "float", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - FloatValueDecoder { binaryDecoder.readInt().toFloat() } - } - - Schema.Type.LONG -> { - FloatValueDecoder { binaryDecoder.readLong().toFloat() } - } - - Schema.Type.FLOAT -> { - FloatValueDecoder { binaryDecoder.readFloat() } - } - - Schema.Type.DOUBLE -> { - FloatValueDecoder { binaryDecoder.readDouble().toFloatExact() } - } - - Schema.Type.STRING -> { - FloatValueDecoder { binaryDecoder.readString().toFloat() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toFloat() + Schema.Type.LONG -> binaryDecoder.readLong().toFloat() + Schema.Type.FLOAT -> binaryDecoder.readFloat() + Schema.Type.STRING -> binaryDecoder.readString().toFloat() + else -> throw unsupportedWriterTypeError(Schema.Type.FLOAT, Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } } override fun decodeDouble(): Double { - return decodeResolvingDouble({ - UnexpectedDecodeSchemaError( - "double", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - DoubleValueDecoder { binaryDecoder.readInt().toDouble() } - } - - Schema.Type.LONG -> { - DoubleValueDecoder { binaryDecoder.readLong().toDouble() } - } - - Schema.Type.FLOAT -> { - DoubleValueDecoder { binaryDecoder.readFloat().toDouble() } - } - - Schema.Type.DOUBLE -> { - DoubleValueDecoder { binaryDecoder.readDouble() } - } - - Schema.Type.STRING -> { - DoubleValueDecoder { binaryDecoder.readString().toDouble() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toDouble() + Schema.Type.LONG -> binaryDecoder.readLong().toDouble() + Schema.Type.FLOAT -> binaryDecoder.readFloat().toDouble() + Schema.Type.DOUBLE -> binaryDecoder.readDouble() + Schema.Type.STRING -> binaryDecoder.readString().toDouble() + else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.STRING) } } override fun decodeChar(): Char { - return decodeResolvingChar({ - UnexpectedDecodeSchemaError( - "char", - Schema.Type.INT, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - CharValueDecoder { binaryDecoder.readInt().toChar() } - } - - Schema.Type.STRING -> { - CharValueDecoder { binaryDecoder.readString(null).single() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toChar() + Schema.Type.STRING -> binaryDecoder.readString(null).single() + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.STRING) } } override fun decodeString(): String { - return decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "string", - Schema.Type.STRING, - Schema.Type.BYTES, - Schema.Type.FIXED - ) - }) { - when (it.type) { - Schema.Type.STRING, - Schema.Type.BYTES, - -> { - AnyValueDecoder { binaryDecoder.readString() } - } - - Schema.Type.FIXED -> { - AnyValueDecoder { ByteArray(it.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }.decodeToString() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.STRING -> binaryDecoder.readString(null).toString() + Schema.Type.BYTES -> binaryDecoder.readBytes(null).array().decodeToString() + Schema.Type.FIXED -> ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }.decodeToString() + else -> throw unsupportedWriterTypeError(Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } } override fun decodeEnum(enumDescriptor: SerialDescriptor): Int { - return decodeResolvingInt({ - UnexpectedDecodeSchemaError( - enumDescriptor.nonNullSerialName, - Schema.Type.ENUM, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.ENUM -> - if (it.isFullNameOrAliasMatch(enumDescriptor)) { - IntValueDecoder { - val enumName = it.enumSymbols[binaryDecoder.readEnum()] - enumDescriptor.getElementIndexNullable(enumName) - ?: avro.enumResolver.getDefaultValueIndex(enumDescriptor) - ?: throw SerializationException( - "Unknown enum symbol name '$enumName' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema" - ) - } - } else { - null - } - - Schema.Type.STRING -> { - IntValueDecoder { - val enumSymbol = binaryDecoder.readString() - enumDescriptor.getElementIndex(enumSymbol) - .takeIf { index -> index >= 0 } - ?: avro.enumResolver.getDefaultValueIndex(enumDescriptor) - ?: throw SerializationException("Unknown enum symbol '$enumSymbol' for Enum '${enumDescriptor.serialName}'") - } - } + decodeAndResolveUnion() - else -> null + return when (currentWriterSchema.type) { + Schema.Type.ENUM -> + if (currentWriterSchema.isFullNameOrAliasMatch(enumDescriptor)) { + val enumName = currentWriterSchema.enumSymbols[binaryDecoder.readEnum()] + enumDescriptor.getElementIndexNullable(enumName) + ?: avro.enumResolver.getDefaultValueIndex(enumDescriptor) + ?: throw SerializationException( + "Unknown enum symbol name '$enumName' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema" + ) + } else { + throw UnexpectedDecodeSchemaError( + enumDescriptor.nonNullSerialName, + Schema.Type.ENUM, + Schema.Type.STRING + ) + } + + Schema.Type.STRING -> { + val enumSymbol = binaryDecoder.readString() + enumDescriptor.getElementIndex(enumSymbol).takeIf { index -> index >= 0 } + ?: avro.enumResolver.getDefaultValueIndex(enumDescriptor) + ?: throw SerializationException("Unknown enum symbol '$enumSymbol' for Enum '${enumDescriptor.serialName}'") } + + else -> throw unsupportedWriterTypeError(Schema.Type.ENUM, Schema.Type.STRING) } } override fun decodeBytes(): ByteArray { - return decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "ByteArray", - Schema.Type.BYTES, - Schema.Type.FIXED, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.BYTES -> { - AnyValueDecoder { binaryDecoder.readBytes(null).array() } - } - - Schema.Type.FIXED -> { - AnyValueDecoder { ByteArray(it.fixedSize).also { buf -> binaryDecoder.readFixed(buf) } } - } - - Schema.Type.STRING -> { - AnyValueDecoder { binaryDecoder.readString(null).bytes } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.BYTES -> binaryDecoder.readBytes(null).array() + Schema.Type.FIXED -> ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) } + Schema.Type.STRING -> binaryDecoder.readString(null).bytes + else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.STRING) } } override fun decodeFixed(): GenericFixed { - return decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "GenericFixed", - Schema.Type.BYTES, - Schema.Type.FIXED - ) - }) { - when (it.type) { - Schema.Type.BYTES -> { - AnyValueDecoder { GenericData.Fixed(it, binaryDecoder.readBytes(null).array()) } - } - - Schema.Type.FIXED -> { - AnyValueDecoder { GenericData.Fixed(it, ByteArray(it.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }) } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.BYTES -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readBytes(null).array()) + Schema.Type.FIXED -> GenericData.Fixed(currentWriterSchema, ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }) + else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED) } } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt index 5c1f67e..086830a 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt @@ -15,27 +15,29 @@ internal fun RecordDirectEncoder( avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, ): CompositeEncoder { - val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding - when (encodingWorkflow) { - is EncodingWorkflow.ExactMatch -> return RecordExactDirectEncoder(schema, avro, binaryEncoder) - is EncodingWorkflow.ContiguousWithSkips -> return RecordSkippingDirectEncoder(encodingWorkflow.fieldsToSkip, schema, avro, binaryEncoder) - is EncodingWorkflow.NonContiguous -> return ReorderingCompositeEncoder( - schema.fields.size, - RecordNonContiguousDirectEncoder( - encodingWorkflow.descriptorToWriterFieldIndex, - schema, - avro, - binaryEncoder - ) - ) { _, index -> - encodingWorkflow.descriptorToWriterFieldIndex[index] - } + return when (val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding) { + is EncodingWorkflow.ExactMatch -> RecordContiguousExactEncoder(schema, avro, binaryEncoder) + is EncodingWorkflow.ContiguousWithSkips -> RecordContiguousSkippingEncoder(encodingWorkflow.fieldsToSkip, schema, avro, binaryEncoder) + is EncodingWorkflow.NonContiguous -> + ReorderingCompositeEncoder( + schema.fields.size, + RecordNonContiguousEncoder( + encodingWorkflow.descriptorToWriterFieldIndex, + schema, + avro, + binaryEncoder + ) + ) { _, index -> + encodingWorkflow.descriptorToWriterFieldIndex[index] + } - is EncodingWorkflow.MissingWriterFields -> throw SerializationException("Invalid encoding workflow") + is EncodingWorkflow.MissingWriterFields -> throw SerializationException( + "Missing writer fields ${schema.fields.filter { it.pos() in encodingWorkflow.missingWriterFields }}} from the descriptor $descriptor" + ) } } -private class RecordNonContiguousDirectEncoder( +private class RecordNonContiguousEncoder( private val descriptorToWriterFieldIndex: IntArray, private val schema: Schema, avro: Avro, @@ -57,12 +59,13 @@ private class RecordNonContiguousDirectEncoder( } } -private class RecordSkippingDirectEncoder( +private class RecordContiguousSkippingEncoder( private val skippedElements: BooleanArray, private val schema: Schema, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, ) : AbstractAvroDirectEncoder(avro, binaryEncoder) { + private var nextWriterFieldIndex = 0 override lateinit var currentWriterSchema: Schema override fun encodeElement( @@ -73,12 +76,12 @@ private class RecordSkippingDirectEncoder( return false } super.encodeElement(descriptor, index) - currentWriterSchema = schema.fields[index].schema() + currentWriterSchema = schema.fields[nextWriterFieldIndex++].schema() return true } } -private class RecordExactDirectEncoder( +private class RecordContiguousExactEncoder( private val schema: Schema, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt index 06a5f41..a10eab4 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt @@ -18,7 +18,7 @@ internal abstract class AbstractAvroGenericEncoder : AbstractAvroEncoder() { get() = avro.serializersModule override fun getRecordEncoder(descriptor: SerialDescriptor): CompositeEncoder { - return RecordGenericEncoder(avro, descriptor, currentWriterSchema) { encodeValue(it) } + return RecordGenericEncoder(descriptor, currentWriterSchema, avro) { encodeValue(it) } } override fun getPolymorphicEncoder(descriptor: SerialDescriptor): CompositeEncoder { diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt index 5f7706a..9d44d59 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt @@ -5,52 +5,98 @@ import com.github.avrokotlin.avro4k.ListRecord import com.github.avrokotlin.avro4k.internal.EncodingWorkflow import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.CompositeEncoder import org.apache.avro.Schema import org.apache.avro.generic.GenericRecord -internal class RecordGenericEncoder( - override val avro: Avro, +@Suppress("FunctionName") +internal fun RecordGenericEncoder( descriptor: SerialDescriptor, - private val schema: Schema, - private val onEncoded: (GenericRecord) -> Unit, -) : AbstractAvroGenericEncoder() { - private val fieldValues: Array = Array(schema.fields.size) { null } + schema: Schema, + avro: Avro, + onEncoded: (GenericRecord) -> Unit, +): CompositeEncoder { + return when (val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding) { + is EncodingWorkflow.ExactMatch -> RecordContiguousExactEncoder(schema, avro, onEncoded) + is EncodingWorkflow.ContiguousWithSkips -> RecordContiguousSkippingEncoder(encodingWorkflow.fieldsToSkip, schema, avro, onEncoded) + is EncodingWorkflow.NonContiguous -> RecordNonContiguousEncoder(encodingWorkflow.descriptorToWriterFieldIndex, schema, avro, onEncoded) + is EncodingWorkflow.MissingWriterFields -> throw SerializationException( + "Missing writer fields ${schema.fields.filter { it.pos() in encodingWorkflow.missingWriterFields }}} from the descriptor $descriptor" + ) + } +} - private val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding - private lateinit var currentField: Schema.Field +private class RecordNonContiguousEncoder( + private val descriptorToWriterFieldIndex: IntArray, + schema: Schema, + avro: Avro, + onEncoded: (GenericRecord) -> Unit, +) : AbstractRecordGenericEncoder(avro, schema, onEncoded) { + override fun encodeElement( + descriptor: SerialDescriptor, + index: Int, + ): Boolean { + val writerFieldIndex = descriptorToWriterFieldIndex[index] + if (writerFieldIndex == -1) { + return false + } + super.encodeElement(descriptor, index) + setWriterField(writerFieldIndex) + return true + } +} - override lateinit var currentWriterSchema: Schema +private class RecordContiguousSkippingEncoder( + private val skippedElements: BooleanArray, + schema: Schema, + avro: Avro, + onEncoded: (GenericRecord) -> Unit, +) : AbstractRecordGenericEncoder(avro, schema, onEncoded) { + private var nextWriterFieldIndex = 0 + + override fun encodeElement( + descriptor: SerialDescriptor, + index: Int, + ): Boolean { + if (skippedElements[index]) { + return false + } + super.encodeElement(descriptor, index) + setWriterField(nextWriterFieldIndex++) + return true + } +} +private class RecordContiguousExactEncoder( + schema: Schema, + avro: Avro, + onEncoded: (GenericRecord) -> Unit, +) : AbstractRecordGenericEncoder(avro, schema, onEncoded) { override fun encodeElement( descriptor: SerialDescriptor, index: Int, ): Boolean { super.encodeElement(descriptor, index) - val writerFieldIndex = - when (encodingWorkflow) { - EncodingWorkflow.ExactMatch -> index + setWriterField(index) + return true + } +} - is EncodingWorkflow.ContiguousWithSkips -> { - if (encodingWorkflow.fieldsToSkip[index]) { - return false - } - index - } +private abstract class AbstractRecordGenericEncoder( + override val avro: Avro, + private val schema: Schema, + private val onEncoded: (GenericRecord) -> Unit, +) : AbstractAvroGenericEncoder() { + private val fieldValues: Array = Array(schema.fields.size) { null } - is EncodingWorkflow.NonContiguous -> { - val writerFieldIndex = encodingWorkflow.descriptorToWriterFieldIndex[index] - if (writerFieldIndex == -1) { - return false - } - writerFieldIndex - } + private lateinit var currentField: Schema.Field + + override lateinit var currentWriterSchema: Schema - is EncodingWorkflow.MissingWriterFields -> throw SerializationException("Invalid encoding workflow") - } + protected fun setWriterField(writerFieldIndex: Int) { val field = schema.fields[writerFieldIndex] currentField = field currentWriterSchema = field.schema() - return true } override fun endStructure(descriptor: SerialDescriptor) {