Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Carry and check universal capability from parents correctly #20004

Merged
merged 4 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,6 @@ extension (tp: Type)
case _ =>
false

def isCapabilityClassRef(using Context) = tp.dealiasKeepAnnots match
case _: TypeRef | _: AppliedType => tp.typeSymbol.hasAnnotation(defn.CapabilityAnnot)
case _ => false

/** Drop @retains annotations everywhere */
def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling
val tm = new TypeMap:
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ class CheckCaptures extends Recheck, SymTransformer:
*/
def addParamArgRefinements(core: Type, initCs: CaptureSet): (Type, CaptureSet) =
var refined: Type = core
var allCaptures: CaptureSet = initCs
var allCaptures: CaptureSet = if setup.isCapabilityClassRef(core)
then CaptureSet.universal else initCs
for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do
val getter = cls.info.member(getterName).suchThat(_.is(ParamAccessor)).symbol
if getter.termRef.isTracked && !getter.is(Private) then
Expand Down
35 changes: 28 additions & 7 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ trait SetupAPI:
def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit
def isPreCC(sym: Symbol)(using Context): Boolean
def postCheck()(using Context): Unit
def isCapabilityClassRef(tp: Type)(using Context): Boolean

object Setup:

Expand Down Expand Up @@ -67,6 +68,31 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
&& !sym.owner.is(CaptureChecked)
&& !defn.isFunctionSymbol(sym.owner)

private val capabilityClassMap = new util.HashMap[Symbol, Boolean]

/** Check if the class is capability, which means:
* 1. the class has a capability annotation,
* 2. or at least one of its parent type has universal capability.
*/
def isCapabilityClassRef(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match
case _: TypeRef | _: AppliedType =>
val sym = tp.classSymbol
def checkSym: Boolean =
sym.hasAnnotation(defn.CapabilityAnnot)
|| sym.info.parents.exists(hasUniversalCapability)
sym.isClass && capabilityClassMap.getOrElseUpdate(sym, checkSym)
case _ => false

private def hasUniversalCapability(tp: Type)(using Context): Boolean = tp.dealiasKeepAnnots match
case CapturingType(parent, refs) =>
refs.isUniversal || hasUniversalCapability(parent)
case AnnotatedType(parent, ann) =>
if ann.symbol.isRetains then
try ann.tree.toCaptureSet.isUniversal || hasUniversalCapability(parent)
catch case ex: IllegalCaptureRef => false
else hasUniversalCapability(parent)
case tp => isCapabilityClassRef(tp)

private def fluidify(using Context) = new TypeMap with IdempotentCaptRefMap:
def apply(t: Type): Type = t match
case t: MethodType =>
Expand Down Expand Up @@ -269,12 +295,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
CapturingType(fntpe, cs, boxed = false)
else fntpe

/** Map references to capability classes C to C^ */
private def expandCapabilityClass(tp: Type): Type =
if tp.isCapabilityClassRef
then CapturingType(tp, defn.expandedUniversalSet, boxed = false)
else tp

private def recur(t: Type): Type = normalizeCaptures(mapOver(t))

def apply(t: Type) =
Expand All @@ -297,7 +317,8 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
case t: TypeVar =>
this(t.underlying)
case t =>
if t.isCapabilityClassRef
// Map references to capability classes C to C^
if isCapabilityClassRef(t)
then CapturingType(t, defn.expandedUniversalSet, boxed = false)
else recur(t)
end expandAliases
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.util.control.NonFatal
import config.Config
import reporting.*
import collection.mutable
import cc.{CapturingType, derivedCapturingType}
import cc.{CapturingType, derivedCapturingType, stripCapturing}

import scala.annotation.internal.sharable
import scala.compiletime.uninitialized
Expand Down Expand Up @@ -2232,7 +2232,7 @@ object SymDenotations {
tp match {
case tp @ TypeRef(prefix, _) =>
def foldGlb(bt: Type, ps: List[Type]): Type = ps match {
case p :: ps1 => foldGlb(bt & recur(p), ps1)
case p :: ps1 => foldGlb(bt & recur(p.stripCapturing), ps1)
case _ => bt
}

Expand Down
15 changes: 15 additions & 0 deletions tests/neg-custom-args/captures/extending-cap-classes.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import annotation.capability

class C1
@capability class C2 extends C1
class C3 extends C2

def test =
val x1: C1 = new C1
val x2: C1 = new C2 // error
val x3: C1 = new C3 // error

val y1: C2 = new C2
val y2: C2 = new C3

val z1: C3 = new C3
30 changes: 30 additions & 0 deletions tests/neg-custom-args/captures/extending-impure-function.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class F1 extends (Int => Unit) {
def apply(x: Int): Unit = ()
}

class F2 extends (Int -> Unit) {
def apply(x: Int): Unit = ()
}

def test =
val x1 = new (Int => Unit) {
def apply(x: Int): Unit = ()
}

val x2: Int -> Unit = new (Int => Unit) { // error
def apply(x: Int): Unit = ()
}

val x3: Int -> Unit = new (Int -> Unit) {
def apply(x: Int): Unit = ()
}

val y1: Int => Unit = new F1
val y2: Int -> Unit = new F1 // error
val y3: Int => Unit = new F2
val y4: Int -> Unit = new F2

val z1 = () => ()
val z2: () -> Unit = () => ()
val z3: () -> Unit = z1
val z4: () => Unit = () => ()
Loading