diff --git a/internal/typeparams/common.go b/internal/typeparams/common.go index 1222764b6a3..53b0696a1d8 100644 --- a/internal/typeparams/common.go +++ b/internal/typeparams/common.go @@ -77,3 +77,35 @@ func IsTypeParam(t types.Type) bool { _, ok := t.(*TypeParam) return ok } + +// OriginMethod returns the origin method associated with the method fn. +// For methods on a non-generic receiver base type, this is just +// fn. However, for methods with a generic receiver, OriginMethod returns the +// corresponding method in the method set of the origin type. +// +// As a special case, if fn is not a method (has no receiver), OriginMethod +// returns fn. +func OriginMethod(fn *types.Func) *types.Func { + recv := fn.Type().(*types.Signature).Recv() + if recv == nil { + + return fn + } + base := recv.Type() + p, isPtr := base.(*types.Pointer) + if isPtr { + base = p.Elem() + } + named, isNamed := base.(*types.Named) + if !isNamed { + // Receiver is a *types.Interface. + return fn + } + if ForNamed(named).Len() == 0 { + // Receiver base has no type parameters, so we can avoid the lookup below. + return fn + } + orig := NamedTypeOrigin(named) + gfn, _, _ := types.LookupFieldOrMethod(orig, true, fn.Pkg(), fn.Name()) + return gfn.(*types.Func) +} diff --git a/internal/typeparams/common_test.go b/internal/typeparams/common_test.go index 1bd15d794bc..da084d173f4 100644 --- a/internal/typeparams/common_test.go +++ b/internal/typeparams/common_test.go @@ -6,16 +6,20 @@ package typeparams_test import ( "go/ast" + "go/parser" + "go/token" + "go/types" "testing" - "golang.org/x/tools/internal/typeparams" + "golang.org/x/tools/internal/testenv" + . "golang.org/x/tools/internal/typeparams" ) func TestGetIndexExprData(t *testing.T) { x := &ast.Ident{} i := &ast.Ident{} - want := &typeparams.IndexListExpr{X: x, Lbrack: 1, Indices: []ast.Expr{i}, Rbrack: 2} + want := &IndexListExpr{X: x, Lbrack: 1, Indices: []ast.Expr{i}, Rbrack: 2} tests := map[ast.Node]bool{ &ast.IndexExpr{X: x, Lbrack: 1, Index: i, Rbrack: 2}: true, want: true, @@ -23,7 +27,7 @@ func TestGetIndexExprData(t *testing.T) { } for n, isIndexExpr := range tests { - X, lbrack, indices, rbrack := typeparams.UnpackIndexExpr(n) + X, lbrack, indices, rbrack := UnpackIndexExpr(n) if got := X != nil; got != isIndexExpr { t.Errorf("UnpackIndexExpr(%v) = %v, _, _, _; want nil: %t", n, x, !isIndexExpr) } @@ -35,3 +39,121 @@ func TestGetIndexExprData(t *testing.T) { } } } + +func TestOriginMethodRecursive(t *testing.T) { + testenv.NeedsGo1Point(t, 18) + src := `package p + +type N[A any] int + +func (r N[B]) m() { r.m(); r.n() } + +func (r *N[C]) n() { } +` + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "p.go", src, 0) + if err != nil { + t.Fatal(err) + } + info := types.Info{ + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + } + var conf types.Config + if _, err := conf.Check("p", fset, []*ast.File{f}, &info); err != nil { + t.Fatal(err) + } + + // Collect objects from types.Info. + var m, n *types.Func // the 'origin' methods in Info.Defs + var mm, mn *types.Func // the methods used in the body of m + + for _, decl := range f.Decls { + fdecl, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + def := info.Defs[fdecl.Name].(*types.Func) + switch fdecl.Name.Name { + case "m": + m = def + ast.Inspect(fdecl.Body, func(n ast.Node) bool { + if call, ok := n.(*ast.CallExpr); ok { + sel := call.Fun.(*ast.SelectorExpr) + use := info.Uses[sel.Sel].(*types.Func) + switch sel.Sel.Name { + case "m": + mm = use + case "n": + mn = use + } + } + return true + }) + case "n": + n = def + } + } + + tests := []struct { + name string + input, want *types.Func + }{ + {"declared m", m, m}, + {"declared n", n, n}, + {"used m", mm, m}, + {"used n", mn, n}, + } + + for _, test := range tests { + if got := OriginMethod(test.input); got != test.want { + t.Errorf("OriginMethod(%q) = %v, want %v", test.name, test.input, test.want) + } + } +} + +func TestOriginMethodUses(t *testing.T) { + testenv.NeedsGo1Point(t, 18) + + tests := []string{ + `type T interface { m() }; func _(t T) { t.m() }`, + `type T[P any] interface { m() P }; func _[A any](t T[A]) { t.m() }`, + `type T[P any] interface { m() P }; func _(t T[int]) { t.m() }`, + `type T[P any] int; func (r T[A]) m() { r.m() }`, + `type T[P any] int; func (r *T[A]) m() { r.m() }`, + `type T[P any] int; func (r *T[A]) m() {}; func _(t T[int]) { t.m() }`, + `type T[P any] int; func (r *T[A]) m() {}; func _[A any](t T[A]) { t.m() }`, + } + + for _, src := range tests { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "p.go", "package p; "+src, 0) + if err != nil { + t.Fatal(err) + } + info := types.Info{ + Uses: make(map[*ast.Ident]types.Object), + } + var conf types.Config + pkg, err := conf.Check("p", fset, []*ast.File{f}, &info) + if err != nil { + t.Fatal(err) + } + + T := pkg.Scope().Lookup("T").Type() + obj, _, _ := types.LookupFieldOrMethod(T, true, pkg, "m") + m := obj.(*types.Func) + + ast.Inspect(f, func(n ast.Node) bool { + if call, ok := n.(*ast.CallExpr); ok { + sel := call.Fun.(*ast.SelectorExpr) + use := info.Uses[sel.Sel].(*types.Func) + orig := OriginMethod(use) + if orig != m { + t.Errorf("%s:\nUses[%v] = %v, want %v", src, types.ExprString(sel), use, m) + } + } + return true + }) + } +}