From a7aac307a2821b51de6c822dfb130d9e9d3dcd3d Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Mon, 16 Dec 2024 10:33:47 -0500 Subject: [PATCH] Avoid manual polling Use blocking and non-blocking reads instead. --- src/unix_term.rs | 176 ++++++++++++++++------------------------------- 1 file changed, 61 insertions(+), 115 deletions(-) diff --git a/src/unix_term.rs b/src/unix_term.rs index a9035c02..1f37cdf8 100644 --- a/src/unix_term.rs +++ b/src/unix_term.rs @@ -1,7 +1,7 @@ use std::env; use std::fmt::Display; use std::fs; -use std::io::{self, BufRead, BufReader}; +use std::io::{self, BufRead, BufReader, Read}; use std::mem; use std::os::fd::{AsRawFd, RawFd}; use std::str; @@ -94,6 +94,15 @@ impl Input> { } } +impl Read for Input { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + Self::Stdin(s) => s.read(buf), + Self::File(f) => f.read(buf), + } + } +} + impl Input { fn read_line(&mut self, buf: &mut String) -> io::Result { match self { @@ -143,112 +152,55 @@ pub(crate) fn read_secure() -> io::Result { }) } -fn poll_fd(fd: RawFd, timeout: i32) -> io::Result { - let mut pollfd = libc::pollfd { - fd, - events: libc::POLLIN, - revents: 0, - }; - let ret = unsafe { libc::poll(&mut pollfd as *mut _, 1, timeout) }; - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(pollfd.revents & libc::POLLIN != 0) - } -} - -#[cfg(target_os = "macos")] -fn select_fd(fd: RawFd, timeout: i32) -> io::Result { - unsafe { - let mut read_fd_set: libc::fd_set = mem::zeroed(); - - let mut timeout_val; - let timeout = if timeout < 0 { - std::ptr::null_mut() - } else { - timeout_val = libc::timeval { - tv_sec: (timeout / 1000) as _, - tv_usec: (timeout * 1000) as _, - }; - &mut timeout_val - }; - - libc::FD_ZERO(&mut read_fd_set); - libc::FD_SET(fd, &mut read_fd_set); - let ret = libc::select( - fd + 1, - &mut read_fd_set, - std::ptr::null_mut(), - std::ptr::null_mut(), - timeout, - ); - if ret < 0 { - Err(io::Error::last_os_error()) - } else { - Ok(libc::FD_ISSET(fd, &read_fd_set)) +fn read_single_char(input: &mut T) -> io::Result> { + let original = unsafe { libc::fcntl(input.as_raw_fd(), libc::F_GETFL) }; + c_result(|| unsafe { + libc::fcntl( + input.as_raw_fd(), + libc::F_SETFL, + original | libc::O_NONBLOCK, + ) + })?; + let mut buf = [0u8; 1]; + let result = input.read_exact(&mut buf); + c_result(|| unsafe { libc::fcntl(input.as_raw_fd(), libc::F_SETFL, original) })?; + match result { + Ok(()) => { + let [c] = buf; + Ok(Some(c as char)) } - } -} - -fn select_or_poll_term_fd(fd: RawFd, timeout: i32) -> io::Result { - // There is a bug on macos that ttys cannot be polled, only select() - // works. However given how problematic select is in general, we - // normally want to use poll there too. - #[cfg(target_os = "macos")] - { - if unsafe { libc::isatty(fd) == 1 } { - return select_fd(fd, timeout); + Err(err) => { + if err.kind() == io::ErrorKind::WouldBlock { + Ok(None) + } else { + Err(err) + } } } - poll_fd(fd, timeout) } -fn read_single_char(fd: RawFd) -> io::Result> { - // timeout of zero means that it will not block - let is_ready = select_or_poll_term_fd(fd, 0)?; - - if is_ready { - // if there is something to be read, take 1 byte from it - let mut buf: [u8; 1] = [0]; - - read_bytes(fd, &mut buf, 1)?; - Ok(Some(buf[0] as char)) - } else { - //there is nothing to be read - Ok(None) - } -} - -// Similar to libc::read. Read count bytes into slice buf from descriptor fd. -// If successful, return the number of bytes read. -// Will return an error if nothing was read, i.e when called at end of file. -fn read_bytes(fd: RawFd, buf: &mut [u8], count: u8) -> io::Result { - let read = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut _, count as usize) }; - if read < 0 { - Err(io::Error::last_os_error()) - } else if read == 0 { - Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "Reached end of file", - )) - } else if buf[0] == b'\x03' { - Err(io::Error::new( +fn read_bytes(input: &mut impl Read, buf: &mut [u8]) -> io::Result<()> { + input.read_exact(buf)?; + match buf { + [b'\x03', ..] => Err(io::Error::new( io::ErrorKind::Interrupted, "read interrupted", - )) - } else { - Ok(read as u8) + )), + _ => Ok(()), } } -fn read_single_key_impl(fd: RawFd) -> Result { +fn read_single_key_impl(input: &mut T) -> Result { loop { - match read_single_char(fd)? { - Some('\x1b') => { + let mut buf = [0u8; 1]; + input.read_exact(&mut buf)?; + let [c] = buf; + match c { + b'\x1b' => { // Escape was read, keep reading in case we find a familiar key - break if let Some(c1) = read_single_char(fd)? { + break if let Some(c1) = read_single_char(input)? { if c1 == '[' { - if let Some(c2) = read_single_char(fd)? { + if let Some(c2) = read_single_char(input)? { match c2 { 'A' => Ok(Key::ArrowUp), 'B' => Ok(Key::ArrowDown), @@ -258,7 +210,7 @@ fn read_single_key_impl(fd: RawFd) -> Result { 'F' => Ok(Key::End), 'Z' => Ok(Key::BackTab), _ => { - let c3 = read_single_char(fd)?; + let c3 = read_single_char(input)?; if let Some(c3) = c3 { if c3 == '~' { match c2 { @@ -294,48 +246,42 @@ fn read_single_key_impl(fd: RawFd) -> Result { Ok(Key::Escape) }; } - Some(c) => { + c => { let byte = c as u8; - let mut buf: [u8; 4] = [byte, 0, 0, 0]; break if byte & 224u8 == 192u8 { // a two byte unicode character - read_bytes(fd, &mut buf[1..], 1)?; - Ok(key_from_utf8(&buf[..2])) + let mut buf: [u8; 2] = [byte, 0]; + read_bytes(input, &mut buf[1..])?; + Ok(key_from_utf8(&buf)) } else if byte & 240u8 == 224u8 { // a three byte unicode character - read_bytes(fd, &mut buf[1..], 2)?; - Ok(key_from_utf8(&buf[..3])) + let mut buf: [u8; 3] = [byte, 0, 0]; + read_bytes(input, &mut buf[1..])?; + Ok(key_from_utf8(&buf)) } else if byte & 248u8 == 240u8 { // a four byte unicode character - read_bytes(fd, &mut buf[1..], 3)?; - Ok(key_from_utf8(&buf[..4])) + let mut buf: [u8; 4] = [byte, 0, 0, 0]; + read_bytes(input, &mut buf[1..])?; + Ok(key_from_utf8(&buf)) } else { - Ok(match c { + Ok(match c as char { '\n' | '\r' => Key::Enter, '\x7f' => Key::Backspace, '\t' => Key::Tab, '\x01' => Key::Home, // Control-A (home) '\x05' => Key::End, // Control-E (end) '\x08' => Key::Backspace, // Control-H (8) (Identical to '\b') - _ => Key::Char(c), + c => Key::Char(c), }) }; } - None => { - // there is no subsequent byte ready to be read, block and wait for input - // negative timeout means that it will block indefinitely - match select_or_poll_term_fd(fd, -1) { - Ok(_) => continue, - Err(_) => break Err(io::Error::last_os_error()), - } - } } } } pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result { - let input = Input::::new()?; + let mut input = Input::::new()?; let mut termios = core::mem::MaybeUninit::uninit(); c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?; @@ -344,7 +290,7 @@ pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result { unsafe { libc::cfmakeraw(&mut termios) }; termios.c_oflag = original.c_oflag; c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &termios) })?; - let rv: io::Result = read_single_key_impl(input.as_raw_fd()); + let rv: io::Result = read_single_key_impl(&mut input); c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &original) })?; // if the user hit ^C we want to signal SIGINT to outselves.