From fa1efd11307f6d21903cf6a2064de44c249d2272 Mon Sep 17 00:00:00 2001 From: Denis Voytyuk <5462781+denisvmedia@users.noreply.github.com> Date: Sat, 30 Nov 2024 10:32:30 +0100 Subject: [PATCH] chore: Improve sortables detection --- lint/package.go | 69 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/lint/package.go b/lint/package.go index 873f8a002..e212205ee 100644 --- a/lint/package.go +++ b/lint/package.go @@ -1,6 +1,7 @@ package lint import ( + "errors" "go/ast" "go/importer" "go/token" @@ -31,7 +32,6 @@ type Package struct { var ( trueValue = 1 falseValue = 2 - notSet = 3 go121 = goversion.Must(goversion.NewVersion("1.21")) go122 = goversion.Must(goversion.NewVersion("1.22")) @@ -111,6 +111,11 @@ func (p *Package) TypeCheck() error { astFiles = append(astFiles, f.AST) } + if anyFile == nil { + // this is unlikely to happen, but technically guarantees anyFile to not be nil + return errors.New("no ast.File found") + } + typesPkg, err := check(config, anyFile.AST.Name.Name, p.fset, astFiles, info) // Remember the typechecking info, even if config.Check failed, @@ -135,7 +140,7 @@ func check(config *types.Config, n string, fset *token.FileSet, astFiles []*ast. return config.Check(n, fset, astFiles, info) } -// TypeOf returns the type of an expression. +// TypeOf returns the type of expression. func (p *Package) TypeOf(expr ast.Expr) types.Type { if p.typesInfo == nil { return nil @@ -148,32 +153,72 @@ type walker struct { has map[string]int } +// bitfield for which methods exist on each type. +const ( + bfLen = 1 << iota + bfLess + bfSwap +) + func (w *walker) Visit(n ast.Node) ast.Visitor { fn, ok := n.(*ast.FuncDecl) if !ok || fn.Recv == nil || len(fn.Recv.List) == 0 { return w } - // TODO(dsymonds): We could check the signature to be more precise. + recv := typeparams.ReceiverType(fn) - if i, ok := w.nmap[fn.Name.Name]; ok { - w.has[recv] |= i + + // Ensure the method signature matches expectations. + switch fn.Name.Name { + case "Len": + if fn.Type.Params.NumFields() == 0 && fn.Type.Results.NumFields() == 1 { + resultType := fn.Type.Results.List[0].Type + if _, ok := resultType.(*ast.Ident); ok && resultType.(*ast.Ident).Name == "int" { + w.has[recv] |= bfLen + } + } + case "Less": + if fn.Type.Params.NumFields() == 2 && fn.Type.Results.NumFields() == 1 { + param1 := fn.Type.Params.List[0].Type + var param2 ast.Expr + if len(fn.Type.Params.List) == 2 { + param2 = fn.Type.Params.List[1].Type + } else { + param2 = param1 + } + resultType := fn.Type.Results.List[0].Type + + // Ensure parameters have the same type and the result is a bool. + if typesEqual(param1, param2) && isBool(resultType) { + w.has[recv] |= bfLess + } + } + case "Swap": + if fn.Type.Params.NumFields() == 2 && fn.Type.Results.NumFields() == 0 { + w.has[recv] |= bfSwap + } } return w } +func typesEqual(a, b ast.Expr) bool { + identA, okA := a.(*ast.Ident) + identB, okB := b.(*ast.Ident) + return okA && okB && identA.Name == identB.Name +} + +func isBool(t ast.Expr) bool { + ident, ok := t.(*ast.Ident) + return ok && ident.Name == "bool" +} + func (p *Package) scanSortable() { p.sortable = map[string]bool{} - // bitfield for which methods exist on each type. - const ( - bfLen = 1 << iota - bfLess - bfSwap - ) nmap := map[string]int{"Len": bfLen, "Less": bfLess, "Swap": bfSwap} has := map[string]int{} for _, f := range p.files { - ast.Walk(&walker{nmap, has}, f.AST) + ast.Walk(&walker{nmap: nmap, has: has}, f.AST) } for typ, ms := range has { if ms == bfLen|bfLess|bfSwap {