Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shell command #1434

Merged
merged 2 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 0 additions & 112 deletions builtin/bins/dkron-executor-shell/shell.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
package main

import (
"encoding/base64"
"errors"
"fmt"
"log"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"syscall"
"time"

"github.com/armon/circbuf"
dkplugin "github.com/distribworks/dkron/v3/plugin"
Expand Down Expand Up @@ -53,111 +46,6 @@ func (s *Shell) Execute(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper)
return resp, nil
}

// ExecuteImpl do execute command
func (s *Shell) ExecuteImpl(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper) ([]byte, error) {
output, _ := circbuf.NewBuffer(maxBufSize)

shell, err := strconv.ParseBool(args.Config["shell"])
if err != nil {
shell = false
}
command := args.Config["command"]
env := strings.Split(args.Config["env"], ",")
cwd := args.Config["cwd"]

executionInfo := strings.Split(fmt.Sprintf("ENV_JOB_NAME=%s", args.JobName), ",")
env = append(env, executionInfo...)

cmd, err := buildCmd(command, shell, env, cwd)
if err != nil {
return nil, err
}
err = setCmdAttr(cmd, args.Config)
if err != nil {
return nil, err
}
// use same buffer for both channels, for the full return at the end
cmd.Stderr = reportingWriter{buffer: output, cb: cb, isError: true}
cmd.Stdout = reportingWriter{buffer: output, cb: cb}

stdin, err := cmd.StdinPipe()
if err != nil {
return nil, err
}

defer stdin.Close()

payload, err := base64.StdEncoding.DecodeString(args.Config["payload"])
if err != nil {
return nil, err
}

stdin.Write(payload)
stdin.Close()

jobTimeout := args.Config["timeout"]
var jt time.Duration

if jobTimeout != "" {
jt, err = time.ParseDuration(jobTimeout)
if err != nil {
return nil, errors.New("shell: Error parsing job timeout")
}
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}

log.Printf("shell: going to run %s", command)

err = cmd.Start()
if err != nil {
return nil, err
}

var jobTimeoutMessage string
var jobTimedOut bool

if jt != 0 {
slowTimer := time.AfterFunc(jt, func() {
// Kill child process to avoid cmd.Wait()
err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) // note the minus sign
if err != nil {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. SIGKILL returned error. Job may not have been killed", command, jt)
} else {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. Job was killed", command, jt)
}

jobTimedOut = true
})

defer slowTimer.Stop()
}

quit := make(chan int)

go CollectProcessMetrics(args.JobName, cmd.Process.Pid, quit)

err = cmd.Wait()
quit <- cmd.ProcessState.ExitCode()
close(quit) // exit metric refresh goroutine after job is finished

if jobTimedOut {
_, err := output.Write([]byte(jobTimeoutMessage))
if err != nil {
log.Printf("Error writing output on timeout event: %v", err)
}
}

// Warn if buffer is overwritten
if output.TotalWritten() > output.Size() {
log.Printf("shell: Script '%s' generated %d bytes of output, truncated to %d", command, output.TotalWritten(), output.Size())
}

// Always log output
log.Printf("shell: Command output %s", output)

return output.Bytes(), err
}

