Skip to content

Commit

Permalink
feat: global signal handling with context cancellation
Browse files Browse the repository at this point in the history
Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com>
  • Loading branch information
Benehiko committed Jun 7, 2024
1 parent 8b924a5 commit 3f0d90a
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 49 deletions.
7 changes: 6 additions & 1 deletion cli/command/container/attach.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ func RunAttach(ctx context.Context, dockerCLI command.Cli, containerID string, o

if opts.Proxy && !c.Config.Tty {
sigc := notifyAllSignals()
go ForwardAllSignals(ctx, apiClient, containerID, sigc)
// since we're explicitly setting up signal handling here, and the daemon will
// get notified independently of the clients ctx cancellation, we use this context
// but without cancellation to avoid ForwardAllSignals from returning
// before all signals are forwarded.
bgCtx := context.WithoutCancel(ctx)
go ForwardAllSignals(bgCtx, apiClient, containerID, sigc)
defer signal.StopCatch(sigc)
}

Expand Down
8 changes: 8 additions & 0 deletions cli/command/container/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type fakeClient struct {
containerRemoveFunc func(ctx context.Context, containerID string, options container.RemoveOptions) error
containerKillFunc func(ctx context.Context, containerID, signal string) error
containerPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error)
containerAttachFunc func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error)
Version string
}

Expand Down Expand Up @@ -173,3 +174,10 @@ func (f *fakeClient) ContainersPrune(ctx context.Context, pruneFilters filters.A
}
return types.ContainersPruneReport{}, nil
}

func (f *fakeClient) ContainerAttach(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) {
if f.containerAttachFunc != nil {
return f.containerAttachFunc(ctx, containerID, options)
}
return types.HijackedResponse{}, nil
}
7 changes: 6 additions & 1 deletion cli/command/container/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,12 @@ func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOption
}
if runOpts.sigProxy {
sigc := notifyAllSignals()
go ForwardAllSignals(ctx, apiClient, containerID, sigc)
// since we're explicitly setting up signal handling here, and the daemon will
// get notified independently of the clients ctx cancellation, we use this context
// but without cancellation to avoid ForwardAllSignals from returning
// before all signals are forwarded.
bgCtx := context.WithoutCancel(ctx)
go ForwardAllSignals(bgCtx, apiClient, containerID, sigc)
defer signal.StopCatch(sigc)
}

Expand Down
69 changes: 69 additions & 0 deletions cli/command/container/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@ import (
"errors"
"fmt"
"io"
"net"
"os/signal"
"syscall"
"testing"
"time"

"github.com/creack/pty"
"github.com/docker/cli/cli"
"github.com/docker/cli/cli/streams"
"github.com/docker/cli/internal/test"
"github.com/docker/cli/internal/test/notary"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network"
specs "github.com/opencontainers/image-spec/specs-go/v1"
Expand All @@ -32,6 +39,68 @@ func TestRunLabel(t *testing.T) {
assert.NilError(t, cmd.Execute())
}

func TestRunAttachTermination(t *testing.T) {
p, tty, err := pty.Open()
assert.NilError(t, err)

defer func() {
_ = tty.Close()
_ = p.Close()
}()

killCh := make(chan struct{})
attachCh := make(chan struct{})
fakeCLI := test.NewFakeCli(&fakeClient{
createContainerFunc: func(_ *container.Config, _ *container.HostConfig, _ *network.NetworkingConfig, _ *specs.Platform, _ string) (container.CreateResponse, error) {
return container.CreateResponse{
ID: "id",
}, nil
},
containerKillFunc: func(ctx context.Context, containerID, signal string) error {
killCh <- struct{}{}
return nil
},
containerAttachFunc: func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) {
server, client := net.Pipe()
t.Cleanup(func() {
_ = server.Close()
})
attachCh <- struct{}{}
return types.NewHijackedResponse(client, types.MediaTypeRawStream), nil
},
Version: "1.36",
}, func(fc *test.FakeCli) {
fc.SetOut(streams.NewOut(tty))
fc.SetIn(streams.NewIn(tty))
})
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer cancel()

assert.Equal(t, fakeCLI.In().IsTerminal(), true)
assert.Equal(t, fakeCLI.Out().IsTerminal(), true)

cmd := NewRunCommand(fakeCLI)
cmd.SetArgs([]string{"-it", "busybox"})
cmd.SilenceUsage = true
go func() {
assert.ErrorIs(t, cmd.ExecuteContext(ctx), context.Canceled)
}()

select {
case <-time.After(5 * time.Second):
t.Fatal("containerAttachFunc was not called before the 5 second timeout")
case <-attachCh:
}

assert.NilError(t, syscall.Kill(syscall.Getpid(), syscall.SIGTERM))
select {
case <-time.After(5 * time.Second):
cancel()
t.Fatal("containerKillFunc was not called before the 5 second timeout")
case <-killCh:
}
}

