Skip to content

Commit

Permalink
Backport "Check user defined PolyFunction refinements " to LTS (#20647)
Browse files Browse the repository at this point in the history
Backports #18457 to the LTS branch.

PR submitted by the release tooling.
[skip ci]
  • Loading branch information
WojciechMazur authored Jun 20, 2024
2 parents 26336ef + 917072b commit 739f55a
Show file tree
Hide file tree
Showing 19 changed files with 137 additions and 12 deletions.
33 changes: 22 additions & 11 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,18 @@ class Definitions {
case _ => None
}

object ErasedFunctionOf {
/** Matches a refined `ErasedFunction` type and extracts the apply info.
*
* Pattern: `ErasedFunction { def apply: $mt }`
*/
def unapply(ft: Type)(using Context): Option[MethodType] = ft.dealias match
case RefinedType(parent, nme.apply, mt: MethodType)
if parent.derivesFrom(defn.ErasedFunctionClass) =>
Some(mt)
case _ => None
}

object PolyFunctionOf {
/** Matches a refined `PolyFunction` type and extracts the apply info.
*
Expand All @@ -1157,18 +1169,17 @@ class Definitions {
if tpe.refinedName == nme.apply && tpe.parent.derivesFrom(defn.PolyFunctionClass) =>
Some(mt)
case _ => None
}

object ErasedFunctionOf {
/** Matches a refined `ErasedFunction` type and extracts the apply info.
*
* Pattern: `ErasedFunction { def apply: $mt }`
*/
def unapply(ft: Type)(using Context): Option[MethodType] = ft.dealias match
case RefinedType(parent, nme.apply, mt: MethodType)
if parent.derivesFrom(defn.ErasedFunctionClass) =>
Some(mt)
case _ => None
def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
def isValidMethodType(info: Type) = info match
case info: MethodType =>
!info.resType.isInstanceOf[MethodOrPoly] && // Has only one parameter list
!info.isVarArgsMethod &&
!info.isMethodWithByNameArgs // No by-name parameters
case _ => false
info match
case info: PolyType => isValidMethodType(info.resType)
case _ => isValidMethodType(info)
}

object PartialFunctionOf {
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,12 @@ object Types {
case _ => false
}

/** Is this the type of a method that has a by-name parameters? */
def isMethodWithByNameArgs(using Context): Boolean = stripPoly match {
case mt: MethodType => mt.paramInfos.exists(_.isInstanceOf[ExprType])
case _ => false
}

/** Is this the type of a method with a leading empty parameter list?
*/
def isNullaryMethod(using Context): Boolean = stripPoly match {
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -375,13 +375,15 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
case tree: ValDef =>
registerIfHasMacroAnnotations(tree)
checkErasedDef(tree)
Checking.checkPolyFunctionType(tree.tpt)
val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
if tree1.removeAttachment(desugar.UntupledParam).isDefined then
checkStableSelection(tree.rhs)
processValOrDefDef(super.transform(tree1))
case tree: DefDef =>
registerIfHasMacroAnnotations(tree)
checkErasedDef(tree)
Checking.checkPolyFunctionType(tree.tpt)
annotateContextResults(tree)
val tree1 = cpy.DefDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
Expand Down Expand Up @@ -483,6 +485,9 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
)
case Block(_, Closure(_, _, tpt)) if ExpandSAMs.needsWrapperClass(tpt.tpe) =>
superAcc.withInvalidCurrentClass(super.transform(tree))
case tree: RefinedTypeTree =>
Checking.checkPolyFunctionType(tree)
super.transform(tree)
case _: Quote =>
ctx.compilationUnit.needsStaging = true
super.transform(tree)
Expand Down
25 changes: 25 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,31 @@ object Checking {
else Feature.checkExperimentalFeature("features", imp.srcPos)
case _ =>
end checkExperimentalImports

/** Checks that PolyFunction only have valid refinements.
*
* It only supports `apply` methods with one parameter list and optional type arguments.
*/
def checkPolyFunctionType(tree: Tree)(using Context): Unit = new TreeTraverser {
def traverse(tree: Tree)(using Context): Unit = tree match
case tree: RefinedTypeTree if tree.tpe.derivesFrom(defn.PolyFunctionClass) =>
if tree.refinements.isEmpty then
reportNoRefinements(tree.srcPos)
tree.refinements.foreach {
case refinement: DefDef if refinement.name != nme.apply =>
report.error("PolyFunction only supports apply method refinements", refinement.srcPos)
case refinement: DefDef if !defn.PolyFunctionOf.isValidPolyFunctionInfo(refinement.tpe.widen) =>
report.error("Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.", refinement.srcPos)
case _ =>
}
case _: RefTree if tree.symbol == defn.PolyFunctionClass =>
reportNoRefinements(tree.srcPos)
case _ =>
traverseChildren(tree)

def reportNoRefinements(pos: SrcPos) =
report.error("PolyFunction subtypes must refine the apply method", pos)
}.traverse(tree)
}

trait Checking {
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ object Nullables:
def postProcessByNameArgs(fn: TermRef, app: Tree)(using Context): Tree =
fn.widen match
case mt: MethodType
if mt.paramInfos.exists(_.isInstanceOf[ExprType]) && !fn.symbol.is(Inline) =>
if mt.isMethodWithByNameArgs && !fn.symbol.is(Inline) =>
app match
case Apply(fn, args) =>
object dropNotNull extends TreeMap:
Expand Down
4 changes: 4 additions & 0 deletions tests/neg/i18302b.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i18302b.scala:3:32 ---------------------------------------------------------------------------------
3 |def polyFun: PolyFunction { def apply(x: Int)(y: Int): Int } = // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.
5 changes: 5 additions & 0 deletions tests/neg/i18302b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def test = polyFun(1)(2)

def polyFun: PolyFunction { def apply(x: Int)(y: Int): Int } = // error
new PolyFunction:
def apply(x: Int)(y: Int): Int = x + y
4 changes: 4 additions & 0 deletions tests/neg/i18302c.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i18302c.scala:4:32 ---------------------------------------------------------------------------------
4 |def polyFun: PolyFunction { def foo(x: Int): Int } = // error
| ^^^^^^^^^^^^^^^^^^^^
| PolyFunction only supports apply method refinements
5 changes: 5 additions & 0 deletions tests/neg/i18302c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import scala.reflect.Selectable.reflectiveSelectable

def test = polyFun.foo(1)
def polyFun: PolyFunction { def foo(x: Int): Int } = // error
new PolyFunction { def foo(x: Int): Int = x + 1 }
4 changes: 4 additions & 0 deletions tests/neg/i18302d.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i18302d.scala:1:32 ---------------------------------------------------------------------------------
1 |def polyFun: PolyFunction { def apply: Int } = // error
| ^^^^^^^^^^^^^^
|Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.
2 changes: 2 additions & 0 deletions tests/neg/i18302d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def polyFun: PolyFunction { def apply: Int } = // error
new PolyFunction { def apply: Int = 1 }
8 changes: 8 additions & 0 deletions tests/neg/i18302e.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Error: tests/neg/i18302e.scala:1:13 ---------------------------------------------------------------------------------
1 |def polyFun: PolyFunction { } = // error
| ^^^^^^^^^^^^^^^^^
| PolyFunction subtypes must refine the apply method
-- Error: tests/neg/i18302e.scala:4:15 ---------------------------------------------------------------------------------
4 |def polyFun(f: PolyFunction { }) = () // error
| ^^^^^^^^^^^^^^^^^
| PolyFunction subtypes must refine the apply method
4 changes: 4 additions & 0 deletions tests/neg/i18302e.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def polyFun: PolyFunction { } = // error
new PolyFunction { }

def polyFun(f: PolyFunction { }) = () // error
12 changes: 12 additions & 0 deletions tests/neg/i18302f.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- Error: tests/neg/i18302f.scala:1:13 ---------------------------------------------------------------------------------
1 |def polyFun: PolyFunction = // error
| ^^^^^^^^^^^^
| PolyFunction subtypes must refine the apply method
-- Error: tests/neg/i18302f.scala:4:16 ---------------------------------------------------------------------------------
4 |def polyFun2(a: PolyFunction) = () // error
| ^^^^^^^^^^^^
| PolyFunction subtypes must refine the apply method
-- Error: tests/neg/i18302f.scala:6:14 ---------------------------------------------------------------------------------
6 |val polyFun3: PolyFunction = // error
| ^^^^^^^^^^^^
| PolyFunction subtypes must refine the apply method
7 changes: 7 additions & 0 deletions tests/neg/i18302f.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def polyFun: PolyFunction = // error
new PolyFunction { }

def polyFun2(a: PolyFunction) = () // error

val polyFun3: PolyFunction = // error
new PolyFunction { }
6 changes: 6 additions & 0 deletions tests/neg/i18302i.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def polyFun1: Option[PolyFunction] = ??? // error
def polyFun2: PolyFunction & Any = ??? // error
def polyFun3: Any & PolyFunction = ??? // error
def polyFun4: PolyFunction | Any = ??? // error
def polyFun5: Any | PolyFunction = ??? // error
def polyFun6(a: Any | PolyFunction) = ??? // error
5 changes: 5 additions & 0 deletions tests/neg/i18302j.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def polyFunByName: PolyFunction { def apply(thunk: => Int): Int } = // error
new PolyFunction { def apply(thunk: => Int): Int = 1 }

def polyFunVarArgs: PolyFunction { def apply(args: Int*): Int } = // error
new PolyFunction { def apply(thunk: Int*): Int = 1 }
8 changes: 8 additions & 0 deletions tests/neg/i8299.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package example

object Main {
def main(a: Array[String]): Unit = {
val p: PolyFunction = // error: PolyFunction subtypes must refine the apply method
[A] => (xs: List[A]) => xs.headOption
}
}
4 changes: 4 additions & 0 deletions tests/pos/i18302a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = polyFun(1)

def polyFun: PolyFunction { def apply(x: Int): Int } =
new PolyFunction { def apply(x: Int): Int = x + 1 }

0 comments on commit 739f55a

Please sign in to comment.