From 8ad9a64af196402907129730aed747fea091319d Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Wed, 20 Mar 2024 00:49:26 +0100 Subject: [PATCH 1/2] Support basic generics --- guinep/src/main/scala/macros.scala | 79 +++++++++++++++++++++-------- testcases/src/main/scala/main.scala | 14 ++++- 2 files changed, 70 insertions(+), 23 deletions(-) diff --git a/guinep/src/main/scala/macros.scala b/guinep/src/main/scala/macros.scala index 44cf8b9..c4df3c5 100644 --- a/guinep/src/main/scala/macros.scala +++ b/guinep/src/main/scala/macros.scala @@ -22,7 +22,7 @@ private[guinep] object macros { case _ => None } - def wrongParamsListError(f: Expr[Any]): Nothing = + private def wrongParamsListError(f: Expr[Any]): Nothing = report.errorAndAbort(s"Wrong params list, expected a function reference, got: ${f.show}", f.asTerm.pos) private def unsupportedFunctionParamType(t: TypeRepr, pos: Option[Position] = None): Nothing = pos match { @@ -35,6 +35,10 @@ private[guinep] object macros { private def select(s: String): Term = t.select(t.tpe.typeSymbol.methodMember(s).head) + extension (s: Symbol) + private def prettyName: String = + s.name.stripSuffix("$") + private def functionNameImpl(f: Expr[Any]): Expr[String] = { val name = f.asTerm match { case Inlined(_, _, Lambda(_, body)) => @@ -83,26 +87,28 @@ private[guinep] object macros { val isEnumCaseNonClassDef = typeSymbol.flags.is(Flags.Enum) && typeSymbol.flags.is(Flags.Case) && !typeSymbol.isClassDef isModule || isEnumCaseNonClassDef - private def tpeArguments(tpe: TypeRepr): List[TypeRepr] = tpe match { - case AppliedType(tpe, args) => args - case _ => Nil - } - private def functionFormElementFromTree(paramName: String, paramType: TypeRepr): FormElement = paramType match { case ntpe: NamedType if ntpe.name == "String" => FormElement.TextInput(paramName) case ntpe: NamedType if ntpe.name == "Int" => FormElement.NumberInput(paramName) case ntpe: NamedType if ntpe.name == "Boolean" => FormElement.CheckboxInput(paramName) case ntpe if isProductTpe(ntpe) => val classSymbol = ntpe.typeSymbol + val typeDefParams = classSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTypeParam) val fields = classSymbol.primaryConstructor.paramSymss.flatten.filter(_.isValDef).map(_.tree).collect { case v: ValDef => v } - FormElement.FieldSet(paramName, fields.map(v => functionFormElementFromTree(v.name, v.tpt.tpe))) + FormElement.FieldSet( + paramName, + fields.map { valdef => + functionFormElementFromTree( + valdef.name, + valdef.tpt.tpe.substituteTypes(typeDefParams, ntpe.typeArgs) + ) + } + ) case ntpe if isSumTpe(ntpe) => val classSymbol = ntpe.typeSymbol - val typeParamSyms = classSymbol.primaryConstructor.paramSymss.flatten.filter(_.isType) - val tpeArgs = tpeArguments(ntpe) - val childrenAppliedTpes = classSymbol.children.map(_.typeRef) + val childrenAppliedTpes = classSymbol.children.map(child => appliedChild(child, classSymbol, ntpe.typeArgs)) val childrenFormElements = childrenAppliedTpes.map(t => functionFormElementFromTree("value", t)) - val options = classSymbol.children.map(_.name).zip(childrenFormElements) + val options = classSymbol.children.map(_.prettyName).zip(childrenFormElements) FormElement.Dropdown(paramName, options) case _ => unsupportedFunctionParamType(paramType) @@ -113,37 +119,68 @@ private[guinep] object macros { functionParams(f).map { case ValDef(name, tpt, _) => functionFormElementFromTree(name, tpt.tpe) } .map(Expr(_)) ) + private def appliedChild(childSym: Symbol, parentSym: Symbol, parentArgs: List[TypeRepr]): TypeRepr = childSym.tree match { + case classDef @ ClassDef(_, _, parents, _, _) => + parents + .collect { + case tpt: TypeTree => tpt.tpe + } + .collectFirst { + case AppliedType(tpe, args) if tpe.typeSymbol == parentSym => args + case tpe if tpe.typeSymbol == parentSym => Nil + }.match + case None => + report.errorAndAbort(s"""PANIC: Could not find applied parent for ${childSym.name}, parents: ${parents.map(_.show).mkString(",")}""", classDef.pos) + case Some(parentExtendsArgs) => + val childDefArgs = classDef.symbol.primaryConstructor.paramSymss.flatten.filter(_.isTypeParam).map(_.typeRef) + val childArgTpes = childDefArgs.map { arg => + arg.substituteTypes(parentExtendsArgs.map(_.typeSymbol), parentArgs) + } + // TODO(kπ) might want to handle the case when there are unsubstituted type parameters left + val childTpe = childSym.typeRef.appliedTo(childArgTpes) + childTpe + case _ => + childSym.typeRef + } + private def constructArg(paramTpe: TypeRepr, param: Term): Term = { paramTpe match { case ntpe: NamedType if ntpe.name == "String" => param.select("asInstanceOf").appliedToType(ntpe) case ntpe: NamedType if ntpe.name == "Int" => param.select("asInstanceOf").appliedToType(ntpe) case ntpe: NamedType if ntpe.name == "Boolean" => param.select("asInstanceOf").appliedToType(ntpe) + case ntpe if isCaseObjectTpe(ntpe) && ntpe.typeSymbol.flags.is(Flags.Module) => + Ref(ntpe.typeSymbol.companionModule) case ntpe if isCaseObjectTpe(ntpe) => - Ident(ntpe.typeSymbol.termRef) + Ref(ntpe.typeSymbol) case ntpe if isProductTpe(ntpe) => - val classSymbol = ntpe.classSymbol.getOrElse(unsupportedFunctionParamType(paramTpe, Some(param.pos))) + val classSymbol = ntpe.typeSymbol + val typeDefParams = classSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTypeParam) val fields = classSymbol.primaryConstructor.paramSymss.flatten.filter(_.isValDef).map(_.tree) val paramValue = '{ ${param.asExpr}.asInstanceOf[Map[String, Any]] }.asTerm val args = fields.collect { case field: ValDef => - val fieldName = field.asInstanceOf[ValDef].name + val fieldName = field.name val fieldValue = paramValue.select("apply").appliedTo(Literal(StringConstant(fieldName))) - constructArg(field.tpt.tpe, fieldValue) + constructArg( + field.tpt.tpe.substituteTypes(typeDefParams, ntpe.typeArgs), + fieldValue + ) } - New(Inferred(ntpe)).select(classSymbol.primaryConstructor).appliedToArgs(args) + New(Inferred(ntpe.typeSymbol.typeRef)).select(classSymbol.primaryConstructor).appliedToTypes(ntpe.typeArgs).appliedToArgs(args) case ntpe if isSumTpe(ntpe) => - val classSymbol = ntpe.classSymbol.getOrElse(unsupportedFunctionParamType(paramTpe, Some(param.pos))) + val classSymbol = ntpe.typeSymbol val className = classSymbol.name val children = classSymbol.children + val childrenAppliedTpes = children.map(child => appliedChild(child, classSymbol, ntpe.typeArgs)) val paramMap = '{ ${param.asExpr}.asInstanceOf[Map[String, Any]] }.asTerm val paramName = paramMap.select("apply").appliedTo(Literal(StringConstant("name"))) val paramValue = paramMap.select("apply").appliedTo(Literal(StringConstant("value"))) - children.foldRight[Term]{ + children.zip(childrenAppliedTpes).foldRight[Term]{ '{ throw new RuntimeException(s"Class ${${paramName.asExpr}} is not a child of ${${Expr(className)}}") }.asTerm - } { (child, acc) => - val childName = Literal(StringConstant(child.name)) + } { case ((child, childAppliedTpe), acc) => + val childName = Literal(StringConstant(child.prettyName)) If( paramName.select("equals").appliedTo(childName), - constructArg(child.typeRef, paramValue), + constructArg(childAppliedTpe, paramValue), acc ) } diff --git a/testcases/src/main/scala/main.scala b/testcases/src/main/scala/main.scala index 774c968..b3ddf44 100644 --- a/testcases/src/main/scala/main.scala +++ b/testcases/src/main/scala/main.scala @@ -57,6 +57,15 @@ def roll20: Int = def roll6(): Int = scala.util.Random.nextInt(6) + 1 +sealed trait WeirdGADT[+A] +case class IntValue(value: Int) extends WeirdGADT[Int] +case class SomeValue[+A](value: A) extends WeirdGADT[A] +case class SomeOtherValue[+A, +B](value: A, value2: B) extends WeirdGADT[A] + +def printsWeirdGADT(g: WeirdGADT[String]): String = g match + case SomeValue(value) => s"SomeValue($value)" + case SomeOtherValue(value, value2) => s"SomeOtherValue($value, $value2)" + @main def run: Unit = guinep.web( @@ -66,10 +75,11 @@ def run: Unit = concat, giveALongText, addObj, - // greetMaybeName, + greetMaybeName, greetInLanguage, nameWithPossiblePrefix, nameWithPossiblePrefix1, roll20, - roll6() + roll6(), + // printsWeirdGADT ) From f41e91d35920b7db3831b44b40cdaf696726e6d3 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Wed, 20 Mar 2024 00:55:54 +0100 Subject: [PATCH 2/2] Strip annotations from types --- guinep/src/main/scala/macros.scala | 12 +++++++++--- testcases/src/main/scala/main.scala | 6 ++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/guinep/src/main/scala/macros.scala b/guinep/src/main/scala/macros.scala index c4df3c5..6127ce5 100644 --- a/guinep/src/main/scala/macros.scala +++ b/guinep/src/main/scala/macros.scala @@ -39,6 +39,12 @@ private[guinep] object macros { private def prettyName: String = s.name.stripSuffix("$") + extension (tpe: TypeRepr) + private def stripAnnots: TypeRepr = tpe match { + case AnnotatedType(tpe, _) => tpe.stripAnnots + case _ => tpe + } + private def functionNameImpl(f: Expr[Any]): Expr[String] = { val name = f.asTerm match { case Inlined(_, _, Lambda(_, body)) => @@ -100,13 +106,13 @@ private[guinep] object macros { fields.map { valdef => functionFormElementFromTree( valdef.name, - valdef.tpt.tpe.substituteTypes(typeDefParams, ntpe.typeArgs) + valdef.tpt.tpe.substituteTypes(typeDefParams, ntpe.typeArgs).stripAnnots ) } ) case ntpe if isSumTpe(ntpe) => val classSymbol = ntpe.typeSymbol - val childrenAppliedTpes = classSymbol.children.map(child => appliedChild(child, classSymbol, ntpe.typeArgs)) + val childrenAppliedTpes = classSymbol.children.map(child => appliedChild(child, classSymbol, ntpe.typeArgs)).map(_.stripAnnots) val childrenFormElements = childrenAppliedTpes.map(t => functionFormElementFromTree("value", t)) val options = classSymbol.children.map(_.prettyName).zip(childrenFormElements) FormElement.Dropdown(paramName, options) @@ -170,7 +176,7 @@ private[guinep] object macros { val classSymbol = ntpe.typeSymbol val className = classSymbol.name val children = classSymbol.children - val childrenAppliedTpes = children.map(child => appliedChild(child, classSymbol, ntpe.typeArgs)) + val childrenAppliedTpes = children.map(child => appliedChild(child, classSymbol, ntpe.typeArgs)).map(_.stripAnnots) val paramMap = '{ ${param.asExpr}.asInstanceOf[Map[String, Any]] }.asTerm val paramName = paramMap.select("apply").appliedTo(Literal(StringConstant("name"))) val paramValue = paramMap.select("apply").appliedTo(Literal(StringConstant("value"))) diff --git a/testcases/src/main/scala/main.scala b/testcases/src/main/scala/main.scala index b3ddf44..b884e37 100644 --- a/testcases/src/main/scala/main.scala +++ b/testcases/src/main/scala/main.scala @@ -62,10 +62,15 @@ case class IntValue(value: Int) extends WeirdGADT[Int] case class SomeValue[+A](value: A) extends WeirdGADT[A] case class SomeOtherValue[+A, +B](value: A, value2: B) extends WeirdGADT[A] +// This fails on unknown type params def printsWeirdGADT(g: WeirdGADT[String]): String = g match case SomeValue(value) => s"SomeValue($value)" case SomeOtherValue(value, value2) => s"SomeOtherValue($value, $value2)" +// This loops forever +def concatAll(elems: List[String]): String = + elems.mkString + @main def run: Unit = guinep.web( @@ -82,4 +87,5 @@ def run: Unit = roll20, roll6(), // printsWeirdGADT + // concatAll )