diff --git a/builtin/bins/dkron-executor-shell/shell.go b/builtin/bins/dkron-executor-shell/shell.go index 09e939f37..5b96aac2d 100644 --- a/builtin/bins/dkron-executor-shell/shell.go +++ b/builtin/bins/dkron-executor-shell/shell.go @@ -1,17 +1,10 @@ package main import ( - "encoding/base64" "errors" - "fmt" - "log" "os" "os/exec" "runtime" - "strconv" - "strings" - "syscall" - "time" "github.com/armon/circbuf" dkplugin "github.com/distribworks/dkron/v3/plugin" @@ -53,111 +46,6 @@ func (s *Shell) Execute(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper) return resp, nil } -// ExecuteImpl do execute command -func (s *Shell) ExecuteImpl(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper) ([]byte, error) { - output, _ := circbuf.NewBuffer(maxBufSize) - - shell, err := strconv.ParseBool(args.Config["shell"]) - if err != nil { - shell = false - } - command := args.Config["command"] - env := strings.Split(args.Config["env"], ",") - cwd := args.Config["cwd"] - - executionInfo := strings.Split(fmt.Sprintf("ENV_JOB_NAME=%s", args.JobName), ",") - env = append(env, executionInfo...) - - cmd, err := buildCmd(command, shell, env, cwd) - if err != nil { - return nil, err - } - err = setCmdAttr(cmd, args.Config) - if err != nil { - return nil, err - } - // use same buffer for both channels, for the full return at the end - cmd.Stderr = reportingWriter{buffer: output, cb: cb, isError: true} - cmd.Stdout = reportingWriter{buffer: output, cb: cb} - - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, err - } - - defer stdin.Close() - - payload, err := base64.StdEncoding.DecodeString(args.Config["payload"]) - if err != nil { - return nil, err - } - - stdin.Write(payload) - stdin.Close() - - jobTimeout := args.Config["timeout"] - var jt time.Duration - - if jobTimeout != "" { - jt, err = time.ParseDuration(jobTimeout) - if err != nil { - return nil, errors.New("shell: Error parsing job timeout") - } - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - } - - log.Printf("shell: going to run %s", command) - - err = cmd.Start() - if err != nil { - return nil, err - } - - var jobTimeoutMessage string - var jobTimedOut bool - - if jt != 0 { - slowTimer := time.AfterFunc(jt, func() { - // Kill child process to avoid cmd.Wait() - err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) // note the minus sign - if err != nil { - jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. SIGKILL returned error. Job may not have been killed", command, jt) - } else { - jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. Job was killed", command, jt) - } - - jobTimedOut = true - }) - - defer slowTimer.Stop() - } - - quit := make(chan int) - - go CollectProcessMetrics(args.JobName, cmd.Process.Pid, quit) - - err = cmd.Wait() - quit <- cmd.ProcessState.ExitCode() - close(quit) // exit metric refresh goroutine after job is finished - - if jobTimedOut { - _, err := output.Write([]byte(jobTimeoutMessage)) - if err != nil { - log.Printf("Error writing output on timeout event: %v", err) - } - } - - // Warn if buffer is overwritten - if output.TotalWritten() > output.Size() { - log.Printf("shell: Script '%s' generated %d bytes of output, truncated to %d", command, output.TotalWritten(), output.Size()) - } - - // Always log output - log.Printf("shell: Command output %s", output) - - return output.Bytes(), err -} - // Determine the shell invocation based on OS func buildCmd(command string, useShell bool, env []string, cwd string) (cmd *exec.Cmd, err error) { var shell, flag string diff --git a/builtin/bins/dkron-executor-shell/shell_unix.go b/builtin/bins/dkron-executor-shell/shell_unix.go index c3731163a..9e87027b3 100644 --- a/builtin/bins/dkron-executor-shell/shell_unix.go +++ b/builtin/bins/dkron-executor-shell/shell_unix.go @@ -1,13 +1,23 @@ +//go:build !windows // +build !windows package main import ( + "encoding/base64" + "errors" + "fmt" + "log" "os/exec" "os/user" "strconv" "strings" "syscall" + "time" + + "github.com/armon/circbuf" + dkplugin "github.com/distribworks/dkron/v3/plugin" + dktypes "github.com/distribworks/dkron/v3/plugin/types" ) func setCmdAttr(cmd *exec.Cmd, config map[string]string) error { @@ -37,3 +47,108 @@ func setCmdAttr(cmd *exec.Cmd, config map[string]string) error { } return nil } + +// ExecuteImpl do execute command +func (s *Shell) ExecuteImpl(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper) ([]byte, error) { + output, _ := circbuf.NewBuffer(maxBufSize) + + shell, err := strconv.ParseBool(args.Config["shell"]) + if err != nil { + shell = false + } + command := args.Config["command"] + env := strings.Split(args.Config["env"], ",") + cwd := args.Config["cwd"] + + executionInfo := strings.Split(fmt.Sprintf("ENV_JOB_NAME=%s", args.JobName), ",") + env = append(env, executionInfo...) + + cmd, err := buildCmd(command, shell, env, cwd) + if err != nil { + return nil, err + } + err = setCmdAttr(cmd, args.Config) + if err != nil { + return nil, err + } + // use same buffer for both channels, for the full return at the end + cmd.Stderr = reportingWriter{buffer: output, cb: cb, isError: true} + cmd.Stdout = reportingWriter{buffer: output, cb: cb} + + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, err + } + + defer stdin.Close() + + payload, err := base64.StdEncoding.DecodeString(args.Config["payload"]) + if err != nil { + return nil, err + } + + stdin.Write(payload) + stdin.Close() + + jobTimeout := args.Config["timeout"] + var jt time.Duration + + if jobTimeout != "" { + jt, err = time.ParseDuration(jobTimeout) + if err != nil { + return nil, errors.New("shell: Error parsing job timeout") + } + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + } + + log.Printf("shell: going to run %s", command) + + err = cmd.Start() + if err != nil { + return nil, err + } + + var jobTimeoutMessage string + var jobTimedOut bool + + if jt != 0 { + slowTimer := time.AfterFunc(jt, func() { + // Kill child process to avoid cmd.Wait() + err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) // note the minus sign + if err != nil { + jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. SIGKILL returned error. Job may not have been killed", command, jt) + } else { + jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. Job was killed", command, jt) + } + + jobTimedOut = true + }) + + defer slowTimer.Stop() + } + + quit := make(chan int) + + go CollectProcessMetrics(args.JobName, cmd.Process.Pid, quit) + + err = cmd.Wait() + quit <- cmd.ProcessState.ExitCode() + close(quit) // exit metric refresh goroutine after job is finished + + if jobTimedOut { + _, err := output.Write([]byte(jobTimeoutMessage)) + if err != nil { + log.Printf("Error writing output on timeout event: %v", err) + } + } + + // Warn if buffer is overwritten + if output.TotalWritten() > output.Size() { + log.Printf("shell: Script '%s' generated %d bytes of output, truncated to %d", command, output.TotalWritten(), output.Size()) + } + + // Always log output + log.Printf("shell: Command output %s", output) + + return output.Bytes(), err +} diff --git a/builtin/bins/dkron-executor-shell/shell_windows.go b/builtin/bins/dkron-executor-shell/shell_windows.go index 417de4030..1c895605a 100644 --- a/builtin/bins/dkron-executor-shell/shell_windows.go +++ b/builtin/bins/dkron-executor-shell/shell_windows.go @@ -1,11 +1,129 @@ +//go:build windows // +build windows package main import ( + "encoding/base64" + "errors" + "fmt" + "log" "os/exec" + "strconv" + "strings" + "syscall" + "time" + + "github.com/armon/circbuf" + dkplugin "github.com/distribworks/dkron/v3/plugin" + dktypes "github.com/distribworks/dkron/v3/plugin/types" ) func setCmdAttr(cmd *exec.Cmd, config map[string]string) error { return nil } + +// ExecuteImpl do execute command +func (s *Shell) ExecuteImpl(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper) ([]byte, error) { + output, _ := circbuf.NewBuffer(maxBufSize) + + shell, err := strconv.ParseBool(args.Config["shell"]) + if err != nil { + shell = false + } + command := args.Config["command"] + env := strings.Split(args.Config["env"], ",") + cwd := args.Config["cwd"] + + executionInfo := strings.Split(fmt.Sprintf("ENV_JOB_NAME=%s", args.JobName), ",") + env = append(env, executionInfo...) + + cmd, err := buildCmd(command, shell, env, cwd) + if err != nil { + return nil, err + } + err = setCmdAttr(cmd, args.Config) + if err != nil { + return nil, err + } + // use same buffer for both channels, for the full return at the end + cmd.Stderr = reportingWriter{buffer: output, cb: cb, isError: true} + cmd.Stdout = reportingWriter{buffer: output, cb: cb} + + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, err + } + + defer stdin.Close() + + payload, err := base64.StdEncoding.DecodeString(args.Config["payload"]) + if err != nil { + return nil, err + } + + stdin.Write(payload) + stdin.Close() + + jobTimeout := args.Config["timeout"] + var jt time.Duration + + if jobTimeout != "" { + jt, err = time.ParseDuration(jobTimeout) + if err != nil { + return nil, errors.New("shell: Error parsing job timeout") + } + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + } + + log.Printf("shell: going to run %s", command) + + err = cmd.Start() + if err != nil { + return nil, err + } + + var jobTimeoutMessage string + var jobTimedOut bool + + if jt != 0 { + slowTimer := time.AfterFunc(jt, func() { + // Kill child process to avoid cmd.Wait() + err = cmd.Process.Kill() + if err != nil { + jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. SIGKILL returned error. Job may not have been killed", command, jt) + } else { + jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. Job was killed", command, jt) + } + + jobTimedOut = true + }) + + defer slowTimer.Stop() + } + + quit := make(chan int) + + go CollectProcessMetrics(args.JobName, cmd.Process.Pid, quit) + + err = cmd.Wait() + quit <- cmd.ProcessState.ExitCode() + close(quit) // exit metric refresh goroutine after job is finished + + if jobTimedOut { + _, err := output.Write([]byte(jobTimeoutMessage)) + if err != nil { + log.Printf("Error writing output on timeout event: %v", err) + } + } + + // Warn if buffer is overwritten + if output.TotalWritten() > output.Size() { + log.Printf("shell: Script '%s' generated %d bytes of output, truncated to %d", command, output.TotalWritten(), output.Size()) + } + + // Always log output + log.Printf("shell: Command output %s", output) + + return output.Bytes(), err +}