diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs index b43690b248..249505a992 100644 --- a/src/net/tcp/socket.rs +++ b/src/net/tcp/socket.rs @@ -106,6 +106,11 @@ impl TcpSocket { sys::tcp::set_linger(self.sys, dur) } + /// Gets the value of `SO_LINGER` on this socket + pub fn get_linger(&self) -> io::Result> { + sys::tcp::get_linger(self.sys) + } + /// Sets the value of `SO_RCVBUF` on this socket. pub fn set_recv_buffer_size(&self, size: u32) -> io::Result<()> { sys::tcp::set_recv_buffer_size(self.sys, size) diff --git a/src/sys/shell/tcp.rs b/src/sys/shell/tcp.rs index b67e33db3c..f51d6ca04a 100644 --- a/src/sys/shell/tcp.rs +++ b/src/sys/shell/tcp.rs @@ -50,6 +50,10 @@ pub(crate) fn set_linger(_: TcpSocket, _: Option) -> io::Result<()> { os_required!(); } +pub(crate) fn get_linger(_: TcpSocket) -> io::Result> { + os_required!(); +} + pub(crate) fn set_recv_buffer_size(_: TcpSocket, _: u32) -> io::Result<()> { os_required!(); } diff --git a/src/sys/unix/tcp.rs b/src/sys/unix/tcp.rs index 8d5f5aa5c1..86d2734794 100644 --- a/src/sys/unix/tcp.rs +++ b/src/sys/unix/tcp.rs @@ -129,6 +129,25 @@ pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result )).map(|_| ()) } +pub(crate) fn get_linger(socket: TcpSocket) -> io::Result> { + let mut val: libc::linger = unsafe { std::mem::zeroed() }; + let mut len = mem::size_of::() as libc::socklen_t; + + syscall!(getsockopt( + socket, + libc::SOL_SOCKET, + libc::SO_LINGER, + &mut val as *mut _ as *mut _, + &mut len, + ))?; + + if val.l_onoff == 0 { + Ok(None) + } else { + Ok(Some(Duration::from_secs(val.l_linger as u64))) + } +} + pub(crate) fn set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { let size = size.try_into().ok().unwrap_or_else(i32::max_value); syscall!(setsockopt( diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index e14f1c8bd9..4baf5562a2 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -128,8 +128,6 @@ pub(crate) fn get_localaddr(socket: TcpSocket) -> io::Result { } }, } - - } pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result<()> { @@ -150,6 +148,28 @@ pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result } } +pub(crate) fn get_linger(socket: TcpSocket) -> io::Result> { + let mut val: linger = unsafe { std::mem::zeroed() }; + let mut len = size_of::() as c_int; + + match unsafe { getsockopt( + socket, + SOL_SOCKET, + SO_LINGER, + &mut val as *mut _ as *mut _, + &mut len, + ) } { + SOCKET_ERROR => Err(io::Error::last_os_error()), + _ => { + if val.l_onoff == 0 { + Ok(None) + } else { + Ok(Some(Duration::from_secs(val.l_linger as u64))) + } + }, + } +} + pub(crate) fn set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { let size = size.try_into().ok().unwrap_or_else(i32::max_value); diff --git a/tests/tcp_socket.rs b/tests/tcp_socket.rs index bb57ade71a..4b0a2a93f4 100644 --- a/tests/tcp_socket.rs +++ b/tests/tcp_socket.rs @@ -1,7 +1,7 @@ #![cfg(all(feature = "os-poll", feature = "tcp"))] use mio::net::TcpSocket; -use std::io; +use std::time::Duration; #[test] fn is_send_and_sync() { @@ -58,6 +58,22 @@ fn get_localaddr() { let _ = socket.listen(128).unwrap(); } +#[test] +fn set_linger() { + let addr = "127.0.0.1:0".parse().unwrap(); + + let socket = TcpSocket::new_v4().unwrap(); + socket.set_linger(Some(Duration::from_secs(1))).unwrap(); + assert_eq!(socket.get_linger().unwrap().unwrap().as_secs(), 1); + + let _ = socket.set_linger(None); + assert_eq!(socket.get_linger().unwrap(), None); + + socket.bind(addr).unwrap(); + + let _ = socket.listen(128).unwrap(); +} + #[test] fn send_buffer_size_roundtrips() { test_buffer_sizes(