diff --git a/cmd.go b/cmd.go index bc08dd0..7893c72 100644 --- a/cmd.go +++ b/cmd.go @@ -1,65 +1,65 @@ -package pty - -import ( - "context" - "os" - "syscall" -) - -// Cmd is a command that can be started attached to a pseudo-terminal. -// This is similar to the API of exec.Cmd. The main difference is that -// the command is started attached to a pseudo-terminal. -// This is required as we cannot use exec.Cmd directly on Windows due to -// limitation of starting a process attached to a pseudo-terminal. -// See: https://github.com/golang/go/issues/62708 -type Cmd struct { - ctx context.Context - pty Pty - sys interface{} - - // Path is the path of the command to run. - Path string - - // Args holds command line arguments, including the command as Args[0]. - Args []string - - // Env specifies the environment of the process. - // If Env is nil, the new process uses the current process's environment. - Env []string - - // Dir specifies the working directory of the command. - // If Dir is the empty string, the current directory is used. - Dir string - - // SysProcAttr holds optional, operating system-specific attributes. - SysProcAttr *syscall.SysProcAttr - - // Process is the underlying process, once started. - Process *os.Process - - // ProcessState contains information about an exited process. - // If the process was started successfully, Wait or Run will populate this - // field when the command completes. - ProcessState *os.ProcessState - - // Cancel is called when the command is canceled. - Cancel func() error -} - -// Start starts the specified command attached to the pseudo-terminal. -func (c *Cmd) Start() error { - return c.start() -} - -// Wait waits for the command to exit. -func (c *Cmd) Wait() error { - return c.wait() -} - -// Run runs the command and waits for it to complete. -func (c *Cmd) Run() error { - if err := c.Start(); err != nil { - return err - } - return c.Wait() -} +package pty + +import ( + "context" + "os" + "syscall" +) + +// Cmd is a command that can be started attached to a pseudo-terminal. +// This is similar to the API of exec.Cmd. The main difference is that +// the command is started attached to a pseudo-terminal. +// This is required as we cannot use exec.Cmd directly on Windows due to +// limitation of starting a process attached to a pseudo-terminal. +// See: https://github.com/golang/go/issues/62708 +type Cmd struct { + ctx context.Context + pty Pty + sys interface{} + + // Path is the path of the command to run. + Path string + + // Args holds command line arguments, including the command as Args[0]. + Args []string + + // Env specifies the environment of the process. + // If Env is nil, the new process uses the current process's environment. + Env []string + + // Dir specifies the working directory of the command. + // If Dir is the empty string, the current directory is used. + Dir string + + // SysProcAttr holds optional, operating system-specific attributes. + SysProcAttr *syscall.SysProcAttr + + // Process is the underlying process, once started. + Process *os.Process + + // ProcessState contains information about an exited process. + // If the process was started successfully, Wait or Run will populate this + // field when the command completes. + ProcessState *os.ProcessState + + // Cancel is called when the command is canceled. + Cancel func() error +} + +// Start starts the specified command attached to the pseudo-terminal. +func (c *Cmd) Start() error { + return c.start() +} + +// Wait waits for the command to exit. +func (c *Cmd) Wait() error { + return c.wait() +} + +// Run runs the command and waits for it to complete. +func (c *Cmd) Run() error { + if err := c.Start(); err != nil { + return err + } + return c.Wait() +} diff --git a/cmd_windows.go b/cmd_windows.go index 97f52ea..e41ac4a 100644 --- a/cmd_windows.go +++ b/cmd_windows.go @@ -1,413 +1,413 @@ -//go:build windows -// +build windows - -package pty - -import ( - "errors" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - "syscall" - "unicode/utf16" - "unsafe" - - "golang.org/x/sys/windows" -) - -type conPtySys struct { - attrs *windows.ProcThreadAttributeListContainer - done chan error - cmdErr error -} - -func (c *Cmd) start() error { - pty, ok := c.pty.(*ConPty) - if !ok { - return ErrInvalidCommand - } - - if c.SysProcAttr == nil { - c.SysProcAttr = &syscall.SysProcAttr{} - } - - argv0, err := lookExtensions(c.Path, c.Dir) - if err != nil { - return err - } - if len(c.Dir) != 0 { - // Windows CreateProcess looks for argv0 relative to the current - // directory, and, only once the new process is started, it does - // Chdir(attr.Dir). We are adjusting for that difference here by - // making argv0 absolute. - var err error - argv0, err = joinExeDirAndFName(c.Dir, c.Path) - if err != nil { - return err - } - } - - argv0p, err := windows.UTF16PtrFromString(argv0) - if err != nil { - return err - } - - var cmdline string - if c.SysProcAttr.CmdLine != "" { - cmdline = c.SysProcAttr.CmdLine - } else { - cmdline = windows.ComposeCommandLine(c.Args) - } - argvp, err := windows.UTF16PtrFromString(cmdline) - if err != nil { - return err - } - - var dirp *uint16 - if len(c.Dir) != 0 { - dirp, err = windows.UTF16PtrFromString(c.Dir) - if err != nil { - return err - } - } - - if c.Env == nil { - c.Env, err = execEnvDefault(c.SysProcAttr) - if err != nil { - return err - } - } - - siEx := new(windows.StartupInfoEx) - siEx.Flags = windows.STARTF_USESTDHANDLES - pi := new(windows.ProcessInformation) - - // Need EXTENDED_STARTUPINFO_PRESENT as we're making use of the attribute list field. - flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT | c.SysProcAttr.CreationFlags - - // Allocate an attribute list that's large enough to do the operations we care about - // 2. Pseudo console setup if one was requested. - // Therefore we need a list of size 3. - attrs, err := windows.NewProcThreadAttributeList(1) - if err != nil { - return fmt.Errorf("failed to initialize process thread attribute list: %w", err) - } - - c.sys = &conPtySys{ - attrs: attrs, - done: make(chan error, 1), - } - - if err := pty.updateProcThreadAttribute(attrs); err != nil { - return err - } - - var zeroSec windows.SecurityAttributes - pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} - if c.SysProcAttr.ProcessAttributes != nil { - pSec = &windows.SecurityAttributes{ - Length: c.SysProcAttr.ProcessAttributes.Length, - InheritHandle: c.SysProcAttr.ProcessAttributes.InheritHandle, - } - } - tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} - if c.SysProcAttr.ThreadAttributes != nil { - tSec = &windows.SecurityAttributes{ - Length: c.SysProcAttr.ThreadAttributes.Length, - InheritHandle: c.SysProcAttr.ThreadAttributes.InheritHandle, - } - } - - siEx.ProcThreadAttributeList = attrs.List() //nolint:govet // unusedwrite: ProcThreadAttributeList will be read in syscall - siEx.Cb = uint32(unsafe.Sizeof(*siEx)) - if c.SysProcAttr.Token != 0 { - err = windows.CreateProcessAsUser( - windows.Token(c.SysProcAttr.Token), - argv0p, - argvp, - pSec, - tSec, - false, - flags, - createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))), - dirp, - &siEx.StartupInfo, - pi, - ) - } else { - err = windows.CreateProcess( - argv0p, - argvp, - pSec, - tSec, - false, - flags, - createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))), - dirp, - &siEx.StartupInfo, - pi, - ) - } - if err != nil { - return fmt.Errorf("failed to create process: %w", err) - } - // Don't need the thread handle for anything. - defer func() { - _ = windows.CloseHandle(pi.Thread) - }() - - // Grab an *os.Process to avoid reinventing the wheel here. The stdlib has great logic around waiting, exit code status/cleanup after a - // process has been launched. - c.Process, err = os.FindProcess(int(pi.ProcessId)) - if err != nil { - // If we can't find the process via os.FindProcess, terminate the process as that's what we rely on for all further operations on the - // object. - if tErr := windows.TerminateProcess(pi.Process, 1); tErr != nil { - return fmt.Errorf("failed to terminate process after process not found: %w", tErr) - } - return fmt.Errorf("failed to find process after starting: %w", err) - } - - if c.ctx != nil { - go c.waitOnContext() - } - - return nil -} - -func (c *Cmd) waitOnContext() { - sys := c.sys.(*conPtySys) - select { - case <-c.ctx.Done(): - _ = c.Cancel() - sys.cmdErr = c.ctx.Err() - case err := <-sys.done: - sys.cmdErr = err - } -} - -func (c *Cmd) wait() (retErr error) { - if c.Process == nil { - return errNotStarted - } - if c.ProcessState != nil { - return errors.New("process already waited on") - } - defer func() { - sys := c.sys.(*conPtySys) - sys.attrs.Delete() - sys.done <- nil - if retErr == nil { - retErr = sys.cmdErr - } - }() - c.ProcessState, retErr = c.Process.Wait() - if retErr != nil { - return retErr - } - return -} - -// -// Below are a bunch of helpers for working with Windows' CreateProcess family of functions. These are mostly exact copies of the same utilities -// found in the go stdlib. -// - -func lookExtensions(path, dir string) (string, error) { - if filepath.Base(path) == path { - path = filepath.Join(".", path) - } - - if dir == "" { - return exec.LookPath(path) - } - - if filepath.VolumeName(path) != "" { - return exec.LookPath(path) - } - - if len(path) > 1 && os.IsPathSeparator(path[0]) { - return exec.LookPath(path) - } - - dirandpath := filepath.Join(dir, path) - - // We assume that LookPath will only add file extension. - lp, err := exec.LookPath(dirandpath) - if err != nil { - return "", err - } - - ext := strings.TrimPrefix(lp, dirandpath) - - return path + ext, nil -} - -func execEnvDefault(sys *syscall.SysProcAttr) (env []string, err error) { - if sys == nil || sys.Token == 0 { - return syscall.Environ(), nil - } - - var block *uint16 - err = windows.CreateEnvironmentBlock(&block, windows.Token(sys.Token), false) - if err != nil { - return nil, err - } - - defer windows.DestroyEnvironmentBlock(block) - blockp := uintptr(unsafe.Pointer(block)) - - for { - // find NUL terminator - end := unsafe.Pointer(blockp) - for *(*uint16)(end) != 0 { - end = unsafe.Pointer(uintptr(end) + 2) - } - - n := (uintptr(end) - uintptr(unsafe.Pointer(blockp))) / 2 - if n == 0 { - // environment block ends with empty string - break - } - - entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(blockp))[:n:n] - env = append(env, string(utf16.Decode(entry))) - blockp += 2 * (uintptr(len(entry)) + 1) - } - return -} - -func isSlash(c uint8) bool { - return c == '\\' || c == '/' -} - -func normalizeDir(dir string) (name string, err error) { - ndir, err := syscall.FullPath(dir) - if err != nil { - return "", err - } - if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) { - // dir cannot have \\server\share\path form - return "", syscall.EINVAL - } - return ndir, nil -} - -func volToUpper(ch int) int { - if 'a' <= ch && ch <= 'z' { - ch += 'A' - 'a' - } - return ch -} - -func joinExeDirAndFName(dir, p string) (name string, err error) { - if len(p) == 0 { - return "", syscall.EINVAL - } - if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) { - // \\server\share\path form - return p, nil - } - if len(p) > 1 && p[1] == ':' { - // has drive letter - if len(p) == 2 { - return "", syscall.EINVAL - } - if isSlash(p[2]) { - return p, nil - } else { - d, err := normalizeDir(dir) - if err != nil { - return "", err - } - if volToUpper(int(p[0])) == volToUpper(int(d[0])) { - return syscall.FullPath(d + "\\" + p[2:]) - } else { - return syscall.FullPath(p) - } - } - } else { - // no drive letter - d, err := normalizeDir(dir) - if err != nil { - return "", err - } - if isSlash(p[0]) { - return windows.FullPath(d[:2] + p) - } else { - return windows.FullPath(d + "\\" + p) - } - } -} - -// createEnvBlock converts an array of environment strings into -// the representation required by CreateProcess: a sequence of NUL -// terminated strings followed by a nil. -// Last bytes are two UCS-2 NULs, or four NUL bytes. -func createEnvBlock(envv []string) *uint16 { - if len(envv) == 0 { - return &utf16.Encode([]rune("\x00\x00"))[0] - } - length := 0 - for _, s := range envv { - length += len(s) + 1 - } - length++ - - b := make([]byte, length) - i := 0 - for _, s := range envv { - l := len(s) - copy(b[i:i+l], []byte(s)) - copy(b[i+l:i+l+1], []byte{0}) - i = i + l + 1 - } - copy(b[i:i+1], []byte{0}) - - return &utf16.Encode([]rune(string(b)))[0] -} - -// dedupEnvCase is dedupEnv with a case option for testing. -// If caseInsensitive is true, the case of keys is ignored. -func dedupEnvCase(caseInsensitive bool, env []string) []string { - out := make([]string, 0, len(env)) - saw := make(map[string]int, len(env)) // key => index into out - for _, kv := range env { - eq := strings.Index(kv, "=") - if eq < 0 { - out = append(out, kv) - continue - } - k := kv[:eq] - if caseInsensitive { - k = strings.ToLower(k) - } - if dupIdx, isDup := saw[k]; isDup { - out[dupIdx] = kv - continue - } - saw[k] = len(out) - out = append(out, kv) - } - return out -} - -// addCriticalEnv adds any critical environment variables that are required -// (or at least almost always required) on the operating system. -// Currently this is only used for Windows. -func addCriticalEnv(env []string) []string { - for _, kv := range env { - eq := strings.Index(kv, "=") - if eq < 0 { - continue - } - k := kv[:eq] - if strings.EqualFold(k, "SYSTEMROOT") { - // We already have it. - return env - } - } - return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) -} +//go:build windows +// +build windows + +package pty + +import ( + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +type conPtySys struct { + attrs *windows.ProcThreadAttributeListContainer + done chan error + cmdErr error +} + +func (c *Cmd) start() error { + pty, ok := c.pty.(*ConPty) + if !ok { + return ErrInvalidCommand + } + + if c.SysProcAttr == nil { + c.SysProcAttr = &syscall.SysProcAttr{} + } + + argv0, err := lookExtensions(c.Path, c.Dir) + if err != nil { + return err + } + if len(c.Dir) != 0 { + // Windows CreateProcess looks for argv0 relative to the current + // directory, and, only once the new process is started, it does + // Chdir(attr.Dir). We are adjusting for that difference here by + // making argv0 absolute. + var err error + argv0, err = joinExeDirAndFName(c.Dir, c.Path) + if err != nil { + return err + } + } + + argv0p, err := windows.UTF16PtrFromString(argv0) + if err != nil { + return err + } + + var cmdline string + if c.SysProcAttr.CmdLine != "" { + cmdline = c.SysProcAttr.CmdLine + } else { + cmdline = windows.ComposeCommandLine(c.Args) + } + argvp, err := windows.UTF16PtrFromString(cmdline) + if err != nil { + return err + } + + var dirp *uint16 + if len(c.Dir) != 0 { + dirp, err = windows.UTF16PtrFromString(c.Dir) + if err != nil { + return err + } + } + + if c.Env == nil { + c.Env, err = execEnvDefault(c.SysProcAttr) + if err != nil { + return err + } + } + + siEx := new(windows.StartupInfoEx) + siEx.Flags = windows.STARTF_USESTDHANDLES + pi := new(windows.ProcessInformation) + + // Need EXTENDED_STARTUPINFO_PRESENT as we're making use of the attribute list field. + flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT | c.SysProcAttr.CreationFlags + + // Allocate an attribute list that's large enough to do the operations we care about + // 2. Pseudo console setup if one was requested. + // Therefore we need a list of size 3. + attrs, err := windows.NewProcThreadAttributeList(1) + if err != nil { + return fmt.Errorf("failed to initialize process thread attribute list: %w", err) + } + + c.sys = &conPtySys{ + attrs: attrs, + done: make(chan error, 1), + } + + if err := pty.updateProcThreadAttribute(attrs); err != nil { + return err + } + + var zeroSec windows.SecurityAttributes + pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} + if c.SysProcAttr.ProcessAttributes != nil { + pSec = &windows.SecurityAttributes{ + Length: c.SysProcAttr.ProcessAttributes.Length, + InheritHandle: c.SysProcAttr.ProcessAttributes.InheritHandle, + } + } + tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} + if c.SysProcAttr.ThreadAttributes != nil { + tSec = &windows.SecurityAttributes{ + Length: c.SysProcAttr.ThreadAttributes.Length, + InheritHandle: c.SysProcAttr.ThreadAttributes.InheritHandle, + } + } + + siEx.ProcThreadAttributeList = attrs.List() //nolint:govet // unusedwrite: ProcThreadAttributeList will be read in syscall + siEx.Cb = uint32(unsafe.Sizeof(*siEx)) + if c.SysProcAttr.Token != 0 { + err = windows.CreateProcessAsUser( + windows.Token(c.SysProcAttr.Token), + argv0p, + argvp, + pSec, + tSec, + false, + flags, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))), + dirp, + &siEx.StartupInfo, + pi, + ) + } else { + err = windows.CreateProcess( + argv0p, + argvp, + pSec, + tSec, + false, + flags, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))), + dirp, + &siEx.StartupInfo, + pi, + ) + } + if err != nil { + return fmt.Errorf("failed to create process: %w", err) + } + // Don't need the thread handle for anything. + defer func() { + _ = windows.CloseHandle(pi.Thread) + }() + + // Grab an *os.Process to avoid reinventing the wheel here. The stdlib has great logic around waiting, exit code status/cleanup after a + // process has been launched. + c.Process, err = os.FindProcess(int(pi.ProcessId)) + if err != nil { + // If we can't find the process via os.FindProcess, terminate the process as that's what we rely on for all further operations on the + // object. + if tErr := windows.TerminateProcess(pi.Process, 1); tErr != nil { + return fmt.Errorf("failed to terminate process after process not found: %w", tErr) + } + return fmt.Errorf("failed to find process after starting: %w", err) + } + + if c.ctx != nil { + go c.waitOnContext() + } + + return nil +} + +func (c *Cmd) waitOnContext() { + sys := c.sys.(*conPtySys) + select { + case <-c.ctx.Done(): + _ = c.Cancel() + sys.cmdErr = c.ctx.Err() + case err := <-sys.done: + sys.cmdErr = err + } +} + +func (c *Cmd) wait() (retErr error) { + if c.Process == nil { + return errNotStarted + } + if c.ProcessState != nil { + return errors.New("process already waited on") + } + defer func() { + sys := c.sys.(*conPtySys) + sys.attrs.Delete() + sys.done <- nil + if retErr == nil { + retErr = sys.cmdErr + } + }() + c.ProcessState, retErr = c.Process.Wait() + if retErr != nil { + return retErr + } + return +} + +// +// Below are a bunch of helpers for working with Windows' CreateProcess family of functions. These are mostly exact copies of the same utilities +// found in the go stdlib. +// + +func lookExtensions(path, dir string) (string, error) { + if filepath.Base(path) == path { + path = filepath.Join(".", path) + } + + if dir == "" { + return exec.LookPath(path) + } + + if filepath.VolumeName(path) != "" { + return exec.LookPath(path) + } + + if len(path) > 1 && os.IsPathSeparator(path[0]) { + return exec.LookPath(path) + } + + dirandpath := filepath.Join(dir, path) + + // We assume that LookPath will only add file extension. + lp, err := exec.LookPath(dirandpath) + if err != nil { + return "", err + } + + ext := strings.TrimPrefix(lp, dirandpath) + + return path + ext, nil +} + +func execEnvDefault(sys *syscall.SysProcAttr) (env []string, err error) { + if sys == nil || sys.Token == 0 { + return syscall.Environ(), nil + } + + var block *uint16 + err = windows.CreateEnvironmentBlock(&block, windows.Token(sys.Token), false) + if err != nil { + return nil, err + } + + defer windows.DestroyEnvironmentBlock(block) + blockp := uintptr(unsafe.Pointer(block)) + + for { + // find NUL terminator + end := unsafe.Pointer(blockp) + for *(*uint16)(end) != 0 { + end = unsafe.Pointer(uintptr(end) + 2) + } + + n := (uintptr(end) - uintptr(unsafe.Pointer(blockp))) / 2 + if n == 0 { + // environment block ends with empty string + break + } + + entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(blockp))[:n:n] + env = append(env, string(utf16.Decode(entry))) + blockp += 2 * (uintptr(len(entry)) + 1) + } + return +} + +func isSlash(c uint8) bool { + return c == '\\' || c == '/' +} + +func normalizeDir(dir string) (name string, err error) { + ndir, err := syscall.FullPath(dir) + if err != nil { + return "", err + } + if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) { + // dir cannot have \\server\share\path form + return "", syscall.EINVAL + } + return ndir, nil +} + +func volToUpper(ch int) int { + if 'a' <= ch && ch <= 'z' { + ch += 'A' - 'a' + } + return ch +} + +func joinExeDirAndFName(dir, p string) (name string, err error) { + if len(p) == 0 { + return "", syscall.EINVAL + } + if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) { + // \\server\share\path form + return p, nil + } + if len(p) > 1 && p[1] == ':' { + // has drive letter + if len(p) == 2 { + return "", syscall.EINVAL + } + if isSlash(p[2]) { + return p, nil + } else { + d, err := normalizeDir(dir) + if err != nil { + return "", err + } + if volToUpper(int(p[0])) == volToUpper(int(d[0])) { + return syscall.FullPath(d + "\\" + p[2:]) + } else { + return syscall.FullPath(p) + } + } + } else { + // no drive letter + d, err := normalizeDir(dir) + if err != nil { + return "", err + } + if isSlash(p[0]) { + return windows.FullPath(d[:2] + p) + } else { + return windows.FullPath(d + "\\" + p) + } + } +} + +// createEnvBlock converts an array of environment strings into +// the representation required by CreateProcess: a sequence of NUL +// terminated strings followed by a nil. +// Last bytes are two UCS-2 NULs, or four NUL bytes. +func createEnvBlock(envv []string) *uint16 { + if len(envv) == 0 { + return &utf16.Encode([]rune("\x00\x00"))[0] + } + length := 0 + for _, s := range envv { + length += len(s) + 1 + } + length++ + + b := make([]byte, length) + i := 0 + for _, s := range envv { + l := len(s) + copy(b[i:i+l], []byte(s)) + copy(b[i+l:i+l+1], []byte{0}) + i = i + l + 1 + } + copy(b[i:i+1], []byte{0}) + + return &utf16.Encode([]rune(string(b)))[0] +} + +// dedupEnvCase is dedupEnv with a case option for testing. +// If caseInsensitive is true, the case of keys is ignored. +func dedupEnvCase(caseInsensitive bool, env []string) []string { + out := make([]string, 0, len(env)) + saw := make(map[string]int, len(env)) // key => index into out + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + out = append(out, kv) + continue + } + k := kv[:eq] + if caseInsensitive { + k = strings.ToLower(k) + } + if dupIdx, isDup := saw[k]; isDup { + out[dupIdx] = kv + continue + } + saw[k] = len(out) + out = append(out, kv) + } + return out +} + +// addCriticalEnv adds any critical environment variables that are required +// (or at least almost always required) on the operating system. +// Currently this is only used for Windows. +func addCriticalEnv(env []string) []string { + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + continue + } + k := kv[:eq] + if strings.EqualFold(k, "SYSTEMROOT") { + // We already have it. + return env + } + } + return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) +} diff --git a/examples/command/main.go b/examples/command/main.go index a3db709..97f07d4 100644 --- a/examples/command/main.go +++ b/examples/command/main.go @@ -1,34 +1,34 @@ -package main - -import ( - "io" - "log" - "os" - - "github.com/aymanbagabas/go-pty" -) - -func main() { - pty, err := pty.New() - if err != nil { - log.Fatalf("failed to open pty: %s", err) - } - - defer pty.Close() - c := pty.Command("grep", "--color=auto", "bar") - if err := c.Start(); err != nil { - log.Fatalf("failed to start: %s", err) - } - - go func() { - pty.Write([]byte("foo\n")) - pty.Write([]byte("bar\n")) - pty.Write([]byte("baz\n")) - pty.Write([]byte{4}) // EOT - }() - go io.Copy(os.Stdout, pty) - - if err := c.Wait(); err != nil { - panic(err) - } -} +package main + +import ( + "io" + "log" + "os" + + "github.com/aymanbagabas/go-pty" +) + +func main() { + pty, err := pty.New() + if err != nil { + log.Fatalf("failed to open pty: %s", err) + } + + defer pty.Close() + c := pty.Command("grep", "--color=auto", "bar") + if err := c.Start(); err != nil { + log.Fatalf("failed to start: %s", err) + } + + go func() { + pty.Write([]byte("foo\n")) + pty.Write([]byte("bar\n")) + pty.Write([]byte("baz\n")) + pty.Write([]byte{4}) // EOT + }() + go io.Copy(os.Stdout, pty) + + if err := c.Wait(); err != nil { + panic(err) + } +} diff --git a/examples/shell/main.go b/examples/shell/main.go index bda0fb1..1b54464 100644 --- a/examples/shell/main.go +++ b/examples/shell/main.go @@ -1,56 +1,56 @@ -package main - -import ( - "io" - "log" - "os" - "os/signal" - - "github.com/aymanbagabas/go-pty" - "golang.org/x/term" -) - -type PTY interface { - Resize(w, h int) error -} - -func test() error { - ptmx, err := pty.New() - if err != nil { - return err - } - - defer ptmx.Close() - - c := ptmx.Command(`bash`) - if err := c.Start(); err != nil { - return err - } - - // Handle pty size. - ch := make(chan os.Signal, 1) - notifySizeChanges(ch) - go handlePtySize(ptmx, ch) - initSizeChange(ch) - defer func() { signal.Stop(ch); close(ch) }() // Cleanup signals when done. - - // Set stdin in raw mode. - oldState, err := term.MakeRaw(int(os.Stdin.Fd())) - if err != nil { - panic(err) - } - defer func() { _ = term.Restore(int(os.Stdin.Fd()), oldState) }() // Best effort. - - // Copy stdin to the pty and the pty to stdout. - // NOTE: The goroutine will keep reading until the next keystroke before returning. - go io.Copy(ptmx, os.Stdin) - go io.Copy(os.Stdout, ptmx) - - return c.Wait() -} - -func main() { - if err := test(); err != nil { - log.Fatal(err) - } -} +package main + +import ( + "io" + "log" + "os" + "os/signal" + + "github.com/aymanbagabas/go-pty" + "golang.org/x/term" +) + +type PTY interface { + Resize(w, h int) error +} + +func test() error { + ptmx, err := pty.New() + if err != nil { + return err + } + + defer ptmx.Close() + + c := ptmx.Command(`bash`) + if err := c.Start(); err != nil { + return err + } + + // Handle pty size. + ch := make(chan os.Signal, 1) + notifySizeChanges(ch) + go handlePtySize(ptmx, ch) + initSizeChange(ch) + defer func() { signal.Stop(ch); close(ch) }() // Cleanup signals when done. + + // Set stdin in raw mode. + oldState, err := term.MakeRaw(int(os.Stdin.Fd())) + if err != nil { + panic(err) + } + defer func() { _ = term.Restore(int(os.Stdin.Fd()), oldState) }() // Best effort. + + // Copy stdin to the pty and the pty to stdout. + // NOTE: The goroutine will keep reading until the next keystroke before returning. + go io.Copy(ptmx, os.Stdin) + go io.Copy(os.Stdout, ptmx) + + return c.Wait() +} + +func main() { + if err := test(); err != nil { + log.Fatal(err) + } +} diff --git a/examples/shell/size_other.go b/examples/shell/size_other.go index 64b811b..b983bcf 100644 --- a/examples/shell/size_other.go +++ b/examples/shell/size_other.go @@ -1,35 +1,35 @@ -//go:build !windows -// +build !windows - -package main - -import ( - "log" - "os" - "os/signal" - "syscall" - - "github.com/aymanbagabas/go-pty" - "golang.org/x/term" -) - -func notifySizeChanges(ch chan os.Signal) { - signal.Notify(ch, syscall.SIGWINCH) -} - -func handlePtySize(p pty.Pty, ch chan os.Signal) { - for range ch { - w, h, err := term.GetSize(int(os.Stdin.Fd())) - if err != nil { - log.Printf("error resizing pty: %s", err) - continue - } - if err := p.Resize(w, h); err != nil { - log.Printf("error resizing pty: %s", err) - } - } -} - -func initSizeChange(ch chan os.Signal) { - ch <- syscall.SIGWINCH -} +//go:build !windows +// +build !windows + +package main + +import ( + "log" + "os" + "os/signal" + "syscall" + + "github.com/aymanbagabas/go-pty" + "golang.org/x/term" +) + +func notifySizeChanges(ch chan os.Signal) { + signal.Notify(ch, syscall.SIGWINCH) +} + +func handlePtySize(p pty.Pty, ch chan os.Signal) { + for range ch { + w, h, err := term.GetSize(int(os.Stdin.Fd())) + if err != nil { + log.Printf("error resizing pty: %s", err) + continue + } + if err := p.Resize(w, h); err != nil { + log.Printf("error resizing pty: %s", err) + } + } +} + +func initSizeChange(ch chan os.Signal) { + ch <- syscall.SIGWINCH +} diff --git a/examples/shell/size_windows.go b/examples/shell/size_windows.go index a2c8bb7..62c5a31 100644 --- a/examples/shell/size_windows.go +++ b/examples/shell/size_windows.go @@ -1,18 +1,18 @@ -//go:build windows -// +build windows - -package main - -import ( - "os" - - "github.com/aymanbagabas/go-pty" -) - -func notifySizeChanges(chan os.Signal) {} - -func handlePtySize(p pty.Pty, _ chan os.Signal) { - // TODO -} - -func initSizeChange(chan os.Signal) {} +//go:build windows +// +build windows + +package main + +import ( + "os" + + "github.com/aymanbagabas/go-pty" +) + +func notifySizeChanges(chan os.Signal) {} + +func handlePtySize(p pty.Pty, _ chan os.Signal) { + // TODO +} + +func initSizeChange(chan os.Signal) {} diff --git a/examples/ssh/main.go b/examples/ssh/main.go index 357bac1..1894e64 100644 --- a/examples/ssh/main.go +++ b/examples/ssh/main.go @@ -1,63 +1,63 @@ -package main - -import ( - "io" - "log" - - "github.com/aymanbagabas/go-pty" - "github.com/charmbracelet/ssh" -) - -func main() { - ssh.Handle(func(s ssh.Session) { - ptyReq, winCh, isPty := s.Pty() - if isPty { - pseudo, err := pty.New() - if err != nil { - log.Println(err) - return - } - - defer pseudo.Close() - w, h := ptyReq.Window.Width, ptyReq.Window.Height - if ptyReq.Modes != nil { - if err := pty.ApplyTerminalModes(int(pseudo.Fd()), w, h, ptyReq.Modes); err != nil { - log.Println(err) - return - } - } - if err := pseudo.Resize(w, h); err != nil { - log.Println(err) - return - } - - cmd := pseudo.Command("bash") - cmd.Env = append(cmd.Env, "TERM="+ptyReq.Term, "SSH_TTY="+pseudo.Name()) - - if err := cmd.Start(); err != nil { - log.Print(err) - return - } - - go func() { - for win := range winCh { - pseudo.Resize(win.Height, win.Width) - } - }() - - go io.Copy(pseudo, s) // stdin - go io.Copy(s, pseudo) // stdout - - if err := cmd.Wait(); err != nil { - log.Print(err) - return - } - } else { - io.WriteString(s, "No PTY requested.\n") - s.Exit(1) - } - }) - - log.Println("starting ssh server on port 2222...") - log.Fatal(ssh.ListenAndServe(":2222", nil)) -} +package main + +import ( + "io" + "log" + + "github.com/aymanbagabas/go-pty" + "github.com/charmbracelet/ssh" +) + +func main() { + ssh.Handle(func(s ssh.Session) { + ptyReq, winCh, isPty := s.Pty() + if isPty { + pseudo, err := pty.New() + if err != nil { + log.Println(err) + return + } + + defer pseudo.Close() + w, h := ptyReq.Window.Width, ptyReq.Window.Height + if ptyReq.Modes != nil { + if err := pty.ApplyTerminalModes(int(pseudo.Fd()), w, h, ptyReq.Modes); err != nil { + log.Println(err) + return + } + } + if err := pseudo.Resize(w, h); err != nil { + log.Println(err) + return + } + + cmd := pseudo.Command("bash") + cmd.Env = append(cmd.Env, "TERM="+ptyReq.Term, "SSH_TTY="+pseudo.Name()) + + if err := cmd.Start(); err != nil { + log.Print(err) + return + } + + go func() { + for win := range winCh { + pseudo.Resize(win.Height, win.Width) + } + }() + + go io.Copy(pseudo, s) // stdin + go io.Copy(s, pseudo) // stdout + + if err := cmd.Wait(); err != nil { + log.Print(err) + return + } + } else { + io.WriteString(s, "No PTY requested.\n") + s.Exit(1) + } + }) + + log.Println("starting ssh server on port 2222...") + log.Fatal(ssh.ListenAndServe(":2222", nil)) +} diff --git a/examples/ssh/modes_other.go b/examples/ssh/modes_other.go index 6ba2668..361a04d 100644 --- a/examples/ssh/modes_other.go +++ b/examples/ssh/modes_other.go @@ -1,126 +1,126 @@ -//go:build !windows -// +build !windows - -package main - -import ( - "fmt" - "log" - - "github.com/u-root/u-root/pkg/termios" - "golang.org/x/crypto/ssh" -) - -// terminalModeFlagNames maps the SSH terminal mode flags to mnemonic -// names used by the termios package. -var terminalModeFlagNames = map[uint8]string{ - ssh.VINTR: "intr", - ssh.VQUIT: "quit", - ssh.VERASE: "erase", - ssh.VKILL: "kill", - ssh.VEOF: "eof", - ssh.VEOL: "eol", - ssh.VEOL2: "eol2", - ssh.VSTART: "start", - ssh.VSTOP: "stop", - ssh.VSUSP: "susp", - ssh.VDSUSP: "dsusp", - ssh.VREPRINT: "rprnt", - ssh.VWERASE: "werase", - ssh.VLNEXT: "lnext", - ssh.VFLUSH: "flush", - ssh.VSWTCH: "swtch", - ssh.VSTATUS: "status", - ssh.VDISCARD: "discard", - ssh.IGNPAR: "ignpar", - ssh.PARMRK: "parmrk", - ssh.INPCK: "inpck", - ssh.ISTRIP: "istrip", - ssh.INLCR: "inlcr", - ssh.IGNCR: "igncr", - ssh.ICRNL: "icrnl", - ssh.IUCLC: "iuclc", - ssh.IXON: "ixon", - ssh.IXANY: "ixany", - ssh.IXOFF: "ixoff", - ssh.IMAXBEL: "imaxbel", - ssh.IUTF8: "iutf8", - ssh.ISIG: "isig", - ssh.ICANON: "icanon", - ssh.XCASE: "xcase", - ssh.ECHO: "echo", - ssh.ECHOE: "echoe", - ssh.ECHOK: "echok", - ssh.ECHONL: "echonl", - ssh.NOFLSH: "noflsh", - ssh.TOSTOP: "tostop", - ssh.IEXTEN: "iexten", - ssh.ECHOCTL: "echoctl", - ssh.ECHOKE: "echoke", - ssh.PENDIN: "pendin", - ssh.OPOST: "opost", - ssh.OLCUC: "olcuc", - ssh.ONLCR: "onlcr", - ssh.OCRNL: "ocrnl", - ssh.ONOCR: "onocr", - ssh.ONLRET: "onlret", - ssh.CS7: "cs7", - ssh.CS8: "cs8", - ssh.PARENB: "parenb", - ssh.PARODD: "parodd", - ssh.TTY_OP_ISPEED: "tty_op_ispeed", - ssh.TTY_OP_OSPEED: "tty_op_ospeed", -} - -func applyTerminalModesToFd(fd uintptr, width int, height int, modes ssh.TerminalModes, logger *log.Logger) error { - if modes == nil { - modes = ssh.TerminalModes{} - } - - // Get the current TTY configuration. - tios, err := termios.GTTY(int(fd)) - if err != nil { - return fmt.Errorf("GTTY: %w", err) - } - - // Apply the modes from the SSH request. - tios.Row = height - tios.Col = width - - for c, v := range modes { - if c == ssh.TTY_OP_ISPEED { - tios.Ispeed = int(v) - continue - } - if c == ssh.TTY_OP_OSPEED { - tios.Ospeed = int(v) - continue - } - k, ok := terminalModeFlagNames[c] - if !ok { - if logger != nil { - logger.Printf("unknown terminal mode: %d", c) - } - continue - } - if _, ok := tios.CC[k]; ok { - tios.CC[k] = uint8(v) - continue - } - if _, ok := tios.Opts[k]; ok { - tios.Opts[k] = v > 0 - continue - } - - if logger != nil { - logger.Printf("unsupported terminal mode: k=%s, c=%d, v=%d", k, c, v) - } - } - - // Save the new TTY configuration. - if _, err := tios.STTY(int(fd)); err != nil { - return fmt.Errorf("STTY: %w", err) - } - - return nil -} +//go:build !windows +// +build !windows + +package main + +import ( + "fmt" + "log" + + "github.com/u-root/u-root/pkg/termios" + "golang.org/x/crypto/ssh" +) + +// terminalModeFlagNames maps the SSH terminal mode flags to mnemonic +// names used by the termios package. +var terminalModeFlagNames = map[uint8]string{ + ssh.VINTR: "intr", + ssh.VQUIT: "quit", + ssh.VERASE: "erase", + ssh.VKILL: "kill", + ssh.VEOF: "eof", + ssh.VEOL: "eol", + ssh.VEOL2: "eol2", + ssh.VSTART: "start", + ssh.VSTOP: "stop", + ssh.VSUSP: "susp", + ssh.VDSUSP: "dsusp", + ssh.VREPRINT: "rprnt", + ssh.VWERASE: "werase", + ssh.VLNEXT: "lnext", + ssh.VFLUSH: "flush", + ssh.VSWTCH: "swtch", + ssh.VSTATUS: "status", + ssh.VDISCARD: "discard", + ssh.IGNPAR: "ignpar", + ssh.PARMRK: "parmrk", + ssh.INPCK: "inpck", + ssh.ISTRIP: "istrip", + ssh.INLCR: "inlcr", + ssh.IGNCR: "igncr", + ssh.ICRNL: "icrnl", + ssh.IUCLC: "iuclc", + ssh.IXON: "ixon", + ssh.IXANY: "ixany", + ssh.IXOFF: "ixoff", + ssh.IMAXBEL: "imaxbel", + ssh.IUTF8: "iutf8", + ssh.ISIG: "isig", + ssh.ICANON: "icanon", + ssh.XCASE: "xcase", + ssh.ECHO: "echo", + ssh.ECHOE: "echoe", + ssh.ECHOK: "echok", + ssh.ECHONL: "echonl", + ssh.NOFLSH: "noflsh", + ssh.TOSTOP: "tostop", + ssh.IEXTEN: "iexten", + ssh.ECHOCTL: "echoctl", + ssh.ECHOKE: "echoke", + ssh.PENDIN: "pendin", + ssh.OPOST: "opost", + ssh.OLCUC: "olcuc", + ssh.ONLCR: "onlcr", + ssh.OCRNL: "ocrnl", + ssh.ONOCR: "onocr", + ssh.ONLRET: "onlret", + ssh.CS7: "cs7", + ssh.CS8: "cs8", + ssh.PARENB: "parenb", + ssh.PARODD: "parodd", + ssh.TTY_OP_ISPEED: "tty_op_ispeed", + ssh.TTY_OP_OSPEED: "tty_op_ospeed", +} + +func applyTerminalModesToFd(fd uintptr, width int, height int, modes ssh.TerminalModes, logger *log.Logger) error { + if modes == nil { + modes = ssh.TerminalModes{} + } + + // Get the current TTY configuration. + tios, err := termios.GTTY(int(fd)) + if err != nil { + return fmt.Errorf("GTTY: %w", err) + } + + // Apply the modes from the SSH request. + tios.Row = height + tios.Col = width + + for c, v := range modes { + if c == ssh.TTY_OP_ISPEED { + tios.Ispeed = int(v) + continue + } + if c == ssh.TTY_OP_OSPEED { + tios.Ospeed = int(v) + continue + } + k, ok := terminalModeFlagNames[c] + if !ok { + if logger != nil { + logger.Printf("unknown terminal mode: %d", c) + } + continue + } + if _, ok := tios.CC[k]; ok { + tios.CC[k] = uint8(v) + continue + } + if _, ok := tios.Opts[k]; ok { + tios.Opts[k] = v > 0 + continue + } + + if logger != nil { + logger.Printf("unsupported terminal mode: k=%s, c=%d, v=%d", k, c, v) + } + } + + // Save the new TTY configuration. + if _, err := tios.STTY(int(fd)); err != nil { + return fmt.Errorf("STTY: %w", err) + } + + return nil +} diff --git a/examples/ssh/modes_windows.go b/examples/ssh/modes_windows.go index 5e0fe43..157f81c 100644 --- a/examples/ssh/modes_windows.go +++ b/examples/ssh/modes_windows.go @@ -1,14 +1,14 @@ -//go:build windows -// +build windows - -package main - -import ( - "log" - - "golang.org/x/crypto/ssh" -) - -func applyTerminalModesToFd(fd uintptr, width int, height int, modes ssh.TerminalModes, logger *log.Logger) error { - return nil -} +//go:build windows +// +build windows + +package main + +import ( + "log" + + "golang.org/x/crypto/ssh" +) + +func applyTerminalModesToFd(fd uintptr, width int, height int, modes ssh.TerminalModes, logger *log.Logger) error { + return nil +} diff --git a/ssh_other.go b/ssh_other.go index 8defcde..84792dc 100644 --- a/ssh_other.go +++ b/ssh_other.go @@ -1,13 +1,13 @@ -//go:build !linux && !darwin && !freebsd && !dragonfly && !netbsd && !openbsd && !solaris -// +build !linux,!darwin,!freebsd,!dragonfly,!netbsd,!openbsd,!solaris - -package pty - -import ( - "golang.org/x/crypto/ssh" -) - -func applyTerminalModesToFd(fd int, width int, height int, modes ssh.TerminalModes) error { - // TODO - return nil -} +//go:build !linux && !darwin && !freebsd && !dragonfly && !netbsd && !openbsd && !solaris +// +build !linux,!darwin,!freebsd,!dragonfly,!netbsd,!openbsd,!solaris + +package pty + +import ( + "golang.org/x/crypto/ssh" +) + +func applyTerminalModesToFd(fd int, width int, height int, modes ssh.TerminalModes) error { + // TODO + return nil +} diff --git a/ssh_unix.go b/ssh_unix.go index 0905b8e..eb7a3e2 100644 --- a/ssh_unix.go +++ b/ssh_unix.go @@ -1,53 +1,53 @@ -//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris -// +build darwin dragonfly freebsd linux netbsd openbsd solaris - -package pty - -import ( - "fmt" - - "github.com/u-root/u-root/pkg/termios" - "golang.org/x/crypto/ssh" -) - -func applyTerminalModesToFd(fd int, width int, height int, modes ssh.TerminalModes) error { - // Get the current TTY configuration. - tios, err := termios.GTTY(int(fd)) - if err != nil { - return fmt.Errorf("GTTY: %w", err) - } - - // Apply the modes from the SSH request. - tios.Row = height - tios.Col = width - - for c, v := range modes { - if c == ssh.TTY_OP_ISPEED { - tios.Ispeed = int(v) - continue - } - if c == ssh.TTY_OP_OSPEED { - tios.Ospeed = int(v) - continue - } - k, ok := terminalModeFlagNames[c] - if !ok { - continue - } - if _, ok := tios.CC[k]; ok { - tios.CC[k] = uint8(v) - continue - } - if _, ok := tios.Opts[k]; ok { - tios.Opts[k] = v > 0 - continue - } - } - - // Save the new TTY configuration. - if _, err := tios.STTY(int(fd)); err != nil { - return fmt.Errorf("STTY: %w", err) - } - - return nil -} +//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris +// +build darwin dragonfly freebsd linux netbsd openbsd solaris + +package pty + +import ( + "fmt" + + "github.com/u-root/u-root/pkg/termios" + "golang.org/x/crypto/ssh" +) + +func applyTerminalModesToFd(fd int, width int, height int, modes ssh.TerminalModes) error { + // Get the current TTY configuration. + tios, err := termios.GTTY(int(fd)) + if err != nil { + return fmt.Errorf("GTTY: %w", err) + } + + // Apply the modes from the SSH request. + tios.Row = height + tios.Col = width + + for c, v := range modes { + if c == ssh.TTY_OP_ISPEED { + tios.Ispeed = int(v) + continue + } + if c == ssh.TTY_OP_OSPEED { + tios.Ospeed = int(v) + continue + } + k, ok := terminalModeFlagNames[c] + if !ok { + continue + } + if _, ok := tios.CC[k]; ok { + tios.CC[k] = uint8(v) + continue + } + if _, ok := tios.Opts[k]; ok { + tios.Opts[k] = v > 0 + continue + } + } + + // Save the new TTY configuration. + if _, err := tios.STTY(int(fd)); err != nil { + return fmt.Errorf("STTY: %w", err) + } + + return nil +}