From 2f46f5126f8ca1a1abf25d80f3533a2fa70a915c Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 14 Sep 2023 15:08:58 -0400 Subject: [PATCH] feat: initial implementation Support Windows ConPty and *nix through creack/pty --- cmd.go | 65 +++ cmd_other.go | 45 ++ cmd_windows.go | 413 ++++++++++++++ examples/command/main.go | 34 ++ examples/go.mod | 21 + examples/go.sum | 25 + examples/shell/main.go | 56 ++ examples/shell/size_other.go | 35 ++ examples/shell/size_windows.go | 18 + examples/ssh/main.go | 63 +++ .../ssh/modes_other.go | 251 ++++----- .../ssh/modes_windows.go | 21 +- go.mod | 21 +- go.sum | 39 +- pty.go | 143 +---- pty_linux.go | 24 - pty_other.go | 236 +++----- pty_windows.go | 312 ++++------- ptytest/ptytest.go | 524 ------------------ ptytest/ptytest_internal_test.go | 37 -- ptytest/ptytest_test.go | 77 --- ssh.go | 75 +++ ssh/ssh.go | 16 - ssh/ssh_test.go | 45 -- ssh_other.go | 52 ++ ssh_windows.go | 13 + start.go | 73 --- start_other.go | 83 --- start_other_test.go | 75 --- start_test.go | 176 ------ start_windows.go | 204 ------- start_windows_test.go | 75 --- zsyscall_windows.go | 75 +++ 33 files changed, 1357 insertions(+), 2065 deletions(-) create mode 100644 cmd.go create mode 100644 cmd_other.go create mode 100644 cmd_windows.go create mode 100644 examples/command/main.go create mode 100644 examples/go.mod create mode 100644 examples/go.sum create mode 100644 examples/shell/main.go create mode 100644 examples/shell/size_other.go create mode 100644 examples/shell/size_windows.go create mode 100644 examples/ssh/main.go rename ssh/ssh_other.go => examples/ssh/modes_other.go (92%) rename ssh/ssh_windows.go => examples/ssh/modes_windows.go (57%) delete mode 100644 pty_linux.go delete mode 100644 ptytest/ptytest.go delete mode 100644 ptytest/ptytest_internal_test.go delete mode 100644 ptytest/ptytest_test.go create mode 100644 ssh.go delete mode 100644 ssh/ssh.go delete mode 100644 ssh/ssh_test.go create mode 100644 ssh_other.go create mode 100644 ssh_windows.go delete mode 100644 start.go delete mode 100644 start_other.go delete mode 100644 start_other_test.go delete mode 100644 start_test.go delete mode 100644 start_windows.go delete mode 100644 start_windows_test.go create mode 100644 zsyscall_windows.go diff --git a/cmd.go b/cmd.go new file mode 100644 index 0000000..bc08dd0 --- /dev/null +++ b/cmd.go @@ -0,0 +1,65 @@ +package pty + +import ( + "context" + "os" + "syscall" +) + +// Cmd is a command that can be started attached to a pseudo-terminal. +// This is similar to the API of exec.Cmd. The main difference is that +// the command is started attached to a pseudo-terminal. +// This is required as we cannot use exec.Cmd directly on Windows due to +// limitation of starting a process attached to a pseudo-terminal. +// See: https://github.com/golang/go/issues/62708 +type Cmd struct { + ctx context.Context + pty Pty + sys interface{} + + // Path is the path of the command to run. + Path string + + // Args holds command line arguments, including the command as Args[0]. + Args []string + + // Env specifies the environment of the process. + // If Env is nil, the new process uses the current process's environment. + Env []string + + // Dir specifies the working directory of the command. + // If Dir is the empty string, the current directory is used. + Dir string + + // SysProcAttr holds optional, operating system-specific attributes. + SysProcAttr *syscall.SysProcAttr + + // Process is the underlying process, once started. + Process *os.Process + + // ProcessState contains information about an exited process. + // If the process was started successfully, Wait or Run will populate this + // field when the command completes. + ProcessState *os.ProcessState + + // Cancel is called when the command is canceled. + Cancel func() error +} + +// Start starts the specified command attached to the pseudo-terminal. +func (c *Cmd) Start() error { + return c.start() +} + +// Wait waits for the command to exit. +func (c *Cmd) Wait() error { + return c.wait() +} + +// Run runs the command and waits for it to complete. +func (c *Cmd) Run() error { + if err := c.Start(); err != nil { + return err + } + return c.Wait() +} diff --git a/cmd_other.go b/cmd_other.go new file mode 100644 index 0000000..638fe64 --- /dev/null +++ b/cmd_other.go @@ -0,0 +1,45 @@ +//go:build !windows +// +build !windows + +package pty + +import ( + "os/exec" + + "golang.org/x/sys/unix" +) + +func (c *Cmd) start() error { + cmd, ok := c.sys.(*exec.Cmd) + if !ok { + return ErrInvalidCommand + } + pty, ok := c.pty.(*PosixPty) + if !ok { + return ErrInvalidCommand + } + + cmd.Stdin = pty.slave + cmd.Stdout = pty.slave + cmd.Stderr = pty.slave + cmd.SysProcAttr = &unix.SysProcAttr{ + Setsid: true, + Setctty: true, + } + if err := cmd.Start(); err != nil { + return err + } + + c.Process = cmd.Process + return nil +} + +func (c *Cmd) wait() error { + cmd, ok := c.sys.(*exec.Cmd) + if !ok { + return ErrInvalidCommand + } + err := cmd.Wait() + c.ProcessState = cmd.ProcessState + return err +} diff --git a/cmd_windows.go b/cmd_windows.go new file mode 100644 index 0000000..97f52ea --- /dev/null +++ b/cmd_windows.go @@ -0,0 +1,413 @@ +//go:build windows +// +build windows + +package pty + +import ( + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +type conPtySys struct { + attrs *windows.ProcThreadAttributeListContainer + done chan error + cmdErr error +} + +func (c *Cmd) start() error { + pty, ok := c.pty.(*ConPty) + if !ok { + return ErrInvalidCommand + } + + if c.SysProcAttr == nil { + c.SysProcAttr = &syscall.SysProcAttr{} + } + + argv0, err := lookExtensions(c.Path, c.Dir) + if err != nil { + return err + } + if len(c.Dir) != 0 { + // Windows CreateProcess looks for argv0 relative to the current + // directory, and, only once the new process is started, it does + // Chdir(attr.Dir). We are adjusting for that difference here by + // making argv0 absolute. + var err error + argv0, err = joinExeDirAndFName(c.Dir, c.Path) + if err != nil { + return err + } + } + + argv0p, err := windows.UTF16PtrFromString(argv0) + if err != nil { + return err + } + + var cmdline string + if c.SysProcAttr.CmdLine != "" { + cmdline = c.SysProcAttr.CmdLine + } else { + cmdline = windows.ComposeCommandLine(c.Args) + } + argvp, err := windows.UTF16PtrFromString(cmdline) + if err != nil { + return err + } + + var dirp *uint16 + if len(c.Dir) != 0 { + dirp, err = windows.UTF16PtrFromString(c.Dir) + if err != nil { + return err + } + } + + if c.Env == nil { + c.Env, err = execEnvDefault(c.SysProcAttr) + if err != nil { + return err + } + } + + siEx := new(windows.StartupInfoEx) + siEx.Flags = windows.STARTF_USESTDHANDLES + pi := new(windows.ProcessInformation) + + // Need EXTENDED_STARTUPINFO_PRESENT as we're making use of the attribute list field. + flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT | c.SysProcAttr.CreationFlags + + // Allocate an attribute list that's large enough to do the operations we care about + // 2. Pseudo console setup if one was requested. + // Therefore we need a list of size 3. + attrs, err := windows.NewProcThreadAttributeList(1) + if err != nil { + return fmt.Errorf("failed to initialize process thread attribute list: %w", err) + } + + c.sys = &conPtySys{ + attrs: attrs, + done: make(chan error, 1), + } + + if err := pty.updateProcThreadAttribute(attrs); err != nil { + return err + } + + var zeroSec windows.SecurityAttributes + pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} + if c.SysProcAttr.ProcessAttributes != nil { + pSec = &windows.SecurityAttributes{ + Length: c.SysProcAttr.ProcessAttributes.Length, + InheritHandle: c.SysProcAttr.ProcessAttributes.InheritHandle, + } + } + tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} + if c.SysProcAttr.ThreadAttributes != nil { + tSec = &windows.SecurityAttributes{ + Length: c.SysProcAttr.ThreadAttributes.Length, + InheritHandle: c.SysProcAttr.ThreadAttributes.InheritHandle, + } + } + + siEx.ProcThreadAttributeList = attrs.List() //nolint:govet // unusedwrite: ProcThreadAttributeList will be read in syscall + siEx.Cb = uint32(unsafe.Sizeof(*siEx)) + if c.SysProcAttr.Token != 0 { + err = windows.CreateProcessAsUser( + windows.Token(c.SysProcAttr.Token), + argv0p, + argvp, + pSec, + tSec, + false, + flags, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))), + dirp, + &siEx.StartupInfo, + pi, + ) + } else { + err = windows.CreateProcess( + argv0p, + argvp, + pSec, + tSec, + false, + flags, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))), + dirp, + &siEx.StartupInfo, + pi, + ) + } + if err != nil { + return fmt.Errorf("failed to create process: %w", err) + } + // Don't need the thread handle for anything. + defer func() { + _ = windows.CloseHandle(pi.Thread) + }() + + // Grab an *os.Process to avoid reinventing the wheel here. The stdlib has great logic around waiting, exit code status/cleanup after a + // process has been launched. + c.Process, err = os.FindProcess(int(pi.ProcessId)) + if err != nil { + // If we can't find the process via os.FindProcess, terminate the process as that's what we rely on for all further operations on the + // object. + if tErr := windows.TerminateProcess(pi.Process, 1); tErr != nil { + return fmt.Errorf("failed to terminate process after process not found: %w", tErr) + } + return fmt.Errorf("failed to find process after starting: %w", err) + } + + if c.ctx != nil { + go c.waitOnContext() + } + + return nil +} + +func (c *Cmd) waitOnContext() { + sys := c.sys.(*conPtySys) + select { + case <-c.ctx.Done(): + _ = c.Cancel() + sys.cmdErr = c.ctx.Err() + case err := <-sys.done: + sys.cmdErr = err + } +} + +func (c *Cmd) wait() (retErr error) { + if c.Process == nil { + return errNotStarted + } + if c.ProcessState != nil { + return errors.New("process already waited on") + } + defer func() { + sys := c.sys.(*conPtySys) + sys.attrs.Delete() + sys.done <- nil + if retErr == nil { + retErr = sys.cmdErr + } + }() + c.ProcessState, retErr = c.Process.Wait() + if retErr != nil { + return retErr + } + return +} + +// +// Below are a bunch of helpers for working with Windows' CreateProcess family of functions. These are mostly exact copies of the same utilities +// found in the go stdlib. +// + +func lookExtensions(path, dir string) (string, error) { + if filepath.Base(path) == path { + path = filepath.Join(".", path) + } + + if dir == "" { + return exec.LookPath(path) + } + + if filepath.VolumeName(path) != "" { + return exec.LookPath(path) + } + + if len(path) > 1 && os.IsPathSeparator(path[0]) { + return exec.LookPath(path) + } + + dirandpath := filepath.Join(dir, path) + + // We assume that LookPath will only add file extension. + lp, err := exec.LookPath(dirandpath) + if err != nil { + return "", err + } + + ext := strings.TrimPrefix(lp, dirandpath) + + return path + ext, nil +} + +func execEnvDefault(sys *syscall.SysProcAttr) (env []string, err error) { + if sys == nil || sys.Token == 0 { + return syscall.Environ(), nil + } + + var block *uint16 + err = windows.CreateEnvironmentBlock(&block, windows.Token(sys.Token), false) + if err != nil { + return nil, err + } + + defer windows.DestroyEnvironmentBlock(block) + blockp := uintptr(unsafe.Pointer(block)) + + for { + // find NUL terminator + end := unsafe.Pointer(blockp) + for *(*uint16)(end) != 0 { + end = unsafe.Pointer(uintptr(end) + 2) + } + + n := (uintptr(end) - uintptr(unsafe.Pointer(blockp))) / 2 + if n == 0 { + // environment block ends with empty string + break + } + + entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(blockp))[:n:n] + env = append(env, string(utf16.Decode(entry))) + blockp += 2 * (uintptr(len(entry)) + 1) + } + return +} + +func isSlash(c uint8) bool { + return c == '\\' || c == '/' +} + +func normalizeDir(dir string) (name string, err error) { + ndir, err := syscall.FullPath(dir) + if err != nil { + return "", err + } + if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) { + // dir cannot have \\server\share\path form + return "", syscall.EINVAL + } + return ndir, nil +} + +func volToUpper(ch int) int { + if 'a' <= ch && ch <= 'z' { + ch += 'A' - 'a' + } + return ch +} + +func joinExeDirAndFName(dir, p string) (name string, err error) { + if len(p) == 0 { + return "", syscall.EINVAL + } + if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) { + // \\server\share\path form + return p, nil + } + if len(p) > 1 && p[1] == ':' { + // has drive letter + if len(p) == 2 { + return "", syscall.EINVAL + } + if isSlash(p[2]) { + return p, nil + } else { + d, err := normalizeDir(dir) + if err != nil { + return "", err + } + if volToUpper(int(p[0])) == volToUpper(int(d[0])) { + return syscall.FullPath(d + "\\" + p[2:]) + } else { + return syscall.FullPath(p) + } + } + } else { + // no drive letter + d, err := normalizeDir(dir) + if err != nil { + return "", err + } + if isSlash(p[0]) { + return windows.FullPath(d[:2] + p) + } else { + return windows.FullPath(d + "\\" + p) + } + } +} + +// createEnvBlock converts an array of environment strings into +// the representation required by CreateProcess: a sequence of NUL +// terminated strings followed by a nil. +// Last bytes are two UCS-2 NULs, or four NUL bytes. +func createEnvBlock(envv []string) *uint16 { + if len(envv) == 0 { + return &utf16.Encode([]rune("\x00\x00"))[0] + } + length := 0 + for _, s := range envv { + length += len(s) + 1 + } + length++ + + b := make([]byte, length) + i := 0 + for _, s := range envv { + l := len(s) + copy(b[i:i+l], []byte(s)) + copy(b[i+l:i+l+1], []byte{0}) + i = i + l + 1 + } + copy(b[i:i+1], []byte{0}) + + return &utf16.Encode([]rune(string(b)))[0] +} + +// dedupEnvCase is dedupEnv with a case option for testing. +// If caseInsensitive is true, the case of keys is ignored. +func dedupEnvCase(caseInsensitive bool, env []string) []string { + out := make([]string, 0, len(env)) + saw := make(map[string]int, len(env)) // key => index into out + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + out = append(out, kv) + continue + } + k := kv[:eq] + if caseInsensitive { + k = strings.ToLower(k) + } + if dupIdx, isDup := saw[k]; isDup { + out[dupIdx] = kv + continue + } + saw[k] = len(out) + out = append(out, kv) + } + return out +} + +// addCriticalEnv adds any critical environment variables that are required +// (or at least almost always required) on the operating system. +// Currently this is only used for Windows. +func addCriticalEnv(env []string) []string { + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + continue + } + k := kv[:eq] + if strings.EqualFold(k, "SYSTEMROOT") { + // We already have it. + return env + } + } + return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) +} diff --git a/examples/command/main.go b/examples/command/main.go new file mode 100644 index 0000000..a3db709 --- /dev/null +++ b/examples/command/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "io" + "log" + "os" + + "github.com/aymanbagabas/go-pty" +) + +func main() { + pty, err := pty.New() + if err != nil { + log.Fatalf("failed to open pty: %s", err) + } + + defer pty.Close() + c := pty.Command("grep", "--color=auto", "bar") + if err := c.Start(); err != nil { + log.Fatalf("failed to start: %s", err) + } + + go func() { + pty.Write([]byte("foo\n")) + pty.Write([]byte("bar\n")) + pty.Write([]byte("baz\n")) + pty.Write([]byte{4}) // EOT + }() + go io.Copy(os.Stdout, pty) + + if err := c.Wait(); err != nil { + panic(err) + } +} diff --git a/examples/go.mod b/examples/go.mod new file mode 100644 index 0000000..5627785 --- /dev/null +++ b/examples/go.mod @@ -0,0 +1,21 @@ +module examples + +go 1.20 + +replace github.com/aymanbagabas/go-pty => ../ + +replace github.com/creack/pty => github.com/aymanbagabas/pty v1.1.19-0.20230922024246-7bc6991e768a + +require ( + github.com/aymanbagabas/go-pty v0.0.0-00010101000000-000000000000 + github.com/charmbracelet/ssh v0.0.0-20230822194956-1a051f898e09 + github.com/u-root/u-root v0.11.0 + golang.org/x/crypto v0.13.0 + golang.org/x/term v0.12.0 +) + +require ( + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect + github.com/creack/pty v1.1.15 // indirect + golang.org/x/sys v0.12.0 // indirect +) diff --git a/examples/go.sum b/examples/go.sum new file mode 100644 index 0000000..ba8ec3c --- /dev/null +++ b/examples/go.sum @@ -0,0 +1,25 @@ +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/aymanbagabas/pty v1.1.19-0.20230922024246-7bc6991e768a h1:E1T8MDAqrtsQHu5WVPiAK3zoe1haQWe+SsSkAD0/SIw= +github.com/aymanbagabas/pty v1.1.19-0.20230922024246-7bc6991e768a/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/charmbracelet/ssh v0.0.0-20230822194956-1a051f898e09 h1:ZDIQmTtohv0S/AAYE//w8mYTxCzqphhF1+4ACPDMiLU= +github.com/charmbracelet/ssh v0.0.0-20230822194956-1a051f898e09/go.mod h1:F1vgddWsb/Yr/OZilFeRZEh5sE/qU0Dt1mKkmke6Zvg= +github.com/u-root/gobusybox/src v0.0.0-20221229083637-46b2883a7f90 h1:zTk5683I9K62wtZ6eUa6vu6IWwVHXPnoKK5n2unAwv0= +github.com/u-root/u-root v0.11.0 h1:6gCZLOeRyevw7gbTwMj3fKxnr9+yHFlgF3N7udUVNO8= +github.com/u-root/u-root v0.11.0/go.mod h1:DBkDtiZyONk9hzVEdB/PWI9B4TxDkElWlVTHseglrZY= +golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/examples/shell/main.go b/examples/shell/main.go new file mode 100644 index 0000000..bda0fb1 --- /dev/null +++ b/examples/shell/main.go @@ -0,0 +1,56 @@ +package main + +import ( + "io" + "log" + "os" + "os/signal" + + "github.com/aymanbagabas/go-pty" + "golang.org/x/term" +) + +type PTY interface { + Resize(w, h int) error +} + +func test() error { + ptmx, err := pty.New() + if err != nil { + return err + } + + defer ptmx.Close() + + c := ptmx.Command(`bash`) + if err := c.Start(); err != nil { + return err + } + + // Handle pty size. + ch := make(chan os.Signal, 1) + notifySizeChanges(ch) + go handlePtySize(ptmx, ch) + initSizeChange(ch) + defer func() { signal.Stop(ch); close(ch) }() // Cleanup signals when done. + + // Set stdin in raw mode. + oldState, err := term.MakeRaw(int(os.Stdin.Fd())) + if err != nil { + panic(err) + } + defer func() { _ = term.Restore(int(os.Stdin.Fd()), oldState) }() // Best effort. + + // Copy stdin to the pty and the pty to stdout. + // NOTE: The goroutine will keep reading until the next keystroke before returning. + go io.Copy(ptmx, os.Stdin) + go io.Copy(os.Stdout, ptmx) + + return c.Wait() +} + +func main() { + if err := test(); err != nil { + log.Fatal(err) + } +} diff --git a/examples/shell/size_other.go b/examples/shell/size_other.go new file mode 100644 index 0000000..64b811b --- /dev/null +++ b/examples/shell/size_other.go @@ -0,0 +1,35 @@ +//go:build !windows +// +build !windows + +package main + +import ( + "log" + "os" + "os/signal" + "syscall" + + "github.com/aymanbagabas/go-pty" + "golang.org/x/term" +) + +func notifySizeChanges(ch chan os.Signal) { + signal.Notify(ch, syscall.SIGWINCH) +} + +func handlePtySize(p pty.Pty, ch chan os.Signal) { + for range ch { + w, h, err := term.GetSize(int(os.Stdin.Fd())) + if err != nil { + log.Printf("error resizing pty: %s", err) + continue + } + if err := p.Resize(w, h); err != nil { + log.Printf("error resizing pty: %s", err) + } + } +} + +func initSizeChange(ch chan os.Signal) { + ch <- syscall.SIGWINCH +} diff --git a/examples/shell/size_windows.go b/examples/shell/size_windows.go new file mode 100644 index 0000000..a2c8bb7 --- /dev/null +++ b/examples/shell/size_windows.go @@ -0,0 +1,18 @@ +//go:build windows +// +build windows + +package main + +import ( + "os" + + "github.com/aymanbagabas/go-pty" +) + +func notifySizeChanges(chan os.Signal) {} + +func handlePtySize(p pty.Pty, _ chan os.Signal) { + // TODO +} + +func initSizeChange(chan os.Signal) {} diff --git a/examples/ssh/main.go b/examples/ssh/main.go new file mode 100644 index 0000000..357bac1 --- /dev/null +++ b/examples/ssh/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "io" + "log" + + "github.com/aymanbagabas/go-pty" + "github.com/charmbracelet/ssh" +) + +func main() { + ssh.Handle(func(s ssh.Session) { + ptyReq, winCh, isPty := s.Pty() + if isPty { + pseudo, err := pty.New() + if err != nil { + log.Println(err) + return + } + + defer pseudo.Close() + w, h := ptyReq.Window.Width, ptyReq.Window.Height + if ptyReq.Modes != nil { + if err := pty.ApplyTerminalModes(int(pseudo.Fd()), w, h, ptyReq.Modes); err != nil { + log.Println(err) + return + } + } + if err := pseudo.Resize(w, h); err != nil { + log.Println(err) + return + } + + cmd := pseudo.Command("bash") + cmd.Env = append(cmd.Env, "TERM="+ptyReq.Term, "SSH_TTY="+pseudo.Name()) + + if err := cmd.Start(); err != nil { + log.Print(err) + return + } + + go func() { + for win := range winCh { + pseudo.Resize(win.Height, win.Width) + } + }() + + go io.Copy(pseudo, s) // stdin + go io.Copy(s, pseudo) // stdout + + if err := cmd.Wait(); err != nil { + log.Print(err) + return + } + } else { + io.WriteString(s, "No PTY requested.\n") + s.Exit(1) + } + }) + + log.Println("starting ssh server on port 2222...") + log.Fatal(ssh.ListenAndServe(":2222", nil)) +} diff --git a/ssh/ssh_other.go b/examples/ssh/modes_other.go similarity index 92% rename from ssh/ssh_other.go rename to examples/ssh/modes_other.go index 654c570..6ba2668 100644 --- a/ssh/ssh_other.go +++ b/examples/ssh/modes_other.go @@ -1,125 +1,126 @@ -//go:build !windows - -package ssh - -import ( - "log" - - "github.com/u-root/u-root/pkg/termios" - "golang.org/x/crypto/ssh" - "golang.org/x/xerrors" -) - -// terminalModeFlagNames maps the SSH terminal mode flags to mnemonic -// names used by the termios package. -var terminalModeFlagNames = map[uint8]string{ - ssh.VINTR: "intr", - ssh.VQUIT: "quit", - ssh.VERASE: "erase", - ssh.VKILL: "kill", - ssh.VEOF: "eof", - ssh.VEOL: "eol", - ssh.VEOL2: "eol2", - ssh.VSTART: "start", - ssh.VSTOP: "stop", - ssh.VSUSP: "susp", - ssh.VDSUSP: "dsusp", - ssh.VREPRINT: "rprnt", - ssh.VWERASE: "werase", - ssh.VLNEXT: "lnext", - ssh.VFLUSH: "flush", - ssh.VSWTCH: "swtch", - ssh.VSTATUS: "status", - ssh.VDISCARD: "discard", - ssh.IGNPAR: "ignpar", - ssh.PARMRK: "parmrk", - ssh.INPCK: "inpck", - ssh.ISTRIP: "istrip", - ssh.INLCR: "inlcr", - ssh.IGNCR: "igncr", - ssh.ICRNL: "icrnl", - ssh.IUCLC: "iuclc", - ssh.IXON: "ixon", - ssh.IXANY: "ixany", - ssh.IXOFF: "ixoff", - ssh.IMAXBEL: "imaxbel", - ssh.IUTF8: "iutf8", - ssh.ISIG: "isig", - ssh.ICANON: "icanon", - ssh.XCASE: "xcase", - ssh.ECHO: "echo", - ssh.ECHOE: "echoe", - ssh.ECHOK: "echok", - ssh.ECHONL: "echonl", - ssh.NOFLSH: "noflsh", - ssh.TOSTOP: "tostop", - ssh.IEXTEN: "iexten", - ssh.ECHOCTL: "echoctl", - ssh.ECHOKE: "echoke", - ssh.PENDIN: "pendin", - ssh.OPOST: "opost", - ssh.OLCUC: "olcuc", - ssh.ONLCR: "onlcr", - ssh.OCRNL: "ocrnl", - ssh.ONOCR: "onocr", - ssh.ONLRET: "onlret", - ssh.CS7: "cs7", - ssh.CS8: "cs8", - ssh.PARENB: "parenb", - ssh.PARODD: "parodd", - ssh.TTY_OP_ISPEED: "tty_op_ispeed", - ssh.TTY_OP_OSPEED: "tty_op_ospeed", -} - -func applyTerminalModesToFd(fd uintptr, width int, height int, modes ssh.TerminalModes, logger *log.Logger) error { - if modes == nil { - modes = ssh.TerminalModes{} - } - - // Get the current TTY configuration. - tios, err := termios.GTTY(int(fd)) - if err != nil { - return xerrors.Errorf("GTTY: %w", err) - } - - // Apply the modes from the SSH request. - tios.Row = height - tios.Col = width - - for c, v := range modes { - if c == ssh.TTY_OP_ISPEED { - tios.Ispeed = int(v) - continue - } - if c == ssh.TTY_OP_OSPEED { - tios.Ospeed = int(v) - continue - } - k, ok := terminalModeFlagNames[c] - if !ok { - if logger != nil { - logger.Printf("unknown terminal mode: %d", c) - } - continue - } - if _, ok := tios.CC[k]; ok { - tios.CC[k] = uint8(v) - continue - } - if _, ok := tios.Opts[k]; ok { - tios.Opts[k] = v > 0 - continue - } - - if logger != nil { - logger.Printf("unsupported terminal mode: k=%s, c=%d, v=%d", k, c, v) - } - } - - // Save the new TTY configuration. - if _, err := tios.STTY(int(fd)); err != nil { - return xerrors.Errorf("STTY: %w", err) - } - - return nil -} +//go:build !windows +// +build !windows + +package main + +import ( + "fmt" + "log" + + "github.com/u-root/u-root/pkg/termios" + "golang.org/x/crypto/ssh" +) + +// terminalModeFlagNames maps the SSH terminal mode flags to mnemonic +// names used by the termios package. +var terminalModeFlagNames = map[uint8]string{ + ssh.VINTR: "intr", + ssh.VQUIT: "quit", + ssh.VERASE: "erase", + ssh.VKILL: "kill", + ssh.VEOF: "eof", + ssh.VEOL: "eol", + ssh.VEOL2: "eol2", + ssh.VSTART: "start", + ssh.VSTOP: "stop", + ssh.VSUSP: "susp", + ssh.VDSUSP: "dsusp", + ssh.VREPRINT: "rprnt", + ssh.VWERASE: "werase", + ssh.VLNEXT: "lnext", + ssh.VFLUSH: "flush", + ssh.VSWTCH: "swtch", + ssh.VSTATUS: "status", + ssh.VDISCARD: "discard", + ssh.IGNPAR: "ignpar", + ssh.PARMRK: "parmrk", + ssh.INPCK: "inpck", + ssh.ISTRIP: "istrip", + ssh.INLCR: "inlcr", + ssh.IGNCR: "igncr", + ssh.ICRNL: "icrnl", + ssh.IUCLC: "iuclc", + ssh.IXON: "ixon", + ssh.IXANY: "ixany", + ssh.IXOFF: "ixoff", + ssh.IMAXBEL: "imaxbel", + ssh.IUTF8: "iutf8", + ssh.ISIG: "isig", + ssh.ICANON: "icanon", + ssh.XCASE: "xcase", + ssh.ECHO: "echo", + ssh.ECHOE: "echoe", + ssh.ECHOK: "echok", + ssh.ECHONL: "echonl", + ssh.NOFLSH: "noflsh", + ssh.TOSTOP: "tostop", + ssh.IEXTEN: "iexten", + ssh.ECHOCTL: "echoctl", + ssh.ECHOKE: "echoke", + ssh.PENDIN: "pendin", + ssh.OPOST: "opost", + ssh.OLCUC: "olcuc", + ssh.ONLCR: "onlcr", + ssh.OCRNL: "ocrnl", + ssh.ONOCR: "onocr", + ssh.ONLRET: "onlret", + ssh.CS7: "cs7", + ssh.CS8: "cs8", + ssh.PARENB: "parenb", + ssh.PARODD: "parodd", + ssh.TTY_OP_ISPEED: "tty_op_ispeed", + ssh.TTY_OP_OSPEED: "tty_op_ospeed", +} + +func applyTerminalModesToFd(fd uintptr, width int, height int, modes ssh.TerminalModes, logger *log.Logger) error { + if modes == nil { + modes = ssh.TerminalModes{} + } + + // Get the current TTY configuration. + tios, err := termios.GTTY(int(fd)) + if err != nil { + return fmt.Errorf("GTTY: %w", err) + } + + // Apply the modes from the SSH request. + tios.Row = height + tios.Col = width + + for c, v := range modes { + if c == ssh.TTY_OP_ISPEED { + tios.Ispeed = int(v) + continue + } + if c == ssh.TTY_OP_OSPEED { + tios.Ospeed = int(v) + continue + } + k, ok := terminalModeFlagNames[c] + if !ok { + if logger != nil { + logger.Printf("unknown terminal mode: %d", c) + } + continue + } + if _, ok := tios.CC[k]; ok { + tios.CC[k] = uint8(v) + continue + } + if _, ok := tios.Opts[k]; ok { + tios.Opts[k] = v > 0 + continue + } + + if logger != nil { + logger.Printf("unsupported terminal mode: k=%s, c=%d, v=%d", k, c, v) + } + } + + // Save the new TTY configuration. + if _, err := tios.STTY(int(fd)); err != nil { + return fmt.Errorf("STTY: %w", err) + } + + return nil +} diff --git a/ssh/ssh_windows.go b/examples/ssh/modes_windows.go similarity index 57% rename from ssh/ssh_windows.go rename to examples/ssh/modes_windows.go index f7ad382..5e0fe43 100644 --- a/ssh/ssh_windows.go +++ b/examples/ssh/modes_windows.go @@ -1,7 +1,14 @@ -//go:build windows - -package ssh - -func applyTerminalModesToFd(fd uintptr, width int, height int, modes ssh.TerminalModes, logger *log.Logger) error { - return xerrors.Errorf("not implemented") -} +//go:build windows +// +build windows + +package main + +import ( + "log" + + "golang.org/x/crypto/ssh" +) + +func applyTerminalModesToFd(fd uintptr, width int, height int, modes ssh.TerminalModes, logger *log.Logger) error { + return nil +} diff --git a/go.mod b/go.mod index ef77224..715427f 100644 --- a/go.mod +++ b/go.mod @@ -2,25 +2,10 @@ module github.com/aymanbagabas/go-pty go 1.20 +replace github.com/creack/pty => github.com/aymanbagabas/pty v1.1.19-0.20230922024246-7bc6991e768a + require ( - github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d - github.com/creack/pty v1.1.18 - github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 - github.com/spf13/cobra v1.7.0 - github.com/stretchr/testify v1.8.4 - github.com/u-root/u-root v0.11.0 - go.uber.org/goleak v1.2.1 + github.com/creack/pty v1.1.15 golang.org/x/crypto v0.12.0 - golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691 golang.org/x/sys v0.12.0 - golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 -) - -require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/kr/text v0.2.0 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c269034..c683503 100644 --- a/go.sum +++ b/go.sum @@ -1,42 +1,7 @@ -github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8= -github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= -github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 h1:AgcIVYPa6XJnU3phs104wLj8l5GEththEw6+F79YsIY= -github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/u-root/gobusybox/src v0.0.0-20221229083637-46b2883a7f90 h1:zTk5683I9K62wtZ6eUa6vu6IWwVHXPnoKK5n2unAwv0= -github.com/u-root/u-root v0.11.0 h1:6gCZLOeRyevw7gbTwMj3fKxnr9+yHFlgF3N7udUVNO8= -github.com/u-root/u-root v0.11.0/go.mod h1:DBkDtiZyONk9hzVEdB/PWI9B4TxDkElWlVTHseglrZY= -go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= -go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= +github.com/aymanbagabas/pty v1.1.19-0.20230922024246-7bc6991e768a h1:E1T8MDAqrtsQHu5WVPiAK3zoe1haQWe+SsSkAD0/SIw= +github.com/aymanbagabas/pty v1.1.19-0.20230922024246-7bc6991e768a/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691 h1:/yRP+0AN7mf5DkD3BAI6TOFnd51gEoDEb8o35jIFtgw= -golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0= -golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= -golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pty.go b/pty.go index f5d1641..583897c 100644 --- a/pty.go +++ b/pty.go @@ -1,131 +1,44 @@ package pty import ( + "context" + "errors" "io" - - "golang.org/x/xerrors" ) -// ErrClosed is returned when a PTY is used after it has been closed. -var ErrClosed = xerrors.New("pty: closed") - -// PTYCmd is an interface for interacting with a pseudo-TTY where we control -// only one end, and the other end has been passed to a running os.Process. -// nolint:revive -type PTYCmd interface { - io.Closer - - // Resize sets the size of the PTY. - Resize(width int, height int) error - - // OutputReader returns an io.Reader for reading the output from the process - // controlled by the pseudo-TTY - OutputReader() io.Reader +var ( + // ErrInvalidCommand is returned when the command is invalid. + ErrInvalidCommand = errors.New("pty: invalid command") +) - // InputWriter returns an io.Writer for writing into to the process - // controlled by the pseudo-TTY - InputWriter() io.Writer +// New returns a new pseudo-terminal. +func New() (Pty, error) { + return newPty() } -// PTY is a minimal interface for interacting with pseudo-TTY where this -// process retains access to _both_ ends of the pseudo-TTY (i.e. `ptm` & `pts` -// on Linux). -type PTY interface { - io.Closer - - // Resize sets the size of the PTY. - Resize(width int, height int) error +// Pty is a pseudo-terminal interface. +type Pty interface { + io.ReadWriteCloser - // Name of the TTY. Example on Linux would be "/dev/pts/1". + // Name returns the name of the pseudo-terminal. + // On Windows, this will always be "windows-pty". + // On Unix, this will return the name of the slave end of the + // pseudo-terminal TTY. Name() string - // Output handles TTY output. - // - // cmd.SetOutput(pty.Output()) would be used to specify a command - // uses the output stream for writing. - // - // The same stream could be read to validate output. - Output() io.ReadWriter - - // Input handles TTY input. - // - // cmd.SetInput(pty.Input()) would be used to specify a command - // uses the PTY input for reading. - // - // The same stream would be used to provide user input: pty.Input().Write(...) - Input() io.ReadWriter -} - -// Process represents a process running in a PTY. We need to trigger special processing on the PTY -// on process completion, meaning that we will have goroutines calling Wait() on the process. Since -// the caller will also typically wait for the process, and it is not safe for multiple goroutines -// to Wait() on a process, this abstraction provides a goroutine-safe interface for interacting with -// the process. -type Process interface { - // Wait for the command to complete. Returned error is as for exec.Cmd.Wait() - Wait() error - - // Kill the command process. Returned error is as for os.Process.Kill() - Kill() error -} - -// WithFlags represents a PTY whose flags can be inspected, in particular -// to determine whether local echo is enabled. -type WithFlags interface { - PTY - - // EchoEnabled determines whether local echo is currently enabled for this terminal. - EchoEnabled() (bool, error) -} + // Command returns a command that can be used to start a process + // attached to the pseudo-terminal. + Command(name string, args ...string) *Cmd -// Controllable represents a PTY that can be controlled via the syscall.RawConn -// interface. -type Controllable interface { - PTY + // CommandContext returns a command that can be used to start a process + // attached to the pseudo-terminal. + CommandContext(ctx context.Context, name string, args ...string) *Cmd - // ControlPTY allows the caller to control the PTY via the syscall.RawConn interface. - ControlPTY(func(uintptr) error) error - - // ControlTTY allows the caller to control the TTY via the syscall.RawConn interface. - ControlTTY(func(uintptr) error) error -} - -// Options represents a an option for a PTY. -type Option func(*ptyOptions) - -type ptyOptions struct { - setSize bool - - height int - width int -} - -// WithSize sets the size of the PTY. -func WithSize(width int, height int) Option { - return func(opts *ptyOptions) { - opts.setSize = true - opts.height = height - opts.width = width - } -} - -// New constructs a new Pty. -func New(opts ...Option) (PTY, error) { - return newPty(opts...) -} - -// readWriter is an implementation of io.ReadWriter that wraps two separate -// underlying file descriptors, one for reading and one for writing, and allows -// them to be accessed separately. -type readWriter struct { - Reader io.Reader - Writer io.Writer -} - -func (rw readWriter) Read(p []byte) (int, error) { - return rw.Reader.Read(p) -} + // Resize resizes the pseudo-terminal. + Resize(width int, height int) error -func (rw readWriter) Write(p []byte) (int, error) { - return rw.Writer.Write(p) + // Fd returns the file descriptor of the pseudo-terminal. + // On Unix, this will return the file descriptor of the master end. + // On Windows, this will return the handle of the console. + Fd() uintptr } diff --git a/pty_linux.go b/pty_linux.go deleted file mode 100644 index c0a5d31..0000000 --- a/pty_linux.go +++ /dev/null @@ -1,24 +0,0 @@ -// go:build linux - -package pty - -import ( - "github.com/u-root/u-root/pkg/termios" - "golang.org/x/sys/unix" -) - -func (p *otherPty) EchoEnabled() (echo bool, err error) { - err = p.control(p.pty, func(fd uintptr) error { - t, err := termios.GetTermios(fd) - if err != nil { - return err - } - - echo = (t.Lflag & unix.ECHO) != 0 - return nil - }) - if err != nil { - return false, err - } - return echo, nil -} diff --git a/pty_other.go b/pty_other.go index 7029599..c0ac82b 100644 --- a/pty_other.go +++ b/pty_other.go @@ -1,203 +1,131 @@ //go:build !windows +// +build !windows package pty import ( - "io" - "io/fs" + "context" + "errors" "os" "os/exec" - "runtime" - "sync" "github.com/creack/pty" - "github.com/u-root/u-root/pkg/termios" "golang.org/x/sys/unix" - "golang.org/x/xerrors" ) -func newPty(opt ...Option) (retPTY *otherPty, err error) { - var opts ptyOptions - for _, o := range opt { - o(&opts) - } +// PosixPty is a POSIX compliant pseudo-terminal. +// See: https://pubs.opengroup.org/onlinepubs/9699919799/ +type PosixPty struct { + master, slave *os.File + closed bool +} - ptyFile, ttyFile, err := pty.Open() - if err != nil { - return nil, err - } - opty := &otherPty{ - pty: ptyFile, - tty: ttyFile, - opts: opts, - name: ttyFile.Name(), +var _ Pty = &PosixPty{} + +// Close implements Pty. +func (p *PosixPty) Close() error { + if p.closed { + return nil } defer func() { - if err != nil { - _ = opty.Close() - } + p.closed = true }() + return errors.Join(p.master.Close(), p.slave.Close()) +} - if opts.setSize { - if err := opty.Resize(opts.width, opts.height); err != nil { - return nil, err - } +// Command implements Pty. +func (p *PosixPty) Command(name string, args ...string) *Cmd { + cmd := exec.Command(name, args...) + c := &Cmd{ + pty: p, + sys: cmd, + Path: name, + Args: append([]string{name}, args...), } - - return opty, err + c.sys = cmd + return c } -type otherPty struct { - mutex sync.Mutex - closed bool - err error - pty, tty *os.File - opts ptyOptions - name string +// CommandContext implements Pty. +func (p *PosixPty) CommandContext(ctx context.Context, name string, args ...string) *Cmd { + cmd := exec.CommandContext(ctx, name, args...) + c := p.Command(name, args...) + c.ctx = ctx + c.Cancel = func() error { + return cmd.Cancel() + } + return c } -func (p *otherPty) ControlPTY(fn func(fd uintptr) error) error { - return p.control(p.pty, fn) +// Name implements Pty. +func (p *PosixPty) Name() string { + return p.slave.Name() } -func (p *otherPty) ControlTTY(fn func(fd uintptr) error) error { - return p.control(p.tty, fn) +// Read implements Pty. +func (p *PosixPty) Read(b []byte) (n int, err error) { + return p.master.Read(b) } -func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error) { - defer func() { - // Always echo the close error for closed ptys. - p.mutex.Lock() - defer p.mutex.Unlock() - if p.closed { - err = p.err - } - }() - - rawConn, err := tty.SyscallConn() +func (p *PosixPty) Control(f func(fd uintptr)) error { + conn, err := p.master.SyscallConn() if err != nil { return err } - - var ctlErr error - err = rawConn.Control(func(fd uintptr) { - ctlErr = fn(fd) - }) - switch { - case err != nil: - return err - case ctlErr != nil: - return ctlErr - default: - return nil - } + return conn.Control(f) } -func (p *otherPty) Name() string { - return p.name +// Master returns the pseudo-terminal master end (pty). +func (p *PosixPty) Master() *os.File { + return p.master } -func (p *otherPty) Input() io.ReadWriter { - return readWriter{ - Reader: p.tty, - Writer: p.pty, - } +// Slave returns the pseudo-terminal slave end (tty). +func (p *PosixPty) Slave() *os.File { + return p.slave } -func (p *otherPty) InputWriter() io.Writer { - return p.pty -} +// Winsize represents the terminal window size. +type Winsize = unix.Winsize -func (p *otherPty) Output() io.ReadWriter { - return readWriter{ - Reader: &ptmReader{p.pty}, - Writer: p.tty, +// SetWinsize sets the pseudo-terminal window size. +func (p *PosixPty) SetWinsize(ws *Winsize) error { + var ctrlErr error + if err := p.Control(func(fd uintptr) { + ctrlErr = unix.IoctlSetWinsize(int(fd), unix.TIOCSWINSZ, ws) + }); err != nil { + return err } -} -func (p *otherPty) OutputReader() io.Reader { - return &ptmReader{p.pty} + return ctrlErr } -func (p *otherPty) Resize(width int, height int) error { - return p.control(p.pty, func(fd uintptr) error { - return termios.SetWinSize(fd, &termios.Winsize{ - Winsize: unix.Winsize{ - Row: uint16(height), - Col: uint16(width), - }, - }) +// Resize implements Pty. +func (p *PosixPty) Resize(width int, height int) error { + return p.SetWinsize(&Winsize{ + Row: uint16(height), + Col: uint16(width), }) } -func (p *otherPty) Close() error { - p.mutex.Lock() - defer p.mutex.Unlock() - - if p.closed { - return p.err - } - p.closed = true - - err := p.pty.Close() - // tty is closed & unset if we Start() a new process - if p.tty != nil { - err2 := p.tty.Close() - if err == nil { - err = err2 - } - } - - if err != nil { - p.err = err - } else { - p.err = ErrClosed - } - - return err +// Write implements Pty. +func (p *PosixPty) Write(b []byte) (n int, err error) { + return p.master.Write(b) } -type otherProcess struct { - pty *os.File - cmd *exec.Cmd - - // cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first. - cmdDone chan any - cmdErr error +// Fd implements Pty. +func (p *PosixPty) Fd() uintptr { + return p.master.Fd() } -func (p *otherProcess) Wait() error { - <-p.cmdDone - return p.cmdErr -} - -func (p *otherProcess) Kill() error { - return p.cmd.Process.Kill() -} - -func (p *otherProcess) waitInternal() { - // The GC can garbage collect the TTY FD before the command - // has finished running. See: - // https://github.com/creack/pty/issues/127#issuecomment-932764012 - p.cmdErr = p.cmd.Wait() - runtime.KeepAlive(p.pty) - close(p.cmdDone) -} - -// ptmReader wraps a reference to the ptm side of a pseudo-TTY for portability -type ptmReader struct { - ptm io.Reader -} - -func (r *ptmReader) Read(p []byte) (n int, err error) { - n, err = r.ptm.Read(p) - // output from the ptm will hit a PathErr when the process hangs up the - // other side (typically when the process exits, but could be earlier). For - // portability, and to fit with our use of io.Copy() to copy from the PTY, - // we want to translate this error into io.EOF - pathErr := &fs.PathError{} - if xerrors.As(err, &pathErr) { - return n, io.EOF +func newPty() (Pty, error) { + master, slave, err := pty.Open() + if err != nil { + return nil, err } - return n, err + + return &PosixPty{ + master: master, + slave: slave, + }, nil } diff --git a/pty_windows.go b/pty_windows.go index cc4c67f..4d9820b 100644 --- a/pty_windows.go +++ b/pty_windows.go @@ -1,262 +1,174 @@ //go:build windows +// +build windows package pty import ( "context" - "io" + "errors" + "fmt" "os" - "os/exec" "sync" "unsafe" "golang.org/x/sys/windows" +) - "golang.org/x/xerrors" +const ( + _PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE = 0x20016 // nolint:revive ) var ( - kernel32 = windows.NewLazySystemDLL("kernel32.dll") - procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole") - procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole") - procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") + errClosedConPty = errors.New("pseudo console is closed") + errNotStarted = errors.New("process not started") ) +// Install this from github.com/Microsoft/go-winio +// go install github.com/Microsoft/go-winio/tools/mkwinsyscall@latest +//go:generate mkwinsyscall -output zsyscall_windows.go ./*.go + +// ConPty is a Windows console pseudo-terminal. +// It uses Windows pseudo console API to create a console that can be used to +// start processes attached to it. +// // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session -func newPty(opt ...Option) (*ptyWindows, error) { - var opts ptyOptions - for _, o := range opt { - o(&opts) - } +type ConPty struct { + handle windows.Handle + inPipe, outPipe *os.File + mtx sync.RWMutex +} - // We use the CreatePseudoConsole API which was introduced in build 17763 - vsn := windows.RtlGetVersion() - if vsn.MajorVersion < 10 || - vsn.BuildNumber < 17763 { - // If the CreatePseudoConsole API is not available, we fall back to a simpler - // implementation that doesn't create an actual PTY - just uses os.Pipe - return nil, xerrors.Errorf("pty not supported") - } +var _ Pty = &ConPty{} - pty := &ptyWindows{ - opts: opts, +func newPty() (Pty, error) { + ptyIn, inPipeOurs, err := os.Pipe() + if err != nil { + return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err) } - var err error - pty.inputRead, pty.inputWrite, err = os.Pipe() + outPipeOurs, ptyOut, err := os.Pipe() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err) } - pty.outputRead, pty.outputWrite, err = os.Pipe() + + var hpc windows.Handle + coord := windows.Coord{X: 80, Y: 25} + err = createPseudoConsole(coord, windows.Handle(ptyIn.Fd()), windows.Handle(ptyOut.Fd()), 0, &hpc) if err != nil { - _ = pty.inputRead.Close() - _ = pty.inputWrite.Close() - return nil, err + return nil, fmt.Errorf("failed to create pseudo console: %w", err) } - consoleSize := uintptr(80) + (uintptr(80) << 16) - if opts.setSize { - consoleSize = uintptr(opts.width) + (uintptr(opts.height) << 16) + if err := ptyOut.Close(); err != nil { + return nil, fmt.Errorf("failed to close pseudo console handle: %w", err) } - ret, _, err := procCreatePseudoConsole.Call( - consoleSize, - uintptr(pty.inputRead.Fd()), - uintptr(pty.outputWrite.Fd()), - 0, - uintptr(unsafe.Pointer(&pty.console)), - ) - // CreatePseudoConsole returns S_OK on success, as per: - // https://learn.microsoft.com/en-us/windows/console/createpseudoconsole - if windows.Handle(ret) != windows.S_OK { - _ = pty.Close() - return nil, xerrors.Errorf("create pseudo console (%d): %w", int32(ret), err) + if err := ptyIn.Close(); err != nil { + return nil, fmt.Errorf("failed to close pseudo console handle: %w", err) } - return pty, nil + return &ConPty{ + handle: hpc, + inPipe: inPipeOurs, + outPipe: outPipeOurs, + }, nil } -type ptyWindows struct { - opts ptyOptions - console windows.Handle +// Close implements Pty. +func (p *ConPty) Close() error { + p.mtx.Lock() + defer p.mtx.Unlock() - outputWrite *os.File - outputRead *os.File - inputWrite *os.File - inputRead *os.File - - closeMutex sync.Mutex - closed bool + closePseudoConsole(p.handle) + return errors.Join(p.inPipe.Close(), p.outPipe.Close()) } -type windowsProcess struct { - // cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first. - cmdDone chan any - cmdErr error - proc *os.Process - pw *ptyWindows -} - -// Name returns the TTY name on Windows. -// -// Not implemented. -func (p *ptyWindows) Name() string { - return "" -} - -func (p *ptyWindows) Output() io.ReadWriter { - return readWriter{ - Reader: p.outputRead, - Writer: p.outputWrite, +// Command implements Pty. +func (p *ConPty) Command(name string, args ...string) *Cmd { + c := &Cmd{ + pty: p, + Path: name, + Args: append([]string{name}, args...), } + return c } -func (p *ptyWindows) OutputReader() io.Reader { - return p.outputRead -} - -func (p *ptyWindows) Input() io.ReadWriter { - return readWriter{ - Reader: p.inputRead, - Writer: p.inputWrite, +// CommandContext implements Pty. +func (p *ConPty) CommandContext(ctx context.Context, name string, args ...string) *Cmd { + if ctx == nil { + panic("nil context") + } + c := p.Command(name, args...) + c.ctx = ctx + c.Cancel = func() error { + return c.Process.Kill() } + return c } -func (p *ptyWindows) InputWriter() io.Writer { - return p.inputWrite +// Name implements Pty. +func (*ConPty) Name() string { + return "windows-pty" } -func (p *ptyWindows) Resize(width int, height int) error { - // hold the lock, so we don't race with anyone trying to close the console - p.closeMutex.Lock() - defer p.closeMutex.Unlock() - if p.closed || p.console == windows.InvalidHandle { - return ErrClosed - } - // Taken from: https://github.com/microsoft/hcsshim/blob/54a5ad86808d761e3e396aff3e2022840f39f9a8/internal/winapi/zsyscall_windows.go#L144 - ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(*((*uint32)(unsafe.Pointer(&windows.Coord{ - Y: int16(height), - X: int16(width), - }))))) - if windows.Handle(ret) != windows.S_OK { - return err - } - return nil +// Read implements Pty. +func (p *ConPty) Read(b []byte) (n int, err error) { + return p.outPipe.Read(b) } -// closeConsoleNoLock closes the console handle, and sets it to -// windows.InvalidHandle. It must be called with p.closeMutex held. -func (p *ptyWindows) closeConsoleNoLock() error { - // if we are running a command in the PTY, the corresponding *windowsProcess - // may have already closed the PseudoConsole when the command exited, so that - // output reads can get to EOF. In that case, we don't need to close it - // again here. - if p.console != windows.InvalidHandle { - // ClosePseudoConsole has no return value and typically the syscall - // returns S_FALSE (a success value). We could ignore the return value - // and error here but we handle anyway, it just in case. - // - // Note that ClosePseudoConsole is a blocking system call and may write - // a final frame to the output buffer (p.outputWrite), so there must be - // a consumer (p.outputRead) to ensure we don't block here indefinitely. - // - // https://docs.microsoft.com/en-us/windows/console/closepseudoconsole - ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) - if winerrorFailed(ret) { - return xerrors.Errorf("close pseudo console (%d): %w", ret, err) - } - p.console = windows.InvalidHandle +// Resize implements Pty. +func (p *ConPty) Resize(width int, height int) error { + p.mtx.RLock() + defer p.mtx.RUnlock() + if err := resizePseudoConsole(p.handle, windows.Coord{X: int16(height), Y: int16(width)}); err != nil { + return fmt.Errorf("failed to resize pseudo console: %w", err) } - return nil } -func (p *ptyWindows) Close() error { - p.closeMutex.Lock() - defer p.closeMutex.Unlock() - if p.closed { - return nil - } - - // Close the pseudo console, this will also terminate the process attached - // to this pty. If it was created via Start(), this also unblocks close of - // the readers below. - err := p.closeConsoleNoLock() - if err != nil { - return err - } - - // Only set closed after the console has been successfully closed. - p.closed = true - - // Close the pipes ensuring that the writer is closed before the respective - // reader, otherwise closing the reader may block indefinitely. Note that - // outputWrite and inputRead are unset when we Start() a new process. - if p.outputWrite != nil { - _ = p.outputWrite.Close() - } - _ = p.outputRead.Close() - _ = p.inputWrite.Close() - if p.inputRead != nil { - _ = p.inputRead.Close() - } - return nil +// Write implements Pty. +func (p *ConPty) Write(b []byte) (n int, err error) { + return p.inPipe.Write(b) } -func (p *windowsProcess) waitInternal() { - // put this on the bottom of the defer stack since the next defer can write to p.cmdErr - defer close(p.cmdDone) - defer func() { - // close the pseudoconsole handle when the process exits, if it hasn't already been closed. - // this is important because the PseudoConsole (conhost.exe) holds the write-end - // of the output pipe. If it is not closed, reads on that pipe will block, even though - // the command has exited. - // c.f. https://devblogs.microsoft.com/commandline/windows-command-line-introducing-the-windows-pseudo-console-conpty/ - p.pw.closeMutex.Lock() - defer p.pw.closeMutex.Unlock() +// Fd implements Pty. +func (p *ConPty) Fd() uintptr { + p.mtx.RLock() + defer p.mtx.RUnlock() + return uintptr(p.handle) +} - err := p.pw.closeConsoleNoLock() - // if we already have an error from the command, prefer that error - // but if the command succeeded and closing the PseudoConsole fails - // then record that error so that we have a chance to see it - if err != nil && p.cmdErr == nil { - p.cmdErr = err - } - }() +// updateProcThreadAttribute updates the passed in attribute list to contain the entry necessary for use with +// CreateProcess. +func (p *ConPty) updateProcThreadAttribute(attrList *windows.ProcThreadAttributeListContainer) error { + p.mtx.RLock() + defer p.mtx.RUnlock() - state, err := p.proc.Wait() - if err != nil { - p.cmdErr = err - return + if p.handle == 0 { + return errClosedConPty } - if !state.Success() { - p.cmdErr = &exec.ExitError{ProcessState: state} - return + + if err := attrList.Update( + _PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE, + unsafe.Pointer(p.handle), + unsafe.Sizeof(p.handle), + ); err != nil { + return fmt.Errorf("failed to update proc thread attributes for pseudo console: %w", err) } -} -func (p *windowsProcess) Wait() error { - <-p.cmdDone - return p.cmdErr + return nil } -func (p *windowsProcess) Kill() error { - return p.proc.Kill() +// createPseudoConsole creates a windows pseudo console. +func createPseudoConsole(size windows.Coord, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) error { + // We need this wrapper as the function takes a COORD struct and not a pointer to one, so we need to cast to something beforehand. + return _createPseudoConsole(*((*uint32)(unsafe.Pointer(&size))), hInput, hOutput, dwFlags, hpcon) } -// killOnContext waits for the context to be done and kills the process, unless it exits on its own first. -func (p *windowsProcess) killOnContext(ctx context.Context) { - select { - case <-p.cmdDone: - return - case <-ctx.Done(): - p.Kill() - } +// resizePseudoConsole resizes the internal buffers of the pseudo console to the width and height specified in `size`. +func resizePseudoConsole(hpcon windows.Handle, size windows.Coord) error { + // We need this wrapper as the function takes a COORD struct and not a pointer to one, so we need to cast to something beforehand. + return _resizePseudoConsole(hpcon, *((*uint32)(unsafe.Pointer(&size)))) } -// winerrorFailed returns true if the syscall failed, this function -// assumes the return value is a 32-bit integer, like HRESULT. -// -// https://learn.microsoft.com/en-us/windows/win32/api/winerror/nf-winerror-failed -func winerrorFailed(r1 uintptr) bool { - return int32(r1) < 0 -} +//sys _createPseudoConsole(size uint32, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) (hr error) = kernel32.CreatePseudoConsole +//sys _resizePseudoConsole(hPc windows.Handle, size uint32) (hr error) = kernel32.ResizePseudoConsole +//sys closePseudoConsole(hpc windows.Handle) = kernel32.ClosePseudoConsole diff --git a/ptytest/ptytest.go b/ptytest/ptytest.go deleted file mode 100644 index 1548771..0000000 --- a/ptytest/ptytest.go +++ /dev/null @@ -1,524 +0,0 @@ -package ptytest - -import ( - "bufio" - "bytes" - "context" - "fmt" - "io" - "runtime" - "strings" - "sync" - "testing" - "time" - "unicode/utf8" - - "github.com/acarl005/stripansi" - "github.com/spf13/cobra" - "github.com/stretchr/testify/require" - "golang.org/x/exp/slices" - "golang.org/x/xerrors" - - "github.com/aymanbagabas/go-pty" -) - -const ( - WaitShort = time.Second * 10 - WaitMedium = time.Second * 15 - WaitLong = time.Second * 25 - WaitSuperLong = time.Minute - IntervalFast = time.Millisecond * 25 -) - -func New(t *testing.T, opts ...pty.Option) *PTY { - t.Helper() - - ptty, err := pty.New(opts...) - require.NoError(t, err) - - e := newExpecter(t, ptty.Output(), "cmd") - r := &PTY{ - outExpecter: e, - PTY: ptty, - } - // Ensure pty is cleaned up at the end of test. - t.Cleanup(func() { - _ = r.Close() - }) - return r -} - -// Start starts a new process asynchronously and returns a PTYCmd and Process. -// It kills the process and PTYCmd upon cleanup -func Start(t *testing.T, cmd *pty.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) { - t.Helper() - - ptty, ps, err := pty.Start(cmd, opts...) - require.NoError(t, err) - t.Cleanup(func() { - _ = ps.Kill() - _ = ps.Wait() - }) - ex := newExpecter(t, ptty.OutputReader(), cmd.Args[0]) - - r := &PTYCmd{ - outExpecter: ex, - PTYCmd: ptty, - } - t.Cleanup(func() { - _ = r.Close() - }) - return r, ps -} - -func newExpecter(t *testing.T, r io.Reader, name string) outExpecter { - // Use pipe for logging. - logDone := make(chan struct{}) - logr, logw := io.Pipe() - - // Write to log and output buffer. - copyDone := make(chan struct{}) - out := newStdbuf() - w := io.MultiWriter(logw, out) - - ex := outExpecter{ - t: t, - out: out, - name: name, - - runeReader: bufio.NewReaderSize(out, utf8.UTFMax), - } - - logClose := func(name string, c io.Closer) { - ex.logf("closing %s", name) - err := c.Close() - ex.logf("closed %s: %v", name, err) - } - // Set the actual close function for the outExpecter. - ex.close = func(reason string) error { - ctx, cancel := context.WithTimeout(context.Background(), WaitShort) - defer cancel() - - ex.logf("closing expecter: %s", reason) - - // Caller needs to have closed the PTY so that copying can complete - select { - case <-ctx.Done(): - ex.fatalf("close", "copy did not close in time") - case <-copyDone: - } - - logClose("logw", logw) - logClose("logr", logr) - select { - case <-ctx.Done(): - ex.fatalf("close", "log pipe did not close in time") - case <-logDone: - } - - ex.logf("closed expecter") - - return nil - } - - go func() { - defer close(copyDone) - _, err := io.Copy(w, r) - ex.logf("copy done: %v", err) - ex.logf("closing out") - err = out.closeErr(err) - ex.logf("closed out: %v", err) - }() - - // Log all output as part of test for easier debugging on errors. - go func() { - defer close(logDone) - s := bufio.NewScanner(logr) - for s.Scan() { - ex.logf("%q", stripansi.Strip(s.Text())) - } - }() - - return ex -} - -type outExpecter struct { - t *testing.T - close func(reason string) error - out *stdbuf - name string - - runeReader *bufio.Reader -} - -func (e *outExpecter) ExpectMatch(str string) string { - e.t.Helper() - - timeout, cancel := context.WithTimeout(context.Background(), WaitMedium) - defer cancel() - - return e.ExpectMatchContext(timeout, str) -} - -// TODO(mafredri): Rename this to ExpectMatch when refactoring. -func (e *outExpecter) ExpectMatchContext(ctx context.Context, str string) string { - e.t.Helper() - - var buffer bytes.Buffer - err := e.doMatchWithDeadline(ctx, "ExpectMatchContext", func(rd *bufio.Reader) error { - for { - r, _, err := rd.ReadRune() - if err != nil { - return err - } - _, err = buffer.WriteRune(r) - if err != nil { - return err - } - if strings.Contains(buffer.String(), str) { - return nil - } - } - }) - if err != nil { - e.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String()) - return "" - } - e.logf("matched %q = %q", str, stripansi.Strip(buffer.String())) - return buffer.String() -} - -// ExpectNoMatchBefore validates that `match` does not occur before `before`. -func (e *outExpecter) ExpectNoMatchBefore(ctx context.Context, match, before string) string { - e.t.Helper() - - var buffer bytes.Buffer - err := e.doMatchWithDeadline(ctx, "ExpectNoMatchBefore", func(rd *bufio.Reader) error { - for { - r, _, err := rd.ReadRune() - if err != nil { - return err - } - _, err = buffer.WriteRune(r) - if err != nil { - return err - } - - if strings.Contains(buffer.String(), match) { - return xerrors.Errorf("found %q before %q", match, before) - } - - if strings.Contains(buffer.String(), before) { - return nil - } - } - }) - if err != nil { - e.fatalf("read error", "%v (wanted no %q before %q; got %q)", err, match, before, buffer.String()) - return "" - } - e.logf("matched %q = %q", before, stripansi.Strip(buffer.String())) - return buffer.String() -} - -func (e *outExpecter) Peek(ctx context.Context, n int) []byte { - e.t.Helper() - - var out []byte - err := e.doMatchWithDeadline(ctx, "Peek", func(rd *bufio.Reader) error { - var err error - out, err = rd.Peek(n) - return err - }) - if err != nil { - e.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out) - return nil - } - e.logf("peeked %d/%d bytes = %q", len(out), n, out) - return slices.Clone(out) -} - -func (e *outExpecter) ReadRune(ctx context.Context) rune { - e.t.Helper() - - var r rune - err := e.doMatchWithDeadline(ctx, "ReadRune", func(rd *bufio.Reader) error { - var err error - r, _, err = rd.ReadRune() - return err - }) - if err != nil { - e.fatalf("read error", "%v (wanted rune; got %q)", err, r) - return 0 - } - e.logf("matched rune = %q", r) - return r -} - -func (e *outExpecter) ReadLine(ctx context.Context) string { - e.t.Helper() - - var buffer bytes.Buffer - err := e.doMatchWithDeadline(ctx, "ReadLine", func(rd *bufio.Reader) error { - for { - r, _, err := rd.ReadRune() - if err != nil { - return err - } - if r == '\n' { - return nil - } - if r == '\r' { - // Peek the next rune to see if it's an LF and then consume - // it. - - // Unicode code points can be up to 4 bytes, but the - // ones we're looking for are only 1 byte. - b, _ := rd.Peek(1) - if len(b) == 0 { - return nil - } - - r, _ = utf8.DecodeRune(b) - if r == '\n' { - _, _, err = rd.ReadRune() - if err != nil { - return err - } - } - - return nil - } - - _, err = buffer.WriteRune(r) - if err != nil { - return err - } - } - }) - if err != nil { - e.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String()) - return "" - } - e.logf("matched newline = %q", buffer.String()) - return buffer.String() -} - -func (e *outExpecter) doMatchWithDeadline(ctx context.Context, name string, fn func(*bufio.Reader) error) error { - e.t.Helper() - - // A timeout is mandatory, caller can decide by passing a context - // that times out. - if _, ok := ctx.Deadline(); !ok { - timeout := WaitMedium - e.logf("%s ctx has no deadline, using %s", name, timeout) - var cancel context.CancelFunc - //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*. - ctx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - - match := make(chan error, 1) - go func() { - defer close(match) - match <- fn(e.runeReader) - }() - select { - case err := <-match: - return err - case <-ctx.Done(): - // Ensure goroutine is cleaned up before test exit, do not call - // (*outExpecter).close here to let the caller decide. - _ = e.out.Close() - <-match - - return xerrors.Errorf("match deadline exceeded: %w", ctx.Err()) - } -} - -func (e *outExpecter) logf(format string, args ...interface{}) { - e.t.Helper() - - // Match regular logger timestamp format, we seem to be logging in - // UTC in other places as well, so match here. - e.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), e.name, fmt.Sprintf(format, args...)) -} - -func (e *outExpecter) fatalf(reason string, format string, args ...interface{}) { - e.t.Helper() - - // Ensure the message is part of the normal log stream before - // failing the test. - e.logf("%s: %s", reason, fmt.Sprintf(format, args...)) - - require.FailNowf(e.t, reason, format, args...) -} - -type PTY struct { - outExpecter - pty.PTY -} - -func (p *PTY) Close() error { - p.t.Helper() - pErr := p.PTY.Close() - if pErr != nil { - p.logf("PTY: Close failed: %v", pErr) - } - eErr := p.outExpecter.close("PTY close") - if eErr != nil { - p.logf("PTY: close expecter failed: %v", eErr) - } - if pErr != nil { - return pErr - } - return eErr -} - -func (p *PTY) Attach(inv *cobra.Command) *PTY { - p.t.Helper() - - inv.SetOut(p.Output()) - inv.SetErr(p.Output()) - inv.SetIn(p.Input()) - return p -} - -func (p *PTY) Write(r rune) { - p.t.Helper() - - p.logf("stdin: %q", r) - _, err := p.Input().Write([]byte{byte(r)}) - require.NoError(p.t, err, "write failed") -} - -func (p *PTY) WriteLine(str string) { - p.t.Helper() - - newline := []byte{'\r'} - if runtime.GOOS == "windows" { - newline = append(newline, '\n') - } - p.logf("stdin: %q", str+string(newline)) - _, err := p.Input().Write(append([]byte(str), newline...)) - require.NoError(p.t, err, "write line failed") -} - -type PTYCmd struct { - outExpecter - pty.PTYCmd -} - -func (p *PTYCmd) Close() error { - p.t.Helper() - pErr := p.PTYCmd.Close() - if pErr != nil { - p.logf("PTYCmd: Close failed: %v", pErr) - } - eErr := p.outExpecter.close("PTYCmd close") - if eErr != nil { - p.logf("PTYCmd: close expecter failed: %v", eErr) - } - if pErr != nil { - return pErr - } - return eErr -} - -// stdbuf is like a buffered stdout, it buffers writes until read. -type stdbuf struct { - r io.Reader - - mu sync.Mutex // Protects following. - b []byte - more chan struct{} - err error -} - -func newStdbuf() *stdbuf { - return &stdbuf{more: make(chan struct{}, 1)} -} - -func (b *stdbuf) Read(p []byte) (int, error) { - if b.r == nil { - return b.readOrWaitForMore(p) - } - - n, err := b.r.Read(p) - if xerrors.Is(err, io.EOF) { - b.r = nil - err = nil - if n == 0 { - return b.readOrWaitForMore(p) - } - } - return n, err -} - -func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) { - b.mu.Lock() - defer b.mu.Unlock() - - // Deplete channel so that more check - // is for future input into buffer. - select { - case <-b.more: - default: - } - - if len(b.b) == 0 { - if b.err != nil { - return 0, b.err - } - - b.mu.Unlock() - <-b.more - b.mu.Lock() - } - - b.r = bytes.NewReader(b.b) - b.b = b.b[len(b.b):] - - return b.r.Read(p) -} - -func (b *stdbuf) Write(p []byte) (int, error) { - if len(p) == 0 { - return 0, nil - } - - b.mu.Lock() - defer b.mu.Unlock() - - if b.err != nil { - return 0, b.err - } - - b.b = append(b.b, p...) - - select { - case b.more <- struct{}{}: - default: - } - - return len(p), nil -} - -func (b *stdbuf) Close() error { - return b.closeErr(nil) -} - -func (b *stdbuf) closeErr(err error) error { - b.mu.Lock() - defer b.mu.Unlock() - if b.err != nil { - return err - } - if err == nil { - b.err = io.EOF - } else { - b.err = err - } - close(b.more) - return err -} diff --git a/ptytest/ptytest_internal_test.go b/ptytest/ptytest_internal_test.go deleted file mode 100644 index 2915417..0000000 --- a/ptytest/ptytest_internal_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package ptytest - -import ( - "bytes" - "io" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestStdbuf(t *testing.T) { - t.Parallel() - - var got bytes.Buffer - - b := newStdbuf() - done := make(chan struct{}) - go func() { - defer close(done) - _, err := io.Copy(&got, b) - assert.NoError(t, err) - }() - - _, err := b.Write([]byte("hello ")) - require.NoError(t, err) - _, err = b.Write([]byte("world\n")) - require.NoError(t, err) - _, err = b.Write([]byte("bye\n")) - require.NoError(t, err) - - err = b.Close() - require.NoError(t, err) - <-done - - assert.Equal(t, "hello world\nbye\n", got.String()) -} diff --git a/ptytest/ptytest_test.go b/ptytest/ptytest_test.go deleted file mode 100644 index bf05e3c..0000000 --- a/ptytest/ptytest_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package ptytest_test - -import ( - "context" - "fmt" - "runtime" - "strings" - "testing" - - "github.com/spf13/cobra" - "github.com/stretchr/testify/require" - - "github.com/aymanbagabas/go-pty/ptytest" -) - -func TestPtytest(t *testing.T) { - t.Parallel() - t.Run("Echo", func(t *testing.T) { - t.Parallel() - pty := ptytest.New(t) - pty.Output().Write([]byte("write")) - pty.ExpectMatch("write") - pty.WriteLine("read") - }) - - t.Run("ReadLine", func(t *testing.T) { - t.Parallel() - if runtime.GOOS == "windows" { - t.Skip("ReadLine is glitchy on windows when it comes to the final line of output it seems") - } - - ctx, cancel := context.WithTimeout(context.Background(), ptytest.WaitLong) - t.Cleanup(cancel) - pty := ptytest.New(t) - - // The PTY expands these to \r\n (even on linux). - pty.Output().Write([]byte("line 1\nline 2\nline 3\nline 4\nline 5")) - require.Equal(t, "line 1", pty.ReadLine(ctx)) - require.Equal(t, "line 2", pty.ReadLine(ctx)) - require.Equal(t, "line 3", pty.ReadLine(ctx)) - require.Equal(t, "line 4", pty.ReadLine(ctx)) - require.Equal(t, "line 5", pty.ExpectMatch("5")) - }) - - // See https://github.com/coder/coder/issues/2122 for the motivation - // behind this test. - t.Run("Ptytest should not hang when output is not consumed", func(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - output string - isPlatformBug bool - }{ - {name: "1024 is safe (does not exceed macOS buffer)", output: strings.Repeat(".", 1024)}, - {name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025)}, - {name: "10241 large output", output: strings.Repeat(".", 10241)}, // 1024 * 10 + 1 - } - for _, tt := range tests { - tt := tt - // nolint:paralleltest // Avoid parallel test to more easily identify the issue. - t.Run(tt.name, func(t *testing.T) { - inv := &cobra.Command{ - Use: "test", - RunE: func(cmd *cobra.Command, _ []string) error { - _, err := fmt.Fprint(cmd.OutOrStdout(), tt.output) - return err - }, - } - pty := ptytest.New(t) - pty.Attach(inv) - err := inv.Execute() - require.NoError(t, err) - }) - } - }) -} diff --git a/ssh.go b/ssh.go new file mode 100644 index 0000000..1ca0bed --- /dev/null +++ b/ssh.go @@ -0,0 +1,75 @@ +package pty + +import ( + "golang.org/x/crypto/ssh" +) + +// ApplyTerminalModes applies the given ssh terminal modes to the given file +// descriptor. +func ApplyTerminalModes(fd int, width int, height int, modes ssh.TerminalModes) error { + if modes == nil { + return nil + } + return applyTerminalModesToFd(fd, width, height, modes) +} + +// terminalModeFlagNames maps the SSH terminal mode flags to mnemonic +// names used by the termios package. +var terminalModeFlagNames = map[uint8]string{ + ssh.VINTR: "intr", + ssh.VQUIT: "quit", + ssh.VERASE: "erase", + ssh.VKILL: "kill", + ssh.VEOF: "eof", + ssh.VEOL: "eol", + ssh.VEOL2: "eol2", + ssh.VSTART: "start", + ssh.VSTOP: "stop", + ssh.VSUSP: "susp", + ssh.VDSUSP: "dsusp", + ssh.VREPRINT: "rprnt", + ssh.VWERASE: "werase", + ssh.VLNEXT: "lnext", + ssh.VFLUSH: "flush", + ssh.VSWTCH: "swtch", + ssh.VSTATUS: "status", + ssh.VDISCARD: "discard", + ssh.IGNPAR: "ignpar", + ssh.PARMRK: "parmrk", + ssh.INPCK: "inpck", + ssh.ISTRIP: "istrip", + ssh.INLCR: "inlcr", + ssh.IGNCR: "igncr", + ssh.ICRNL: "icrnl", + ssh.IUCLC: "iuclc", + ssh.IXON: "ixon", + ssh.IXANY: "ixany", + ssh.IXOFF: "ixoff", + ssh.IMAXBEL: "imaxbel", + ssh.IUTF8: "iutf8", + ssh.ISIG: "isig", + ssh.ICANON: "icanon", + ssh.XCASE: "xcase", + ssh.ECHO: "echo", + ssh.ECHOE: "echoe", + ssh.ECHOK: "echok", + ssh.ECHONL: "echonl", + ssh.NOFLSH: "noflsh", + ssh.TOSTOP: "tostop", + ssh.IEXTEN: "iexten", + ssh.ECHOCTL: "echoctl", + ssh.ECHOKE: "echoke", + ssh.PENDIN: "pendin", + ssh.OPOST: "opost", + ssh.OLCUC: "olcuc", + ssh.ONLCR: "onlcr", + ssh.OCRNL: "ocrnl", + ssh.ONOCR: "onocr", + ssh.ONLRET: "onlret", + ssh.CS7: "cs7", + ssh.CS8: "cs8", + ssh.PARENB: "parenb", + ssh.PARODD: "parodd", + ssh.TTY_OP_ISPEED: "tty_op_ispeed", + ssh.TTY_OP_OSPEED: "tty_op_ospeed", +} diff --git a/ssh/ssh.go b/ssh/ssh.go deleted file mode 100644 index e6c72e5..0000000 --- a/ssh/ssh.go +++ /dev/null @@ -1,16 +0,0 @@ -package ssh - -import ( - "log" - - "golang.org/x/crypto/ssh" -) - -// ApplyTerminalModes -// request to the given fd. -// -// This is based on code from Tailscale's tailssh package: -// https://github.com/tailscale/tailscale/blob/main/ssh/tailssh/incubator.go -func ApplyTerminalModes(fd uintptr, width int, height int, modes ssh.TerminalModes, logger *log.Logger) error { - return applyTerminalModesToFd(fd, height, width, modes, logger) -} diff --git a/ssh/ssh_test.go b/ssh/ssh_test.go deleted file mode 100644 index 08eb5e5..0000000 --- a/ssh/ssh_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package ssh_test - -import ( - "io" - "log" - "testing" - - "github.com/aymanbagabas/go-pty" - "github.com/aymanbagabas/go-pty/ptytest" - sshPty "github.com/aymanbagabas/go-pty/ssh" - "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" -) - -func TestSSH(t *testing.T) { - t.Run("SSH_TTY", func(t *testing.T) { - t.Parallel() - h, w := 24, 80 - modes := ssh.TerminalModes{ - ssh.ECHO: 1, - ssh.ICANON: 1, - } - logger := log.New(io.Discard, "", 0) - pty, ps := ptytest.Start(t, - pty.Command("env"), - pty.WithPTYOption(pty.WithSize(w, h)), - pty.WithPTYCallback(func(p pty.PTY, c *pty.Cmd) error { - c.Env = append(c.Env, "SSH_TTY="+p.Name()) - if c, ok := p.(pty.Controllable); ok { - if err := c.ControlTTY(func(fd uintptr) error { - return sshPty.ApplyTerminalModes(fd, w, h, modes, logger) - }); err != nil { - return err - } - } - return nil - }), - ) - pty.ExpectMatch("SSH_TTY=/dev/") - err := ps.Wait() - require.NoError(t, err) - err = pty.Close() - require.NoError(t, err) - }) -} diff --git a/ssh_other.go b/ssh_other.go new file mode 100644 index 0000000..0803c61 --- /dev/null +++ b/ssh_other.go @@ -0,0 +1,52 @@ +//go:build !windows +// +build !windows + +package pty + +import ( + "fmt" + + "golang.org/x/crypto/ssh" +) + +func applyTerminalModesToFd(fd int, width int, height int, modes ssh.TerminalModes) error { + // Get the current TTY configuration. + tios, err := termios.GTTY(int(fd)) + if err != nil { + return fmt.Errorf("GTTY: %w", err) + } + + // Apply the modes from the SSH request. + tios.Row = height + tios.Col = width + + for c, v := range modes { + if c == ssh.TTY_OP_ISPEED { + tios.Ispeed = int(v) + continue + } + if c == ssh.TTY_OP_OSPEED { + tios.Ospeed = int(v) + continue + } + k, ok := terminalModeFlagNames[c] + if !ok { + continue + } + if _, ok := tios.CC[k]; ok { + tios.CC[k] = uint8(v) + continue + } + if _, ok := tios.Opts[k]; ok { + tios.Opts[k] = v > 0 + continue + } + } + + // Save the new TTY configuration. + if _, err := tios.STTY(int(fd)); err != nil { + return fmt.Errorf("STTY: %w", err) + } + + return nil +} diff --git a/ssh_windows.go b/ssh_windows.go new file mode 100644 index 0000000..79a8ccf --- /dev/null +++ b/ssh_windows.go @@ -0,0 +1,13 @@ +//go:build windows +// +build windows + +package pty + +import ( + "golang.org/x/crypto/ssh" +) + +func applyTerminalModesToFd(fd int, width int, height int, modes ssh.TerminalModes) error { + // TODO + return nil +} diff --git a/start.go b/start.go deleted file mode 100644 index 46103f6..0000000 --- a/start.go +++ /dev/null @@ -1,73 +0,0 @@ -package pty - -import ( - "context" - "os/exec" -) - -// StartOption represents a configuration option passed to Start. -type StartOption func(*startOptions) - -type startOptions struct { - ptyOpts []Option - ptyCb PTYCallback -} - -// WithPTYOption applies the given options to the underlying PTY. -func WithPTYOption(opts ...Option) StartOption { - return func(o *startOptions) { - o.ptyOpts = append(o.ptyOpts, opts...) - } -} - -// PTYCallback is a function that is called with the Cmd and the PTY before the -// command is started. This allows the caller to modify the PTY and Cmd before -// the command is started. -type PTYCallback func(PTY, *Cmd) error - -// WithPTYCallback allows the caller to modify the Cmd before it is started. -func WithPTYCallback(fn PTYCallback) StartOption { - return func(o *startOptions) { - o.ptyCb = fn - } -} - -// Cmd is a drop-in replacement for exec.Cmd with most of the same API, but -// it exposes the context.Context to our PTY code so that we can still kill the -// process when the Context expires. This is required because on Windows, we don't -// start the command using the `exec` library, so we have to manage the context -// ourselves. -type Cmd struct { - Context context.Context - Path string - Args []string - Env []string - Dir string -} - -func CommandContext(ctx context.Context, name string, arg ...string) *Cmd { - return &Cmd{ - Context: ctx, - Path: name, - Args: append([]string{name}, arg...), - Env: make([]string, 0), - } -} - -func Command(name string, arg ...string) *Cmd { - return CommandContext(context.Background(), name, arg...) -} - -func (c *Cmd) AsExec() *exec.Cmd { - //nolint: gosec - execCmd := exec.CommandContext(c.Context, c.Path, c.Args[1:]...) - execCmd.Dir = c.Dir - execCmd.Env = c.Env - return execCmd -} - -// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and -// instead rely on the returned Process to manage the command/process. -func Start(cmd *Cmd, opt ...StartOption) (PTYCmd, Process, error) { - return startPty(cmd, opt...) -} diff --git a/start_other.go b/start_other.go deleted file mode 100644 index 323b62b..0000000 --- a/start_other.go +++ /dev/null @@ -1,83 +0,0 @@ -//go:build !windows - -package pty - -import ( - "context" - "runtime" - "strings" - "syscall" - - "golang.org/x/xerrors" -) - -func startPty(cmdPty *Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) { - var opts startOptions - for _, o := range opt { - o(&opts) - } - - opty, err := newPty(opts.ptyOpts...) - if err != nil { - return nil, nil, xerrors.Errorf("newPty failed: %w", err) - } - - if opts.ptyCb != nil { - if err := opts.ptyCb(opty, cmdPty); err != nil { - _ = opty.Close() - return nil, nil, xerrors.Errorf("pty callback failed: %w", err) - } - } - - origEnv := cmdPty.Env - if cmdPty.Context == nil { - cmdPty.Context = context.Background() - } - cmdExec := cmdPty.AsExec() - - cmdExec.SysProcAttr = &syscall.SysProcAttr{ - Setsid: true, - Setctty: true, - } - cmdExec.Stdout = opty.tty - cmdExec.Stderr = opty.tty - cmdExec.Stdin = opty.tty - err = cmdExec.Start() - if err != nil { - _ = opty.Close() - if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "bad file descriptor") { - // macOS has an obscure issue where the PTY occasionally closes - // before it's used. It's unknown why this is, but creating a new - // TTY resolves it. - cmdPty.Env = origEnv - return startPty(cmdPty, opt...) - } - return nil, nil, xerrors.Errorf("start: %w", err) - } - if runtime.GOOS == "linux" { - // Now that we've started the command, and passed the TTY to it, close - // our file so that the other process has the only open file to the TTY. - // Once the process closes the TTY (usually on exit), there will be no - // open references and the OS kernel returns an error when trying to - // read or write to our PTY end. Without this (on Linux), reading from - // the process output will block until we close our TTY. - // - // Note that on darwin, reads on the PTY don't block even if we keep the - // TTY file open, and keeping it open seems to prevent race conditions - // where we lose output. Couldn't find official documentation - // confirming this, but I did find a thread of someone else's - // observations: https://developer.apple.com/forums/thread/663632 - if err := opty.tty.Close(); err != nil { - _ = cmdExec.Process.Kill() - return nil, nil, xerrors.Errorf("close tty: %w", err) - } - opty.tty = nil // remove so we don't attempt to close it again. - } - oProcess := &otherProcess{ - pty: opty.pty, - cmd: cmdExec, - cmdDone: make(chan any), - } - go oProcess.waitInternal() - return opty, oProcess, nil -} diff --git a/start_other_test.go b/start_other_test.go deleted file mode 100644 index 51dbef3..0000000 --- a/start_other_test.go +++ /dev/null @@ -1,75 +0,0 @@ -//go:build !windows - -package pty_test - -import ( - "os/exec" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "golang.org/x/xerrors" - - "github.com/aymanbagabas/go-pty" - "github.com/aymanbagabas/go-pty/ptytest" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestStart(t *testing.T) { - t.Parallel() - t.Run("Echo", func(t *testing.T) { - t.Parallel() - pty, ps := ptytest.Start(t, pty.Command("echo", "test")) - - pty.ExpectMatch("test") - err := ps.Wait() - require.NoError(t, err) - err = pty.Close() - require.NoError(t, err) - }) - - t.Run("Kill", func(t *testing.T) { - t.Parallel() - pty, ps := ptytest.Start(t, pty.Command("sleep", "30")) - err := ps.Kill() - assert.NoError(t, err) - err = ps.Wait() - var exitErr *exec.ExitError - require.True(t, xerrors.As(err, &exitErr)) - assert.NotEqual(t, 0, exitErr.ExitCode()) - err = pty.Close() - require.NoError(t, err) - }) -} - -// these constants/vars are used by Test_Start_copy - -const cmdEcho = "echo" - -var argEcho = []string{"test"} - -// these constants/vars are used by Test_Start_truncate - -const ( - countEnd = 1000 - cmdCount = "sh" -) - -var argCount = []string{"-c", ` -i=0 -while [ $i -ne 1000 ] -do - i=$(($i+1)) - echo "$i" -done -`} - -// these constants/vars are used by Test_Start_cancel_context - -const cmdSleep = "sleep" - -var argSleep = []string{"30"} diff --git a/start_test.go b/start_test.go deleted file mode 100644 index f008db6..0000000 --- a/start_test.go +++ /dev/null @@ -1,176 +0,0 @@ -package pty_test - -import ( - "bytes" - "context" - "fmt" - "io" - "strings" - "testing" - "time" - - "github.com/hinshun/vt10x" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/aymanbagabas/go-pty" - "github.com/aymanbagabas/go-pty/ptytest" -) - -// Test_Start_copy tests that we can use io.Copy() on command output -// without deadlocking. -func Test_Start_copy(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), ptytest.WaitShort) - defer cancel() - - pc, cmd, err := pty.Start(pty.CommandContext(ctx, cmdEcho, argEcho...)) - require.NoError(t, err) - b := &bytes.Buffer{} - readDone := make(chan error, 1) - go func() { - _, err := io.Copy(b, pc.OutputReader()) - readDone <- err - }() - - select { - case err := <-readDone: - require.NoError(t, err) - case <-ctx.Done(): - t.Error("read timed out") - } - assert.Contains(t, b.String(), "test") - - cmdDone := make(chan error, 1) - go func() { - cmdDone <- cmd.Wait() - }() - - select { - case err := <-cmdDone: - require.NoError(t, err) - case <-ctx.Done(): - t.Error("cmd.Wait() timed out") - } -} - -// Test_Start_truncation tests that we can read command output without truncation -// even after the command has exited. -func Test_Start_truncation(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), ptytest.WaitSuperLong) - defer cancel() - - pc, cmd, err := pty.Start(pty.CommandContext(ctx, cmdCount, argCount...)) - - require.NoError(t, err) - readDone := make(chan struct{}) - go func() { - defer close(readDone) - // avoid buffered IO so that we can precisely control how many bytes to read. - n := 1 - for n <= countEnd { - want := fmt.Sprintf("%d", n) - err := readUntil(ctx, t, want, pc.OutputReader()) - assert.NoError(t, err, "want: %s", want) - if err != nil { - return - } - n++ - if (countEnd - n) < 100 { - // If the OS buffers the output, the process can exit even if - // we're not done reading. We want to slow our reads so that - // if there is a race between reading the data and it being - // truncated, we will lose and fail the test. - time.Sleep(ptytest.IntervalFast) - } - } - // ensure we still get to EOF - endB := &bytes.Buffer{} - _, err := io.Copy(endB, pc.OutputReader()) - assert.NoError(t, err) - }() - - cmdDone := make(chan error, 1) - go func() { - cmdDone <- cmd.Wait() - }() - - select { - case err := <-cmdDone: - require.NoError(t, err) - case <-ctx.Done(): - t.Fatal("cmd.Wait() timed out") - } - - select { - case <-readDone: - // OK! - case <-ctx.Done(): - t.Fatal("read timed out") - } -} - -// Test_Start_cancel_context tests that we can cancel the command context and kill the process. -func Test_Start_cancel_context(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), ptytest.WaitMedium) - defer cancel() - cmdCtx, cmdCancel := context.WithCancel(ctx) - - pc, cmd, err := pty.Start(pty.CommandContext(cmdCtx, cmdSleep, argSleep...)) - require.NoError(t, err) - defer func() { - _ = pc.Close() - }() - cmdCancel() - - cmdDone := make(chan struct{}) - go func() { - defer close(cmdDone) - _ = cmd.Wait() - }() - - select { - case <-cmdDone: - // OK! - case <-ctx.Done(): - t.Error("cmd.Wait() timed out") - } -} - -// readUntil reads one byte at a time until we either see the string we want, or the context expires -func readUntil(ctx context.Context, t *testing.T, want string, r io.Reader) error { - // output can contain virtual terminal sequences, so we need to parse these - // to correctly interpret getting what we want. - term := vt10x.New(vt10x.WithSize(80, 80)) - readErrs := make(chan error, 1) - for { - b := make([]byte, 1) - go func() { - _, err := r.Read(b) - readErrs <- err - }() - select { - case err := <-readErrs: - if err != nil { - t.Logf("err: %v\ngot: %v", err, term) - return err - } - term.Write(b) - case <-ctx.Done(): - return ctx.Err() - } - got := term.String() - lines := strings.Split(got, "\n") - for _, line := range lines { - if strings.TrimSpace(line) == want { - t.Logf("want: %v\n got:%v", want, line) - return nil - } - } - } -} diff --git a/start_windows.go b/start_windows.go deleted file mode 100644 index 5ce62bd..0000000 --- a/start_windows.go +++ /dev/null @@ -1,204 +0,0 @@ -//go:build windows -// +build windows - -package pty - -import ( - "os" - "os/exec" - "strings" - "unicode/utf16" - "unsafe" - - "golang.org/x/sys/windows" - "golang.org/x/xerrors" -) - -// Allocates a PTY and starts the specified command attached to it. -// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process -func startPty(cmd *Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) { - var opts startOptions - for _, o := range opt { - o(&opts) - } - - fullPath, err := exec.LookPath(cmd.Path) - if err != nil { - return nil, nil, err - } - pathPtr, err := windows.UTF16PtrFromString(fullPath) - if err != nil { - return nil, nil, err - } - argsPtr, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(cmd.Args)) - if err != nil { - return nil, nil, err - } - if cmd.Dir == "" { - cmd.Dir, err = os.Getwd() - if err != nil { - return nil, nil, err - } - } - dirPtr, err := windows.UTF16PtrFromString(cmd.Dir) - if err != nil { - return nil, nil, err - } - - winPty, err := newPty(opts.ptyOpts...) - if err != nil { - return nil, nil, err - } - defer func() { - if retErr != nil { - // we hit some error finishing setup; close pty, so - // we don't leak the kernel resources associated with it - _ = winPty.Close() - } - }() - if opts.ptyCb != nil { - if err := opts.ptyCb(winPty, cmd); err != nil { - return nil, nil, xerrors.Errorf("pty callback failed: %w", err) - } - } - - attrs, err := windows.NewProcThreadAttributeList(1) - if err != nil { - return nil, nil, err - } - // Taken from: https://github.com/microsoft/hcsshim/blob/2314362e977aa03b3ed245a4beb12d00422af0e2/internal/winapi/process.go#L6 - err = attrs.Update(0x20016, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console)) - if err != nil { - return nil, nil, err - } - - startupInfo := &windows.StartupInfoEx{} - startupInfo.ProcThreadAttributeList = attrs.List() - startupInfo.StartupInfo.Flags = windows.STARTF_USESTDHANDLES - startupInfo.StartupInfo.Cb = uint32(unsafe.Sizeof(*startupInfo)) - var processInfo windows.ProcessInformation - err = windows.CreateProcess( - pathPtr, - argsPtr, - nil, - nil, - false, - // https://docs.microsoft.com/en-us/windows/win32/procthread/process-creation-flags#create_unicode_environment - windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, - createEnvBlock(addCriticalEnv(dedupEnvCase(true, cmd.Env))), - dirPtr, - &startupInfo.StartupInfo, - &processInfo, - ) - if err != nil { - return nil, nil, err - } - defer windows.CloseHandle(processInfo.Thread) - defer windows.CloseHandle(processInfo.Process) - - process, err := os.FindProcess(int(processInfo.ProcessId)) - if err != nil { - return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err) - } - wp := &windowsProcess{ - cmdDone: make(chan any), - proc: process, - pw: winPty, - } - defer func() { - if retErr != nil { - // if we later error out, kill the process since - // the caller will have no way to interact with it - _ = process.Kill() - } - }() - - // Now that we've started the command, and passed the pseudoconsole to it, - // close the output write and input read files, so that the other process - // has the only handles to them. Once the process closes the console, there - // will be no open references and the OS kernel returns an error when trying - // to read or write to our end. Without this, reading from the process - // output will block until they are closed. - errO := winPty.outputWrite.Close() - winPty.outputWrite = nil - errI := winPty.inputRead.Close() - winPty.inputRead = nil - if errO != nil { - return nil, nil, errO - } - if errI != nil { - return nil, nil, errI - } - go wp.waitInternal() - if cmd.Context != nil { - go wp.killOnContext(cmd.Context) - } - return winPty, wp, nil -} - -// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476 -func createEnvBlock(envv []string) *uint16 { - if len(envv) == 0 { - return &utf16.Encode([]rune("\x00\x00"))[0] - } - length := 0 - for _, s := range envv { - length += len(s) + 1 - } - length += 1 - - b := make([]byte, length) - i := 0 - for _, s := range envv { - l := len(s) - copy(b[i:i+l], []byte(s)) - copy(b[i+l:i+l+1], []byte{0}) - i = i + l + 1 - } - copy(b[i:i+1], []byte{0}) - - return &utf16.Encode([]rune(string(b)))[0] -} - -// dedupEnvCase is dedupEnv with a case option for testing. -// If caseInsensitive is true, the case of keys is ignored. -func dedupEnvCase(caseInsensitive bool, env []string) []string { - out := make([]string, 0, len(env)) - saw := make(map[string]int, len(env)) // key => index into out - for _, kv := range env { - eq := strings.Index(kv, "=") - if eq < 0 { - out = append(out, kv) - continue - } - k := kv[:eq] - if caseInsensitive { - k = strings.ToLower(k) - } - if dupIdx, isDup := saw[k]; isDup { - out[dupIdx] = kv - continue - } - saw[k] = len(out) - out = append(out, kv) - } - return out -} - -// addCriticalEnv adds any critical environment variables that are required -// (or at least almost always required) on the operating system. -// Currently this is only used for Windows. -func addCriticalEnv(env []string) []string { - for _, kv := range env { - eq := strings.Index(kv, "=") - if eq < 0 { - continue - } - k := kv[:eq] - if strings.EqualFold(k, "SYSTEMROOT") { - // We already have it. - return env - } - } - return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) -} diff --git a/start_windows_test.go b/start_windows_test.go deleted file mode 100644 index 78804d6..0000000 --- a/start_windows_test.go +++ /dev/null @@ -1,75 +0,0 @@ -//go:build windows -// +build windows - -package pty_test - -import ( - "fmt" - "os/exec" - "testing" - - "github.com/aymanbagabas/go-pty" - "github.com/aymanbagabas/go-pty/ptytest" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "golang.org/x/xerrors" -) - -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - -func TestStart(t *testing.T) { - t.Parallel() - t.Run("Echo", func(t *testing.T) { - t.Parallel() - ptty, ps := ptytest.Start(t, pty.Command("cmd.exe", "/c", "echo", "test")) - ptty.ExpectMatch("test") - err := ps.Wait() - require.NoError(t, err) - err = ptty.Close() - require.NoError(t, err) - }) - t.Run("Resize", func(t *testing.T) { - t.Parallel() - ptty, _ := ptytest.Start(t, pty.Command("cmd.exe")) - err := ptty.Resize(100, 50) - require.NoError(t, err) - err = ptty.Close() - require.NoError(t, err) - }) - t.Run("Kill", func(t *testing.T) { - t.Parallel() - ptty, ps := ptytest.Start(t, pty.Command("cmd.exe")) - err := ps.Kill() - assert.NoError(t, err) - err = ps.Wait() - var exitErr *exec.ExitError - require.True(t, xerrors.As(err, &exitErr)) - assert.NotEqual(t, 0, exitErr.ExitCode()) - err = ptty.Close() - require.NoError(t, err) - }) -} - -// these constants/vars are used by Test_Start_copy - -const cmdEcho = "cmd.exe" - -var argEcho = []string{"/c", "echo", "test"} - -// these constants/vars are used by Test_Start_truncate - -const ( - countEnd = 1000 - cmdCount = "cmd.exe" -) - -var argCount = []string{"/c", fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)} - -// these constants/vars are used by Test_Start_cancel_context - -const cmdSleep = "cmd.exe" - -var argSleep = []string{"/c", "timeout", "/t", "30"} diff --git a/zsyscall_windows.go b/zsyscall_windows.go new file mode 100644 index 0000000..d2a766a --- /dev/null +++ b/zsyscall_windows.go @@ -0,0 +1,75 @@ +//go:build windows + +// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT. + +package pty + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procClosePseudoConsole = modkernel32.NewProc("ClosePseudoConsole") + procCreatePseudoConsole = modkernel32.NewProc("CreatePseudoConsole") + procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole") +) + +func closePseudoConsole(hpc windows.Handle) { + syscall.Syscall(procClosePseudoConsole.Addr(), 1, uintptr(hpc), 0, 0) + return +} + +func _createPseudoConsole(size uint32, hInput windows.Handle, hOutput windows.Handle, dwFlags uint32, hpcon *windows.Handle) (hr error) { + r0, _, _ := syscall.Syscall6(procCreatePseudoConsole.Addr(), 5, uintptr(size), uintptr(hInput), uintptr(hOutput), uintptr(dwFlags), uintptr(unsafe.Pointer(hpcon)), 0) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func _resizePseudoConsole(hPc windows.Handle, size uint32) (hr error) { + r0, _, _ := syscall.Syscall(procResizePseudoConsole.Addr(), 2, uintptr(hPc), uintptr(size), 0) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +}