Skip to content

Commit

Permalink
Merge pull request #210 from johnrichardrinehart/fix/detect-death-by-…
Browse files Browse the repository at this point in the history
…sigint

fix(cmd/main.go): add a channel to proc{} for detecting SIGINT
  • Loading branch information
dnephin authored Aug 23, 2021
2 parents 1a94380 + 643063a commit 37116ff
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 26 deletions.
44 changes: 35 additions & 9 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"os/exec"
"os/signal"
"strings"
"sync/atomic"
"syscall"

"github.com/dnephin/pflag"
"github.com/fatih/color"
Expand Down Expand Up @@ -209,6 +211,9 @@ func run(opts *options) error {
return finishRun(opts, exec, err)
}
exitErr := goTestProc.cmd.Wait()
if signum := atomic.LoadInt32(&goTestProc.signal); signum != 0 {
return finishRun(opts, exec, exitError{num: signalExitCode + int(signum)})
}
if exitErr == nil || opts.rerunFailsMaxAttempts == 0 {
return finishRun(opts, exec, exitErr)
}
Expand Down Expand Up @@ -335,15 +340,18 @@ type proc struct {
cmd waiter
stdout io.Reader
stderr io.Reader
// signal is atomically set to the signal value when a signal is received
// by newSignalHandler.
signal int32
}

type waiter interface {
Wait() error
}

