Skip to content

Commit

Permalink
recvmsg: Check if CMSG buffer was too small and return an error
Browse files Browse the repository at this point in the history
If MSG_CTRUNC is set, it is not safe to iterate the cmsgs, since they
could have been truncated. Change RecvMsg::cmsgs() to return a Result,
and to check for this flag (an API change).

Update tests for API change. Add test for too-small buffer.
  • Loading branch information
agrover committed May 22, 2024
1 parent 663506a commit d3d131a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 31 deletions.
1 change: 1 addition & 0 deletions changelog/2413.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`RecvMsg::cmsgs()` now returns a `Result`, and checks that cmsgs were not truncated.
19 changes: 13 additions & 6 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use libc::{self, c_int, size_t, socklen_t};
#[cfg(all(feature = "uio", not(target_os = "redox")))]
use libc::{
c_void, iovec, CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE,
MSG_CTRUNC,
};
#[cfg(not(target_os = "redox"))]
use std::io::{IoSlice, IoSliceMut};
Expand Down Expand Up @@ -599,13 +600,19 @@ pub struct RecvMsg<'a, 's, S> {
}

impl<'a, S> RecvMsg<'a, '_, S> {
/// Iterate over the valid control messages pointed to by this
/// msghdr.
pub fn cmsgs(&self) -> CmsgIterator {
CmsgIterator {
/// Iterate over the valid control messages pointed to by this msghdr. If
/// allocated space for CMSGs was too small it is not safe to iterate,
/// instead return an `Error::ENOBUFS` error.
pub fn cmsgs(&self) -> Result<CmsgIterator> {

if self.mhdr.msg_flags & MSG_CTRUNC == MSG_CTRUNC {
return Err(Errno::ENOBUFS);
}

Ok(CmsgIterator {
cmsghdr: self.cmsghdr,
mhdr: &self.mhdr
}
})
}
}

