diff --git a/src/host.rs b/src/host.rs index 7fbf57b..78a4905 100644 --- a/src/host.rs +++ b/src/host.rs @@ -417,6 +417,12 @@ pub fn matches(bind: SocketAddr, dst: SocketAddr) -> bool { bind == dst } +/// Returns true if loopback is supported between two addresses, or +/// if the IPs are the same (in which case turmoil treats it like loopback) +pub(crate) fn is_same(src: SocketAddr, dst: SocketAddr) -> bool { + dst.ip().is_loopback() || src.ip() == dst.ip() +} + #[cfg(test)] mod test { use crate::{Host, Result}; diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index ea32623..76b9f53 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -17,6 +17,7 @@ use tokio::{ use crate::{ envelope::{Envelope, Protocol, Segment, Syn}, + host::is_same, host::SequencedSegment, net::SocketPair, world::World, @@ -74,7 +75,7 @@ impl TcpStream { let rx = host.tcp.new_stream(pair); let syn = Protocol::Tcp(Segment::Syn(Syn { ack })); - if !dst.ip().is_loopback() { + if !is_same(local_addr, dst) { world.send_message(local_addr, dst, syn)?; } else { send_loopback(local_addr, dst, syn); @@ -270,7 +271,7 @@ impl WriteHalf { fn send(&self, world: &mut World, segment: Segment) -> Result<()> { let message = Protocol::Tcp(segment); - if self.pair.remote.ip().is_loopback() { + if is_same(self.pair.local, self.pair.remote) { send_loopback(self.pair.local, self.pair.remote, message); } else { world.send_message(self.pair.local, self.pair.remote, message)?; diff --git a/src/net/udp.rs b/src/net/udp.rs index 9a22b06..96b2bd2 100644 --- a/src/net/udp.rs +++ b/src/net/udp.rs @@ -6,6 +6,7 @@ use tokio::{ use crate::{ envelope::{Datagram, Envelope, Protocol}, + host::is_same, ToSocketAddrs, World, TRACING_TARGET, }; @@ -291,7 +292,7 @@ impl UdpSocket { src.set_ip(world.current_host_mut().addr); } - if dst.ip().is_loopback() { + if is_same(src, dst) { send_loopback(src, dst, msg); } else { world.send_message(src, dst, msg)?; diff --git a/tests/tcp.rs b/tests/tcp.rs index a32ee93..25825b3 100644 --- a/tests/tcp.rs +++ b/tests/tcp.rs @@ -736,7 +736,9 @@ fn non_zero_bind() -> Result { sim.client("client", async move { let sock = TcpListener::bind("1.1.1.1:1").await; - let Err(err) = sock else { panic!("bind should have failed") }; + let Err(err) = sock else { + panic!("bind should have failed") + }; assert_eq!(err.to_string(), "1.1.1.1:1 is not supported"); Ok(()) }); @@ -830,6 +832,38 @@ fn loopback_to_localhost_v4() -> Result { run_localhost_test(IpVersion::V4, bind_addr, connect_addr) } +#[test] +fn loopback_wildcard_public_v4() -> Result { + let bind_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 1234); + let connect_addr = SocketAddr::from((Ipv4Addr::new(192, 168, 0, 1), 1234)); + run_localhost_test(IpVersion::V4, bind_addr, connect_addr) +} + +#[test] +fn loopback_localhost_public_v4() -> Result { + let bind_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 1234); + let connect_addr = SocketAddr::from((Ipv4Addr::new(192, 168, 0, 1), 1234)); + let mut sim = Builder::new().ip_version(IpVersion::V4).build(); + sim.client("client", async move { + let listener = TcpListener::bind(bind_addr).await?; + + tokio::spawn(async move { + let (mut socket, socket_addr) = listener.accept().await.unwrap(); + socket.write_all(&[0, 1, 3, 7, 8]).await.unwrap(); + + assert_eq!(socket_addr.ip(), connect_addr.ip()); + assert_eq!(socket.local_addr().unwrap().ip(), connect_addr.ip()); + assert_eq!(socket.peer_addr().unwrap().ip(), connect_addr.ip()); + }); + + let res = TcpStream::connect(connect_addr).await; + assert_error_kind(res, io::ErrorKind::ConnectionRefused); + + Ok(()) + }); + sim.run() +} + #[test] fn loopback_to_wildcard_v6() -> Result { let bind_addr = SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 1234); @@ -844,6 +878,38 @@ fn loopback_to_localhost_v6() -> Result { run_localhost_test(IpVersion::V6, bind_addr, connect_addr) } +#[test] +fn loopback_wildcard_public_v6() -> Result { + let bind_addr = SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 1234); + let connect_addr = SocketAddr::from((Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), 1234)); + run_localhost_test(IpVersion::V6, bind_addr, connect_addr) +} + +#[test] +fn loopback_localhost_public_v6() -> Result { + let bind_addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 1234); + let connect_addr = SocketAddr::from((Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), 1234)); + let mut sim = Builder::new().ip_version(IpVersion::V6).build(); + sim.client("client", async move { + let listener = TcpListener::bind(bind_addr).await?; + + tokio::spawn(async move { + let (mut socket, socket_addr) = listener.accept().await.unwrap(); + socket.write_all(&[0, 1, 3, 7, 8]).await.unwrap(); + + assert_eq!(socket_addr.ip(), connect_addr.ip()); + assert_eq!(socket.local_addr().unwrap().ip(), connect_addr.ip()); + assert_eq!(socket.peer_addr().unwrap().ip(), connect_addr.ip()); + }); + + let res = TcpStream::connect(connect_addr).await; + assert_error_kind(res, io::ErrorKind::ConnectionRefused); + + Ok(()) + }); + sim.run() +} + #[test] fn remote_to_localhost_refused() -> Result { let mut sim = Builder::new().build(); diff --git a/tests/udp.rs b/tests/udp.rs index a8e8dc7..8942a81 100644 --- a/tests/udp.rs +++ b/tests/udp.rs @@ -1,6 +1,5 @@ use std::{ io::{self, ErrorKind}, - matches, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, rc::Rc, sync::{atomic::AtomicUsize, atomic::Ordering}, @@ -210,7 +209,7 @@ fn hold_and_release() -> Result { send_ping(&sock).await?; let res = timeout(Duration::from_secs(1), recv_pong(&sock)).await; - assert!(matches!(res, Err(_))); + assert!(res.is_err()); // resume the network. note that the client ping does not have to be // resent. @@ -406,7 +405,9 @@ fn non_zero_bind() -> Result { sim.client("client", async move { let sock = UdpSocket::bind("1.1.1.1:1").await; - let Err(err) = sock else { panic!("socket creation should have failed") }; + let Err(err) = sock else { + panic!("socket creation should have failed") + }; assert_eq!(err.to_string(), "1.1.1.1:1 is not supported"); Ok(()) }); @@ -506,6 +507,43 @@ fn loopback_to_localhost_v4() -> Result { run_localhost_test(IpVersion::V4, bind_addr, connect_addr) } +#[test] +fn loopback_wildcard_public_v4() -> Result { + let bind_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 1234); + let connect_addr = SocketAddr::from((Ipv4Addr::new(192, 168, 0, 1), 1234)); + run_localhost_test(IpVersion::V4, bind_addr, connect_addr) +} + +#[test] +fn loopback_localhost_public_v4() -> Result { + let bind_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 1234); + let connect_addr = SocketAddr::from((Ipv4Addr::new(192, 168, 0, 1), 1234)); + let mut sim = Builder::new().ip_version(IpVersion::V4).build(); + let expected = [0, 1, 7, 3, 8]; + sim.client("client", async move { + let socket = UdpSocket::bind(bind_addr).await?; + + tokio::spawn(async move { + let mut buf = [0; 5]; + let (_, peer) = socket.recv_from(&mut buf).await.unwrap(); + + assert_eq!(expected, buf); + assert_eq!(peer.ip(), connect_addr.ip()); + assert_eq!(socket.local_addr().unwrap().ip(), bind_addr.ip()); + + socket.send_to(&expected, peer).await.unwrap(); + }); + + let bind_addr = SocketAddr::new(bind_addr.ip(), 0); + let socket = UdpSocket::bind(bind_addr).await?; + let res = socket.send_to(&expected, connect_addr).await; + assert_error_kind(res, io::ErrorKind::ConnectionRefused); + + Ok(()) + }); + sim.run() +} + #[test] fn loopback_to_wildcard_v6() -> Result { let bind_addr = SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 1234); @@ -520,6 +558,43 @@ fn loopback_to_localhost_v6() -> Result { run_localhost_test(IpVersion::V6, bind_addr, connect_addr) } +#[test] +fn loopback_wildcard_public_v6() -> Result { + let bind_addr = SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 1234); + let connect_addr = SocketAddr::from((Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), 1234)); + run_localhost_test(IpVersion::V6, bind_addr, connect_addr) +} + +#[test] +fn loopback_localhost_public_v6() -> Result { + let bind_addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 1234); + let connect_addr = SocketAddr::from((Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), 1234)); + let mut sim = Builder::new().ip_version(IpVersion::V6).build(); + let expected = [0, 1, 7, 3, 8]; + sim.client("client", async move { + let socket = UdpSocket::bind(bind_addr).await?; + + tokio::spawn(async move { + let mut buf = [0; 5]; + let (_, peer) = socket.recv_from(&mut buf).await.unwrap(); + + assert_eq!(expected, buf); + assert_eq!(peer.ip(), connect_addr.ip()); + assert_eq!(socket.local_addr().unwrap().ip(), bind_addr.ip()); + + socket.send_to(&expected, peer).await.unwrap(); + }); + + let bind_addr = SocketAddr::new(bind_addr.ip(), 0); + let socket = UdpSocket::bind(bind_addr).await?; + let res = socket.send_to(&expected, connect_addr).await; + assert_error_kind(res, io::ErrorKind::ConnectionRefused); + + Ok(()) + }); + sim.run() +} + #[test] fn remote_to_localhost_dropped() -> Result { let mut sim = Builder::new().build();