From 82ff24f8fdaa6b0bd743dae202bd8272ee041aad Mon Sep 17 00:00:00 2001 From: Wojciech Mazur Date: Thu, 20 Jun 2024 15:17:31 +0200 Subject: [PATCH 1/2] Check user defined PolyFunction refinements `PolyFunction` must be refined with an `apply` method that has a single parameter list with no by-name nor varargs parameters. It may optionally have type parameters. Some of these restrictions could be lifted later, but for now these features are not properly handled by the compiler. Fixes #8299 Fixes #18302 [Cherry-picked e5ca0c42f85bff00f9fdd6e79f8db293d84966dd][modified] --- .../dotty/tools/dotc/core/Definitions.scala | 11 ++++++++ .../tools/dotc/transform/PostTyper.scala | 5 ++++ .../src/dotty/tools/dotc/typer/Checking.scala | 25 +++++++++++++++++++ tests/neg/i18302b.check | 4 +++ tests/neg/i18302b.scala | 5 ++++ tests/neg/i18302c.check | 4 +++ tests/neg/i18302c.scala | 5 ++++ tests/neg/i18302d.check | 4 +++ tests/neg/i18302d.scala | 2 ++ tests/neg/i18302e.check | 8 ++++++ tests/neg/i18302e.scala | 4 +++ tests/neg/i18302f.check | 12 +++++++++ tests/neg/i18302f.scala | 7 ++++++ tests/neg/i18302i.scala | 6 +++++ tests/neg/i18302j.scala | 5 ++++ tests/neg/i8299.scala | 8 ++++++ tests/pos/i18302a.scala | 4 +++ 17 files changed, 119 insertions(+) create mode 100644 tests/neg/i18302b.check create mode 100644 tests/neg/i18302b.scala create mode 100644 tests/neg/i18302c.check create mode 100644 tests/neg/i18302c.scala create mode 100644 tests/neg/i18302d.check create mode 100644 tests/neg/i18302d.scala create mode 100644 tests/neg/i18302e.check create mode 100644 tests/neg/i18302e.scala create mode 100644 tests/neg/i18302f.check create mode 100644 tests/neg/i18302f.scala create mode 100644 tests/neg/i18302i.scala create mode 100644 tests/neg/i18302j.scala create mode 100644 tests/neg/i8299.scala create mode 100644 tests/pos/i18302a.scala diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index c5a798e2dcd7..6ee40c9f9706 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1157,6 +1157,17 @@ class Definitions { if tpe.refinedName == nme.apply && tpe.parent.derivesFrom(defn.PolyFunctionClass) => Some(mt) case _ => None + + def isValidPolyFunctionInfo(info: Type)(using Context): Boolean = + def isValidMethodType(info: Type) = info match + case info: MethodType => + !info.resType.isInstanceOf[MethodOrPoly] && // Has only one parameter list + !info.isVarArgsMethod && + !info.paramInfos.exists(_.isInstanceOf[ExprType]) // No by-name parameters + case _ => false + info match + case info: PolyType => isValidMethodType(info.resType) + case _ => isValidMethodType(info) } object ErasedFunctionOf { diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 822d679b4954..39f8ae6e757b 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -375,6 +375,7 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase case tree: ValDef => registerIfHasMacroAnnotations(tree) checkErasedDef(tree) + Checking.checkPolyFunctionType(tree.tpt) val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol)) if tree1.removeAttachment(desugar.UntupledParam).isDefined then checkStableSelection(tree.rhs) @@ -382,6 +383,7 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase case tree: DefDef => registerIfHasMacroAnnotations(tree) checkErasedDef(tree) + Checking.checkPolyFunctionType(tree.tpt) annotateContextResults(tree) val tree1 = cpy.DefDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol)) processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef])) @@ -483,6 +485,9 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase ) case Block(_, Closure(_, _, tpt)) if ExpandSAMs.needsWrapperClass(tpt.tpe) => superAcc.withInvalidCurrentClass(super.transform(tree)) + case tree: RefinedTypeTree => + Checking.checkPolyFunctionType(tree) + super.transform(tree) case _: Quote => ctx.compilationUnit.needsStaging = true super.transform(tree) diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index 93acc01d28ad..3948fecb2a0e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -804,6 +804,31 @@ object Checking { else Feature.checkExperimentalFeature("features", imp.srcPos) case _ => end checkExperimentalImports + + /** Checks that PolyFunction only have valid refinements. + * + * It only supports `apply` methods with one parameter list and optional type arguments. + */ + def checkPolyFunctionType(tree: Tree)(using Context): Unit = new TreeTraverser { + def traverse(tree: Tree)(using Context): Unit = tree match + case tree: RefinedTypeTree if tree.tpe.derivesFrom(defn.PolyFunctionClass) => + if tree.refinements.isEmpty then + reportNoRefinements(tree.srcPos) + tree.refinements.foreach { + case refinement: DefDef if refinement.name != nme.apply => + report.error("PolyFunction only supports apply method refinements", refinement.srcPos) + case refinement: DefDef if !defn.PolyFunctionOf.isValidPolyFunctionInfo(refinement.tpe.widen) => + report.error("Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed.", refinement.srcPos) + case _ => + } + case _: RefTree if tree.symbol == defn.PolyFunctionClass => + reportNoRefinements(tree.srcPos) + case _ => + traverseChildren(tree) + + def reportNoRefinements(pos: SrcPos) = + report.error("PolyFunction subtypes must refine the apply method", pos) + }.traverse(tree) } trait Checking { diff --git a/tests/neg/i18302b.check b/tests/neg/i18302b.check new file mode 100644 index 000000000000..0dc3ba6c054a --- /dev/null +++ b/tests/neg/i18302b.check @@ -0,0 +1,4 @@ +-- Error: tests/neg/i18302b.scala:3:32 --------------------------------------------------------------------------------- +3 |def polyFun: PolyFunction { def apply(x: Int)(y: Int): Int } = // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + |Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed. diff --git a/tests/neg/i18302b.scala b/tests/neg/i18302b.scala new file mode 100644 index 000000000000..71c7992c178b --- /dev/null +++ b/tests/neg/i18302b.scala @@ -0,0 +1,5 @@ +def test = polyFun(1)(2) + +def polyFun: PolyFunction { def apply(x: Int)(y: Int): Int } = // error + new PolyFunction: + def apply(x: Int)(y: Int): Int = x + y diff --git a/tests/neg/i18302c.check b/tests/neg/i18302c.check new file mode 100644 index 000000000000..4610145a30b2 --- /dev/null +++ b/tests/neg/i18302c.check @@ -0,0 +1,4 @@ +-- Error: tests/neg/i18302c.scala:4:32 --------------------------------------------------------------------------------- +4 |def polyFun: PolyFunction { def foo(x: Int): Int } = // error + | ^^^^^^^^^^^^^^^^^^^^ + | PolyFunction only supports apply method refinements diff --git a/tests/neg/i18302c.scala b/tests/neg/i18302c.scala new file mode 100644 index 000000000000..a5d182d7ad0c --- /dev/null +++ b/tests/neg/i18302c.scala @@ -0,0 +1,5 @@ +import scala.reflect.Selectable.reflectiveSelectable + +def test = polyFun.foo(1) +def polyFun: PolyFunction { def foo(x: Int): Int } = // error + new PolyFunction { def foo(x: Int): Int = x + 1 } diff --git a/tests/neg/i18302d.check b/tests/neg/i18302d.check new file mode 100644 index 000000000000..976db59763c1 --- /dev/null +++ b/tests/neg/i18302d.check @@ -0,0 +1,4 @@ +-- Error: tests/neg/i18302d.scala:1:32 --------------------------------------------------------------------------------- +1 |def polyFun: PolyFunction { def apply: Int } = // error + | ^^^^^^^^^^^^^^ + |Implementation restriction: PolyFunction apply must have exactly one parameter list and optionally type arguments. No by-name nor varags are allowed. diff --git a/tests/neg/i18302d.scala b/tests/neg/i18302d.scala new file mode 100644 index 000000000000..a7f9a5bec286 --- /dev/null +++ b/tests/neg/i18302d.scala @@ -0,0 +1,2 @@ +def polyFun: PolyFunction { def apply: Int } = // error + new PolyFunction { def apply: Int = 1 } diff --git a/tests/neg/i18302e.check b/tests/neg/i18302e.check new file mode 100644 index 000000000000..aae101875845 --- /dev/null +++ b/tests/neg/i18302e.check @@ -0,0 +1,8 @@ +-- Error: tests/neg/i18302e.scala:1:13 --------------------------------------------------------------------------------- +1 |def polyFun: PolyFunction { } = // error + | ^^^^^^^^^^^^^^^^^ + | PolyFunction subtypes must refine the apply method +-- Error: tests/neg/i18302e.scala:4:15 --------------------------------------------------------------------------------- +4 |def polyFun(f: PolyFunction { }) = () // error + | ^^^^^^^^^^^^^^^^^ + | PolyFunction subtypes must refine the apply method diff --git a/tests/neg/i18302e.scala b/tests/neg/i18302e.scala new file mode 100644 index 000000000000..1ffab2586048 --- /dev/null +++ b/tests/neg/i18302e.scala @@ -0,0 +1,4 @@ +def polyFun: PolyFunction { } = // error + new PolyFunction { } + +def polyFun(f: PolyFunction { }) = () // error diff --git a/tests/neg/i18302f.check b/tests/neg/i18302f.check new file mode 100644 index 000000000000..df0d76c2f157 --- /dev/null +++ b/tests/neg/i18302f.check @@ -0,0 +1,12 @@ +-- Error: tests/neg/i18302f.scala:1:13 --------------------------------------------------------------------------------- +1 |def polyFun: PolyFunction = // error + | ^^^^^^^^^^^^ + | PolyFunction subtypes must refine the apply method +-- Error: tests/neg/i18302f.scala:4:16 --------------------------------------------------------------------------------- +4 |def polyFun2(a: PolyFunction) = () // error + | ^^^^^^^^^^^^ + | PolyFunction subtypes must refine the apply method +-- Error: tests/neg/i18302f.scala:6:14 --------------------------------------------------------------------------------- +6 |val polyFun3: PolyFunction = // error + | ^^^^^^^^^^^^ + | PolyFunction subtypes must refine the apply method diff --git a/tests/neg/i18302f.scala b/tests/neg/i18302f.scala new file mode 100644 index 000000000000..2f86f0e1eb62 --- /dev/null +++ b/tests/neg/i18302f.scala @@ -0,0 +1,7 @@ +def polyFun: PolyFunction = // error + new PolyFunction { } + +def polyFun2(a: PolyFunction) = () // error + +val polyFun3: PolyFunction = // error + new PolyFunction { } diff --git a/tests/neg/i18302i.scala b/tests/neg/i18302i.scala new file mode 100644 index 000000000000..e64330879e55 --- /dev/null +++ b/tests/neg/i18302i.scala @@ -0,0 +1,6 @@ +def polyFun1: Option[PolyFunction] = ??? // error +def polyFun2: PolyFunction & Any = ??? // error +def polyFun3: Any & PolyFunction = ??? // error +def polyFun4: PolyFunction | Any = ??? // error +def polyFun5: Any | PolyFunction = ??? // error +def polyFun6(a: Any | PolyFunction) = ??? // error diff --git a/tests/neg/i18302j.scala b/tests/neg/i18302j.scala new file mode 100644 index 000000000000..8c63aa573c9b --- /dev/null +++ b/tests/neg/i18302j.scala @@ -0,0 +1,5 @@ +def polyFunByName: PolyFunction { def apply(thunk: => Int): Int } = // error + new PolyFunction { def apply(thunk: => Int): Int = 1 } + +def polyFunVarArgs: PolyFunction { def apply(args: Int*): Int } = // error + new PolyFunction { def apply(thunk: Int*): Int = 1 } diff --git a/tests/neg/i8299.scala b/tests/neg/i8299.scala new file mode 100644 index 000000000000..e3e41515ff29 --- /dev/null +++ b/tests/neg/i8299.scala @@ -0,0 +1,8 @@ +package example + +object Main { + def main(a: Array[String]): Unit = { + val p: PolyFunction = // error: PolyFunction subtypes must refine the apply method + [A] => (xs: List[A]) => xs.headOption + } +} diff --git a/tests/pos/i18302a.scala b/tests/pos/i18302a.scala new file mode 100644 index 000000000000..c087b63543f4 --- /dev/null +++ b/tests/pos/i18302a.scala @@ -0,0 +1,4 @@ +def test = polyFun(1) + +def polyFun: PolyFunction { def apply(x: Int): Int } = + new PolyFunction { def apply(x: Int): Int = x + 1 } From 54f5421accffbfcf4f8ff23ca247da5833bd4512 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Fri, 25 Aug 2023 16:50:02 +0200 Subject: [PATCH 2/2] Add `isMethodWithByNameArgs` [Cherry-picked 9966ced284fec8a7e601c3de1edd328a10a451fc] --- compiler/src/dotty/tools/dotc/core/Definitions.scala | 2 +- compiler/src/dotty/tools/dotc/core/Types.scala | 6 ++++++ compiler/src/dotty/tools/dotc/typer/Nullables.scala | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 6ee40c9f9706..e9b8d2e7affa 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1163,7 +1163,7 @@ class Definitions { case info: MethodType => !info.resType.isInstanceOf[MethodOrPoly] && // Has only one parameter list !info.isVarArgsMethod && - !info.paramInfos.exists(_.isInstanceOf[ExprType]) // No by-name parameters + !info.isMethodWithByNameArgs // No by-name parameters case _ => false info match case info: PolyType => isValidMethodType(info.resType) diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 8b2749fc1254..545bb138969a 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -419,6 +419,12 @@ object Types { case _ => false } + /** Is this the type of a method that has a by-name parameters? */ + def isMethodWithByNameArgs(using Context): Boolean = stripPoly match { + case mt: MethodType => mt.paramInfos.exists(_.isInstanceOf[ExprType]) + case _ => false + } + /** Is this the type of a method with a leading empty parameter list? */ def isNullaryMethod(using Context): Boolean = stripPoly match { diff --git a/compiler/src/dotty/tools/dotc/typer/Nullables.scala b/compiler/src/dotty/tools/dotc/typer/Nullables.scala index 9104418d406f..722dc2186693 100644 --- a/compiler/src/dotty/tools/dotc/typer/Nullables.scala +++ b/compiler/src/dotty/tools/dotc/typer/Nullables.scala @@ -507,7 +507,7 @@ object Nullables: def postProcessByNameArgs(fn: TermRef, app: Tree)(using Context): Tree = fn.widen match case mt: MethodType - if mt.paramInfos.exists(_.isInstanceOf[ExprType]) && !fn.symbol.is(Inline) => + if mt.isMethodWithByNameArgs && !fn.symbol.is(Inline) => app match case Apply(fn, args) => object dropNotNull extends TreeMap: