Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Improve sortables detection #1151

Merged
merged 5 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions internal/astutils/ast_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package astutils

import "go/ast"

// FuncSignatureIs returns true if the given func decl satisfies a signature characterized
// by the given name, parameters types and return types; false otherwise.
//
// Example: to check if a function declaration has the signature Foo(int, string) (bool,error)
// call to FuncSignatureIs(funcDecl,"Foo",[]string{"int","string"},[]string{"bool","error"})
func FuncSignatureIs(funcDecl *ast.FuncDecl, wantName string, wantParametersTypes, wantResultsTypes []string) bool {
if wantName != funcDecl.Name.String() {
return false // func name doesn't match expected one
}

funcParametersTypes := getTypeNames(funcDecl.Type.Params)
if len(wantParametersTypes) != len(funcParametersTypes) {
return false // func has not the expected number of parameters
}

funcResultsTypes := getTypeNames(funcDecl.Type.Results)
if len(wantResultsTypes) != len(funcResultsTypes) {
return false // func has not the expected number of return values
}

for i, wantType := range wantParametersTypes {
if wantType != funcParametersTypes[i] {
return false // type of a func's parameter does not match the type of the corresponding expected parameter
}
}

for i, wantType := range wantResultsTypes {
if wantType != funcResultsTypes[i] {
return false // type of a func's return value does not match the type of the corresponding expected return value
}
}

return true
}

func getTypeNames(fields *ast.FieldList) []string {
result := []string{}

if fields == nil {
return result
}

for _, field := range fields.List {
typeName := field.Type.(*ast.Ident).Name
if field.Names == nil { // unnamed field
result = append(result, typeName)
continue
}

for range field.Names { // add one type name for each field name
result = append(result, typeName)
}
}

return result
}
72 changes: 42 additions & 30 deletions lint/package.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lint

import (
"errors"
"go/ast"
"go/importer"
"go/token"
Expand All @@ -9,6 +10,7 @@ import (

goversion "github.com/hashicorp/go-version"

"github.com/mgechev/revive/internal/astutils"
"github.com/mgechev/revive/internal/typeparams"
)

Expand All @@ -31,7 +33,6 @@ type Package struct {
var (
trueValue = 1
falseValue = 2
notSet = 3

go121 = goversion.Must(goversion.NewVersion("1.21"))
go122 = goversion.Must(goversion.NewVersion("1.22"))
Expand Down Expand Up @@ -111,6 +112,11 @@ func (p *Package) TypeCheck() error {
astFiles = append(astFiles, f.AST)
}

if anyFile == nil {
// this is unlikely to happen, but technically guarantees anyFile to not be nil
return errors.New("no ast.File found")
}

typesPkg, err := check(config, anyFile.AST.Name.Name, p.fset, astFiles, info)

// Remember the typechecking info, even if config.Check failed,
Expand All @@ -135,47 +141,40 @@ func check(config *types.Config, n string, fset *token.FileSet, astFiles []*ast.
return config.Check(n, fset, astFiles, info)
}

// TypeOf returns the type of an expression.
// TypeOf returns the type of expression.
func (p *Package) TypeOf(expr ast.Expr) types.Type {
if p.typesInfo == nil {
return nil
}
return p.typesInfo.TypeOf(expr)
}

type walker struct {
nmap map[string]int
has map[string]int
}
type sortableMethodsFlags int

func (w *walker) Visit(n ast.Node) ast.Visitor {
fn, ok := n.(*ast.FuncDecl)
if !ok || fn.Recv == nil || len(fn.Recv.List) == 0 {
return w
}
// TODO(dsymonds): We could check the signature to be more precise.
recv := typeparams.ReceiverType(fn)
if i, ok := w.nmap[fn.Name.Name]; ok {
w.has[recv] |= i
}
return w
}
// flags for sortable interface methods.
const (
bfLen sortableMethodsFlags = 1 << iota
bfLess
bfSwap
)

func (p *Package) scanSortable() {
p.sortable = map[string]bool{}

// bitfield for which methods exist on each type.
const (
bfLen = 1 << iota
bfLess
bfSwap
)
nmap := map[string]int{"Len": bfLen, "Less": bfLess, "Swap": bfSwap}
has := map[string]int{}
sortableFlags := map[string]sortableMethodsFlags{}
for _, f := range p.files {
ast.Walk(&walker{nmap, has}, f.AST)
for _, decl := range f.AST.Decls {
fn, ok := decl.(*ast.FuncDecl)
isAMethodDeclaration := ok && fn.Recv != nil && len(fn.Recv.List) != 0
if !isAMethodDeclaration {
continue
}

recvType := typeparams.ReceiverType(fn)
sortableFlags[recvType] |= getSortableMethodFlagForFunction(fn)
}
}
for typ, ms := range has {

p.sortable = make(map[string]bool, len(sortableFlags))
for typ, ms := range sortableFlags {
if ms == bfLen|bfLess|bfSwap {
p.sortable[typ] = true
}
Expand Down Expand Up @@ -204,3 +203,16 @@ func (p *Package) IsAtLeastGo121() bool {
func (p *Package) IsAtLeastGo122() bool {
return p.goVersion.GreaterThanOrEqual(go122)
}

func getSortableMethodFlagForFunction(fn *ast.FuncDecl) sortableMethodsFlags {
switch {
case astutils.FuncSignatureIs(fn, "Len", []string{}, []string{"int"}):
return bfLen
case astutils.FuncSignatureIs(fn, "Less", []string{"int", "int"}, []string{"bool"}):
return bfLess
case astutils.FuncSignatureIs(fn, "Swap", []string{"int", "int"}, []string{}):
return bfSwap
default:
return 0
}
}
18 changes: 18 additions & 0 deletions testdata/golint/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,21 @@ func (u U) Less(i, j int) bool { return u[i] < u[j] }
func (u U) Swap(i, j int) { u[i], u[j] = u[j], u[i] }

func (u U) Other() {} // MATCH /exported method U.Other should have comment or be unexported/

// V is ...
type V []int

func (v V) Len() (result int) { return len(w) }
func (v V) Less(i int, j int) (result bool) { return w[i] < w[j] }
func (v V) Swap(i int, j int) { v[i], v[j] = v[j], v[i] }

// W is ...
type W []int

func (w W) Swap(i int, j int) {} // MATCH /exported method W.Swap should have comment or be unexported/

// Vv is ...
type Vv []int

func (vv Vv) Len() (result int) { return len(w) } // MATCH /exported method Vv.Len should have comment or be unexported/
func (vv Vv) Less(i int, j int) (result bool) { return w[i] < w[j] } // MATCH /exported method Vv.Less should have comment or be unexported/