From b9c2a1b926c5cd680e7a8d14f2ba1370ba48962a Mon Sep 17 00:00:00 2001 From: Damien Deville Date: Fri, 3 Nov 2023 14:54:38 +0100 Subject: [PATCH] udp: add support for ECN on Windows --- quinn-proto/src/connection/mod.rs | 1 + quinn-udp/Cargo.toml | 3 +- quinn-udp/src/cmsg.rs | 113 --------- quinn-udp/src/cmsg/mod.rs | 142 +++++++++++ quinn-udp/src/cmsg/unix.rs | 53 +++++ quinn-udp/src/cmsg/windows.rs | 112 +++++++++ quinn-udp/src/lib.rs | 3 +- quinn-udp/src/unix.rs | 18 +- quinn-udp/src/windows.rs | 382 ++++++++++++++++++++++++------ quinn/src/connection.rs | 1 + quinn/src/tests.rs | 5 +- 11 files changed, 639 insertions(+), 194 deletions(-) delete mode 100644 quinn-udp/src/cmsg.rs create mode 100644 quinn-udp/src/cmsg/mod.rs create mode 100644 quinn-udp/src/cmsg/unix.rs create mode 100644 quinn-udp/src/cmsg/windows.rs mode change 100644 => 100755 quinn/src/tests.rs diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 4798a12111..0179afe53c 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -1209,6 +1209,7 @@ impl Connection { /// Retrieving the local IP address is currently supported on the following /// platforms: /// - Linux + /// - Windows /// /// On all non-supported platforms the local IP address will not be available, /// and the method will return `None`. diff --git a/quinn-udp/Cargo.toml b/quinn-udp/Cargo.toml index 0730a43115..68f387618e 100644 --- a/quinn-udp/Cargo.toml +++ b/quinn-udp/Cargo.toml @@ -28,4 +28,5 @@ socket2 = "0.5" tracing = "0.1.10" [target.'cfg(windows)'.dependencies] -windows-sys = { version = "0.52.0", features = ["Win32_Networking_WinSock"] } +windows-sys = { version = "0.52.0", features = ["Win32_Foundation", "Win32_System_IO", "Win32_Networking_WinSock"] } +once_cell = "1.19.0" diff --git a/quinn-udp/src/cmsg.rs b/quinn-udp/src/cmsg.rs deleted file mode 100644 index 01d1997265..0000000000 --- a/quinn-udp/src/cmsg.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::{mem, ptr}; - -#[derive(Copy, Clone)] -#[repr(align(8))] // Conservative bound for align_of -pub(crate) struct Aligned(pub(crate) T); - -/// Helper to encode a series of control messages ("cmsgs") to a buffer for use in `sendmsg`. -/// -/// The operation must be "finished" for the msghdr to be usable, either by calling `finish` -/// explicitly or by dropping the `Encoder`. -pub(crate) struct Encoder<'a> { - hdr: &'a mut libc::msghdr, - cmsg: Option<&'a mut libc::cmsghdr>, - len: usize, -} - -impl<'a> Encoder<'a> { - /// # Safety - /// - `hdr.msg_control` must be a suitably aligned pointer to `hdr.msg_controllen` bytes that - /// can be safely written - /// - The `Encoder` must be dropped before `hdr` is passed to a system call, and must not be leaked. - pub(crate) unsafe fn new(hdr: &'a mut libc::msghdr) -> Self { - Self { - cmsg: libc::CMSG_FIRSTHDR(hdr).as_mut(), - hdr, - len: 0, - } - } - - /// Append a control message to the buffer. - /// - /// # Panics - /// - If insufficient buffer space remains. - /// - If `T` has stricter alignment requirements than `cmsghdr` - pub(crate) fn push(&mut self, level: libc::c_int, ty: libc::c_int, value: T) { - assert!(mem::align_of::() <= mem::align_of::()); - let space = unsafe { libc::CMSG_SPACE(mem::size_of_val(&value) as _) as usize }; - #[allow(clippy::unnecessary_cast)] // hdr.msg_controllen defined as size_t - { - assert!( - self.hdr.msg_controllen as usize >= self.len + space, - "control message buffer too small. Required: {}, Available: {}", - self.len + space, - self.hdr.msg_controllen - ); - } - let cmsg = self.cmsg.take().expect("no control buffer space remaining"); - cmsg.cmsg_level = level; - cmsg.cmsg_type = ty; - cmsg.cmsg_len = unsafe { libc::CMSG_LEN(mem::size_of_val(&value) as _) } as _; - unsafe { - ptr::write(libc::CMSG_DATA(cmsg) as *const T as *mut T, value); - } - self.len += space; - self.cmsg = unsafe { libc::CMSG_NXTHDR(self.hdr, cmsg).as_mut() }; - } - - /// Finishes appending control messages to the buffer - pub(crate) fn finish(self) { - // Delegates to the `Drop` impl - } -} - -// Statically guarantees that the encoding operation is "finished" before the control buffer is read -// by `sendmsg`. -impl<'a> Drop for Encoder<'a> { - fn drop(&mut self) { - self.hdr.msg_controllen = self.len as _; - } -} - -/// # Safety -/// -/// `cmsg` must refer to a cmsg containing a payload of type `T` -pub(crate) unsafe fn decode(cmsg: &libc::cmsghdr) -> T { - assert!(mem::align_of::() <= mem::align_of::()); - #[allow(clippy::unnecessary_cast)] // cmsg.cmsg_len defined as size_t - { - debug_assert_eq!( - cmsg.cmsg_len as usize, - libc::CMSG_LEN(mem::size_of::() as _) as usize - ); - } - ptr::read(libc::CMSG_DATA(cmsg) as *const T) -} - -pub(crate) struct Iter<'a> { - hdr: &'a libc::msghdr, - cmsg: Option<&'a libc::cmsghdr>, -} - -impl<'a> Iter<'a> { - /// # Safety - /// - /// `hdr.msg_control` must point to memory outliving `'a` which can be soundly read for the - /// lifetime of the constructed `Iter` and contains a buffer of cmsgs, i.e. is aligned for - /// `cmsghdr`, is fully initialized, and has correct internal links. - pub(crate) unsafe fn new(hdr: &'a libc::msghdr) -> Self { - Self { - hdr, - cmsg: libc::CMSG_FIRSTHDR(hdr).as_ref(), - } - } -} - -impl<'a> Iterator for Iter<'a> { - type Item = &'a libc::cmsghdr; - fn next(&mut self) -> Option<&'a libc::cmsghdr> { - let current = self.cmsg.take()?; - self.cmsg = unsafe { libc::CMSG_NXTHDR(self.hdr, current).as_ref() }; - Some(current) - } -} diff --git a/quinn-udp/src/cmsg/mod.rs b/quinn-udp/src/cmsg/mod.rs new file mode 100644 index 0000000000..cc5ecdc7f7 --- /dev/null +++ b/quinn-udp/src/cmsg/mod.rs @@ -0,0 +1,142 @@ +use std::{ + ffi::{c_int, c_uchar}, + mem, ptr, +}; + +#[cfg(unix)] +#[path = "unix.rs"] +mod imp; + +#[cfg(windows)] +#[path = "windows.rs"] +mod imp; + +pub(crate) use imp::Aligned; + +// Helper traits for native types for control messages +pub(crate) trait MsgHdr { + type ControlMessage: CMsgHdr; + + fn control_len(&self) -> usize; + + fn set_control_len(&mut self, len: usize); + + fn cmsg_firsthdr(&self) -> *mut Self::ControlMessage; + + fn cmsg_nxthdr(&self, cmsg: &Self::ControlMessage) -> *mut Self::ControlMessage; +} + +pub(crate) trait CMsgHdr { + fn set(&mut self, level: c_int, ty: c_int, len: usize); + + fn len(&self) -> usize; + + fn cmsg_space(length: usize) -> usize; + + fn cmsg_len(length: usize) -> usize; + + fn cmsg_data(&self) -> *mut c_uchar; +} + +/// Helper to encode a series of control messages (native "cmsgs") to a buffer for use in `sendmsg` +// like API. +/// +/// The operation must be "finished" for the native msghdr to be usable, either by calling `finish` +/// explicitly or by dropping the `Encoder`. +pub(crate) struct Encoder<'a, M: MsgHdr> { + hdr: &'a mut M, + cmsg: Option<&'a mut M::ControlMessage>, + len: usize, +} + +impl<'a, M: MsgHdr> Encoder<'a, M> { + /// # Safety + /// - `hdr` must contains a suitably aligned pointer to a buffer big enought to hold the control messages bytes + // that can be safely written + /// - The `Encoder` must be dropped before `hdr` is passed to a system call, and must not be leaked. + pub(crate) unsafe fn new(hdr: &'a mut M) -> Self { + Self { + cmsg: hdr.cmsg_firsthdr().as_mut(), + hdr, + len: 0, + } + } + + /// Append a control message to the buffer. + /// + /// # Panics + /// - If insufficient buffer space remains. + /// - If `T` has stricter alignment requirements than native `cmsghdr` + pub(crate) fn push(&mut self, level: c_int, ty: c_int, value: T) { + assert!(mem::align_of::() <= mem::align_of::()); + let space = M::ControlMessage::cmsg_space(mem::size_of_val(&value)); + assert!( + self.hdr.control_len() >= self.len + space, + "control message buffer too small. Required: {}, Available: {}", + self.len + space, + self.hdr.control_len() + ); + let cmsg = self.cmsg.take().expect("no control buffer space remaining"); + cmsg.set( + level, + ty, + M::ControlMessage::cmsg_len(mem::size_of_val(&value)), + ); + unsafe { + ptr::write(cmsg.cmsg_data() as *const T as *mut T, value); + } + self.len += space; + self.cmsg = unsafe { self.hdr.cmsg_nxthdr(cmsg).as_mut() }; + } + + /// Finishes appending control messages to the buffer + pub(crate) fn finish(self) { + // Delegates to the `Drop` impl + } +} + +// Statically guarantees that the encoding operation is "finished" before the control buffer is read +// by `sendmsg` like API. +impl<'a, M: MsgHdr> Drop for Encoder<'a, M> { + fn drop(&mut self) { + self.hdr.set_control_len(self.len as _); + } +} + +/// # Safety +/// +/// `cmsg` must refer to a native cmsg containing a payload of type `T` +pub(crate) unsafe fn decode(cmsg: &C) -> T { + assert!(mem::align_of::() <= mem::align_of::()); + debug_assert_eq!(cmsg.len(), C::cmsg_len(mem::size_of::())); + ptr::read(cmsg.cmsg_data() as *const T) +} + +pub(crate) struct Iter<'a, M: MsgHdr> { + hdr: &'a M, + cmsg: Option<&'a M::ControlMessage>, +} + +impl<'a, M: MsgHdr> Iter<'a, M> { + /// # Safety + /// + /// `hdr` must hold a pointer to memory outliving `'a` which can be soundly read for the + /// lifetime of the constructed `Iter` and contains a buffer of native cmsgs, i.e. is aligned + // for native `cmsghdr`, is fully initialized, and has correct internal links. + pub(crate) unsafe fn new(hdr: &'a M) -> Self { + Self { + hdr, + cmsg: hdr.cmsg_firsthdr().as_ref(), + } + } +} + +impl<'a, M: MsgHdr> Iterator for Iter<'a, M> { + type Item = &'a M::ControlMessage; + + fn next(&mut self) -> Option { + let current = self.cmsg.take()?; + self.cmsg = unsafe { self.hdr.cmsg_nxthdr(current).as_ref() }; + Some(current) + } +} diff --git a/quinn-udp/src/cmsg/unix.rs b/quinn-udp/src/cmsg/unix.rs new file mode 100644 index 0000000000..bed8a293f2 --- /dev/null +++ b/quinn-udp/src/cmsg/unix.rs @@ -0,0 +1,53 @@ +use std::ffi::{c_int, c_uchar}; + +use super::{CMsgHdr, MsgHdr}; + +#[derive(Copy, Clone)] +#[repr(align(8))] // Conservative bound for align_of +pub(crate) struct Aligned(pub(crate) T); + +/// Helpers for [`libc::msghdr`] +impl MsgHdr for libc::msghdr { + type ControlMessage = libc::cmsghdr; + + fn control_len(&self) -> usize { + self.msg_controllen as _ + } + + fn set_control_len(&mut self, len: usize) { + self.msg_controllen = len as _; + } + + fn cmsg_firsthdr(&self) -> *mut Self::ControlMessage { + unsafe { libc::CMSG_FIRSTHDR(self) } + } + + fn cmsg_nxthdr(&self, cmsg: &Self::ControlMessage) -> *mut Self::ControlMessage { + unsafe { libc::CMSG_NXTHDR(self, cmsg) } + } +} + +/// Helpers for [`libc::cmsghdr`] +impl CMsgHdr for libc::cmsghdr { + fn set(&mut self, level: c_int, ty: c_int, len: usize) { + self.cmsg_level = level as _; + self.cmsg_type = ty as _; + self.cmsg_len = len as _; + } + + fn len(&self) -> usize { + self.cmsg_len as _ + } + + fn cmsg_space(length: usize) -> usize { + unsafe { libc::CMSG_SPACE(length as _) as usize } + } + + fn cmsg_len(length: usize) -> usize { + unsafe { libc::CMSG_LEN(length as _) as usize } + } + + fn cmsg_data(&self) -> *mut c_uchar { + unsafe { libc::CMSG_DATA(self) } + } +} diff --git a/quinn-udp/src/cmsg/windows.rs b/quinn-udp/src/cmsg/windows.rs new file mode 100644 index 0000000000..63adcf9415 --- /dev/null +++ b/quinn-udp/src/cmsg/windows.rs @@ -0,0 +1,112 @@ +use std::ffi::{c_int, c_uchar}; + +use windows_sys::Win32::Networking::WinSock; + +use super::{CMsgHdr, MsgHdr}; + +#[derive(Copy, Clone)] +#[repr(align(8))] // Conservative bound for align_of +pub(crate) struct Aligned(pub(crate) T); + +/// Helpers for [`WinSock::WSAMSG`] +// https://learn.microsoft.com/en-us/windows/win32/api/ws2def/ns-ws2def-wsamsg +// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/struct.WSAMSG.html +impl MsgHdr for WinSock::WSAMSG { + type ControlMessage = WinSock::CMSGHDR; + + fn control_len(&self) -> usize { + self.Control.len as _ + } + + fn set_control_len(&mut self, len: usize) { + self.Control.len = len as _; + } + + fn cmsg_firsthdr(&self) -> *mut Self::ControlMessage { + unsafe { self::wsa2::cmsg_firsthdr(self) } + } + + fn cmsg_nxthdr(&self, cmsg: &Self::ControlMessage) -> *mut Self::ControlMessage { + unsafe { self::wsa2::cmsg_nxthdr(self, cmsg) } + } +} + +/// Helpers for [`WinSock::CMSGHDR`] +// https://learn.microsoft.com/en-us/windows/win32/api/ws2def/ns-ws2def-wsacmsghdr +// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/struct.CMSGHDR.html +impl CMsgHdr for WinSock::CMSGHDR { + fn set(&mut self, level: c_int, ty: c_int, len: usize) { + self.cmsg_level = level as _; + self.cmsg_type = ty as _; + self.cmsg_len = len as _; + } + + fn len(&self) -> usize { + self.cmsg_len as _ + } + + fn cmsg_space(length: usize) -> usize { + self::wsa2::cmsg_space(length) + } + + fn cmsg_len(length: usize) -> usize { + self::wsa2::cmsg_len(length) + } + + fn cmsg_data(&self) -> *mut c_uchar { + unsafe { self::wsa2::cmsg_data(self) } + } +} + +mod wsa2 { + use std::{mem, ptr}; + + use windows_sys::Win32::Networking::WinSock; + + // Helpers functions based on C macros from + // https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h#L741 + fn cmsghdr_align(length: usize) -> usize { + (length + mem::align_of::() - 1) + & !(mem::align_of::() - 1) + } + + fn cmsgdata_align(length: usize) -> usize { + (length + mem::align_of::() - 1) & !(mem::align_of::() - 1) + } + + pub(crate) unsafe fn cmsg_firsthdr(msg: *const WinSock::WSAMSG) -> *mut WinSock::CMSGHDR { + if (*msg).Control.len as usize >= mem::size_of::() { + (*msg).Control.buf as *mut WinSock::CMSGHDR + } else { + ptr::null_mut::() + } + } + + pub(crate) unsafe fn cmsg_nxthdr( + hdr: &WinSock::WSAMSG, + cmsg: *const WinSock::CMSGHDR, + ) -> *mut WinSock::CMSGHDR { + if cmsg.is_null() { + return cmsg_firsthdr(hdr); + } + let next = (cmsg as usize + cmsghdr_align((*cmsg).cmsg_len)) as *mut WinSock::CMSGHDR; + let max = hdr.Control.buf as usize + hdr.Control.len as usize; + if (next.offset(1)) as usize > max { + ptr::null_mut() + } else { + next + } + } + + pub(crate) unsafe fn cmsg_data(cmsg: *const WinSock::CMSGHDR) -> *mut u8 { + (cmsg as usize + cmsgdata_align(mem::size_of::())) as *mut u8 + } + + pub(crate) fn cmsg_space(length: usize) -> usize { + cmsgdata_align(mem::size_of::() + cmsghdr_align(length)) + } + + pub(crate) fn cmsg_len(length: usize) -> usize { + cmsgdata_align(mem::size_of::()) + length + } +} diff --git a/quinn-udp/src/lib.rs b/quinn-udp/src/lib.rs index 5adf622fde..c6c9a12dfb 100644 --- a/quinn-udp/src/lib.rs +++ b/quinn-udp/src/lib.rs @@ -15,8 +15,9 @@ use std::{ use bytes::Bytes; use tracing::warn; -#[cfg(unix)] +#[cfg(any(unix, windows))] mod cmsg; + #[cfg(unix)] #[path = "unix.rs"] mod imp; diff --git a/quinn-udp/src/unix.rs b/quinn-udp/src/unix.rs index c408161f4e..b59f0fdb37 100644 --- a/quinn-udp/src/unix.rs +++ b/quinn-udp/src/unix.rs @@ -668,7 +668,7 @@ fn decode_recv( match (cmsg.cmsg_level, cmsg.cmsg_type) { // FreeBSD uses IP_RECVTOS here, and we can be liberal because cmsgs are opt-in. (libc::IPPROTO_IP, libc::IP_TOS) | (libc::IPPROTO_IP, libc::IP_RECVTOS) => unsafe { - ecn_bits = cmsg::decode::(cmsg); + ecn_bits = cmsg::decode::(cmsg); }, (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => unsafe { // Temporary hack around broken macos ABI. Remove once upstream fixes it. @@ -677,30 +677,30 @@ fn decode_recv( if cfg!(target_os = "macos") && cmsg.cmsg_len as usize == libc::CMSG_LEN(mem::size_of::() as _) as usize { - ecn_bits = cmsg::decode::(cmsg); + ecn_bits = cmsg::decode::(cmsg); } else { - ecn_bits = cmsg::decode::(cmsg) as u8; + ecn_bits = cmsg::decode::(cmsg) as u8; } }, #[cfg(target_os = "linux")] (libc::IPPROTO_IP, libc::IP_PKTINFO) => { - let pktinfo = unsafe { cmsg::decode::(cmsg) }; + let pktinfo = unsafe { cmsg::decode::(cmsg) }; dst_ip = Some(IpAddr::V4(Ipv4Addr::from( pktinfo.ipi_addr.s_addr.to_ne_bytes(), ))); } #[cfg(any(target_os = "freebsd", target_os = "macos"))] (libc::IPPROTO_IP, libc::IP_RECVDSTADDR) => { - let in_addr = unsafe { cmsg::decode::(cmsg) }; + let in_addr = unsafe { cmsg::decode::(cmsg) }; dst_ip = Some(IpAddr::V4(Ipv4Addr::from(in_addr.s_addr.to_ne_bytes()))); } (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { - let pktinfo = unsafe { cmsg::decode::(cmsg) }; + let pktinfo = unsafe { cmsg::decode::(cmsg) }; dst_ip = Some(IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr))); } #[cfg(target_os = "linux")] (libc::SOL_UDP, libc::UDP_GRO) => unsafe { - stride = cmsg::decode::(cmsg) as usize; + stride = cmsg::decode::(cmsg) as usize; }, _ => {} } @@ -770,7 +770,7 @@ mod gso { } } - pub(crate) fn set_segment_size(encoder: &mut cmsg::Encoder, segment_size: u16) { + pub(crate) fn set_segment_size(encoder: &mut cmsg::Encoder, segment_size: u16) { encoder.push(libc::SOL_UDP, libc::UDP_SEGMENT, segment_size); } } @@ -783,7 +783,7 @@ mod gso { 1 } - pub(super) fn set_segment_size(_encoder: &mut cmsg::Encoder, _segment_size: u16) { + pub(super) fn set_segment_size(_encoder: &mut cmsg::Encoder, _segment_size: u16) { panic!("Setting a segment size is not supported on current platform"); } } diff --git a/quinn-udp/src/windows.rs b/quinn-udp/src/windows.rs index 4680d1a453..ecb8a046d3 100644 --- a/quinn-udp/src/windows.rs +++ b/quinn-udp/src/windows.rs @@ -1,14 +1,25 @@ use std::{ io::{self, IoSliceMut}, mem, + net::{IpAddr, Ipv4Addr}, os::windows::io::AsRawSocket, + ptr, sync::Mutex, time::Instant, }; +use once_cell::sync::OnceCell; use windows_sys::Win32::Networking::WinSock; -use super::{log_sendmsg_error, RecvMeta, Transmit, UdpSockRef, IO_ERROR_LOG_INTERVAL}; +use crate::EcnCodepoint; + +use super::{cmsg, log_sendmsg_error, RecvMeta, Transmit, UdpSockRef, IO_ERROR_LOG_INTERVAL}; + +// Enough to store max(IP_PKTINFO + IP_ECN, IPV6_PKTINFO + IPV6_ECN) bytes (header + data) and some extra margin +const CMSG_LEN: usize = 128; + +// FIXME this could use [`std::sync::OnceLock`] once the MSRV is bumped to 1.70 and upper +static WSARECVMSG_PTR: OnceCell = OnceCell::new(); /// QUIC-friendly UDP interface for Windows #[derive(Debug)] @@ -38,36 +49,53 @@ impl UdpSocketState { }; let is_ipv4 = addr.as_socket_ipv4().is_some() || !v6only; - let sock_true: u32 = 1; + let wsa_recvmsg_ptr = WSARECVMSG_PTR.get_or_init(|| get_wsarecvmsg_fn(&*socket.0)); + + // We do not support anymore old version of windows that do not give access to WSARecvMsg() function + if wsa_recvmsg_ptr.is_none() { + tracing::error!("Network stack does not support WSARecvMsg function"); + + return Err(io::Error::from(io::ErrorKind::Unsupported)); + } if is_ipv4 { - let rc = unsafe { - WinSock::setsockopt( - socket.0.as_raw_socket() as _, - WinSock::IPPROTO_IP as _, - WinSock::IP_DONTFRAGMENT as _, - &sock_true as *const _ as _, - mem::size_of_val(&sock_true) as _, - ) - }; - if rc == -1 { - return Err(io::Error::last_os_error()); - } + set_socket_option( + &*socket.0, + WinSock::IPPROTO_IP, + WinSock::IP_DONTFRAGMENT, + OPTION_ON, + )?; + + set_socket_option( + &*socket.0, + WinSock::IPPROTO_IP, + WinSock::IP_PKTINFO, + OPTION_ON, + )?; + set_socket_option(&*socket.0, WinSock::IPPROTO_IP, WinSock::IP_ECN, OPTION_ON)?; } if is_ipv6 { - let rc = unsafe { - WinSock::setsockopt( - socket.0.as_raw_socket() as _, - WinSock::IPPROTO_IPV6 as _, - WinSock::IPV6_DONTFRAG as _, - &sock_true as *const _ as _, - mem::size_of_val(&sock_true) as _, - ) - }; - if rc == -1 { - return Err(io::Error::last_os_error()); - } + set_socket_option( + &*socket.0, + WinSock::IPPROTO_IPV6, + WinSock::IPV6_DONTFRAG, + OPTION_ON, + )?; + + set_socket_option( + &*socket.0, + WinSock::IPPROTO_IPV6, + WinSock::IPV6_PKTINFO, + OPTION_ON, + )?; + + set_socket_option( + &*socket.0, + WinSock::IPPROTO_IPV6, + WinSock::IPV6_ECN, + OPTION_ON, + )?; } let now = Instant::now(); @@ -76,33 +104,8 @@ impl UdpSocketState { }) } - pub fn send(&self, socket: UdpSockRef<'_>, transmits: &[Transmit]) -> Result { - let mut sent = 0; - for transmit in transmits { - match socket.0.send_to( - &transmit.contents, - &socket2::SockAddr::from(transmit.destination), - ) { - Ok(_) => { - sent += 1; - } - // We need to report that some packets were sent in this case, so we rely on - // errors being either harmlessly transient (in the case of WouldBlock) or - // recurring on the next call. - Err(_) if sent != 0 => return Ok(sent), - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - return Err(e); - } - - // Other errors are ignored, since they will usually be handled - // by higher level retransmits and timeouts. - log_sendmsg_error(&self.last_send_error, e, transmit); - sent += 1; - } - } - } - Ok(sent) + pub fn send(&self, socket: UdpSockRef<'_>, transmits: &[Transmit]) -> io::Result { + send(self, socket, transmits) } pub fn recv( @@ -111,22 +114,7 @@ impl UdpSocketState { bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta], ) -> io::Result { - // Safety: both `IoSliceMut` and `MaybeUninitSlice` promise to have the - // same layout, that of `iovec`/`WSABUF`. Furthermore `recv_vectored` - // promises to not write unitialised bytes to the `bufs` and pass it - // directly to the `recvmsg` system call, so this is safe. - let bufs = unsafe { - &mut *(bufs as *mut [IoSliceMut<'_>] as *mut [socket2::MaybeUninitSlice<'_>]) - }; - let (len, _flags, addr) = socket.0.recv_from_vectored(bufs)?; - meta[0] = RecvMeta { - len, - stride: len, - addr: addr.as_socket().unwrap(), - ecn: None, - dst_ip: None, - }; - Ok(1) + recv(socket, bufs, meta) } /// The maximum amount of segments which can be transmitted if a platform @@ -154,4 +142,260 @@ impl UdpSocketState { } } +fn send( + state: &UdpSocketState, + socket: UdpSockRef<'_>, + transmits: &[Transmit], +) -> io::Result { + let mut sent = 0; + for transmit in transmits { + // we cannot use [`socket2::sendmsg()`] and [`socket2::MsgHdr`] as we do not have access + // to the inner field which holds the WSAMSG + let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]); + let daddr = socket2::SockAddr::from(transmit.destination); + + let mut data = WinSock::WSABUF { + buf: transmit.contents.as_ptr() as *mut _, + len: transmit.contents.len() as _, + }; + + let ctrl = WinSock::WSABUF { + buf: ctrl_buf.0.as_mut_ptr(), + len: ctrl_buf.0.len() as _, + }; + + let mut wsa_msg = WinSock::WSAMSG { + name: daddr.as_ptr() as *mut _, + namelen: daddr.len(), + lpBuffers: &mut data, + Control: ctrl, + dwBufferCount: 1, + dwFlags: 0, + }; + + // Add control messages (ECN and PKTINFO) + let mut encoder = unsafe { cmsg::Encoder::new(&mut wsa_msg) }; + + if let Some(ip) = transmit.src_ip { + let ip = std::net::SocketAddr::new(ip, 0); + let ip = socket2::SockAddr::from(ip); + match ip.family() { + WinSock::AF_INET => { + let src_ip: WinSock::SOCKADDR_IN = unsafe { ptr::read(ip.as_ptr() as _) }; + let pktinfo = WinSock::IN_PKTINFO { + ipi_addr: src_ip.sin_addr, + ipi_ifindex: 0, + }; + encoder.push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo); + } + WinSock::AF_INET6 => { + let src_ip: WinSock::SOCKADDR_IN6 = unsafe { ptr::read(ip.as_ptr() as _) }; + let pktinfo = WinSock::IN6_PKTINFO { + ipi6_addr: src_ip.sin6_addr, + ipi6_ifindex: unsafe { src_ip.Anonymous.sin6_scope_id }, + }; + encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo); + } + _ => { + return Err(io::Error::from(io::ErrorKind::InvalidInput)); + } + } + } + + // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn + let ecn = transmit.ecn.map_or(0, |x| x as i32); + if transmit.destination.is_ipv4() { + encoder.push(WinSock::IPPROTO_IP, WinSock::IP_TOS, ecn); + } else { + encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, ecn); + } + + encoder.finish(); + + let mut len = 0; + let rc = unsafe { + WinSock::WSASendMsg( + socket.0.as_raw_socket() as usize, + &wsa_msg, + 0, + &mut len, + ptr::null_mut(), + None, + ) + }; + + if rc == 0 { + sent += 1; + } else if sent != 0 { + // We need to report that some packets were sent in this case, so we rely on + // errors being either harmlessly transient (in the case of WouldBlock) or + // recurring on the next call. + return Ok(sent); + } else if rc == WinSock::WSAEWOULDBLOCK { + return Err(io::Error::last_os_error()); + } else { + // Other errors are ignored, since they will usually be handled + // by higher level retransmits and timeouts. + log_sendmsg_error(&state.last_send_error, io::Error::last_os_error(), transmit); + sent += 1; + } + } + Ok(sent) +} + +fn recv( + socket: UdpSockRef<'_>, + bufs: &mut [IoSliceMut<'_>], + meta: &mut [RecvMeta], +) -> io::Result { + let wsa_recvmsg_ptr = WSARECVMSG_PTR + .get_or_init(|| get_wsarecvmsg_fn(&*socket.0)) + .expect("Valid function pointer for WSARecvMsg"); + + // we cannot use [`socket2::MsgHdrMut`] as we do not have access to inner field which holds the WSAMSG + let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]); + let mut source: WinSock::SOCKADDR_INET = unsafe { mem::zeroed() }; + let mut data = WinSock::WSABUF { + buf: bufs[0].as_mut_ptr(), + len: bufs[0].len() as _, + }; + + let ctrl = WinSock::WSABUF { + buf: ctrl_buf.0.as_mut_ptr(), + len: ctrl_buf.0.len() as _, + }; + + let mut wsa_msg = WinSock::WSAMSG { + name: &mut source as *mut _ as *mut _, + namelen: mem::size_of_val(&source) as _, + lpBuffers: &mut data, + Control: ctrl, + dwBufferCount: 1, + dwFlags: 0, + }; + + // FIXME add Safety: ? + let mut len = 0; + unsafe { + let rc = (wsa_recvmsg_ptr)( + socket.0.as_raw_socket() as usize, + &mut wsa_msg, + &mut len, + ptr::null_mut(), + None, + ); + if rc == -1 { + return Err(io::Error::last_os_error()); + } + } + + // FIXME add Safety: ? + let addr = unsafe { + let (_, addr) = socket2::SockAddr::try_init(|addr_storage, len| { + *len = mem::size_of_val(&source) as _; + ptr::copy_nonoverlapping(&source, addr_storage as _, 1); + Ok(()) + })?; + addr.as_socket() + }; + + // Decode control messages (PKTINFO and ECN) + let mut ecn_bits = 0; + let mut dst_ip = None; + + let cmsg_iter = unsafe { cmsg::Iter::new(&wsa_msg) }; + for cmsg in cmsg_iter { + // [header (len)][data][padding(len + sizeof(data))] -> [header][data][padding] + match (cmsg.cmsg_level, cmsg.cmsg_type) { + (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => { + let pktinfo = + unsafe { cmsg::decode::(cmsg) }; + // Addr is stored in big endian format + let ip4 = Ipv4Addr::from(u32::from_be(unsafe { pktinfo.ipi_addr.S_un.S_addr })); + dst_ip = Some(ip4.into()); + } + (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => { + let pktinfo = + unsafe { cmsg::decode::(cmsg) }; + // Addr is stored in big endian format + dst_ip = Some(IpAddr::from(unsafe { pktinfo.ipi6_addr.u.Byte })); + } + (WinSock::IPPROTO_IP, WinSock::IP_ECN) => { + // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn + ecn_bits = unsafe { cmsg::decode::(cmsg) }; + } + (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => { + // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn + ecn_bits = unsafe { cmsg::decode::(cmsg) }; + } + _ => {} + } + } + + meta[0] = RecvMeta { + len: len as usize, + stride: len as usize, + addr: addr.unwrap(), + ecn: EcnCodepoint::from_bits(ecn_bits as u8), + dst_ip, + }; + Ok(1) +} + pub(crate) const BATCH_SIZE: usize = 1; + +fn get_wsarecvmsg_fn(socket: &impl AsRawSocket) -> WinSock::LPFN_WSARECVMSG { + // Detect if OS expose WSARecvMsg API based on + // https://github.com/Azure/mio-uds-windows/blob/a3c97df82018086add96d8821edb4aa85ec1b42b/src/stdnet/ext.rs#L601 + let guid = WinSock::WSAID_WSARECVMSG; + let mut wsa_recvmsg_ptr = None; + let mut len = 0; + + // Safety: Option handles the NULL pointer with a None value + let rc = unsafe { + WinSock::WSAIoctl( + socket.as_raw_socket() as _, + WinSock::SIO_GET_EXTENSION_FUNCTION_POINTER, + &guid as *const _ as *const _, + mem::size_of_val(&guid) as u32, + &mut wsa_recvmsg_ptr as *mut _ as *mut _, + mem::size_of_val(&wsa_recvmsg_ptr) as u32, + &mut len, + ptr::null_mut(), + None, + ) + }; + + if rc == -1 { + tracing::debug!("Ignoring WSARecvMsg function pointer due to ioctl error"); + } else if len as usize != mem::size_of::() { + tracing::debug!("Ignoring WSARecvMsg function pointer due to pointer size mismatch"); + wsa_recvmsg_ptr = None; + } + + wsa_recvmsg_ptr +} + +fn set_socket_option( + socket: &impl AsRawSocket, + level: i32, + name: i32, + value: u32, +) -> Result<(), io::Error> { + let rc = unsafe { + WinSock::setsockopt( + socket.as_raw_socket() as usize, + level, + name, + &value as *const _ as _, + mem::size_of_val(&value) as _, + ) + }; + + match rc == 0 { + true => Ok(()), + false => Err(io::Error::last_os_error()), + } +} + +const OPTION_ON: u32 = 1; diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 5bed610cb5..f2c60a65a6 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -149,6 +149,7 @@ impl Connecting { /// - Linux /// - FreeBSD /// - macOS + /// - Windows /// /// On all non-supported platforms the local IP address will not be available, /// and the method will return `None`. diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs old mode 100644 new mode 100755 index 3436cb68d1..05c0ebf5e8 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -479,7 +479,10 @@ fn run_echo(args: EchoArgs) { // If `local_ip` gets available on additional platforms - which // requires modifying this test - please update the list of supported // platforms in the doc comments of the various `local_ip` functions. - if cfg!(target_os = "linux") || cfg!(target_os = "freebsd") || cfg!(target_os = "macos") + if cfg!(target_os = "linux") + || cfg!(target_os = "freebsd") + || cfg!(target_os = "macos") + || cfg!(target_os = "windows") { let local_ip = incoming.local_ip().expect("Local IP must be available"); assert!(local_ip.is_loopback());