Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not remove type discriminator in unsafe manner during polymorphic … #810

Merged
merged 2 commits into from
Apr 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ internal fun <T> JsonInput.decodeSerializableValuePolymorphic(deserializer: Dese
}

val jsonTree = cast<JsonObject>(decodeJson())
val type = jsonTree.getValue(json.configuration.classDiscriminator).content
(jsonTree.content as MutableMap).remove(json.configuration.classDiscriminator)
val discriminator = json.configuration.classDiscriminator
val type = jsonTree.getValue(discriminator).content
val actualSerializer = deserializer.findPolymorphicSerializer(this, type).cast<T>()
return json.readJson(jsonTree, actualSerializer)
return json.readPolymorphicJson(discriminator, jsonTree, actualSerializer)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ internal fun <T> Json.readJson(element: JsonElement, deserializer: Deserializati
return input.decode(deserializer)
}

internal fun <T> Json.readPolymorphicJson(
discriminator: String,
element: JsonObject,
deserializer: DeserializationStrategy<T>
): T {
return JsonTreeInput(this, element, discriminator, deserializer.descriptor).decode(deserializer)
}

private sealed class AbstractJsonTreeInput(
override val json: Json,
open val value: JsonElement
Expand Down Expand Up @@ -140,7 +148,12 @@ private class JsonPrimitiveInput(json: Json, override val value: JsonPrimitive)
}
}

private open class JsonTreeInput(json: Json, override val value: JsonObject) : AbstractJsonTreeInput(json, value) {
private open class JsonTreeInput(
json: Json,
override val value: JsonObject,
private val polyDiscriminator: String? = null,
private val polyDescriptor: SerialDescriptor? = null
) : AbstractJsonTreeInput(json, value) {
private var position = 0

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
Expand All @@ -155,12 +168,23 @@ private open class JsonTreeInput(json: Json, override val value: JsonObject) : A

override fun currentElement(tag: String): JsonElement = value.getValue(tag)

override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
/*
* For polymorphic serialization we'd like to avoid excessive decoder creating in
* beginStructure to properly preserve 'polyDiscriminator' field and filter it out.
*/
if (descriptor === polyDescriptor) return this
return super.beginStructure(descriptor)
}

override fun endStructure(descriptor: SerialDescriptor) {
if (configuration.ignoreUnknownKeys || descriptor.kind is PolymorphicKind) return
// Validate keys
val names = descriptor.cachedSerialNames()
for (key in value.keys) {
if (key !in names) throw UnknownKeyException(key, value.toString())
if (key !in names && key != polyDiscriminator) {
throw UnknownKeyException(key, value.toString())
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PolymorphismTest : JsonTestBase() {
object PolyDefaultSerializer : JsonTransformingSerializer<PolyDefault>(PolyDefault.serializer(), "foo") {
override fun readTransform(element: JsonElement): JsonElement {
return buildJson {
add("json", element)
add("json", JsonObject(element.jsonObject.filterKeys { it != "type" }))
add("id", 42)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,24 @@ import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.test.assertStringFormAndRestored
import kotlin.test.Test

@Serializable
data class FooHolder(
val someMetadata: Int,
val payload: List<@Polymorphic Foo>
)

@Serializable
sealed class Foo {
@Serializable
data class Bar(val bar: Int) : Foo()
class SealedPolymorphismTest {

@Serializable
data class Baz(val baz: String) : Foo()
}
data class FooHolder(
val someMetadata: Int,
val payload: List<@Polymorphic Foo>
)

class SealedPolymorphismTest {
@Serializable
@SerialName("Foo")
sealed class Foo {
@Serializable
@SerialName("Bar")
data class Bar(val bar: Int) : Foo()
@Serializable
@SerialName("Baz")
data class Baz(val baz: String) : Foo()
}

val sealedModule = SerializersModule {
polymorphic(Foo::class) {
Expand All @@ -39,8 +42,8 @@ class SealedPolymorphismTest {
fun testSaveSealedClassesList() {
assertStringFormAndRestored(
"""{"someMetadata":42,"payload":[
|{"type":"kotlinx.serialization.features.Foo.Bar","bar":1},
|{"type":"kotlinx.serialization.features.Foo.Baz","baz":"2"}]}""".trimMargin().replace("\n", ""),
|{"type":"Bar","bar":1},
|{"type":"Baz","baz":"2"}]}""".trimMargin().replace("\n", ""),
FooHolder(42, listOf(Foo.Bar(1), Foo.Baz("2"))),
FooHolder.serializer(),
json,
Expand All @@ -51,7 +54,7 @@ class SealedPolymorphismTest {
@Test
fun testCanSerializeSealedClassPolymorphicallyOnTopLevel() {
assertStringFormAndRestored(
"""{"type":"kotlinx.serialization.features.Foo.Bar","bar":1}""",
"""{"type":"Bar","bar":1}""",
Foo.Bar(1),
PolymorphicSerializer(Foo::class),
json
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.serialization.json.polymorphic

import kotlinx.serialization.Serializable
import kotlinx.serialization.json.*
import kotlinx.serialization.test.*
import kotlin.test.*

class JsonDeserializePolymorphicTwiceTest {

@Serializable
sealed class Foo {
@Serializable
data class Bar(val a: Int) : Foo()
}

@Test
fun testDeserializeTwice() { // #812
val json = Json.toJson(Foo.serializer(), Foo.Bar(1))
assertEquals(Foo.Bar(1), Json.fromJson(Foo.serializer(), json))
assertEquals(Foo.Bar(1), Json.fromJson(Foo.serializer(), json))
}
}