Skip to content

Commit

Permalink
Merge pull request nix-rust#1915 from asomers/sockopt-iosafety
Browse files Browse the repository at this point in the history
Add I/O safety to sockopt and some socket functions
  • Loading branch information
asomers authored Aug 7, 2023
2 parents 783e38d + c1317e4 commit ee91423
Show file tree
Hide file tree
Showing 5 changed files with 393 additions and 299 deletions.
8 changes: 5 additions & 3 deletions src/sys/socket/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -875,9 +875,10 @@ pub trait SockaddrLike: private::SockaddrLikePriv {
/// One common use is to match on the family of a union type, like this:
/// ```
/// # use nix::sys::socket::*;
/// # use std::os::unix::io::AsRawFd;
/// let fd = socket(AddressFamily::Inet, SockType::Stream,
/// SockFlag::empty(), None).unwrap();
/// let ss: SockaddrStorage = getsockname(fd).unwrap();
/// let ss: SockaddrStorage = getsockname(fd.as_raw_fd()).unwrap();
/// match ss.family().unwrap() {
/// AddressFamily::Inet => println!("{}", ss.as_sockaddr_in().unwrap()),
/// AddressFamily::Inet6 => println!("{}", ss.as_sockaddr_in6().unwrap()),
Expand Down Expand Up @@ -1261,11 +1262,12 @@ impl std::str::FromStr for SockaddrIn6 {
/// ```
/// # use nix::sys::socket::*;
/// # use std::str::FromStr;
/// # use std::os::unix::io::AsRawFd;
/// let localhost = SockaddrIn::from_str("127.0.0.1:8081").unwrap();
/// let fd = socket(AddressFamily::Inet, SockType::Stream, SockFlag::empty(),
/// None).unwrap();
/// bind(fd, &localhost).expect("bind");
/// let ss: SockaddrStorage = getsockname(fd).expect("getsockname");
/// bind(fd.as_raw_fd(), &localhost).expect("bind");
/// let ss: SockaddrStorage = getsockname(fd.as_raw_fd()).expect("getsockname");
/// assert_eq!(&localhost, ss.as_sockaddr_in().unwrap());
/// ```
#[derive(Clone, Copy, Eq)]
Expand Down
66 changes: 39 additions & 27 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use libc::{
use std::io::{IoSlice, IoSliceMut};
#[cfg(feature = "net")]
use std::net;
use std::os::unix::io::RawFd;
use std::os::unix::io::{AsFd, AsRawFd, FromRawFd, RawFd, OwnedFd};
use std::{mem, ptr};

#[deny(missing_docs)]
Expand Down Expand Up @@ -693,6 +693,7 @@ pub enum ControlMessageOwned {
/// # use std::io::{IoSlice, IoSliceMut};
/// # use std::time::*;
/// # use std::str::FromStr;
/// # use std::os::unix::io::AsRawFd;
/// # fn main() {
/// // Set up
/// let message = "Ohayō!".as_bytes();
Expand All @@ -701,22 +702,22 @@ pub enum ControlMessageOwned {
/// SockType::Datagram,
/// SockFlag::empty(),
/// None).unwrap();
/// setsockopt(in_socket, sockopt::ReceiveTimestamp, &true).unwrap();
/// setsockopt(&in_socket, sockopt::ReceiveTimestamp, &true).unwrap();
/// let localhost = SockaddrIn::from_str("127.0.0.1:0").unwrap();
/// bind(in_socket, &localhost).unwrap();
/// let address: SockaddrIn = getsockname(in_socket).unwrap();
/// bind(in_socket.as_raw_fd(), &localhost).unwrap();
/// let address: SockaddrIn = getsockname(in_socket.as_raw_fd()).unwrap();
/// // Get initial time
/// let time0 = SystemTime::now();
/// // Send the message
/// let iov = [IoSlice::new(message)];
/// let flags = MsgFlags::empty();
/// let l = sendmsg(in_socket, &iov, &[], flags, Some(&address)).unwrap();
/// let l = sendmsg(in_socket.as_raw_fd(), &iov, &[], flags, Some(&address)).unwrap();
/// assert_eq!(message.len(), l);
/// // Receive the message
/// let mut buffer = vec![0u8; message.len()];
/// let mut cmsgspace = cmsg_space!(TimeVal);
/// let mut iov = [IoSliceMut::new(&mut buffer)];
/// let r = recvmsg::<SockaddrIn>(in_socket, &mut iov, Some(&mut cmsgspace), flags)
/// let r = recvmsg::<SockaddrIn>(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags)
/// .unwrap();
/// let rtime = match r.cmsgs().next() {
/// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime,
Expand All @@ -732,7 +733,6 @@ pub enum ControlMessageOwned {
/// assert!(time0.duration_since(UNIX_EPOCH).unwrap() <= rduration);
/// assert!(rduration <= time1.duration_since(UNIX_EPOCH).unwrap());
/// // Close socket
/// nix::unistd::close(in_socket).unwrap();
/// # }
/// ```
ScmTimestamp(TimeVal),
Expand Down Expand Up @@ -1451,6 +1451,7 @@ impl<'a> ControlMessage<'a> {
/// # use nix::sys::socket::*;
/// # use nix::unistd::pipe;
/// # use std::io::IoSlice;
/// # use std::os::unix::io::AsRawFd;
/// let (fd1, fd2) = socketpair(AddressFamily::Unix, SockType::Stream, None,
/// SockFlag::empty())
/// .unwrap();
Expand All @@ -1459,14 +1460,15 @@ impl<'a> ControlMessage<'a> {
/// let iov = [IoSlice::new(b"hello")];
/// let fds = [r];
/// let cmsg = ControlMessage::ScmRights(&fds);
/// sendmsg::<()>(fd1, &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
/// sendmsg::<()>(fd1.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
/// ```
/// When directing to a specific address, the generic type will be inferred.
/// ```
/// # use nix::sys::socket::*;
/// # use nix::unistd::pipe;
/// # use std::io::IoSlice;
/// # use std::str::FromStr;
/// # use std::os::unix::io::AsRawFd;
/// let localhost = SockaddrIn::from_str("1.2.3.4:8080").unwrap();
/// let fd = socket(AddressFamily::Inet, SockType::Datagram, SockFlag::empty(),
/// None).unwrap();
Expand All @@ -1475,7 +1477,7 @@ impl<'a> ControlMessage<'a> {
/// let iov = [IoSlice::new(b"hello")];
/// let fds = [r];
/// let cmsg = ControlMessage::ScmRights(&fds);
/// sendmsg(fd, &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
/// sendmsg(fd.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
/// ```
pub fn sendmsg<S>(fd: RawFd, iov: &[IoSlice<'_>], cmsgs: &[ControlMessage],
flags: MsgFlags, addr: Option<&S>) -> Result<usize>
Expand Down Expand Up @@ -1823,6 +1825,7 @@ mod test {
use crate::sys::socket::{AddressFamily, ControlMessageOwned};
use crate::*;
use std::str::FromStr;
use std::os::unix::io::AsRawFd;

#[cfg_attr(qemu, ignore)]
#[test]
Expand All @@ -1849,9 +1852,9 @@ mod test {
None,
)?;

crate::sys::socket::bind(rsock, &sock_addr)?;
crate::sys::socket::bind(rsock.as_raw_fd(), &sock_addr)?;

setsockopt(rsock, Timestamping, &TimestampingFlag::all())?;
setsockopt(&rsock, Timestamping, &TimestampingFlag::all())?;

let sbuf = (0..400).map(|i| i as u8).collect::<Vec<_>>();

Expand All @@ -1873,13 +1876,13 @@ mod test {
let iov1 = [IoSlice::new(&sbuf)];

let cmsg = cmsg_space!(crate::sys::socket::Timestamps);
sendmsg(ssock, &iov1, &[], flags, Some(&sock_addr)).unwrap();
sendmsg(ssock.as_raw_fd(), &iov1, &[], flags, Some(&sock_addr)).unwrap();

let mut data = super::MultiHeaders::<()>::preallocate(recv_iovs.len(), Some(cmsg));

let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10));

let recv = super::recvmmsg(rsock, &mut data, recv_iovs.iter(), flags, Some(t))?;
let recv = super::recvmmsg(rsock.as_raw_fd(), &mut data, recv_iovs.iter(), flags, Some(t))?;

for rmsg in recv {
#[cfg(not(any(qemu, target_arch = "aarch64")))]
Expand Down Expand Up @@ -2091,7 +2094,7 @@ pub fn socket<T: Into<Option<SockProtocol>>>(
ty: SockType,
flags: SockFlag,
protocol: T,
) -> Result<RawFd> {
) -> Result<OwnedFd> {
let protocol = match protocol.into() {
None => 0,
Some(p) => p as c_int,
Expand All @@ -2105,7 +2108,13 @@ pub fn socket<T: Into<Option<SockProtocol>>>(

let res = unsafe { libc::socket(domain as c_int, ty, protocol) };

Errno::result(res)
match res {
-1 => Err(Errno::last()),
fd => {
// Safe because libc::socket returned success
unsafe { Ok(OwnedFd::from_raw_fd(fd)) }
}
}
}

/// Create a pair of connected sockets
Expand All @@ -2116,7 +2125,7 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
ty: SockType,
protocol: T,
flags: SockFlag,
) -> Result<(RawFd, RawFd)> {
) -> Result<(OwnedFd, OwnedFd)> {
let protocol = match protocol.into() {
None => 0,
Some(p) => p as c_int,
Expand All @@ -2135,14 +2144,18 @@ pub fn socketpair<T: Into<Option<SockProtocol>>>(
};
Errno::result(res)?;

Ok((fds[0], fds[1]))
// Safe because socketpair returned success.
unsafe {
Ok((OwnedFd::from_raw_fd(fds[0]), OwnedFd::from_raw_fd(fds[1])))
}
}

/// Listen for connections on a socket
///
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/listen.html)
pub fn listen(sockfd: RawFd, backlog: usize) -> Result<()> {
let res = unsafe { libc::listen(sockfd, backlog as c_int) };
pub fn listen<F: AsFd>(sock: &F, backlog: usize) -> Result<()> {
let fd = sock.as_fd().as_raw_fd();
let res = unsafe { libc::listen(fd, backlog as c_int) };

Errno::result(res).map(drop)
}
Expand Down Expand Up @@ -2302,21 +2315,21 @@ pub trait GetSockOpt: Copy {
type Val;

/// Look up the value of this socket option on the given socket.
fn get(&self, fd: RawFd) -> Result<Self::Val>;
fn get<F: AsFd>(&self, fd: &F) -> Result<Self::Val>;
}

/// Represents a socket option that can be set.
pub trait SetSockOpt: Clone {
type Val;

/// Set the value of this socket option on the given socket.
fn set(&self, fd: RawFd, val: &Self::Val) -> Result<()>;
fn set<F: AsFd>(&self, fd: &F, val: &Self::Val) -> Result<()>;
}

/// Get the current value for the requested socket option
///
/// [Further reading](https://pubs.opengroup.org/onlinepubs/9699919799/functions/getsockopt.html)
pub fn getsockopt<O: GetSockOpt>(fd: RawFd, opt: O) -> Result<O::Val> {
pub fn getsockopt<F: AsFd, O: GetSockOpt>(fd: &F, opt: O) -> Result<O::Val> {
opt.get(fd)
}

Expand All @@ -2330,15 +2343,14 @@ pub fn getsockopt<O: GetSockOpt>(fd: RawFd, opt: O) -> Result<O::Val> {
/// use nix::sys::socket::setsockopt;
/// use nix::sys::socket::sockopt::KeepAlive;
/// use std::net::TcpListener;
/// use std::os::unix::io::AsRawFd;
///
/// let listener = TcpListener::bind("0.0.0.0:0").unwrap();
/// let fd = listener.as_raw_fd();
/// let res = setsockopt(fd, KeepAlive, &true);
/// let fd = listener;
/// let res = setsockopt(&fd, KeepAlive, &true);
/// assert!(res.is_ok());
/// ```
pub fn setsockopt<O: SetSockOpt>(
fd: RawFd,
pub fn setsockopt<F: AsFd, O: SetSockOpt>(
fd: &F,
opt: O,
val: &O::Val,
) -> Result<()> {
Expand Down
41 changes: 17 additions & 24 deletions src/sys/socket/sockopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::ffi::{OsStr, OsString};
use std::mem::{self, MaybeUninit};
#[cfg(target_family = "unix")]
use std::os::unix::ffi::OsStrExt;
use std::os::unix::io::RawFd;
use std::os::unix::io::{AsFd, AsRawFd};

// Constants
// TCP_CA_NAME_MAX isn't defined in user space include files
Expand Down Expand Up @@ -44,12 +44,12 @@ macro_rules! setsockopt_impl {
impl SetSockOpt for $name {
type Val = $ty;

fn set(&self, fd: RawFd, val: &$ty) -> Result<()> {
fn set<F: AsFd>(&self, fd: &F, val: &$ty) -> Result<()> {
unsafe {
let setter: $setter = Set::new(val);

let res = libc::setsockopt(
fd,
fd.as_fd().as_raw_fd(),
$level,
$flag,
setter.ffi_ptr(),
Expand Down Expand Up @@ -89,12 +89,12 @@ macro_rules! getsockopt_impl {
impl GetSockOpt for $name {
type Val = $ty;

fn get(&self, fd: RawFd) -> Result<$ty> {
fn get<F: AsFd>(&self, fd: &F) -> Result<$ty> {
unsafe {
let mut getter: $getter = Get::uninit();

let res = libc::getsockopt(
fd,
fd.as_fd().as_raw_fd(),
$level,
$flag,
getter.ffi_ptr(),
Expand Down Expand Up @@ -1053,10 +1053,10 @@ pub struct AlgSetAeadAuthSize;
impl SetSockOpt for AlgSetAeadAuthSize {
type Val = usize;

fn set(&self, fd: RawFd, val: &usize) -> Result<()> {
fn set<F: AsFd>(&self, fd: &F, val: &usize) -> Result<()> {
unsafe {
let res = libc::setsockopt(
fd,
fd.as_fd().as_raw_fd(),
libc::SOL_ALG,
libc::ALG_SET_AEAD_AUTHSIZE,
::std::ptr::null(),
Expand Down Expand Up @@ -1087,10 +1087,10 @@ where
{
type Val = T;

fn set(&self, fd: RawFd, val: &T) -> Result<()> {
fn set<F: AsFd>(&self, fd: &F, val: &T) -> Result<()> {
unsafe {
let res = libc::setsockopt(
fd,
fd.as_fd().as_raw_fd(),
libc::SOL_ALG,
libc::ALG_SET_KEY,
val.as_ref().as_ptr() as *const _,
Expand Down Expand Up @@ -1403,34 +1403,30 @@ mod test {
SockFlag::empty(),
)
.unwrap();
let a_cred = getsockopt(a, super::PeerCredentials).unwrap();
let b_cred = getsockopt(b, super::PeerCredentials).unwrap();
let a_cred = getsockopt(&a, super::PeerCredentials).unwrap();
let b_cred = getsockopt(&b, super::PeerCredentials).unwrap();
assert_eq!(a_cred, b_cred);
assert_ne!(a_cred.pid(), 0);
}

#[test]
fn is_socket_type_unix() {
use super::super::*;
use crate::unistd::close;

let (a, b) = socketpair(
let (a, _b) = socketpair(
AddressFamily::Unix,
SockType::Stream,
None,
SockFlag::empty(),
)
.unwrap();
let a_type = getsockopt(a, super::SockType).unwrap();
let a_type = getsockopt(&a, super::SockType).unwrap();
assert_eq!(a_type, SockType::Stream);
close(a).unwrap();
close(b).unwrap();
}

#[test]
fn is_socket_type_dgram() {
use super::super::*;
use crate::unistd::close;

let s = socket(
AddressFamily::Inet,
Expand All @@ -1439,16 +1435,14 @@ mod test {
None,
)
.unwrap();
let s_type = getsockopt(s, super::SockType).unwrap();
let s_type = getsockopt(&s, super::SockType).unwrap();
assert_eq!(s_type, SockType::Datagram);
close(s).unwrap();
}

#[cfg(any(target_os = "freebsd", target_os = "linux"))]
#[test]
fn can_get_listen_on_tcp_socket() {
use super::super::*;
use crate::unistd::close;

let s = socket(
AddressFamily::Inet,
Expand All @@ -1457,11 +1451,10 @@ mod test {
None,
)
.unwrap();
let s_listening = getsockopt(s, super::AcceptConn).unwrap();
let s_listening = getsockopt(&s, super::AcceptConn).unwrap();
assert!(!s_listening);
listen(s, 10).unwrap();
let s_listening2 = getsockopt(s, super::AcceptConn).unwrap();
listen(&s, 10).unwrap();
let s_listening2 = getsockopt(&s, super::AcceptConn).unwrap();
assert!(s_listening2);
close(s).unwrap();
}
}
Loading

0 comments on commit ee91423

Please sign in to comment.