func TestRunCommandWithContentTrustErrors(t *testing.T) {
testCases := []struct {
name string
Expand Down
3 changes: 2 additions & 1 deletion cli/command/container/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func RunStart(ctx context.Context, dockerCli command.Cli, opts *StartOptions) er
// We always use c.ID instead of container to maintain consistency during `docker start`
if !c.Config.Tty {
sigc := notifyAllSignals()
go ForwardAllSignals(ctx, dockerCli.Client(), c.ID, sigc)
bgCtx := context.WithoutCancel(ctx)
go ForwardAllSignals(bgCtx, dockerCli.Client(), c.ID, sigc)
defer signal.StopCatch(sigc)
}

Expand Down
10 changes: 1 addition & 9 deletions cli/command/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ import (
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"syscall"

"github.com/docker/cli/cli/streams"
"github.com/docker/docker/api/types/filters"
Expand Down Expand Up @@ -103,11 +101,6 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m

result := make(chan bool)

// Catch the termination signal and exit the prompt gracefully.
// The caller is responsible for properly handling the termination.
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer notifyCancel()

go func() {
var res bool
scanner := bufio.NewScanner(ins)
Expand All @@ -121,8 +114,7 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m
}()

select {
case <-notifyCtx.Done():
// print a newline on termination
case <-ctx.Done():
_, _ = fmt.Fprintln(outs, "")
return false, ErrPromptTerminated
case r := <-result:
Expand Down
6 changes: 5 additions & 1 deletion cli/command/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
Expand Down Expand Up @@ -135,6 +136,9 @@ func TestPromptForConfirmation(t *testing.T) {
}, promptResult{false, nil}},
} {
t.Run("case="+tc.desc, func(t *testing.T) {
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
t.Cleanup(notifyCancel)

buf.Reset()
promptReader, promptWriter = io.Pipe()

Expand All @@ -145,7 +149,7 @@ func TestPromptForConfirmation(t *testing.T) {

result := make(chan promptResult, 1)
go func() {
r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "")
r, err := command.PromptForConfirmation(notifyCtx, promptReader, promptOut, "")
result <- promptResult{r, err}
}()

Expand Down
91 changes: 58 additions & 33 deletions cmd/docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@ import (
)

func main() {
ctx := context.Background()
statusCode := dockerMain()
if statusCode != 0 {
os.Exit(statusCode)
}
}

func dockerMain() int {
ctx, cancelNotify := signal.NotifyContext(context.Background(), platformsignals.TerminationSignals...)
defer cancelNotify()

dockerCli, err := command.NewDockerCli(command.WithBaseContext(ctx))
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
return 1
}
logrus.SetOutput(dockerCli.Err())
otel.SetErrorHandler(debug.OTELErrorHandler)
Expand All @@ -46,16 +54,17 @@ func main() {
// StatusError should only be used for errors, and all errors should
// have a non-zero exit status, so never exit with 0
if sterr.StatusCode == 0 {
os.Exit(1)
return 1
}
os.Exit(sterr.StatusCode)
return sterr.StatusCode
}
if errdefs.IsCancelled(err) {
os.Exit(0)
return 0
}
fmt.Fprintln(dockerCli.Err(), err)
os.Exit(1)
return 1
}
return 0
}

func newDockerCommand(dockerCli *command.DockerCli) *cli.TopLevelCommand {
Expand Down Expand Up @@ -224,7 +233,7 @@ func setValidateArgs(dockerCli command.Cli, cmd *cobra.Command) {
})
}

func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
func tryPluginRun(ctx context.Context, dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
plugincmd, err := pluginmanager.PluginRunCommand(dockerCli, subcommand, cmd)
if err != nil {
return err
Expand All @@ -242,40 +251,56 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,

// Background signal handling logic: block on the signals channel, and
// notify the plugin via the PluginServer (or signal) as appropriate.
const exitLimit = 3
signals := make(chan os.Signal, exitLimit)
signal.Notify(signals, platformsignals.TerminationSignals...)
const exitLimit = 2

tryTerminatePlugin := func(force bool) {
// If stdin is a TTY, the kernel will forward
// signals to the subprocess because the shared
// pgid makes the TTY a controlling terminal.
//
// The plugin should have it's own copy of this
// termination logic, and exit after 3 retries
// on it's own.
if dockerCli.Out().IsTerminal() {
return
}

// Terminate the plugin server, which will
// close all connections with plugin
// subprocesses, and signal them to exit.
//
// Repeated invocations will result in EINVAL,
// or EBADF; but that is fine for our purposes.
_ = srv.Close()

// force the process to terminate if it hasn't already
if force {
_ = plugincmd.Process.Kill()
_, _ = fmt.Fprint(dockerCli.Err(), "got 3 SIGTERM/SIGINTs, forcefully exiting\n")
os.Exit(1)
}
}

go func() {
retries := 0
for range signals {
// If stdin is a TTY, the kernel will forward
// signals to the subprocess because the shared
// pgid makes the TTY a controlling terminal.
//
// The plugin should have it's own copy of this
// termination logic, and exit after 3 retries
// on it's own.
if dockerCli.Out().IsTerminal() {
continue
}
force := false
// catch the first signal through context cancellation
<-ctx.Done()
tryTerminatePlugin(force)

// Terminate the plugin server, which will
// close all connections with plugin
// subprocesses, and signal them to exit.
//
// Repeated invocations will result in EINVAL,
// or EBADF; but that is fine for our purposes.
_ = srv.Close()
// register subsequent signals
signals := make(chan os.Signal, exitLimit)
signal.Notify(signals, platformsignals.TerminationSignals...)

for range signals {
retries++
// If we're still running after 3 interruptions
// (SIGINT/SIGTERM), send a SIGKILL to the plugin as a
// final attempt to terminate, and exit.
retries++
if retries >= exitLimit {
_, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries)
_ = plugincmd.Process.Kill()
os.Exit(1)
force = true
}
tryTerminatePlugin(force)
}
}()

Expand Down Expand Up @@ -338,7 +363,7 @@ func runDocker(ctx context.Context, dockerCli *command.DockerCli) error {
ccmd, _, err := cmd.Find(args)
subCommand = ccmd
if err != nil || pluginmanager.IsPluginCommand(ccmd) {
err := tryPluginRun(dockerCli, cmd, args[0], envs)
err := tryPluginRun(ctx, dockerCli, cmd, args[0], envs)
if err == nil {
if dockerCli.HooksEnabled() && dockerCli.Out().IsTerminal() && ccmd != nil {
pluginmanager.RunPluginHooks(ctx, dockerCli, cmd, ccmd, args)
Expand Down
8 changes: 5 additions & 3 deletions internal/test/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package test
import (
"context"
"os"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -32,8 +31,11 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli
assert.NilError(t, err)
cli.SetIn(streams.NewIn(r))

notifyCtx, notifyCancel := context.WithCancel(ctx)
t.Cleanup(notifyCancel)

go func() {
errChan <- cmd.ExecuteContext(ctx)
errChan <- cmd.ExecuteContext(notifyCtx)
}()

writeCtx, writeCancel := context.WithTimeout(ctx, 100*time.Millisecond)
Expand Down Expand Up @@ -66,7 +68,7 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli

// sigint and sigterm are caught by the prompt
// this allows us to gracefully exit the prompt with a 0 exit code
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
notifyCancel()

select {
case <-errCtx.Done():
Expand Down

0 comments on commit 3f0d90a

Please sign in to comment.