Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make polymorphic functions more efficient and expressive #17548

Merged
merged 8 commits into from
Jun 21, 2023
105 changes: 42 additions & 63 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,40 @@ object desugar {
name
}

/** Strip parens and empty blocks around the body of `tree`. */
def normalizePolyFunction(tree: PolyFunction)(using Context): PolyFunction =
def stripped(body: Tree): Tree = body match
case Parens(body1) =>
stripped(body1)
case Block(Nil, body1) =>
stripped(body1)
case _ => body
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]

/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
val funFlags = fun match
case fun: FunctionWithMods =>
fun.mods.flags
case _ => EmptyFlags

// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val paramFlags = funFlags.toTermFlags & Given
val vparams = vparamTypes.zipWithIndex.map:
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
)).withSpan(tree.span)
end makePolyFunctionType

/** Invent a name for an anonympus given of type or template `impl`. */
def inventGivenOrExtensionName(impl: Tree)(using Context): SimpleName =
val str = impl match
Expand Down Expand Up @@ -1429,17 +1463,20 @@ object desugar {
}

/** Make closure corresponding to function.
* params => body
* [tparams] => params => body
* ==>
* def $anonfun(params) = body
* def $anonfun[tparams](params) = body
* Closure($anonfun)
*/
def makeClosure(params: List[ValDef], body: Tree, tpt: Tree | Null = null, isContextual: Boolean, span: Span)(using Context): Block =
def makeClosure(tparams: List[TypeDef], vparams: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
val paramss: List[ParamClause] =
if tparams.isEmpty then vparams :: Nil
else tparams :: vparams :: Nil
Block(
DefDef(nme.ANON_FUN, params :: Nil, if (tpt == null) TypeTree() else tpt, body)
DefDef(nme.ANON_FUN, paramss, if (tpt == null) TypeTree() else tpt, body)
.withSpan(span)
.withMods(synthetic | Artifact),
Closure(Nil, Ident(nme.ANON_FUN), if (isContextual) ContextualEmptyTree else EmptyTree))
Closure(Nil, Ident(nme.ANON_FUN), EmptyTree))

/** If `nparams` == 1, expand partial function
*
Expand Down Expand Up @@ -1728,62 +1765,6 @@ object desugar {
}
}

def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match {
case Parens(body1) =>
makePolyFunction(targs, body1, pt)
case Block(Nil, body1) =>
makePolyFunction(targs, body1, pt)
case Function(vargs, res) =>
assert(targs.nonEmpty)
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
val mods = body match {
case body: FunctionWithMods => body.mods
case _ => untpd.EmptyModifiers
}
val polyFunctionTpt = ref(defn.PolyFunctionType)
val applyTParams = targs.asInstanceOf[List[TypeDef]]
if (ctx.mode.is(Mode.Type)) {
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }

val applyVParams = vargs.zipWithIndex.map {
case (p: ValDef, _) => p.withAddedFlags(mods.flags)
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags.toTermFlags)
}
RefinedTypeTree(polyFunctionTpt, List(
DefDef(nme.apply, applyTParams :: applyVParams :: Nil, res, EmptyTree).withFlags(Synthetic)
))
}
else {
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
// with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.

def typeTree(tp: Type) = tp match
case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
var bail = false
def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match
case tp: TypeRef => ref(tp)
case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name)
case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_)))
case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree }
val mapped = mapper(mt.resultType, topLevel = true)
if bail then TypeTree() else mapped
case _ => TypeTree()

val applyVParams = vargs.asInstanceOf[List[ValDef]]
.map(varg => varg.withAddedFlags(mods.flags | Param))
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res))
))
}
case _ =>
// may happen for erroneous input. An error will already have been reported.
assert(ctx.reporter.errorsReported)
EmptyTree
}

// begin desugar

// Special case for `Parens` desugaring: unlike all the desugarings below,
Expand All @@ -1796,8 +1777,6 @@ object desugar {
}

val desugared = tree match {
case PolyFunction(targs, body) =>
makePolyFunction(targs, body, pt) orElse tree
case SymbolLit(str) =>
Apply(
ref(defn.ScalaSymbolClass.companionModule.termRef),
Expand Down
5 changes: 1 addition & 4 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,7 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
case Closure(_, meth, _) => true
case Block(Nil, expr) => isContextualClosure(expr)
case Block(DefDef(nme.ANON_FUN, params :: _, _, _) :: Nil, cl: Closure) =>
if params.isEmpty then
cl.tpt.eq(untpd.ContextualEmptyTree) || defn.isContextFunctionType(cl.tpt.typeOpt)
else
isUsingClause(params)
isUsingClause(params)
case _ => false
}

Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,6 @@ object Trees {

@sharable val EmptyTree: Thicket = genericEmptyTree
@sharable val EmptyValDef: ValDef = genericEmptyValDef
@sharable val ContextualEmptyTree: Thicket = new EmptyTree() // an empty tree marking a contextual closure

// ----- Auxiliary creation methods ------------------

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree

/** Short-lived usage in typer, does not need copy/transform/fold infrastructure */
case class DependentTypeTree(tp: List[Symbol] => Type)(implicit @constructorOnly src: SourceFile) extends Tree
case class DependentTypeTree(tp: (List[TypeSymbol], List[TermSymbol]) => Type)(implicit @constructorOnly src: SourceFile) extends Tree
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we simplify this so that a DependentTypeTree just takes a (tp: List[Symbol] => Type)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a List[Symbol] means that we need to spend extra time concatenating and partitioning the parameters, and we need to document this behavior, so I'm not sure it's really more simple, but I'm happy to do the change if you prefer it that way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, in the end it does not seem to be a simplification. So OK to keep as is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I started my comment stream I thought it would come out simpler than it did 😄


