Skip to content

Commit

Permalink
cmd: update signal handling
Browse files Browse the repository at this point in the history
Use an atomic int to track the signal code.
Add 128 to the signal to match bash.
Return a pointer to proc, so that the caller is able to load the signal
number.
  • Loading branch information
dnephin committed Aug 14, 2021
1 parent 232870c commit 643063a
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 33 deletions.
48 changes: 32 additions & 16 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os/exec"
"os/signal"
"strings"
"sync/atomic"
"syscall"

"github.com/dnephin/pflag"
Expand Down Expand Up @@ -210,11 +211,8 @@ func run(opts *options) error {
return finishRun(opts, exec, err)
}
exitErr := goTestProc.cmd.Wait()
siggedOut := <-goTestProc.signal // check if we received a SIGINT

if siggedOut != nil {
n, _ := (siggedOut).(syscall.Signal)
return finishRun(opts, exec, fmt.Errorf("syscall.Signal==%d", int(n)))
if signum := atomic.LoadInt32(&goTestProc.signal); signum != 0 {
return finishRun(opts, exec, exitError{num: signalExitCode + int(signum)})
}
if exitErr == nil || opts.rerunFailsMaxAttempts == 0 {
return finishRun(opts, exec, exitErr)
Expand Down Expand Up @@ -342,39 +340,41 @@ type proc struct {
cmd waiter
stdout io.Reader
stderr io.Reader
signal chan os.Signal
// signal is atomically set to the signal value when a signal is received
// by newSignalHandler.
signal int32
}

type waiter interface {
Wait() error
}

func startGoTest(ctx context.Context, args []string) (proc, error) {
func startGoTest(ctx context.Context, args []string) (*proc, error) {
if len(args) == 0 {
return proc{}, errors.New("missing command to run")
return nil, errors.New("missing command to run")
}

cmd := exec.CommandContext(ctx, args[0], args[1:]...)
p := proc{cmd: cmd, signal: make(chan os.Signal, 1)}
p := proc{cmd: cmd}
log.Debugf("exec: %s", cmd.Args)
var err error
p.stdout, err = cmd.StdoutPipe()
if err != nil {
return p, err
return nil, err
}
p.stderr, err = cmd.StderrPipe()
if err != nil {
return p, err
return nil, err
}
if err := cmd.Start(); err != nil {
return p, errors.Wrapf(err, "failed to run %s", strings.Join(cmd.Args, " "))
return nil, errors.Wrapf(err, "failed to run %s", strings.Join(cmd.Args, " "))
}
log.Debugf("go test pid: %d", cmd.Process.Pid)

ctx, cancel := context.WithCancel(ctx)
newSignalHandler(ctx, cmd.Process.Pid, &p)
p.cmd = &cancelWaiter{cancel: cancel, wrapped: p.cmd}
return p, nil
return &p, nil
}

// ExitCodeWithDefault returns the ExitStatus of a process from the error returned by
Expand All @@ -395,11 +395,27 @@ type exitCoder interface {
ExitCode() int
}

func isExitCoder(err error) bool {
func IsExitCoder(err error) bool {
_, ok := err.(exitCoder)
return ok
}

type exitError struct {
num int
}

func (e exitError) Error() string {
return fmt.Sprintf("exit code %d", e.num)
}

func (e exitError) ExitCode() int {
return e.num
}

// signalExitCode is the base value added to a signal number to produce the
// exit code value. This matches the behaviour of bash.
const signalExitCode = 128

func newSignalHandler(ctx context.Context, pid int, p *proc) {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
Expand All @@ -409,11 +425,11 @@ func newSignalHandler(ctx context.Context, pid int, p *proc) {

select {
case <-ctx.Done():
close(p.signal)
return
case s := <-c:
atomic.StoreInt32(&p.signal, int32(s.(syscall.Signal)))

proc, err := os.FindProcess(pid)
p.signal <- s
if err != nil {
log.Errorf("failed to find pid of 'go test': %v", err)
return
Expand Down
2 changes: 1 addition & 1 deletion cmd/main_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func TestE2E_SignalHandler(t *testing.T) {
assert.NilError(t, result.Cmd.Process.Signal(os.Interrupt))
icmd.WaitOnCmd(2*time.Second, result)

result.Assert(t, icmd.Expected{ExitCode: 102})
result.Assert(t, icmd.Expected{ExitCode: 130})
}

func TestE2E_MaxFails_EndTestRun(t *testing.T) {
Expand Down
12 changes: 6 additions & 6 deletions cmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ func TestRun_RerunFails_WithTooManyInitialFailures(t *testing.T) {
{"Package": "pkg", "Action": "fail"}
`

fn := func(args []string) proc {
return proc{
fn := func(args []string) *proc {
return &proc{
cmd: fakeWaiter{result: newExitCode("failed", 1)},
stdout: strings.NewReader(jsonFailed),
stderr: bytes.NewReader(nil),
Expand Down Expand Up @@ -339,8 +339,8 @@ func TestRun_RerunFails_BuildErrorPreventsRerun(t *testing.T) {
{"Package": "pkg", "Action": "fail"}
`

fn := func(args []string) proc {
return proc{
fn := func(args []string) *proc {
return &proc{
cmd: fakeWaiter{result: newExitCode("failed", 1)},
stdout: strings.NewReader(jsonFailed),
stderr: strings.NewReader("anything here is an error\n"),
Expand Down Expand Up @@ -375,8 +375,8 @@ func TestRun_RerunFails_PanicPreventsRerun(t *testing.T) {
{"Package": "pkg", "Action": "fail"}
`

fn := func(args []string) proc {
return proc{
fn := func(args []string) *proc {
return &proc{
cmd: fakeWaiter{result: newExitCode("failed", 1)},
stdout: strings.NewReader(jsonFailed),
stderr: bytes.NewReader(nil),
Expand Down
8 changes: 4 additions & 4 deletions cmd/rerunfails_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ func TestRerunFailed_ReturnsAnErrorWhenTheLastTestIsSuccessful(t *testing.T) {
},
}

fn := func(args []string) proc {
fn := func(args []string) *proc {
next := events[0]
events = events[1:]
return proc{
return &proc{
cmd: fakeWaiter{result: next.err},
stdout: strings.NewReader(next.out),
stderr: bytes.NewReader(nil),
Expand All @@ -136,9 +136,9 @@ func TestRerunFailed_ReturnsAnErrorWhenTheLastTestIsSuccessful(t *testing.T) {
assert.Error(t, err, "run-failed-3")
}

func patchStartGoTestFn(f func(args []string) proc) func() {
func patchStartGoTestFn(f func(args []string) *proc) func() {
orig := startGoTestFn
startGoTestFn = func(ctx context.Context, args []string) (proc, error) {
startGoTestFn = func(ctx context.Context, args []string) (*proc, error) {
return f(args), nil
}
return func() {
Expand Down
4 changes: 2 additions & 2 deletions cmd/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (w *watchRuns) run(event filewatcher.Event) error {
args: w.opts.args,
initFilePath: path,
}
if err := runDelve(o); !isExitCoder(err) {
if err := runDelve(o); !IsExitCoder(err) {
return fmt.Errorf("delve failed: %w", err)
}
return nil
Expand All @@ -43,7 +43,7 @@ func (w *watchRuns) run(event filewatcher.Event) error {
opts := w.opts
opts.packages = []string{event.PkgPath}
var err error
if w.prevExec, err = runSingle(&opts); !isExitCoder(err) {
if w.prevExec, err = runSingle(&opts); !IsExitCoder(err) {
return err
}
return nil
Expand Down
7 changes: 3 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"os"
"os/exec"

"gotest.tools/gotestsum/cmd"
"gotest.tools/gotestsum/cmd/tool"
Expand All @@ -11,10 +10,10 @@ import (

func main() {
err := route(os.Args)
switch err.(type) {
case nil:
switch {
case err == nil:
return
case *exec.ExitError:
case cmd.IsExitCoder(err):
// go test should already report the error to stderr, exit with
// the same status code
os.Exit(cmd.ExitCodeWithDefault(err))
Expand Down

0 comments on commit 643063a

Please sign in to comment.