Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set the length of a socket address when calling recvmsg on Linux #2041

Merged
merged 11 commits into from
Jul 17, 2023
101 changes: 98 additions & 3 deletions src/sys/socket/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,22 @@ impl SockaddrLike for UnixAddr {
{
mem::size_of::<libc::sockaddr_un>() as libc::socklen_t
}

unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
cfg_if! {
if #[cfg(any(target_os = "android",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "redox",
))] {
self.sun_len = new_length as u8;
} else {
self.sun.sun_len = new_length as u8;
JarredAllen marked this conversation as resolved.
Show resolved Hide resolved
}
};
Ok(())
}
}

impl AsRef<libc::sockaddr_un> for UnixAddr {
Expand Down Expand Up @@ -912,8 +928,30 @@ pub trait SockaddrLike: private::SockaddrLikePriv {
{
mem::size_of::<Self>() as libc::socklen_t
}

/// Set the length of this socket address
///
/// This method may only be called on socket addresses whose lenghts are dynamic, and it
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// This method may only be called on socket addresses whose lenghts are dynamic, and it
/// This method may only be called on socket addresses whose lengths are dynamic, and it

/// returns an error if called on a type whose length is static.
///
/// # Safety
///
/// `new_length` must be a valid length for this type of address. Specifically, reads of that
/// length from `self` must be valid.
unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add a default implementation here, then you can avoid a lot of boilerplate elsewhere.

Suggested change
unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic>;
unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}

}

/// The error returned by [`SockaddrLike::set_length`] on an address whose length is statically
/// fixed.
#[derive(Copy, Clone, Debug)]
pub struct SocketAddressLengthNotDynamic;
impl fmt::Display for SocketAddressLengthNotDynamic {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Attempted to set length on socket whose length is statically fixed")
}
}
impl std::error::Error for SocketAddressLengthNotDynamic {}

impl private::SockaddrLikePriv for () {
fn as_mut_ptr(&mut self) -> *mut libc::sockaddr {
ptr::null_mut()
Expand Down Expand Up @@ -946,6 +984,10 @@ impl SockaddrLike for () {
fn len(&self) -> libc::socklen_t {
0
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

/// An IPv4 socket address
Expand Down Expand Up @@ -1015,6 +1057,10 @@ impl SockaddrLike for SockaddrIn {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

#[cfg(feature = "net")]
Expand Down Expand Up @@ -1134,6 +1180,10 @@ impl SockaddrLike for SockaddrIn6 {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

#[cfg(feature = "net")]
Expand Down Expand Up @@ -1361,6 +1411,27 @@ impl SockaddrLike for SockaddrStorage {
None => mem::size_of_val(self) as libc::socklen_t,
}
}

unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
match self.as_unix_addr_mut() {
Some(addr) => {
cfg_if! {
if #[cfg(any(target_os = "android",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "redox",
))] {
addr.sun_len = new_length as u8;
} else {
addr.sun.sun_len = new_length as u8;
}
}
Ok(())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cfg_if! {
if #[cfg(any(target_os = "android",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "redox",
))] {
addr.sun_len = new_length as u8;
} else {
addr.sun.sun_len = new_length as u8;
}
}
Ok(())
addr.set_length(new_length)

},
None => Err(SocketAddressLengthNotDynamic),
}
}
}

