Skip to content

Commit

Permalink
feat: optional extension to early-return rule (#1133) (#1138)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdelah authored Nov 28, 2024
1 parent 777abc9 commit 7e1d35d
Show file tree
Hide file tree
Showing 14 changed files with 308 additions and 104 deletions.
3 changes: 2 additions & 1 deletion RULES_DESCRIPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,13 @@ if !cond {
_Configuration_: ([]string) rule flags. Available flags are:

* _preserveScope_: do not suggest refactorings that would increase variable scope
* _allowJump_: suggest a new jump (`return`, `continue` or `break` statement) if it could unnest multiple statements. By default, only relocation of _existing_ jumps (i.e. from the `else` clause) are suggested.

Example:

```toml
[rule.early-return]
arguments = ["preserveScope"]
arguments = ["preserveScope", "allowJump"]
```

## empty-block
Expand Down
9 changes: 8 additions & 1 deletion internal/ifelse/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@ package ifelse
// that would enlarge variable scope
const PreserveScope = "preserveScope"

// AllowJump is a configuration argument that permits early-return to
// suggest introducing a new jump (return, continue, etc) statement
// to reduce nesting. By default, suggestions only bring existing jumps
// earlier.
const AllowJump = "allowJump"

// Args contains arguments common to the early-return, indent-error-flow
// and superfluous-else rules (currently just preserveScope)
// and superfluous-else rules
type Args struct {
PreserveScope bool
AllowJump bool
}
42 changes: 32 additions & 10 deletions internal/ifelse/branch.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
// Branch contains information about a branch within an if-else chain.
type Branch struct {
BranchKind
Call // The function called at the end for kind Panic or Exit.
HasDecls bool // The branch has one or more declarations (at the top level block)
Call // The function called at the end for kind Panic or Exit.
block []ast.Stmt
}

// BlockBranch gets the Branch of an ast.BlockStmt.
Expand All @@ -21,7 +21,7 @@ func BlockBranch(block *ast.BlockStmt) Branch {
}

branch := StmtBranch(block.List[blockLen-1])
branch.HasDecls = hasDecls(block)
branch.block = block.List
return branch
}

Expand Down Expand Up @@ -61,25 +61,28 @@ func StmtBranch(stmt ast.Stmt) Branch {
// String returns a brief string representation
func (b Branch) String() string {
switch b.BranchKind {
case Empty:
return "{ }"
case Regular:
return "{ ... }"
case Panic, Exit:
return fmt.Sprintf("... %v()", b.Call)
default:
return b.BranchKind.String()
return fmt.Sprintf("{ ... %v() }", b.Call)
}
return fmt.Sprintf("{ ... %v }", b.BranchKind)
}

// LongString returns a longer form string representation
func (b Branch) LongString() string {
switch b.BranchKind {
case Panic, Exit:
return fmt.Sprintf("call to %v function", b.Call)
default:
return b.BranchKind.LongString()
}
return b.BranchKind.LongString()
}

func hasDecls(block *ast.BlockStmt) bool {
for _, stmt := range block.List {
// HasDecls returns whether the branch has any top-level declarations
func (b Branch) HasDecls() bool {
for _, stmt := range b.block {
switch stmt := stmt.(type) {
case *ast.DeclStmt:
return true
Expand All @@ -91,3 +94,22 @@ func hasDecls(block *ast.BlockStmt) bool {
}
return false
}

// IsShort returns whether the branch is empty or consists of a single statement
func (b Branch) IsShort() bool {
switch len(b.block) {
case 0:
return true
case 1:
return isShortStmt(b.block[0])
}
return false
}

func isShortStmt(stmt ast.Stmt) bool {
switch stmt.(type) {
case *ast.BlockStmt, *ast.IfStmt, *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.SelectStmt, *ast.ForStmt, *ast.RangeStmt:
return false
}
return true
}
23 changes: 10 additions & 13 deletions internal/ifelse/branch_kind.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ func (k BranchKind) Deviates() bool {
return false
case Return, Continue, Break, Goto, Panic, Exit:
return true
default:
panic("invalid kind")
}
panic("invalid kind")
}

// Branch returns a Branch with the given kind
Expand All @@ -58,22 +57,21 @@ func (k BranchKind) String() string {
case Empty:
return ""
case Regular:
return "..."
return ""
case Return:
return "... return"
return "return"
case Continue:
return "... continue"
return "continue"
case Break:
return "... break"
return "break"
case Goto:
return "... goto"
return "goto"
case Panic:
return "... panic()"
return "panic()"
case Exit:
return "... os.Exit()"
default:
panic("invalid kind")
return "os.Exit()"
}
panic("invalid kind")
}

// LongString returns a longer form string representation
Expand All @@ -95,7 +93,6 @@ func (k BranchKind) LongString() string {
return "a function call that panics"
case Exit:
return "a function call that exits the program"
default:
panic("invalid kind")
}
panic("invalid kind")
}
12 changes: 7 additions & 5 deletions internal/ifelse/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package ifelse

// Chain contains information about an if-else chain.
type Chain struct {
If Branch // what happens at the end of the "if" block
Else Branch // what happens at the end of the "else" block
HasInitializer bool // is there an "if"-initializer somewhere in the chain?
HasPriorNonDeviating bool // is there a prior "if" block that does NOT deviate control flow?
AtBlockEnd bool // whether the chain is placed at the end of the surrounding block
If Branch // what happens at the end of the "if" block
HasElse bool // is there an "else" block?
Else Branch // what happens at the end of the "else" block
HasInitializer bool // is there an "if"-initializer somewhere in the chain?
HasPriorNonDeviating bool // is there a prior "if" block that does NOT deviate control flow?
AtBlockEnd bool // whether the chain is placed at the end of the surrounding block
BlockEndKind BranchKind // control flow at end of surrounding block (e.g. "return" for function body)
}
6 changes: 2 additions & 4 deletions internal/ifelse/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ func ExprCall(expr *ast.ExprStmt) (Call, bool) {

// String returns the function name with package qualifier (if any)
func (f Call) String() string {
switch {
case f.Pkg != "":
if f.Pkg != "" {
return fmt.Sprintf("%s.%s", f.Pkg, f.Name)
default:
return f.Name
}
return f.Name
}
112 changes: 75 additions & 37 deletions internal/ifelse/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import (
"github.com/mgechev/revive/lint"
)

// Rule is an interface for linters operating on if-else chains
type Rule interface {
CheckIfElse(chain Chain, args Args) (failMsg string)
}
// CheckFunc evaluates a rule against the given if-else chain and returns a message
// describing the proposed refactor, along with a indicator of whether such a refactor
// could be found.
type CheckFunc func(Chain, Args) (string, bool)

// Apply evaluates the given Rule on if-else chains found within the given AST,
// and returns the failures.
Expand All @@ -28,11 +28,14 @@ type Rule interface {
//
// Only the block following "bar" is linted. This is because the rules that use this function
// do not presently have anything to say about earlier blocks in the chain.
func Apply(rule Rule, node ast.Node, target Target, args lint.Arguments) []lint.Failure {
v := &visitor{rule: rule, target: target}
func Apply(check CheckFunc, node ast.Node, target Target, args lint.Arguments) []lint.Failure {
v := &visitor{check: check, target: target}
for _, arg := range args {
if arg == PreserveScope {
switch arg {
case PreserveScope:
v.args.PreserveScope = true
case AllowJump:
v.args.AllowJump = true
}
}
ast.Walk(v, node)
Expand All @@ -42,64 +45,99 @@ func Apply(rule Rule, node ast.Node, target Target, args lint.Arguments) []lint.
type visitor struct {
failures []lint.Failure
target Target
rule Rule
check CheckFunc
args Args
}

func (v *visitor) Visit(node ast.Node) ast.Visitor {
block, ok := node.(*ast.BlockStmt)
if !ok {
switch stmt := node.(type) {
case *ast.FuncDecl:
v.visitBody(stmt.Body, Return)
case *ast.FuncLit:
v.visitBody(stmt.Body, Return)
case *ast.ForStmt:
v.visitBody(stmt.Body, Continue)
case *ast.RangeStmt:
v.visitBody(stmt.Body, Continue)
case *ast.CaseClause:
v.visitBlock(stmt.Body, Break)
case *ast.BlockStmt:
v.visitBlock(stmt.List, Regular)
default:
return v
}
return nil
}

func (v *visitor) visitBody(body *ast.BlockStmt, endKind BranchKind) {
if body != nil {
v.visitBlock(body.List, endKind)
}
}

for i, stmt := range block.List {
if ifStmt, ok := stmt.(*ast.IfStmt); ok {
v.visitChain(ifStmt, Chain{AtBlockEnd: i == len(block.List)-1})
func (v *visitor) visitBlock(stmts []ast.Stmt, endKind BranchKind) {
for i, stmt := range stmts {
ifStmt, ok := stmt.(*ast.IfStmt)
if !ok {
ast.Walk(v, stmt)
continue
}
ast.Walk(v, stmt)
var chain Chain
if i == len(stmts)-1 {
chain.AtBlockEnd = true
chain.BlockEndKind = endKind
}
v.visitIf(ifStmt, chain)
}
return nil
}

func (v *visitor) visitChain(ifStmt *ast.IfStmt, chain Chain) {
func (v *visitor) visitIf(ifStmt *ast.IfStmt, chain Chain) {
// look for other if-else chains nested inside this if { } block
ast.Walk(v, ifStmt.Body)

if ifStmt.Else == nil {
// no else branch
return
}
v.visitBlock(ifStmt.Body.List, chain.BlockEndKind)

if as, ok := ifStmt.Init.(*ast.AssignStmt); ok && as.Tok == token.DEFINE {
chain.HasInitializer = true
}
chain.If = BlockBranch(ifStmt.Body)

if ifStmt.Else == nil {
if v.args.AllowJump {
v.checkRule(ifStmt, chain)
}
return
}

switch elseBlock := ifStmt.Else.(type) {
case *ast.IfStmt:
if !chain.If.Deviates() {
chain.HasPriorNonDeviating = true
}
v.visitChain(elseBlock, chain)
v.visitIf(elseBlock, chain)
case *ast.BlockStmt:
// look for other if-else chains nested inside this else { } block
ast.Walk(v, elseBlock)
v.visitBlock(elseBlock.List, chain.BlockEndKind)

chain.HasElse = true
chain.Else = BlockBranch(elseBlock)
if failMsg := v.rule.CheckIfElse(chain, v.args); failMsg != "" {
if chain.HasInitializer {
// if statement has a := initializer, so we might need to move the assignment
// onto its own line in case the body references it
failMsg += " (move short variable declaration to its own line if necessary)"
}
v.failures = append(v.failures, lint.Failure{
Confidence: 1,
Node: v.target.node(ifStmt),
Failure: failMsg,
})
}
v.checkRule(ifStmt, chain)
default:
panic("invalid node type for else")
panic("unexpected node type for else")
}
}

func (v *visitor) checkRule(ifStmt *ast.IfStmt, chain Chain) {
msg, found := v.check(chain, v.args)
if !found {
return // passed the check
}
if chain.HasInitializer {
// if statement has a := initializer, so we might need to move the assignment
// onto its own line in case the body references it
msg += " (move short variable declaration to its own line if necessary)"
}
v.failures = append(v.failures, lint.Failure{
Confidence: 1,
Node: v.target.node(ifStmt),
Failure: msg,
})
}
3 changes: 1 addition & 2 deletions internal/ifelse/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ func (t Target) node(ifStmt *ast.IfStmt) ast.Node {
return ifStmt
case TargetElse:
return ifStmt.Else
default:
panic("bad target")
}
panic("bad target")
}
Loading

0 comments on commit 7e1d35d

Please sign in to comment.