Skip to content

Commit

Permalink
Improving type guarantees when invoking (#233)
Browse files Browse the repository at this point in the history
With this pull request necessary casts are added when calling and
returning from lambdas and functions (inline and non-inline).

Before code like
```kotlin
fun <T> id(x: T) = x

inline fun <T, R> T.runWithId(block: T.() -> R) = id(this).block()

class Class {
  val member = 0
}

Class().runWithId { member }
```
wouldn't work: we only knew that the object returned from `id` had type
`Any?`. Therefore, we make casts when entering (inline calls) and
returning (from any calls) to keep `TypeEmbedding` consistent with
kotlin types (this is also needed to correctly represent `==` calls
etc.)
  • Loading branch information
GrigoriiSolnyshkin committed Aug 25, 2024
1 parent 4a397b9 commit 2da4599
Show file tree
Hide file tree
Showing 56 changed files with 1,088 additions and 324 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -371,20 +371,13 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
)
val stmtCtx = StmtConverter(methodCtx)
val body = stmtCtx.convert(firBody)

// In the end we ensure that returned value is of some type even if that type is Unit.
// However, for Unit we don't assign the result to any value.
// One of the simplest solutions is to do is directly in the beginning of the body.
val unitExtendedBody: ExpEmbedding =
if (signature.type.returnType != UnitTypeEmbedding) body
else blockOf(
Assign(stmtCtx.defaultResolvedReturnTarget.variable, UnitLit),
body,
)
val bodyExp = FunctionExp(signature, unitExtendedBody, returnTarget.label)
val bodyExp = FunctionExp(signature, body, returnTarget.label)
val seqnBuilder = SeqnBuilder(declaration.source)
val linearizer = Linearizer(SharedLinearizationState(anonVarProducer), seqnBuilder, declaration.source)
bodyExp.toViperUnusedResult(linearizer)
// note: we must guarantee somewhere that returned value is Unit
// as we may not encounter any `return` statement in the body
returnTarget.variable.withIsUnitInvariantIfUnit().toViperUnusedResult(linearizer)
return FunctionBodyEmbedding(seqnBuilder.block, returnTarget, bodyExp)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.jetbrains.kotlin.formver.embeddings.expression.*
import org.jetbrains.kotlin.formver.isCustom
import org.jetbrains.kotlin.formver.viper.ast.Label
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.utils.addIfNotNull
import org.jetbrains.kotlin.utils.addToStdlib.ifTrue
import org.jetbrains.kotlin.utils.filterIsInstanceAnd

Expand Down Expand Up @@ -131,18 +132,35 @@ fun StmtConversionContext.embedPropertyAccess(accessExpression: FirPropertyAcces
error("Property access symbol $calleeSymbol has unsupported type.")
}


fun StmtConversionContext.argumentDeclaration(arg: ExpEmbedding, callType: TypeEmbedding): Pair<Declare?, ExpEmbedding> =
when (arg.ignoringMetaNodes()) {
is LambdaExp -> null to arg
else -> {
val argWithInvariants = arg.withNewTypeInvariants(callType) {
proven = true
access = true
}
// If `argWithInvariants` is `Cast(...(Cast(someVariable))...)` it is fine to use it
// since in Viper it will always be translated to `someVariable`.
// On other hand, `TypeEmbedding` and invariants in Viper are guaranteed
// via previous line.
if (argWithInvariants.underlyingVariable != null) null to argWithInvariants
else declareAnonVar(callType, argWithInvariants).let {
it to it.variable
}
}
}