Expand Down Expand Up @@ -700,7 +707,7 @@ pub enum ControlMessageOwned {
/// let mut iov = [IoSliceMut::new(&mut buffer)];
/// let r = recvmsg::<SockaddrIn>(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags)
/// .unwrap();
/// let rtime = match r.cmsgs().next() {
/// let rtime = match r.cmsgs().unwrap().next() {
/// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime,
/// Some(_) => panic!("Unexpected control message"),
/// None => panic!("No control message")
Expand Down
61 changes: 36 additions & 25 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub fn test_timestamping() {
.unwrap();

let mut ts = None;
for c in recv.cmsgs() {
for c in recv.cmsgs().unwrap() {
if let ControlMessageOwned::ScmTimestampsns(timestamps) = c {
ts = Some(timestamps.system);
}
Expand Down Expand Up @@ -117,7 +117,7 @@ pub fn test_timestamping_realtime() {
.unwrap();

let mut ts = None;
for c in recv.cmsgs() {
for c in recv.cmsgs().unwrap() {
if let ControlMessageOwned::ScmRealtime(timeval) = c {
ts = Some(timeval);
}
Expand Down Expand Up @@ -179,7 +179,7 @@ pub fn test_timestamping_monotonic() {
.unwrap();

let mut ts = None;
for c in recv.cmsgs() {
for c in recv.cmsgs().unwrap() {
if let ControlMessageOwned::ScmMonotonic(timeval) = c {
ts = Some(timeval);
}
Expand Down Expand Up @@ -889,7 +889,7 @@ pub fn test_scm_rights() {
)
.unwrap();

for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs().unwrap() {
if let ControlMessageOwned::ScmRights(fd) = cmsg {
assert_eq!(received_r, None);
assert_eq!(fd.len(), 1);
Expand Down Expand Up @@ -1330,7 +1330,7 @@ fn test_scm_rights_single_cmsg_multiple_fds() {
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));

let mut cmsgs = msg.cmsgs();
let mut cmsgs = msg.cmsgs().unwrap();
match cmsgs.next() {
Some(ControlMessageOwned::ScmRights(fds)) => {
assert_eq!(
Expand Down Expand Up @@ -1399,7 +1399,7 @@ pub fn test_sendmsg_empty_cmsgs() {
)
.unwrap();

if msg.cmsgs().next().is_some() {
if msg.cmsgs().unwrap().next().is_some() {
panic!("unexpected cmsg");
}
assert!(!msg
Expand Down Expand Up @@ -1466,7 +1466,7 @@ fn test_scm_credentials() {
.unwrap();
let mut received_cred = None;

for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs().unwrap() {
let cred = match cmsg {
#[cfg(linux_android)]
ControlMessageOwned::ScmCredentials(cred) => cred,
Expand Down Expand Up @@ -1497,7 +1497,7 @@ fn test_scm_credentials() {
#[test]
fn test_scm_credentials_and_rights() {
let space = cmsg_space!(libc::ucred, RawFd);
test_impl_scm_credentials_and_rights(space);
test_impl_scm_credentials_and_rights(space).unwrap();
}

/// Ensure that passing a an oversized control message buffer to recvmsg
Expand All @@ -1509,11 +1509,20 @@ fn test_scm_credentials_and_rights() {
#[test]
fn test_too_large_cmsgspace() {
let space = vec![0u8; 1024];
test_impl_scm_credentials_and_rights(space);
test_impl_scm_credentials_and_rights(space).unwrap();
}

#[cfg(linux_android)]
fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
#[test]
fn test_too_small_cmsgspace() {
let space = vec![0u8; 4];
assert_eq!(test_impl_scm_credentials_and_rights(space), Err(nix::errno::Errno::ENOBUF));
}

#[cfg(linux_android)]
fn test_impl_scm_credentials_and_rights(
mut space: Vec<u8>,
) -> Result<(), std::io::Error> {
use libc::ucred;
use nix::sys::socket::sockopt::PassCred;
use nix::sys::socket::{
Expand Down Expand Up @@ -1573,9 +1582,9 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
.unwrap();
let mut received_cred = None;

assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs");
assert_eq!(msg.cmsgs()?.count(), 2, "expected 2 cmsgs");

for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs()? {
match cmsg {
ControlMessageOwned::ScmRights(fds) => {
assert_eq!(received_r, None, "already received fd");
Expand Down Expand Up @@ -1606,6 +1615,8 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
read(received_r.as_raw_fd(), &mut buf).unwrap();
assert_eq!(&buf[..], b"world");
close(received_r).unwrap();

Ok(())
}

// Test creating and using named unix domain sockets
Expand Down Expand Up @@ -1837,7 +1848,7 @@ pub fn test_recv_ipv4pktinfo() {
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));

let mut cmsgs = msg.cmsgs();
let mut cmsgs = msg.cmsgs().unwrap();
if let Some(ControlMessageOwned::Ipv4PacketInfo(pktinfo)) = cmsgs.next()
{
let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex");
Expand Down Expand Up @@ -1929,11 +1940,11 @@ pub fn test_recvif() {
assert!(!msg
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs");
assert_eq!(msg.cmsgs().unwrap().count(), 2, "expected 2 cmsgs");

let mut rx_recvif = false;
let mut rx_recvdstaddr = false;
for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs().unwrap() {
match cmsg {
ControlMessageOwned::Ipv4RecvIf(dl) => {
rx_recvif = true;
Expand Down Expand Up @@ -2027,10 +2038,10 @@ pub fn test_recvif_ipv4() {
assert!(!msg
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs");
assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs");

let mut rx_recvorigdstaddr = false;
for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs().unwrap() {
match cmsg {
ControlMessageOwned::Ipv4OrigDstAddr(addr) => {
rx_recvorigdstaddr = true;
Expand Down Expand Up @@ -2113,10 +2124,10 @@ pub fn test_recvif_ipv6() {
assert!(!msg
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));
assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs");
assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs");

let mut rx_recvorigdstaddr = false;
for cmsg in msg.cmsgs() {
for cmsg in msg.cmsgs().unwrap() {
match cmsg {
ControlMessageOwned::Ipv6OrigDstAddr(addr) => {
rx_recvorigdstaddr = true;
Expand Down Expand Up @@ -2214,7 +2225,7 @@ pub fn test_recv_ipv6pktinfo() {
.flags
.intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC));

let mut cmsgs = msg.cmsgs();
let mut cmsgs = msg.cmsgs().unwrap();
if let Some(ControlMessageOwned::Ipv6PacketInfo(pktinfo)) = cmsgs.next()
{
let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex");
Expand Down Expand Up @@ -2357,7 +2368,7 @@ fn test_recvmsg_timestampns() {
flags,
)
.unwrap();
let rtime = match r.cmsgs().next() {
let rtime = match r.cmsgs().unwrap().next() {
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
Some(_) => panic!("Unexpected control message"),
None => panic!("No control message"),
Expand Down Expand Up @@ -2418,7 +2429,7 @@ fn test_recvmmsg_timestampns() {
)
.unwrap()
.collect();
let rtime = match r[0].cmsgs().next() {
let rtime = match r[0].cmsgs().unwrap().next() {
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
Some(_) => panic!("Unexpected control message"),
None => panic!("No control message"),
Expand Down Expand Up @@ -2508,7 +2519,7 @@ fn test_recvmsg_rxq_ovfl() {
MsgFlags::MSG_DONTWAIT,
) {
Ok(r) => {
drop_counter = match r.cmsgs().next() {
drop_counter = match r.cmsgs().unwrap().next() {
Some(ControlMessageOwned::RxqOvfl(drop_counter)) => {
drop_counter
}
Expand Down Expand Up @@ -2687,7 +2698,7 @@ mod linux_errqueue {
assert_eq!(msg.address, Some(sock_addr));

// Check for expected control message.
let ext_err = match msg.cmsgs().next() {
let ext_err = match msg.cmsgs().unwrap().next() {
Some(cmsg) => testf(&cmsg),
None => panic!("No control message"),
};
Expand Down Expand Up @@ -2878,7 +2889,7 @@ fn test_recvmm2() -> nix::Result<()> {
#[cfg(not(any(qemu, target_arch = "aarch64")))]
let mut saw_time = false;
let mut recvd = 0;
for cmsg in rmsg.cmsgs() {
for cmsg in rmsg.cmsgs().unwrap() {
if let ControlMessageOwned::ScmTimestampsns(timestamps) = cmsg {
let ts = timestamps.system;

Expand Down

0 comments on commit d3d131a

Please sign in to comment.