From f2cd7e2c664cfac57fa6b9af1b5091ab66ce236a Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Tue, 27 Aug 2024 12:57:59 -0400 Subject: [PATCH] Improve type safety, extract identical code Avoid fragility of tracking objects and their FDs separately. --- src/unix_term.rs | 118 ++++++++++++++++++++++++++++------------------- 1 file changed, 71 insertions(+), 47 deletions(-) diff --git a/src/unix_term.rs b/src/unix_term.rs index b8e0db2e..a9035c02 100644 --- a/src/unix_term.rs +++ b/src/unix_term.rs @@ -1,10 +1,9 @@ use std::env; use std::fmt::Display; use std::fs; -use std::io; -use std::io::{BufRead, BufReader}; +use std::io::{self, BufRead, BufReader}; use std::mem; -use std::os::unix::io::AsRawFd; +use std::os::fd::{AsRawFd, RawFd}; use std::str; #[cfg(not(target_os = "macos"))] @@ -18,7 +17,7 @@ pub(crate) use crate::common_term::*; pub(crate) const DEFAULT_WIDTH: u16 = 80; #[inline] -pub(crate) fn is_a_terminal(out: &Term) -> bool { +pub(crate) fn is_a_terminal(out: &impl AsRawFd) -> bool { unsafe { libc::isatty(out.as_raw_fd()) != 0 } } @@ -66,41 +65,76 @@ pub(crate) fn terminal_size(out: &Term) -> Option<(u16, u16)> { } } -pub(crate) fn read_secure() -> io::Result { - let mut f_tty; - let fd = unsafe { - if libc::isatty(libc::STDIN_FILENO) == 1 { - f_tty = None; - libc::STDIN_FILENO +enum Input { + Stdin(io::Stdin), + File(T), +} + +impl Input { + fn new() -> io::Result { + let stdin = io::stdin(); + if is_a_terminal(&stdin) { + Ok(Input::Stdin(stdin)) } else { let f = fs::OpenOptions::new() .read(true) .write(true) .open("/dev/tty")?; - let fd = f.as_raw_fd(); - f_tty = Some(BufReader::new(f)); - fd + Ok(Input::File(f)) } - }; + } +} + +impl Input> { + fn new() -> io::Result { + Ok(match Input::::new()? { + Input::Stdin(s) => Self::Stdin(s), + Input::File(f) => Self::File(BufReader::new(f)), + }) + } +} + +impl Input { + fn read_line(&mut self, buf: &mut String) -> io::Result { + match self { + Self::Stdin(s) => s.read_line(buf), + Self::File(f) => f.read_line(buf), + } + } +} + +impl AsRawFd for Input { + fn as_raw_fd(&self) -> RawFd { + match self { + Self::Stdin(s) => s.as_raw_fd(), + Self::File(f) => f.as_raw_fd(), + } + } +} + +impl AsRawFd for Input> { + fn as_raw_fd(&self) -> RawFd { + match self { + Self::Stdin(s) => s.as_raw_fd(), + Self::File(f) => f.get_ref().as_raw_fd(), + } + } +} + +pub(crate) fn read_secure() -> io::Result { + let mut input = Input::>::new()?; let mut termios = mem::MaybeUninit::uninit(); - c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?; + c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?; let mut termios = unsafe { termios.assume_init() }; let original = termios; termios.c_lflag &= !libc::ECHO; - c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &termios) })?; + c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &termios) })?; let mut rv = String::new(); - let read_rv = if let Some(f) = &mut f_tty { - f.read_line(&mut rv) - } else { - io::stdin().read_line(&mut rv) - }; - - c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &original) })?; + let read_rv = input.read_line(&mut rv); - // Ensure the fd is only closed after everything has been restored. - drop(f_tty); + c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &original) })?; read_rv.map(|_| { let len = rv.trim_end_matches(&['\r', '\n'][..]).len(); @@ -109,7 +143,7 @@ pub(crate) fn read_secure() -> io::Result { }) } -fn poll_fd(fd: i32, timeout: i32) -> io::Result { +fn poll_fd(fd: RawFd, timeout: i32) -> io::Result { let mut pollfd = libc::pollfd { fd, events: libc::POLLIN, @@ -124,7 +158,7 @@ fn poll_fd(fd: i32, timeout: i32) -> io::Result { } #[cfg(target_os = "macos")] -fn select_fd(fd: i32, timeout: i32) -> io::Result { +fn select_fd(fd: RawFd, timeout: i32) -> io::Result { unsafe { let mut read_fd_set: libc::fd_set = mem::zeroed(); @@ -156,7 +190,7 @@ fn select_fd(fd: i32, timeout: i32) -> io::Result { } } -fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result { +fn select_or_poll_term_fd(fd: RawFd, timeout: i32) -> io::Result { // There is a bug on macos that ttys cannot be polled, only select() // works. However given how problematic select is in general, we // normally want to use poll there too. @@ -169,7 +203,7 @@ fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result { poll_fd(fd, timeout) } -fn read_single_char(fd: i32) -> io::Result> { +fn read_single_char(fd: RawFd) -> io::Result> { // timeout of zero means that it will not block let is_ready = select_or_poll_term_fd(fd, 0)?; @@ -188,7 +222,7 @@ fn read_single_char(fd: i32) -> io::Result> { // Similar to libc::read. Read count bytes into slice buf from descriptor fd. // If successful, return the number of bytes read. // Will return an error if nothing was read, i.e when called at end of file. -fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result { +fn read_bytes(fd: RawFd, buf: &mut [u8], count: u8) -> io::Result { let read = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut _, count as usize) }; if read < 0 { Err(io::Error::last_os_error()) @@ -207,7 +241,7 @@ fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result { } } -fn read_single_key_impl(fd: i32) -> Result { +fn read_single_key_impl(fd: RawFd) -> Result { loop { match read_single_char(fd)? { Some('\x1b') => { @@ -301,27 +335,17 @@ fn read_single_key_impl(fd: i32) -> Result { } pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result { - let tty_f; - let fd = unsafe { - if libc::isatty(libc::STDIN_FILENO) == 1 { - libc::STDIN_FILENO - } else { - tty_f = fs::OpenOptions::new() - .read(true) - .write(true) - .open("/dev/tty")?; - tty_f.as_raw_fd() - } - }; + let input = Input::::new()?; + let mut termios = core::mem::MaybeUninit::uninit(); - c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?; + c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?; let mut termios = unsafe { termios.assume_init() }; let original = termios; unsafe { libc::cfmakeraw(&mut termios) }; termios.c_oflag = original.c_oflag; - c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &termios) })?; - let rv: io::Result = read_single_key_impl(fd); - c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &original) })?; + c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &termios) })?; + let rv: io::Result = read_single_key_impl(input.as_raw_fd()); + c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &original) })?; // if the user hit ^C we want to signal SIGINT to outselves. if let Err(ref err) = rv {