From 1574cec099396b6b41e0a86d90fbba1eb2ee28e6 Mon Sep 17 00:00:00 2001 From: FujiApple Date: Wed, 26 Oct 2022 18:28:09 +0800 Subject: [PATCH] refactor: move platform specific code to platform specific modules --- Cargo.toml | 5 +- src/tracing/net/channel.rs | 55 ++----- src/tracing/net/ipv4.rs | 50 +------ src/tracing/net/ipv6.rs | 45 +----- src/tracing/net/platform.rs | 117 --------------- src/tracing/net/platform/byte_order.rs | 60 ++++++++ src/tracing/net/platform/mod.rs | 14 ++ src/tracing/net/platform/unix.rs | 200 +++++++++++++++++++++++++ src/tracing/net/platform/windows.rs | 81 ++++++++++ 9 files changed, 380 insertions(+), 247 deletions(-) delete mode 100644 src/tracing/net/platform.rs create mode 100644 src/tracing/net/platform/byte_order.rs create mode 100644 src/tracing/net/platform/mod.rs create mode 100644 src/tracing/net/platform/unix.rs create mode 100644 src/tracing/net/platform/windows.rs diff --git a/Cargo.toml b/Cargo.toml index 893f0b20..0687c26f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,6 @@ name = "trip" # Library dependencies socket2 = { version = "0.4.7", features = [ "all" ] } -nix = { version = "0.25.0", default-features = false, features = [ "user", "poll", "net" ] } thiserror = "1.0.32" derive_more = "0.99.17" arrayvec = "0.7.2" @@ -46,5 +45,9 @@ comfy-table = "6.1.1" [target.'cfg(target_os = "linux")'.dependencies] caps = "0.5.4" +# Library dependancies (Unix) +[target.'cfg(target_family = "unix")'.dependencies] +nix = { version = "0.25.0", default-features = false, features = [ "user", "poll", "net" ] } + [dev-dependencies] rand = "0.8.5" \ No newline at end of file diff --git a/src/tracing/net/channel.rs b/src/tracing/net/channel.rs index 8f7db8f6..837f3f80 100644 --- a/src/tracing/net/channel.rs +++ b/src/tracing/net/channel.rs @@ -1,7 +1,7 @@ use crate::tracing::error::TracerError::InvalidSourceAddr; use crate::tracing::error::{TraceResult, TracerError}; use crate::tracing::net::platform::PlatformIpv4FieldByteOrder; -use crate::tracing::net::{ipv4, ipv6, Network}; +use crate::tracing::net::{ipv4, ipv6, platform, Network}; use crate::tracing::probe::ProbeResponse; use crate::tracing::types::{PacketSize, PayloadPattern, Port, Sequence, TraceId, TypeOfService}; use crate::tracing::util::Required; @@ -10,11 +10,8 @@ use crate::tracing::{ }; use arrayvec::ArrayVec; use itertools::Itertools; -use nix::sys::select::FdSet; -use nix::sys::time::{TimeVal, TimeValLike}; use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use std::net::{IpAddr, SocketAddr}; -use std::os::unix::io::AsRawFd; use std::time::{Duration, SystemTime}; /// The maximum size of the IP packet we allow. @@ -196,7 +193,7 @@ impl TracerChannel { /// Generate a `ProbeResponse` for the next available ICMP packet, if any fn recv_icmp_probe(&mut self) -> TraceResult> { - if is_readable(&self.recv_socket, self.read_timeout)? { + if platform::is_readable(&self.recv_socket, self.read_timeout)? { match self.addr_family { TracerAddrFamily::Ipv4 => ipv4::recv_icmp_probe( &mut self.recv_socket, @@ -222,7 +219,7 @@ impl TracerChannel { let found_index = self .tcp_probes .iter() - .find_position(|&probe| is_writable(&probe.socket).unwrap_or_default()) + .find_position(|&probe| platform::is_writable(&probe.socket).unwrap_or_default()) .map(|(i, _)| i); if let Some(i) = found_index { let probe = self.tcp_probes.remove(i); @@ -249,36 +246,6 @@ impl TcpProbe { } } -/// Returns true if the socket becomes readable before the timeout, false otherwise. -fn is_readable(sock: &Socket, timeout: Duration) -> TraceResult { - let mut read = FdSet::new(); - read.insert(sock.as_raw_fd()); - let readable = nix::sys::select::select( - None, - Some(&mut read), - None, - None, - Some(&mut TimeVal::milliseconds(timeout.as_millis() as i64)), - ) - .map_err(|err| TracerError::IoError(std::io::Error::from(err)))?; - Ok(readable == 1) -} - -/// Returns true if the socket is currently writeable, false otherwise. -fn is_writable(sock: &Socket) -> TraceResult { - let mut write = FdSet::new(); - write.insert(sock.as_raw_fd()); - let writable = nix::sys::select::select( - None, - None, - Some(&mut write), - None, - Some(&mut TimeVal::zero()), - ) - .map_err(|err| TracerError::IoError(std::io::Error::from(err)))?; - Ok(writable == 1) -} - /// Validate, Lookup or discover the source `IpAddr`. fn make_src_addr( source_addr: Option, @@ -302,8 +269,8 @@ fn make_src_addr( /// Lookup the address for a named interface. fn lookup_interface_addr(addr_family: TracerAddrFamily, name: &str) -> TraceResult { match addr_family { - TracerAddrFamily::Ipv4 => ipv4::lookup_interface_addr(name), - TracerAddrFamily::Ipv6 => ipv6::lookup_interface_addr(name), + TracerAddrFamily::Ipv4 => platform::lookup_interface_addr_ipv4(name), + TracerAddrFamily::Ipv6 => platform::lookup_interface_addr_ipv6(name), } } @@ -341,23 +308,23 @@ fn udp_socket_for_addr_family(addr_family: TracerAddrFamily) -> TraceResult TraceResult { match addr_family { - TracerAddrFamily::Ipv4 => ipv4::make_icmp_send_socket(), - TracerAddrFamily::Ipv6 => ipv6::make_icmp_send_socket(), + TracerAddrFamily::Ipv4 => platform::make_icmp_send_socket_ipv4(), + TracerAddrFamily::Ipv6 => platform::make_icmp_send_socket_ipv6(), } } /// Make a socket for sending `UDP` packets. fn make_udp_send_socket(addr_family: TracerAddrFamily) -> TraceResult { match addr_family { - TracerAddrFamily::Ipv4 => ipv4::make_udp_send_socket(), - TracerAddrFamily::Ipv6 => ipv6::make_udp_send_socket(), + TracerAddrFamily::Ipv4 => platform::make_udp_send_socket_ipv4(), + TracerAddrFamily::Ipv6 => platform::make_udp_send_socket_ipv6(), } } /// Make a socket for receiving raw `ICMP` packets. fn make_recv_socket(addr_family: TracerAddrFamily) -> TraceResult { match addr_family { - TracerAddrFamily::Ipv4 => ipv4::make_recv_socket(), - TracerAddrFamily::Ipv6 => ipv6::make_recv_socket(), + TracerAddrFamily::Ipv4 => platform::make_recv_socket_ipv4(), + TracerAddrFamily::Ipv6 => platform::make_recv_socket_ipv6(), } } diff --git a/src/tracing/net/ipv4.rs b/src/tracing/net/ipv4.rs index 14520577..80955a22 100644 --- a/src/tracing/net/ipv4.rs +++ b/src/tracing/net/ipv4.rs @@ -1,6 +1,7 @@ use crate::tracing::error::TracerError::AddressNotAvailable; use crate::tracing::error::{TraceResult, TracerError}; use crate::tracing::net::channel::MAX_PACKET_SIZE; +use crate::tracing::net::platform; use crate::tracing::net::platform::PlatformIpv4FieldByteOrder; use crate::tracing::packet::checksum::{icmp_ipv4_checksum, udp_ipv4_checksum}; use crate::tracing::packet::icmpv4::destination_unreachable::DestinationUnreachablePacket; @@ -16,9 +17,7 @@ use crate::tracing::probe::{ProbeResponse, ProbeResponseData, TcpProbeResponseDa use crate::tracing::types::{PacketSize, PayloadPattern, Sequence, TraceId, TypeOfService}; use crate::tracing::util::Required; use crate::tracing::{MultipathStrategy, PortDirection, Probe, TracerProtocol}; -use nix::libc::IPPROTO_RAW; -use nix::sys::socket::{AddressFamily, SockaddrLike}; -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; +use socket2::{SockAddr, Socket}; use std::io::{ErrorKind, Read}; use std::net::{IpAddr, Ipv4Addr, Shutdown, SocketAddr}; use std::time::SystemTime; @@ -40,36 +39,6 @@ const MAX_ICMP_PAYLOAD_BUF: usize = MAX_ICMP_PACKET_BUF - IcmpPacket::minimum_pa /// 0100 0000 0000 0000 const DONT_FRAGMENT: u16 = 0x4000; -pub fn lookup_interface_addr(name: &str) -> TraceResult { - nix::ifaddrs::getifaddrs() - .map_err(|_| TracerError::UnknownInterface(name.to_string()))? - .into_iter() - .find_map(|ia| { - ia.address.and_then(|addr| match addr.family() { - Some(AddressFamily::Inet) if ia.interface_name == name => addr - .as_sockaddr_in() - .map(|sock_addr| IpAddr::V4(Ipv4Addr::from(sock_addr.ip()))), - _ => None, - }) - }) - .ok_or_else(|| TracerError::UnknownInterface(name.to_string())) -} - -pub fn make_icmp_send_socket() -> TraceResult { - make_raw_socket() -} - -pub fn make_udp_send_socket() -> TraceResult { - make_raw_socket() -} - -pub fn make_recv_socket() -> TraceResult { - let socket = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))?; - socket.set_nonblocking(true)?; - socket.set_header_included(true)?; - Ok(socket) -} - #[allow(clippy::too_many_arguments)] pub fn dispatch_icmp_probe( icmp_send_socket: &mut Socket, @@ -186,9 +155,7 @@ pub fn dispatch_tcp_probe( PortDirection::FixedDest(dest_port) => (probe.sequence.0, dest_port.0), PortDirection::FixedBoth(_, _) | PortDirection::None => unimplemented!(), }; - let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?; - socket.set_nonblocking(true)?; - socket.set_reuse_port(true)?; + let socket = platform::make_stream_socket_ipv4()?; let local_addr = SocketAddr::new(IpAddr::V4(src_addr), src_port); socket.bind(&SockAddr::from(local_addr))?; socket.set_ttl(u32::from(probe.ttl.0))?; @@ -198,7 +165,7 @@ pub fn dispatch_tcp_probe( Ok(_) => {} Err(err) => { if let Some(code) = err.raw_os_error() { - if nix::Error::from_i32(code) != nix::Error::EINPROGRESS { + if platform::is_in_progress_error(code) { return match err.kind() { ErrorKind::AddrInUse | ErrorKind::AddrNotAvailable => { Err(AddressNotAvailable(local_addr)) @@ -255,7 +222,7 @@ pub fn recv_tcp_socket( } Some(err) => { if let Some(code) = err.raw_os_error() { - if nix::Error::from_i32(code) == nix::Error::ECONNREFUSED { + if platform::is_conn_refused_error(code) { return Ok(Some(ProbeResponse::TcpRefused(TcpProbeResponseData::new( SystemTime::now(), dest_addr, @@ -268,13 +235,6 @@ pub fn recv_tcp_socket( Ok(None) } -fn make_raw_socket() -> TraceResult { - let socket = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::from(IPPROTO_RAW)))?; - socket.set_nonblocking(true)?; - socket.set_header_included(true)?; - Ok(socket) -} - /// Create an ICMP `EchoRequest` packet. fn make_echo_request_icmp_packet( icmp_buf: &mut [u8], diff --git a/src/tracing/net/ipv6.rs b/src/tracing/net/ipv6.rs index 4a19a378..2c785abb 100644 --- a/src/tracing/net/ipv6.rs +++ b/src/tracing/net/ipv6.rs @@ -1,6 +1,7 @@ use crate::tracing::error::TracerError::AddressNotAvailable; use crate::tracing::error::{TraceResult, TracerError}; use crate::tracing::net::channel::MAX_PACKET_SIZE; +use crate::tracing::net::platform; use crate::tracing::packet::checksum::{icmp_ipv6_checksum, udp_ipv6_checksum}; use crate::tracing::packet::icmpv6::destination_unreachable::DestinationUnreachablePacket; use crate::tracing::packet::icmpv6::echo_reply::EchoReplyPacket; @@ -14,8 +15,7 @@ use crate::tracing::probe::{ProbeResponse, ProbeResponseData, TcpProbeResponseDa use crate::tracing::types::{PacketSize, PayloadPattern, Sequence, TraceId}; use crate::tracing::util::Required; use crate::tracing::{PortDirection, Probe, TracerProtocol}; -use nix::sys::socket::{AddressFamily, SockaddrLike}; -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; +use socket2::{SockAddr, Socket}; use std::io::ErrorKind; use std::net::{IpAddr, Ipv6Addr, Shutdown, SocketAddr}; use std::time::SystemTime; @@ -32,39 +32,6 @@ const MAX_ICMP_PACKET_BUF: usize = MAX_PACKET_SIZE - Ipv6Packet::minimum_packet_ /// The maximum size of ICMP payload we allow. const MAX_ICMP_PAYLOAD_BUF: usize = MAX_ICMP_PACKET_BUF - IcmpPacket::minimum_packet_size(); -pub fn lookup_interface_addr(name: &str) -> TraceResult { - nix::ifaddrs::getifaddrs() - .map_err(|_| TracerError::UnknownInterface(name.to_string()))? - .into_iter() - .find_map(|ia| { - ia.address.and_then(|addr| match addr.family() { - Some(AddressFamily::Inet6) if ia.interface_name == name => addr - .as_sockaddr_in6() - .map(|sock_addr| IpAddr::V6(sock_addr.ip())), - _ => None, - }) - }) - .ok_or_else(|| TracerError::UnknownInterface(name.to_string())) -} - -pub fn make_icmp_send_socket() -> TraceResult { - let socket = Socket::new(Domain::IPV6, Type::RAW, Some(Protocol::ICMPV6))?; - socket.set_nonblocking(true)?; - Ok(socket) -} - -pub fn make_udp_send_socket() -> TraceResult { - let socket = Socket::new(Domain::IPV6, Type::RAW, Some(Protocol::UDP))?; - socket.set_nonblocking(true)?; - Ok(socket) -} - -pub fn make_recv_socket() -> TraceResult { - let socket = Socket::new(Domain::IPV6, Type::RAW, Some(Protocol::ICMPV6))?; - socket.set_nonblocking(true)?; - Ok(socket) -} - pub fn dispatch_icmp_probe( icmp_send_socket: &mut Socket, probe: Probe, @@ -147,9 +114,7 @@ pub fn dispatch_tcp_probe( PortDirection::FixedDest(dest_port) => (probe.sequence.0, dest_port.0), PortDirection::FixedBoth(_, _) | PortDirection::None => unimplemented!(), }; - let socket = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))?; - socket.set_nonblocking(true)?; - socket.set_reuse_port(true)?; + let socket = platform::make_stream_socket_ipv6()?; let local_addr = SocketAddr::new(IpAddr::V6(src_addr), src_port); socket.bind(&SockAddr::from(local_addr))?; socket.set_unicast_hops_v6(u32::from(probe.ttl.0))?; @@ -158,7 +123,7 @@ pub fn dispatch_tcp_probe( Ok(_) => {} Err(err) => { if let Some(code) = err.raw_os_error() { - if nix::Error::from_i32(code) != nix::Error::EINPROGRESS { + if platform::is_in_progress_error(code) { return match err.kind() { ErrorKind::AddrInUse | ErrorKind::AddrNotAvailable => { Err(AddressNotAvailable(local_addr)) @@ -210,7 +175,7 @@ pub fn recv_tcp_socket( } Some(err) => { if let Some(code) = err.raw_os_error() { - if nix::Error::from_i32(code) == nix::Error::ECONNREFUSED { + if platform::is_conn_refused_error(code) { return Ok(Some(ProbeResponse::TcpRefused(TcpProbeResponseData::new( SystemTime::now(), dest_addr, diff --git a/src/tracing/net/platform.rs b/src/tracing/net/platform.rs deleted file mode 100644 index 2e5b863a..00000000 --- a/src/tracing/net/platform.rs +++ /dev/null @@ -1,117 +0,0 @@ -use crate::tracing::error::TraceResult; -use std::net::IpAddr; - -/// The size of the test packet to use for discovering the `total_length` byte order. -#[cfg(all(unix, not(target_os = "linux")))] -const TEST_PACKET_LENGTH: u16 = 256; - -/// The byte order to encode the `total_length`, `flags` and `fragment_offset` fields of the IPv4 header. -/// -/// To quote directly from the `mtr` source code (from `check_length_order` in `probe_unix.c`): -/// -/// "Nearly all fields in the IP header should be encoded in network byte -/// order prior to passing to send(). However, the required byte order of -/// the length field of the IP header is inconsistent between operating -/// systems and operating system versions. FreeBSD 11 requires the length -/// field in network byte order, but some older versions of FreeBSD -/// require host byte order. OS X requires the length field in host -/// byte order. Linux will accept either byte order." -#[derive(Debug, Copy, Clone)] -pub enum PlatformIpv4FieldByteOrder { - #[cfg(all(unix, not(target_os = "linux")))] - Host, - Network, -} - -impl PlatformIpv4FieldByteOrder { - /// Discover the required byte ordering for the IPv4 header fields `total_length`, `flags` and `fragment_offset`. - /// - /// This is achieved by creating a raw socket and attempting to send an `IPv4` packet to localhost with the - /// `total_length` set in either host byte order or network byte order. The OS will return an `InvalidInput` error - /// if the buffer provided is smaller than the `total_length` indicated, which will be the case when the byte order - /// is set incorrectly. - /// - /// This is a little confusing as `Ipv4Packet::set_total_length` method will _always_ convert from host byte order - /// to network byte order (which will be a no-op on big-endian system) and so to test the host byte order case - /// we must try both the normal and the swapped byte order. - /// - /// For example, for a packet of length 4660 bytes (dec): - /// - /// For a little-endian architecture: - /// - /// Try Host (LE) Wire (BE) Order (if succeeds) - /// normal 34 12 12 34 `PlatformIpv4FieldByteOrder::Network` - /// swapped 12 34 34 12 `PlatformIpv4FieldByteOrder::Host` - /// - /// For a big-endian architecture: - /// - /// Try Host (BE) Wire (BE) Order (if succeeds) - /// normal 12 34 12 34 `Ipv4TotalLengthByteOrder::Host` - /// swapped 34 12 34 12 `Ipv4TotalLengthByteOrder::Network` - #[cfg(all(unix, not(target_os = "linux")))] - pub fn for_address(addr: IpAddr) -> TraceResult { - use crate::tracing::error::TracerError; - let addr = match addr { - IpAddr::V4(addr) => addr, - IpAddr::V6(_) => return Ok(Self::Network), - }; - match test_send_local_ip4_packet(addr, TEST_PACKET_LENGTH) { - Ok(_) => Ok(Self::Network), - Err(TracerError::IoError(io)) if io.kind() == std::io::ErrorKind::InvalidInput => { - match test_send_local_ip4_packet(addr, TEST_PACKET_LENGTH.swap_bytes()) { - Ok(_) => Ok(Self::Host), - Err(err) => Err(err), - } - } - Err(err) => Err(err), - } - } - - /// Discover the required byte ordering for the IPv4 header fields `total_length`, `flags` and `fragment_offset`. - /// - /// Linux accepts either network byte order or host byte order for the `total_length` field and so we skip the - /// check and return network bye order unconditionally. - #[cfg(target_os = "linux")] - #[allow(clippy::unnecessary_wraps)] - pub fn for_address(_src_addr: IpAddr) -> TraceResult { - Ok(Self::Network) - } - - /// Adjust the IPv4 `total_length` header. - pub fn adjust_length(self, ipv4_total_length: u16) -> u16 { - match self { - #[cfg(all(unix, not(target_os = "linux")))] - Self::Host => ipv4_total_length.swap_bytes(), - Self::Network => ipv4_total_length, - } - } -} - -/// Open a raw socket and attempt to send an `ICMP` packet to a local address. -/// -/// The packet is actually of length `256` bytes but we set the `total_length` based on the input provided so as to -/// test if the OS rejects the attempt. -#[cfg(all(unix, not(target_os = "linux")))] -fn test_send_local_ip4_packet( - src_addr: std::net::Ipv4Addr, - total_length: u16, -) -> TraceResult { - use crate::tracing::util::Required; - let mut buf = [0_u8; TEST_PACKET_LENGTH as usize]; - let mut ipv4 = crate::tracing::packet::ipv4::Ipv4Packet::new(&mut buf).req()?; - ipv4.set_version(4); - ipv4.set_header_length(5); - ipv4.set_protocol(crate::tracing::packet::IpProtocol::Icmp); - ipv4.set_ttl(255); - ipv4.set_source(src_addr); - ipv4.set_destination(std::net::Ipv4Addr::LOCALHOST); - ipv4.set_total_length(total_length); - let probe_socket = socket2::Socket::new( - socket2::Domain::IPV4, - socket2::Type::RAW, - Some(socket2::Protocol::from(nix::libc::IPPROTO_RAW)), - )?; - probe_socket.set_header_included(true)?; - let remote_addr = std::net::SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), 0); - Ok(probe_socket.send_to(ipv4.packet(), &socket2::SockAddr::from(remote_addr))?) -} diff --git a/src/tracing/net/platform/byte_order.rs b/src/tracing/net/platform/byte_order.rs new file mode 100644 index 00000000..4faae605 --- /dev/null +++ b/src/tracing/net/platform/byte_order.rs @@ -0,0 +1,60 @@ +use crate::tracing::error::TraceResult; +use crate::tracing::net::platform::for_address; +use std::net::IpAddr; + +/// The byte order to encode the `total_length`, `flags` and `fragment_offset` fields of the IPv4 header. +/// +/// To quote directly from the `mtr` source code (from `check_length_order` in `probe_unix.c`): +/// +/// "Nearly all fields in the IP header should be encoded in network byte +/// order prior to passing to send(). However, the required byte order of +/// the length field of the IP header is inconsistent between operating +/// systems and operating system versions. FreeBSD 11 requires the length +/// field in network byte order, but some older versions of FreeBSD +/// require host byte order. OS X requires the length field in host +/// byte order. Linux will accept either byte order." +#[derive(Debug, Copy, Clone)] +pub enum PlatformIpv4FieldByteOrder { + #[cfg(all(unix, not(target_os = "linux"), not(target_os = "windows")))] + Host, + Network, +} + +impl PlatformIpv4FieldByteOrder { + /// Discover the required byte ordering for the IPv4 header fields `total_length`, `flags` and `fragment_offset`. + /// + /// This is achieved by creating a raw socket and attempting to send an `IPv4` packet to localhost with the + /// `total_length` set in either host byte order or network byte order. The OS will return an `InvalidInput` error + /// if the buffer provided is smaller than the `total_length` indicated, which will be the case when the byte order + /// is set incorrectly. + /// + /// This is a little confusing as `Ipv4Packet::set_total_length` method will _always_ convert from host byte order + /// to network byte order (which will be a no-op on big-endian system) and so to test the host byte order case + /// we must try both the normal and the swapped byte order. + /// + /// For example, for a packet of length 4660 bytes (dec): + /// + /// For a little-endian architecture: + /// + /// Try Host (LE) Wire (BE) Order (if succeeds) + /// normal 34 12 12 34 `PlatformIpv4FieldByteOrder::Network` + /// swapped 12 34 34 12 `PlatformIpv4FieldByteOrder::Host` + /// + /// For a big-endian architecture: + /// + /// Try Host (BE) Wire (BE) Order (if succeeds) + /// normal 12 34 12 34 `Ipv4TotalLengthByteOrder::Host` + /// swapped 34 12 34 12 `Ipv4TotalLengthByteOrder::Network` + pub fn for_address(addr: IpAddr) -> TraceResult { + for_address(addr) + } + + /// Adjust the IPv4 `total_length` header. + pub fn adjust_length(self, ipv4_total_length: u16) -> u16 { + match self { + #[cfg(all(unix, not(target_os = "linux"), not(target_os = "windows")))] + Self::Host => ipv4_total_length.swap_bytes(), + Self::Network => ipv4_total_length, + } + } +} diff --git a/src/tracing/net/platform/mod.rs b/src/tracing/net/platform/mod.rs new file mode 100644 index 00000000..f06782cf --- /dev/null +++ b/src/tracing/net/platform/mod.rs @@ -0,0 +1,14 @@ +pub mod byte_order; +pub use byte_order::PlatformIpv4FieldByteOrder; + +#[cfg(unix)] +pub mod unix; + +#[cfg(unix)] +pub use unix::*; + +#[cfg(windows)] +pub mod windows; + +#[cfg(windows)] +pub use windows::*; diff --git a/src/tracing/net/platform/unix.rs b/src/tracing/net/platform/unix.rs new file mode 100644 index 00000000..451542ea --- /dev/null +++ b/src/tracing/net/platform/unix.rs @@ -0,0 +1,200 @@ +use super::byte_order::PlatformIpv4FieldByteOrder; +use crate::tracing::error::{TraceResult, TracerError}; +use nix::{ + sys::select::FdSet, + sys::socket::{AddressFamily, SockaddrLike}, + sys::time::{TimeVal, TimeValLike}, +}; +use socket2::{Domain, Protocol, Socket, Type}; +use std::net::{IpAddr, Ipv4Addr}; +use std::os::unix::io::AsRawFd; +use std::time::Duration; + +/// The size of the test packet to use for discovering the `total_length` byte order. +#[cfg(not(target_os = "linux"))] +const TEST_PACKET_LENGTH: u16 = 256; + +/// Discover the required byte ordering for the IPv4 header fields `total_length`, `flags` and `fragment_offset`. +/// +/// Linux accepts either network byte order or host byte order for the `total_length` field and so we skip the +/// check and return network byte order unconditionally. +#[cfg(target_os = "linux")] +#[allow(clippy::unnecessary_wraps)] +pub fn for_address(_src_addr: IpAddr) -> TraceResult { + Ok(PlatformIpv4FieldByteOrder::Network) +} + +#[cfg(not(target_os = "linux"))] +pub fn for_address(addr: IpAddr) -> TraceResult { + let addr = match addr { + IpAddr::V4(addr) => addr, + IpAddr::V6(_) => return Ok(PlatformIpv4FieldByteOrder::Network), + }; + match test_send_local_ip4_packet(addr, TEST_PACKET_LENGTH) { + Ok(_) => Ok(PlatformIpv4FieldByteOrder::Network), + Err(TracerError::IoError(io)) if io.kind() == std::io::ErrorKind::InvalidInput => { + match test_send_local_ip4_packet(addr, TEST_PACKET_LENGTH.swap_bytes()) { + Ok(_) => Ok(PlatformIpv4FieldByteOrder::Host), + Err(err) => Err(err), + } + } + Err(err) => Err(err), + } +} + +/// Open a raw socket and attempt to send an `ICMP` packet to a local address. +/// +/// The packet is actually of length `256` bytes but we set the `total_length` based on the input provided so as to +/// test if the OS rejects the attempt. +#[cfg(not(target_os = "linux"))] +fn test_send_local_ip4_packet(src_addr: Ipv4Addr, total_length: u16) -> TraceResult { + use crate::tracing::util::Required; + let mut buf = [0_u8; TEST_PACKET_LENGTH as usize]; + let mut ipv4 = crate::tracing::packet::ipv4::Ipv4Packet::new(&mut buf).req()?; + ipv4.set_version(4); + ipv4.set_header_length(5); + ipv4.set_protocol(crate::tracing::packet::IpProtocol::Icmp); + ipv4.set_ttl(255); + ipv4.set_source(src_addr); + ipv4.set_destination(Ipv4Addr::LOCALHOST); + ipv4.set_total_length(total_length); + let probe_socket = Socket::new( + socket2::Domain::IPV4, + socket2::Type::RAW, + Some(socket2::Protocol::from(nix::libc::IPPROTO_RAW)), + )?; + probe_socket.set_header_included(true)?; + let remote_addr = std::net::SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0); + Ok(probe_socket.send_to(ipv4.packet(), &socket2::SockAddr::from(remote_addr))?) +} + +pub fn lookup_interface_addr_ipv4(name: &str) -> TraceResult { + nix::ifaddrs::getifaddrs() + .map_err(|_| TracerError::UnknownInterface(name.to_string()))? + .into_iter() + .find_map(|ia| { + ia.address.and_then(|addr| match addr.family() { + Some(AddressFamily::Inet) if ia.interface_name == name => addr + .as_sockaddr_in() + .map(|sock_addr| IpAddr::V4(Ipv4Addr::from(sock_addr.ip()))), + _ => None, + }) + }) + .ok_or_else(|| TracerError::UnknownInterface(name.to_string())) +} + +pub fn lookup_interface_addr_ipv6(name: &str) -> TraceResult { + nix::ifaddrs::getifaddrs() + .map_err(|_| TracerError::UnknownInterface(name.to_string()))? + .into_iter() + .find_map(|ia| { + ia.address.and_then(|addr| match addr.family() { + Some(AddressFamily::Inet6) if ia.interface_name == name => addr + .as_sockaddr_in6() + .map(|sock_addr| IpAddr::V6(sock_addr.ip())), + _ => None, + }) + }) + .ok_or_else(|| TracerError::UnknownInterface(name.to_string())) +} + +pub fn make_icmp_send_socket_ipv4() -> TraceResult { + let socket = Socket::new( + Domain::IPV4, + Type::RAW, + Some(Protocol::from(nix::libc::IPPROTO_RAW)), + )?; + socket.set_nonblocking(true)?; + socket.set_header_included(true)?; + Ok(socket) +} + +pub fn make_udp_send_socket_ipv4() -> TraceResult { + let socket = Socket::new( + Domain::IPV4, + Type::RAW, + Some(Protocol::from(nix::libc::IPPROTO_RAW)), + )?; + socket.set_nonblocking(true)?; + socket.set_header_included(true)?; + Ok(socket) +} + +pub fn make_recv_socket_ipv4() -> TraceResult { + let socket = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::ICMPV4))?; + socket.set_nonblocking(true)?; + socket.set_header_included(true)?; + Ok(socket) +} + +pub fn make_icmp_send_socket_ipv6() -> TraceResult { + let socket = Socket::new(Domain::IPV6, Type::RAW, Some(Protocol::ICMPV6))?; + socket.set_nonblocking(true)?; + Ok(socket) +} + +pub fn make_udp_send_socket_ipv6() -> TraceResult { + let socket = Socket::new(Domain::IPV6, Type::RAW, Some(Protocol::UDP))?; + socket.set_nonblocking(true)?; + Ok(socket) +} + +pub fn make_recv_socket_ipv6() -> TraceResult { + let socket = Socket::new(Domain::IPV6, Type::RAW, Some(Protocol::ICMPV6))?; + socket.set_nonblocking(true)?; + Ok(socket) +} + +/// Create a IPv4/TCP socket. +pub fn make_stream_socket_ipv4() -> TraceResult { + let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?; + socket.set_nonblocking(true)?; + socket.set_reuse_port(true)?; + Ok(socket) +} + +/// Create a IPv6/TCP socket. +pub fn make_stream_socket_ipv6() -> TraceResult { + let socket = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))?; + socket.set_nonblocking(true)?; + socket.set_reuse_port(true)?; + Ok(socket) +} + +/// Returns true if the socket becomes readable before the timeout, false otherwise. +pub fn is_readable(sock: &Socket, timeout: Duration) -> TraceResult { + let mut read = FdSet::new(); + read.insert(sock.as_raw_fd()); + let readable = nix::sys::select::select( + None, + Some(&mut read), + None, + None, + Some(&mut TimeVal::milliseconds(timeout.as_millis() as i64)), + ) + .map_err(|err| TracerError::IoError(std::io::Error::from(err)))?; + Ok(readable == 1) +} + +/// Returns true if the socket is currently writeable, false otherwise. +pub fn is_writable(sock: &Socket) -> TraceResult { + let mut write = FdSet::new(); + write.insert(sock.as_raw_fd()); + let writable = nix::sys::select::select( + None, + None, + Some(&mut write), + None, + Some(&mut TimeVal::zero()), + ) + .map_err(|err| TracerError::IoError(std::io::Error::from(err)))?; + Ok(writable == 1) +} + +pub fn is_in_progress_error(code: i32) -> bool { + nix::Error::from_i32(code) != nix::Error::EINPROGRESS +} + +pub fn is_conn_refused_error(code: i32) -> bool { + nix::Error::from_i32(code) == nix::Error::ECONNREFUSED +} diff --git a/src/tracing/net/platform/windows.rs b/src/tracing/net/platform/windows.rs new file mode 100644 index 00000000..b28a3687 --- /dev/null +++ b/src/tracing/net/platform/windows.rs @@ -0,0 +1,81 @@ +use super::byte_order::PlatformIpv4FieldByteOrder; +use crate::tracing::error::TraceResult; +use socket2::Socket; +use std::net::IpAddr; +use std::time::Duration; + +/// TODO +#[allow(clippy::unnecessary_wraps)] +pub fn for_address(_src_addr: IpAddr) -> TraceResult { + Ok(PlatformIpv4FieldByteOrder::Network) +} + +/// TODO +pub fn lookup_interface_addr_ipv4(_name: &str) -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn lookup_interface_addr_ipv6(_name: &str) -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn make_icmp_send_socket_ipv4() -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn make_udp_send_socket_ipv4() -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn make_recv_socket_ipv4() -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn make_icmp_send_socket_ipv6() -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn make_udp_send_socket_ipv6() -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn make_recv_socket_ipv6() -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn make_stream_socket_ipv4() -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn make_stream_socket_ipv6() -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn is_readable(_sock: &Socket, _timeout: Duration) -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn is_writable(_sock: &Socket) -> TraceResult { + unimplemented!() +} + +/// TODO +pub fn is_in_progress_error(_code: i32) -> bool { + unimplemented!() +} + +/// TODO +pub fn is_conn_refused_error(_code: i32) -> bool { + unimplemented!() +}