Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proper support for transparent wrappers in typed mongo #428

Merged
merged 1 commit into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ class MongoMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) {
}
}

private def isTransparentUnwrap(prefixTpe: Type, fieldSym: Symbol): Boolean = {
val sym = prefixTpe.typeSymbol
sym.isClass && sym.asClass.isCaseClass && (primaryConstructorOf(prefixTpe).asMethod.paramLists match {
case List(param) :: _ if param.name == fieldSym.name =>
val paramTpe = fieldSym.typeSignatureIn(prefixTpe).finalResultType
val wrappingTpe = getType(tq"$SerializationPkg.TransparentWrapping[$paramTpe, $prefixTpe]")
inferImplicitValue(wrappingTpe) != EmptyTree
case _ => false
})
}

def isOptionLike(fullTpe: Type, wrappedTpe: Type): Boolean =
fullTpe != null && fullTpe != NoType && wrappedTpe != null && wrappedTpe != NoType &&
c.inferImplicitValue(getType(tq"$CommonsPkg.meta.OptionLike.Aux[$fullTpe, $wrappedTpe]")) != EmptyTree
Expand Down Expand Up @@ -63,7 +74,9 @@ class MongoMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) {
val prefixTpe = prefix.tpe.widen
val bodyTpe = body.tpe.widen

if (termSym.isCaseAccessor || isSealedHierarchySharedField(prefixTpe, body.symbol.asTerm))
if (isTransparentUnwrap(prefixTpe, body.symbol.asTerm))
q"$newPrefixRef.unwrap"
else if (termSym.isCaseAccessor || isSealedHierarchySharedField(prefixTpe, body.symbol.asTerm))
q"$newPrefixRef.asAdtRef.fieldRefFor[$bodyTpe](${name.decodedName.toString})"
else if (name == TermName("get") && isOptionLike(prefixTpe, bodyTpe))
q"$newPrefixRef.get"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ private final class FilterDocBuilder(prefixPath: Opt[String], filterDocs: BsonAr
appendToPrefix(prefix.rawPath).addOperator(MongoQueryOperator.Ne(optionLike.none, prefix.format))
addImpliedFilters(prefix)

case MongoRef.TransparentUnwrap(prefix, _, _) =>
addImpliedFilters(prefix)

case MongoRef.FieldRef(prefix, _, _, _) =>
addImpliedFilters(prefix)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ sealed trait MongoFormat[T] {
case union: MongoAdtFormat.UnionFormat[T] => union
case _ => throw new IllegalArgumentException(
"Encountered a non-union MongoFormat for an union type (sealed hierarchy) -" +
"do you have any custom implicit MongoFormat for that type?"
"do you have any custom implicit MongoFormat for that type?",
)
}

Expand All @@ -44,51 +44,68 @@ sealed trait MongoFormat[T] {
"Encountered a non-optional MongoFormat for an Option-like type - " +
"do you have a custom implicit MongoFormat for that type?")
}

def assumeTransparent[R]: MongoFormat.TransparentFormat[T, R] = this match {
case transparent: MongoFormat.TransparentFormat[T@unchecked, R@unchecked] => transparent
case _ => throw new IllegalArgumentException(
"Encountered a non-transparent MongoFormat for a transparent wrapper type - " +
"do you have a custom implicit MongoFormat for that type?")
}
}
object MongoFormat extends MetadataCompanion[MongoFormat] with MongoFormatLowPriority {
final case class Opaque[T](
codec: GenCodec[T]
codec: GenCodec[T],
) extends MongoFormat[T]

final case class CollectionFormat[C[X] <: Iterable[X], T](
codec: GenCodec[C[T]],
elementFormat: MongoFormat[T]
elementFormat: MongoFormat[T],
) extends MongoFormat[C[T]]

final case class DictionaryFormat[M[X, Y] <: BMap[X, Y], K, V](
codec: GenCodec[M[K, V]],
keyCodec: GenKeyCodec[K],
valueFormat: MongoFormat[V]
valueFormat: MongoFormat[V],
) extends MongoFormat[M[K, V]]

final case class TypedMapFormat[K[_]](
codec: GenCodec[TypedMap[K]],
keyCodec: GenKeyCodec[K[_]],
valueFormats: MongoFormatMapping[K]
valueFormats: MongoFormatMapping[K],
) extends MongoFormat[TypedMap[K]]

final case class OptionalFormat[O, T](
codec: GenCodec[O],
optionLike: OptionLike.Aux[O, T],
wrappedFormat: MongoFormat[T]
wrappedFormat: MongoFormat[T],
) extends MongoFormat[O]

final case class TransparentFormat[T, R](
codec: GenCodec[T],
wrapping: TransparentWrapping[R, T],
wrappedFormat: MongoFormat[R],
) extends MongoFormat[T]

implicit def collectionFormat[C[X] <: Iterable[X], T](
implicit collectionCodec: GenCodec[C[T]], elementFormat: MongoFormat[T]
implicit collectionCodec: GenCodec[C[T]], elementFormat: MongoFormat[T],
): MongoFormat[C[T]] = CollectionFormat(collectionCodec, elementFormat)

implicit def dictionaryFormat[M[X, Y] <: BMap[X, Y], K, V](
implicit mapCodec: GenCodec[M[K, V]], keyCodec: GenKeyCodec[K], valueFormat: MongoFormat[V]
implicit mapCodec: GenCodec[M[K, V]], keyCodec: GenKeyCodec[K], valueFormat: MongoFormat[V],
): MongoFormat[M[K, V]] = DictionaryFormat(mapCodec, keyCodec, valueFormat)

implicit def typedMapFormat[K[_]](
implicit keyCodec: GenKeyCodec[K[_]], valueFormats: MongoFormatMapping[K]
implicit keyCodec: GenKeyCodec[K[_]], valueFormats: MongoFormatMapping[K],
): MongoFormat[TypedMap[K]] = TypedMapFormat[K](TypedMap.typedMapCodec, keyCodec, valueFormats)

implicit def optionalFormat[O, T](
implicit optionLike: OptionLike.Aux[O, T], optionCodec: GenCodec[O], wrappedFormat: MongoFormat[T]
implicit optionLike: OptionLike.Aux[O, T], optionCodec: GenCodec[O], wrappedFormat: MongoFormat[T],
): MongoFormat[O] = OptionalFormat(optionCodec, optionLike, wrappedFormat)

implicit def transparentFormat[R, T](
implicit codec: GenCodec[T], wrapping: TransparentWrapping[R, T], wrappedFormat: MongoFormat[R],
): MongoFormat[T] = TransparentFormat(codec, wrapping, wrappedFormat)

implicit class collectionFormatOps[C[X] <: Iterable[X], T](private val format: MongoFormat[C[T]]) extends AnyVal {
def assumeCollection: CollectionFormat[C, T] = format match {
case coll: CollectionFormat[C@unchecked, T@unchecked] => coll
Expand Down Expand Up @@ -134,7 +151,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@infer val codec: GenObjectCodec[T],
@infer val dataClassTag: ClassTag[T],
@reifyAnnot val flattenAnnot: flatten,
@multi @adtCaseMetadata val cases: List[Case[_]]
@multi @adtCaseMetadata val cases: List[Case[_]],
) extends MongoAdtFormat[T] {

lazy val casesByClass: Map[Class[_], Case[_]] =
Expand Down Expand Up @@ -163,7 +180,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@tailrec def loop(cases: List[Case[_]], rawName: Opt[String]): Unit = cases match {
case cse :: tail =>
val field = cse.getField(scalaFieldName).getOrElse(throw new NoSuchElementException(
s"Field $scalaFieldName not found in at least one case class/object."
s"Field $scalaFieldName not found in at least one case class/object.",
))
if (rawName.exists(_ != field.info.rawName)) {
throw new IllegalArgumentException(s"Field $scalaFieldName has different raw name across case classes")
Expand Down Expand Up @@ -207,7 +224,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@positioned(positioned.here)
final class RecordFormat[T](
@composite val record: RecordCase[T],
@infer val codec: GenObjectCodec[T]
@infer val codec: GenObjectCodec[T],
) extends MongoAdtFormat[T] {
def dataClassTag: ClassTag[T] = record.classTag

Expand All @@ -218,7 +235,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@positioned(positioned.here)
final class SingletonFormat[T](
@composite val singleton: SingletonCase[T],
@infer val codec: GenObjectCodec[T]
@infer val codec: GenObjectCodec[T],
) extends MongoAdtFormat[T] {
def dataClassTag: ClassTag[T] = singleton.classTag

Expand All @@ -242,7 +259,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@composite val info: GenCaseInfo[T],
@infer val classTag: ClassTag[T],
@multi @adtParamMetadata val fields: List[Field[_]],
@multi @adtCaseSealedParentMetadata val sealedParents: List[SealedParent[_]]
@multi @adtCaseSealedParentMetadata val sealedParents: List[SealedParent[_]],
) extends Case[T] {
def asAdtFormat(codec: GenObjectCodec[T]): MongoAdtFormat[T] =
new RecordFormat(this, codec)
Expand All @@ -258,7 +275,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {

def fieldRefFor[E, T0](prefix: MongoRef[E, T], scalaFieldName: String): MongoPropertyRef[E, T0] = {
val field = fieldsByScalaName.getOrElse(scalaFieldName,
throw new NoSuchElementException(s"Field $scalaFieldName not found")
throw new NoSuchElementException(s"Field $scalaFieldName not found"),
).asInstanceOf[MongoAdtFormat.Field[T0]]
prefix match {
case fieldRef: MongoRef.FieldRef[E, _, T] if transparentWrapper =>
Expand All @@ -274,7 +291,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@composite val info: GenCaseInfo[T],
@infer val classTag: ClassTag[T],
@multi @adtCaseSealedParentMetadata val sealedParents: List[SealedParent[_]],
@infer @checked val value: ValueOf[T]
@infer @checked val value: ValueOf[T],
) extends Case[T] {
def asAdtFormat(codec: GenObjectCodec[T]): MongoAdtFormat[T] =
new SingletonFormat(this, codec)
Expand All @@ -290,7 +307,7 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {
@composite val info: GenParamInfo[T],
@optional @reifyDefaultValue defaultValue: Opt[DefaultValue[T]],
@optional @reifyAnnot whenAbsentAnnot: Opt[whenAbsent[T]],
@infer val format: MongoFormat.Lazy[T]
@infer val format: MongoFormat.Lazy[T],
) extends TypedMetadata[T] {
lazy val fallbackBson: Opt[BsonValue] = {
if (info.optional) Opt(BsonNull.VALUE)
Expand All @@ -301,13 +318,13 @@ object MongoAdtFormat extends AdtMetadataCompanion[MongoAdtFormat] {

final class SealedParent[T](
@composite val info: GenUnionInfo[T],
@infer val classTag: ClassTag[T]
@infer val classTag: ClassTag[T],
) extends TypedMetadata[T]
}

final class MongoEntityMeta[E <: BaseMongoEntity](
@infer val format: MongoAdtFormat[E],
@infer val idMode: EntityIdMode[E, E#IDType]
@infer val idMode: EntityIdMode[E, E#IDType],
) {
def idRef: MongoPropertyRef[E, E#IDType] = idMode.idRef(format)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.avsystem.commons.misc.TypedMap
import com.avsystem.commons.mongo.typed.MongoPropertyRef.Separator
import com.avsystem.commons.mongo.{BsonValueInput, KeyEscaper}
import com.avsystem.commons.serialization.GenCodec.ReadFailure
import com.avsystem.commons.serialization.TransparentWrapping
import org.bson.{BsonDocument, BsonValue}

/**
Expand Down Expand Up @@ -192,7 +193,7 @@ sealed trait MongoPropertyRef[E, T] extends MongoRef[E, T]
private def computePath[T0](
onlyUpToArray: Boolean,
ref: MongoPropertyRef[E, T0],
acc: List[String]
acc: List[String],
): List[String] = ref match {
case FieldRef(_: MongoToplevelRef[_, _], fieldName, _, _) =>
KeyEscaper.escape(fieldName) :: acc
Expand All @@ -207,6 +208,9 @@ sealed trait MongoPropertyRef[E, T] extends MongoRef[E, T]
case GetFromOptional(prefix, _, _) =>
computePath(onlyUpToArray, prefix, acc)

case TransparentUnwrap(prefix, _, _) =>
computePath(onlyUpToArray, prefix, acc)

case PropertySubtypeRef(prefix, _, _, _) =>
computePath(onlyUpToArray, prefix, acc)
}
Expand Down Expand Up @@ -234,6 +238,9 @@ sealed trait MongoPropertyRef[E, T] extends MongoRef[E, T]
case GetFromOptional(prefix, _, _) =>
prefix.extractBson(doc)

case TransparentUnwrap(prefix, _, _) =>
prefix.extractBson(doc)

case PropertySubtypeRef(prefix, _, _, _) =>
prefix.extractBson(doc)
}
Expand Down Expand Up @@ -274,12 +281,23 @@ object MongoPropertyRef {
MongoRef.GetFromOptional(ref, format.wrappedFormat, format.optionLike)
}
}

implicit def transparentRefOps[E, T, R](ref: MongoPropertyRef[E, T])(implicit wrapping: TransparentWrapping[R, T]): TransparentRefOps[E, T, R] =
new TransparentRefOps[E, T, R](ref)

class TransparentRefOps[E, T, R](private val ref: MongoPropertyRef[E, T]) extends AnyVal {
def unwrap: MongoPropertyRef[E, R] = {
val format = ref.format.assumeTransparent[R]
MongoRef.TransparentUnwrap(ref, format.wrappedFormat, format.wrapping)
}
}

}

object MongoRef {
// Deliberately not calling this IdentityRef so that it doesn't get confused with IdRef (for database ID field)
final case class RootRef[T](
format: MongoAdtFormat[T]
format: MongoAdtFormat[T],
) extends MongoToplevelRef[T, T] {
def fullRef: RootRef[T] = this
def compose[P](prefix: MongoRef[P, T]): MongoRef[P, T] = prefix
Expand All @@ -289,7 +307,7 @@ object MongoRef {
fullRef: RootRef[E],
caseFieldName: String,
caseNames: List[String],
format: MongoAdtFormat[T]
format: MongoAdtFormat[T],
) extends MongoToplevelRef[E, T] {
def compose[P](prefix: MongoRef[P, E]): MongoRef[P, T] = prefix match {
case _: MongoToplevelRef[P, E] =>
Expand All @@ -304,7 +322,7 @@ object MongoRef {
prefix: MongoRef[E, E0],
fieldName: String,
format: MongoFormat[T],
fallbackBson: Opt[BsonValue]
fallbackBson: Opt[BsonValue],
) extends MongoPropertyRef[E, T] {
def compose[P](newPrefix: MongoRef[P, E]): MongoPropertyRef[P, T] =
copy(prefix = this.prefix compose newPrefix)
Expand All @@ -313,7 +331,7 @@ object MongoRef {
final case class ArrayIndexRef[E, C[X] <: Iterable[X], T](
prefix: MongoPropertyRef[E, C[T]],
index: Int,
format: MongoFormat[T]
format: MongoFormat[T],
) extends MongoPropertyRef[E, T] {
require(index >= 0, "array index must be non-negative")
def compose[P](newPrefix: MongoRef[P, E]): MongoPropertyRef[P, T] =
Expand All @@ -323,17 +341,26 @@ object MongoRef {
final case class GetFromOptional[E, O, T](
prefix: MongoPropertyRef[E, O],
format: MongoFormat[T],
optionLike: OptionLike.Aux[O, T]
optionLike: OptionLike.Aux[O, T],
) extends MongoPropertyRef[E, T] {
def compose[P](newPrefix: MongoRef[P, E]): MongoPropertyRef[P, T] =
copy(prefix = prefix compose newPrefix)
}

final case class TransparentUnwrap[E, R, T](
prefix: MongoPropertyRef[E, T],
format: MongoFormat[R],
transparentWrapping: TransparentWrapping[R, T],
) extends MongoPropertyRef[E, R] {
def compose[P](newPrefix: MongoRef[P, E]): MongoPropertyRef[P, R] =
copy(prefix = prefix compose newPrefix)
}

final case class PropertySubtypeRef[E, T0, T <: T0](
prefix: MongoPropertyRef[E, T0],
caseFieldName: String,
caseNames: List[String],
format: MongoAdtFormat[T]
format: MongoAdtFormat[T],
) extends MongoPropertyRef[E, T] {
def compose[P](newPrefix: MongoRef[P, E]): MongoPropertyRef[P, T] =
copy(prefix = prefix compose newPrefix)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class MongoRefTest extends AnyFunSuite {
assert(Rte.ref(_.inner.int).rawPath == "inner.int")
assert(Rte.ref(_.complex.get.apply(InnerId("foo")).apply(5).int).rawPath == "complex.foo.5.int")
assert(Rte.ref(_.complex).ref(_.get).ref(_.apply(InnerId("foo"))).ref(_.apply(5)).ref(_.int).rawPath == "complex.foo.5.int")
assert(Rte.ref(_.props.map("key")).rawPath == "props.key")
assert(Rte.ref(_.union.str).rawPath == "union.str")
assert(Rte.ref(_.union.as[CaseOne]).rawPath == "union")
assert(Rte.ref(_.union.as[CaseOne].str).rawPath == "union.str")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TypedMongoCollectionTest extends AnyFunSuite with ScalaFutures with Before
List(ir),
Map(InnerId("iid") -> ir),
Opt(Map(InnerId("iid") -> List(ir))),
Props(Map("foo" -> "bar")),
i % 3 match {
case 0 => CaseOne(s"uid$i", "ustr", i % 2 == 0)
case 1 => CaseTwo(s"uid$i", "ustr", i, Rte.Example)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ case class InnerRecord(
strOpt: Opt[String],
@optionalParam intOpt: Opt[Int],
intList: List[Int],
intMap: Map[String, Int]
intMap: Map[String, Int],
)
object InnerRecord extends MongoDataCompanion[InnerRecord] {
final val Example = InnerRecord(
24, "istr", Opt("istropt"), Opt.Empty, List(3, 4, 5), Map("ione" -> 1, "ithree" -> 3))
}

case class Props(map: Map[String, String]) extends AnyVal
object Props extends TransparentWrapperCompanion[Map[String, String], Props]

case class RecordTestEntity(
id: String,
int: Int,
Expand All @@ -48,14 +51,16 @@ case class RecordTestEntity(
innerList: List[InnerRecord],
innerMap: Map[InnerId, InnerRecord],
complex: Opt[Map[InnerId, List[InnerRecord]]],
@transientDefault union: UnionTestEntity = CaseOne("uid", "ustr", data = false)
props: Props,
@transientDefault union: UnionTestEntity = CaseOne("uid", "ustr", data = false),
) extends MongoEntity[String]
object RecordTestEntity extends MongoEntityCompanion[RecordTestEntity] {
final val Example = RecordTestEntity(
"rid", 42, "str", Timestamp.Zero, Opt("stropt"), Opt.Empty,
List(1, 2, 3), Map("one" -> 1, "two" -> 2), TypedMap(PKey.IntKey -> 42, PKey.InnerKey -> InnerRecord.Example),
InnerRecord.Example, Opt(InnerRecord.Example), List(InnerRecord.Example),
Map(InnerId("iid") -> InnerRecord.Example), Opt(Map(InnerId("iid") -> List(InnerRecord.Example)))
Map(InnerId("iid") -> InnerRecord.Example), Opt(Map(InnerId("iid") -> List(InnerRecord.Example))),
Props(Map.empty),
)
}

Expand All @@ -75,12 +80,12 @@ object TestAutoId extends ObjectIdWrapperCompanion[TestAutoId]

case class RecordTestAutoIdEntity(
str: String,
int: Int
int: Int,
) extends AutoIdMongoEntity[TestAutoId]
object RecordTestAutoIdEntity extends MongoEntityCompanion[RecordTestAutoIdEntity]

case class AutoObjectIdEntity(
str: String,
int: Int
int: Int,
) extends AutoIdMongoEntity[ObjectId]
object AutoObjectIdEntity extends MongoEntityCompanion[AutoObjectIdEntity]