Skip to content

Commit

Permalink
feat(output): add unsafe option
Browse files Browse the repository at this point in the history
This disables any TTY checks including IOCTL and unix calls.
Useful for mocking console output and querying the terminal over SSH.
  • Loading branch information
aymanbagabas authored and muesli committed Feb 3, 2023
1 parent b8d620b commit 3582eeb
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 27 deletions.
12 changes: 12 additions & 0 deletions output.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Output struct {
tty io.Writer
environ Environ

unsafe bool
cache bool
fgSync *sync.Once
fgColor Color
Expand Down Expand Up @@ -105,6 +106,17 @@ func WithColorCache(v bool) OutputOption {
}
}

// WithUnsafe returns a new Output with unsafe mode enabled. Unsafe mode doesn't
// check whether or not the terminal is a TTY.
//
// This is useful when mocking console output and enforcing ANSI escape output
// e.g. on SSH sessions.
func WithUnsafe() OutputOption {
return func(o *Output) {
o.unsafe = true
}
}

// ForegroundColor returns the terminal's default foreground color.
func (o *Output) ForegroundColor() Color {
f := func() {
Expand Down
3 changes: 3 additions & 0 deletions termenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ const (
)

func (o *Output) isTTY() bool {
if o.unsafe {
return true
}
if len(o.environ.Getenv("CI")) > 0 {
return false
}
Expand Down
59 changes: 32 additions & 27 deletions termenv_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ func (o Output) backgroundColor() Color {
return ANSIColor(0)
}

func waitForData(fd uintptr, timeout time.Duration) error {
func (o *Output) waitForData(timeout time.Duration) error {
fd := o.TTY().Fd()
tv := unix.NsecToTimeval(int64(timeout))
var readfds unix.FdSet
readfds.Set(int(fd))
Expand All @@ -132,13 +133,15 @@ func waitForData(fd uintptr, timeout time.Duration) error {
return nil
}

func readNextByte(f File) (byte, error) {
if err := waitForData(f.Fd(), OSCTimeout); err != nil {
return 0, err
func (o *Output) readNextByte() (byte, error) {
if !o.unsafe {
if err := o.waitForData(OSCTimeout); err != nil {
return 0, err
}
}

var b [1]byte
n, err := f.Read(b[:])
n, err := o.TTY().Read(b[:])
if err != nil {
return 0, err
}
Expand All @@ -153,15 +156,15 @@ func readNextByte(f File) (byte, error) {
// readNextResponse reads either an OSC response or a cursor position response:
// - OSC response: "\x1b]11;rgb:1111/1111/1111\x1b\\"
// - cursor position response: "\x1b[42;1R"
func readNextResponse(fd File) (response string, isOSC bool, err error) {
start, err := readNextByte(fd)
func (o *Output) readNextResponse() (response string, isOSC bool, err error) {
start, err := o.readNextByte()
if err != nil {
return "", false, err
}

// first byte must be ESC
for start != ESC {
start, err = readNextByte(fd)
start, err = o.readNextByte()
if err != nil {
return "", false, err
}
Expand All @@ -170,7 +173,7 @@ func readNextResponse(fd File) (response string, isOSC bool, err error) {
response += string(start)

// next byte is either '[' (cursor position response) or ']' (OSC response)
tpe, err := readNextByte(fd)
tpe, err := o.readNextByte()
if err != nil {
return "", false, err
}
Expand All @@ -188,7 +191,7 @@ func readNextResponse(fd File) (response string, isOSC bool, err error) {
}

for {
b, err := readNextByte(fd)
b, err := o.readNextByte()
if err != nil {
return "", false, err
}
Expand Down Expand Up @@ -229,23 +232,25 @@ func (o Output) termStatusReport(sequence int) (string, error) {
return "", ErrStatusReport
}

fd := int(tty.Fd())
// if in background, we can't control the terminal
if !isForeground(fd) {
return "", ErrStatusReport
}
if !o.unsafe {
fd := int(tty.Fd())
// if in background, we can't control the terminal
if !isForeground(fd) {
return "", ErrStatusReport
}

t, err := unix.IoctlGetTermios(fd, tcgetattr)
if err != nil {
return "", fmt.Errorf("%s: %s", ErrStatusReport, err)
}
defer unix.IoctlSetTermios(fd, tcsetattr, t) //nolint:errcheck
t, err := unix.IoctlGetTermios(fd, tcgetattr)
if err != nil {
return "", fmt.Errorf("%s: %s", ErrStatusReport, err)
}
defer unix.IoctlSetTermios(fd, tcsetattr, t) //nolint:errcheck

noecho := *t
noecho.Lflag = noecho.Lflag &^ unix.ECHO
noecho.Lflag = noecho.Lflag &^ unix.ICANON
if err := unix.IoctlSetTermios(fd, tcsetattr, &noecho); err != nil {
return "", fmt.Errorf("%s: %s", ErrStatusReport, err)
noecho := *t
noecho.Lflag = noecho.Lflag &^ unix.ECHO
noecho.Lflag = noecho.Lflag &^ unix.ICANON
if err := unix.IoctlSetTermios(fd, tcsetattr, &noecho); err != nil {
return "", fmt.Errorf("%s: %s", ErrStatusReport, err)
}
}

// first, send OSC query, which is ignored by terminal which do not support it
Expand All @@ -255,7 +260,7 @@ func (o Output) termStatusReport(sequence int) (string, error) {
fmt.Fprintf(tty, CSI+"6n")

// read the next response
res, isOSC, err := readNextResponse(tty)
res, isOSC, err := o.readNextResponse()
if err != nil {
return "", fmt.Errorf("%s: %s", ErrStatusReport, err)
}
Expand All @@ -266,7 +271,7 @@ func (o Output) termStatusReport(sequence int) (string, error) {
}

// read the cursor query response next and discard the result
_, _, err = readNextResponse(tty)
_, _, err = o.readNextResponse()
if err != nil {
return "", err
}
Expand Down

0 comments on commit 3582eeb

Please sign in to comment.