From 3f0d90a2a939afa04a784b810dbdb035a0dff669 Mon Sep 17 00:00:00 2001
From: Alano Terblanche <18033717+Benehiko@users.noreply.github.com>
Date: Mon, 8 Apr 2024 10:11:09 +0200
Subject: [PATCH] feat: global signal handling with context cancellation

Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com>
---
 cli/command/container/attach.go      |  7 ++-
 cli/command/container/client_test.go |  8 +++
 cli/command/container/run.go         |  7 ++-
 cli/command/container/run_test.go    | 69 +++++++++++++++++++++
 cli/command/container/start.go       |  3 +-
 cli/command/utils.go                 | 10 +--
 cli/command/utils_test.go            |  6 +-
 cmd/docker/docker.go                 | 91 ++++++++++++++++++----------
 internal/test/cmd.go                 |  8 ++-
 9 files changed, 160 insertions(+), 49 deletions(-)

diff --git a/cli/command/container/attach.go b/cli/command/container/attach.go
index bf8341af5077..7f1cb258240a 100644
--- a/cli/command/container/attach.go
+++ b/cli/command/container/attach.go
@@ -105,7 +105,12 @@ func RunAttach(ctx context.Context, dockerCLI command.Cli, containerID string, o
 
 	if opts.Proxy && !c.Config.Tty {
 		sigc := notifyAllSignals()
-		go ForwardAllSignals(ctx, apiClient, containerID, sigc)
+		// since we're explicitly setting up signal handling here, and the daemon will
+		// get notified independently of the clients ctx cancellation, we use this context
+		// but without cancellation to avoid ForwardAllSignals from returning
+		// before all signals are forwarded.
+		bgCtx := context.WithoutCancel(ctx)
+		go ForwardAllSignals(bgCtx, apiClient, containerID, sigc)
 		defer signal.StopCatch(sigc)
 	}
 
diff --git a/cli/command/container/client_test.go b/cli/command/container/client_test.go
index c45f34040f43..985fd1dd2cce 100644
--- a/cli/command/container/client_test.go
+++ b/cli/command/container/client_test.go
@@ -37,6 +37,7 @@ type fakeClient struct {
 	containerRemoveFunc     func(ctx context.Context, containerID string, options container.RemoveOptions) error
 	containerKillFunc       func(ctx context.Context, containerID, signal string) error
 	containerPruneFunc      func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error)
+	containerAttachFunc     func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error)
 	Version                 string
 }
 
@@ -173,3 +174,10 @@ func (f *fakeClient) ContainersPrune(ctx context.Context, pruneFilters filters.A
 	}
 	return types.ContainersPruneReport{}, nil
 }
+
+func (f *fakeClient) ContainerAttach(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) {
+	if f.containerAttachFunc != nil {
+		return f.containerAttachFunc(ctx, containerID, options)
+	}
+	return types.HijackedResponse{}, nil
+}
diff --git a/cli/command/container/run.go b/cli/command/container/run.go
index 0da127c46379..749071b17332 100644
--- a/cli/command/container/run.go
+++ b/cli/command/container/run.go
@@ -150,7 +150,12 @@ func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOption
 	}
 	if runOpts.sigProxy {
 		sigc := notifyAllSignals()
-		go ForwardAllSignals(ctx, apiClient, containerID, sigc)
+		// since we're explicitly setting up signal handling here, and the daemon will
+		// get notified independently of the clients ctx cancellation, we use this context
+		// but without cancellation to avoid ForwardAllSignals from returning
+		// before all signals are forwarded.
+		bgCtx := context.WithoutCancel(ctx)
+		go ForwardAllSignals(bgCtx, apiClient, containerID, sigc)
 		defer signal.StopCatch(sigc)
 	}
 
diff --git a/cli/command/container/run_test.go b/cli/command/container/run_test.go
index f2416f7d5a03..007fcb222c5c 100644
--- a/cli/command/container/run_test.go
+++ b/cli/command/container/run_test.go
@@ -5,11 +5,18 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net"
+	"os/signal"
+	"syscall"
 	"testing"
+	"time"
 
+	"github.com/creack/pty"
 	"github.com/docker/cli/cli"
+	"github.com/docker/cli/cli/streams"
 	"github.com/docker/cli/internal/test"
 	"github.com/docker/cli/internal/test/notary"
+	"github.com/docker/docker/api/types"
 	"github.com/docker/docker/api/types/container"
 	"github.com/docker/docker/api/types/network"
 	specs "github.com/opencontainers/image-spec/specs-go/v1"
@@ -32,6 +39,68 @@ func TestRunLabel(t *testing.T) {
 	assert.NilError(t, cmd.Execute())
 }
 
+func TestRunAttachTermination(t *testing.T) {
+	p, tty, err := pty.Open()
+	assert.NilError(t, err)
+
+	defer func() {
+		_ = tty.Close()
+		_ = p.Close()
+	}()
+
+	killCh := make(chan struct{})
+	attachCh := make(chan struct{})
+	fakeCLI := test.NewFakeCli(&fakeClient{
+		createContainerFunc: func(_ *container.Config, _ *container.HostConfig, _ *network.NetworkingConfig, _ *specs.Platform, _ string) (container.CreateResponse, error) {
+			return container.CreateResponse{
+				ID: "id",
+			}, nil
+		},
+		containerKillFunc: func(ctx context.Context, containerID, signal string) error {
+			killCh <- struct{}{}
+			return nil
+		},
+		containerAttachFunc: func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) {
+			server, client := net.Pipe()
+			t.Cleanup(func() {
+				_ = server.Close()
+			})
+			attachCh <- struct{}{}
+			return types.NewHijackedResponse(client, types.MediaTypeRawStream), nil
+		},
+		Version: "1.36",
+	}, func(fc *test.FakeCli) {
+		fc.SetOut(streams.NewOut(tty))
+		fc.SetIn(streams.NewIn(tty))
+	})
+	ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM)
+	defer cancel()
+
+	assert.Equal(t, fakeCLI.In().IsTerminal(), true)
+	assert.Equal(t, fakeCLI.Out().IsTerminal(), true)
+
+	cmd := NewRunCommand(fakeCLI)
+	cmd.SetArgs([]string{"-it", "busybox"})
+	cmd.SilenceUsage = true
+	go func() {
+		assert.ErrorIs(t, cmd.ExecuteContext(ctx), context.Canceled)
+	}()
+
+	select {
+	case <-time.After(5 * time.Second):
+		t.Fatal("containerAttachFunc was not called before the 5 second timeout")
+	case <-attachCh:
+	}
+
+	assert.NilError(t, syscall.Kill(syscall.Getpid(), syscall.SIGTERM))
+	select {
+	case <-time.After(5 * time.Second):
+		cancel()
+		t.Fatal("containerKillFunc was not called before the 5 second timeout")
+	case <-killCh:
+	}
+}
+
 func TestRunCommandWithContentTrustErrors(t *testing.T) {
 	testCases := []struct {
 		name          string
diff --git a/cli/command/container/start.go b/cli/command/container/start.go
index 5cc019a73e8a..ba04e34435af 100644
--- a/cli/command/container/start.go
+++ b/cli/command/container/start.go
@@ -87,7 +87,8 @@ func RunStart(ctx context.Context, dockerCli command.Cli, opts *StartOptions) er
 		// We always use c.ID instead of container to maintain consistency during `docker start`
 		if !c.Config.Tty {
 			sigc := notifyAllSignals()
-			go ForwardAllSignals(ctx, dockerCli.Client(), c.ID, sigc)
+			bgCtx := context.WithoutCancel(ctx)
+			go ForwardAllSignals(bgCtx, dockerCli.Client(), c.ID, sigc)
 			defer signal.StopCatch(sigc)
 		}
 
diff --git a/cli/command/utils.go b/cli/command/utils.go
index d7184f8c4e6a..5b2cb9721569 100644
--- a/cli/command/utils.go
+++ b/cli/command/utils.go
@@ -9,11 +9,9 @@ import (
 	"fmt"
 	"io"
 	"os"
-	"os/signal"
 	"path/filepath"
 	"runtime"
 	"strings"
-	"syscall"
 
 	"github.com/docker/cli/cli/streams"
 	"github.com/docker/docker/api/types/filters"
@@ -103,11 +101,6 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m
 
 	result := make(chan bool)
 
-	// Catch the termination signal and exit the prompt gracefully.
-	// The caller is responsible for properly handling the termination.
-	notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
-	defer notifyCancel()
-
 	go func() {
 		var res bool
 		scanner := bufio.NewScanner(ins)
@@ -121,8 +114,7 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m
 	}()
 
 	select {
-	case <-notifyCtx.Done():
-		// print a newline on termination
+	case <-ctx.Done():
 		_, _ = fmt.Fprintln(outs, "")
 		return false, ErrPromptTerminated
 	case r := <-result:
diff --git a/cli/command/utils_test.go b/cli/command/utils_test.go
index b1ea2dd74c51..1566067f3b38 100644
--- a/cli/command/utils_test.go
+++ b/cli/command/utils_test.go
@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"io"
 	"os"
+	"os/signal"
 	"path/filepath"
 	"strings"
 	"syscall"
@@ -135,6 +136,9 @@ func TestPromptForConfirmation(t *testing.T) {
 		}, promptResult{false, nil}},
 	} {
 		t.Run("case="+tc.desc, func(t *testing.T) {
+			notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
+			t.Cleanup(notifyCancel)
+
 			buf.Reset()
 			promptReader, promptWriter = io.Pipe()
 
@@ -145,7 +149,7 @@ func TestPromptForConfirmation(t *testing.T) {
 
 			result := make(chan promptResult, 1)
 			go func() {
-				r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "")
+				r, err := command.PromptForConfirmation(notifyCtx, promptReader, promptOut, "")
 				result <- promptResult{r, err}
 			}()
 
diff --git a/cmd/docker/docker.go b/cmd/docker/docker.go
index 27eb0263bd61..f5dad2ec79fc 100644
--- a/cmd/docker/docker.go
+++ b/cmd/docker/docker.go
@@ -28,12 +28,20 @@ import (
 )
 
 func main() {
-	ctx := context.Background()
+	statusCode := dockerMain()
+	if statusCode != 0 {
+		os.Exit(statusCode)
+	}
+}
+
+func dockerMain() int {
+	ctx, cancelNotify := signal.NotifyContext(context.Background(), platformsignals.TerminationSignals...)
+	defer cancelNotify()
 
 	dockerCli, err := command.NewDockerCli(command.WithBaseContext(ctx))
 	if err != nil {
 		fmt.Fprintln(os.Stderr, err)
-		os.Exit(1)
+		return 1
 	}
 	logrus.SetOutput(dockerCli.Err())
 	otel.SetErrorHandler(debug.OTELErrorHandler)
@@ -46,16 +54,17 @@ func main() {
 			// StatusError should only be used for errors, and all errors should
 			// have a non-zero exit status, so never exit with 0
 			if sterr.StatusCode == 0 {
-				os.Exit(1)
+				return 1
 			}
-			os.Exit(sterr.StatusCode)
+			return sterr.StatusCode
 		}
 		if errdefs.IsCancelled(err) {
-			os.Exit(0)
+			return 0
 		}
 		fmt.Fprintln(dockerCli.Err(), err)
-		os.Exit(1)
+		return 1
 	}
+	return 0
 }
 
 func newDockerCommand(dockerCli *command.DockerCli) *cli.TopLevelCommand {
@@ -224,7 +233,7 @@ func setValidateArgs(dockerCli command.Cli, cmd *cobra.Command) {
 	})
 }
 
-func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
+func tryPluginRun(ctx context.Context, dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
 	plugincmd, err := pluginmanager.PluginRunCommand(dockerCli, subcommand, cmd)
 	if err != nil {
 		return err
@@ -242,40 +251,56 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
 
 	// Background signal handling logic: block on the signals channel, and
 	// notify the plugin via the PluginServer (or signal) as appropriate.
-	const exitLimit = 3
-	signals := make(chan os.Signal, exitLimit)
-	signal.Notify(signals, platformsignals.TerminationSignals...)
+	const exitLimit = 2
+
+	tryTerminatePlugin := func(force bool) {
+		// If stdin is a TTY, the kernel will forward
+		// signals to the subprocess because the shared
+		// pgid makes the TTY a controlling terminal.
+		//
+		// The plugin should have it's own copy of this
+		// termination logic, and exit after 3 retries
+		// on it's own.
+		if dockerCli.Out().IsTerminal() {
+			return
+		}
+
+		// Terminate the plugin server, which will
+		// close all connections with plugin
+		// subprocesses, and signal them to exit.
+		//
+		// Repeated invocations will result in EINVAL,
+		// or EBADF; but that is fine for our purposes.
+		_ = srv.Close()
+
+		// force the process to terminate if it hasn't already
+		if force {
+			_ = plugincmd.Process.Kill()
+			_, _ = fmt.Fprint(dockerCli.Err(), "got 3 SIGTERM/SIGINTs, forcefully exiting\n")
+			os.Exit(1)
+		}
+	}
+
 	go func() {
 		retries := 0
-		for range signals {
-			// If stdin is a TTY, the kernel will forward
-			// signals to the subprocess because the shared
-			// pgid makes the TTY a controlling terminal.
-			//
-			// The plugin should have it's own copy of this
-			// termination logic, and exit after 3 retries
-			// on it's own.
-			if dockerCli.Out().IsTerminal() {
-				continue
-			}
+		force := false
+		// catch the first signal through context cancellation
+		<-ctx.Done()
+		tryTerminatePlugin(force)
 
-			// Terminate the plugin server, which will
-			// close all connections with plugin
-			// subprocesses, and signal them to exit.
-			//
-			// Repeated invocations will result in EINVAL,
-			// or EBADF; but that is fine for our purposes.
-			_ = srv.Close()
+		// register subsequent signals
+		signals := make(chan os.Signal, exitLimit)
+		signal.Notify(signals, platformsignals.TerminationSignals...)
 
+		for range signals {
+			retries++
 			// If we're still running after 3 interruptions
 			// (SIGINT/SIGTERM), send a SIGKILL to the plugin as a
 			// final attempt to terminate, and exit.
-			retries++
 			if retries >= exitLimit {
-				_, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries)
-				_ = plugincmd.Process.Kill()
-				os.Exit(1)
+				force = true
 			}
+			tryTerminatePlugin(force)
 		}
 	}()
 
@@ -338,7 +363,7 @@ func runDocker(ctx context.Context, dockerCli *command.DockerCli) error {
 		ccmd, _, err := cmd.Find(args)
 		subCommand = ccmd
 		if err != nil || pluginmanager.IsPluginCommand(ccmd) {
-			err := tryPluginRun(dockerCli, cmd, args[0], envs)
+			err := tryPluginRun(ctx, dockerCli, cmd, args[0], envs)
 			if err == nil {
 				if dockerCli.HooksEnabled() && dockerCli.Out().IsTerminal() && ccmd != nil {
 					pluginmanager.RunPluginHooks(ctx, dockerCli, cmd, ccmd, args)
diff --git a/internal/test/cmd.go b/internal/test/cmd.go
index 04df496a833b..52b44d66f1d4 100644
--- a/internal/test/cmd.go
+++ b/internal/test/cmd.go
@@ -3,7 +3,6 @@ package test
 import (
 	"context"
 	"os"
-	"syscall"
 	"testing"
 	"time"
 
@@ -32,8 +31,11 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli
 	assert.NilError(t, err)
 	cli.SetIn(streams.NewIn(r))
 
+	notifyCtx, notifyCancel := context.WithCancel(ctx)
+	t.Cleanup(notifyCancel)
+
 	go func() {
-		errChan <- cmd.ExecuteContext(ctx)
+		errChan <- cmd.ExecuteContext(notifyCtx)
 	}()
 
 	writeCtx, writeCancel := context.WithTimeout(ctx, 100*time.Millisecond)
@@ -66,7 +68,7 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli
 
 	// sigint and sigterm are caught by the prompt
 	// this allows us to gracefully exit the prompt with a 0 exit code
-	syscall.Kill(syscall.Getpid(), syscall.SIGINT)
+	notifyCancel()
 
 	select {
 	case <-errCtx.Done():