macro_rules! accessors {
Expand Down Expand Up @@ -1754,6 +1825,10 @@ pub mod netlink {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_nl> for NetlinkAddr {
Expand Down Expand Up @@ -1803,6 +1878,10 @@ pub mod alg {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_alg> for AlgAddr {
Expand Down Expand Up @@ -1902,7 +1981,7 @@ pub mod sys_control {
use std::{fmt, mem, ptr};
use std::os::unix::io::RawFd;
use crate::{Errno, Result};
use super::{private, SockaddrLike};
use super::{private, SockaddrLike, SocketAddressLengthNotDynamic};

// FIXME: Move type into `libc`
#[repr(C)]
Expand Down Expand Up @@ -1943,6 +2022,10 @@ pub mod sys_control {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_ctl> for SysControlAddr {
Expand Down Expand Up @@ -2007,7 +2090,7 @@ pub mod sys_control {
mod datalink {
feature! {
#![feature = "net"]
use super::{fmt, mem, private, ptr, SockaddrLike};
use super::{fmt, mem, private, ptr, SockaddrLike, SocketAddressLengthNotDynamic};

/// Hardware Address
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
Expand Down Expand Up @@ -2085,6 +2168,10 @@ mod datalink {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_ll> for LinkAddr {
Expand All @@ -2110,7 +2197,7 @@ mod datalink {
mod datalink {
feature! {
#![feature = "net"]
use super::{fmt, mem, private, ptr, SockaddrLike};
use super::{fmt, mem, private, ptr, SockaddrLike, SocketAddressLengthNotDynamic};

/// Hardware Address
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
Expand Down Expand Up @@ -2209,6 +2296,10 @@ mod datalink {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_dl> for LinkAddr {
Expand Down Expand Up @@ -2257,6 +2348,10 @@ pub mod vsock {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_vm> for VsockAddr {
Expand Down
35 changes: 19 additions & 16 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1626,9 +1626,7 @@ impl<S> MultiHeaders<S> {
Some(v) => ((&v[ix * msg_controllen] as *const u8), msg_controllen),
None => (std::ptr::null(), 0),
};
let msg_hdr = unsafe {
pack_mhdr_to_receive(std::ptr::null(), 0, ptr, cap, <S as addr::private::SockaddrLikePriv>::as_mut_ptr(address.assume_init_mut()).cast())
};
let msg_hdr = unsafe { pack_mhdr_to_receive(std::ptr::null(), 0, ptr, cap, address.as_mut_ptr()) };
libc::mmsghdr {
msg_hdr,
msg_len: 0,
Expand Down Expand Up @@ -1763,7 +1761,7 @@ where
mmsghdr.msg_hdr,
mmsghdr.msg_len as isize,
self.rmm.msg_controllen,
Some(address),
address,
)
})
}
Expand Down Expand Up @@ -1916,7 +1914,7 @@ unsafe fn read_mhdr<'a, 'i, S>(
mhdr: msghdr,
r: isize,
msg_controllen: usize,
address: Option<S>,
mut address: S,
) -> RecvMsg<'a, 'i, S>
where S: SockaddrLike
{
Expand All @@ -1932,10 +1930,15 @@ unsafe fn read_mhdr<'a, 'i, S>(
}.as_ref()
};

// Ignore errors if this socket address has statically-known length
//
// This is to ensure that unix socket addresses have their length set appropriately.
let _ = address.set_length(mhdr.msg_namelen as usize);

RecvMsg {
bytes: r as usize,
cmsghdr,
address,
address: Some(address),
flags: MsgFlags::from_bits_truncate(mhdr.msg_flags),
mhdr,
iobufs: std::marker::PhantomData,
Expand All @@ -1953,19 +1956,22 @@ unsafe fn read_mhdr<'a, 'i, S>(
/// headers are not used
///
/// Buffers must remain valid for the whole lifetime of msghdr
unsafe fn pack_mhdr_to_receive(
unsafe fn pack_mhdr_to_receive<S>(
iov_buffer: *const IoSliceMut,
iov_buffer_len: usize,
cmsg_buffer: *const u8,
cmsg_capacity: usize,
address: *mut libc::sockaddr_storage,
) -> msghdr {
address: *mut S,
) -> msghdr
where
S: SockaddrLike
{
// Musl's msghdr has private fields, so this is the only way to
// initialize it.
let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed();
let p = mhdr.as_mut_ptr();
(*p).msg_name = address as *mut c_void;
(*p).msg_namelen = mem::size_of::<libc::sockaddr_storage>() as u32;
(*p).msg_namelen = S::size();
(*p).msg_iov = iov_buffer as *mut iovec;
(*p).msg_iovlen = iov_buffer_len as _;
(*p).msg_control = cmsg_buffer as *mut c_void;
Expand Down Expand Up @@ -2047,23 +2053,20 @@ pub fn recvmsg<'a, 'outer, 'inner, S>(fd: RawFd, iov: &'outer mut [IoSliceMut<'i
where S: SockaddrLike + 'a,
'inner: 'outer
{
let mut address: libc::sockaddr_storage = unsafe { mem::MaybeUninit::zeroed().assume_init() };
let address_ptr: *mut libc::sockaddr_storage = &mut address as *mut libc::sockaddr_storage;
let mut address = mem::MaybeUninit::zeroed();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you think it's necessary to zero the address here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally thought it might be necessary when I first started working on this (since the length was being left uninitialized initially), but looking back at it now I think everything is initialized now so I put it back to uninit()


let (msg_control, msg_controllen) = cmsg_buffer.as_mut()
.map(|v| (v.as_mut_ptr(), v.capacity()))
.unwrap_or((ptr::null_mut(), 0));
let mut mhdr = unsafe {
pack_mhdr_to_receive(iov.as_ref().as_ptr(), iov.len(), msg_control, msg_controllen, address_ptr)
pack_mhdr_to_receive(iov.as_ref().as_ptr(), iov.len(), msg_control, msg_controllen, address.as_mut_ptr())
};

let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) };

let r = Errno::result(ret)?;

let address = unsafe { S::from_raw(address_ptr.cast::<libc::sockaddr>(), Some(mhdr.msg_namelen)) };

Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address) })
Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address.assume_init()) })
}
}

Expand Down
3 changes: 1 addition & 2 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,7 @@ pub fn test_recvmsg_sockaddr_un() {
MsgFlags::empty(),
Some(&sockaddr),
) {
print!("Couldn't send ({e:?}), so skipping test");
return;
crate::skip!("Couldn't send ({e:?}), so skipping test");
}

// Receive the message
Expand Down