Skip to content

Commit

Permalink
fix(emissary): signal SIGINT/SIGTERM in windows correctly (#13693)
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Weibel <michael@helio.exchange>
  • Loading branch information
mweibel authored Oct 9, 2024
1 parent ceaabf1 commit 7a33720
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 5 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ require (
go.opencensus.io v0.24.0 // indirect
go.starlark.net v0.0.0-20230525235612-a134d8f9ddca // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/sys v0.21.0
golang.org/x/term v0.21.0
golang.org/x/text v0.16.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
Expand Down
101 changes: 97 additions & 4 deletions workflow/executor/os-specific/signal_windows.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
package os_specific

import (
"fmt"
"os"
"syscall"
"unsafe"

"github.com/argoproj/argo-workflows/v3/util/errors"
"golang.org/x/sys/windows"
)

var (
Term = os.Interrupt

modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procCreateRemoteThread = modkernel32.NewProc("CreateRemoteThread")
procCtrlRoutine = modkernel32.NewProc("CtrlRoutine")
)

func CanIgnoreSignal(s os.Signal) bool {
Expand All @@ -19,11 +26,23 @@ func Kill(pid int, s syscall.Signal) error {
if pid < 0 {
pid = -pid // // we cannot kill a negative process on windows
}
p, err := os.FindProcess(pid)
if err != nil {
return err

winSignal := -1
switch s {
case syscall.SIGTERM:
winSignal = windows.CTRL_SHUTDOWN_EVENT
case syscall.SIGINT:
winSignal = windows.CTRL_C_EVENT
}
return p.Signal(s)

if winSignal == -1 {
p, err := os.FindProcess(pid)
if err != nil {
return err
}
return p.Signal(s)
}
return signalProcess(uint32(pid), winSignal)
}

func Setpgid(a *syscall.SysProcAttr) {
Expand All @@ -37,3 +56,77 @@ func Wait(process *os.Process) error {
}
return err
}

// signalProcess sends the specified signal to a process.
//
// Code +/- copied from: https://github.com/microsoft/hcsshim/blob/1d69a9c658655b77dd4e5275bff99caad6b38416/internal/jobcontainers/process.go#L251
// License: MIT
// Author: Microsoft
func signalProcess(pid uint32, signal int) error {
hProc, err := windows.OpenProcess(windows.PROCESS_TERMINATE, true, pid)
if err != nil {
return fmt.Errorf("failed to open process: %w", err)
}
defer func() {
_ = windows.Close(hProc)
}()

if err := procCtrlRoutine.Find(); err != nil {
return fmt.Errorf("failed to load CtrlRoutine: %w", err)
}

threadHandle, err := createRemoteThread(hProc, nil, 0, procCtrlRoutine.Addr(), uintptr(signal), 0, nil)
if err != nil {
return fmt.Errorf("failed to open remote thread in target process %d: %w", pid, err)
}
defer func() {
_ = windows.Close(windows.Handle(threadHandle))
}()
return nil
}

// Following code has been generated using github.com/Microsoft/go-winio/tools/mkwinsyscall and inlined
// for easier usage

// HANDLE CreateRemoteThread(
//
// HANDLE hProcess,
// LPSECURITY_ATTRIBUTES lpThreadAttributes,
// SIZE_T dwStackSize,
// LPTHREAD_START_ROUTINE lpStartAddress,
// LPVOID lpParameter,
// DWORD dwCreationFlags,
// LPDWORD lpThreadId
//
// );
func createRemoteThread(process windows.Handle, sa *windows.SecurityAttributes, stackSize uint32, startAddr uintptr, parameter uintptr, creationFlags uint32, threadID *uint32) (handle windows.Handle, err error) {
r0, _, e1 := syscall.SyscallN(procCreateRemoteThread.Addr(), uintptr(process), uintptr(unsafe.Pointer(sa)), uintptr(stackSize), uintptr(startAddr), uintptr(parameter), uintptr(creationFlags), uintptr(unsafe.Pointer(threadID)))
handle = windows.Handle(r0)
if handle == 0 {
err = errnoErr(e1)
}
return
}

// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)

var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)

// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
return e
}
36 changes: 36 additions & 0 deletions workflow/executor/os-specific/signal_windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//go:build windows

package os_specific

import (
"os/exec"
"sync"
"syscall"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestKill(t *testing.T) {
shell := "pwsh.exe"
cmd := exec.Command(shell, "-c", `while(1) { sleep 600000 }`)

_, err := StartCommand(cmd)
require.NoError(t, err)

var wg sync.WaitGroup
go func() {
wg.Add(1)
defer wg.Done()

err = cmd.Wait()
// we'll get an exit code
assert.Error(t, err)
}()

err = Kill(cmd.Process.Pid, syscall.SIGTERM)
require.NoError(t, err)

wg.Wait()
}

0 comments on commit 7a33720

Please sign in to comment.