From 3f0d90a2a939afa04a784b810dbdb035a0dff669 Mon Sep 17 00:00:00 2001 From: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> Date: Mon, 8 Apr 2024 10:11:09 +0200 Subject: [PATCH] feat: global signal handling with context cancellation Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> --- cli/command/container/attach.go | 7 ++- cli/command/container/client_test.go | 8 +++ cli/command/container/run.go | 7 ++- cli/command/container/run_test.go | 69 +++++++++++++++++++++ cli/command/container/start.go | 3 +- cli/command/utils.go | 10 +-- cli/command/utils_test.go | 6 +- cmd/docker/docker.go | 91 ++++++++++++++++++---------- internal/test/cmd.go | 8 ++- 9 files changed, 160 insertions(+), 49 deletions(-) diff --git a/cli/command/container/attach.go b/cli/command/container/attach.go index bf8341af5077..7f1cb258240a 100644 --- a/cli/command/container/attach.go +++ b/cli/command/container/attach.go @@ -105,7 +105,12 @@ func RunAttach(ctx context.Context, dockerCLI command.Cli, containerID string, o if opts.Proxy && !c.Config.Tty { sigc := notifyAllSignals() - go ForwardAllSignals(ctx, apiClient, containerID, sigc) + // since we're explicitly setting up signal handling here, and the daemon will + // get notified independently of the clients ctx cancellation, we use this context + // but without cancellation to avoid ForwardAllSignals from returning + // before all signals are forwarded. + bgCtx := context.WithoutCancel(ctx) + go ForwardAllSignals(bgCtx, apiClient, containerID, sigc) defer signal.StopCatch(sigc) } diff --git a/cli/command/container/client_test.go b/cli/command/container/client_test.go index c45f34040f43..985fd1dd2cce 100644 --- a/cli/command/container/client_test.go +++ b/cli/command/container/client_test.go @@ -37,6 +37,7 @@ type fakeClient struct { containerRemoveFunc func(ctx context.Context, containerID string, options container.RemoveOptions) error containerKillFunc func(ctx context.Context, containerID, signal string) error containerPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) + containerAttachFunc func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) Version string } @@ -173,3 +174,10 @@ func (f *fakeClient) ContainersPrune(ctx context.Context, pruneFilters filters.A } return types.ContainersPruneReport{}, nil } + +func (f *fakeClient) ContainerAttach(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) { + if f.containerAttachFunc != nil { + return f.containerAttachFunc(ctx, containerID, options) + } + return types.HijackedResponse{}, nil +} diff --git a/cli/command/container/run.go b/cli/command/container/run.go index 0da127c46379..749071b17332 100644 --- a/cli/command/container/run.go +++ b/cli/command/container/run.go @@ -150,7 +150,12 @@ func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOption } if runOpts.sigProxy { sigc := notifyAllSignals() - go ForwardAllSignals(ctx, apiClient, containerID, sigc) + // since we're explicitly setting up signal handling here, and the daemon will + // get notified independently of the clients ctx cancellation, we use this context + // but without cancellation to avoid ForwardAllSignals from returning + // before all signals are forwarded. + bgCtx := context.WithoutCancel(ctx) + go ForwardAllSignals(bgCtx, apiClient, containerID, sigc) defer signal.StopCatch(sigc) } diff --git a/cli/command/container/run_test.go b/cli/command/container/run_test.go index f2416f7d5a03..007fcb222c5c 100644 --- a/cli/command/container/run_test.go +++ b/cli/command/container/run_test.go @@ -5,11 +5,18 @@ import ( "errors" "fmt" "io" + "net" + "os/signal" + "syscall" "testing" + "time" + "github.com/creack/pty" "github.com/docker/cli/cli" + "github.com/docker/cli/cli/streams" "github.com/docker/cli/internal/test" "github.com/docker/cli/internal/test/notary" + "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/network" specs "github.com/opencontainers/image-spec/specs-go/v1" @@ -32,6 +39,68 @@ func TestRunLabel(t *testing.T) { assert.NilError(t, cmd.Execute()) } +func TestRunAttachTermination(t *testing.T) { + p, tty, err := pty.Open() + assert.NilError(t, err) + + defer func() { + _ = tty.Close() + _ = p.Close() + }() + + killCh := make(chan struct{}) + attachCh := make(chan struct{}) + fakeCLI := test.NewFakeCli(&fakeClient{ + createContainerFunc: func(_ *container.Config, _ *container.HostConfig, _ *network.NetworkingConfig, _ *specs.Platform, _ string) (container.CreateResponse, error) { + return container.CreateResponse{ + ID: "id", + }, nil + }, + containerKillFunc: func(ctx context.Context, containerID, signal string) error { + killCh <- struct{}{} + return nil + }, + containerAttachFunc: func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) { + server, client := net.Pipe() + t.Cleanup(func() { + _ = server.Close() + }) + attachCh <- struct{}{} + return types.NewHijackedResponse(client, types.MediaTypeRawStream), nil + }, + Version: "1.36", + }, func(fc *test.FakeCli) { + fc.SetOut(streams.NewOut(tty)) + fc.SetIn(streams.NewIn(tty)) + }) + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM) + defer cancel() + + assert.Equal(t, fakeCLI.In().IsTerminal(), true) + assert.Equal(t, fakeCLI.Out().IsTerminal(), true) + + cmd := NewRunCommand(fakeCLI) + cmd.SetArgs([]string{"-it", "busybox"}) + cmd.SilenceUsage = true + go func() { + assert.ErrorIs(t, cmd.ExecuteContext(ctx), context.Canceled) + }() + + select { + case <-time.After(5 * time.Second): + t.Fatal("containerAttachFunc was not called before the 5 second timeout") + case <-attachCh: + } + + assert.NilError(t, syscall.Kill(syscall.Getpid(), syscall.SIGTERM)) + select { + case <-time.After(5 * time.Second): + cancel() + t.Fatal("containerKillFunc was not called before the 5 second timeout") + case <-killCh: + } +} + func TestRunCommandWithContentTrustErrors(t *testing.T) { testCases := []struct { name string diff --git a/cli/command/container/start.go b/cli/command/container/start.go index 5cc019a73e8a..ba04e34435af 100644 --- a/cli/command/container/start.go +++ b/cli/command/container/start.go @@ -87,7 +87,8 @@ func RunStart(ctx context.Context, dockerCli command.Cli, opts *StartOptions) er // We always use c.ID instead of container to maintain consistency during `docker start` if !c.Config.Tty { sigc := notifyAllSignals() - go ForwardAllSignals(ctx, dockerCli.Client(), c.ID, sigc) + bgCtx := context.WithoutCancel(ctx) + go ForwardAllSignals(bgCtx, dockerCli.Client(), c.ID, sigc) defer signal.StopCatch(sigc) } diff --git a/cli/command/utils.go b/cli/command/utils.go index d7184f8c4e6a..5b2cb9721569 100644 --- a/cli/command/utils.go +++ b/cli/command/utils.go @@ -9,11 +9,9 @@ import ( "fmt" "io" "os" - "os/signal" "path/filepath" "runtime" "strings" - "syscall" "github.com/docker/cli/cli/streams" "github.com/docker/docker/api/types/filters" @@ -103,11 +101,6 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m result := make(chan bool) - // Catch the termination signal and exit the prompt gracefully. - // The caller is responsible for properly handling the termination. - notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) - defer notifyCancel() - go func() { var res bool scanner := bufio.NewScanner(ins) @@ -121,8 +114,7 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m }() select { - case <-notifyCtx.Done(): - // print a newline on termination + case <-ctx.Done(): _, _ = fmt.Fprintln(outs, "") return false, ErrPromptTerminated case r := <-result: diff --git a/cli/command/utils_test.go b/cli/command/utils_test.go index b1ea2dd74c51..1566067f3b38 100644 --- a/cli/command/utils_test.go +++ b/cli/command/utils_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "os" + "os/signal" "path/filepath" "strings" "syscall" @@ -135,6 +136,9 @@ func TestPromptForConfirmation(t *testing.T) { }, promptResult{false, nil}}, } { t.Run("case="+tc.desc, func(t *testing.T) { + notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) + t.Cleanup(notifyCancel) + buf.Reset() promptReader, promptWriter = io.Pipe() @@ -145,7 +149,7 @@ func TestPromptForConfirmation(t *testing.T) { result := make(chan promptResult, 1) go func() { - r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "") + r, err := command.PromptForConfirmation(notifyCtx, promptReader, promptOut, "") result <- promptResult{r, err} }() diff --git a/cmd/docker/docker.go b/cmd/docker/docker.go index 27eb0263bd61..f5dad2ec79fc 100644 --- a/cmd/docker/docker.go +++ b/cmd/docker/docker.go @@ -28,12 +28,20 @@ import ( ) func main() { - ctx := context.Background() + statusCode := dockerMain() + if statusCode != 0 { + os.Exit(statusCode) + } +} + +func dockerMain() int { + ctx, cancelNotify := signal.NotifyContext(context.Background(), platformsignals.TerminationSignals...) + defer cancelNotify() dockerCli, err := command.NewDockerCli(command.WithBaseContext(ctx)) if err != nil { fmt.Fprintln(os.Stderr, err) - os.Exit(1) + return 1 } logrus.SetOutput(dockerCli.Err()) otel.SetErrorHandler(debug.OTELErrorHandler) @@ -46,16 +54,17 @@ func main() { // StatusError should only be used for errors, and all errors should // have a non-zero exit status, so never exit with 0 if sterr.StatusCode == 0 { - os.Exit(1) + return 1 } - os.Exit(sterr.StatusCode) + return sterr.StatusCode } if errdefs.IsCancelled(err) { - os.Exit(0) + return 0 } fmt.Fprintln(dockerCli.Err(), err) - os.Exit(1) + return 1 } + return 0 } func newDockerCommand(dockerCli *command.DockerCli) *cli.TopLevelCommand { @@ -224,7 +233,7 @@ func setValidateArgs(dockerCli command.Cli, cmd *cobra.Command) { }) } -func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error { +func tryPluginRun(ctx context.Context, dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error { plugincmd, err := pluginmanager.PluginRunCommand(dockerCli, subcommand, cmd) if err != nil { return err @@ -242,40 +251,56 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, // Background signal handling logic: block on the signals channel, and // notify the plugin via the PluginServer (or signal) as appropriate. - const exitLimit = 3 - signals := make(chan os.Signal, exitLimit) - signal.Notify(signals, platformsignals.TerminationSignals...) + const exitLimit = 2 + + tryTerminatePlugin := func(force bool) { + // If stdin is a TTY, the kernel will forward + // signals to the subprocess because the shared + // pgid makes the TTY a controlling terminal. + // + // The plugin should have it's own copy of this + // termination logic, and exit after 3 retries + // on it's own. + if dockerCli.Out().IsTerminal() { + return + } + + // Terminate the plugin server, which will + // close all connections with plugin + // subprocesses, and signal them to exit. + // + // Repeated invocations will result in EINVAL, + // or EBADF; but that is fine for our purposes. + _ = srv.Close() + + // force the process to terminate if it hasn't already + if force { + _ = plugincmd.Process.Kill() + _, _ = fmt.Fprint(dockerCli.Err(), "got 3 SIGTERM/SIGINTs, forcefully exiting\n") + os.Exit(1) + } + } + go func() { retries := 0 - for range signals { - // If stdin is a TTY, the kernel will forward - // signals to the subprocess because the shared - // pgid makes the TTY a controlling terminal. - // - // The plugin should have it's own copy of this - // termination logic, and exit after 3 retries - // on it's own. - if dockerCli.Out().IsTerminal() { - continue - } + force := false + // catch the first signal through context cancellation + <-ctx.Done() + tryTerminatePlugin(force) - // Terminate the plugin server, which will - // close all connections with plugin - // subprocesses, and signal them to exit. - // - // Repeated invocations will result in EINVAL, - // or EBADF; but that is fine for our purposes. - _ = srv.Close() + // register subsequent signals + signals := make(chan os.Signal, exitLimit) + signal.Notify(signals, platformsignals.TerminationSignals...) + for range signals { + retries++ // If we're still running after 3 interruptions // (SIGINT/SIGTERM), send a SIGKILL to the plugin as a // final attempt to terminate, and exit. - retries++ if retries >= exitLimit { - _, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries) - _ = plugincmd.Process.Kill() - os.Exit(1) + force = true } + tryTerminatePlugin(force) } }() @@ -338,7 +363,7 @@ func runDocker(ctx context.Context, dockerCli *command.DockerCli) error { ccmd, _, err := cmd.Find(args) subCommand = ccmd if err != nil || pluginmanager.IsPluginCommand(ccmd) { - err := tryPluginRun(dockerCli, cmd, args[0], envs) + err := tryPluginRun(ctx, dockerCli, cmd, args[0], envs) if err == nil { if dockerCli.HooksEnabled() && dockerCli.Out().IsTerminal() && ccmd != nil { pluginmanager.RunPluginHooks(ctx, dockerCli, cmd, ccmd, args) diff --git a/internal/test/cmd.go b/internal/test/cmd.go index 04df496a833b..52b44d66f1d4 100644 --- a/internal/test/cmd.go +++ b/internal/test/cmd.go @@ -3,7 +3,6 @@ package test import ( "context" "os" - "syscall" "testing" "time" @@ -32,8 +31,11 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli assert.NilError(t, err) cli.SetIn(streams.NewIn(r)) + notifyCtx, notifyCancel := context.WithCancel(ctx) + t.Cleanup(notifyCancel) + go func() { - errChan <- cmd.ExecuteContext(ctx) + errChan <- cmd.ExecuteContext(notifyCtx) }() writeCtx, writeCancel := context.WithTimeout(ctx, 100*time.Millisecond) @@ -66,7 +68,7 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli // sigint and sigterm are caught by the prompt // this allows us to gracefully exit the prompt with a 0 exit code - syscall.Kill(syscall.Getpid(), syscall.SIGINT) + notifyCancel() select { case <-errCtx.Done():