From ec0a0afe7466f548742b235dfc7681a4e5884665 Mon Sep 17 00:00:00 2001 From: Ilya Ulanov Date: Wed, 12 Jan 2022 12:53:16 +0500 Subject: [PATCH] Add ability to add discriminator to JsonObjectWriter (#142) --- .../builder/WriterDescription.scala | 3 ++- .../derivation/SemiautoDerivationMacro.scala | 22 ++++++++++++++----- .../impl/derivation/WriterDerivation.scala | 15 +++++++++++-- .../SemiautoWriterDerivationTest.scala | 18 +++++++++++++++ 4 files changed, 49 insertions(+), 9 deletions(-) diff --git a/modules/macro-derivation/src/main/scala/tethys/derivation/builder/WriterDescription.scala b/modules/macro-derivation/src/main/scala/tethys/derivation/builder/WriterDescription.scala index f6b49db0..2452de89 100644 --- a/modules/macro-derivation/src/main/scala/tethys/derivation/builder/WriterDescription.scala +++ b/modules/macro-derivation/src/main/scala/tethys/derivation/builder/WriterDescription.scala @@ -2,8 +2,9 @@ package tethys.derivation.builder import tethys.derivation.builder.WriterDescription.BuilderOperation -case class WriterDerivationConfig(fieldStyle: Option[FieldStyle]) { +case class WriterDerivationConfig(fieldStyle: Option[FieldStyle], discriminator: Option[String] = None) { def withFieldStyle(fieldStyle: FieldStyle): WriterDerivationConfig = this.copy(fieldStyle = Some(fieldStyle)) + def withDiscriminator(discriminator: String): WriterDerivationConfig = this.copy(discriminator = Some(discriminator)) } object WriterDerivationConfig { diff --git a/modules/macro-derivation/src/main/scala/tethys/derivation/impl/derivation/SemiautoDerivationMacro.scala b/modules/macro-derivation/src/main/scala/tethys/derivation/impl/derivation/SemiautoDerivationMacro.scala index 9efac6d4..bf08b8d5 100644 --- a/modules/macro-derivation/src/main/scala/tethys/derivation/impl/derivation/SemiautoDerivationMacro.scala +++ b/modules/macro-derivation/src/main/scala/tethys/derivation/impl/derivation/SemiautoDerivationMacro.scala @@ -31,12 +31,22 @@ class SemiautoDerivationMacro(val c: blackbox.Context) } def jsonWriterWithConfig[A: WeakTypeTag](config: Expr[WriterDerivationConfig]): Expr[JsonObjectWriter[A]] = { - val description = MacroWriteDescription( - tpe = weakTypeOf[A], - config = c.Expr[WriterDerivationConfig](c.untypecheck(config.tree)), - operations = Seq.empty - ) - deriveWriter[A](description) + val tpe = weakTypeOf[A] + val clazz = classSym(tpe) + + if (isCaseClass(tpe)) { + deriveWriter[A]( + MacroWriteDescription( + tpe = tpe, + config = c.Expr[WriterDerivationConfig](c.untypecheck(config.tree)), + operations = Seq.empty + ) + ) + } else if (clazz.isSealed) { + deriveWriterForSealedClass[A](c.Expr[WriterDerivationConfig](c.untypecheck(config.tree))) + } else { + abort(s"Can't auto derive JsonWriter[$tpe]") + } } def describedJsonWriter[A: WeakTypeTag](description: Expr[WriterDescription[A]]): Expr[JsonObjectWriter[A]] = { diff --git a/modules/macro-derivation/src/main/scala/tethys/derivation/impl/derivation/WriterDerivation.scala b/modules/macro-derivation/src/main/scala/tethys/derivation/impl/derivation/WriterDerivation.scala index 4e3a61f6..0ad4279d 100644 --- a/modules/macro-derivation/src/main/scala/tethys/derivation/impl/derivation/WriterDerivation.scala +++ b/modules/macro-derivation/src/main/scala/tethys/derivation/impl/derivation/WriterDerivation.scala @@ -1,7 +1,7 @@ package tethys.derivation.impl.derivation import tethys.JsonObjectWriter -import tethys.derivation.builder.FieldStyle +import tethys.derivation.builder.{FieldStyle, WriterDerivationConfig} import tethys.derivation.impl.builder.{WriteBuilderUtils, WriterBuilderCommons} import tethys.derivation.impl.{BaseMacroDefinitions, CaseClassUtils} import tethys.writers.tokens.TokenWriter @@ -34,6 +34,10 @@ trait WriterDerivation } def deriveWriterForSealedClass[A: WeakTypeTag]: Expr[JsonObjectWriter[A]] = { + deriveWriterForSealedClass[A](emptyWriterConfig) + } + + def deriveWriterForSealedClass[A: WeakTypeTag](config: c.Expr[WriterDerivationConfig]): Expr[JsonObjectWriter[A]] = { val tpe = weakTypeOf[A] val types = collectDistinctSubtypes(tpe).sortBy(_.typeSymbol.fullName) @@ -51,7 +55,14 @@ trait WriterDerivation val subClassesCases = types.zip(terms).map { case (subtype, writer) => val term = TermName(c.freshName("sub")) - cq"$term: $subtype => $writer.writeValues($term, $tokenWriterTerm)" + val discriminatorTerm = TermName(c.freshName("discriminator")) + val typeName = subtype.typeSymbol.asClass.name.decodedName.toString.trim + cq"""$term: $subtype => { + $writer.writeValues($term, $tokenWriterTerm) + ${config.tree}.discriminator.foreach { $discriminatorTerm: String => + implicitly[$jsonWriterType[String]].write($discriminatorTerm, $typeName, $tokenWriterTerm) + } + }""" } c.Expr[JsonObjectWriter[A]] { diff --git a/modules/macro-derivation/src/test/scala/tethys/derivation/SemiautoWriterDerivationTest.scala b/modules/macro-derivation/src/test/scala/tethys/derivation/SemiautoWriterDerivationTest.scala index 2a752aa1..ff9b071b 100644 --- a/modules/macro-derivation/src/test/scala/tethys/derivation/SemiautoWriterDerivationTest.scala +++ b/modules/macro-derivation/src/test/scala/tethys/derivation/SemiautoWriterDerivationTest.scala @@ -170,4 +170,22 @@ class SemiautoWriterDerivationTest extends AnyFlatSpec with Matchers { write(JustObject) shouldBe obj("type" -> "JustObject") write(SubChild(3)) shouldBe obj("c" -> 3) } + + it should "derive writer for simple sealed trait with hierarchy with discriminator" in { + implicit val caseClassWriter: JsonObjectWriter[CaseClass] = jsonWriter[CaseClass] + implicit val simpleClassWriter: JsonObjectWriter[SimpleClass] = JsonWriter.obj[SimpleClass].addField("b")(_.b) + implicit val justObjectWriter: JsonObjectWriter[JustObject.type] = JsonWriter.obj + implicit val subChildWriter: JsonObjectWriter[SubChild] = jsonWriter[SubChild] + + implicit val sealedWriter: JsonWriter[SimpleSealedType] = jsonWriter[SimpleSealedType]( + WriterDerivationConfig.empty.withDiscriminator("__type") + ) + + def write(simpleSealedType: SimpleSealedType): List[TokenNode] = simpleSealedType.asTokenList + + write(CaseClass(1)) shouldBe obj("a" -> 1, "__type" -> "CaseClass") + write(new SimpleClass(2)) shouldBe obj("b" -> 2, "__type" -> "SimpleClass") + write(JustObject) shouldBe obj("__type" -> "JustObject") + write(SubChild(3)) shouldBe obj("c" -> 3, "__type" -> "SubChild") + } }