Skip to content

Commit

Permalink
Merge pull request #428 from AVSystem/mongoref-transparent
Browse files Browse the repository at this point in the history
proper support for transparent wrappers in typed mongo
  • Loading branch information
ddworak authored Jan 18, 2023
2 parents d97f4d6 + 022cd1a commit 8802150
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ class MongoMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) {
}
}

private def isTransparentUnwrap(prefixTpe: Type, fieldSym: Symbol): Boolean = {
val sym = prefixTpe.typeSymbol
sym.isClass && sym.asClass.isCaseClass && (primaryConstructorOf(prefixTpe).asMethod.paramLists match {
case List(param) :: _ if param.name == fieldSym.name =>
val paramTpe = fieldSym.typeSignatureIn(prefixTpe).finalResultType
val wrappingTpe = getType(tq"$SerializationPkg.TransparentWrapping[$paramTpe, $prefixTpe]")
inferImplicitValue(wrappingTpe) != EmptyTree
case _ => false
})
}

def isOptionLike(fullTpe: Type, wrappedTpe: Type): Boolean =
fullTpe != null && fullTpe != NoType && wrappedTpe != null && wrappedTpe != NoType &&
c.inferImplicitValue(getType(tq"$CommonsPkg.meta.OptionLike.Aux[$fullTpe, $wrappedTpe]")) != EmptyTree
Expand Down Expand Up @@ -63,7 +74,9 @@ class MongoMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) {
val prefixTpe = prefix.tpe.widen
val bodyTpe = body.tpe.widen

if (termSym.isCaseAccessor || isSealedHierarchySharedField(prefixTpe, body.symbol.asTerm))
if (isTransparentUnwrap(prefixTpe, body.symbol.asTerm))
q"$newPrefixRef.unwrap"
else if (termSym.isCaseAccessor || isSealedHierarchySharedField(prefixTpe, body.symbol.asTerm))
q"$newPrefixRef.asAdtRef.fieldRefFor[$bodyTpe](${name.decodedName.toString})"
else if (name == TermName("get") && isOptionLike(prefixTpe, bodyTpe))
q"$newPrefixRef.get"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ private final class FilterDocBuilder(prefixPath: Opt[String], filterDocs: BsonAr
appendToPrefix(prefix.rawPath).addOperator(MongoQueryOperator.Ne(optionLike.none, prefix.format))
addImpliedFilters(prefix)

case MongoRef.TransparentUnwrap(prefix, _, _) =>
addImpliedFilters(prefix)

case MongoRef.FieldRef(prefix, _, _, _) =>
addImpliedFilters(prefix)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ sealed trait MongoFormat[T] {
case union: MongoAdtFormat.UnionFormat[T] => union
case _ => throw new IllegalArgumentException(
"Encountered a non-union MongoFormat for an union type (sealed hierarchy) -" +
"do you have any custom implicit MongoFormat for that type?"
"do you have any custom implicit MongoFormat for that type?",
)
}

Expand All @@ -44,51 +44,68 @@ sealed trait MongoFormat[T] {
"Encountered a non-optional MongoFormat for an Option-like type - " +
"do you have a custom implicit MongoFormat for that type?")
}

def assumeTransparent[R]: MongoFormat.TransparentFormat[T, R] = this match {
case transparent: MongoFormat.TransparentFormat[T@unchecked, R@unchecked] => transparent
case _ => throw new IllegalArgumentException(
"Encountered a non-transparent MongoFormat for a transparent wrapper type - " +
"do you have a custom implicit MongoFormat for that type?")
}
}
object MongoFormat extends MetadataCompanion[MongoFormat] with MongoFormatLowPriority {
final case class Opaque[T](
codec: GenCodec[T]
codec: GenCodec[T],
) extends MongoFormat[T]

final case class CollectionFormat[C[X] <: Iterable[X], T](
codec: GenCodec[C[T]],
elementFormat: MongoFormat[T]
elementFormat: MongoFormat[T],
) extends MongoFormat[C[T]]

final case class DictionaryFormat[M[X, Y] <: BMap[X, Y], K, V](
codec: GenCodec[M[K, V]],
keyCodec: GenKeyCodec[K],
valueFormat: MongoFormat[V]
valueFormat: MongoFormat[V],
) extends MongoFormat[M[K, V]]

final case class TypedMapFormat[K[_]](
codec: GenCodec[TypedMap[K]],
keyCodec: GenKeyCodec[K[_]],
valueFormats: MongoFormatMapping[K]
valueFormats: MongoFormatMapping[K],
) extends MongoFormat[TypedMap[K]]

final case class OptionalFormat[O, T](
codec: GenCodec[O],
optionLike: OptionLike.Aux[O, T],
wrappedFormat: MongoFormat[T]
wrappedFormat: MongoFormat[T],
) extends MongoFormat[O]

final case class TransparentFormat[T, R](
codec: GenCodec[T],
wrapping: TransparentWrapping[R, T],
wrappedFormat: MongoFormat[R],
) extends MongoFormat[T]

implicit def collectionFormat[C[X] <: Iterable[X], T](
implicit collectionCodec: GenCodec[C[T]], elementFormat: MongoFormat[T]
implicit collectionCodec: GenCodec[C[T]], elementFormat: MongoFormat[T],
): MongoFormat[C[T]] = CollectionFormat(collectionCodec, elementFormat)

implicit def dictionaryFormat[M[X, Y] <: BMap[X, Y], K, V](
implicit mapCodec: GenCodec[M[K, V]], keyCodec: GenKeyCodec[K], valueFormat: MongoFormat[V]
implicit mapCodec: GenCodec[M[K, V]], keyCodec: GenKeyCodec[K], valueFormat: MongoFormat[V],
): MongoFormat[M[K, V]] = DictionaryFormat(mapCodec, keyCodec, valueFormat)

implicit def typedMapFormat[K[_]](
implicit keyCodec: GenKeyCodec[K[_]], valueFormats: MongoFormatMapping[K]
implicit keyCodec: GenKeyCodec[K[_]], valueFormats: MongoFormatMapping[K],
): MongoFormat[TypedMap[K]] = TypedMapFormat[K](TypedMap.typedMapCodec, keyCodec, valueFormats)

implicit def optionalFormat[O, T](
implicit optionLike: OptionLike.Aux[O, T], optionCodec: GenCodec[O], wrappedFormat: MongoFormat[T]
implicit optionLike: OptionLike.Aux[O, T], optionCodec: GenCodec[O], wrappedFormat: MongoFormat[T],
): MongoFormat[O] = OptionalFormat(optionCodec, optionLike, wrappedFormat)

implicit def transparentFormat[R, T](
implicit codec: GenCodec[T], wrapping: TransparentWrapping[R, T], wrappedFormat: MongoFormat[R],
): MongoFormat[T] = TransparentFormat(codec, wrapping, wrappedFormat)

implicit class collectionFormatOps[C[X] <: Iterable[X], T](private val format: MongoFormat[C[T]]) extends AnyVal {
def assumeCollection: CollectionFormat[C, T] = format match {
case coll: CollectionFormat[C@unchecked, T@unchecked] => coll
Expand Down Expand Up @@ -134,7 +151,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@infer val codec: GenObjectCodec[T],
@infer val dataClassTag: ClassTag[T],
@reifyAnnot val flattenAnnot: flatten,
@multi @adtCaseMetadata val cases: List[Case[_]]
@multi @adtCaseMetadata val cases: List[Case[_]],
) extends MongoAdtFormat[T] {

lazy val casesByClass: Map[Class[_], Case[_]] =
Expand Down Expand Up @@ -163,7 +180,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@tailrec def loop(cases: List[Case[_]], rawName: Opt[String]): Unit = cases match {
case cse :: tail =>
val field = cse.getField(scalaFieldName).getOrElse(throw new NoSuchElementException(
s"Field $scalaFieldName not found in at least one case class/object."
s"Field $scalaFieldName not found in at least one case class/object.",
))
if (rawName.exists(_ != field.info.rawName)) {
throw new IllegalArgumentException(s"Field $scalaFieldName has different raw name across case classes")
Expand Down Expand Up @@ -207,7 +224,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@positioned(positioned.here)
final class RecordFormat[T](
@composite val record: RecordCase[T],
@infer val codec: GenObjectCodec[T]
@infer val codec: GenObjectCodec[T],
) extends MongoAdtFormat[T] {
def dataClassTag: ClassTag[T] = record.classTag

Expand All @@ -218,7 +235,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@positioned(positioned.here)
final class SingletonFormat[T](
@composite val singleton: SingletonCase[T],
@infer val codec: GenObjectCodec[T]
@infer val codec: GenObjectCodec[T],
) extends MongoAdtFormat[T] {
def dataClassTag: ClassTag[T] = singleton.classTag

Expand All @@ -242,7 +259,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@composite val info: GenCaseInfo[T],
@infer val classTag: ClassTag[T],
@multi @adtParamMetadata val fields: List[Field[_]],
@multi @adtCaseSealedParentMetadata val sealedParents: List[SealedParent[_]]
@multi @adtCaseSealedParentMetadata val sealedParents: List[SealedParent[_]],
) extends Case[T] {
def asAdtFormat(codec: GenObjectCodec[T]): MongoAdtFormat[T] =
new RecordFormat(this, codec)
Expand All @@ -258,7 +275,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {

def fieldRefFor[E, T0](prefix: MongoRef[E, T], scalaFieldName: String): MongoPropertyRef[E, T0] = {
val field = fieldsByScalaName.getOrElse(scalaFieldName,
throw new NoSuchElementException(s"Field $scalaFieldName not found")
throw new NoSuchElementException(s"Field $scalaFieldName not found"),
).asInstanceOf[MongoAdtFormat.Field[T0]]
prefix match {
case fieldRef: MongoRef.FieldRef[E, _, T] if transparentWrapper =>
Expand All @@ -274,7 +291,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@composite val info: GenCaseInfo[T],
@infer val classTag: ClassTag[T],
@multi @adtCaseSealedParentMetadata val sealedParents: List[SealedParent[_]],
@infer @checked val value: ValueOf[T]
@infer @checked val value: ValueOf[T],
) extends Case[T] {
def asAdtFormat(codec: GenObjectCodec[T]): MongoAdtFormat[T] =
new SingletonFormat(this, codec)
Expand All @@ -290,7 +307,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@composite val info: GenParamInfo[T],
@optional @reifyDefaultValue defaultValue: Opt[DefaultValue[T]],
@optional @reifyAnnot whenAbsentAnnot: Opt[whenAbsent[T]],
@infer val format: MongoFormat.Lazy[T]
@infer val format: MongoFormat.Lazy[T],
) extends TypedMetadata[T] {
lazy val fallbackBson: Opt[BsonValue] = {
if (info.optional) Opt(BsonNull.VALUE)
Expand All @@ -301,13 +318,13 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {

final class SealedParent[T](
@composite val info: GenUnionInfo[T],
@infer val classTag: ClassTag[T]
@infer val classTag: ClassTag[T],
) extends TypedMetadata[T]
}

final class MongoEntityMeta[E <: BaseMongoEntity](
@infer val format: MongoAdtFormat[E],
@infer val idMode: EntityIdMode[E, E#IDType]
@infer val idMode: EntityIdMode[E, E#IDType],
) {
def idRef: MongoPropertyRef[E, E#IDType] = idMode.idRef(format)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.avsystem.commons.misc.TypedMap
import com.avsystem.commons.mongo.typed.MongoPropertyRef.Separator
import com.avsystem.commons.mongo.{BsonValueInput, KeyEscaper}
import com.avsystem.commons.serialization.GenCodec.ReadFailure
import com.avsystem.commons.serialization.TransparentWrapping
import org.bson.{BsonDocument, BsonValue}

/**
Expand Down Expand Up @@ -192,7 +193,7 @@ sealed trait MongoPropertyRef[E, T] extends MongoRef[E, T]
private def computePath[T0](
onlyUpToArray: Boolean,
ref: MongoPropertyRef[E, T0],
acc: List[String]
acc: List[String],
): List[String] = ref match {
case FieldRef(_: MongoToplevelRef[_, _], fieldName, _, _) =>
KeyEscaper.escape(fieldName) :: acc
Expand All @@ -207,6 +208,9 @@ sealed trait MongoPropertyRef[E, T] extends MongoRef[E, T]
case GetFromOptional(prefix, _, _) =>
computePath(onlyUpToArray, prefix, acc)

case TransparentUnwrap(prefix, _, _) =>
computePath(onlyUpToArray, prefix, acc)

case PropertySubtypeRef(prefix, _, _, _) =>
computePath(onlyUpToArray, prefix, acc)
}
Expand Down Expand Up @@ -234,6 +238,9 @@ sealed trait MongoPropertyRef[E, T] extends MongoRef[E, T]
case GetFromOptional(prefix, _, _) =>
prefix.extractBson(doc)

case TransparentUnwrap(prefix, _, _) =>
prefix.extractBson(doc)

case PropertySubtypeRef(prefix, _, _, _) =>
prefix.extractBson(doc)
}
Expand Down Expand Up @@ -274,12 +281,23 @@ object MongoPropertyRef {
MongoRef.GetFromOptional(ref, format.wrappedFormat, format.optionLike)
}
}

implicit def transparentRefOps[E, T, R](ref: MongoPropertyRef[E, T])(implicit wrapping: TransparentWrapping[R, T]): TransparentRefOps[E, T, R] =
new TransparentRefOps[E, T, R](ref)

class TransparentRefOps[E, T, R](private val ref: MongoPropertyRef[E, T]) extends AnyVal {
def unwrap: MongoPropertyRef[E, R] = {
val format = ref.format.assumeTransparent[R]
MongoRef.TransparentUnwrap(ref, format.wrappedFormat, format.wrapping)
}
}

}

object MongoRef {
// Deliberately not calling this IdentityRef so that it doesn't get confused with IdRef (for database ID field)
final case class RootRef[T](
format: MongoAdtFormat[T]
format: MongoAdtFormat[T],
) extends MongoToplevelRef[T, T] {
def fullRef: RootRef[T] = this
def compose[P](prefix: MongoRef[P, T]): MongoRef[P, T] = prefix
Expand All @@ -289,7 +307,7 @@ object MongoRef {
fullRef: RootRef[E],
caseFieldName: String,
caseNames: List[String],
format: MongoAdtFormat[T]
format: MongoAdtFormat[T],
) extends MongoToplevelRef[E, T] {
def compose[P](prefix: MongoRef[P, E]): MongoRef[P, T] = prefix match {
case _: MongoToplevelRef[P, E] =>
Expand All @@ -304,7 +322,7 @@ object MongoRef {
prefix: MongoRef[E, E0],
fieldName: String,
format: MongoFormat[T],
fallbackBson: Opt[BsonValue]
fallbackBson: Opt[BsonValue],
) extends MongoPropertyRef[E, T] {
def compose[P](newPrefix: MongoRef[P, E]): MongoPropertyRef[P, T] =
copy(prefix = this.prefix compose newPrefix)
Expand All @@ -313,7 +331,7 @@ object MongoRef {
final case class ArrayIndexRef[E, C[X] <: Iterable[X], T](
prefix: MongoPropertyRef[E, C[T]],
index: Int,
format: MongoFormat[T]
format: MongoFormat[T],
) extends MongoPropertyRef[E, T] {
require(index >= 0, "array index must be non-negative")
def compose[P](newPrefix: MongoRef[P, E]): MongoPropertyRef[P, T] =
Expand All @@ -323,17 +341,26 @@ object MongoRef {
final case class GetFromOptional[E, O, T](
prefix: MongoPropertyRef[E, O],
format: MongoFormat[T],
optionLike: OptionLike.Aux[O, T]
optionLike: OptionLike.Aux[O, T],
) extends MongoPropertyRef[E, T] {
def compose[P](newPrefix: MongoRef[P, E]): MongoPropertyRef[P, T] =
copy(prefix = prefix compose newPrefix)
}

final case class TransparentUnwrap[E, R, T](
prefix: MongoPropertyRef[E, T],
format: MongoFormat[R],
transparentWrapping: TransparentWrapping[R, T],
) extends MongoPropertyRef[E, R] {
def compose[P](newPrefix: MongoRef[P, E]): MongoPropertyRef[P, R] =
copy(prefix = prefix compose newPrefix)
}

final case class PropertySubtypeRef[E, T0, T <: T0](
prefix: MongoPropertyRef[E, T0],
caseFieldName: String,
caseNames: List[String],
format: MongoAdtFormat[T]
format: MongoAdtFormat[T],
) extends MongoPropertyRef[E, T] {
def compose[P](newPrefix: MongoRef[P, E]): MongoPropertyRef[P, T] =
copy(prefix = prefix compose newPrefix)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class MongoRefTest extends AnyFunSuite {
assert(Rte.ref(_.inner.int).rawPath == "inner.int")
assert(Rte.ref(_.complex.get.apply(InnerId("foo")).apply(5).int).rawPath == "complex.foo.5.int")
assert(Rte.ref(_.complex).ref(_.get).ref(_.apply(InnerId("foo"))).ref(_.apply(5)).ref(_.int).rawPath == "complex.foo.5.int")
assert(Rte.ref(_.props.map("key")).rawPath == "props.key")
assert(Rte.ref(_.union.str).rawPath == "union.str")
assert(Rte.ref(_.union.as[CaseOne]).rawPath == "union")
assert(Rte.ref(_.union.as[CaseOne].str).rawPath == "union.str")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TypedMongoCollectionTest extends AnyFunSuite with ScalaFutures with Before
List(ir),
Map(InnerId("iid") -> ir),
Opt(Map(InnerId("iid") -> List(ir))),
Props(Map("foo" -> "bar")),
i % 3 match {
case 0 => CaseOne(s"uid$i", "ustr", i % 2 == 0)
case 1 => CaseTwo(s"uid$i", "ustr", i, Rte.Example)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ case class InnerRecord(
strOpt: Opt[String],
@optionalParam intOpt: Opt[Int],
intList: List[Int],
intMap: Map[String, Int]
intMap: Map[String, Int],
)
object InnerRecord extends MongoDataCompanion[InnerRecord] {
final val Example = InnerRecord(
24, "istr", Opt("istropt"), Opt.Empty, List(3, 4, 5), Map("ione" -> 1, "ithree" -> 3))
}

case class Props(map: Map[String, String]) extends AnyVal
object Props extends TransparentWrapperCompanion[Map[String, String], Props]

case class RecordTestEntity(
id: String,
int: Int,
Expand All @@ -48,14 +51,16 @@ case class RecordTestEntity(
innerList: List[InnerRecord],
innerMap: Map[InnerId, InnerRecord],
complex: Opt[Map[InnerId, List[InnerRecord]]],
@transientDefault union: UnionTestEntity = CaseOne("uid", "ustr", data = false)
props: Props,
@transientDefault union: UnionTestEntity = CaseOne("uid", "ustr", data = false),
) extends MongoEntity[String]
object RecordTestEntity extends MongoEntityCompanion[RecordTestEntity] {
final val Example = RecordTestEntity(
"rid", 42, "str", Timestamp.Zero, Opt("stropt"), Opt.Empty,
List(1, 2, 3), Map("one" -> 1, "two" -> 2), TypedMap(PKey.IntKey -> 42, PKey.InnerKey -> InnerRecord.Example),
InnerRecord.Example, Opt(InnerRecord.Example), List(InnerRecord.Example),
Map(InnerId("iid") -> InnerRecord.Example), Opt(Map(InnerId("iid") -> List(InnerRecord.Example)))
Map(InnerId("iid") -> InnerRecord.Example), Opt(Map(InnerId("iid") -> List(InnerRecord.Example))),
Props(Map.empty),
)
}

Expand All @@ -75,12 +80,12 @@ object TestAutoId extends ObjectIdWrapperCompanion[TestAutoId]

case class RecordTestAutoIdEntity(
str: String,
int: Int
int: Int,
) extends AutoIdMongoEntity[TestAutoId]
object RecordTestAutoIdEntity extends MongoEntityCompanion[RecordTestAutoIdEntity]

case class AutoObjectIdEntity(
str: String,
int: Int
int: Int,
) extends AutoIdMongoEntity[ObjectId]
object AutoObjectIdEntity extends MongoEntityCompanion[AutoObjectIdEntity]

0 comments on commit 8802150

Please sign in to comment.