func startGoTest(ctx context.Context, args []string) (proc, error) {
func startGoTest(ctx context.Context, args []string) (*proc, error) {
if len(args) == 0 {
return proc{}, errors.New("missing command to run")
return nil, errors.New("missing command to run")
}

cmd := exec.CommandContext(ctx, args[0], args[1:]...)
Expand All @@ -352,21 +360,21 @@ func startGoTest(ctx context.Context, args []string) (proc, error) {
var err error
p.stdout, err = cmd.StdoutPipe()
if err != nil {
return p, err
return nil, err
}
p.stderr, err = cmd.StderrPipe()
if err != nil {
return p, err
return nil, err
}
if err := cmd.Start(); err != nil {
return p, errors.Wrapf(err, "failed to run %s", strings.Join(cmd.Args, " "))
return nil, errors.Wrapf(err, "failed to run %s", strings.Join(cmd.Args, " "))
}
log.Debugf("go test pid: %d", cmd.Process.Pid)

ctx, cancel := context.WithCancel(ctx)
newSignalHandler(ctx, cmd.Process.Pid)
newSignalHandler(ctx, cmd.Process.Pid, &p)
p.cmd = &cancelWaiter{cancel: cancel, wrapped: p.cmd}
return p, nil
return &p, nil
}

// ExitCodeWithDefault returns the ExitStatus of a process from the error returned by
Expand All @@ -387,12 +395,28 @@ type exitCoder interface {
ExitCode() int
}

func isExitCoder(err error) bool {
func IsExitCoder(err error) bool {
_, ok := err.(exitCoder)
return ok
}

func newSignalHandler(ctx context.Context, pid int) {
type exitError struct {
num int
}

func (e exitError) Error() string {
return fmt.Sprintf("exit code %d", e.num)
}

func (e exitError) ExitCode() int {
return e.num
}

// signalExitCode is the base value added to a signal number to produce the
// exit code value. This matches the behaviour of bash.
const signalExitCode = 128

func newSignalHandler(ctx context.Context, pid int, p *proc) {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)

Expand All @@ -403,6 +427,8 @@ func newSignalHandler(ctx context.Context, pid int) {
case <-ctx.Done():
return
case s := <-c:
atomic.StoreInt32(&p.signal, int32(s.(syscall.Signal)))

proc, err := os.FindProcess(pid)
if err != nil {
log.Errorf("failed to find pid of 'go test': %v", err)
Expand Down
2 changes: 1 addition & 1 deletion cmd/main_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func TestE2E_SignalHandler(t *testing.T) {
assert.NilError(t, result.Cmd.Process.Signal(os.Interrupt))
icmd.WaitOnCmd(2*time.Second, result)

result.Assert(t, icmd.Expected{ExitCode: 102})
result.Assert(t, icmd.Expected{ExitCode: 130})
}

func TestE2E_MaxFails_EndTestRun(t *testing.T) {
Expand Down
12 changes: 6 additions & 6 deletions cmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ func TestRun_RerunFails_WithTooManyInitialFailures(t *testing.T) {
{"Package": "pkg", "Action": "fail"}
`

fn := func(args []string) proc {
return proc{
fn := func(args []string) *proc {
return &proc{
cmd: fakeWaiter{result: newExitCode("failed", 1)},
stdout: strings.NewReader(jsonFailed),
stderr: bytes.NewReader(nil),
Expand Down Expand Up @@ -339,8 +339,8 @@ func TestRun_RerunFails_BuildErrorPreventsRerun(t *testing.T) {
{"Package": "pkg", "Action": "fail"}
`

fn := func(args []string) proc {
return proc{
fn := func(args []string) *proc {
return &proc{
cmd: fakeWaiter{result: newExitCode("failed", 1)},
stdout: strings.NewReader(jsonFailed),
stderr: strings.NewReader("anything here is an error\n"),
Expand Down Expand Up @@ -375,8 +375,8 @@ func TestRun_RerunFails_PanicPreventsRerun(t *testing.T) {
{"Package": "pkg", "Action": "fail"}
`

fn := func(args []string) proc {
return proc{
fn := func(args []string) *proc {
return &proc{
cmd: fakeWaiter{result: newExitCode("failed", 1)},
stdout: strings.NewReader(jsonFailed),
stderr: bytes.NewReader(nil),
Expand Down
8 changes: 4 additions & 4 deletions cmd/rerunfails_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ func TestRerunFailed_ReturnsAnErrorWhenTheLastTestIsSuccessful(t *testing.T) {
},
}

fn := func(args []string) proc {
fn := func(args []string) *proc {
next := events[0]
events = events[1:]
return proc{
return &proc{
cmd: fakeWaiter{result: next.err},
stdout: strings.NewReader(next.out),
stderr: bytes.NewReader(nil),
Expand All @@ -136,9 +136,9 @@ func TestRerunFailed_ReturnsAnErrorWhenTheLastTestIsSuccessful(t *testing.T) {
assert.Error(t, err, "run-failed-3")
}

func patchStartGoTestFn(f func(args []string) proc) func() {
func patchStartGoTestFn(f func(args []string) *proc) func() {
orig := startGoTestFn
startGoTestFn = func(ctx context.Context, args []string) (proc, error) {
startGoTestFn = func(ctx context.Context, args []string) (*proc, error) {
return f(args), nil
}
return func() {
Expand Down
4 changes: 2 additions & 2 deletions cmd/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (w *watchRuns) run(event filewatcher.Event) error {
args: w.opts.args,
initFilePath: path,
}
if err := runDelve(o); !isExitCoder(err) {
if err := runDelve(o); !IsExitCoder(err) {
return fmt.Errorf("delve failed: %w", err)
}
return nil
Expand All @@ -43,7 +43,7 @@ func (w *watchRuns) run(event filewatcher.Event) error {
opts := w.opts
opts.packages = []string{event.PkgPath}
var err error
if w.prevExec, err = runSingle(&opts); !isExitCoder(err) {
if w.prevExec, err = runSingle(&opts); !IsExitCoder(err) {
return err
}
return nil
Expand Down
7 changes: 3 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"os"
"os/exec"

"gotest.tools/gotestsum/cmd"
"gotest.tools/gotestsum/cmd/tool"
Expand All @@ -11,10 +10,10 @@ import (

func main() {
err := route(os.Args)
switch err.(type) {
case nil:
switch {
case err == nil:
return
case *exec.ExitError:
case cmd.IsExitCoder(err):
// go test should already report the error to stderr, exit with
// the same status code
os.Exit(cmd.ExitCodeWithDefault(err))
Expand Down

0 comments on commit 37116ff

Please sign in to comment.