Skip to content

Commit

Permalink
recvmsg: Check if CMSG buffer was too small and return an error
Browse files Browse the repository at this point in the history
If MSG_CTRUNC is set, it is not safe to iterate the cmsgs, since they
could have been truncated. Change RecvMsg::cmsgs() to return a Result,
and to check for this flag (an API change).

Update tests for API change. Add test for too-small buffer.
  • Loading branch information
agrover committed May 21, 2024
1 parent 663506a commit 53b4487
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 30 deletions.
1 change: 1 addition & 0 deletions changelog/2413.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`RecvMsg::cmsgs()` now returns a `Result`, and checks that cmsgs were not truncated.
14 changes: 10 additions & 4 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use libc::{self, c_int, size_t, socklen_t};
#[cfg(all(feature = "uio", not(target_os = "redox")))]
use libc::{
c_void, iovec, CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE,
MSG_CTRUNC,
};
#[cfg(not(target_os = "redox"))]
use std::io::{IoSlice, IoSliceMut};
Expand Down Expand Up @@ -601,11 +602,16 @@ pub struct RecvMsg<'a, 's, S> {
impl<'a, S> RecvMsg<'a, '_, S> {
/// Iterate over the valid control messages pointed to by this
/// msghdr.
pub fn cmsgs(&self) -> CmsgIterator {
CmsgIterator {
pub fn cmsgs(&self) -> Result<CmsgIterator> {

if self.mhdr.msg_flags & MSG_CTRUNC == MSG_CTRUNC {
return Err(Errno::ENOBUFS);
}

Ok(CmsgIterator {
cmsghdr: self.cmsghdr,
mhdr: &self.mhdr
}
})
}
}

Expand Down Expand Up @@ -700,7 +706,7 @@ pub enum ControlMessageOwned {
/// let mut iov = [IoSliceMut::new(&mut buffer)];
/// let r = recvmsg::<SockaddrIn>(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags)
/// .unwrap();
/// let rtime = match r.cmsgs().next() {
/// let rtime = match r.cmsgs().unwrap().next() {
/// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime,
/// Some(_) => panic!("Unexpected control message"),
/// None => panic!("No control message")
Expand Down
63 changes: 37 additions & 26 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use libc::c_char;
use nix::sys::socket::{getsockname, AddressFamily, UnixAddr};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::io;
use std::net::{SocketAddrV4, SocketAddrV6};
use std::os::unix::io::{AsRawFd, RawFd};
use std::path::Path;
Expand Down Expand Up @@ -55,7 +56,7 @@ pub fn test_timestamping() {
.unwrap();

let mut ts = None;
for c in recv.cmsgs() {
for c in recv.cmsgs().unwrap() {
if let ControlMessageOwned::ScmTimestampsns(timestamps) = c {
ts = Some(timestamps.system);
}
Expand Down Expand Up @@ -117,7 +118,7 @@ pub fn test_timestamping_realtime() {
.unwrap();

let mut ts = None;
for c in recv.cmsgs() {
for c in recv.cmsgs().unwrap() {
if let ControlMessageOwned::ScmRealtime(timeval) = c {
ts = Some(timeval);
}
Expand Down Expand Up @@ -179,7 +180,7 @@ pub fn test_timestamping_monotonic() {
.unwrap();

let mut ts = None;
for c in recv.cmsgs() {
for c in recv.cmsgs().unwrap() {
if let ControlMessageOwned::ScmMonotonic(timeval) = c {
ts = Some(timeval);
}
Expand Down Expand Up @@ -889,7 +890,7 @@ pub fn test_scm_rights() {
)
.unwrap();

for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs().unwrap() {
if let ControlMessageOwned::ScmRights(fd) = cmsg {
assert_eq!(received_r, None);
assert_eq!(fd.len(), 1);
Expand Down Expand Up @@ -1330,7 +1331,7 @@ fn test_scm_rights_single_cmsg_multiple_fds() {
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));

let mut cmsgs = msg.cmsgs();
let mut cmsgs = msg.cmsgs().unwrap();
match cmsgs.next() {
Some(ControlMessageOwned::ScmRights(fds)) => {
assert_eq!(
Expand Down Expand Up @@ -1399,7 +1400,7 @@ pub fn test_sendmsg_empty_cmsgs() {
)
.unwrap();

if msg.cmsgs().next().is_some() {
if msg.cmsgs().unwrap().next().is_some() {
panic!("unexpected cmsg");
}
assert!(!msg
Expand Down Expand Up @@ -1466,7 +1467,7 @@ fn test_scm_credentials() {
.unwrap();
let mut received_cred = None;

for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs().unwrap() {
let cred = match cmsg {
#[cfg(linux_android)]
ControlMessageOwned::ScmCredentials(cred) => cred,
Expand Down Expand Up @@ -1497,7 +1498,7 @@ fn test_scm_credentials() {
#[test]
fn test_scm_credentials_and_rights() {
let space = cmsg_space!(libc::ucred, RawFd);
test_impl_scm_credentials_and_rights(space);
test_impl_scm_credentials_and_rights(space).unwrap();
}

/// Ensure that passing a an oversized control message buffer to recvmsg
Expand All @@ -1509,11 +1510,20 @@ fn test_scm_credentials_and_rights() {
#[test]
fn test_too_large_cmsgspace() {
let space = vec![0u8; 1024];
test_impl_scm_credentials_and_rights(space);
test_impl_scm_credentials_and_rights(space).unwrap();
}

#[cfg(linux_android)]
fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
#[test]
fn test_too_small_cmsgspace() {
let space = vec![0u8; 4];
assert!(test_impl_scm_credentials_and_rights(space).is_err());
}

#[cfg(linux_android)]
fn test_impl_scm_credentials_and_rights(
mut space: Vec<u8>,
) -> Result<(), io::Error> {
use libc::ucred;
use nix::sys::socket::sockopt::PassCred;
use nix::sys::socket::{
Expand Down Expand Up @@ -1573,9 +1583,9 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
.unwrap();
let mut received_cred = None;

assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs");
assert_eq!(msg.cmsgs()?.count(), 2, "expected 2 cmsgs");

for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs()? {
match cmsg {
ControlMessageOwned::ScmRights(fds) => {
assert_eq!(received_r, None, "already received fd");
Expand Down Expand Up @@ -1606,6 +1616,8 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
read(received_r.as_raw_fd(), &mut buf).unwrap();
assert_eq!(&buf[..], b"world");
close(received_r).unwrap();

Ok(())
}

// Test creating and using named unix domain sockets
Expand Down Expand Up @@ -1742,7 +1754,6 @@ fn loopback_address(
use nix::ifaddrs::getifaddrs;
use nix::net::if_::*;
use nix::sys::socket::SockaddrLike;
use std::io;
use std::io::Write;

let mut addrs = match getifaddrs() {
Expand Down Expand Up @@ -1837,7 +1848,7 @@ pub fn test_recv_ipv4pktinfo() {
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));

let mut cmsgs = msg.cmsgs();
let mut cmsgs = msg.cmsgs().unwrap();
if let Some(ControlMessageOwned::Ipv4PacketInfo(pktinfo)) = cmsgs.next()
{
let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex");
Expand Down Expand Up @@ -1929,11 +1940,11 @@ pub fn test_recvif() {
assert!(!msg
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs");
assert_eq!(msg.cmsgs().unwrap().count(), 2, "expected 2 cmsgs");

let mut rx_recvif = false;
let mut rx_recvdstaddr = false;
for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs().unwrap() {
match cmsg {
ControlMessageOwned::Ipv4RecvIf(dl) => {
rx_recvif = true;
Expand Down Expand Up @@ -2027,10 +2038,10 @@ pub fn test_recvif_ipv4() {
assert!(!msg
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs");
assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs");

let mut rx_recvorigdstaddr = false;
for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs().unwrap() {
match cmsg {
ControlMessageOwned::Ipv4OrigDstAddr(addr) => {
rx_recvorigdstaddr = true;
Expand Down Expand Up @@ -2113,10 +2124,10 @@ pub fn test_recvif_ipv6() {
assert!(!msg
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs");
assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs");

let mut rx_recvorigdstaddr = false;
for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs().unwrap() {
match cmsg {
ControlMessageOwned::Ipv6OrigDstAddr(addr) => {
rx_recvorigdstaddr = true;
Expand Down Expand Up @@ -2214,7 +2225,7 @@ pub fn test_recv_ipv6pktinfo() {
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));

let mut cmsgs = msg.cmsgs();
let mut cmsgs = msg.cmsgs().unwrap();
if let Some(ControlMessageOwned::Ipv6PacketInfo(pktinfo)) = cmsgs.next()
{
let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex");
Expand Down Expand Up @@ -2357,7 +2368,7 @@ fn test_recvmsg_timestampns() {
flags,
)
.unwrap();
let rtime = match r.cmsgs().next() {
let rtime = match r.cmsgs().unwrap().next() {
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
Some(_) => panic!("Unexpected control message"),
None => panic!("No control message"),
Expand Down Expand Up @@ -2418,7 +2429,7 @@ fn test_recvmmsg_timestampns() {
)
.unwrap()
.collect();
let rtime = match r[0].cmsgs().next() {
let rtime = match r[0].cmsgs().unwrap().next() {
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
Some(_) => panic!("Unexpected control message"),
None => panic!("No control message"),
Expand Down Expand Up @@ -2508,7 +2519,7 @@ fn test_recvmsg_rxq_ovfl() {
MsgFlags::MSG_DONTWAIT,
) {
Ok(r) => {
drop_counter = match r.cmsgs().next() {
drop_counter = match r.cmsgs().unwrap().next() {
Some(ControlMessageOwned::RxqOvfl(drop_counter)) => {
drop_counter
}
Expand Down Expand Up @@ -2687,7 +2698,7 @@ mod linux_errqueue {
assert_eq!(msg.address, Some(sock_addr));

// Check for expected control message.
let ext_err = match msg.cmsgs().next() {
let ext_err = match msg.cmsgs().unwrap().next() {
Some(cmsg) => testf(&cmsg),
None => panic!("No control message"),
};
Expand Down Expand Up @@ -2878,7 +2889,7 @@ fn test_recvmm2() -> nix::Result<()> {
#[cfg(not(any(qemu, target_arch = "aarch64")))]
let mut saw_time = false;
let mut recvd = 0;
for cmsg in rmsg.cmsgs() {
for cmsg in rmsg.cmsgs().unwrap() {
if let ControlMessageOwned::ScmTimestampsns(timestamps) = cmsg {
let ts = timestamps.system;

Expand Down

0 comments on commit 53b4487

Please sign in to comment.