Skip to content

Commit

Permalink
internal/typeparams: add a helper to return the origin method
Browse files Browse the repository at this point in the history
With instantiated types, method objects are no longer unique: they may
be instantiations of methods with generic receiver. However, some
use-cases require finding the canonical method representing the method
in the source. For these use-cases, provide an OriginMethod helper.

For golang/go#50447

Change-Id: I6f8af3fb5c5eeefb11f8f3bdba54cd6692ca389f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/380554
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
gopls-CI: kokoro <noreply+kokoro@google.com>
Reviewed-by: Tim King <taking@google.com>
  • Loading branch information
findleyr committed Feb 3, 2022
1 parent a739c97 commit ea5e1dc
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 3 deletions.
32 changes: 32 additions & 0 deletions internal/typeparams/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,35 @@ func IsTypeParam(t types.Type) bool {
_, ok := t.(*TypeParam)
return ok
}

// OriginMethod returns the origin method associated with the method fn.
// For methods on a non-generic receiver base type, this is just
// fn. However, for methods with a generic receiver, OriginMethod returns the
// corresponding method in the method set of the origin type.
//
// As a special case, if fn is not a method (has no receiver), OriginMethod
// returns fn.
func OriginMethod(fn *types.Func) *types.Func {
recv := fn.Type().(*types.Signature).Recv()
if recv == nil {

return fn
}
base := recv.Type()
p, isPtr := base.(*types.Pointer)
if isPtr {
base = p.Elem()
}
named, isNamed := base.(*types.Named)
if !isNamed {
// Receiver is a *types.Interface.
return fn
}
if ForNamed(named).Len() == 0 {
// Receiver base has no type parameters, so we can avoid the lookup below.
return fn
}
orig := NamedTypeOrigin(named)
gfn, _, _ := types.LookupFieldOrMethod(orig, true, fn.Pkg(), fn.Name())
return gfn.(*types.Func)
}
128 changes: 125 additions & 3 deletions internal/typeparams/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,28 @@ package typeparams_test

import (
"go/ast"
"go/parser"
"go/token"
"go/types"
"testing"

"golang.org/x/tools/internal/typeparams"
"golang.org/x/tools/internal/testenv"
. "golang.org/x/tools/internal/typeparams"
)

func TestGetIndexExprData(t *testing.T) {
x := &ast.Ident{}
i := &ast.Ident{}

want := &typeparams.IndexListExpr{X: x, Lbrack: 1, Indices: []ast.Expr{i}, Rbrack: 2}
want := &IndexListExpr{X: x, Lbrack: 1, Indices: []ast.Expr{i}, Rbrack: 2}
tests := map[ast.Node]bool{
&ast.IndexExpr{X: x, Lbrack: 1, Index: i, Rbrack: 2}: true,
want: true,
&ast.Ident{}: false,
}

for n, isIndexExpr := range tests {
X, lbrack, indices, rbrack := typeparams.UnpackIndexExpr(n)
X, lbrack, indices, rbrack := UnpackIndexExpr(n)
if got := X != nil; got != isIndexExpr {
t.Errorf("UnpackIndexExpr(%v) = %v, _, _, _; want nil: %t", n, x, !isIndexExpr)
}
Expand All @@ -35,3 +39,121 @@ func TestGetIndexExprData(t *testing.T) {
}
}
}

func TestOriginMethodRecursive(t *testing.T) {
testenv.NeedsGo1Point(t, 18)
src := `package p
type N[A any] int
func (r N[B]) m() { r.m(); r.n() }
func (r *N[C]) n() { }
`
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "p.go", src, 0)
if err != nil {
t.Fatal(err)
}
info := types.Info{
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
}
var conf types.Config
if _, err := conf.Check("p", fset, []*ast.File{f}, &info); err != nil {
t.Fatal(err)
}

// Collect objects from types.Info.
var m, n *types.Func // the 'origin' methods in Info.Defs
var mm, mn *types.Func // the methods used in the body of m

for _, decl := range f.Decls {
fdecl, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
def := info.Defs[fdecl.Name].(*types.Func)
switch fdecl.Name.Name {
case "m":
m = def
ast.Inspect(fdecl.Body, func(n ast.Node) bool {
if call, ok := n.(*ast.CallExpr); ok {
sel := call.Fun.(*ast.SelectorExpr)
use := info.Uses[sel.Sel].(*types.Func)
switch sel.Sel.Name {
case "m":
mm = use
case "n":
mn = use
}
}
return true
})
case "n":
n = def
}
}

tests := []struct {
name string
input, want *types.Func
}{
{"declared m", m, m},
{"declared n", n, n},
{"used m", mm, m},
{"used n", mn, n},
}

for _, test := range tests {
if got := OriginMethod(test.input); got != test.want {
t.Errorf("OriginMethod(%q) = %v, want %v", test.name, test.input, test.want)
}
}
}

func TestOriginMethodUses(t *testing.T) {
testenv.NeedsGo1Point(t, 18)

tests := []string{
`type T interface { m() }; func _(t T) { t.m() }`,
`type T[P any] interface { m() P }; func _[A any](t T[A]) { t.m() }`,
`type T[P any] interface { m() P }; func _(t T[int]) { t.m() }`,
`type T[P any] int; func (r T[A]) m() { r.m() }`,
`type T[P any] int; func (r *T[A]) m() { r.m() }`,
`type T[P any] int; func (r *T[A]) m() {}; func _(t T[int]) { t.m() }`,
`type T[P any] int; func (r *T[A]) m() {}; func _[A any](t T[A]) { t.m() }`,
}

for _, src := range tests {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "p.go", "package p; "+src, 0)
if err != nil {
t.Fatal(err)
}
info := types.Info{
Uses: make(map[*ast.Ident]types.Object),
}
var conf types.Config
pkg, err := conf.Check("p", fset, []*ast.File{f}, &info)
if err != nil {
t.Fatal(err)
}

T := pkg.Scope().Lookup("T").Type()
obj, _, _ := types.LookupFieldOrMethod(T, true, pkg, "m")
m := obj.(*types.Func)

ast.Inspect(f, func(n ast.Node) bool {
if call, ok := n.(*ast.CallExpr); ok {
sel := call.Fun.(*ast.SelectorExpr)
use := info.Uses[sel.Sel].(*types.Func)
orig := OriginMethod(use)
if orig != m {
t.Errorf("%s:\nUses[%v] = %v, want %v", src, types.ExprString(sel), use, m)
}
}
return true
})
}
}

0 comments on commit ea5e1dc

Please sign in to comment.