Skip to content

Commit

Permalink
add connect timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
lazytiger committed Nov 20, 2024
1 parent 13242d4 commit b4468b7
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 42 deletions.
19 changes: 9 additions & 10 deletions src/aproxy/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use bytes::BytesMut;
use rustls_pki_types::ServerName;
use std::{
net::{IpAddr, SocketAddr},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};

use bytes::BytesMut;
use rustls_pki_types::ServerName;
use tokio::{
io::{split, AsyncWriteExt, WriteHalf},
net::{tcp::OwnedReadHalf, TcpListener, TcpStream},
Expand Down Expand Up @@ -41,12 +41,7 @@ pub async fn run_tcp(
let server_name_clone = server_name.clone();
let connector_clone = connector.clone();
spawn(async move {
let ret = start_tcp_proxy(
client,
server_name_clone,
connector_clone,
dst_addr,
).await;
let ret = start_tcp_proxy(client, server_name_clone, connector_clone, dst_addr).await;
if let Err(err) = ret {
log::error!("tcp proxy routine exit with:{:?}", err);
}
Expand All @@ -60,7 +55,11 @@ async fn start_tcp_proxy(
connector: TlsConnector,
dst_addr: SocketAddr,
) -> Result<()> {
let mut remote = init_tls_conn(connector, server_name).await?;
let mut remote = tokio::time::timeout(
Duration::from_secs(3),
init_tls_conn(connector, server_name),
)
.await??;
let mut request = BytesMut::new();
TrojanRequest::generate(&mut request, CONNECT, &dst_addr);
if let Err(err) = remote.write_all(request.as_ref()).await {
Expand Down
73 changes: 41 additions & 32 deletions src/aproxy/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{
aproxy::{init_tls_conn, new_socket, wait_until_stop},
config::OPTIONS,
proto::{TrojanRequest, UdpAssociate, UdpParseResult, UDP_ASSOCIATE},
sys,
sys, types,
types::Result,
};

Expand Down Expand Up @@ -93,15 +93,26 @@ pub async fn run_udp(
});
let (req_sender, req_receiver) = channel(1024);
remotes.insert(src_addr, req_sender);
spawn(local_to_remote(
req_receiver,
local.clone(),
server_name.clone(),
connector.clone(),
request.clone(),
src_addr,
sender.clone(),
));
let local_clone = local.clone();
let server_name_clone = server_name.clone();
let request_clone = request.clone();
let sender_clone = sender.clone();
let connector_clone = connector.clone();
spawn(async move {
if let Err(err) = local_to_remote(
req_receiver,
local_clone,
server_name_clone,
connector_clone,
request_clone,
src_addr,
sender_clone,
)
.await
{
log::error!("udp local to remote failed:{:?}", err);
}
});
remotes.get(&src_addr).unwrap()
}
};
Expand Down Expand Up @@ -140,27 +151,24 @@ async fn local_to_remote(
request: Arc<BytesMut>,
src_addr: SocketAddr,
sender: Sender<SocketAddr>,
) {
let mut remote = match init_tls_conn(connector, server_name).await {
Ok(mut remote) => {
if let Err(err) = remote.write_all(request.as_ref()).await {
let _ = remote.shutdown().await;
let _ = sender.send(src_addr).await;
log::error!("send handshake to remote failed:{}", err);
return;
}
let (read_half, write_half) = split(remote);
spawn(remote_to_local_with_wait(
read_half, socket, src_addr, sender,
));
write_half
}
Err(err) => {
log::error!("connect to remote server failed:{:?}", err);
let _ = sender.send(src_addr).await;
return;
}
};
) -> types::Result<()> {
let mut remote = tokio::time::timeout(
Duration::from_secs(3),
init_tls_conn(connector, server_name),
)
.await??;

if let Err(err) = remote.write_all(request.as_ref()).await {
let _ = remote.shutdown().await;
let _ = sender.send(src_addr).await;
log::error!("send handshake to remote failed:{}", err);
return Ok(());
}
let (read_half, write_half) = split(remote);
spawn(remote_to_local_with_wait(
read_half, socket, src_addr, sender,
));
let mut remote = write_half;

let mut header = BytesMut::new();
while let Some((target, data)) = local.recv().await {
Expand All @@ -179,6 +187,7 @@ async fn local_to_remote(
}
local.close();
let _ = remote.shutdown().await;
Ok(())
}

async fn remote_to_local(
Expand All @@ -194,7 +203,7 @@ async fn remote_to_local(
Duration::from_secs(OPTIONS.udp_idle_timeout),
remote.read_buf(&mut buffer),
)
.await
.await
{
Ok(Ok(n)) if n > 0 => loop {
match UdpAssociate::parse(buffer.as_ref()) {
Expand Down

0 comments on commit b4468b7

Please sign in to comment.