From 9d12e2b8b7e6ee3a523d2950a23044f5b34c85c0 Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy Date: Tue, 17 Oct 2023 17:43:10 -0400 Subject: [PATCH] Update Cmd IO handling Update `Cmd.Wait` to return a known error value if it times out waiting on IO copy after the command exits (and update `TestCmdStuckIo` to check for that error). Prior, the test checked for an `io.ErrClosedPipe`, which: 1. is not the best indicator that IO is stuck; and 2. is now ignored as an error value raised during IO relay. Update `stuckIOProcess` logic in `cmd_test.go` to mirror logic in `interal/exec.Exec`, using `os.Pipe` for std io that returns an `io.EOF` (instead of `io.Pipe`, which does not). Signed-off-by: Hamza El-Saawy --- internal/cmd/cmd.go | 53 +++++++++++++++++++++++++++++-------- internal/cmd/cmd_test.go | 56 ++++++++++++++++++++++++++++------------ internal/cmd/io.go | 4 +-- 3 files changed, 84 insertions(+), 29 deletions(-) diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index d7228619eb..aaaf7fb37a 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -5,6 +5,7 @@ package cmd import ( "bytes" "context" + "errors" "fmt" "io" "strings" @@ -20,6 +21,8 @@ import ( "golang.org/x/sys/windows" ) +var errIOTimeOut = errors.New("timed out waiting for stdio relay") + // CmdProcessRequest stores information on command requests made through this package. type CmdProcessRequest struct { Args []string @@ -62,7 +65,7 @@ type Cmd struct { // ExitState is filled out after Wait() (or Run() or Output()) completes. ExitState *ExitState - iogrp errgroup.Group + ioGrp errgroup.Group stdinErr atomic.Value allDoneCh chan struct{} } @@ -90,13 +93,13 @@ func (err *ExitError) Error() string { return fmt.Sprintf("process exited with exit code %d", err.ExitCode()) } -// Additional fields to hcsschema.ProcessParameters used by LCOW +// Additional fields to hcsschema.ProcessParameters used by LCOW. type lcowProcessParameters struct { hcsschema.ProcessParameters OCIProcess *specs.Process `json:"OciProcess,omitempty"` } -// escapeArgs makes a Windows-style escaped command line from a set of arguments +// escapeArgs makes a Windows-style escaped command line from a set of arguments. func escapeArgs(args []string) string { escapedArgs := make([]string, len(args)) for i, a := range args { @@ -136,9 +139,19 @@ func CommandContext(ctx context.Context, host cow.ProcessHost, name string, arg // Start starts a command. The caller must ensure that if Start succeeds, // Wait is eventually called to clean up resources. func (c *Cmd) Start() error { + if c.Host == nil { + return errors.New("empty ProcessHost") + } + + // closed in (*Cmd).Wait; signals command execution is done c.allDoneCh = make(chan struct{}) + var x interface{} if !c.Host.IsOCI() { + if c.Spec == nil { + return errors.New("process spec is required for non-OCI ProcessHost") + } + wpp := &hcsschema.ProcessParameters{ CommandLine: c.Spec.CommandLine, User: c.Spec.User.Username, @@ -199,6 +212,7 @@ func (c *Cmd) Start() error { // Start relaying process IO. stdin, stdout, stderr := p.Stdio() if c.Stdin != nil { + c.Log.Info("coping stdin") // Do not make stdin part of the error group because there is no way for // us or the caller to reliably unblock the c.Stdin read when the // process exits. @@ -218,20 +232,20 @@ func (c *Cmd) Start() error { } if c.Stdout != nil { - c.iogrp.Go(func() error { + c.ioGrp.Go(func() error { _, err := relayIO(c.Stdout, stdout, c.Log, "stdout") - if err := p.CloseStdout(context.TODO()); err != nil { - c.Log.WithError(err).Warn("failed to close Cmd stdout") + if cErr := p.CloseStdout(context.TODO()); cErr != nil && c.Log != nil { + c.Log.WithError(cErr).Warn("failed to close Cmd stdout") } return err }) } if c.Stderr != nil { - c.iogrp.Go(func() error { + c.ioGrp.Go(func() error { _, err := relayIO(c.Stderr, stderr, c.Log, "stderr") - if err := p.CloseStderr(context.TODO()); err != nil { - c.Log.WithError(err).Warn("failed to close Cmd stderr") + if cErr := p.CloseStderr(context.TODO()); cErr != nil && c.Log != nil { + c.Log.WithError(cErr).Warn("failed to close Cmd stderr") } return err }) @@ -270,9 +284,16 @@ func (c *Cmd) Wait() error { state.exited = true state.code = code } + // Terminate the IO if the copy does not complete in the requested time. + // Closing the process should (eventually) lead to unblocking `ioGrp`, but we still need + // `timeoutErrCh` to: + // 1. communitate that the IO copy timed out; and + // 2. prevent a race condition between setting the timeout err in the goroutine and setting it for `ioErr`. + timeoutErrCh := make(chan error) if c.CopyAfterExitTimeout != 0 { go func() { + defer close(timeoutErrCh) t := time.NewTimer(c.CopyAfterExitTimeout) defer t.Stop() select { @@ -280,17 +301,27 @@ func (c *Cmd) Wait() error { case <-t.C: // Close the process to cancel any reads to stdout or stderr. c.Process.Close() + err := errIOTimeOut + // log the timeout, since we may not return it to the caller if c.Log != nil { - c.Log.Warn("timed out waiting for stdio relay") + c.Log.WithField("timeout", c.CopyAfterExitTimeout).Warn(err.Error()) } + timeoutErrCh <- err } }() + } else { + close(timeoutErrCh) } - ioErr := c.iogrp.Wait() + + // TODO (go1.20): use multierror for these + ioErr := c.ioGrp.Wait() if ioErr == nil { ioErr, _ = c.stdinErr.Load().(error) } close(c.allDoneCh) + if tErr := <-timeoutErrCh; ioErr == nil { + ioErr = tErr + } c.Process.Close() c.ExitState = state if exitErr != nil { diff --git a/internal/cmd/cmd_test.go b/internal/cmd/cmd_test.go index 00eb86c7b0..ada98e7719 100644 --- a/internal/cmd/cmd_test.go +++ b/internal/cmd/cmd_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "os" "os/exec" @@ -213,46 +214,69 @@ func TestCmdStdinBlocked(t *testing.T) { } } -type stuckIoProcessHost struct { +type stuckIOProcessHost struct { cow.ProcessHost } -type stuckIoProcess struct { +type stuckIOProcess struct { cow.Process - stdin, pstdout, pstderr *io.PipeWriter - pstdin, stdout, stderr *io.PipeReader + + // don't initialize p.stdin, since it complicates the logic + pstdout, pstderr *os.File + stdout, stderr *os.File } -func (h *stuckIoProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) { +func (h *stuckIOProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) { p, err := h.ProcessHost.CreateProcess(ctx, cfg) if err != nil { return nil, err } - sp := &stuckIoProcess{ + sp := &stuckIOProcess{ Process: p, } - sp.pstdin, sp.stdin = io.Pipe() - sp.stdout, sp.pstdout = io.Pipe() - sp.stderr, sp.pstderr = io.Pipe() + + if sp.stdout, sp.pstdout, err = os.Pipe(); err != nil { + return nil, fmt.Errorf("create stdout pipe: %w", err) + } + if sp.stderr, sp.pstderr, err = os.Pipe(); err != nil { + return nil, fmt.Errorf("create stderr pipe: %w", err) + } return sp, nil } -func (p *stuckIoProcess) Stdio() (io.Writer, io.Reader, io.Reader) { - return p.stdin, p.stdout, p.stderr +func (p *stuckIOProcess) Stdio() (io.Writer, io.Reader, io.Reader) { + return nil, p.stdout, p.stderr } -func (p *stuckIoProcess) Close() error { - p.stdin.Close() +func (*stuckIOProcess) CloseStdin(context.Context) error { + return nil +} + +func (p *stuckIOProcess) CloseStdout(context.Context) error { + _ = p.pstdout.Close() + return p.stdout.Close() +} + +func (p *stuckIOProcess) CloseStderr(context.Context) error { + _ = p.pstderr.Close() + return p.stderr.Close() +} + +func (p *stuckIOProcess) Close() error { + p.pstdout.Close() + p.pstderr.Close() + p.stdout.Close() p.stderr.Close() + return p.Process.Close() } func TestCmdStuckIo(t *testing.T) { - cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c", "echo", "hello") + cmd := Command(&stuckIOProcessHost{&localProcessHost{}}, "cmd", "/c", "(exit 0)") cmd.CopyAfterExitTimeout = time.Millisecond * 200 _, err := cmd.Output() - if err != io.ErrClosedPipe { //nolint:errorlint - t.Fatal(err) + if !errors.Is(err, errIOTimeOut) { + t.Fatalf("expected: %v; got: %v", errIOTimeOut, err) } } diff --git a/internal/cmd/io.go b/internal/cmd/io.go index 75ddd1f355..d3663e67c0 100644 --- a/internal/cmd/io.go +++ b/internal/cmd/io.go @@ -4,11 +4,11 @@ package cmd import ( "context" + "fmt" "io" "net/url" "time" - "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -57,7 +57,7 @@ func NewUpstreamIO(ctx context.Context, id, stdout, stderr, stdin string, termin // Create IO for binary logging driver. if u.Scheme != "binary" { - return nil, errors.Errorf("scheme must be 'binary', got: '%s'", u.Scheme) + return nil, fmt.Errorf("scheme must be 'binary', got: '%s'", u.Scheme) } return NewBinaryIO(ctx, id, u)