Skip to content

Commit

Permalink
chore: Improve sortables detection
Browse files Browse the repository at this point in the history
  • Loading branch information
denisvmedia committed Nov 30, 2024
1 parent 3bc2404 commit fa1efd1
Showing 1 changed file with 57 additions and 12 deletions.
69 changes: 57 additions & 12 deletions lint/package.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lint

import (
"errors"
"go/ast"
"go/importer"
"go/token"
Expand Down Expand Up @@ -31,7 +32,6 @@ type Package struct {
var (
trueValue = 1
falseValue = 2
notSet = 3

go121 = goversion.Must(goversion.NewVersion("1.21"))
go122 = goversion.Must(goversion.NewVersion("1.22"))
Expand Down Expand Up @@ -111,6 +111,11 @@ func (p *Package) TypeCheck() error {
astFiles = append(astFiles, f.AST)
}

if anyFile == nil {
// this is unlikely to happen, but technically guarantees anyFile to not be nil
return errors.New("no ast.File found")
}

typesPkg, err := check(config, anyFile.AST.Name.Name, p.fset, astFiles, info)

// Remember the typechecking info, even if config.Check failed,
Expand All @@ -135,7 +140,7 @@ func check(config *types.Config, n string, fset *token.FileSet, astFiles []*ast.
return config.Check(n, fset, astFiles, info)
}

// TypeOf returns the type of an expression.
// TypeOf returns the type of expression.
func (p *Package) TypeOf(expr ast.Expr) types.Type {
if p.typesInfo == nil {
return nil
Expand All @@ -148,32 +153,72 @@ type walker struct {
has map[string]int
}

// bitfield for which methods exist on each type.
const (
bfLen = 1 << iota
bfLess
bfSwap
)

func (w *walker) Visit(n ast.Node) ast.Visitor {
fn, ok := n.(*ast.FuncDecl)
if !ok || fn.Recv == nil || len(fn.Recv.List) == 0 {
return w
}
// TODO(dsymonds): We could check the signature to be more precise.

recv := typeparams.ReceiverType(fn)
if i, ok := w.nmap[fn.Name.Name]; ok {
w.has[recv] |= i

// Ensure the method signature matches expectations.
switch fn.Name.Name {
case "Len":
if fn.Type.Params.NumFields() == 0 && fn.Type.Results.NumFields() == 1 {
resultType := fn.Type.Results.List[0].Type
if _, ok := resultType.(*ast.Ident); ok && resultType.(*ast.Ident).Name == "int" {
w.has[recv] |= bfLen
}
}
case "Less":
if fn.Type.Params.NumFields() == 2 && fn.Type.Results.NumFields() == 1 {
param1 := fn.Type.Params.List[0].Type
var param2 ast.Expr
if len(fn.Type.Params.List) == 2 {
param2 = fn.Type.Params.List[1].Type
} else {
param2 = param1
}
resultType := fn.Type.Results.List[0].Type

// Ensure parameters have the same type and the result is a bool.
if typesEqual(param1, param2) && isBool(resultType) {
w.has[recv] |= bfLess
}
}
case "Swap":
if fn.Type.Params.NumFields() == 2 && fn.Type.Results.NumFields() == 0 {
w.has[recv] |= bfSwap
}
}
return w
}

func typesEqual(a, b ast.Expr) bool {
identA, okA := a.(*ast.Ident)
identB, okB := b.(*ast.Ident)
return okA && okB && identA.Name == identB.Name
}

func isBool(t ast.Expr) bool {
ident, ok := t.(*ast.Ident)
return ok && ident.Name == "bool"
}

func (p *Package) scanSortable() {
p.sortable = map[string]bool{}

// bitfield for which methods exist on each type.
const (
bfLen = 1 << iota
bfLess
bfSwap
)
nmap := map[string]int{"Len": bfLen, "Less": bfLess, "Swap": bfSwap}
has := map[string]int{}
for _, f := range p.files {
ast.Walk(&walker{nmap, has}, f.AST)
ast.Walk(&walker{nmap: nmap, has: has}, f.AST)
}
for typ, ms := range has {
if ms == bfLen|bfLess|bfSwap {
Expand Down

0 comments on commit fa1efd1

Please sign in to comment.