From b41a022b2242eef1969c70c8ba93e04c528dba47 Mon Sep 17 00:00:00 2001 From: Zeki Sherif <9832640+zekisherif@users.noreply.github.com> Date: Thu, 29 Oct 2020 10:35:40 -0700 Subject: [PATCH] Add getter for local_addr on TcpSocket (#1379) --- src/net/tcp/socket.rs | 7 +++++++ src/sys/shell/tcp.rs | 4 ++++ src/sys/unix/tcp.rs | 13 +++++++++++++ src/sys/windows/tcp.rs | 32 ++++++++++++++++++++++++++++++-- tests/tcp_socket.rs | 19 +++++++++++++++++++ 5 files changed, 73 insertions(+), 2 deletions(-) diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs index efc23490a..f3e27c370 100644 --- a/src/net/tcp/socket.rs +++ b/src/net/tcp/socket.rs @@ -105,6 +105,13 @@ impl TcpSocket { pub fn set_linger(&self, dur: Option) -> io::Result<()> { sys::tcp::set_linger(self.sys, dur) } + + /// Returns the local address of this socket + /// + /// Will return `Err` result in windows if called before calling `bind` + pub fn get_localaddr(&self) -> io::Result { + sys::tcp::get_localaddr(self.sys) + } } impl Drop for TcpSocket { diff --git a/src/sys/shell/tcp.rs b/src/sys/shell/tcp.rs index b3c4e2999..3073d42f7 100644 --- a/src/sys/shell/tcp.rs +++ b/src/sys/shell/tcp.rs @@ -53,3 +53,7 @@ pub(crate) fn set_linger(_: TcpSocket, _: Option) -> io::Result<()> { pub fn accept(_: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { os_required!(); } + +pub(crate) fn get_localaddr(_: TcpSocket) -> io::Result { + os_required!(); +} diff --git a/src/sys/unix/tcp.rs b/src/sys/unix/tcp.rs index a623cf391..65b7400e9 100644 --- a/src/sys/unix/tcp.rs +++ b/src/sys/unix/tcp.rs @@ -103,6 +103,19 @@ pub(crate) fn get_reuseport(socket: TcpSocket) -> io::Result { Ok(optval != 0) } +pub(crate) fn get_localaddr(socket: TcpSocket) -> io::Result { + let mut addr: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + let mut length = size_of::() as libc::socklen_t; + + syscall!(getsockname( + socket, + &mut addr as *mut _ as *mut _, + &mut length + ))?; + + unsafe { to_socket_addr(&addr) } +} + pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result<()> { let val: libc::linger = libc::linger { l_onoff: if dur.is_some() { 1 } else { 0 }, diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index 42443f2af..b78d86479 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -1,14 +1,17 @@ use std::io; use std::mem::size_of; -use std::net::{self, SocketAddr}; +use std::net::{self, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::time::Duration; use std::os::windows::io::FromRawSocket; use std::os::windows::raw::SOCKET as StdSocket; // winapi uses usize, stdlib uses u32/u64. use winapi::ctypes::{c_char, c_int, c_ushort}; +use winapi::shared::ws2def::{SOCKADDR_STORAGE, AF_INET, SOCKADDR_IN}; +use winapi::shared::ws2ipdef::SOCKADDR_IN6_LH; + use winapi::shared::minwindef::{BOOL, TRUE, FALSE}; use winapi::um::winsock2::{ - self, closesocket, linger, setsockopt, getsockopt, PF_INET, PF_INET6, SOCKET, SOCKET_ERROR, + self, closesocket, linger, setsockopt, getsockopt, getsockname, PF_INET, PF_INET6, SOCKET, SOCKET_ERROR, SOCK_STREAM, SOL_SOCKET, SO_LINGER, SO_REUSEADDR, }; @@ -103,6 +106,31 @@ pub(crate) fn get_reuseaddr(socket: TcpSocket) -> io::Result { } } +pub(crate) fn get_localaddr(socket: TcpSocket) -> io::Result { + let mut addr: SOCKADDR_STORAGE = unsafe { std::mem::zeroed() }; + let mut length = std::mem::size_of_val(&addr) as c_int; + + match unsafe { getsockname( + socket, + &mut addr as *mut _ as *mut _, + &mut length + ) } { + SOCKET_ERROR => Err(io::Error::last_os_error()), + _ => { + let storage: *const SOCKADDR_STORAGE = (&addr) as *const _; + if addr.ss_family as c_int == AF_INET { + let sock_addr : SocketAddrV4 = unsafe { *(storage as *const SOCKADDR_IN as *const _) }; + Ok(sock_addr.into()) + } else { + let sock_addr : SocketAddrV6 = unsafe { *(storage as *const SOCKADDR_IN6_LH as *const _) }; + Ok(sock_addr.into()) + } + }, + } + + +} + pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result<()> { let val: linger = linger { l_onoff: if dur.is_some() { 1 } else { 0 }, diff --git a/tests/tcp_socket.rs b/tests/tcp_socket.rs index 536e84b42..0ad2c7ba5 100644 --- a/tests/tcp_socket.rs +++ b/tests/tcp_socket.rs @@ -37,3 +37,22 @@ fn set_reuseport() { let _ = socket.listen(128).unwrap(); } + +#[test] +fn get_localaddr() { + let expected_addr = "127.0.0.1:0".parse().unwrap(); + let socket = TcpSocket::new_v4().unwrap(); + + //Windows doesn't support calling getsockname before calling `bind` + #[cfg(not(windows))] + assert_eq!("0.0.0.0:0", socket.get_localaddr().unwrap().to_string()); + + socket.bind(expected_addr).unwrap(); + + let actual_addr = socket.get_localaddr().unwrap(); + + assert_eq!(expected_addr.ip(), actual_addr.ip()); + assert!(actual_addr.port() > 0); + + let _ = socket.listen(128).unwrap(); +}