// Determine the shell invocation based on OS
func buildCmd(command string, useShell bool, env []string, cwd string) (cmd *exec.Cmd, err error) {
var shell, flag string
Expand Down
115 changes: 115 additions & 0 deletions builtin/bins/dkron-executor-shell/shell_unix.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
//go:build !windows
// +build !windows

package main

import (
"encoding/base64"
"errors"
"fmt"
"log"
"os/exec"
"os/user"
"strconv"
"strings"
"syscall"
"time"

"github.com/armon/circbuf"
dkplugin "github.com/distribworks/dkron/v3/plugin"
dktypes "github.com/distribworks/dkron/v3/plugin/types"
)

func setCmdAttr(cmd *exec.Cmd, config map[string]string) error {
Expand Down Expand Up @@ -37,3 +47,108 @@ func setCmdAttr(cmd *exec.Cmd, config map[string]string) error {
}
return nil
}

// ExecuteImpl do execute command
func (s *Shell) ExecuteImpl(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper) ([]byte, error) {
output, _ := circbuf.NewBuffer(maxBufSize)

shell, err := strconv.ParseBool(args.Config["shell"])
if err != nil {
shell = false
}
command := args.Config["command"]
env := strings.Split(args.Config["env"], ",")
cwd := args.Config["cwd"]

executionInfo := strings.Split(fmt.Sprintf("ENV_JOB_NAME=%s", args.JobName), ",")
env = append(env, executionInfo...)

cmd, err := buildCmd(command, shell, env, cwd)
if err != nil {
return nil, err
}
err = setCmdAttr(cmd, args.Config)
if err != nil {
return nil, err
}
// use same buffer for both channels, for the full return at the end
cmd.Stderr = reportingWriter{buffer: output, cb: cb, isError: true}
cmd.Stdout = reportingWriter{buffer: output, cb: cb}

stdin, err := cmd.StdinPipe()
if err != nil {
return nil, err
}

defer stdin.Close()

payload, err := base64.StdEncoding.DecodeString(args.Config["payload"])
if err != nil {
return nil, err
}

stdin.Write(payload)
stdin.Close()

jobTimeout := args.Config["timeout"]
var jt time.Duration

if jobTimeout != "" {
jt, err = time.ParseDuration(jobTimeout)
if err != nil {
return nil, errors.New("shell: Error parsing job timeout")
}
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}

log.Printf("shell: going to run %s", command)

err = cmd.Start()
if err != nil {
return nil, err
}

var jobTimeoutMessage string
var jobTimedOut bool

if jt != 0 {
slowTimer := time.AfterFunc(jt, func() {
// Kill child process to avoid cmd.Wait()
err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) // note the minus sign
if err != nil {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. SIGKILL returned error. Job may not have been killed", command, jt)
} else {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. Job was killed", command, jt)
}

jobTimedOut = true
})

defer slowTimer.Stop()
}

quit := make(chan int)

go CollectProcessMetrics(args.JobName, cmd.Process.Pid, quit)

err = cmd.Wait()
quit <- cmd.ProcessState.ExitCode()
close(quit) // exit metric refresh goroutine after job is finished

if jobTimedOut {
_, err := output.Write([]byte(jobTimeoutMessage))
if err != nil {
log.Printf("Error writing output on timeout event: %v", err)
}
}

// Warn if buffer is overwritten
if output.TotalWritten() > output.Size() {
log.Printf("shell: Script '%s' generated %d bytes of output, truncated to %d", command, output.TotalWritten(), output.Size())
}

// Always log output
log.Printf("shell: Command output %s", output)

return output.Bytes(), err
}
118 changes: 118 additions & 0 deletions builtin/bins/dkron-executor-shell/shell_windows.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,129 @@
//go:build windows
// +build windows

package main

import (
"encoding/base64"
"errors"
"fmt"
"log"
"os/exec"
"strconv"
"strings"
"syscall"
"time"

"github.com/armon/circbuf"
dkplugin "github.com/distribworks/dkron/v3/plugin"
dktypes "github.com/distribworks/dkron/v3/plugin/types"
)

func setCmdAttr(cmd *exec.Cmd, config map[string]string) error {
return nil
}

// ExecuteImpl do execute command
func (s *Shell) ExecuteImpl(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper) ([]byte, error) {
output, _ := circbuf.NewBuffer(maxBufSize)

shell, err := strconv.ParseBool(args.Config["shell"])
if err != nil {
shell = false
}
command := args.Config["command"]
env := strings.Split(args.Config["env"], ",")
cwd := args.Config["cwd"]

executionInfo := strings.Split(fmt.Sprintf("ENV_JOB_NAME=%s", args.JobName), ",")
env = append(env, executionInfo...)

cmd, err := buildCmd(command, shell, env, cwd)
if err != nil {
return nil, err
}
err = setCmdAttr(cmd, args.Config)
if err != nil {
return nil, err
}
// use same buffer for both channels, for the full return at the end
cmd.Stderr = reportingWriter{buffer: output, cb: cb, isError: true}
cmd.Stdout = reportingWriter{buffer: output, cb: cb}

stdin, err := cmd.StdinPipe()
if err != nil {
return nil, err
}

defer stdin.Close()

payload, err := base64.StdEncoding.DecodeString(args.Config["payload"])
if err != nil {
return nil, err
}

stdin.Write(payload)
stdin.Close()

jobTimeout := args.Config["timeout"]
var jt time.Duration

if jobTimeout != "" {
jt, err = time.ParseDuration(jobTimeout)
if err != nil {
return nil, errors.New("shell: Error parsing job timeout")
}
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}

log.Printf("shell: going to run %s", command)

err = cmd.Start()
if err != nil {
return nil, err
}

var jobTimeoutMessage string
var jobTimedOut bool

if jt != 0 {
slowTimer := time.AfterFunc(jt, func() {
// Kill child process to avoid cmd.Wait()
err = cmd.Process.Kill()
if err != nil {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. SIGKILL returned error. Job may not have been killed", command, jt)
} else {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. Job was killed", command, jt)
}

jobTimedOut = true
})

defer slowTimer.Stop()
}

quit := make(chan int)

go CollectProcessMetrics(args.JobName, cmd.Process.Pid, quit)

err = cmd.Wait()
quit <- cmd.ProcessState.ExitCode()
close(quit) // exit metric refresh goroutine after job is finished

if jobTimedOut {
_, err := output.Write([]byte(jobTimeoutMessage))
if err != nil {
log.Printf("Error writing output on timeout event: %v", err)
}
}

// Warn if buffer is overwritten
if output.TotalWritten() > output.Size() {
log.Printf("shell: Script '%s' generated %d bytes of output, truncated to %d", command, output.TotalWritten(), output.Size())
}

// Always log output
log.Printf("shell: Command output %s", output)

return output.Bytes(), err
}
Loading