From 290c43a96662d54ab7c4b8814e5a9f9a9e523fda Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 11 Nov 2020 02:13:29 -0800 Subject: [PATCH] net: add keepalive support to TcpSocket Signed-off-by: Eliza Weisman --- src/net/mod.rs | 2 +- src/net/tcp/mod.rs | 2 +- src/net/tcp/socket.rs | 239 ++++++++++++++++++++++++++++++++++++++++- src/sys/shell/tcp.rs | 54 ++++++++++ src/sys/unix/tcp.rs | 234 ++++++++++++++++++++++++++++++++++++++-- src/sys/windows/tcp.rs | 87 ++++++++++++++- tests/tcp_socket.rs | 85 ++++++++++++++- 7 files changed, 681 insertions(+), 22 deletions(-) diff --git a/src/net/mod.rs b/src/net/mod.rs index de1be0d0e..4df701d45 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -8,7 +8,7 @@ //! [portability guidelines]: ../struct.Poll.html#portability mod tcp; -pub use self::tcp::{TcpListener, TcpSocket, TcpStream}; +pub use self::tcp::{TcpListener, TcpSocket, TcpStream, TcpKeepalive}; mod udp; pub use self::udp::UdpSocket; diff --git a/src/net/tcp/mod.rs b/src/net/tcp/mod.rs index b39b909b7..4e47aeed0 100644 --- a/src/net/tcp/mod.rs +++ b/src/net/tcp/mod.rs @@ -2,7 +2,7 @@ mod listener; pub use self::listener::TcpListener; mod socket; -pub use self::socket::TcpSocket; +pub use self::socket::{TcpSocket, TcpKeepalive}; mod stream; pub use self::stream::TcpStream; diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs index 249505a99..f1f3e7c5a 100644 --- a/src/net/tcp/socket.rs +++ b/src/net/tcp/socket.rs @@ -22,13 +22,36 @@ pub struct TcpSocket { sys: sys::tcp::TcpSocket, } +/// Configures a socket's TCP keepalive parameters. +#[derive(Debug, Default, Clone)] +pub struct TcpKeepalive { + pub(crate) time: Option, + #[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows", + ))] + pub(crate) interval: Option, + #[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + ))] + pub(crate) retries: Option, +} + impl TcpSocket { /// Create a new IPv4 TCP socket. /// /// This calls `socket(2)`. pub fn new_v4() -> io::Result { - sys::tcp::new_v4_socket().map(|sys| TcpSocket { - sys + sys::tcp::new_v4_socket().map(|sys| { + TcpSocket { sys } }) } @@ -36,8 +59,8 @@ impl TcpSocket { /// /// This calls `socket(2)`. pub fn new_v6() -> io::Result { - sys::tcp::new_v6_socket().map(|sys| TcpSocket { - sys + sys::tcp::new_v6_socket().map(|sys| { + TcpSocket { sys } }) } @@ -168,7 +191,133 @@ impl TcpSocket { pub fn get_send_buffer_size(&self) -> io::Result { sys::tcp::get_send_buffer_size(self.sys) } - + + /// Sets whether keepalive messages are enabled to be sent on this socket. + /// + /// This will set the `SO_KEEPALIVE` option on this socket. + pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> { + sys::tcp::set_keepalive(self.sys, keepalive) + } + + /// Returns whether or not TCP keepalive probes will be sent by this socket. + pub fn get_keepalive(&self) -> io::Result { + sys::tcp::get_keepalive(self.sys) + } + + /// Sets parameters configuring TCP keepalive probes for this socket. + /// + /// The supported parameters depend on the operating system, and are + /// configured using the [`TcpKeepalive`] struct. At a minimum, all systems + /// support configuring the [keepalive time]: the time after which the OS + /// will start sending keepalive messages on an idle connection. + /// + /// # Notes + /// + /// * This will enable TCP keepalive on this socket, if it is not already + /// enabled. + /// * On some platforms, such as Windows, any keepalive parameters *not* + /// configured by the `TcpKeepalive` struct passed to this function may be + /// overwritten with their default values. Therefore, this function should + /// either only be called once per socket, or the same parameters should + /// be passed every time it is called. + /// + /// # Examples + /// ``` + /// use mio::net::{TcpSocket, TcpKeepalive}; + /// use std::time::Duration; + /// + /// # fn main() -> Result<(), std::io::Error> { + /// let socket = TcpSocket::new_v6()?; + /// let keepalive = TcpKeepalive::default() + /// .with_time(Duration::from_secs(4)); + /// // Depending on the target operating system, we may also be able to + /// // configure the keepalive probe interval and/or the number of retries + /// // here as well. + /// + /// socket.set_keepalive_params(keepalive)?; + /// # Ok(()) } + /// ``` + /// + /// [`TcpKeepalive`]: ../struct.TcpKeepalive.html + /// [keepalive time]: ../struct.TcpKeepalive.html#method.with_time + pub fn set_keepalive_params(&self, keepalive: TcpKeepalive) -> io::Result<()> { + self.set_keepalive(true)?; + sys::tcp::set_keepalive_params(self.sys, keepalive) + } + + /// Returns the amount of time after which TCP keepalive probes will be sent + /// on idle connections. + /// + /// If `None`, then keepalive messages are disabled. + /// + /// This returns the value of `SO_KEEPALIVE` + `IPPROTO_TCP` on OpenBSD, + /// NetBSD, and Haiku, `TCP_KEEPALIVE` on macOS and iOS, and `TCP_KEEPIDLE` + /// on all other Unix operating systems. On Windows, it is not possible to + /// access the value of TCP keepalive parameters after they have been set. + /// + /// Some platforms specify this value in seconds, so sub-second + /// specifications may be omitted. + #[cfg_attr(docsrs, doc(cfg(not(target_os = "windows"))))] + #[cfg(not(target_os = "windows"))] + pub fn get_keepalive_time(&self) -> io::Result> { + sys::tcp::get_keepalive_time(self.sys) + } + + /// Returns the time interval between TCP keepalive probes, if TCP keepalive is + /// enabled on this socket. + /// + /// If `None`, then keepalive messages are disabled. + /// + /// This returns the value of `TCP_KEEPINTVL` on supported Unix operating + /// systems. On Windows, it is not possible to access the value of TCP + /// keepalive parameters after they have been set.. + /// + /// Some platforms specify this value in seconds, so sub-second + /// specifications may be omitted. + #[cfg_attr(docsrs, doc(cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + ))))] + #[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + ))] + pub fn get_keepalive_interval(&self) -> io::Result> { + sys::tcp::get_keepalive_interval(self.sys) + } + + /// Returns the maximum number of TCP keepalive probes that will be sent before + /// dropping a connection, if TCP keepalive is enabled on this socket. + /// + /// If `None`, then keepalive messages are disabled. + /// + /// This returns the value of `TCP_KEEPCNT` on Unix operating systems that + /// support this option. On Windows, it is not possible to access the value + /// of TCP keepalive parameters after they have been set. + #[cfg_attr(docsrs, doc(cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + ))))] + #[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + ))] + pub fn get_keepalive_retries(&self) -> io::Result> { + sys::tcp::get_keepalive_retries(self.sys) + } + /// Returns the local address of this socket /// /// Will return `Err` result in windows if called before calling `bind` @@ -238,3 +387,83 @@ impl FromRawSocket for TcpSocket { TcpSocket { sys: socket as sys::tcp::TcpSocket } } } + +impl TcpKeepalive { + // Sets the amount of time after which TCP keepalive probes will be sent + /// on idle connections. + /// + /// This will set the value of `SO_KEEPALIVE` + `IPPROTO_TCP` on OpenBSD, + /// NetBSD, and Haiku, `TCP_KEEPALIVE` on macOS and iOS, and `TCP_KEEPIDLE` + /// on all other Unix operating systems. On Windows, this sets the value of + /// the `tcp_keepalive` struct's `keepalivetime` field. + /// + /// Some platforms specify this value in seconds, so sub-second + /// specifications may be omitted. + pub fn with_time(self, time: Duration) -> Self { + Self { + time: Some(time), + ..self + } + } + + /// Sets the time interval between TCP keepalive probes. + /// This sets the value of `TCP_KEEPINTVL` on supported Unix operating + /// systems. On Windows, this sets the value of the `tcp_keepalive` struct's + /// `keepaliveinterval` field. + /// + /// Some platforms specify this value in seconds, so sub-second + /// specifications may be omitted. + #[cfg_attr(docsrs, doc(cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows" + ))))] + #[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows" + ))] + pub fn with_interval(self, interval: Duration) -> Self { + Self { + interval: Some(interval), + ..self + } + } + + /// Sets the maximum number of TCP keepalive probes that will be sent before + /// dropping a connection, if TCP keepalive is enabled on this socket. + /// + /// This will set the value of `TCP_KEEPCNT` on Unix operating systems that + /// support this option. + #[cfg_attr(docsrs, doc(cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + ))))] + #[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + ))] + pub fn with_retries(self, retries: u32) -> Self { + Self { + retries: Some(retries), + ..self + } + } + + /// Returns a new, empty set of TCP keepalive parameters. + pub fn new() -> Self { + Self::default() + } +} \ No newline at end of file diff --git a/src/sys/shell/tcp.rs b/src/sys/shell/tcp.rs index f51d6ca04..2017bda30 100644 --- a/src/sys/shell/tcp.rs +++ b/src/sys/shell/tcp.rs @@ -1,6 +1,7 @@ use std::io; use std::net::{self, SocketAddr}; use std::time::Duration; +use crate::net::TcpKeepalive; pub(crate) type TcpSocket = i32; @@ -70,6 +71,59 @@ pub(crate) fn get_send_buffer_size(_: TcpSocket) -> io::Result { os_required!(); } +pub(crate) fn set_keepalive(_: TcpSocket, _: bool) -> io::Result<()> { + os_required!(); +} + +pub(crate) fn get_keepalive(_: TcpSocket) -> io::Result { + os_required!(); +} + +#[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows", +))] +pub(crate) fn set_keepalive_params(_: TcpSocket, _: TcpKeepalive) -> io::Result<()> { + os_required!() +} + +#[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", +))] +pub(crate) fn get_keepalive_time(_: TcpSocket) -> io::Result> { + os_required!() +} + +#[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", +))] +pub(crate) fn get_keepalive_interval(_: TcpSocket) -> io::Result> { + os_required!() +} + +#[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", +))] +pub(crate) fn get_keepalive_retries(_: TcpSocket) -> io::Result> { + os_required!() +} + pub fn accept(_: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { os_required!(); } diff --git a/src/sys/unix/tcp.rs b/src/sys/unix/tcp.rs index b2333b504..9e1d70069 100644 --- a/src/sys/unix/tcp.rs +++ b/src/sys/unix/tcp.rs @@ -1,13 +1,26 @@ -use std::io; use std::convert::TryInto; +use std::io; use std::mem; use std::mem::{size_of, MaybeUninit}; use std::net::{self, SocketAddr}; -use std::time::Duration; use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::time::Duration; use crate::sys::unix::net::{new_socket, socket_addr, to_socket_addr}; - +use crate::net::TcpKeepalive; + +#[cfg(any(target_os = "openbsd", target_os = "netbsd", target_os = "haiku"))] +use libc::SO_KEEPALIVE as KEEPALIVE_TIME; +#[cfg(any(target_os = "macos", target_os = "ios"))] +use libc::TCP_KEEPALIVE as KEEPALIVE_TIME; +#[cfg(not(any( + target_os = "macos", + target_os = "ios", + target_os = "openbsd", + target_os = "netbsd", + target_os = "haiku" +)))] +use libc::TCP_KEEPIDLE as KEEPALIVE_TIME; pub type TcpSocket = libc::c_int; pub(crate) fn new_v4_socket() -> io::Result { @@ -55,7 +68,8 @@ pub(crate) fn set_reuseaddr(socket: TcpSocket, reuseaddr: bool) -> io::Result<() libc::SO_REUSEADDR, &val as *const libc::c_int as *const libc::c_void, size_of::() as libc::socklen_t, - )).map(|_| ()) + )) + .map(|_| ()) } pub(crate) fn get_reuseaddr(socket: TcpSocket) -> io::Result { @@ -83,7 +97,8 @@ pub(crate) fn set_reuseport(socket: TcpSocket, reuseport: bool) -> io::Result<() libc::SO_REUSEPORT, &val as *const libc::c_int as *const libc::c_void, size_of::() as libc::socklen_t, - )).map(|_| ()) + )) + .map(|_| ()) } #[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] @@ -118,7 +133,9 @@ pub(crate) fn get_localaddr(socket: TcpSocket) -> io::Result { pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result<()> { let val: libc::linger = libc::linger { l_onoff: if dur.is_some() { 1 } else { 0 }, - l_linger: dur.map(|dur| dur.as_secs() as libc::c_int).unwrap_or_default(), + l_linger: dur + .map(|dur| dur.as_secs() as libc::c_int) + .unwrap_or_default(), }; syscall!(setsockopt( socket, @@ -126,7 +143,8 @@ pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result libc::SO_LINGER, &val as *const libc::linger as *const libc::c_void, size_of::() as libc::socklen_t, - )).map(|_| ()) + )) + .map(|_| ()) } pub(crate) fn get_linger(socket: TcpSocket) -> io::Result> { @@ -160,10 +178,9 @@ pub(crate) fn set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<( .map(|_| ()) } -pub(crate) fn get_recv_buffer_size(socket: TcpSocket) -> io::Result { +pub(crate) fn get_recv_buffer_size(socket: TcpSocket) -> io::Result { let mut optval: libc::c_int = 0; let mut optlen = size_of::() as libc::socklen_t; - syscall!(getsockopt( socket, libc::SOL_SOCKET, @@ -187,7 +204,7 @@ pub(crate) fn set_send_buffer_size(socket: TcpSocket, size: u32) -> io::Result<( .map(|_| ()) } -pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { +pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { let mut optval: libc::c_int = 0; let mut optlen = size_of::() as libc::socklen_t; @@ -202,6 +219,203 @@ pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { Ok(optval as u32) } +pub(crate) fn set_keepalive(socket: TcpSocket, keepalive: bool) -> io::Result<()> { + let val: libc::c_int = if keepalive { 1 } else { 0 }; + syscall!(setsockopt( + socket, + libc::SOL_SOCKET, + libc::SO_KEEPALIVE, + &val as *const _ as *const libc::c_void, + size_of::() as libc::socklen_t + )) + .map(|_| ()) +} + +pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result { + let mut optval: libc::c_int = 0; + let mut optlen = mem::size_of::() as libc::socklen_t; + + syscall!(getsockopt( + socket, + libc::SOL_SOCKET, + libc::SO_KEEPALIVE, + &mut optval as *mut _ as *mut _, + &mut optlen, + ))?; + + Ok(optval != 0) +} + +pub(crate) fn set_keepalive_params(socket: TcpSocket, keepalive: TcpKeepalive) -> io::Result<()> { + if let Some(dur) = keepalive.time { + set_keepalive_time(socket, dur)?; + } + + #[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + ))] + { + if let Some(dur) = keepalive.interval { + set_keepalive_interval(socket, dur)?; + } + + if let Some(retries) = keepalive.retries { + set_keepalive_retries(socket, retries)?; + } + } + + + Ok(()) +} + +fn set_keepalive_time(socket: TcpSocket, time: Duration) -> io::Result<()> { + let time_secs = time + .as_secs() + .try_into() + .ok() + .unwrap_or_else(i32::max_value); + syscall!(setsockopt( + socket, + libc::IPPROTO_TCP, + KEEPALIVE_TIME, + &(time_secs as libc::c_int) as *const _ as *const libc::c_void, + size_of::() as libc::socklen_t + )) + .map(|_| ()) +} + +pub(crate) fn get_keepalive_time(socket: TcpSocket) -> io::Result> { + if !get_keepalive(socket)? { + return Ok(None); + } + + let mut optval: libc::c_int = 0; + let mut optlen = mem::size_of::() as libc::socklen_t; + syscall!(getsockopt( + socket, + libc::IPPROTO_TCP, + KEEPALIVE_TIME, + &mut optval as *mut _ as *mut _, + &mut optlen, + ))?; + + Ok(Some(Duration::from_secs(optval as u64))) +} + +/// Linux, FreeBSD, and NetBSD support setting the keepalive interval via +/// `TCP_KEEPINTVL`. +/// See: +/// - https://man7.org/linux/man-pages/man7/tcp.7.html +/// - https://www.freebsd.org/cgi/man.cgi?query=tcp#end +/// - http://man.netbsd.org/tcp.4#DESCRIPTION +/// +/// OpenBSD does not: +/// https://man.openbsd.org/tcp +#[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", +))] +fn set_keepalive_interval(socket: TcpSocket, interval: Duration) -> io::Result<()> { + let interval_secs = interval + .as_secs() + .try_into() + .ok() + .unwrap_or_else(i32::max_value); + syscall!(setsockopt( + socket, + libc::IPPROTO_TCP, + libc::TCP_KEEPINTVL, + &(interval_secs as libc::c_int) as *const _ as *const libc::c_void, + size_of::() as libc::socklen_t + )) + .map(|_| ()) +} + +#[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", +))] +pub(crate) fn get_keepalive_interval(socket: TcpSocket) -> io::Result> { + if !get_keepalive(socket)? { + return Ok(None); + } + + let mut optval: libc::c_int = 0; + let mut optlen = mem::size_of::() as libc::socklen_t; + syscall!(getsockopt( + socket, + libc::IPPROTO_TCP, + libc::TCP_KEEPINTVL, + &mut optval as *mut _ as *mut _, + &mut optlen, + ))?; + + Ok(Some(Duration::from_secs(optval as u64))) +} + +/// Linux, macOS/iOS, FreeBSD, and NetBSD support setting the number of TCP +/// keepalive retries via `TCP_KEEPCNT`. +/// See: +/// - https://man7.org/linux/man-pages/man7/tcp.7.html +/// - https://www.freebsd.org/cgi/man.cgi?query=tcp#end +/// - http://man.netbsd.org/tcp.4#DESCRIPTION +/// +/// OpenBSD does not: +/// https://man.openbsd.org/tcp +#[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", +))] +fn set_keepalive_retries(socket: TcpSocket, retries: u32) -> io::Result<()> { + let retries = retries.try_into().ok().unwrap_or_else(i32::max_value); + syscall!(setsockopt( + socket, + libc::IPPROTO_TCP, + libc::TCP_KEEPCNT, + &(retries as libc::c_int) as *const _ as *const libc::c_void, + size_of::() as libc::socklen_t + )) + .map(|_| ()) +} + +#[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", +))] +pub(crate) fn get_keepalive_retries(socket: TcpSocket) -> io::Result> { + if !get_keepalive(socket)? { + return Ok(None); + } + + let mut optval: libc::c_int = 0; + let mut optlen = mem::size_of::() as libc::socklen_t; + syscall!(getsockopt( + socket, + libc::IPPROTO_TCP, + libc::TCP_KEEPCNT, + &mut optval as *mut _ as *mut _, + &mut optlen, + ))?; + + Ok(Some(optval as u32)) +} + pub fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { let mut addr: MaybeUninit = MaybeUninit::uninit(); let mut length = size_of::() as libc::socklen_t; diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index 5ac0c3213..6757b4476 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -1,22 +1,25 @@ use std::io; +use std::convert::TryInto; use std::mem::size_of; use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::time::Duration; -use std::convert::TryInto; +use std::ptr; use std::os::windows::io::FromRawSocket; use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64. -use winapi::ctypes::{c_char, c_int, c_ushort}; +use winapi::ctypes::{c_char, c_int, c_ushort, c_ulong}; use winapi::shared::ws2def::{SOCKADDR_STORAGE, AF_INET, AF_INET6, SOCKADDR_IN}; use winapi::shared::ws2ipdef::SOCKADDR_IN6_LH; +use winapi::shared::mstcpip; -use winapi::shared::minwindef::{BOOL, TRUE, FALSE}; +use winapi::shared::minwindef::{BOOL, TRUE, FALSE, DWORD, LPVOID, LPDWORD}; use winapi::um::winsock2::{ self, closesocket, linger, setsockopt, getsockopt, getsockname, PF_INET, PF_INET6, SOCKET, SOCKET_ERROR, - SOCK_STREAM, SOL_SOCKET, SO_LINGER, SO_REUSEADDR, SO_RCVBUF, SO_SNDBUF, + SOCK_STREAM, SOL_SOCKET, SO_LINGER, SO_REUSEADDR, SO_RCVBUF, SO_SNDBUF, SO_KEEPALIVE, WSAIoctl, LPWSAOVERLAPPED, }; use crate::sys::windows::net::{init, new_socket, socket_addr}; +use crate::net::TcpKeepalive; pub(crate) type TcpSocket = SOCKET; @@ -238,6 +241,82 @@ pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { } } +pub(crate) fn set_keepalive(socket: TcpSocket, keepalive: bool) -> io::Result<()> { + let val: BOOL = if keepalive { TRUE } else { FALSE }; + match unsafe { setsockopt( + socket, + SOL_SOCKET, + SO_KEEPALIVE, + &val as *const _ as *const c_char, + size_of::() as c_int + ) } { + SOCKET_ERROR => Err(io::Error::last_os_error()), + _ => Ok(()), + } +} + +pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result { + let mut optval: c_char = 0; + let mut optlen = size_of::() as c_int; + + match unsafe { getsockopt( + socket, + SOL_SOCKET, + SO_KEEPALIVE, + &mut optval as *mut _ as *mut _, + &mut optlen, + ) } { + SOCKET_ERROR => Err(io::Error::last_os_error()), + _ => Ok(optval != FALSE as c_char), + } +} + +pub(crate) fn set_keepalive_params(socket: TcpSocket, keepalive: TcpKeepalive) -> io::Result<()> { + /// Windows configures keepalive time/interval in a u32 of milliseconds. + fn dur_to_ulong_ms(dur: Duration) -> c_ulong { + dur.as_millis().try_into().ok().unwrap_or_else(u32::max_value) + } + + // If any of the fields on the `tcp_keepalive` struct were not provided by + // the user, just leaving them zero will clobber any existing value. + // Unfortunately, we can't access the current value, so we will use the + // defaults if a value for the time or interval was not not provided. + let time = keepalive.time.unwrap_or_else(|| { + // The default value is two hours, as per + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-keepalive-vals + let two_hours = 2 * 60 * 60; + Duration::from_secs(two_hours) + }); + + let interval = keepalive.interval.unwrap_or_else(|| { + // The default value is one second, as per + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-keepalive-vals + Duration::from_secs(1) + }); + + let mut keepalive = mstcpip::tcp_keepalive { + // Enable keepalive + onoff: 1, + keepalivetime: dur_to_ulong_ms(time), + keepaliveinterval: dur_to_ulong_ms(interval), + }; + + let mut out = 0; + match unsafe { WSAIoctl( + socket, + mstcpip::SIO_KEEPALIVE_VALS, + &mut keepalive as *mut _ as LPVOID, + size_of::() as DWORD, + ptr::null_mut() as LPVOID, + 0 as DWORD, + &mut out as *mut _ as LPDWORD, + 0 as LPWSAOVERLAPPED, + None, + ) } { + 0 => Ok(()), + _ => Err(io::Error::last_os_error()) + } +} pub(crate) fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { // The non-blocking state of `listener` is inherited. See diff --git a/tests/tcp_socket.rs b/tests/tcp_socket.rs index e5fc537b1..4dac73f01 100644 --- a/tests/tcp_socket.rs +++ b/tests/tcp_socket.rs @@ -1,6 +1,6 @@ #![cfg(all(feature = "os-poll", feature = "tcp"))] -use mio::net::TcpSocket; +use mio::net::{TcpKeepalive, TcpSocket}; use std::io; use std::time::Duration; @@ -40,6 +40,89 @@ fn set_reuseport() { let _ = socket.listen(128).unwrap(); } +#[test] +fn set_keepalive() { + let addr = "127.0.0.1:0".parse().unwrap(); + + let socket = TcpSocket::new_v4().unwrap(); + socket.set_keepalive(false).unwrap(); + assert_eq!(false, socket.get_keepalive().unwrap()); + + socket.set_keepalive(true).unwrap(); + assert_eq!(true, socket.get_keepalive().unwrap()); + + socket.bind(addr).unwrap(); + + let _ = socket.listen(128).unwrap(); +} + +#[test] +fn set_keepalive_time() { + let dur = Duration::from_secs(4); // Chosen by fair dice roll, guaranteed to be random + let addr = "127.0.0.1:0".parse().unwrap(); + + let socket = TcpSocket::new_v4().unwrap(); + socket + .set_keepalive_params(TcpKeepalive::default().with_time(dur)) + .unwrap(); + + // It's not possible to access keepalive parameters on Windows... + #[cfg(not(target_os = "windows"))] + assert_eq!(Some(dur), socket.get_keepalive_time().unwrap()); + + socket.bind(addr).unwrap(); + + let _ = socket.listen(128).unwrap(); +} + +#[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows" +))] +#[test] +fn set_keepalive_interval() { + let dur = Duration::from_secs(4); // Chosen by fair dice roll, guaranteed to be random + let addr = "127.0.0.1:0".parse().unwrap(); + + let socket = TcpSocket::new_v4().unwrap(); + socket + .set_keepalive_params(TcpKeepalive::default().with_interval(dur)) + .unwrap(); + // It's not possible to access keepalive parameters on Windows... + #[cfg(not(target_os = "windows"))] + assert_eq!(Some(dur), socket.get_keepalive_interval().unwrap()); + + socket.bind(addr).unwrap(); + + let _ = socket.listen(128).unwrap(); +} + +#[cfg(any( + target_os = "linux", + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", +))] +#[test] +fn set_keepalive_retries() { + let addr = "127.0.0.1:0".parse().unwrap(); + + let socket = TcpSocket::new_v4().unwrap(); + socket + .set_keepalive_params(TcpKeepalive::default().with_retries(16)) + .unwrap(); + assert_eq!(Some(16), socket.get_keepalive_retries().unwrap()); + + socket.bind(addr).unwrap(); + + let _ = socket.listen(128).unwrap(); +} + #[test] fn get_localaddr() { let expected_addr = "127.0.0.1:0".parse().unwrap();