Skip to content

Commit

Permalink
Add workaround for generated syntax with missing type info
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Dec 14, 2023
1 parent 5480992 commit ab465f8
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 6 deletions.
14 changes: 10 additions & 4 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
return v
}

signature := copyFunctionType(functionTypeOf(fn))
signature.TypeParams = nil

// The function syntax may be generic, requiring translation of type param
// placeholders to known type args.
var typeArg func(*types.TypeParam) types.Type
if g != nil {
typeArg = g.typeArgOf
}

signature := copyFunctionType(functionTypeOf(fn))
signature.TypeParams = nil

recv := copyFieldList(functionRecvOf(fn))
for _, fields := range []*ast.FieldList{recv, signature.Params, signature.Results} {
if fields != nil {
Expand Down Expand Up @@ -126,7 +128,11 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
for _, name := range s.Names {
typ := p.TypesInfo.TypeOf(name)
if typ == nil {
scope.insert(name, s.Type)
// FIXME: this means that TypesInfo was not updated when syntax was
// generated or mutated. The following workaround is required as a
// result.
e := substituteTypeArgs(p, s.Type, typeArg)
scope.insert(name, e)
} else {
scope.insert(name, typeExpr(p, typ, typeArg))
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/testdata/coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -3577,8 +3577,8 @@ func init() {
D uintptr
X0 *struct {
IP int
X0 *IdentityGenericStruct[T]
X1 T
X0 *IdentityGenericStruct[int]
X1 int
}
}]("github.com/stealthrocket/coroutine/compiler/testdata.(*IdentityGenericStruct[go.shape.int]).Closure.func2")
_types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.(*IdentityGenericStruct[go.shape.int]).Run")
Expand Down
77 changes: 77 additions & 0 deletions compiler/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,80 @@ func newFields(p *packages.Package, tuple *types.Tuple, typeArg func(*types.Type
}
return fields
}

// substituteTypeArgs replaces all type parameter placeholders
// with type args.
//
// It returns a deep copy of the input expr.
func substituteTypeArgs(p *packages.Package, expr ast.Expr, typeArg func(*types.TypeParam) types.Type) ast.Expr {
if expr == nil {
return nil
}
switch e := expr.(type) {
case *ast.ArrayType:
return &ast.ArrayType{
Elt: substituteTypeArgs(p, e.Elt, typeArg),
Len: substituteTypeArgs(p, e.Len, typeArg),
}
case *ast.MapType:
return &ast.MapType{
Key: substituteTypeArgs(p, e.Key, typeArg),
Value: substituteTypeArgs(p, e.Value, typeArg),
}
case *ast.FuncType:
return &ast.FuncType{
TypeParams: substituteFieldList(p, e.TypeParams, typeArg),
Params: substituteFieldList(p, e.Params, typeArg),
Results: substituteFieldList(p, e.Results, typeArg),
}
case *ast.ChanType:
return &ast.ChanType{
Dir: e.Dir,
Value: substituteTypeArgs(p, e.Value, typeArg),
}
case *ast.StructType:
return &ast.StructType{
Fields: substituteFieldList(p, e.Fields, typeArg),
}
case *ast.StarExpr:
return &ast.StarExpr{
X: substituteTypeArgs(p, e.X, typeArg),
}
case *ast.SelectorExpr:
return &ast.SelectorExpr{
X: substituteTypeArgs(p, e.X, typeArg),
Sel: e.Sel,
}
case *ast.IndexExpr:
return &ast.IndexExpr{
X: substituteTypeArgs(p, e.X, typeArg),
Index: substituteTypeArgs(p, e.Index, typeArg),
}
case *ast.Ident:
t := p.TypesInfo.TypeOf(e)
tp, ok := t.(*types.TypeParam)
if !ok {
return e
}
return typeExpr(p, typeArg(tp), typeArg)
case *ast.BasicLit:
return e
default:
panic(fmt.Sprintf("not implemented: %T", e))
}
}

func substituteFieldList(p *packages.Package, f *ast.FieldList, typeArg func(*types.TypeParam) types.Type) *ast.FieldList {
if f == nil || f.List == nil {
return f
}
fields := make([]*ast.Field, len(f.List))
for i, field := range f.List {
fields[i] = &ast.Field{
Names: field.Names,
Type: substituteTypeArgs(p, field.Type, typeArg),
Tag: field.Tag,
}
}
return &ast.FieldList{List: fields}
}

0 comments on commit ab465f8

Please sign in to comment.