From c59018549e10475ce4a6d5572c24881a5bad9e99 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Tue, 20 Jun 2023 13:03:42 +0200 Subject: [PATCH] Intrinsify `constValueTuple` and `summonAll` The new implementation instantiates the TupleN/TupleXXL classes directly. This avoids the expensive construction of tuples using `*:`. Fixes #15988 --- compiler/src/dotty/tools/dotc/ast/tpd.scala | 19 +++++ .../dotty/tools/dotc/core/Definitions.scala | 6 +- .../dotty/tools/dotc/inlines/Inliner.scala | 4 +- .../dotty/tools/dotc/inlines/Inlines.scala | 69 ++++++++++++++----- library/src/scala/compiletime/package.scala | 19 ++--- tests/neg/17211.check | 20 +++--- tests/neg/i14177a.scala | 2 +- tests/run/i15988a.scala | 6 ++ tests/run/i15988b.scala | 21 ++++++ 9 files changed, 119 insertions(+), 47 deletions(-) create mode 100644 tests/run/i15988a.scala create mode 100644 tests/run/i15988b.scala diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 76e16cc00a90..71f8dafa9206 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1514,6 +1514,25 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { } } + /** Creates the tuple containing the elemets */ + def tupleTree(elems: List[Tree])(using Context): Tree = { + val arity = elems.length + if arity == 0 then + ref(defn.EmptyTupleModule) + else if arity <= Definitions.MaxTupleArity then + // TupleN[elem1Tpe, ...](elem1, ...) + ref(defn.TupleType(arity).nn.typeSymbol.companionModule) + .select(nme.apply) + .appliedToTypes(elems.map(_.tpe.widenIfUnstable)) + .appliedToArgs(elems) + else + // TupleXXL.apply(elems*) // TODO add and use Tuple.apply(elems*) ? + ref(defn.TupleXXLModule) + .select(nme.apply) + .appliedToVarargs(elems.map(_.asInstance(defn.ObjectType)), TypeTree(defn.ObjectType)) + .asInstance(defn.tupleType(elems.map(elem => elem.tpe.widenIfUnstable))) + } + /** Creates the tuple type tree representation of the type trees in `ts` */ def tupleTypeTree(elems: List[Tree])(using Context): Tree = { val arity = elems.length diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index ca92364c5b9b..249d17de344d 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -243,8 +243,10 @@ class Definitions { @tu lazy val Compiletime_requireConst : Symbol = CompiletimePackageClass.requiredMethod("requireConst") @tu lazy val Compiletime_constValue : Symbol = CompiletimePackageClass.requiredMethod("constValue") @tu lazy val Compiletime_constValueOpt: Symbol = CompiletimePackageClass.requiredMethod("constValueOpt") + @tu lazy val Compiletime_constValueTuple: Symbol = CompiletimePackageClass.requiredMethod("constValueTuple") @tu lazy val Compiletime_summonFrom : Symbol = CompiletimePackageClass.requiredMethod("summonFrom") - @tu lazy val Compiletime_summonInline : Symbol = CompiletimePackageClass.requiredMethod("summonInline") + @tu lazy val Compiletime_summonInline : Symbol = CompiletimePackageClass.requiredMethod("summonInline") + @tu lazy val Compiletime_summonAll : Symbol = CompiletimePackageClass.requiredMethod("summonAll") @tu lazy val CompiletimeTestingPackage: Symbol = requiredPackage("scala.compiletime.testing") @tu lazy val CompiletimeTesting_typeChecks: Symbol = CompiletimeTestingPackage.requiredMethod("typeChecks") @tu lazy val CompiletimeTesting_typeCheckErrors: Symbol = CompiletimeTestingPackage.requiredMethod("typeCheckErrors") @@ -932,6 +934,8 @@ class Definitions { @tu lazy val TupleTypeRef: TypeRef = requiredClassRef("scala.Tuple") def TupleClass(using Context): ClassSymbol = TupleTypeRef.symbol.asClass @tu lazy val Tuple_cons: Symbol = TupleClass.requiredMethod("*:") + @tu lazy val TupleModule: Symbol = requiredModule("scala.Tuple") + @tu lazy val EmptyTupleClass: Symbol = requiredClass("scala.EmptyTuple") @tu lazy val EmptyTupleModule: Symbol = requiredModule("scala.EmptyTuple") @tu lazy val NonEmptyTupleTypeRef: TypeRef = requiredClassRef("scala.NonEmptyTuple") def NonEmptyTupleClass(using Context): ClassSymbol = NonEmptyTupleTypeRef.symbol.asClass diff --git a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala index 73fa2a2871a2..79497645bcf7 100644 --- a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala @@ -497,8 +497,8 @@ class Inliner(val call: tpd.Tree)(using Context): // assertAllPositioned(tree) // debug tree.changeOwner(originalOwner, ctx.owner) - def tryConstValue: Tree = - TypeComparer.constValue(callTypeArgs.head.tpe) match { + def tryConstValue(tpe: Type): Tree = + TypeComparer.constValue(tpe) match { case Some(c) => Literal(c).withSpan(call.span) case _ => EmptyTree } diff --git a/compiler/src/dotty/tools/dotc/inlines/Inlines.scala b/compiler/src/dotty/tools/dotc/inlines/Inlines.scala index bcc10ffa6db8..d7bfa7d1823f 100644 --- a/compiler/src/dotty/tools/dotc/inlines/Inlines.scala +++ b/compiler/src/dotty/tools/dotc/inlines/Inlines.scala @@ -408,36 +408,67 @@ object Inlines: return Intrinsics.codeOf(arg, call.srcPos) case _ => - // Special handling of `constValue[T]`, `constValueOpt[T], and summonInline[T]` + // Special handling of `constValue[T]`, `constValueOpt[T]`, `constValueTuple[T]`, and `summonInline[T]` if callTypeArgs.length == 1 then - if (inlinedMethod == defn.Compiletime_constValue) { - val constVal = tryConstValue + + def constValueOrError(tpe: Type): Tree = + val constVal = tryConstValue(tpe) if constVal.isEmpty then - val msg = NotConstant("cannot take constValue", callTypeArgs.head.tpe) - return ref(defn.Predef_undefined).withSpan(call.span).withType(ErrorType(msg)) + val msg = NotConstant("cannot take constValue", tpe) + ref(defn.Predef_undefined).withSpan(callTypeArgs.head.span).withType(ErrorType(msg)) else - return constVal + constVal + + def searchImplicitOrError(tpe: Type): Tree = + val evTyper = new Typer(ctx.nestingLevel + 1) + val evCtx = ctx.fresh.setTyper(evTyper) + inContext(evCtx) { + val evidence = evTyper.inferImplicitArg(tpe, callTypeArgs.head.span) + evidence.tpe match + case fail: Implicits.SearchFailureType => + errorTree(call, evTyper.missingArgMsg(evidence, tpe, "")) + case _ => + evidence + } + + def unrollTupleTypes(tpe: Type): Option[List[Type]] = tpe.dealias match + case AppliedType(tycon, args) if defn.isTupleClass(tycon.typeSymbol) => + Some(args) + case AppliedType(tycon, head :: tail :: Nil) if tycon.isRef(defn.PairClass) => + unrollTupleTypes(tail).map(head :: _) + case tpe: TermRef if tpe.symbol == defn.EmptyTupleModule => + Some(Nil) + case _ => + None + + if (inlinedMethod == defn.Compiletime_constValue) { + return constValueOrError(callTypeArgs.head.tpe) } else if (inlinedMethod == defn.Compiletime_constValueOpt) { - val constVal = tryConstValue + val constVal = tryConstValue(callTypeArgs.head.tpe) return ( if (constVal.isEmpty) ref(defn.NoneModule.termRef) else New(defn.SomeClass.typeRef.appliedTo(constVal.tpe), constVal :: Nil) ) } + else if (inlinedMethod == defn.Compiletime_constValueTuple) { + unrollTupleTypes(callTypeArgs.head.tpe) match + case Some(types) => + val constants = types.map(constValueOrError) + return Typed(tpd.tupleTree(constants), TypeTree(callTypeArgs.head.tpe)).withSpan(call.span) + case _ => + return errorTree(call, em"Tuple element types must be known at compile time") + } else if (inlinedMethod == defn.Compiletime_summonInline) { - def searchImplicit(tpt: Tree) = - val evTyper = new Typer(ctx.nestingLevel + 1) - val evCtx = ctx.fresh.setTyper(evTyper) - inContext(evCtx) { - val evidence = evTyper.inferImplicitArg(tpt.tpe, tpt.span) - evidence.tpe match - case fail: Implicits.SearchFailureType => - errorTree(call, evTyper.missingArgMsg(evidence, tpt.tpe, "")) - case _ => - evidence - } - return searchImplicit(callTypeArgs.head) + return searchImplicitOrError(callTypeArgs.head.tpe) + } + else if (inlinedMethod == defn.Compiletime_summonAll) { + unrollTupleTypes(callTypeArgs.head.tpe) match + case Some(types) => + val implicits = types.map(searchImplicitOrError) + return Typed(tpd.tupleTree(implicits), TypeTree(callTypeArgs.head.tpe)).withSpan(call.span) + case _ => + return errorTree(call, em"Tuple element types must be known at compile time") } end if diff --git a/library/src/scala/compiletime/package.scala b/library/src/scala/compiletime/package.scala index ff00b83bcb79..3eca997554a0 100644 --- a/library/src/scala/compiletime/package.scala +++ b/library/src/scala/compiletime/package.scala @@ -117,13 +117,9 @@ transparent inline def constValue[T]: T = * `(constValue[X1], ..., constValue[Xn])`. */ inline def constValueTuple[T <: Tuple]: T = - val res = - inline erasedValue[T] match - case _: EmptyTuple => EmptyTuple - case _: (t *: ts) => constValue[t] *: constValueTuple[ts] - end match - res.asInstanceOf[T] -end constValueTuple + // implemented in dotty.tools.dotc.typer.Inliner + error("Compiler bug: `constValueTuple` was not evaluated by the compiler") + /** Summons first given matching one of the listed cases. E.g. in * @@ -168,13 +164,8 @@ transparent inline def summonInline[T]: T = * @return the given values typed as elements of the tuple */ inline def summonAll[T <: Tuple]: T = - val res = - inline erasedValue[T] match - case _: EmptyTuple => EmptyTuple - case _: (t *: ts) => summonInline[t] *: summonAll[ts] - end match - res.asInstanceOf[T] -end summonAll + // implemented in dotty.tools.dotc.typer.Inliner + error("Compiler bug: `summonAll` was not evaluated by the compiler") /** Assertion that an argument is by-name. Used for nullability checking. */ def byName[T](x: => T): T = x diff --git a/tests/neg/17211.check b/tests/neg/17211.check index 3c2f10a61957..be7086e3b3eb 100644 --- a/tests/neg/17211.check +++ b/tests/neg/17211.check @@ -1,14 +1,14 @@ --- [E182] Type Error: tests/neg/17211.scala:14:12 ---------------------------------------------------------------------- +-- [E182] Type Error: tests/neg/17211.scala:14:13 ---------------------------------------------------------------------- 14 | constValue[IsInt[Foo.Foo]] // error - | ^^^^^^^^^^^^^^^^^^^^^^^^^^ - | IsInt[Foo.Foo] is not a constant type; cannot take constValue + | ^^^^^^^^^^^^^^ + | IsInt[Foo.Foo] is not a constant type; cannot take constValue | - | Note: a match type could not be fully reduced: + | Note: a match type could not be fully reduced: | - | trying to reduce IsInt[Foo.Foo] - | failed since selector Foo.Foo - | does not match case Int => (true : Boolean) - | and cannot be shown to be disjoint from it either. - | Therefore, reduction cannot advance to the remaining case + | trying to reduce IsInt[Foo.Foo] + | failed since selector Foo.Foo + | does not match case Int => (true : Boolean) + | and cannot be shown to be disjoint from it either. + | Therefore, reduction cannot advance to the remaining case | - | case _ => (false : Boolean) + | case _ => (false : Boolean) diff --git a/tests/neg/i14177a.scala b/tests/neg/i14177a.scala index 3031271c369b..237eaacb3b66 100644 --- a/tests/neg/i14177a.scala +++ b/tests/neg/i14177a.scala @@ -3,4 +3,4 @@ import scala.compiletime.* trait C[A] inline given [Tup <: Tuple]: C[Tup] with - val cs = summonAll[Tuple.Map[Tup, C]] // error cannot reduce inline match with + val cs = summonAll[Tuple.Map[Tup, C]] // error: Tuple element types must be known at compile time diff --git a/tests/run/i15988a.scala b/tests/run/i15988a.scala new file mode 100644 index 000000000000..dba5008fd950 --- /dev/null +++ b/tests/run/i15988a.scala @@ -0,0 +1,6 @@ +import scala.compiletime.constValueTuple + +@main def Test: Unit = + assert(constValueTuple[EmptyTuple] == EmptyTuple) + assert(constValueTuple[("foo", 5, 3.14, "bar", false)] == ("foo", 5, 3.14, "bar", false)) + assert(constValueTuple[(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)] == (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)) diff --git a/tests/run/i15988b.scala b/tests/run/i15988b.scala new file mode 100644 index 000000000000..4b7764d94a18 --- /dev/null +++ b/tests/run/i15988b.scala @@ -0,0 +1,21 @@ +import scala.compiletime.summonAll + +@main def Test: Unit = + assert(summonAll[EmptyTuple] == EmptyTuple) + assert(summonAll[(5, 5, 5)] == (5, 5, 5)) + assert( + summonAll[( + 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, + )] == ( + 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, + )) + +given 5 = 5