fun StmtConversionContext.getInlineFunctionCallArgs(
args: List<ExpEmbedding>,
formalArgTypes: List<TypeEmbedding>,
): Pair<List<Declare>, List<ExpEmbedding>> {
val declarations = mutableListOf<Declare>()
val storedArgs = args.map { arg ->
when (arg.ignoringMetaNodes()) {
is VariableEmbedding, is LambdaExp -> arg
else -> {
val paramVarDecl = declareAnonVar(arg.type, arg)
declarations.add(paramVarDecl)
paramVarDecl.variable
}
val storedArgs = args.zip(formalArgTypes).map { (arg, callType) ->
argumentDeclaration(arg, callType).let { (declaration, usage) ->
declarations.addIfNotNull(declaration)
usage
}
}
return Pair(declarations, storedArgs)
Expand All @@ -158,7 +176,7 @@ fun StmtConversionContext.insertInlineFunctionCall(
): ExpEmbedding {
// TODO: It seems like it may be possible to avoid creating a local here, but it is not clear how.
val returnTarget = returnTargetProducer.getFresh(calleeSignature.type.returnType)
val (declarations, callArgs) = getInlineFunctionCallArgs(args)
val (declarations, callArgs) = getInlineFunctionCallArgs(args, calleeSignature.type.formalArgTypes)
val subs = paramNames.zip(callArgs).toMap()
val methodCtxFactory = MethodContextFactory(
calleeSignature,
Expand All @@ -170,7 +188,8 @@ fun StmtConversionContext.insertInlineFunctionCall(
add(Declare(returnTarget.variable, null))
addAll(declarations)
add(FunctionExp(null, convert(body), returnTarget.label))
add(returnTarget.variable)
// if unit is what we return we might not guarantee it yet
add(returnTarget.variable.withIsUnitInvariantIfUnit())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@ import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.declarations.FirProperty
import org.jetbrains.kotlin.fir.expressions.*
import org.jetbrains.kotlin.fir.expressions.impl.FirElseIfTrueCondition
import org.jetbrains.kotlin.fir.expressions.impl.FirUnitExpression
import org.jetbrains.kotlin.fir.references.toResolvedSymbol
import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.fir.types.isUnit
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlin.fir.visitors.FirVisitor
import org.jetbrains.kotlin.formver.UnsupportedFeatureBehaviour
import org.jetbrains.kotlin.formver.embeddings.TypeEmbedding
import org.jetbrains.kotlin.formver.embeddings.buildType
import org.jetbrains.kotlin.formver.embeddings.callables.FullNamedFunctionSignature
import org.jetbrains.kotlin.formver.embeddings.callables.insertCall
import org.jetbrains.kotlin.formver.embeddings.expression.*
import org.jetbrains.kotlin.formver.functionCallArguments
import org.jetbrains.kotlin.text
Expand Down Expand Up @@ -48,7 +51,10 @@ object StmtConversionVisitor : FirVisitor<ExpEmbedding, StmtConversionContext>()
returnExpression: FirReturnExpression,
data: StmtConversionContext,
): ExpEmbedding {
val expr = data.convert(returnExpression.result)
val expr = when (returnExpression.result) {
is FirUnitExpression -> UnitLit
else -> data.convert(returnExpression.result)
}
// returnTarget is null when it is the implicit return of a lambda
val returnTargetName = returnExpression.target.labelName
val target = data.resolveReturnTarget(returnTargetName)
Expand All @@ -58,18 +64,25 @@ object StmtConversionVisitor : FirVisitor<ExpEmbedding, StmtConversionContext>()
)
}

override fun visitResolvedQualifier(resolvedQualifier: FirResolvedQualifier, data: StmtConversionContext): ExpEmbedding {
check(resolvedQualifier.resolvedType.isUnit) {
"Only `Unit` is supported among resolved qualifiers currently."
}
return UnitLit
}

override fun visitBlock(block: FirBlock, data: StmtConversionContext): ExpEmbedding =
block.statements.map(data::convert).toBlock()

override fun visitLiteralExpression(
constExpression: FirLiteralExpression,
literalExpression: FirLiteralExpression,
data: StmtConversionContext,
): ExpEmbedding =
when (constExpression.kind) {
ConstantValueKind.Int -> IntLit((constExpression.value as Long).toInt())
ConstantValueKind.Boolean -> BooleanLit(constExpression.value as Boolean)
when (literalExpression.kind) {
ConstantValueKind.Int -> IntLit((literalExpression.value as Long).toInt())
ConstantValueKind.Boolean -> BooleanLit(literalExpression.value as Boolean)
ConstantValueKind.Null -> NullLit
else -> handleUnimplementedElement("Constant Expression of type ${constExpression.kind} is not yet implemented.", data)
else -> handleUnimplementedElement("Constant Expression of type ${literalExpression.kind} is not yet implemented.", data)
}

override fun visitIntegerLiteralOperatorCall(
Expand Down Expand Up @@ -175,7 +188,11 @@ object StmtConversionVisitor : FirVisitor<ExpEmbedding, StmtConversionContext>()
override fun visitFunctionCall(functionCall: FirFunctionCall, data: StmtConversionContext): ExpEmbedding {
val symbol = functionCall.toResolvedCallableSymbol()
val callee = data.embedFunction(symbol as FirFunctionSymbol<*>)
return callee.insertCall(functionCall.functionCallArguments.map(data::convert), data)
return callee.insertCall(
functionCall.functionCallArguments.map(data::convert),
data,
data.embedType(functionCall.resolvedType),
)
}

override fun visitImplicitInvokeCall(
Expand All @@ -184,17 +201,17 @@ object StmtConversionVisitor : FirVisitor<ExpEmbedding, StmtConversionContext>()
): ExpEmbedding {
val receiver = implicitInvokeCall.dispatchReceiver as? FirPropertyAccessExpression
?: throw NotImplementedError("Implicit invoke calls only support a limited range of receivers at the moment.")
val returnType = data.embedType(implicitInvokeCall.resolvedType)
val receiverSymbol = receiver.calleeReference.toResolvedSymbol<FirBasedSymbol<*>>()!!
val args = implicitInvokeCall.argumentList.arguments.map(data::convert)
return when (val exp = data.embedLocalSymbol(receiverSymbol).ignoringMetaNodes()) {
is LambdaExp -> {
// The lambda is already the receiver, so we do not need to convert it.
// TODO: do this more uniformly: convert the receiver, see it is a lambda, use insertCall on it.
exp.insertCall(args, data)
exp.insertCall(args, data, returnType)
}
else -> {
val retType = data.embedType(implicitInvokeCall.toResolvedCallableSymbol()!!.resolvedReturnType)
InvokeFunctionObject(data.convert(receiver), args, retType)
InvokeFunctionObject(data.convert(receiver), args, returnType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@ package org.jetbrains.kotlin.formver.embeddings

import org.jetbrains.kotlin.formver.conversion.StmtConversionContext
import org.jetbrains.kotlin.formver.embeddings.callables.CallableEmbedding
import org.jetbrains.kotlin.formver.embeddings.callables.insertCall
import org.jetbrains.kotlin.formver.embeddings.expression.ExpEmbedding

class CustomGetter(val getterMethod: CallableEmbedding) : GetterEmbedding {
override fun getValue(
receiver: ExpEmbedding,
ctx: StmtConversionContext,
): ExpEmbedding = getterMethod.insertCall(listOf(receiver), ctx)
): ExpEmbedding = getterMethod.insertCall(listOf(receiver), ctx, getterMethod.type.returnType)
}

class CustomSetter(val setterMethod: CallableEmbedding) : SetterEmbedding {
override fun setValue(
receiver: ExpEmbedding,
value: ExpEmbedding,
ctx: StmtConversionContext,
): ExpEmbedding = setterMethod.insertCall(listOf(receiver, value), ctx)
): ExpEmbedding = setterMethod.insertCall(listOf(receiver, value), ctx, setterMethod.type.returnType)
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ fun TypeBuilder.nullableAny(): AnyPretypeBuilder {

fun buildType(init: TypeBuilder.() -> PretypeBuilder): TypeEmbedding = TypeBuilder().complete(init)

fun TypeEmbedding.equalToType(init: TypeBuilder.() -> PretypeBuilder) = equals(buildType { init() })

fun buildFunctionType(init: FunctionPretypeBuilder.() -> Unit): FunctionTypeEmbedding =
buildType { function { init() } } as FunctionTypeEmbedding

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.jetbrains.kotlin.formver.names.NameMatcher
import org.jetbrains.kotlin.formver.viper.MangledName
import org.jetbrains.kotlin.formver.viper.ast.Exp
import org.jetbrains.kotlin.formver.viper.mangled
import org.jetbrains.kotlin.utils.addToStdlib.ifFalse

/**
* Represents our representation of a Kotlin type.
Expand Down Expand Up @@ -183,4 +184,3 @@ val TypeEmbedding.isCollectionInheritor
get() = isInheritorOfCollectionTypeNamed("Collection")

fun TypeEmbedding.subTypeInvariant() = SubTypeInvariantEmbedding(this)

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ package org.jetbrains.kotlin.formver.embeddings.callables

import org.jetbrains.kotlin.formver.conversion.StmtConversionContext
import org.jetbrains.kotlin.formver.embeddings.FunctionTypeEmbedding
import org.jetbrains.kotlin.formver.embeddings.TypeEmbedding
import org.jetbrains.kotlin.formver.embeddings.expression.ExpEmbedding
import org.jetbrains.kotlin.formver.embeddings.expression.withNewTypeInvariants

/**
* Kotlin entity that can be called.
Expand All @@ -16,3 +18,9 @@ interface CallableEmbedding {
val type: FunctionTypeEmbedding
fun insertCall(args: List<ExpEmbedding>, ctx: StmtConversionContext): ExpEmbedding
}

fun CallableEmbedding.insertCall(args: List<ExpEmbedding>, ctx: StmtConversionContext, actualReturnType: TypeEmbedding) =
insertCall(args, ctx).withNewTypeInvariants(actualReturnType) {
access = true
proven = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ inline fun ExpEmbedding.withInvariants(block: InhaleInvariantsBuilder.() -> Unit
return builder.complete()
}

fun ExpEmbedding.withIsUnitInvariantIfUnit() = withInvariants {
proven = type.equalToType { unit() }
}

inline fun ExpEmbedding.withNewTypeInvariants(newType: TypeEmbedding, block: InhaleInvariantsBuilder.() -> Unit) =
if (this.type == newType) this else withType(newType).withInvariants(block)

Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ method f$compoundConditionalEffect$TF$T$Boolean(p$b: Ref)
ensures true ==> df$rt$boolFromRef(p$b) && false
{
inhale df$rt$isSubtype(df$rt$typeOf(p$b), df$rt$boolType())
ret$0 := df$rt$unitValue()
label lbl$ret$0
inhale df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
}

/cond_effects.kt:(190,220): warning: Cannot verify that if the function returns then (b && false).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ method f$nullableNotNonNullable$TF$Int(p$x: Ref) returns (ret$0: Ref)
ensures true ==> df$rt$isSubtype(df$rt$typeOf(p$x), df$rt$intType())
{
inhale df$rt$isSubtype(df$rt$typeOf(p$x), df$rt$nullable(df$rt$intType()))
ret$0 := df$rt$unitValue()
label lbl$ret$0
inhale df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
}

/is_type_contract.kt:(376,404): warning: Cannot verify that if the function returns then x is Int.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ method f$empty_list_expr_get$TF$() returns (ret$0: Ref)
{
var l0$s: Ref
var anon$0: Ref
ret$0 := df$rt$unitValue()
anon$0 := f$pkg$kotlin_collections$emptyList$TF$()
l0$s := f$pkg$kotlin_collections$c$List$get$TF$T$c$pkg$kotlin_collections$List$T$Int(anon$0,
df$rt$intToRef(0))
label lbl$ret$0
inhale df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
}

method f$pkg$kotlin_collections$c$List$get$TF$T$c$pkg$kotlin_collections$List$T$Int(this: Ref,
Expand Down Expand Up @@ -45,11 +45,11 @@ method f$empty_list_get$TF$() returns (ret$0: Ref)
{
var l0$myList: Ref
var l0$s: Ref
ret$0 := df$rt$unitValue()
l0$myList := f$pkg$kotlin_collections$emptyList$TF$()
l0$s := f$pkg$kotlin_collections$c$List$get$TF$T$c$pkg$kotlin_collections$List$T$Int(l0$myList,
df$rt$intToRef(0))
label lbl$ret$0
inhale df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
}

method f$pkg$kotlin_collections$c$List$get$TF$T$c$pkg$kotlin_collections$List$T$Int(this: Ref,
Expand Down Expand Up @@ -129,12 +129,12 @@ method f$add_get$TF$T$c$pkg$kotlin_collections$MutableList(p$l: Ref)
var l0$n: Ref
inhale df$rt$isSubtype(df$rt$typeOf(p$l), df$rt$T$c$pkg$kotlin_collections$MutableList())
inhale acc(p$pkg$kotlin_collections$c$MutableList$shared(p$l), wildcard)
ret$0 := df$rt$unitValue()
anon$0 := f$pkg$kotlin_collections$c$MutableList$add$TF$T$c$pkg$kotlin_collections$MutableList$T$Int(p$l,
df$rt$intToRef(1))
l0$n := f$pkg$kotlin_collections$c$MutableList$get$TF$T$c$pkg$kotlin_collections$MutableList$T$Int(p$l,
df$rt$intToRef(1))
label lbl$ret$0
inhale df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
}

method f$pkg$kotlin_collections$c$MutableList$add$TF$T$c$pkg$kotlin_collections$MutableList$T$Int(this: Ref,
Expand Down Expand Up @@ -173,11 +173,11 @@ method f$empty_list_sub$TF$() returns (ret$0: Ref)
{
var l0$l: Ref
var anon$0: Ref
ret$0 := df$rt$unitValue()
anon$0 := f$pkg$kotlin_collections$emptyList$TF$()
l0$l := f$pkg$kotlin_collections$c$List$subList$TF$T$c$pkg$kotlin_collections$List$T$Int$T$Int(anon$0,
df$rt$intToRef(0), df$rt$intToRef(1))
label lbl$ret$0
inhale df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
}

method f$pkg$kotlin_collections$c$List$subList$TF$T$c$pkg$kotlin_collections$List$T$Int$T$Int(this: Ref,
Expand Down Expand Up @@ -218,11 +218,11 @@ method f$empty_list_sub_negative$TF$() returns (ret$0: Ref)
{
var l0$l: Ref
var anon$0: Ref
ret$0 := df$rt$unitValue()
anon$0 := f$pkg$kotlin_collections$emptyList$TF$()
l0$l := f$pkg$kotlin_collections$c$List$subList$TF$T$c$pkg$kotlin_collections$List$T$Int$T$Int(anon$0,
df$rt$intToRef(-1), df$rt$intToRef(1))
label lbl$ret$0
inhale df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
}

method f$pkg$kotlin_collections$c$List$subList$TF$T$c$pkg$kotlin_collections$List$T$Int$T$Int(this: Ref,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
method f$verify_false$TF$() returns (ret$0: Ref)
ensures df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
{
ret$0 := df$rt$unitValue()
assert false
label lbl$ret$0
inhale df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
}

/viper_verify.kt:(153,158): warning: Viper verification error: Assert might fail. Assertion false might not hold.
Expand All @@ -14,13 +14,13 @@ method f$verify_compound$TF$() returns (ret$0: Ref)
ensures df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
{
var anon$0: Ref
ret$0 := df$rt$unitValue()
if (true) {
anon$0 := df$rt$boolToRef(false)
} else {
anon$0 := df$rt$boolToRef(false)}
assert df$rt$boolFromRef(anon$0)
label lbl$ret$0
inhale df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
}

/viper_verify.kt:(212,225): warning: Viper verification error: Assert might fail. Assertion df$rt$boolFromRef(anon$0) might not hold.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ method f$test$TF$T$Int(p$n: Ref) returns (ret$0: Ref)
var l0$customList: Ref
var anon$0: Ref
inhale df$rt$isSubtype(df$rt$typeOf(p$n), df$rt$intType())
ret$0 := df$rt$unitValue()
l0$customList := con$c$CustomList$T$Int$T$Int(p$n, df$rt$intToRef(0))
anon$0 := f$c$CustomList$isEmpty$TF$T$c$CustomList(l0$customList)
if (!df$rt$boolFromRef(anon$0)) {
Expand All @@ -84,4 +83,5 @@ method f$test$TF$T$Int(p$n: Ref) returns (ret$0: Ref)
anon$3 := f$c$CustomList$get$TF$T$c$CustomList$T$Int(l0$customList, df$rt$intToRef(0))
}
label lbl$ret$0
inhale df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$unitType())
}
Loading

0 comments on commit 2da4599

Please sign in to comment.