Skip to content

Commit

Permalink
Implement cancellation via context
Browse files Browse the repository at this point in the history
ExecuteWithContext allows a user to pass in a context with
a timeout, or a cancellation that can be used to terminate
the process.

Fixes: #10 #9

Signed-off-by: Alex Ellis (OpenFaaS Ltd) <alexellis2@gmail.com>
  • Loading branch information
alexellis committed Apr 17, 2022
1 parent 02af333 commit f1786aa
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 0 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/alexellis/go-execute

go 1.16

require golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
138 changes: 138 additions & 0 deletions pkg/v1/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package execute

import (
"bytes"
"context"
"fmt"
"io"
"log"
"os"
"os/exec"
"strings"

"golang.org/x/sync/errgroup"
)

type ExecTask struct {
Expand Down Expand Up @@ -34,6 +38,140 @@ type ExecResult struct {
ExitCode int
}

func (et ExecTask) ExecuteWithContext(ctx context.Context) (ExecResult, error) {

child, cancel := context.WithCancel(ctx)
defer cancel()

argsSt := ""
if len(et.Args) > 0 {
argsSt = strings.Join(et.Args, " ")
}

if et.PrintCommand {
fmt.Println("exec: ", et.Command, argsSt)
}

var cmd *exec.Cmd

if et.Shell {
var args []string
if len(et.Args) == 0 {
startArgs := strings.Split(et.Command, " ")
script := strings.Join(startArgs, " ")
args = append([]string{"-c"}, fmt.Sprintf("%s", script))

} else {
script := strings.Join(et.Args, " ")
args = append([]string{"-c"}, fmt.Sprintf("%s %s", et.Command, script))

}

cmd = exec.Command("/bin/bash", args...)
} else {
if strings.Index(et.Command, " ") > 0 {
parts := strings.Split(et.Command, " ")
command := parts[0]
args := parts[1:]
cmd = exec.Command(command, args...)

} else {
cmd = exec.Command(et.Command, et.Args...)
}
}

cmd.Dir = et.Cwd

if len(et.Env) > 0 {
overrides := map[string]bool{}
for _, env := range et.Env {
key := strings.Split(env, "=")[0]
overrides[key] = true
cmd.Env = append(cmd.Env, env)
}

for _, env := range os.Environ() {
key := strings.Split(env, "=")[0]

if _, ok := overrides[key]; !ok {
cmd.Env = append(cmd.Env, env)
}
}
}
if et.Stdin != nil {
cmd.Stdin = et.Stdin
}

stdoutBuff := bytes.Buffer{}
stderrBuff := bytes.Buffer{}

var stdoutWriters io.Writer
var stderrWriters io.Writer

if et.StreamStdio {
stdoutWriters = io.MultiWriter(os.Stdout, &stdoutBuff)
stderrWriters = io.MultiWriter(os.Stderr, &stderrBuff)
} else {
stdoutWriters = &stdoutBuff
stderrWriters = &stderrBuff
}

cmd.Stdout = stdoutWriters
cmd.Stderr = stderrWriters

g := errgroup.Group{}

success := false
exitCode := 0
g.Go(func() error {
if err := cmd.Start(); err != nil {
return err
}

if err := cmd.Wait(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
}
return err
}

success = true
cancel()
return nil
})

g.Go(func() error {
<-child.Done()

ctxErr := child.Err()

if ctxErr != nil {
if success {
log.Printf("Cancellation expected")
return nil
}
log.Printf("Context is done with error")
if err := cmd.Process.Kill(); err != nil {
log.Printf("failed to kill process: %v", err)
}
exitCode = 1
return ctxErr
}
return nil
})

if err := g.Wait(); err != nil {
log.Printf("First error returned from errgroup: %v", err)
return ExecResult{}, err
}

return ExecResult{
Stdout: string(stdoutBuff.Bytes()),
Stderr: string(stderrBuff.Bytes()),
ExitCode: exitCode,
}, nil
}

func (et ExecTask) Execute() (ExecResult, error) {
argsSt := ""
if len(et.Args) > 0 {
Expand Down
65 changes: 65 additions & 0 deletions pkg/v1/exec_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package execute

import (
"context"
"testing"
"time"
)

func TestExecuteWithContext_SleepInterruptedByTimeout(t *testing.T) {
task := ExecTask{Command: "/bin/sleep 1", Shell: true}
timeout := time.Millisecond * 200
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

start := time.Now()
res, err := task.ExecuteWithContext(ctx)
if err != nil {
t.Errorf(err.Error())
t.Fail()
}

duration := time.Since(start)
if duration > timeout*2 {
t.Fatalf("Cancellation failed, took %s, max timeout was: %s", duration, timeout*2)
}

if len(res.Stdout) != 0 {
t.Errorf("want stdout to be empty, but got: %s", res.Stdout)
t.Fail()
}

if len(res.Stderr) != 0 {
t.Errorf("want empty on stderr, but got: %s", res.Stderr)
t.Fail()
}
}

func TestExecuteWithContext_SleepWithinTimeout(t *testing.T) {
task := ExecTask{Command: "/bin/sleep 0.1", Shell: true}
timeout := time.Millisecond * 500
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

start := time.Now()
res, err := task.ExecuteWithContext(ctx)
if err != nil {
t.Errorf(err.Error())
t.Fail()
}

duration := time.Since(start)
if duration > timeout*2 {
t.Fatalf("Cancellation failed, took %s, max timeout was: %s", duration, timeout*2)
}

if len(res.Stdout) != 0 {
t.Errorf("want stdout to be empty, but got: %s", res.Stdout)
t.Fail()
}

if len(res.Stderr) != 0 {
t.Errorf("want empty on stderr, but got: %s", res.Stderr)
t.Fail()
}
}

0 comments on commit f1786aa

Please sign in to comment.