Skip to content

Commit

Permalink
Improve type safety, extract identical code
Browse files Browse the repository at this point in the history
Avoid fragility of tracking objects and their FDs separately.
  • Loading branch information
tamird committed Dec 16, 2024
1 parent 9759785 commit f2cd7e2
Showing 1 changed file with 71 additions and 47 deletions.
118 changes: 71 additions & 47 deletions src/unix_term.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::env;
use std::fmt::Display;
use std::fs;
use std::io;
use std::io::{BufRead, BufReader};
use std::io::{self, BufRead, BufReader};
use std::mem;
use std::os::unix::io::AsRawFd;
use std::os::fd::{AsRawFd, RawFd};
use std::str;

#[cfg(not(target_os = "macos"))]
Expand All @@ -18,7 +17,7 @@ pub(crate) use crate::common_term::*;
pub(crate) const DEFAULT_WIDTH: u16 = 80;

#[inline]
pub(crate) fn is_a_terminal(out: &Term) -> bool {
pub(crate) fn is_a_terminal(out: &impl AsRawFd) -> bool {
unsafe { libc::isatty(out.as_raw_fd()) != 0 }
}

Expand Down Expand Up @@ -66,41 +65,76 @@ pub(crate) fn terminal_size(out: &Term) -> Option<(u16, u16)> {
}
}

pub(crate) fn read_secure() -> io::Result<String> {
let mut f_tty;
let fd = unsafe {
if libc::isatty(libc::STDIN_FILENO) == 1 {
f_tty = None;
libc::STDIN_FILENO
enum Input<T> {
Stdin(io::Stdin),
File(T),
}

impl Input<fs::File> {
fn new() -> io::Result<Self> {
let stdin = io::stdin();
if is_a_terminal(&stdin) {
Ok(Input::Stdin(stdin))
} else {
let f = fs::OpenOptions::new()
.read(true)
.write(true)
.open("/dev/tty")?;
let fd = f.as_raw_fd();
f_tty = Some(BufReader::new(f));
fd
Ok(Input::File(f))
}
};
}
}

impl Input<BufReader<fs::File>> {
fn new() -> io::Result<Self> {
Ok(match Input::<fs::File>::new()? {
Input::Stdin(s) => Self::Stdin(s),
Input::File(f) => Self::File(BufReader::new(f)),
})
}
}

impl<T: BufRead> Input<T> {
fn read_line(&mut self, buf: &mut String) -> io::Result<usize> {
match self {
Self::Stdin(s) => s.read_line(buf),
Self::File(f) => f.read_line(buf),
}
}
}

impl AsRawFd for Input<fs::File> {
fn as_raw_fd(&self) -> RawFd {
match self {
Self::Stdin(s) => s.as_raw_fd(),
Self::File(f) => f.as_raw_fd(),
}
}
}

impl AsRawFd for Input<BufReader<fs::File>> {
fn as_raw_fd(&self) -> RawFd {
match self {
Self::Stdin(s) => s.as_raw_fd(),
Self::File(f) => f.get_ref().as_raw_fd(),
}
}
}

pub(crate) fn read_secure() -> io::Result<String> {
let mut input = Input::<BufReader<fs::File>>::new()?;

let mut termios = mem::MaybeUninit::uninit();
c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?;
c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?;
let mut termios = unsafe { termios.assume_init() };
let original = termios;
termios.c_lflag &= !libc::ECHO;
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &termios) })?;
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &termios) })?;
let mut rv = String::new();

let read_rv = if let Some(f) = &mut f_tty {
f.read_line(&mut rv)
} else {
io::stdin().read_line(&mut rv)
};

c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &original) })?;
let read_rv = input.read_line(&mut rv);

// Ensure the fd is only closed after everything has been restored.
drop(f_tty);
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &original) })?;

read_rv.map(|_| {
let len = rv.trim_end_matches(&['\r', '\n'][..]).len();
Expand All @@ -109,7 +143,7 @@ pub(crate) fn read_secure() -> io::Result<String> {
})
}

fn poll_fd(fd: i32, timeout: i32) -> io::Result<bool> {
fn poll_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
let mut pollfd = libc::pollfd {
fd,
events: libc::POLLIN,
Expand All @@ -124,7 +158,7 @@ fn poll_fd(fd: i32, timeout: i32) -> io::Result<bool> {
}

#[cfg(target_os = "macos")]
fn select_fd(fd: i32, timeout: i32) -> io::Result<bool> {
fn select_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
unsafe {
let mut read_fd_set: libc::fd_set = mem::zeroed();

Expand Down Expand Up @@ -156,7 +190,7 @@ fn select_fd(fd: i32, timeout: i32) -> io::Result<bool> {
}
}

fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result<bool> {
fn select_or_poll_term_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
// There is a bug on macos that ttys cannot be polled, only select()
// works. However given how problematic select is in general, we
// normally want to use poll there too.
Expand All @@ -169,7 +203,7 @@ fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result<bool> {
poll_fd(fd, timeout)
}

fn read_single_char(fd: i32) -> io::Result<Option<char>> {
fn read_single_char(fd: RawFd) -> io::Result<Option<char>> {
// timeout of zero means that it will not block
let is_ready = select_or_poll_term_fd(fd, 0)?;

Expand All @@ -188,7 +222,7 @@ fn read_single_char(fd: i32) -> io::Result<Option<char>> {
// Similar to libc::read. Read count bytes into slice buf from descriptor fd.
// If successful, return the number of bytes read.
// Will return an error if nothing was read, i.e when called at end of file.
fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result<u8> {
fn read_bytes(fd: RawFd, buf: &mut [u8], count: u8) -> io::Result<u8> {
let read = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut _, count as usize) };
if read < 0 {
Err(io::Error::last_os_error())
Expand All @@ -207,7 +241,7 @@ fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result<u8> {
}
}

fn read_single_key_impl(fd: i32) -> Result<Key, io::Error> {
fn read_single_key_impl(fd: RawFd) -> Result<Key, io::Error> {
loop {
match read_single_char(fd)? {
Some('\x1b') => {
Expand Down Expand Up @@ -301,27 +335,17 @@ fn read_single_key_impl(fd: i32) -> Result<Key, io::Error> {
}

pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result<Key> {
let tty_f;
let fd = unsafe {
if libc::isatty(libc::STDIN_FILENO) == 1 {
libc::STDIN_FILENO
} else {
tty_f = fs::OpenOptions::new()
.read(true)
.write(true)
.open("/dev/tty")?;
tty_f.as_raw_fd()
}
};
let input = Input::<fs::File>::new()?;

let mut termios = core::mem::MaybeUninit::uninit();
c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?;
c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?;
let mut termios = unsafe { termios.assume_init() };
let original = termios;
unsafe { libc::cfmakeraw(&mut termios) };
termios.c_oflag = original.c_oflag;
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &termios) })?;
let rv: io::Result<Key> = read_single_key_impl(fd);
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &original) })?;
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &termios) })?;
let rv: io::Result<Key> = read_single_key_impl(input.as_raw_fd());
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &original) })?;

// if the user hit ^C we want to signal SIGINT to outselves.
if let Err(ref err) = rv {
Expand Down

0 comments on commit f2cd7e2

Please sign in to comment.