From 8c6daaf99c42679e257633357ce1e8331370c9fe Mon Sep 17 00:00:00 2001 From: noti0na1 <8036790+noti0na1@users.noreply.github.com> Date: Fri, 22 Mar 2024 04:55:04 +0000 Subject: [PATCH 1/4] Find universal capability from parents --- .../src/dotty/tools/dotc/cc/CheckCaptures.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index 9b6217033ede..c56d5ff090d8 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -519,6 +519,16 @@ class CheckCaptures extends Recheck, SymTransformer: if sym.isConstructor then val cls = sym.owner.asClass + /** Check if the class or one of its parents has a root capability, + * which means that the class has a capability annotation or an impure + * function type. + */ + def hasUniversalCapability(tp: Type): Boolean = tp match + case CapturingType(parent, ref) => + ref.isUniversal || hasUniversalCapability(parent) + case tp => + tp.isCapabilityClassRef || tp.parents.exists(hasUniversalCapability) + /** First half of result pair: * Refine the type of a constructor call `new C(t_1, ..., t_n)` * to C{val x_1: T_1, ..., x_m: T_m} where x_1, ..., x_m are the tracked @@ -528,7 +538,8 @@ class CheckCaptures extends Recheck, SymTransformer: */ def addParamArgRefinements(core: Type, initCs: CaptureSet): (Type, CaptureSet) = var refined: Type = core - var allCaptures: CaptureSet = initCs + var allCaptures: CaptureSet = if hasUniversalCapability(core) + then CaptureSet.universal else initCs for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do val getter = cls.info.member(getterName).suchThat(_.is(ParamAccessor)).symbol if getter.termRef.isTracked && !getter.is(Private) then From 09b7f166f3dcb68ed5264396a8c463bbc665dc10 Mon Sep 17 00:00:00 2001 From: noti0na1 <8036790+noti0na1@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:11:43 +0000 Subject: [PATCH 2/4] Attempt to pass and check capability from parents correctly --- .../src/dotty/tools/dotc/cc/CaptureOps.scala | 11 +++++++++++ .../dotty/tools/dotc/cc/CheckCaptures.scala | 12 +----------- compiler/src/dotty/tools/dotc/cc/Setup.scala | 9 ++------- .../dotty/tools/dotc/core/TypeComparer.scala | 9 ++++++++- .../captures/extending-cap-classes.scala | 15 +++++++++++++++ .../extending-impure-function.scala.scala | 18 ++++++++++++++++++ 6 files changed, 55 insertions(+), 19 deletions(-) create mode 100644 tests/neg-custom-args/captures/extending-cap-classes.scala create mode 100644 tests/neg-custom-args/captures/extending-impure-function.scala.scala diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index 1c951a0c0846..6a8874839fb5 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -207,6 +207,17 @@ extension (tp: Type) case _: TypeRef | _: AppliedType => tp.typeSymbol.hasAnnotation(defn.CapabilityAnnot) case _ => false + /** Check if the class has universal capability, which means: + * 1. the class has a capability annotation, + * 2. the class is an impure function type, + * 3. or one of its base classes has universal capability. + */ + def hasUniversalCapability(using Context): Boolean = tp match + case CapturingType(parent, ref) => + ref.isUniversal || parent.hasUniversalCapability + case tp => + tp.isCapabilityClassRef || tp.parents.exists(_.hasUniversalCapability) + /** Drop @retains annotations everywhere */ def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling val tm = new TypeMap: diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index c56d5ff090d8..845d4bfdc8c2 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -519,16 +519,6 @@ class CheckCaptures extends Recheck, SymTransformer: if sym.isConstructor then val cls = sym.owner.asClass - /** Check if the class or one of its parents has a root capability, - * which means that the class has a capability annotation or an impure - * function type. - */ - def hasUniversalCapability(tp: Type): Boolean = tp match - case CapturingType(parent, ref) => - ref.isUniversal || hasUniversalCapability(parent) - case tp => - tp.isCapabilityClassRef || tp.parents.exists(hasUniversalCapability) - /** First half of result pair: * Refine the type of a constructor call `new C(t_1, ..., t_n)` * to C{val x_1: T_1, ..., x_m: T_m} where x_1, ..., x_m are the tracked @@ -538,7 +528,7 @@ class CheckCaptures extends Recheck, SymTransformer: */ def addParamArgRefinements(core: Type, initCs: CaptureSet): (Type, CaptureSet) = var refined: Type = core - var allCaptures: CaptureSet = if hasUniversalCapability(core) + var allCaptures: CaptureSet = if core.hasUniversalCapability then CaptureSet.universal else initCs for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do val getter = cls.info.member(getterName).suchThat(_.is(ParamAccessor)).symbol diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index 9ab41859f170..082074c84ffc 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -269,12 +269,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: CapturingType(fntpe, cs, boxed = false) else fntpe - /** Map references to capability classes C to C^ */ - private def expandCapabilityClass(tp: Type): Type = - if tp.isCapabilityClassRef - then CapturingType(tp, defn.expandedUniversalSet, boxed = false) - else tp - private def recur(t: Type): Type = normalizeCaptures(mapOver(t)) def apply(t: Type) = @@ -297,7 +291,8 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: case t: TypeVar => this(t.underlying) case t => - if t.isCapabilityClassRef + // Map references to capability classes C to C^ + if t.hasUniversalCapability then CapturingType(t, defn.expandedUniversalSet, boxed = false) else recur(t) end expandAliases diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index cee1ec7fffa8..0c237a0a5fd0 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -895,13 +895,20 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling canWidenAbstract && acc(true, tp) def tryBaseType(cls2: Symbol) = - val base = nonExprBaseType(tp1, cls2) + var base = nonExprBaseType(tp1, cls2) if base.exists && (base ne tp1) && (!caseLambda.exists || widenAbstractOKFor(tp2) || tp1.widen.underlyingClassRef(refinementOK = true).exists) then def checkBase = + // Strip existing capturing set from base type + base = base.stripCapturing + // Pass capture set of tp1 to base type + tp1 match + case tp1 @ CapturingType(_, refs1) => + base = CapturingType(base, refs1, tp1.isBoxed) + case _ => isSubType(base, tp2, if tp1.isRef(cls2) then approx else approx.addLow) && recordGadtUsageIf { MatchType.thatReducesUsingGadt(tp1) } if tp1.widenDealias.isInstanceOf[AndType] || base.isInstanceOf[OrType] then diff --git a/tests/neg-custom-args/captures/extending-cap-classes.scala b/tests/neg-custom-args/captures/extending-cap-classes.scala new file mode 100644 index 000000000000..17497e415a1e --- /dev/null +++ b/tests/neg-custom-args/captures/extending-cap-classes.scala @@ -0,0 +1,15 @@ +import annotation.capability + +class C1 +@capability class C2 extends C1 +class C3 extends C2 + +def test = + val x1: C1 = new C1 + val x2: C1 = new C2 // error + val x3: C1 = new C3 // error + + val y1: C2 = new C2 + val y2: C2 = new C3 + + val z1: C3 = new C3 \ No newline at end of file diff --git a/tests/neg-custom-args/captures/extending-impure-function.scala.scala b/tests/neg-custom-args/captures/extending-impure-function.scala.scala new file mode 100644 index 000000000000..25e7e035c9df --- /dev/null +++ b/tests/neg-custom-args/captures/extending-impure-function.scala.scala @@ -0,0 +1,18 @@ +class F extends (Int => Unit) { + def apply(x: Int): Unit = () +} + +def test = + val x1 = new (Int => Unit) { + def apply(x: Int): Unit = () + } + + val x2: Int -> Unit = new (Int => Unit) { // error + def apply(x: Int): Unit = () + } + + val y1: Int => Unit = new F + val y2: Int -> Unit = new F // error + + val z1 = () => () + val z2: () -> Unit = () => () From 83a409d85120e92a9d20ac9a1b3e1aec01b714ec Mon Sep 17 00:00:00 2001 From: noti0na1 <8036790+noti0na1@users.noreply.github.com> Date: Thu, 4 Apr 2024 03:57:56 +0000 Subject: [PATCH 3/4] Ignore capturing from parents when computing base type --- compiler/src/dotty/tools/dotc/core/SymDenotations.scala | 4 ++-- compiler/src/dotty/tools/dotc/core/TypeComparer.scala | 9 +-------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index f01d2faf86c4..05c16c38d646 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import config.Config import reporting.* import collection.mutable -import cc.{CapturingType, derivedCapturingType} +import cc.{CapturingType, derivedCapturingType, stripCapturing} import scala.annotation.internal.sharable import scala.compiletime.uninitialized @@ -2232,7 +2232,7 @@ object SymDenotations { tp match { case tp @ TypeRef(prefix, _) => def foldGlb(bt: Type, ps: List[Type]): Type = ps match { - case p :: ps1 => foldGlb(bt & recur(p), ps1) + case p :: ps1 => foldGlb(bt & recur(p.stripCapturing), ps1) case _ => bt } diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 0c237a0a5fd0..cee1ec7fffa8 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -895,20 +895,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling canWidenAbstract && acc(true, tp) def tryBaseType(cls2: Symbol) = - var base = nonExprBaseType(tp1, cls2) + val base = nonExprBaseType(tp1, cls2) if base.exists && (base ne tp1) && (!caseLambda.exists || widenAbstractOKFor(tp2) || tp1.widen.underlyingClassRef(refinementOK = true).exists) then def checkBase = - // Strip existing capturing set from base type - base = base.stripCapturing - // Pass capture set of tp1 to base type - tp1 match - case tp1 @ CapturingType(_, refs1) => - base = CapturingType(base, refs1, tp1.isBoxed) - case _ => isSubType(base, tp2, if tp1.isRef(cls2) then approx else approx.addLow) && recordGadtUsageIf { MatchType.thatReducesUsingGadt(tp1) } if tp1.widenDealias.isInstanceOf[AndType] || base.isInstanceOf[OrType] then From f6529c46f38cd6fc1a2adfa534d69ad203a2ca23 Mon Sep 17 00:00:00 2001 From: noti0na1 <8036790+noti0na1@users.noreply.github.com> Date: Mon, 22 Apr 2024 02:09:03 +0000 Subject: [PATCH 4/4] Store capability class information in a hash map during cc --- .../src/dotty/tools/dotc/cc/CaptureOps.scala | 15 ---------- .../dotty/tools/dotc/cc/CheckCaptures.scala | 2 +- compiler/src/dotty/tools/dotc/cc/Setup.scala | 28 ++++++++++++++++- .../captures/extending-impure-function.scala | 30 +++++++++++++++++++ .../extending-impure-function.scala.scala | 18 ----------- 5 files changed, 58 insertions(+), 35 deletions(-) create mode 100644 tests/neg-custom-args/captures/extending-impure-function.scala delete mode 100644 tests/neg-custom-args/captures/extending-impure-function.scala.scala diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index 6a8874839fb5..42483599f1e6 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -203,21 +203,6 @@ extension (tp: Type) case _ => false - def isCapabilityClassRef(using Context) = tp.dealiasKeepAnnots match - case _: TypeRef | _: AppliedType => tp.typeSymbol.hasAnnotation(defn.CapabilityAnnot) - case _ => false - - /** Check if the class has universal capability, which means: - * 1. the class has a capability annotation, - * 2. the class is an impure function type, - * 3. or one of its base classes has universal capability. - */ - def hasUniversalCapability(using Context): Boolean = tp match - case CapturingType(parent, ref) => - ref.isUniversal || parent.hasUniversalCapability - case tp => - tp.isCapabilityClassRef || tp.parents.exists(_.hasUniversalCapability) - /** Drop @retains annotations everywhere */ def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling val tm = new TypeMap: diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index 845d4bfdc8c2..3b241f751403 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -528,7 +528,7 @@ class CheckCaptures extends Recheck, SymTransformer: */ def addParamArgRefinements(core: Type, initCs: CaptureSet): (Type, CaptureSet) = var refined: Type = core - var allCaptures: CaptureSet = if core.hasUniversalCapability + var allCaptures: CaptureSet = if setup.isCapabilityClassRef(core) then CaptureSet.universal else initCs for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do val getter = cls.info.member(getterName).suchThat(_.is(ParamAccessor)).symbol diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index 082074c84ffc..fef88a8ba6de 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -23,6 +23,7 @@ trait SetupAPI: def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit def isPreCC(sym: Symbol)(using Context): Boolean def postCheck()(using Context): Unit + def isCapabilityClassRef(tp: Type)(using Context): Boolean object Setup: @@ -67,6 +68,31 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: && !sym.owner.is(CaptureChecked) && !defn.isFunctionSymbol(sym.owner) + private val capabilityClassMap = new util.HashMap[Symbol, Boolean] + + /** Check if the class is capability, which means: + * 1. the class has a capability annotation, + * 2. or at least one of its parent type has universal capability. + */ + def isCapabilityClassRef(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match + case _: TypeRef | _: AppliedType => + val sym = tp.classSymbol + def checkSym: Boolean = + sym.hasAnnotation(defn.CapabilityAnnot) + || sym.info.parents.exists(hasUniversalCapability) + sym.isClass && capabilityClassMap.getOrElseUpdate(sym, checkSym) + case _ => false + + private def hasUniversalCapability(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match + case CapturingType(parent, refs) => + refs.isUniversal || hasUniversalCapability(parent) + case AnnotatedType(parent, ann) => + if ann.symbol.isRetains then + try ann.tree.toCaptureSet.isUniversal || hasUniversalCapability(parent) + catch case ex: IllegalCaptureRef => false + else hasUniversalCapability(parent) + case tp => isCapabilityClassRef(tp) + private def fluidify(using Context) = new TypeMap with IdempotentCaptRefMap: def apply(t: Type): Type = t match case t: MethodType => @@ -292,7 +318,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI: this(t.underlying) case t => // Map references to capability classes C to C^ - if t.hasUniversalCapability + if isCapabilityClassRef(t) then CapturingType(t, defn.expandedUniversalSet, boxed = false) else recur(t) end expandAliases diff --git a/tests/neg-custom-args/captures/extending-impure-function.scala b/tests/neg-custom-args/captures/extending-impure-function.scala new file mode 100644 index 000000000000..e491b31caed5 --- /dev/null +++ b/tests/neg-custom-args/captures/extending-impure-function.scala @@ -0,0 +1,30 @@ +class F1 extends (Int => Unit) { + def apply(x: Int): Unit = () +} + +class F2 extends (Int -> Unit) { + def apply(x: Int): Unit = () +} + +def test = + val x1 = new (Int => Unit) { + def apply(x: Int): Unit = () + } + + val x2: Int -> Unit = new (Int => Unit) { // error + def apply(x: Int): Unit = () + } + + val x3: Int -> Unit = new (Int -> Unit) { + def apply(x: Int): Unit = () + } + + val y1: Int => Unit = new F1 + val y2: Int -> Unit = new F1 // error + val y3: Int => Unit = new F2 + val y4: Int -> Unit = new F2 + + val z1 = () => () + val z2: () -> Unit = () => () + val z3: () -> Unit = z1 + val z4: () => Unit = () => () diff --git a/tests/neg-custom-args/captures/extending-impure-function.scala.scala b/tests/neg-custom-args/captures/extending-impure-function.scala.scala deleted file mode 100644 index 25e7e035c9df..000000000000 --- a/tests/neg-custom-args/captures/extending-impure-function.scala.scala +++ /dev/null @@ -1,18 +0,0 @@ -class F extends (Int => Unit) { - def apply(x: Int): Unit = () -} - -def test = - val x1 = new (Int => Unit) { - def apply(x: Int): Unit = () - } - - val x2: Int -> Unit = new (Int => Unit) { // error - def apply(x: Int): Unit = () - } - - val y1: Int => Unit = new F - val y2: Int -> Unit = new F // error - - val z1 = () => () - val z2: () -> Unit = () => ()