Skip to content

Commit

Permalink
Check user defined PolyFunction refinements (#18457)
Browse files Browse the repository at this point in the history
Fixes #18302
  • Loading branch information
nicolasstucki committed Oct 12, 2023
2 parents d0fb2b3 + 9966ced commit 12a373f
Show file tree
Hide file tree
Showing 19 changed files with 119 additions and 3 deletions.
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1165,10 +1165,12 @@ class Definitions {
Some(mt)
case _ => None

private def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
def isValidPolyFunctionInfo(info: Type)(using Context): Boolean =
def isValidMethodType(info: Type) = info match
case info: MethodType =>
!info.resType.isInstanceOf[MethodOrPoly] // Has only one parameter list
!info.resType.isInstanceOf[MethodOrPoly] && // Has only one parameter list
!info.isVarArgsMethod &&
!info.isMethodWithByNameArgs // No by-name parameters
case _ => false
info match
case info: PolyType => isValidMethodType(info.resType)
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ object Types {
case _ => false
}

/** Is this the type of a method that has a by-name parameters? */
def isMethodWithByNameArgs(using Context): Boolean = stripPoly match {
case mt: MethodType => mt.paramInfos.exists(_.isInstanceOf[ExprType])
case _ => false
}

/** Is this the type of a method with a leading empty parameter list?
*/
def isNullaryMethod(using Context): Boolean = stripPoly match {
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,15 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case tree: ValDef =>
registerIfHasMacroAnnotations(tree)
checkErasedDef(tree)
Checking.checkPolyFunctionType(tree.tpt)
val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
if tree1.removeAttachment(desugar.UntupledParam).isDefined then
checkStableSelection(tree.rhs)
processValOrDefDef(super.transform(tree1))
case tree: DefDef =>
registerIfHasMacroAnnotations(tree)
checkErasedDef(tree)
Checking.checkPolyFunctionType(tree.tpt)
annotateContextResults(tree)
val tree1 = cpy.DefDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
Expand Down Expand Up @@ -492,6 +494,9 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
)
case Block(_, Closure(_, _, tpt)) if ExpandSAMs.needsWrapperClass(tpt.tpe) =>
superAcc.withInvalidCurrentClass(super.transform(tree))
case tree: RefinedTypeTree =>
Checking.checkPolyFunctionType(tree)
super.transform(tree)
case _: Quote | _: QuotePattern =>
ctx.compilationUnit.needsStaging = true
super.transform(tree)
Expand Down
25 changes: 25 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,31 @@ object Checking {
else Feature.checkExperimentalFeature("features", imp.srcPos)
case _ =>
end checkExperimentalImports

/** Checks that PolyFunction only have valid refinements.
*
* It only supports `apply` methods with one parameter list and optional type arguments.
*/
def checkPolyFunctionType(tree: Tree)(using Context): Unit = new TreeTraverser {
def traverse(tree: Tree)(using Context): Unit = tree match
case tree: RefinedTypeTree if tree.tpe.derivesFrom(defn.PolyFunctionClass) =>
if tree.refinements.isEmpty then
reportNoRefinements(tree.srcPos)
tree.refinements.foreach {
case refinement: DefDef if refinement.name != nme.apply =>
report.error("PolyFunction only supports apply method refinements", refinement.srcPos)
case refinement: DefDef if !defn.PolyFunctionOf.isValidPolyFunctionInfo(refinement.tpe.widen) =>
report.error("Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.", refinement.srcPos)
case _ =>
}
case _: RefTree if tree.symbol == defn.PolyFunctionClass =>
reportNoRefinements(tree.srcPos)
case _ =>
traverseChildren(tree)

def reportNoRefinements(pos: SrcPos) =
report.error("PolyFunction subtypes must refine the apply method", pos)
}.traverse(tree)
}

trait Checking {
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ object Nullables:
def postProcessByNameArgs(fn: TermRef, app: Tree)(using Context): Tree =
fn.widen match
case mt: MethodType
if mt.paramInfos.exists(_.isInstanceOf[ExprType]) && !fn.symbol.is(Inline) =>
if mt.isMethodWithByNameArgs && !fn.symbol.is(Inline) =>
app match
case Apply(fn, args) =>
object dropNotNull extends TreeMap:
Expand Down
4 changes: 4 additions & 0 deletions tests/neg/i18302b.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i18302b.scala:3:32 ---------------------------------------------------------------------------------
3 |def polyFun: PolyFunction { def apply(x: Int)(y: Int): Int } = // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.
5 changes: 5 additions & 0 deletions tests/neg/i18302b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def test = polyFun(1)(2)

def polyFun: PolyFunction { def apply(x: Int)(y: Int): Int } = // error
new PolyFunction:
def apply(x: Int)(y: Int): Int = x + y
4 changes: 4 additions & 0 deletions tests/neg/i18302c.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i18302c.scala:4:32 ---------------------------------------------------------------------------------
4 |def polyFun: PolyFunction { def foo(x: Int): Int } = // error
| ^^^^^^^^^^^^^^^^^^^^
| PolyFunction only supports apply method refinements
5 changes: 5 additions & 0 deletions tests/neg/i18302c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import scala.reflect.Selectable.reflectiveSelectable

def test = polyFun.foo(1)
def polyFun: PolyFunction { def foo(x: Int): Int } = // error
new PolyFunction { def foo(x: Int): Int = x + 1 }
4 changes: 4 additions & 0 deletions tests/neg/i18302d.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i18302d.scala:1:32 ---------------------------------------------------------------------------------
1 |def polyFun: PolyFunction { def apply: Int } = // error
| ^^^^^^^^^^^^^^
|Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.
2 changes: 2 additions & 0 deletions tests/neg/i18302d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def polyFun: PolyFunction { def apply: Int } = // error
new PolyFunction { def apply: Int = 1 }
8 changes: 8 additions & 0 deletions tests/neg/i18302e.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Error: tests/neg/i18302e.scala:1:13 ---------------------------------------------------------------------------------
1 |def polyFun: PolyFunction { } = // error
| ^^^^^^^^^^^^^^^^^
| PolyFunction subtypes must refine the apply method
-- Error: tests/neg/i18302e.scala:4:15 ---------------------------------------------------------------------------------
4 |def polyFun(f: PolyFunction { }) = () // error
| ^^^^^^^^^^^^^^^^^
| PolyFunction subtypes must refine the apply method
4 changes: 4 additions & 0 deletions tests/neg/i18302e.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def polyFun: PolyFunction { } = // error
new PolyFunction { }

def polyFun(f: PolyFunction { }) = () // error
12 changes: 12 additions & 0 deletions tests/neg/i18302f.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- Error: tests/neg/i18302f.scala:1:13 ---------------------------------------------------------------------------------
1 |def polyFun: PolyFunction = // error
| ^^^^^^^^^^^^
| PolyFunction subtypes must refine the apply method
-- Error: tests/neg/i18302f.scala:4:16 ---------------------------------------------------------------------------------
4 |def polyFun2(a: PolyFunction) = () // error
| ^^^^^^^^^^^^
| PolyFunction subtypes must refine the apply method
-- Error: tests/neg/i18302f.scala:6:14 ---------------------------------------------------------------------------------
6 |val polyFun3: PolyFunction = // error
| ^^^^^^^^^^^^
| PolyFunction subtypes must refine the apply method
7 changes: 7 additions & 0 deletions tests/neg/i18302f.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def polyFun: PolyFunction = // error
new PolyFunction { }

def polyFun2(a: PolyFunction) = () // error

val polyFun3: PolyFunction = // error
new PolyFunction { }
6 changes: 6 additions & 0 deletions tests/neg/i18302i.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def polyFun1: Option[PolyFunction] = ??? // error
def polyFun2: PolyFunction & Any = ??? // error
def polyFun3: Any & PolyFunction = ??? // error
def polyFun4: PolyFunction | Any = ??? // error
def polyFun5: Any | PolyFunction = ??? // error
def polyFun6(a: Any | PolyFunction) = ??? // error
5 changes: 5 additions & 0 deletions tests/neg/i18302j.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def polyFunByName: PolyFunction { def apply(thunk: => Int): Int } = // error
new PolyFunction { def apply(thunk: => Int): Int = 1 }

def polyFunVarArgs: PolyFunction { def apply(args: Int*): Int } = // error
new PolyFunction { def apply(thunk: Int*): Int = 1 }
8 changes: 8 additions & 0 deletions tests/neg/i8299.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package example

object Main {
def main(a: Array[String]): Unit = {
val p: PolyFunction = // error: PolyFunction subtypes must refine the apply method
[A] => (xs: List[A]) => xs.headOption
}
}
4 changes: 4 additions & 0 deletions tests/pos/i18302a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = polyFun(1)

def polyFun: PolyFunction { def apply(x: Int): Int } =
new PolyFunction { def apply(x: Int): Int = x + 1 }

0 comments on commit 12a373f

Please sign in to comment.