diff --git a/cli/command/container/attach.go b/cli/command/container/attach.go index 854aa7395224..8470855992c8 100644 --- a/cli/command/container/attach.go +++ b/cli/command/container/attach.go @@ -97,7 +97,8 @@ func runAttach(dockerCli command.Cli, opts *attachOptions) error { } if opts.proxy && !c.Config.Tty { - sigc := ForwardAllSignals(ctx, dockerCli, opts.container) + sigc := notfiyAllSignals() + go ForwardAllSignals(ctx, dockerCli, opts.container, sigc) defer signal.StopCatch(sigc) } diff --git a/cli/command/container/client_test.go b/cli/command/container/client_test.go index 3643550b92c8..6aa7d9991e34 100644 --- a/cli/command/container/client_test.go +++ b/cli/command/container/client_test.go @@ -32,6 +32,7 @@ type fakeClient struct { containerExportFunc func(string) (io.ReadCloser, error) containerExecResizeFunc func(id string, options types.ResizeOptions) error containerRemoveFunc func(ctx context.Context, container string, options types.ContainerRemoveOptions) error + containerKillFunc func(ctx context.Context, container, signal string) error Version string } @@ -154,3 +155,10 @@ func (f *fakeClient) ContainerExecResize(_ context.Context, id string, options t } return nil } + +func (f *fakeClient) ContainerKill(ctx context.Context, container, signal string) error { + if f.containerKillFunc != nil { + return f.containerKillFunc(ctx, container, signal) + } + return nil +} diff --git a/cli/command/container/run.go b/cli/command/container/run.go index 74bee8cd12e1..fec9995841f8 100644 --- a/cli/command/container/run.go +++ b/cli/command/container/run.go @@ -131,7 +131,8 @@ func runContainer(dockerCli command.Cli, opts *runOptions, copts *containerOptio return runStartContainerErr(err) } if opts.sigProxy { - sigc := ForwardAllSignals(ctx, dockerCli, createResponse.ID) + sigc := notfiyAllSignals() + go ForwardAllSignals(ctx, dockerCli, createResponse.ID, sigc) defer signal.StopCatch(sigc) } diff --git a/cli/command/container/signals.go b/cli/command/container/signals.go new file mode 100644 index 000000000000..06e4d9eb6624 --- /dev/null +++ b/cli/command/container/signals.go @@ -0,0 +1,57 @@ +package container + +import ( + "context" + "fmt" + "os" + gosignal "os/signal" + + "github.com/docker/cli/cli/command" + "github.com/docker/docker/pkg/signal" + "github.com/sirupsen/logrus" +) + +// ForwardAllSignals forwards signals to the container +// +// The channel you pass in must already be setup to receive any signals you want to forward. +func ForwardAllSignals(ctx context.Context, cli command.Cli, cid string, sigc <-chan os.Signal) { + var s os.Signal + for { + select { + case s = <-sigc: + case <-ctx.Done(): + return + } + + if s == signal.SIGCHLD || s == signal.SIGPIPE { + continue + } + + // In go1.14+, the go runtime issues SIGURG as an interrupt to support pre-emptable system calls on Linux. + // Since we can't forward that along we'll check that here. + if isRuntimeSig(s) { + continue + } + var sig string + for sigStr, sigN := range signal.SignalMap { + if sigN == s { + sig = sigStr + break + } + } + if sig == "" { + fmt.Fprintf(cli.Err(), "Unsupported signal: %v. Discarding.\n", s) + continue + } + + if err := cli.Client().ContainerKill(ctx, cid, sig); err != nil { + logrus.Debugf("Error sending signal: %s", err) + } + } +} + +func notfiyAllSignals() chan os.Signal { + sigc := make(chan os.Signal, 128) + gosignal.Notify(sigc) + return sigc +} diff --git a/cli/command/container/signals_linux.go b/cli/command/container/signals_linux.go new file mode 100644 index 000000000000..7eeb91985023 --- /dev/null +++ b/cli/command/container/signals_linux.go @@ -0,0 +1,11 @@ +package container + +import ( + "os" + + "golang.org/x/sys/unix" +) + +func isRuntimeSig(s os.Signal) bool { + return s == unix.SIGURG +} diff --git a/cli/command/container/signals_linux_test.go b/cli/command/container/signals_linux_test.go new file mode 100644 index 000000000000..1b2eaff9b3e8 --- /dev/null +++ b/cli/command/container/signals_linux_test.go @@ -0,0 +1,57 @@ +package container + +import ( + "context" + "os" + "syscall" + "testing" + "time" + + "github.com/docker/cli/internal/test" + "golang.org/x/sys/unix" + "gotest.tools/v3/assert" +) + +func TestIgnoredSignals(t *testing.T) { + ignoredSignals := []syscall.Signal{unix.SIGPIPE, unix.SIGCHLD, unix.SIGURG} + + for _, s := range ignoredSignals { + t.Run(unix.SignalName(s), func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var called bool + client := &fakeClient{containerKillFunc: func(ctx context.Context, container, signal string) error { + called = true + return nil + }} + + cli := test.NewFakeCli(client) + sigc := make(chan os.Signal) + defer close(sigc) + + done := make(chan struct{}) + go func() { + ForwardAllSignals(ctx, cli, t.Name(), sigc) + close(done) + }() + + timer := time.NewTimer(30 * time.Second) + defer timer.Stop() + + select { + case <-timer.C: + t.Fatal("timeout waiting to send signal") + case sigc <- s: + case <-done: + } + + // cancel the context so ForwardAllSignals will exit after it has processed the signal we sent. + // This is how we know the signal was actually processed and are not introducing a flakey test. + cancel() + <-done + + assert.Assert(t, !called, "kill was called") + }) + } +} diff --git a/cli/command/container/signals_notlinux.go b/cli/command/container/signals_notlinux.go new file mode 100644 index 000000000000..9e8412f66ae6 --- /dev/null +++ b/cli/command/container/signals_notlinux.go @@ -0,0 +1,9 @@ +// +build !linux + +package container + +import "os" + +func isRuntimeSig(_ os.Signal) bool { + return false +} diff --git a/cli/command/container/signals_test.go b/cli/command/container/signals_test.go new file mode 100644 index 000000000000..39eaab1c5da8 --- /dev/null +++ b/cli/command/container/signals_test.go @@ -0,0 +1,48 @@ +package container + +import ( + "context" + "os" + "testing" + "time" + + "github.com/docker/cli/internal/test" + "github.com/docker/docker/pkg/signal" +) + +func TestForwardSignals(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + called := make(chan struct{}) + client := &fakeClient{containerKillFunc: func(ctx context.Context, container, signal string) error { + close(called) + return nil + }} + + cli := test.NewFakeCli(client) + sigc := make(chan os.Signal) + defer close(sigc) + + go ForwardAllSignals(ctx, cli, t.Name(), sigc) + + timer := time.NewTimer(30 * time.Second) + defer timer.Stop() + + select { + case <-timer.C: + t.Fatal("timeout waiting to send signal") + case sigc <- signal.SignalMap["TERM"]: + } + if !timer.Stop() { + <-timer.C + } + timer.Reset(30 * time.Second) + + select { + case <-called: + case <-timer.C: + t.Fatal("timeout waiting for signal to be processed") + } + +} diff --git a/cli/command/container/start.go b/cli/command/container/start.go index 5ab05ae68b2e..9029061e09de 100644 --- a/cli/command/container/start.go +++ b/cli/command/container/start.go @@ -74,7 +74,8 @@ func runStart(dockerCli command.Cli, opts *startOptions) error { // We always use c.ID instead of container to maintain consistency during `docker start` if !c.Config.Tty { - sigc := ForwardAllSignals(ctx, dockerCli, c.ID) + sigc := notfiyAllSignals() + ForwardAllSignals(ctx, dockerCli, c.ID, sigc) defer signal.StopCatch(sigc) } diff --git a/cli/command/container/tty.go b/cli/command/container/tty.go index b7003f1a04df..4060c0a19f9a 100644 --- a/cli/command/container/tty.go +++ b/cli/command/container/tty.go @@ -95,32 +95,3 @@ func MonitorTtySize(ctx context.Context, cli command.Cli, id string, isExec bool } return nil } - -// ForwardAllSignals forwards signals to the container -func ForwardAllSignals(ctx context.Context, cli command.Cli, cid string) chan os.Signal { - sigc := make(chan os.Signal, 128) - signal.CatchAll(sigc) - go func() { - for s := range sigc { - if s == signal.SIGCHLD || s == signal.SIGPIPE { - continue - } - var sig string - for sigStr, sigN := range signal.SignalMap { - if sigN == s { - sig = sigStr - break - } - } - if sig == "" { - fmt.Fprintf(cli.Err(), "Unsupported signal: %v. Discarding.\n", s) - continue - } - - if err := cli.Client().ContainerKill(ctx, cid, sig); err != nil { - logrus.Debugf("Error sending signal: %s", err) - } - } - }() - return sigc -}