Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for generic function declarations #2463

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 5 additions & 52 deletions runtime/parser/declaration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1246,13 +1246,7 @@ func TestParseFunctionDeclaration(t *testing.T) {

t.Parallel()

result, errs := ParseDeclarations(
nil,
[]byte("fun foo < > () {}"),
Config{
TypeParametersEnabled: true,
},
)
result, errs := testParseDeclarations("fun foo < > () {}")
require.Empty(t, errs)

utils.AssertEqualWithDiff(t,
Expand Down Expand Up @@ -1295,13 +1289,7 @@ func TestParseFunctionDeclaration(t *testing.T) {

t.Parallel()

result, errs := ParseDeclarations(
nil,
[]byte("fun foo < A > () {}"),
Config{
TypeParametersEnabled: true,
},
)
result, errs := testParseDeclarations("fun foo < A > () {}")
require.Empty(t, errs)

utils.AssertEqualWithDiff(t,
Expand Down Expand Up @@ -1351,13 +1339,7 @@ func TestParseFunctionDeclaration(t *testing.T) {

t.Parallel()

result, errs := ParseDeclarations(
nil,
[]byte("fun foo < A , B : C > () {}"),
Config{
TypeParametersEnabled: true,
},
)
result, errs := testParseDeclarations("fun foo < A , B : C > () {}")
require.Empty(t, errs)

utils.AssertEqualWithDiff(t,
Expand Down Expand Up @@ -1418,34 +1400,11 @@ func TestParseFunctionDeclaration(t *testing.T) {
)
})

t.Run("with type parameters, disabled", func(t *testing.T) {

t.Parallel()

_, errs := testParseDeclarations("fun foo<A>() {}")

utils.AssertEqualWithDiff(t,
[]error{
&SyntaxError{
Message: "expected '(' as start of parameter list, got '<'",
Pos: ast.Position{Offset: 7, Line: 1, Column: 7},
},
},
errs,
)
})

t.Run("missing type parameter list end, enabled", func(t *testing.T) {

t.Parallel()

_, errs := ParseDeclarations(
nil,
[]byte("fun foo < "),
Config{
TypeParametersEnabled: true,
},
)
_, errs := testParseDeclarations("fun foo < ")

utils.AssertEqualWithDiff(t,
[]error{
Expand All @@ -1462,13 +1421,7 @@ func TestParseFunctionDeclaration(t *testing.T) {

t.Parallel()

_, errs := ParseDeclarations(
nil,
[]byte("fun foo < A B > () { } "),
Config{
TypeParametersEnabled: true,
},
)
_, errs := testParseDeclarations("fun foo < A B > () { } ")

utils.AssertEqualWithDiff(t,
[]error{
Expand Down
10 changes: 4 additions & 6 deletions runtime/parser/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,12 +313,10 @@ func parseFunctionDeclaration(

var typeParameterList *ast.TypeParameterList

if p.config.TypeParametersEnabled {
var err error
typeParameterList, err = parseTypeParameterList(p)
if err != nil {
return nil, err
}
var err error
typeParameterList, err = parseTypeParameterList(p)
if err != nil {
return nil, err
}

parameterList, returnTypeAnnotation, functionBlock, err :=
Expand Down
2 changes: 0 additions & 2 deletions runtime/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ type Config struct {
StaticModifierEnabled bool
// NativeModifierEnabled determines if the native modifier is enabled
NativeModifierEnabled bool
// TypeParametersEnabled determines if type parameters are enabled
TypeParametersEnabled bool
}

type parser struct {
Expand Down
10 changes: 4 additions & 6 deletions runtime/parser/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,10 @@ func parseFunctionDeclarationOrFunctionExpressionStatement(p *parser) (ast.State

var typeParameterList *ast.TypeParameterList

if p.config.TypeParametersEnabled {
var err error
typeParameterList, err = parseTypeParameterList(p)
if err != nil {
return nil, err
}
var err error
typeParameterList, err = parseTypeParameterList(p)
if err != nil {
return nil, err
}

parameterList, returnTypeAnnotation, functionBlock, err :=
Expand Down
7 changes: 6 additions & 1 deletion runtime/sema/check_composite_declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -1758,7 +1758,11 @@ func (checker *Checker) defaultMembersAndOrigins(

identifier := function.Identifier.Identifier

functionType := checker.functionType(function.ParameterList, function.ReturnTypeAnnotation)
functionType := checker.functionType(
function.TypeParameterList,
function.ParameterList,
function.ReturnTypeAnnotation,
)

argumentLabels := function.ParameterList.EffectiveArgumentLabels()

Expand Down Expand Up @@ -2035,6 +2039,7 @@ func (checker *Checker) checkSpecialFunction(
}

checker.checkFunction(
nil,
specialFunction.FunctionDeclaration.ParameterList,
nil,
functionType,
Expand Down
66 changes: 62 additions & 4 deletions runtime/sema/check_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package sema
import (
"github.com/onflow/cadence/runtime/ast"
"github.com/onflow/cadence/runtime/common"
"github.com/onflow/cadence/runtime/errors"
)

func (checker *Checker) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration) (_ struct{}) {
Expand Down Expand Up @@ -80,7 +81,11 @@ func (checker *Checker) visitFunctionDeclaration(

functionType := checker.Elaboration.FunctionDeclarationFunctionType(declaration)
if functionType == nil {
functionType = checker.functionType(declaration.ParameterList, declaration.ReturnTypeAnnotation)
functionType = checker.functionType(
declaration.TypeParameterList,
declaration.ParameterList,
declaration.ReturnTypeAnnotation,
)

if options.declareFunction {
checker.declareFunctionDeclaration(declaration, functionType)
Expand All @@ -90,6 +95,7 @@ func (checker *Checker) visitFunctionDeclaration(
checker.Elaboration.SetFunctionDeclarationFunctionType(declaration, functionType)

checker.checkFunction(
declaration.TypeParameterList,
declaration.ParameterList,
declaration.ReturnTypeAnnotation,
functionType,
Expand Down Expand Up @@ -125,6 +131,7 @@ func (checker *Checker) declareFunctionDeclaration(
}

func (checker *Checker) checkFunction(
typeParameterList *ast.TypeParameterList,
parameterList *ast.ParameterList,
returnTypeAnnotation *ast.TypeAnnotation,
functionType *FunctionType,
Expand All @@ -133,6 +140,47 @@ func (checker *Checker) checkFunction(
initializationInfo *InitializationInfo,
checkResourceLoss bool,
) {
// If type parameters are given,
// resolve generic types in the function type
// to the type bounds of the type parameters.
//
// Type parameters must have type bounds,
// to at least determine resource-kindedness
// (A function cannot be written in a way that it supports
// either resources or non-resources.)

typeParameters := functionType.TypeParameters
if len(typeParameters) > 0 {

typeArguments := &TypeParameterTypeOrderedMap{}

for typeParameterIndex, typeParameter := range typeParameters {

typeBound := typeParameter.TypeBound
if typeBound == nil {
astTypeParameter := typeParameterList.TypeParameters[typeParameterIndex]

checker.report(&MissingTypeParameterTypeBoundError{
Name: typeParameter.Name,
Range: ast.NewUnmeteredRangeFromPositioned(astTypeParameter),
})
continue
}

typeArguments.Set(typeParameter, typeBound)
}

resolvedType := functionType.Resolve(typeArguments)

if resolvedType != nil {
var ok bool
functionType, ok = resolvedType.(*FunctionType)
if !ok {
panic(errors.NewUnreachableError())
}
}
}

// check argument labels
checker.checkArgumentLabels(parameterList)

Expand Down Expand Up @@ -414,14 +462,24 @@ func (checker *Checker) declareBefore() {

func (checker *Checker) VisitFunctionExpression(expression *ast.FunctionExpression) Type {

// TODO: add support in parser
var typeParameterList *ast.TypeParameterList
parameterList := expression.ParameterList
returnTypeAnnotation := expression.ReturnTypeAnnotation

// TODO: infer
functionType := checker.functionType(expression.ParameterList, expression.ReturnTypeAnnotation)
functionType := checker.functionType(
typeParameterList,
parameterList,
returnTypeAnnotation,
)

checker.Elaboration.SetFunctionExpressionFunctionType(expression, functionType)

checker.checkFunction(
expression.ParameterList,
expression.ReturnTypeAnnotation,
typeParameterList,
parameterList,
returnTypeAnnotation,
functionType,
expression.FunctionBlock,
true,
Expand Down
8 changes: 8 additions & 0 deletions runtime/sema/check_interface_declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ func (checker *Checker) checkInterfaceFunctions(
}
}

if function.TypeParameterList != nil {
checker.report(
&InvalidTypeParameterizedInterfaceFunctionError{
Range: ast.NewUnmeteredRangeFromPositioned(function.TypeParameterList),
},
)
}

checker.visitFunctionDeclaration(
function,
functionDeclarationOptions{
Expand Down
2 changes: 2 additions & 0 deletions runtime/sema/check_transaction_declaration.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ func (checker *Checker) visitTransactionPrepareFunction(
prepareFunctionType := transactionType.PrepareFunctionType()

checker.checkFunction(
nil,
prepareFunction.FunctionDeclaration.ParameterList,
nil,
prepareFunctionType,
Expand Down Expand Up @@ -231,6 +232,7 @@ func (checker *Checker) visitTransactionExecuteFunction(
executeFunctionType := transactionType.ExecuteFunctionType()

checker.checkFunction(
nil,
&ast.ParameterList{},
nil,
executeFunctionType,
Expand Down
Loading