diff --git a/src/aproxy/tcp.rs b/src/aproxy/tcp.rs index e715e8d..93164d9 100644 --- a/src/aproxy/tcp.rs +++ b/src/aproxy/tcp.rs @@ -1,13 +1,13 @@ +use bytes::BytesMut; +use rustls_pki_types::ServerName; use std::{ net::{IpAddr, SocketAddr}, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, + time::Duration, }; - -use bytes::BytesMut; -use rustls_pki_types::ServerName; use tokio::{ io::{split, AsyncWriteExt, WriteHalf}, net::{tcp::OwnedReadHalf, TcpListener, TcpStream}, @@ -41,12 +41,7 @@ pub async fn run_tcp( let server_name_clone = server_name.clone(); let connector_clone = connector.clone(); spawn(async move { - let ret = start_tcp_proxy( - client, - server_name_clone, - connector_clone, - dst_addr, - ).await; + let ret = start_tcp_proxy(client, server_name_clone, connector_clone, dst_addr).await; if let Err(err) = ret { log::error!("tcp proxy routine exit with:{:?}", err); } @@ -60,7 +55,11 @@ async fn start_tcp_proxy( connector: TlsConnector, dst_addr: SocketAddr, ) -> Result<()> { - let mut remote = init_tls_conn(connector, server_name).await?; + let mut remote = tokio::time::timeout( + Duration::from_secs(3), + init_tls_conn(connector, server_name), + ) + .await??; let mut request = BytesMut::new(); TrojanRequest::generate(&mut request, CONNECT, &dst_addr); if let Err(err) = remote.write_all(request.as_ref()).await { diff --git a/src/aproxy/udp.rs b/src/aproxy/udp.rs index dbdc4de..8005872 100644 --- a/src/aproxy/udp.rs +++ b/src/aproxy/udp.rs @@ -24,7 +24,7 @@ use crate::{ aproxy::{init_tls_conn, new_socket, wait_until_stop}, config::OPTIONS, proto::{TrojanRequest, UdpAssociate, UdpParseResult, UDP_ASSOCIATE}, - sys, + sys, types, types::Result, }; @@ -93,15 +93,26 @@ pub async fn run_udp( }); let (req_sender, req_receiver) = channel(1024); remotes.insert(src_addr, req_sender); - spawn(local_to_remote( - req_receiver, - local.clone(), - server_name.clone(), - connector.clone(), - request.clone(), - src_addr, - sender.clone(), - )); + let local_clone = local.clone(); + let server_name_clone = server_name.clone(); + let request_clone = request.clone(); + let sender_clone = sender.clone(); + let connector_clone = connector.clone(); + spawn(async move { + if let Err(err) = local_to_remote( + req_receiver, + local_clone, + server_name_clone, + connector_clone, + request_clone, + src_addr, + sender_clone, + ) + .await + { + log::error!("udp local to remote failed:{:?}", err); + } + }); remotes.get(&src_addr).unwrap() } }; @@ -140,27 +151,24 @@ async fn local_to_remote( request: Arc, src_addr: SocketAddr, sender: Sender, -) { - let mut remote = match init_tls_conn(connector, server_name).await { - Ok(mut remote) => { - if let Err(err) = remote.write_all(request.as_ref()).await { - let _ = remote.shutdown().await; - let _ = sender.send(src_addr).await; - log::error!("send handshake to remote failed:{}", err); - return; - } - let (read_half, write_half) = split(remote); - spawn(remote_to_local_with_wait( - read_half, socket, src_addr, sender, - )); - write_half - } - Err(err) => { - log::error!("connect to remote server failed:{:?}", err); - let _ = sender.send(src_addr).await; - return; - } - }; +) -> types::Result<()> { + let mut remote = tokio::time::timeout( + Duration::from_secs(3), + init_tls_conn(connector, server_name), + ) + .await??; + + if let Err(err) = remote.write_all(request.as_ref()).await { + let _ = remote.shutdown().await; + let _ = sender.send(src_addr).await; + log::error!("send handshake to remote failed:{}", err); + return Ok(()); + } + let (read_half, write_half) = split(remote); + spawn(remote_to_local_with_wait( + read_half, socket, src_addr, sender, + )); + let mut remote = write_half; let mut header = BytesMut::new(); while let Some((target, data)) = local.recv().await { @@ -179,6 +187,7 @@ async fn local_to_remote( } local.close(); let _ = remote.shutdown().await; + Ok(()) } async fn remote_to_local( @@ -194,7 +203,7 @@ async fn remote_to_local( Duration::from_secs(OPTIONS.udp_idle_timeout), remote.read_buf(&mut buffer), ) - .await + .await { Ok(Ok(n)) if n > 0 => loop { match UdpAssociate::parse(buffer.as_ref()) {