Skip to content

Commit

Permalink
Add experimental flexible types feature on top of explicit nulls
Browse files Browse the repository at this point in the history
Enabled by -Yflexible-types with -Yexplicit-nulls.
A flexible type T! is a non-denotable type such that T <: T! <: T|Null and T|Null <: T! <: T. Here we patch return types and parameter types of Java methods and fields to use flexible types.
This is unsound and kills subtyping transitivity but makes interop with Java play more nicely with the explicit nulls experimental feature (i.e. fewer nullability casts).

Co-authored-by: Ondřej Lhoták <olhotak@uwaterloo.ca>
Co-authored-by: Evan Girardin <evangirardin@gmail.com>
  • Loading branch information
3 people committed Nov 30, 2023
1 parent c88c0fe commit 3e7d5ae
Show file tree
Hide file tree
Showing 57 changed files with 650 additions and 37 deletions.
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/ScalaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ private sealed trait YSettings:
// Experimental language features
val YnoKindPolymorphism: Setting[Boolean] = BooleanSetting("-Yno-kind-polymorphism", "Disable kind polymorphism.")
val YexplicitNulls: Setting[Boolean] = BooleanSetting("-Yexplicit-nulls", "Make reference types non-nullable. Nullable types can be expressed with unions: e.g. String|Null.")
val YflexibleTypes: Setting[Boolean] = BooleanSetting("-Yflexible-types", "Make Java return types and parameter types use flexible types, which have a nullable lower bound and non-null upper bound.")
val YcheckInit: Setting[Boolean] = BooleanSetting("-Ysafe-init", "Ensure safe initialization of objects.")
val YcheckInitGlobal: Setting[Boolean] = BooleanSetting("-Ysafe-init-global", "Check safe initialization of global objects.")
val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation.")
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ object Contexts {
/** Is the explicit nulls option set? */
def explicitNulls: Boolean = base.settings.YexplicitNulls.value

/** Is the flexible types option set? */
def flexibleTypes: Boolean = base.settings.YexplicitNulls.value && base.settings.YflexibleTypes.value

/** A fresh clone of this context embedded in this context. */
def fresh: FreshContext = freshOver(this)

Expand Down
30 changes: 17 additions & 13 deletions compiler/src/dotty/tools/dotc/core/JavaNullInterop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ object JavaNullInterop {
* but the result type is not nullable.
*/
private def nullifyExceptReturnType(tp: Type)(using Context): Type =
new JavaNullMap(true)(tp)
new JavaNullMap(outermostLevelAlreadyNullable = true)(tp)

/** Nullifies a Java type by adding `| Null` in the relevant places. */
private def nullifyType(tp: Type)(using Context): Type =
new JavaNullMap(false)(tp)
new JavaNullMap(outermostLevelAlreadyNullable = false)(tp)

/** A type map that implements the nullification function on types. Given a Java-sourced type, this adds `| Null`
* in the right places to make the nulls explicit in Scala.
Expand All @@ -96,25 +96,29 @@ object JavaNullInterop {
* to `(A & B) | Null`, instead of `(A | Null & B | Null) | Null`.
*/
private class JavaNullMap(var outermostLevelAlreadyNullable: Boolean)(using Context) extends TypeMap {
def nullify(tp: Type): Type = if ctx.flexibleTypes then FlexibleType(tp) else OrNull(tp)

/** Should we nullify `tp` at the outermost level? */
def needsNull(tp: Type): Boolean =
!outermostLevelAlreadyNullable && (tp match {
!(outermostLevelAlreadyNullable || (tp match {
case tp: TypeRef =>
// We don't modify value types because they're non-nullable even in Java.
!tp.symbol.isValueClass &&
tp.symbol.isValueClass
// We don't modify unit types.
|| tp.isRef(defn.UnitClass)
// We don't modify `Any` because it's already nullable.
!tp.isRef(defn.AnyClass) &&
|| tp.isRef(defn.AnyClass)
// We don't nullify Java varargs at the top level.
// Example: if `setNames` is a Java method with signature `void setNames(String... names)`,
// then its Scala signature will be `def setNames(names: (String|Null)*): Unit`.
// This is because `setNames(null)` passes as argument a single-element array containing the value `null`,
// and not a `null` array.
!tp.isRef(defn.RepeatedParamClass)
case _ => true
})
|| !ctx.flexibleTypes && tp.isRef(defn.RepeatedParamClass)
case _ => false
}))

override def apply(tp: Type): Type = tp match {
case tp: TypeRef if needsNull(tp) => OrNull(tp)
case tp: TypeRef if needsNull(tp) => nullify(tp)
case appTp @ AppliedType(tycon, targs) =>
val oldOutermostNullable = outermostLevelAlreadyNullable
// We don't make the outmost levels of type arguments nullable if tycon is Java-defined.
Expand All @@ -124,7 +128,7 @@ object JavaNullInterop {
val targs2 = targs map this
outermostLevelAlreadyNullable = oldOutermostNullable
val appTp2 = derivedAppliedType(appTp, tycon, targs2)
if needsNull(tycon) then OrNull(appTp2) else appTp2
if needsNull(tycon) then nullify(appTp2) else appTp2
case ptp: PolyType =>
derivedLambdaType(ptp)(ptp.paramInfos, this(ptp.resType))
case mtp: MethodType =>
Expand All @@ -138,12 +142,12 @@ object JavaNullInterop {
// nullify(A & B) = (nullify(A) & nullify(B)) | Null, but take care not to add
// duplicate `Null`s at the outermost level inside `A` and `B`.
outermostLevelAlreadyNullable = true
OrNull(derivedAndType(tp, this(tp.tp1), this(tp.tp2)))
case tp: TypeParamRef if needsNull(tp) => OrNull(tp)
nullify(derivedAndType(tp, this(tp.tp1), this(tp.tp2)))
case tp: TypeParamRef if needsNull(tp) => nullify(tp)
// In all other cases, return the type unchanged.
// In particular, if the type is a ConstantType, then we don't nullify it because it is the
// type of a final non-nullable field.
case _ => tp
}
}
}
}
7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ import Types.*
object NullOpsDecorator:

extension (self: Type)
def stripFlexible(using Context): Type = {
self match {
case FlexibleType(tp) => tp
case _ => self
}
}
/** Syntactically strips the nullability from this type.
* If the type is `T1 | ... | Tn`, and `Ti` references to `Null`,
* then return `T1 | ... | Ti-1 | Ti+1 | ... | Tn`.
Expand All @@ -33,6 +39,7 @@ object NullOpsDecorator:
if (tp1s ne tp1) && (tp2s ne tp2) then
tp.derivedAndType(tp1s, tp2s)
else tp
case tp: FlexibleType => tp.hi
case tp @ TypeBounds(lo, hi) =>
tp.derivedTypeBounds(strip(lo), strip(hi))
case tp => tp
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
case CapturingType(parent, refs) =>
val parent1 = recur(parent)
if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp
case tp: FlexibleType =>
val underlying = recur(tp.underlying)
if underlying ne tp.underlying then tp.derivedFlexibleType(underlying) else tp
case tp: AnnotatedType =>
val parent1 = recur(tp.parent)
if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Contexts.ctx
import dotty.tools.dotc.reporting.trace
import config.Feature.migrateTo3
import config.Printers.*
import dotty.tools.dotc.core.NullOpsDecorator.stripFlexible

trait PatternTypeConstrainer { self: TypeComparer =>

Expand Down Expand Up @@ -163,7 +164,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
}
}

def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match {
def dealiasDropNonmoduleRefs(tp: Type): Type = tp.dealias match {
case tp: TermRef =>
// we drop TermRefs that don't have a class symbol, as they can't
// meaningfully participate in GADT reasoning and just get in the way.
Expand All @@ -172,6 +173,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
// additional trait - argument-less enum cases desugar to vals.
// See run/enum-Tree.scala.
if tp.classSymbol.exists then tp else tp.info
case FlexibleType(tp) => dealiasDropNonmoduleRefs(tp)
case tp => tp
}

Expand Down
29 changes: 19 additions & 10 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import reporting.trace
import annotation.constructorOnly
import cc.*
import NameKinds.WildcardParamName
import NullOpsDecorator.stripFlexible

/** Provides methods to compare types.
*/
Expand Down Expand Up @@ -524,7 +525,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
constraint = constraint.hardenTypeVars(tp2)

res

case tp1 @ CapturingType(parent1, refs1) =>
def compareCapturing =
if tp2.isAny then true
Expand Down Expand Up @@ -863,6 +863,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
false
}
compareClassInfo
case tp2: FlexibleType =>
recur(tp1, tp2.lo)
case _ =>
fourthTry
}
Expand Down Expand Up @@ -1058,6 +1060,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case tp1: ExprType if ctx.phaseId > gettersPhase.id =>
// getters might have converted T to => T, need to compensate.
recur(tp1.widenExpr, tp2)
case tp1: FlexibleType =>
recur(tp1.hi, tp2)
case _ =>
false
}
Expand Down Expand Up @@ -2499,15 +2503,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
NoType
}

