Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set the length of a socket address when calling recvmsg on Linux #2041

Merged
merged 11 commits into from
Jul 17, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ This project adheres to [Semantic Versioning](https://semver.org/).
### Fixed
- Fix: send `ETH_P_ALL` in htons format
([#1925](https://github.com/nix-rust/nix/pull/1925))
- Fix: `recvmsg` now sets the length of the received `sockaddr_un` field
correctly on Linux platforms. ([#2041](https://github.com/nix-rust/nix/pull/2041))
- Fix potentially invalid conversions in
`SockaddrIn::from<std::net::SocketAddrV4>`,
`SockaddrIn6::from<std::net::SockaddrV6>`, `IpMembershipRequest::new`, and
Expand Down
53 changes: 51 additions & 2 deletions src/sys/socket/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,22 @@ impl SockaddrLike for UnixAddr {
{
mem::size_of::<libc::sockaddr_un>() as libc::socklen_t
}

unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
// `new_length` is only used on some platforms, so it must be provided even when not used
#![allow(unused_variables)]
cfg_if! {
if #[cfg(any(target_os = "android",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "redox",
))] {
self.sun_len = new_length as u8;
}
};
Ok(())
}
}

impl AsRef<libc::sockaddr_un> for UnixAddr {
Expand Down Expand Up @@ -914,7 +930,32 @@ pub trait SockaddrLike: private::SockaddrLikePriv {
{
mem::size_of::<Self>() as libc::socklen_t
}

/// Set the length of this socket address
///
/// This method may only be called on socket addresses whose lengths are dynamic, and it
/// returns an error if called on a type whose length is static.
///
/// # Safety
///
/// `new_length` must be a valid length for this type of address. Specifically, reads of that
/// length from `self` must be valid.
#[doc(hidden)]
unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

/// The error returned by [`SockaddrLike::set_length`] on an address whose length is statically
/// fixed.
#[derive(Copy, Clone, Debug)]
pub struct SocketAddressLengthNotDynamic;
impl fmt::Display for SocketAddressLengthNotDynamic {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Attempted to set length on socket whose length is statically fixed")
}
}
impl std::error::Error for SocketAddressLengthNotDynamic {}

impl private::SockaddrLikePriv for () {
fn as_mut_ptr(&mut self) -> *mut libc::sockaddr {
Expand Down Expand Up @@ -1360,6 +1401,15 @@ impl SockaddrLike for SockaddrStorage {
None => mem::size_of_val(self) as libc::socklen_t,
}
}

unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
match self.as_unix_addr_mut() {
Some(addr) => {
addr.set_length(new_length)
},
None => Err(SocketAddressLengthNotDynamic),
}
}
}

macro_rules! accessors {
Expand Down Expand Up @@ -1678,7 +1728,7 @@ impl PartialEq for SockaddrStorage {
}
}

mod private {
pub(super) mod private {
pub trait SockaddrLikePriv {
/// Returns a mutable raw pointer to the inner structure.
///
Expand Down Expand Up @@ -2215,7 +2265,6 @@ mod datalink {
&self.0
}
}

}
}

Expand Down
11 changes: 8 additions & 3 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1613,7 +1613,7 @@ impl<S> MultiHeaders<S> {
{
// we will be storing pointers to addresses inside mhdr - convert it into boxed
// slice so it can'be changed later by pushing anything into self.addresses
let mut addresses = vec![std::mem::MaybeUninit::uninit(); num_slices].into_boxed_slice();
let mut addresses = vec![std::mem::MaybeUninit::<S>::uninit(); num_slices].into_boxed_slice();

let msg_controllen = cmsg_buffer.as_ref().map_or(0, |v| v.capacity());

Expand Down Expand Up @@ -1918,7 +1918,7 @@ unsafe fn read_mhdr<'a, 'i, S>(
mhdr: msghdr,
r: isize,
msg_controllen: usize,
address: S,
mut address: S,
) -> RecvMsg<'a, 'i, S>
where S: SockaddrLike
{
Expand All @@ -1934,6 +1934,11 @@ unsafe fn read_mhdr<'a, 'i, S>(
}.as_ref()
};

// Ignore errors if this socket address has statically-known length
//
// This is to ensure that unix socket addresses have their length set appropriately.
let _ = address.set_length(mhdr.msg_namelen as usize);

RecvMsg {
bytes: r as usize,
cmsghdr,
Expand Down Expand Up @@ -1969,7 +1974,7 @@ unsafe fn pack_mhdr_to_receive<S>(
// initialize it.
let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed();
let p = mhdr.as_mut_ptr();
(*p).msg_name = (*address).as_mut_ptr() as *mut c_void;
(*p).msg_name = address as *mut c_void;
(*p).msg_namelen = S::size();
(*p).msg_iov = iov_buffer as *mut iovec;
(*p).msg_iovlen = iov_buffer_len as _;
Expand Down
43 changes: 43 additions & 0 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,49 @@ pub fn test_socketpair() {
assert_eq!(&buf[..], b"hello");
}

#[test]
pub fn test_recvmsg_sockaddr_un() {
use nix::sys::socket::{
self, bind, socket, AddressFamily, MsgFlags, SockFlag, SockType,
};

let tempdir = tempfile::tempdir().unwrap();
let sockname = tempdir.path().join("sock");
let sock = socket(
AddressFamily::Unix,
SockType::Datagram,
SockFlag::empty(),
None,
)
.expect("socket failed");
let sockaddr = UnixAddr::new(&sockname).unwrap();
bind(sock, &sockaddr).expect("bind failed");

// Send a message
let send_buffer = "hello".as_bytes();
if let Err(e) = socket::sendmsg(
sock,
&[std::io::IoSlice::new(send_buffer)],
&[],
MsgFlags::empty(),
Some(&sockaddr),
) {
crate::skip!("Couldn't send ({e:?}), so skipping test");
}

// Receive the message
let mut recv_buffer = [0u8; 32];
let received = socket::recvmsg(
sock,
&mut [std::io::IoSliceMut::new(&mut recv_buffer)],
None,
MsgFlags::empty(),
)
.unwrap();
// Check the address in the received message
assert_eq!(sockaddr, received.address.unwrap());
}

#[test]
pub fn test_std_conversions() {
use nix::sys::socket::*;
Expand Down