Skip to content

Commit

Permalink
command: fix signalling to command process group (#637)
Browse files Browse the repository at this point in the history
This change also fixes a regression that broken running interactive
commands, like runme in runme, run from the CLI in
da2c38c.

It also disables FIFO which should mitigate
#635 but a proper fix is needed.

Relates to #636
  • Loading branch information
adambabik authored Jul 29, 2024
1 parent 9658f77 commit 19af228
Show file tree
Hide file tree
Showing 20 changed files with 223 additions and 1,713 deletions.
2 changes: 1 addition & 1 deletion experimental/runme.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ server:

log:
enabled: true
path: "/var/log/runme.log"
path: "/tmp/runme.log"
verbose: true
4 changes: 4 additions & 0 deletions internal/cmd/beta/run_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/stateful/runme/v3/internal/command"
"github.com/stateful/runme/v3/internal/config/autoconfig"
runnerv2alpha1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2alpha1"
"github.com/stateful/runme/v3/pkg/document"
"github.com/stateful/runme/v3/pkg/project"
)
Expand Down Expand Up @@ -115,6 +116,9 @@ func runCodeBlock(
if err != nil {
return err
}

cfg.Mode = runnerv2alpha1.CommandMode_COMMAND_MODE_CLI

cmd, err := factory.Build(cfg, options)
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions internal/command/command_inline_shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ func (c *inlineShellCommand) Wait() error {
err := c.internalCommand.Wait()

if c.envCollector != nil {
if cErr := c.collectEnv(); cErr != nil {
c.logger.Info("failed to collect the environment", zap.Error(cErr))
if err == nil {
err = cErr
}
c.logger.Info("collecting the environment after the script execution")
cErr := c.collectEnv()
c.logger.Info("collected the environment after the script execution", zap.Error(cErr))
if cErr != nil && err == nil {
err = cErr
}
}

Expand Down
48 changes: 20 additions & 28 deletions internal/command/command_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@ import (
"go.uber.org/zap"
)

// SignalToProcessGroup is used in tests to disable sending signals to a process group.
var SignalToProcessGroup = true

type nativeCommand struct {
*base

logger *zap.Logger
disableNewProcessID bool
logger *zap.Logger

cmd *exec.Cmd
}
Expand All @@ -34,21 +32,14 @@ func (c *nativeCommand) Pid() int {
func (c *nativeCommand) Start(ctx context.Context) (err error) {
stdin := c.Stdin()

// TODO(adamb): include explanation why it is needed.
if f, ok := stdin.(*os.File); ok && f != nil {
// Duplicate /dev/stdin.
newStdinFd, err := dup(f.Fd())
if err != nil {
return errors.Wrap(err, "failed to dup stdin")
}
closeOnExec(newStdinFd)

// Setting stdin to the non-block mode fails on the simple "read" command.
// On the other hand, it allows to use SetReadDeadline().
// It turned out it's not needed, but keeping the code here for now.
// if err := syscall.SetNonblock(newStdinFd, true); err != nil {
// return nil, errors.Wrap(err, "failed to set new stdin fd in non-blocking mode")
// }

stdin = os.NewFile(uintptr(newStdinFd), "")
}

Expand All @@ -69,46 +60,47 @@ func (c *nativeCommand) Start(ctx context.Context) (err error) {
c.cmd.Stdout = c.Stdout()
c.cmd.Stderr = c.Stderr()

// Set the process group ID of the program.
// It is helpful to stop the program and its
// children.
// Note that Setsid set in setSysProcAttrCtty()
// already starts a new process group.
// Warning: it does not work with interactive programs
// like "python", hence, it's commented out.
// setSysProcAttrPgid(c.cmd)
if !c.disableNewProcessID {
// Creating a new process group is required to properly replicate a behaviour
// similar to CTRL-C in the terminal, which sends a SIGINT to the whole group.
setSysProcAttrPgid(c.cmd)
}

c.logger.Info("starting a native command", zap.Any("config", redactConfig(c.ProgramConfig())))
c.logger.Info("starting", zap.Any("config", redactConfig(c.ProgramConfig())))
if err := c.cmd.Start(); err != nil {
return errors.WithStack(err)
}
c.logger.Info("a native command started")
c.logger.Info("started")

return nil
}

func (c *nativeCommand) Signal(sig os.Signal) error {
c.logger.Info("stopping the native command with a signal", zap.Stringer("signal", sig))
c.logger.Info("stopping with signal", zap.Stringer("signal", sig))

if SignalToProcessGroup {
if !c.disableNewProcessID {
c.logger.Info("signaling to the process group", zap.Stringer("signal", sig))
// Try to terminate the whole process group. If it fails, fall back to stdlib methods.
err := signalPgid(c.cmd.Process.Pid, sig)
if err == nil {
return nil
}
c.logger.Info("failed to terminate process group; trying Process.Signal()", zap.Error(err))
c.logger.Info("failed to signal the process group; trying regular signaling", zap.Error(err))
}

if err := c.cmd.Process.Signal(sig); err != nil {
c.logger.Info("failed to signal process; trying Process.Kill()", zap.Error(err))
if sig == os.Kill {
return errors.WithStack(err)
}
c.logger.Info("failed to signal the process; trying kill signal", zap.Error(err))
return errors.WithStack(c.cmd.Process.Kill())
}

return nil
}

func (c *nativeCommand) Wait() (err error) {
c.logger.Info("waiting for the native command to finish")
c.logger.Info("waiting for finish")

var stderr []byte
err = errors.WithStack(c.cmd.Wait())
Expand All @@ -119,7 +111,7 @@ func (c *nativeCommand) Wait() (err error) {
}
}

c.logger.Info("the native command finished", zap.Error(err), zap.ByteString("stderr", stderr))
c.logger.Info("finished", zap.Error(err), zap.ByteString("stderr", stderr))

return
}
12 changes: 10 additions & 2 deletions internal/command/command_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ import (
"golang.org/x/sys/unix"
)

func setSysProcAttrCtty(cmd *exec.Cmd) {
func setSysProcAttrCtty(cmd *exec.Cmd, tty int) {
if cmd.SysProcAttr == nil {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
cmd.SysProcAttr.Setsid = true
cmd.SysProcAttr.Ctty = tty
cmd.SysProcAttr.Setctty = true
cmd.SysProcAttr.Setsid = true
}

func setSysProcAttrPgid(cmd *exec.Cmd) {
if cmd.SysProcAttr == nil {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
cmd.SysProcAttr.Setpgid = true
}

func disableEcho(fd uintptr) error {
Expand Down
6 changes: 0 additions & 6 deletions internal/command/command_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ import (
"github.com/stateful/runme/v3/pkg/document/identity"
)

func init() {
// Set to false to disable sending signals to process groups in tests.
// This can be turned on if setSysProcAttrPgid() is called in Start().
SignalToProcessGroup = false
}

func TestCommand(t *testing.T) {
testCases := []struct {
name string
Expand Down
59 changes: 34 additions & 25 deletions internal/command/command_virtual.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func (c *virtualCommand) Start(ctx context.Context) (err error) {
}

if !c.isEchoEnabled {
c.logger.Info("disabling echo")
if err := disableEcho(c.tty.Fd()); err != nil {
return err
}
Expand All @@ -69,6 +70,7 @@ func (c *virtualCommand) Start(ctx context.Context) (err error) {
if err != nil {
return err
}
c.logger.Info("detected program path and arguments", zap.String("program", program), zap.Strings("args", args))

c.cmd = exec.CommandContext(
ctx,
Expand All @@ -81,22 +83,28 @@ func (c *virtualCommand) Start(ctx context.Context) (err error) {
c.cmd.Stdout = c.tty
c.cmd.Stderr = c.tty

setSysProcAttrCtty(c.cmd)
// Create a new session and set the controlling terminal to tty.
// The new process group is created automatically so that sending
// a signal to the command will affect the whole group.
// 3 is because stdin, stdout, stderr + i-th element in ExtraFiles.
setSysProcAttrCtty(c.cmd, 3)
c.cmd.ExtraFiles = []*os.File{c.tty}

c.logger.Info("starting a virtual command", zap.Any("config", redactConfig(c.ProgramConfig())))
c.logger.Info("starting", zap.Any("config", redactConfig(c.ProgramConfig())))
if err := c.cmd.Start(); err != nil {
return errors.WithStack(err)
}
c.logger.Info("started")

if !isNil(c.stdin) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
n, err := io.Copy(c.pty, c.stdin)
c.logger.Info("finished copying from stdin to pty", zap.Error(err), zap.Int64("count", n))
if err != nil {
c.setErr(errors.WithStack(err))
}
c.logger.Info("copied from stdin to pty", zap.Error(err), zap.Int64("count", n))
}()
}

Expand All @@ -112,54 +120,59 @@ func (c *virtualCommand) Start(ctx context.Context) (err error) {
// a master pseudo-terminal which no longer has an open slave.
// See https://github.com/creack/pty/issues/21.
if errors.Is(err, syscall.EIO) {
c.logger.Debug("failed to copy from pty to stdout; handled EIO")
c.logger.Info("failed to copy from pty to stdout; handled EIO")
return
}
if errors.Is(err, os.ErrClosed) {
c.logger.Debug("failed to copy from pty to stdout; handled ErrClosed")
c.logger.Info("failed to copy from pty to stdout; handled ErrClosed")
return
}

c.logger.Info("failed to copy from pty to stdout", zap.Error(err))

c.setErr(errors.WithStack(err))
} else {
c.logger.Debug("finished copying from pty to stdout", zap.Int64("count", n))
}

c.logger.Info("copied from pty to stdout", zap.Int64("count", n))
}()
}

c.logger.Info("a virtual command started")

return nil
}

func (c *virtualCommand) Signal(sig os.Signal) error {
c.logger.Info("stopping the virtual command with signal", zap.String("signal", sig.String()))
c.logger.Info("stopping with signal", zap.String("signal", sig.String()))

// Try to terminate the whole process group. If it fails, fall back to stdlib methods.
if err := signalPgid(c.cmd.Process.Pid, sig); err != nil {
c.logger.Info("failed to terminate process group; trying Process.Signal()", zap.Error(err))
if err := c.cmd.Process.Signal(sig); err != nil {
c.logger.Info("failed to signal process; trying Process.Kill()", zap.Error(err))
return errors.WithStack(c.cmd.Process.Kill())
err := signalPgid(c.cmd.Process.Pid, sig)
if err == nil {
return nil
}

c.logger.Info("failed to signal the process group; trying regular signaling", zap.Error(err))

if err := c.cmd.Process.Signal(sig); err != nil {
if sig == os.Kill {
return errors.WithStack(err)
}
c.logger.Info("failed to signal the process; trying kill signal", zap.Error(err))
return errors.WithStack(c.cmd.Process.Kill())
}

return nil
}

func (c *virtualCommand) Wait() (err error) {
c.logger.Info("waiting for the virtual command to finish")
c.logger.Info("waiting for finish")
err = errors.WithStack(c.cmd.Wait())
c.logger.Info("the virtual command finished", zap.Error(err))
c.logger.Info("finished", zap.Error(err))

errIO := c.closeIO()
c.logger.Info("closed IO of the virtual command", zap.Error(errIO))
c.logger.Info("closed IO", zap.Error(errIO))
if err == nil && errIO != nil {
err = errIO
}

c.logger.Info("waiting IO goroutines")
c.wg.Wait()
c.logger.Info("finished waiting for IO goroutines")

c.mu.Lock()
if err == nil && c.err != nil {
Expand Down Expand Up @@ -192,10 +205,6 @@ func (c *virtualCommand) closeIO() (err error) {
err = multierr.Append(err, errors.WithMessage(errClose, "failed to close tty"))
}

// if err := c.pty.Close(); err != nil {
// return errors.WithMessage(err, "failed to close pty")
// }

return
}

Expand Down
2 changes: 1 addition & 1 deletion internal/command/command_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/pkg/errors"
)

func setSysProcAttrCtty(cmd *exec.Cmd) {}
func setSysProcAttrCtty(cmd *exec.Cmd, tty int) {}

func setSysProcAttrPgid(cmd *exec.Cmd) {}

Expand Down
Loading

0 comments on commit 19af228

Please sign in to comment.