diff --git a/Cargo.lock b/Cargo.lock index b55d42fa..be8d3ba6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1253,7 +1253,7 @@ dependencies = [ [[package]] name = "librqbit" -version = "4.0.0" +version = "4.1.0" dependencies = [ "anyhow", "axum 0.7.1", @@ -1343,7 +1343,7 @@ dependencies = [ [[package]] name = "librqbit-dht" -version = "4.0.0" +version = "4.1.0" dependencies = [ "anyhow", "backoff", @@ -1368,7 +1368,7 @@ dependencies = [ [[package]] name = "librqbit-peer-protocol" -version = "3.2.1" +version = "3.3.0" dependencies = [ "anyhow", "bincode", @@ -2002,7 +2002,7 @@ dependencies = [ [[package]] name = "rqbit" -version = "4.0.0" +version = "4.1.0" dependencies = [ "anyhow", "clap", diff --git a/crates/dht/Cargo.toml b/crates/dht/Cargo.toml index 41418d10..f5da4e5c 100644 --- a/crates/dht/Cargo.toml +++ b/crates/dht/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "librqbit-dht" -version = "4.0.0" +version = "4.1.0" edition = "2021" description = "DHT implementation, used in rqbit torrent client." license = "Apache-2.0" diff --git a/crates/dht/src/dht.rs b/crates/dht/src/dht.rs index 1dbc0b1f..3581bfe2 100644 --- a/crates/dht/src/dht.rs +++ b/crates/dht/src/dht.rs @@ -1,6 +1,6 @@ use std::{ cmp::Reverse, - net::{SocketAddr, SocketAddrV4}, + net::SocketAddr, sync::{ atomic::{AtomicU16, Ordering}, Arc, diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index 87f940cf..3f4afa8a 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "librqbit" -version = "4.0.0" +version = "4.1.0" authors = ["Igor Katson "] edition = "2021" description = "The main library used by rqbit torrent client. The binary is just a small wrapper on top of it." @@ -26,7 +26,7 @@ bencode = {path = "../bencode", default-features=false, package="librqbit-bencod buffers = {path = "../buffers", package="librqbit-buffers", version = "2.2.1"} librqbit-core = {path = "../librqbit_core", version = "3.2.1"} clone_to_owned = {path = "../clone_to_owned", package="librqbit-clone-to-owned", version = "2.2.1"} -peer_binary_protocol = {path = "../peer_binary_protocol", package="librqbit-peer-protocol", version = "3.2.1"} +peer_binary_protocol = {path = "../peer_binary_protocol", package="librqbit-peer-protocol", version = "3.3.0"} sha1w = {path = "../sha1w", default-features=false, package="librqbit-sha1-wrapper", version="2.2.1"} dht = {path = "../dht", package="librqbit-dht", version="4.0.0"} librqbit-upnp = {path = "../upnp", version = "0.1.0"} diff --git a/crates/librqbit/src/peer_connection.rs b/crates/librqbit/src/peer_connection.rs index efef6ab5..5fe22d42 100644 --- a/crates/librqbit/src/peer_connection.rs +++ b/crates/librqbit/src/peer_connection.rs @@ -62,7 +62,7 @@ pub struct PeerConnection { spawner: BlockingSpawner, } -async fn with_timeout( +pub(crate) async fn with_timeout( timeout_value: Duration, fut: impl std::future::Future>, ) -> anyhow::Result @@ -120,18 +120,57 @@ impl PeerConnection { } } + // By the time this is called: + // read_buf should start with valuable data. The handshake should be removed from it. pub async fn manage_peer_incoming( &self, - mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, + outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, + // How many bytes into read buffer have we read already. + read_so_far: usize, + read_buf: Vec, handshake: Handshake, - socket: tokio::net::TcpSocket, + mut conn: tokio::net::TcpStream, ) -> anyhow::Result<()> { - todo!() + use tokio::io::AsyncWriteExt; + + let rwtimeout = self + .options + .read_write_timeout + .unwrap_or_else(|| Duration::from_secs(10)); + + if handshake.info_hash != self.info_hash.0 { + anyhow::bail!("wrong info hash"); + } + + trace!( + "incoming connection: id={:?}", + try_decode_peer_id(Id20(handshake.peer_id)) + ); + + let mut write_buf = Vec::::with_capacity(PIECE_MESSAGE_DEFAULT_LEN); + let handshake = Handshake::new(self.info_hash, self.peer_id); + handshake.serialize(&mut write_buf); + with_timeout(rwtimeout, conn.write_all(&write_buf)) + .await + .context("error writing handshake")?; + write_buf.clear(); + + let h_supports_extended = handshake.supports_extended(); + + self.manage_peer( + h_supports_extended, + read_so_far, + read_buf, + write_buf, + conn, + outgoing_chan, + ) + .await } pub async fn manage_peer_outgoing( &self, - mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, + outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, ) -> anyhow::Result<()> { use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; @@ -170,20 +209,51 @@ impl PeerConnection { let (h, size) = Handshake::deserialize(&read_buf[..read_so_far]) .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; + let h_supports_extended = h.supports_extended(); trace!("connected: id={:?}", try_decode_peer_id(Id20(h.peer_id))); if h.info_hash != self.info_hash.0 { anyhow::bail!("info hash does not match"); } - let mut extended_handshake: Option> = None; - let supports_extended = h.supports_extended(); - self.handler.on_handshake(h)?; + if read_so_far > size { read_buf.copy_within(size..read_so_far, 0); } read_so_far -= size; + self.manage_peer( + h_supports_extended, + read_so_far, + read_buf, + write_buf, + conn, + outgoing_chan, + ) + .await + } + + async fn manage_peer( + &self, + handshake_supports_extended: bool, + // How many bytes into read_buf is there of peer-sent-data. + mut read_so_far: usize, + mut read_buf: Vec, + mut write_buf: Vec, + mut conn: tokio::net::TcpStream, + mut outgoing_chan: tokio::sync::mpsc::UnboundedReceiver, + ) -> anyhow::Result<()> { + use tokio::io::AsyncReadExt; + use tokio::io::AsyncWriteExt; + + let rwtimeout = self + .options + .read_write_timeout + .unwrap_or_else(|| Duration::from_secs(10)); + + let mut extended_handshake: Option> = None; + let supports_extended = handshake_supports_extended; + if supports_extended { let my_extended = Message::Extended(ExtendedMessage::Handshake(ExtendedHandshake::new())); diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 7b64e854..dd65ad78 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -12,9 +12,11 @@ use std::{ use anyhow::{bail, Context}; use bencode::{bencode_serialize_to_writer, BencodeDeserializer}; use buffers::{ByteBufT, ByteString}; +use clone_to_owned::CloneToOwned; use dht::{ Dht, DhtBuilder, DhtConfig, Id20, PersistentDht, PersistentDhtConfig, RequestPeersStream, }; +use futures::{stream::FuturesUnordered, StreamExt, TryFutureExt}; use librqbit_core::{ directories::get_configuration_directory, magnet::Magnet, @@ -22,17 +24,23 @@ use librqbit_core::{ torrent_metainfo::{torrent_from_bytes, TorrentMetaV1Info, TorrentMetaV1Owned}, }; use parking_lot::RwLock; +use peer_binary_protocol::{Handshake, PIECE_MESSAGE_DEFAULT_LEN}; use reqwest::Url; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_with::serde_as; -use tokio::net::TcpListener; -use tracing::{debug, error, error_span, info, warn}; +use tokio::{ + io::AsyncReadExt, + net::{TcpListener, TcpStream}, +}; +use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ dht_utils::{read_metainfo_from_peer_receiver, ReadMetainfoResult}, - peer_connection::PeerConnectionOptions, + peer_connection::{with_timeout, PeerConnectionOptions}, spawn_utils::{spawn, BlockingSpawner}, - torrent_state::{ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState}, + torrent_state::{ + ManagedTorrentBuilder, ManagedTorrentHandle, ManagedTorrentState, TorrentStateLive, + }, }; pub const SUPPORTED_SCHEMES: [&str; 3] = ["http:", "https:", "magnet:"]; @@ -375,6 +383,14 @@ async fn get_public_announce_addr(port: u16) -> anyhow::Result { Ok(addr) } +pub(crate) struct CheckedIncomingConnection { + pub addr: SocketAddr, + pub stream: tokio::net::TcpStream, + pub read_buf: Vec, + pub handshake: Handshake, + pub read_so_far: usize, +} + impl Session { /// Create a new session. The passed in folder will be used as a default unless overriden per torrent. pub async fn new(output_folder: PathBuf) -> anyhow::Result> { @@ -509,14 +525,103 @@ impl Session { Ok(()) } + async fn check_incoming_connection( + &self, + addr: SocketAddr, + mut stream: TcpStream, + ) -> anyhow::Result<(Arc, CheckedIncomingConnection)> { + // TODO: move buffer handling to peer_connection + + let rwtimeout = self + .peer_opts + .read_write_timeout + .unwrap_or_else(|| Duration::from_secs(10)); + + let mut read_buf = vec![0u8; PIECE_MESSAGE_DEFAULT_LEN * 2]; + let mut read_so_far = with_timeout(rwtimeout, stream.read(&mut read_buf)) + .await + .context("error reading handshake")?; + if read_so_far == 0 { + anyhow::bail!("bad handshake"); + } + let (h, size) = Handshake::deserialize(&read_buf[..read_so_far]) + .map_err(|e| anyhow::anyhow!("error deserializing handshake: {:?}", e))?; + + trace!("received handshake from {addr}: {:?}", h); + + for (id, torrent) in self.db.read().torrents.iter() { + if torrent.info_hash().0 != h.info_hash { + continue; + } + + let live = match torrent.live() { + Some(live) => live, + None => { + bail!("torrent {id} is not live, ignoring connection"); + } + }; + + let handshake = h.clone_to_owned(); + + if read_so_far > size { + read_buf.copy_within(size..read_so_far, 0); + } + read_so_far -= size; + + return Ok(( + live, + CheckedIncomingConnection { + addr, + stream, + handshake, + read_buf, + read_so_far, + }, + )); + } + + bail!("didn't find a matching torrent for {:?}", h.info_hash) + } + + fn handover_checked_connection( + &self, + live: Arc, + checked: CheckedIncomingConnection, + ) -> anyhow::Result<()> { + live.add_incoming_peer(checked) + } + async fn task_tcp_listener(self: Arc, l: TcpListener) -> anyhow::Result<()> { - let mut buf = vec![0u8; 4096]; + let mut futs = FuturesUnordered::new(); loop { - let (stream, addr) = l.accept().await.context("error accepting")?; - info!("accepted connection from {addr}"); + tokio::select! { + r = l.accept() => { + match r { + Ok((stream, addr)) => { + trace!("accepted connection from {addr}"); + futs.push( + self.check_incoming_connection(addr, stream) + .map_err(|e| { + error!("error checking incoming connection: {e:#}"); + e + }) + .instrument(error_span!("incoming", addr=%addr)) + ); + } + Err(e) => { + error!("error accepting: {e:#}"); + continue; + } + } + }, + Some(Ok((live, checked))) = futs.next(), if !futs.is_empty() => { + if let Err(e) = self.handover_checked_connection(live, checked) { + warn!("error handing over incoming connection: {e:#}"); + } + }, + } } - Ok(()) } async fn task_upnp_port_forwarder(self: Arc, port: u16) -> anyhow::Result<()> { @@ -562,7 +667,7 @@ impl Session { }); } - fn stop(&self) { + pub fn stop(&self) { let _ = self.cancel_tx.send(()); } diff --git a/crates/librqbit/src/torrent_state/live/mod.rs b/crates/librqbit/src/torrent_state/live/mod.rs index 0276a081..a39f035b 100644 --- a/crates/librqbit/src/torrent_state/live/mod.rs +++ b/crates/librqbit/src/torrent_state/live/mod.rs @@ -89,7 +89,9 @@ use crate::{ peer_connection::{ PeerConnection, PeerConnectionHandler, PeerConnectionOptions, WriterRequest, }, + session::CheckedIncomingConnection, spawn_utils::spawn, + torrent_state::peer::Peer, tracker_comms::{TrackerError, TrackerRequest, TrackerRequestEvent, TrackerResponse}, type_aliases::{PeerHandle, BF}, }; @@ -100,7 +102,7 @@ use self::{ atomic::PeerCountersAtomic as AtomicPeerCounters, snapshot::{PeerStatsFilter, PeerStatsSnapshot}, }, - InflightRequest, PeerState, PeerTx, SendMany, + InflightRequest, PeerRx, PeerState, PeerTx, SendMany, }, peers::PeerStates, stats::{atomic::AtomicStats, snapshot::StatsSnapshot}, @@ -361,7 +363,99 @@ impl TorrentStateLive { } } - async fn task_manage_peer(self: Arc, addr: SocketAddr) -> anyhow::Result<()> { + pub(crate) fn add_incoming_peer( + self: &Arc, + checked_peer: CheckedIncomingConnection, + ) -> anyhow::Result<()> { + use dashmap::mapref::entry::Entry; + let (tx, rx) = unbounded_channel(); + + let counters = match self.peers.states.entry(checked_peer.addr) { + Entry::Occupied(_) => bail!("we are already managing peer {}", checked_peer.addr), + Entry::Vacant(vac) => { + let peer = Peer::new_live_for_incoming_connection( + Id20(checked_peer.handshake.peer_id), + tx.clone(), + ); + let counters = peer.stats.counters.clone(); + vac.insert(peer); + counters + } + }; + + self.spawn( + "incoming peer", + error_span!("manage_incoming_peer", addr = %checked_peer.addr), + self.clone() + .task_manage_incoming_peer(checked_peer, counters, tx, rx), + ); + Ok(()) + } + + async fn task_manage_incoming_peer( + self: Arc, + checked_peer: CheckedIncomingConnection, + counters: Arc, + tx: PeerTx, + rx: PeerRx, + ) -> anyhow::Result<()> { + // TODO: bump counters for incoming + + let handler = PeerHandler { + addr: checked_peer.addr, + on_bitfield_notify: Default::default(), + unchoke_notify: Default::default(), + locked: RwLock::new(PeerHandlerLocked { + i_am_choked: true, + previously_requested_pieces: BF::new(), + }), + requests_sem: Semaphore::new(0), + state: self.clone(), + tx, + counters, + }; + let options = PeerConnectionOptions { + connect_timeout: self.meta.options.peer_connect_timeout, + read_write_timeout: self.meta.options.peer_read_write_timeout, + ..Default::default() + }; + let peer_connection = PeerConnection::new( + checked_peer.addr, + self.meta.info_hash, + self.meta.peer_id, + &handler, + Some(options), + self.meta.spawner, + ); + let requester = handler.task_peer_chunk_requester(checked_peer.addr); + + let res = tokio::select! { + r = requester => {r} + r = peer_connection.manage_peer_incoming( + rx, + checked_peer.read_so_far, + checked_peer.read_buf, + checked_peer.handshake, + checked_peer.stream + ) => {r} + }; + + handler.state.peer_semaphore.add_permits(1); + + match res { + // We disconnected the peer ourselves as we don't need it + Ok(()) => { + handler.on_peer_died(None)?; + } + Err(e) => { + debug!("error managing peer: {:#}", e); + handler.on_peer_died(Some(e))?; + } + }; + Ok(()) + } + + async fn task_manage_outgoing_peer(self: Arc, addr: SocketAddr) -> anyhow::Result<()> { let state = self; let (rx, tx) = state.peers.mark_peer_connecting(addr)?; @@ -440,7 +534,7 @@ impl TorrentStateLive { state.spawn( "manage_peer", error_span!(parent: state.meta.span.clone(), "manage_peer", peer = addr.to_string()), - state.clone().task_manage_peer(addr), + state.clone().task_manage_outgoing_peer(addr), ); } } diff --git a/crates/librqbit/src/torrent_state/live/peer/mod.rs b/crates/librqbit/src/torrent_state/live/peer/mod.rs index 675d7623..b0eee035 100644 --- a/crates/librqbit/src/torrent_state/live/peer/mod.rs +++ b/crates/librqbit/src/torrent_state/live/peer/mod.rs @@ -52,6 +52,15 @@ pub(crate) struct Peer { pub stats: stats::atomic::PeerStats, } +impl Peer { + pub fn new_live_for_incoming_connection(peer_id: Id20, tx: PeerTx) -> Self { + Self { + state: PeerStateNoMut(PeerState::Live(LivePeerState::new(peer_id, tx))), + stats: Default::default(), + } + } +} + #[derive(Debug, Default)] pub(crate) enum PeerState { #[default] diff --git a/crates/peer_binary_protocol/Cargo.toml b/crates/peer_binary_protocol/Cargo.toml index 03dc948b..8261b541 100644 --- a/crates/peer_binary_protocol/Cargo.toml +++ b/crates/peer_binary_protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "librqbit-peer-protocol" -version = "3.2.1" +version = "3.3.0" edition = "2021" description = "Protocol for working with torrent peers. Used in rqbit torrent client." license = "Apache-2.0" diff --git a/crates/peer_binary_protocol/src/lib.rs b/crates/peer_binary_protocol/src/lib.rs index f448b87e..11171f7f 100644 --- a/crates/peer_binary_protocol/src/lib.rs +++ b/crates/peer_binary_protocol/src/lib.rs @@ -5,7 +5,7 @@ pub mod extended; use bincode::Options; -use buffers::{ByteBuf, ByteBufT, ByteString}; +use buffers::{ByteBuf, ByteString}; use byteorder::{ByteOrder, BE}; use clone_to_owned::CloneToOwned; use librqbit_core::{constants::CHUNK_SIZE, id20::Id20, lengths::ChunkInfo}; diff --git a/crates/rqbit/Cargo.toml b/crates/rqbit/Cargo.toml index 1a684579..741a1da7 100644 --- a/crates/rqbit/Cargo.toml +++ b/crates/rqbit/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rqbit" -version = "4.0.0" +version = "4.1.0" authors = ["Igor Katson "] edition = "2021" description = "A bittorrent command line client and server." @@ -23,7 +23,7 @@ default-tls = ["librqbit/default-tls"] rust-tls = ["librqbit/rust-tls"] [dependencies] -librqbit = {path="../librqbit", default-features=false, version = "4.0.0"} +librqbit = {path="../librqbit", default-features=false, version = "4.1.0"} tokio = {version = "1", features = ["macros", "rt-multi-thread"]} console-subscriber = {version = "0.2", optional = true} anyhow = "1"