@sharable object EmptyTypeIdent extends Ident(tpnme.EMPTY)(NoSource) with WithoutTypeOrPos[Untyped] {
override def isEmpty: Boolean = true
Expand Down
8 changes: 5 additions & 3 deletions compiler/src/dotty/tools/dotc/core/NameOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,12 @@ object NameOps {
*/
def isPlainFunction(using Context): Boolean = functionArity >= 0

/** Is a function name that contains `mustHave` as a substring */
private def isSpecificFunction(mustHave: String)(using Context): Boolean =
/** Is a function name that contains `mustHave` as a substring
* and has arity `minArity` or greater.
*/
private def isSpecificFunction(mustHave: String, minArity: Int = 0)(using Context): Boolean =
val suffixStart = functionSuffixStart
isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= 0
isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= minArity

def isContextFunction(using Context): Boolean = isSpecificFunction("Context")
def isImpureFunction(using Context): Boolean = isSpecificFunction("Impure")
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,8 @@ object Types {
if alwaysDependent || mt.isResultDependent then
RefinedType(funType, nme.apply, mt)
else funType
case poly @ PolyType(_, mt: MethodType) if !mt.isParamDependent =>
RefinedType(defn.PolyFunctionType, nme.apply, poly)
}

/** The signature of this type. This is by default NotAMethod,
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,7 @@ object Parsers {
TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], resultType)
else if imods.isOneOf(Given | Impure) || erasedArgs.contains(true) then
if imods.is(Given) && params.isEmpty then
imods &~= Given
syntaxError(em"context function types require at least one parameter", paramSpan)
FunctionWithMods(params, resultType, imods, erasedArgs.toList)
else if !ctx.settings.YkindProjector.isDefault then
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ class PlainPrinter(_ctx: Context) extends Printer {

protected def paramsText(lam: LambdaType): Text = {
val erasedParams = lam.erasedParams
def paramText(name: Name, tp: Type, erased: Boolean) =
keywordText("erased ").provided(erased) ~ toText(name) ~ lambdaHash(lam) ~ toTextRHS(tp, isParameter = true)
Text(lam.paramNames.lazyZip(lam.paramInfos).lazyZip(erasedParams).map(paramText), ", ")
def paramText(ref: ParamRef, erased: Boolean) =
keywordText("erased ").provided(erased) ~ ParamRefNameString(ref) ~ lambdaHash(lam) ~ toTextRHS(ref.underlying, isParameter = true)
Text(lam.paramRefs.lazyZip(erasedParams).map(paramText), ", ")
}

protected def ParamRefNameString(name: Name): String = nameString(name)
Expand Down Expand Up @@ -363,7 +363,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
case tp @ ConstantType(value) =>
toText(value)
case pref: TermParamRef =>
nameString(pref.binder.paramNames(pref.paramNum)) ~ lambdaHash(pref.binder)
ParamRefNameString(pref) ~ lambdaHash(pref.binder)
case tp: RecThis =>
val idx = openRecs.reverse.indexOf(tp.binder)
if (idx >= 0) selfRecName(idx + 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
~ " " ~ argText(args.last)
}

private def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text = Str("")): Text = info match
protected def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text = Str("")): Text = info match
case info: MethodType =>
val capturesRoot = refs == rootSetText
changePrec(GlobalPrec) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
case AmbiguousExtensionMethodID // errorNumber 180
case UnqualifiedCallToAnyRefMethodID // errorNumber: 181
case NotConstantID // errorNumber: 182
case ClosureCannotHaveInternalParameterDependenciesID // errorNumber: 183

def errorNumber = ordinal - 1

Expand Down
35 changes: 33 additions & 2 deletions compiler/src/dotty/tools/dotc/reporting/Message.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ object Message:
*/
private class Seen(disambiguate: Boolean):

/** The set of lambdas that were opened at some point during printing. */
private val openedLambdas = new collection.mutable.HashSet[LambdaType]

/** Register that `tp` was opened during printing. */
def openLambda(tp: LambdaType): Unit =
openedLambdas += tp

val seen = new collection.mutable.HashMap[SeenKey, List[Recorded]]:
override def default(key: SeenKey) = Nil

Expand Down Expand Up @@ -89,8 +96,22 @@ object Message:
val existing = seen(key)
lazy val dealiased = followAlias(entry)

// alts: The alternatives in `existing` that are equal, or follow (an alias of) `entry`
var alts = existing.dropWhile(alt => dealiased ne followAlias(alt))
/** All lambda parameters with the same name are given the same superscript as
* long as their corresponding binder has been printed.
* See tests/neg/lambda-rename.scala for test cases.
*/
def sameSuperscript(cur: Recorded, existing: Recorded) =
(cur eq existing) ||
(cur, existing).match
case (cur: ParamRef, existing: ParamRef) =>
(cur.paramName eq existing.paramName) &&
openedLambdas.contains(cur.binder) &&
openedLambdas.contains(existing.binder)
case _ =>
false

// The length of alts corresponds to the number of superscripts we need to print.
var alts = existing.dropWhile(alt => !sameSuperscript(dealiased, followAlias(alt)))
if alts.isEmpty then
alts = entry :: existing
seen(key) = alts
Expand Down Expand Up @@ -208,10 +229,20 @@ object Message:
case tp: SkolemType => seen.record(tp.repr.toString, isType = true, tp)
case _ => super.toTextRef(tp)

override def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text): Text =
info match
case info: LambdaType =>
seen.openLambda(info)
case _ =>
super.toTextMethodAsFunction(info, isPure, refs)

override def toText(tp: Type): Text =
if !tp.exists || tp.isErroneous then seen.nonSensical = true
tp match
case tp: TypeRef if useSourceModule(tp.symbol) => Str("object ") ~ super.toText(tp)
case tp: LambdaType =>
seen.openLambda(tp)
super.toText(tp)
case _ => super.toText(tp)

override def toText(sym: Symbol): Text =
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2920,3 +2920,10 @@ class MatchTypeScrutineeCannotBeHigherKinded(tp: Type)(using Context)
extends TypeMsg(MatchTypeScrutineeCannotBeHigherKindedID) :
def msg(using Context) = i"the scrutinee of a match type cannot be higher-kinded"
def explain(using Context) = ""

class ClosureCannotHaveInternalParameterDependencies(mt: Type)(using Context)
extends TypeMsg(ClosureCannotHaveInternalParameterDependenciesID):
def msg(using Context) =
i"""cannot turn method type $mt into closure
|because it has internal parameter dependencies"""
def explain(using Context) = ""
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ object Checking {
case tree: RefTree =>
checkRef(tree, tree.symbol)
foldOver(x, tree)
case tree: This =>
case tree: This if tree.tpe.classSymbol == refineCls =>
selfRef(tree)
case tree: TypeTree =>
val checkType = new TypeAccumulator[Unit] {
Expand Down
9 changes: 6 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1692,15 +1692,17 @@ class Namer { typer: Typer =>
def valOrDefDefSig(mdef: ValOrDefDef, sym: Symbol, paramss: List[List[Symbol]], paramFn: Type => Type)(using Context): Type = {

def inferredType = inferredResultType(mdef, sym, paramss, paramFn, WildcardType)
lazy val termParamss = paramss.collect { case TermSymbols(vparams) => vparams }

val tptProto = mdef.tpt match {
case _: untpd.DerivedTypeTree =>
WildcardType
case TypeTree() =>
checkMembersOK(inferredType, mdef.srcPos)
case DependentTypeTree(tpFun) =>
val tpe = tpFun(termParamss.head)
// A lambda has at most one type parameter list followed by exactly one term parameter list.
val tpe = (paramss: @unchecked) match
case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we simplify DependentTypeTrees, these would become:

  case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams ++ vparams)
  case TermSymbols(vparams) :: Nil => tpFun(vparams)

case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams)
if (isFullyDefined(tpe, ForceDegree.none)) tpe
else typedAheadExpr(mdef.rhs, tpe).tpe
case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
Expand All @@ -1724,7 +1726,8 @@ class Namer { typer: Typer =>
// So fixing levels at instantiation avoids the soundness problem but apparently leads
// to type inference problems since it comes too late.
if !Config.checkLevelsOnConstraints then
val hygienicType = TypeOps.avoid(rhsType, termParamss.flatten)
val termParams = paramss.collect { case TermSymbols(vparams) => vparams }.flatten
val hygienicType = TypeOps.avoid(rhsType, termParams)
if (!hygienicType.isValueType || !(hygienicType <:< tpt.tpe))
report.error(
em"""return type ${tpt.tpe} of lambda cannot be made hygienic
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ trait TypeAssigner {
*/
def qualifyingClass(tree: untpd.Tree, qual: Name, packageOK: Boolean)(using Context): Symbol = {
def qualifies(sym: Symbol) =
sym.isClass && (
sym.isClass &&
// `this` in a polymorphic function type never refers to the desugared refinement.
// In other refinements, `this` does refer to the refinement but is deprecated
// (see `Checking#checkRefinementNonCyclic`).
!(sym.isRefinementClass && sym.derivesFrom(defn.PolyFunctionClass)) && (
qual.isEmpty ||
sym.name == qual ||
sym.is(Module) && sym.name.stripModuleClassSuffix == qual)
Expand Down
Loading