From b2ea40bb543a5116109b37e1fd64713c116e4312 Mon Sep 17 00:00:00 2001 From: Motoyuki Kimura Date: Fri, 16 Aug 2024 23:50:34 +0900 Subject: [PATCH] net: add handling for abstract socket name (#6772) --- tokio/src/net/unix/listener.rs | 23 +++++++++++++++++++++-- tokio/src/net/unix/stream.rs | 23 +++++++++++++++++++++-- tokio/tests/uds_stream.rs | 13 +++++++++++++ 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/tokio/src/net/unix/listener.rs b/tokio/src/net/unix/listener.rs index 79b554ee1ab..5b28dc03f8f 100644 --- a/tokio/src/net/unix/listener.rs +++ b/tokio/src/net/unix/listener.rs @@ -3,8 +3,14 @@ use crate::net::unix::{SocketAddr, UnixStream}; use std::fmt; use std::io; +#[cfg(target_os = "android")] +use std::os::android::net::SocketAddrExt; +#[cfg(target_os = "linux")] +use std::os::linux::net::SocketAddrExt; +#[cfg(any(target_os = "linux", target_os = "android"))] +use std::os::unix::ffi::OsStrExt; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd}; -use std::os::unix::net; +use std::os::unix::net::{self, SocketAddr as StdSocketAddr}; use std::path::Path; use std::task::{Context, Poll}; @@ -70,7 +76,20 @@ impl UnixListener { where P: AsRef, { - let listener = mio::net::UnixListener::bind(path)?; + // For now, we handle abstract socket paths on linux here. + #[cfg(any(target_os = "linux", target_os = "android"))] + let addr = { + let os_str_bytes = path.as_ref().as_os_str().as_bytes(); + if os_str_bytes.starts_with(b"\0") { + StdSocketAddr::from_abstract_name(os_str_bytes)? + } else { + StdSocketAddr::from_pathname(path)? + } + }; + #[cfg(not(any(target_os = "linux", target_os = "android")))] + let addr = StdSocketAddr::from_pathname(path)?; + + let listener = mio::net::UnixListener::bind_addr(&addr)?; let io = PollEvented::new(listener)?; Ok(UnixListener { io }) } diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs index 60d58139699..a8b6479f1f8 100644 --- a/tokio/src/net/unix/stream.rs +++ b/tokio/src/net/unix/stream.rs @@ -8,8 +8,14 @@ use crate::net::unix::SocketAddr; use std::fmt; use std::io::{self, Read, Write}; use std::net::Shutdown; +#[cfg(target_os = "android")] +use std::os::android::net::SocketAddrExt; +#[cfg(target_os = "linux")] +use std::os::linux::net::SocketAddrExt; +#[cfg(any(target_os = "linux", target_os = "android"))] +use std::os::unix::ffi::OsStrExt; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd}; -use std::os::unix::net; +use std::os::unix::net::{self, SocketAddr as StdSocketAddr}; use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; @@ -66,7 +72,20 @@ impl UnixStream { where P: AsRef, { - let stream = mio::net::UnixStream::connect(path)?; + // On linux, abstract socket paths need to be considered. + #[cfg(any(target_os = "linux", target_os = "android"))] + let addr = { + let os_str_bytes = path.as_ref().as_os_str().as_bytes(); + if os_str_bytes.starts_with(b"\0") { + StdSocketAddr::from_abstract_name(os_str_bytes)? + } else { + StdSocketAddr::from_pathname(path)? + } + }; + #[cfg(not(any(target_os = "linux", target_os = "android")))] + let addr = StdSocketAddr::from_pathname(path)?; + + let stream = mio::net::UnixStream::connect_addr(&addr)?; let stream = UnixStream::new(stream)?; poll_fn(|cx| stream.io.registration().poll_write_ready(cx)).await?; diff --git a/tokio/tests/uds_stream.rs b/tokio/tests/uds_stream.rs index b8c4e6a8eed..28a836eb76f 100644 --- a/tokio/tests/uds_stream.rs +++ b/tokio/tests/uds_stream.rs @@ -409,3 +409,16 @@ async fn epollhup() -> io::Result<()> { assert_eq!(err.kind(), io::ErrorKind::ConnectionReset); Ok(()) } + +// test for https://github.com/tokio-rs/tokio/issues/6767 +#[tokio::test] +#[cfg(any(target_os = "linux", target_os = "android"))] +async fn abstract_socket_name() { + let socket_path = "\0aaa"; + let listener = UnixListener::bind(socket_path).unwrap(); + + let accept = listener.accept(); + let connect = UnixStream::connect(&socket_path); + + try_join(accept, connect).await.unwrap(); +}