private def andTypeGen(tp1: Type, tp2: Type, op: (Type, Type) => Type,
original: (Type, Type) => Type = _ & _, isErased: Boolean = ctx.erasedTypes): Type = trace(s"andTypeGen(${tp1.show}, ${tp2.show})", subtyping, show = true) {
val t1 = distributeAnd(tp1, tp2)
if (t1.exists) t1
else {
val t2 = distributeAnd(tp2, tp1)
if (t2.exists) t2
else if (isErased) erasedGlb(tp1, tp2)
else liftIfHK(tp1, tp2, op, original, _ | _)
private def andTypeGen(tp1orig: Type, tp2orig: Type, op: (Type, Type) => Type,
original: (Type, Type) => Type = _ & _, isErased: Boolean = ctx.erasedTypes): Type = trace(s"andTypeGen(${tp1orig.show}, ${tp2orig.show})", subtyping, show = true) {
val tp1 = tp1orig.stripFlexible
val tp2 = tp2orig.stripFlexible
val ret = {
val t1 = distributeAnd(tp1, tp2)
if (t1.exists) t1
else {
val t2 = distributeAnd(tp2, tp1)
if (t2.exists) t2
else if (isErased) erasedGlb(tp1, tp2)
else liftIfHK(tp1, tp2, op, original, _ | _)
// The ` | ` on variances is needed since variances are associated with bounds
// not lambdas. Example:
//
Expand All @@ -2517,7 +2524,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
//
// Here, `F` is treated as bivariant in `O`. That is, only bivariant implementation
// of `F` are allowed. See neg/hk-variance2s.scala test.
}
}
if(tp1orig.isInstanceOf[FlexibleType] && tp2orig.isInstanceOf[FlexibleType]) FlexibleType(ret) else ret
}

/** Form a normalized conjunction of two types.
Expand Down
54 changes: 54 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ object Types extends TypeUtils {
/** Is this type guaranteed not to have `null` as a value? */
final def isNotNull(using Context): Boolean = this match {
case tp: ConstantType => tp.value.value != null
case tp: FlexibleType => false
case tp: ClassInfo => !tp.cls.isNullableClass && tp.cls != defn.NothingClass
case tp: AppliedType => tp.superType.isNotNull
case tp: TypeBounds => tp.lo.isNotNull
Expand Down Expand Up @@ -372,6 +373,7 @@ object Types extends TypeUtils {
case AppliedType(tycon, args) => tycon.unusableForInference || args.exists(_.unusableForInference)
case RefinedType(parent, _, rinfo) => parent.unusableForInference || rinfo.unusableForInference
case TypeBounds(lo, hi) => lo.unusableForInference || hi.unusableForInference
case FlexibleType(underlying) => underlying.unusableForInference
case tp: AndOrType => tp.tp1.unusableForInference || tp.tp2.unusableForInference
case tp: LambdaType => tp.resultType.unusableForInference || tp.paramInfos.exists(_.unusableForInference)
case WildcardType(optBounds) => optBounds.unusableForInference
Expand Down Expand Up @@ -3396,6 +3398,40 @@ object Types extends TypeUtils {
}
}

// --- FlexibleType -----------------------------------------------------------------

/* Represents a nullable type coming from Java code in a similar way to Platform Types
* in Kotlin. A FlexibleType(T) generally behaves like an abstract type with bad bounds
* T|Null .. T, so that T|Null <: FlexibleType(T) <: T.
*/
case class FlexibleType(original: Type, lo: Type, hi: Type) extends CachedProxyType with ValueType {
def underlying(using Context): Type = original

override def superType(using Context): Type = hi

def derivedFlexibleType(original: Type)(using Context): Type =
if this.original eq original then this else FlexibleType(original)

override def computeHash(bs: Binders): Int = doHash(bs, original)

override final def baseClasses(using Context): List[ClassSymbol] = original.baseClasses
}

object FlexibleType {
def apply(original: Type)(using Context): FlexibleType = original match {
case ft: FlexibleType => ft
case _ =>
val hi = original.stripNull
val lo = if hi eq original then OrNull(hi) else original
new FlexibleType(original, lo, hi)
}

def unapply(tp: Type)(using Context): Option[Type] = tp match {
case ft: FlexibleType => Some(ft.original)
case _ => None
}
}

// --- AndType/OrType ---------------------------------------------------------------

abstract class AndOrType extends CachedGroundType with ValueType {
Expand Down Expand Up @@ -5694,6 +5730,8 @@ object Types extends TypeUtils {
samClass(tp.underlying)
case tp: AnnotatedType =>
samClass(tp.underlying)
case tp: FlexibleType =>
samClass(tp.superType)
case _ =>
NoSymbol

Expand Down Expand Up @@ -5824,6 +5862,8 @@ object Types extends TypeUtils {
tp.derivedJavaArrayType(elemtp)
protected def derivedExprType(tp: ExprType, restpe: Type): Type =
tp.derivedExprType(restpe)
protected def derivedFlexibleType(tp: FlexibleType, under: Type): Type =
tp.derivedFlexibleType(under)
// note: currying needed because Scala2 does not support param-dependencies
protected def derivedLambdaType(tp: LambdaType)(formals: List[tp.PInfo], restpe: Type): Type =
tp.derivedLambdaType(tp.paramNames, formals, restpe)
Expand Down Expand Up @@ -5947,6 +5987,9 @@ object Types extends TypeUtils {
case tp: OrType =>
derivedOrType(tp, this(tp.tp1), this(tp.tp2))

case tp: FlexibleType =>
derivedFlexibleType(tp, this(tp.underlying))

case tp: MatchType =>
val bound1 = this(tp.bound)
val scrut1 = atVariance(0)(this(tp.scrutinee))
Expand Down Expand Up @@ -6234,6 +6277,14 @@ object Types extends TypeUtils {
if (underlying.isExactlyNothing) underlying
else tp.derivedAnnotatedType(underlying, annot)
}
override protected def derivedFlexibleType(tp: FlexibleType, underlying: Type): Type =
underlying match {
case Range(lo, hi) =>
range(tp.derivedFlexibleType(lo), tp.derivedFlexibleType(hi))
case _ =>
if (underlying.isExactlyNothing) underlying
else tp.derivedFlexibleType(underlying)
}
override protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type =
parent match // TODO ^^^ handle ranges in capture sets as well
case Range(lo, hi) =>
Expand Down Expand Up @@ -6375,6 +6426,9 @@ object Types extends TypeUtils {
case tp: TypeVar =>
this(x, tp.underlying)

case tp: FlexibleType =>
this(x, tp.underlying)

case ExprType(restpe) =>
this(x, restpe)

Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
case tpe: OrType =>
writeByte(ORtype)
withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) }
case tpe: FlexibleType =>
writeByte(FLEXIBLEtype)
withLength { pickleType(tpe.underlying, richTypes) }
case tpe: ExprType =>
writeByte(BYNAMEtype)
pickleType(tpe.underlying)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ class TreeUnpickler(reader: TastyReader,
readTypeRef() match {
case binder: LambdaType => binder.paramRefs(readNat())
}
case FLEXIBLEtype =>
FlexibleType(readType())
}
assert(currentAddr == end, s"$start $currentAddr $end ${astTagToString(tag)}")
result
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ class PlainPrinter(_ctx: Context) extends Printer {
case AnnotatedType(tpe, annot) =>
if annot.symbol == defn.InlineParamAnnot || annot.symbol == defn.ErasedParamAnnot then toText(tpe)
else toTextLocal(tpe) ~ " " ~ toText(annot)
case FlexibleType(tpe) =>
"FlexibleType(" ~ toText(tpe) ~ ")"
case tp: TypeVar =>
def toTextCaret(tp: Type) = if printDebug then toTextLocal(tp) ~ Str("^") else toText(tp)
if (tp.isInstantiated)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@ private class ExtractAPICollector(using Context) extends ThunkHolder {
case tp: OrType =>
val s = combineApiTypes(apiType(tp.tp1), apiType(tp.tp2))
withMarker(s, orMarker)
case tp: FlexibleType =>
apiType(tp.underlying)
case ExprType(resultType) =>
withMarker(apiType(resultType), byNameMarker)
case MatchType(bound, scrut, cases) =>
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -643,14 +643,15 @@ trait Applications extends Compatibility {
missingArg(n)
}

if (formal.isRepeatedParam)
val formal1 = formal.stripFlexible
if (formal1.isRepeatedParam)
args match {
case arg :: Nil if isVarArg(arg) =>
addTyped(arg)
case (arg @ Typed(Literal(Constant(null)), _)) :: Nil if ctx.isAfterTyper =>
addTyped(arg)
case _ =>
val elemFormal = formal.widenExpr.argTypesLo.head
val elemFormal = formal1.widenExpr.argTypesLo.head
val typedArgs =
harmonic(harmonizeArgs, elemFormal) {
args.map { arg =>
Expand Down
Loading

0 comments on commit 3e7d5ae

Please sign in to comment.