Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scoped receivers #243

Merged
merged 3 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ interface MethodConversionContext : ProgramConversionContext {
fun resolveLocal(name: Name): VariableEmbedding
fun registerLocalProperty(symbol: FirPropertySymbol)
fun registerLocalVariable(symbol: FirVariableSymbol<*>)
fun resolveReceiver(isExtension: Boolean): ExpEmbedding?
fun resolveDispatchReceiver(): ExpEmbedding?
fun resolveExtensionReceiver(labelName: String): ExpEmbedding?

fun <R> withScopeImpl(scopeDepth: Int, action: () -> R): R
fun addLoopIdentifier(labelName: String, index: Int)
fun resolveLoopIndex(name: String): Int
fun resolveNamedReturnTarget(sourceName: String): ReturnTarget?
fun resolveNamedReturnTarget(labelName: String): ReturnTarget?
}

fun MethodConversionContext.resolveReturnTarget(targetSourceName: String?): ReturnTarget =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ class MethodConverter(
paramResolver.tryResolveParameter(name) ?: parent?.resolveParameter(name)
?: throw IllegalArgumentException("Parameter $name not found in scope.")

override fun resolveReceiver(isExtension: Boolean): ExpEmbedding? =
paramResolver.tryResolveReceiver(isExtension) ?: parent?.resolveReceiver(isExtension)
override fun resolveDispatchReceiver(): ExpEmbedding? =
paramResolver.tryResolveDispatchReceiver() ?: parent?.resolveDispatchReceiver()

override fun resolveExtensionReceiver(labelName: String): ExpEmbedding? =
paramResolver.tryResolveExtensionReceiver(labelName) ?: parent?.resolveExtensionReceiver(labelName)

override val defaultResolvedReturnTarget = paramResolver.defaultResolvedReturnTarget
override fun resolveNamedReturnTarget(sourceName: String): ReturnTarget? {
return paramResolver.resolveNamedReturnTarget(sourceName) ?: parent?.resolveNamedReturnTarget(sourceName)
}
override fun resolveNamedReturnTarget(labelName: String): ReturnTarget? =
paramResolver.resolveNamedReturnTarget(labelName) ?: parent?.resolveNamedReturnTarget(labelName)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,38 @@ import org.jetbrains.kotlin.utils.addToStdlib.ifTrue
*/
interface ParameterResolver {
fun tryResolveParameter(name: Name): ExpEmbedding?
fun tryResolveReceiver(isExtension: Boolean): ExpEmbedding?
fun tryResolveDispatchReceiver(): ExpEmbedding?
fun tryResolveExtensionReceiver(labelName: String): ExpEmbedding?

val sourceName: String?
val labelName: String?
val defaultResolvedReturnTarget: ReturnTarget
}

fun ParameterResolver.resolveNamedReturnTarget(returnPointName: String): ReturnTarget? =
(returnPointName == sourceName).ifTrue { defaultResolvedReturnTarget }
(returnPointName == labelName).ifTrue { defaultResolvedReturnTarget }

class RootParameterResolver(
val ctx: ProgramConversionContext,
private val signature: FunctionSignature,
override val sourceName: String?,
override val labelName: String?,
override val defaultResolvedReturnTarget: ReturnTarget,
) : ParameterResolver {
private val parameters = signature.params.associateBy { it.name }
override fun tryResolveParameter(name: Name): ExpEmbedding? = parameters[name.embedParameterName()]
override fun tryResolveReceiver(isExtension: Boolean) =
if (isExtension) signature.extensionReceiver
else signature.dispatchReceiver
override fun tryResolveDispatchReceiver() = signature.dispatchReceiver
override fun tryResolveExtensionReceiver(labelName: String) = (labelName == this.labelName).ifTrue {
signature.extensionReceiver
}
}

class InlineParameterResolver(
private val substitutions: Map<Name, ExpEmbedding>,
override val sourceName: String?,
override val labelName: String?,
override val defaultResolvedReturnTarget: ReturnTarget,
) : ParameterResolver {
override fun tryResolveParameter(name: Name): ExpEmbedding? = substitutions[name]
override fun tryResolveReceiver(isExtension: Boolean): ExpEmbedding? =
if (isExtension) substitutions[ExtraSpecialNames.EXTENSION_THIS]
else substitutions[ExtraSpecialNames.DISPATCH_THIS]
override fun tryResolveDispatchReceiver() = substitutions[ExtraSpecialNames.DISPATCH_THIS]
override fun tryResolveExtensionReceiver(labelName: String) = (labelName == this.labelName).ifTrue {
substitutions[ExtraSpecialNames.EXTENSION_THIS]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
private fun embedFullSignature(symbol: FirFunctionSymbol<*>): FullNamedFunctionSignature {
val subSignature = object : NamedFunctionSignature, FunctionSignature by embedFunctionSignature(symbol) {
override val name = symbol.embedName(this@ProgramConverter)
override val sourceName: String?
get() = super<NamedFunctionSignature>.sourceName
override val labelName: String
get() = super<NamedFunctionSignature>.labelName
override val symbol = symbol
}
val constructorParamSymbolsToFields = extractConstructorParamsAsFields(symbol)
val contractVisitor = ContractDescriptionConversionVisitor(this@ProgramConverter, subSignature)
Expand Down Expand Up @@ -371,7 +372,7 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
private fun processCallable(symbol: FirFunctionSymbol<*>, signature: FullNamedFunctionSignature): RichCallableEmbedding {
val body = symbol.fir.body
return if (symbol.isInline && body != null) {
InlineNamedFunction(signature, symbol, body)
InlineNamedFunction(signature, body)
} else {
// We generate a dummy method header here to ensure all required types are processed already. If we skip this, any types
// that are used only in contracts cause an error because they are not processed until too late.
Expand All @@ -387,7 +388,7 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
MethodConverter(
this,
signature,
RootParameterResolver(this, signature, signature.sourceName, returnTarget),
RootParameterResolver(this, signature, signature.labelName, returnTarget),
scopeDepth = scopeIndexProducer.getFresh(),
)
val stmtCtx = StmtConverter(methodCtx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.jetbrains.kotlin.fir.expressions.impl.FirElseIfTrueCondition
import org.jetbrains.kotlin.fir.expressions.impl.FirUnitExpression
import org.jetbrains.kotlin.fir.references.toResolvedSymbol
import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirAnonymousFunctionSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
import org.jetbrains.kotlin.fir.types.coneType
Expand Down Expand Up @@ -319,14 +320,16 @@ object StmtConversionVisitor : FirVisitor<ExpEmbedding, StmtConversionContext>()
// `thisReceiverExpression` has a bound symbol which can be used for lookup
// for extensions `this`es the bound symbol is the function they originate from
// for member functions the bound symbol is a class they're defined in
// TODO: conduct more thorough lookup based on the name of this symbol as well
val isExtensionReceiver = when (thisReceiverExpression.calleeReference.boundSymbol) {
is FirClassSymbol<*> -> false
is FirFunctionSymbol<*> -> true
//
// since dispatch receiver can only originate from non-anonymous function we do not specify its name here
// as we have only one candidate to resolve it
val resolved = when (val symbol = thisReceiverExpression.calleeReference.boundSymbol) {
is FirClassSymbol<*> -> data.resolveDispatchReceiver()
is FirAnonymousFunctionSymbol -> data.resolveExtensionReceiver(symbol.label!!.name)
is FirFunctionSymbol<*> -> data.resolveExtensionReceiver(symbol.name.asString())
else -> error("Unsupported receiver expression type.")
}
return data.resolveReceiver(isExtensionReceiver)
?: throw IllegalArgumentException("Can't resolve the 'this' receiver since the function does not have one.")
return resolved ?: throw IllegalArgumentException("Can't resolve the 'this' receiver since the function does not have one.")
}

override fun visitTypeOperatorCall(
Expand Down Expand Up @@ -355,7 +358,7 @@ object StmtConversionVisitor : FirVisitor<ExpEmbedding, StmtConversionContext>()
data: StmtConversionContext,
): ExpEmbedding {
val function = anonymousFunctionExpression.anonymousFunction
return LambdaExp(data.embedFunctionSignature(function.symbol), function, data)
return LambdaExp(data.embedFunctionSignature(function.symbol), function, data, function.symbol.label!!.name)
}

override fun visitTryExpression(tryExpression: FirTryExpression, data: StmtConversionContext): ExpEmbedding {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.jetbrains.kotlin.formver.embeddings.callables

import org.jetbrains.kotlin.KtSourceElement
import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
import org.jetbrains.kotlin.formver.asPosition
import org.jetbrains.kotlin.formver.embeddings.expression.ExpEmbedding
Expand Down Expand Up @@ -44,18 +45,22 @@ interface FullNamedFunctionSignature : NamedFunctionSignature {
*/
abstract class PropertyAccessorFunctionSignature(
override val name: MangledName,
symbol: FirPropertySymbol,
propertySymbol: FirPropertySymbol,
) : FullNamedFunctionSignature, GenericFunctionSignatureMixin() {
override fun getPreconditions(returnVariable: VariableEmbedding) = emptyList<ExpEmbedding>()
override fun getPostconditions(returnVariable: VariableEmbedding) = emptyList<ExpEmbedding>()
override val dispatchReceiver: VariableEmbedding
get() = PlaceholderVariableEmbedding(DispatchReceiverName, buildType { nullableAny() })
override val extensionReceiver = null
override val declarationSource: KtSourceElement? = symbol.source
override val declarationSource: KtSourceElement? = propertySymbol.source
}

class GetterFunctionSignature(name: MangledName, symbol: FirPropertySymbol) :
PropertyAccessorFunctionSignature(name, symbol) {
override val symbol: FirFunctionSymbol<*>
get() = error {
"Getter symbol should not be accessed directly as it is allowed to be null in some cases."
}
override val callableType: FunctionTypeEmbedding = buildFunctionPretype {
withDispatchReceiver { nullableAny() }
withReturnType { nullableAny() }
Expand All @@ -64,6 +69,10 @@ class GetterFunctionSignature(name: MangledName, symbol: FirPropertySymbol) :

class SetterFunctionSignature(name: MangledName, symbol: FirPropertySymbol) :
PropertyAccessorFunctionSignature(name, symbol) {
override val symbol: FirFunctionSymbol<*>
get() = error {
"Setter symbol should not be accessed directly as it is allowed to be null in some cases."
}
override val callableType: FunctionTypeEmbedding = buildFunctionPretype {
withDispatchReceiver { nullableAny() }
withParam { nullableAny() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ interface FunctionSignature {

val params: List<VariableEmbedding>

val sourceName: String?
val labelName: String?
get() = null

val formalArgs: List<VariableEmbedding>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import org.jetbrains.kotlin.formver.names.ExtraSpecialNames

class InlineNamedFunction(
val signature: FullNamedFunctionSignature,
val symbol: FirFunctionSymbol<*>,
val firBody: FirBlock,
) : RichCallableEmbedding, FullNamedFunctionSignature by signature {
override fun insertCall(
Expand All @@ -28,7 +27,7 @@ class InlineNamedFunction(
add(ExtraSpecialNames.EXTENSION_THIS)
addAll(symbol.valueParameterSymbols.map { it.name })
}
return ctx.insertInlineFunctionCall(signature, paramNames, args, firBody, signature.sourceName)
return ctx.insertInlineFunctionCall(signature, paramNames, args, firBody, signature.labelName)
}

override fun toViperMethodHeader(): Nothing? = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,17 @@

package org.jetbrains.kotlin.formver.embeddings.callables

import org.jetbrains.kotlin.formver.names.FunctionKotlinName
import org.jetbrains.kotlin.formver.names.ScopedKotlinName
import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
import org.jetbrains.kotlin.formver.viper.MangledName
import org.jetbrains.kotlin.formver.viper.ast.*

interface NamedFunctionSignature : FunctionSignature {
val name: MangledName
// TODO: Clean this up; if we want a source name, we should be storing a symbol.
override val sourceName: String?
get() = when (val signatureName = name) {
is FunctionKotlinName -> signatureName.name.asString()
is ScopedKotlinName -> (signatureName.name as? FunctionKotlinName)?.name?.asString()
else -> null
}

val symbol: FirFunctionSymbol<*>

override val labelName: String
get() = symbol.name.asString()
}

fun NamedFunctionSignature.toMethodCall(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LambdaExp(
val signature: FunctionSignature,
val function: FirAnonymousFunction,
private val parentCtx: MethodConversionContext,
override val labelName: String,
) : CallableEmbedding, StoredResultExpEmbedding,
FunctionSignature by signature {
override val type: TypeEmbedding
Expand All @@ -44,7 +45,7 @@ class LambdaExp(
receiverParamNames + nonReceiverParamNames,
args,
inlineBody,
ctx.signature.sourceName,
labelName,
parentCtx,
)
}
Expand Down
Loading
Loading