diff --git a/src/client/message.rs b/src/client/message.rs index 61639e1..1e3f48a 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -1,9 +1,6 @@ use std::fmt::Display; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpStream, -}; +use tokio::{io::AsyncReadExt, net::TcpStream}; pub enum MessageId { Choke = 0, @@ -78,6 +75,40 @@ pub struct SendMessageError { error: String, } +#[derive(Debug)] +pub struct ReceiveMessageError { + error: String, +} + +#[derive(Debug)] +pub enum ReceiveError { + ReceiveError(ReceiveMessageError), + WouldBlock, +} + +pub enum SendError { + SendError(SendMessageError), + WouldBlock, +} + +impl Display for ReceiveError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ReceiveError::ReceiveError(e) => write!(f, "Failed to receive message: {}", e.error), + ReceiveError::WouldBlock => write!(f, "Would block"), + } + } +} + +impl Display for SendError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SendError::SendError(e) => write!(f, "Failed to send message: {}", e.error), + SendError::WouldBlock => write!(f, "Would block"), + } + } +} + impl Display for SendMessageError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -141,32 +172,59 @@ impl Display for Message { } } -pub async fn send_message( - stream: &mut TcpStream, - message: Message, -) -> Result<(), SendMessageError> { - stream.writable().await.unwrap(); - stream - .write_all(&message.serialize()) - .await - .map_err(|e| SendMessageError { - message: message.clone(), - error: format!("Failed to send message: {}", e), - })?; - +pub async fn send_message(stream: &mut TcpStream, message: &Message) -> Result<(), SendError> { + let mut bytes_written = 0; + while bytes_written < message.serialize().len() { + stream.writable().await.unwrap(); + match stream.try_write(&message.serialize()) { + Ok(0) => { + return Err(SendError::SendError(SendMessageError { + message: message.clone(), + error: "EOF".to_string(), + })) + } + Ok(n) => { + bytes_written += n; + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Err(SendError::WouldBlock); + } + Err(e) => { + return Err(SendError::SendError(SendMessageError { + message: message.clone(), + error: format!("Failed to send message: {}", e), + })); + } + }; + } Ok(()) } -pub async fn receive_message(stream: &mut TcpStream) -> Result { +pub async fn receive_message(stream: &mut TcpStream) -> Result { let mut len = [0u8; 4]; - stream - .read_exact(&mut len) - .await - .map_err(|e| SendMessageError { - message: Message::new(MessageId::Choke, &Vec::new()), - error: format!("Failed to read message length: {}", e), - })?; + let mut bytes_read = 0; + while bytes_read < 4 { + stream.readable().await.unwrap(); + match stream.try_read(&mut len[bytes_read..]) { + Ok(0) => { + return Err(ReceiveError::ReceiveError(ReceiveMessageError { + error: "stream was closed".to_string(), + })) + } + Ok(n) => { + bytes_read += n; + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Err(ReceiveError::WouldBlock); + } + Err(e) => { + return Err(ReceiveError::ReceiveError(ReceiveMessageError { + error: format!("Failed to read message length: {}", e), + })); + } + } + } let len = u32::from_be_bytes(len); if len == 0 { return Ok(Message { @@ -178,13 +236,11 @@ pub async fn receive_message(stream: &mut TcpStream) -> Result, stream: Arc>, bitfield: Option, - send_queue: Arc>>, - receive_queue: Arc>>, last_sent: DateTime, am_choking: bool, @@ -118,8 +115,6 @@ impl PeerState { Self { peer_id: peer_id.clone(), stream: Arc::new(Mutex::new(stream)), - send_queue: Arc::new(Mutex::new(VecDeque::new())), - receive_queue: Arc::new(Mutex::new(VecDeque::new())), last_sent: Utc::now(), bitfield: None, @@ -135,6 +130,8 @@ pub struct Client { tracker: Tracker, peers: Arc, PeerState>>>, bitfield: Bitfield, + send_queue: Arc, Message)>>>, + receive_queue: Arc, Message)>>>, } impl Client { @@ -144,96 +141,155 @@ impl Client { tracker, peers: Arc::new(RwLock::new(HashMap::new())), bitfield, + send_queue: Arc::new(Mutex::new(VecDeque::new())), + receive_queue: Arc::new(Mutex::new(VecDeque::new())), } } pub async fn download(&mut self) -> Result<(), ClientError> { - let peer_ids = self.connect_to_peers(30).await?; + self.connect_to_peers(30).await?; - let mut handles = Vec::new(); - for id in peer_ids { - handles.push(self.send_messages(&id).await); - handles.push(self.retrieve_messages(&id)); - } + self.send_messages(); + self.retrieve_messages(); - join_all(handles).await; + let _ = tokio::join!( + self.send_messages(), + self.retrieve_messages(), + self.keep_alive(), + // self.process_messages(), + ); Ok(()) } - fn retrieve_messages(&self, peer_id: &Vec) -> JoinHandle> { + fn keep_alive(&self) -> JoinHandle<()> { let peers = Arc::clone(&self.peers); - let id = peer_id.clone(); - + let send_queue = Arc::clone(&self.send_queue); tokio::spawn(async move { loop { - let id_to_peer = peers.read().await; - let Some(peer) = id_to_peer.get(&id) else { - break; - }; - let mut stream = peer.stream.lock().await; - let Ok(message) = receive_message(&mut stream).await else { - #[cfg(debug_assertions)] - eprintln!( - "Failed to receive message from peer: {}", - String::from_utf8_lossy(&id) - ); - break; - }; - - println!( - "Received \"{}\" message from peer: {}", - &message.get_id(), - String::from_utf8_lossy(&id) - ); + for (peer_id, peer) in peers.read().await.iter() { + if (Utc::now() - peer.last_sent).num_seconds() > 120 { + println!( + "Sending keep alive to peer: {:?}", + String::from_utf8_lossy(peer_id) + ); + send_queue.lock().await.push_back(( + peer_id.clone(), + Message::new(MessageId::KeepAlive, &Vec::new()), + )); + } + } - peer.receive_queue.lock().await.push_back(message); + yield_now().await; } - // peers.write().await.remove(&id); - id }) } - async fn send_messages(&self, peer_id: &Vec) -> JoinHandle> { + fn retrieve_messages(&self) -> JoinHandle<()> { let peers = Arc::clone(&self.peers); - let id = peer_id.clone(); + let receive_queue = Arc::clone(&self.receive_queue); tokio::spawn(async move { + let mut peers_to_remove = Vec::new(); loop { - let id_to_peer = peers.read().await; - let Some(peer) = id_to_peer.get(&id) else { - break; - }; - let message = match peer.send_queue.lock().await.pop_front() { - Some(m) => m, - None => { - if (peer.last_sent - Utc::now()).num_seconds() > 120 { - Message::new(MessageId::KeepAlive, &Vec::new()) - } else { - yield_now().await; + for (peer_id, peer) in peers.read().await.iter() { + let mut stream = peer.stream.lock().await; + // println!( + // "Receiving message from peer: {:?}", + // String::from_utf8_lossy(peer_id) + // ); + + match receive_message(&mut *stream).await { + Ok(message) => { + println!( + "Received \"{}\" message from {}", + message.get_id(), + String::from_utf8_lossy(peer_id) + ); + receive_queue + .lock() + .await + .push_back((peer_id.clone(), message)); + } + Err(ReceiveError::WouldBlock) => { continue; } + Err(e) => { + println!( + "Failed to receive message from peer {:?}: {}", + String::from_utf8_lossy(peer_id), + e.to_string() + ); + peers_to_remove.push(peer_id.clone()); + } + } + } + + for peer_id in &peers_to_remove { + match peers.write().await.remove(peer_id) { + Some(peer) => { + println!( + "Disconnected from peer: {:?}", + String::from_utf8_lossy(&peer.peer_id) + ); + } + None => continue, } + } + } + }) + } + + fn send_messages(&self) -> JoinHandle<()> { + let peers = Arc::clone(&self.peers); + let send_queue = Arc::clone(&self.send_queue); + tokio::spawn(async move { + loop { + let Some((peer_id, message)) = send_queue.lock().await.pop_front() else { + yield_now().await; + continue; }; - println!( - "Sending message {} to peer: {}", - message.get_id(), - String::from_utf8_lossy(&id) - ); - - let mut stream = peer.stream.lock().await; - if let Err(e) = send_message(&mut stream, message).await { - #[cfg(debug_assertions)] - eprintln!( - "Failed to send message to peer {:?}: {}", - String::from_utf8_lossy(&id), - e.to_string() + let send_result = { + let id_to_peer = peers.read().await; + let Some(peer) = id_to_peer.get(&peer_id) else { + // if peer is not found, discard the message + continue; + }; + + let mut stream = peer.stream.lock().await; + println!( + "Sending \"{}\" message from {}", + message.get_id(), + String::from_utf8_lossy(&peer_id) ); - break; + send_message(&mut stream, &message).await + }; + + match send_result { + Ok(()) => { + let mut id_to_peer = peers.write().await; + let peer = id_to_peer.get_mut(&peer_id).unwrap(); + peer.last_sent = Utc::now(); + } + Err(SendError::WouldBlock) => { + send_queue.lock().await.push_back((peer_id, message)); + } + Err(_) => { + println!( + "Failed to send message to peer: {:?}", + String::from_utf8_lossy(&peer_id) + ); + if let Some(peer) = peers.write().await.remove(&peer_id) { + println!( + "Disconnected from peer: {:?}", + String::from_utf8_lossy(&peer.peer_id) + ); + } + } } + + // yield_now().await; } - // peers.write().await.remove(&id); - id }) } @@ -318,12 +374,8 @@ impl Client { Self::validate_handshake(&response, info_hash) } - async fn connect_to_peers( - &mut self, - min_connections: usize, - ) -> Result>, ClientError> { + async fn connect_to_peers(&mut self, min_connections: usize) -> Result<(), ClientError> { println!("Connecting to peers..."); - let mut new_peers = Vec::new(); while self.peers.read().await.len() < min_connections { let mut handles = JoinSet::new(); for peer in @@ -331,9 +383,6 @@ impl Client { ClientError::GetPeersError(format!("Failed to get peers: {}", e)) })? { - if self.peers.read().await.len() >= min_connections { - break; - } let handshake = self.get_handshake()?; let info_hash = self.tracker.get_metainfo().get_info_hash().map_err(|_| { ClientError::GetPeersError(String::from("Failed to get info hash")) @@ -341,6 +390,7 @@ impl Client { let bitfield = self.bitfield.to_bytes(); let peers = Arc::clone(&mut self.peers); + let send_queue = Arc::clone(&self.send_queue); handles.spawn(async move { let mut stream = match timeout( @@ -368,19 +418,20 @@ impl Client { Self::initiate_handshake(&mut stream, &handshake, &info_hash, &peer) .await?; - if peers.read().await.len() >= min_connections { - return Err(ClientError::GetPeersError(String::from( - "Already connected to minimum number of peers", - ))); - } - - let peer_state = PeerState::new(&peer_id, stream); - peer_state - .send_queue - .lock() + // if peers.read().await.len() >= min_connections { + // return Err(ClientError::GetPeersError(String::from( + // "Already connected to minimum number of peers", + // ))); + // } + + send_queue.lock().await.push_back(( + peer_id.clone(), + Message::new(MessageId::Bitfield, &bitfield), + )); + peers + .write() .await - .push_back(Message::new(MessageId::Bitfield, &bitfield)); - peers.write().await.insert(peer_id.clone(), peer_state); + .insert(peer_id.clone(), PeerState::new(&peer_id, stream)); println!("Connected to peer: {:?}", peer.addr); @@ -392,19 +443,15 @@ impl Client { let conection_result = handle.map_err(|e| ClientError::GetPeersError(format!("{}", e)))?; - match conection_result { - Ok(peer_id) => { - new_peers.push(peer_id); - } - Err(e) => { - // #[cfg(debug_assertions)] - // eprintln!("{}", e); - } + if let Err(e) = conection_result { + // #[cfg(debug_assertions)] + // eprintln!("{}", e); } + // println!("{}", handles.len()) } } - println!("Connected to {} new peers", new_peers.len()); - Ok(new_peers) + println!("Connected to {} new peers", self.peers.read().await.len()); + Ok(()) } }