Skip to content

Commit

Permalink
use finer grain locks
Browse files Browse the repository at this point in the history
  • Loading branch information
menghaoyu2002 committed Jun 3, 2024
1 parent bf841e3 commit 2271365
Showing 1 changed file with 27 additions and 33 deletions.
60 changes: 27 additions & 33 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl Display for ClientError {

struct PeerState {
peer_id: Vec<u8>,
stream: Arc<Mutex<TcpStream>>,
stream: TcpStream,
bitfield: Option<Bitfield>,
last_sent: DateTime<Utc>,

Expand All @@ -114,7 +114,7 @@ impl PeerState {
pub fn new(peer_id: &Vec<u8>, stream: TcpStream) -> Self {
Self {
peer_id: peer_id.clone(),
stream: Arc::new(Mutex::new(stream)),
stream,
last_sent: Utc::now(),

bitfield: None,
Expand All @@ -128,7 +128,7 @@ impl PeerState {

pub struct Client {
tracker: Tracker,
peers: Arc<RwLock<HashMap<Vec<u8>, PeerState>>>,
peers: Arc<RwLock<HashMap<Vec<u8>, Arc<RwLock<PeerState>>>>>,
bitfield: Bitfield,
send_queue: Arc<Mutex<VecDeque<(Vec<u8>, Message)>>>,
receive_queue: Arc<Mutex<VecDeque<(Vec<u8>, Message)>>>,
Expand All @@ -147,10 +147,7 @@ impl Client {
}

pub async fn download(&mut self) -> Result<(), ClientError> {
self.connect_to_peers(30).await?;

self.send_messages();
self.retrieve_messages();
self.connect_to_peers(10).await?;

let _ = tokio::join!(
self.send_messages(),
Expand All @@ -168,7 +165,7 @@ impl Client {
tokio::spawn(async move {
loop {
for (peer_id, peer) in peers.read().await.iter() {
if (Utc::now() - peer.last_sent).num_seconds() > 120 {
if (Utc::now() - peer.read().await.last_sent).num_seconds() > 40 {
println!(
"Sending keep alive to peer: {:?}",
String::from_utf8_lossy(peer_id)
Expand All @@ -192,13 +189,13 @@ impl Client {
let mut peers_to_remove = Vec::new();
loop {
for (peer_id, peer) in peers.read().await.iter() {
let mut stream = peer.stream.lock().await;
let stream = &mut peer.write().await.stream;
// println!(
// "Receiving message from peer: {:?}",
// String::from_utf8_lossy(peer_id)
// );

match receive_message(&mut *stream).await {
match receive_message(stream).await {
Ok(message) => {
println!(
"Received \"{}\" message from {}",
Expand All @@ -225,14 +222,11 @@ impl Client {
}

for peer_id in &peers_to_remove {
match peers.write().await.remove(peer_id) {
Some(peer) => {
println!(
"Disconnected from peer: {:?}",
String::from_utf8_lossy(&peer.peer_id)
);
}
None => continue,
if peers.write().await.remove(peer_id).is_some() {
println!(
"Disconnected from peer: {:?}",
String::from_utf8_lossy(&peer_id)
);
}
}
}
Expand All @@ -256,19 +250,19 @@ impl Client {
continue;
};

let mut stream = peer.stream.lock().await;
let stream = &mut peer.write().await.stream;
println!(
"Sending \"{}\" message from {}",
message.get_id(),
String::from_utf8_lossy(&peer_id)
);
send_message(&mut stream, &message).await
send_message(stream, &message).await
};

match send_result {
Ok(()) => {
let mut id_to_peer = peers.write().await;
let peer = id_to_peer.get_mut(&peer_id).unwrap();
let id_to_peer = peers.read().await;
let mut peer = id_to_peer.get(&peer_id).unwrap().write().await;
peer.last_sent = Utc::now();
}
Err(SendError::WouldBlock) => {
Expand All @@ -279,10 +273,10 @@ impl Client {
"Failed to send message to peer: {:?}",
String::from_utf8_lossy(&peer_id)
);
if let Some(peer) = peers.write().await.remove(&peer_id) {
if peers.write().await.remove(&peer_id).is_some() {
println!(
"Disconnected from peer: {:?}",
String::from_utf8_lossy(&peer.peer_id)
String::from_utf8_lossy(&peer_id)
);
}
}
Expand Down Expand Up @@ -418,20 +412,20 @@ impl Client {
Self::initiate_handshake(&mut stream, &handshake, &info_hash, &peer)
.await?;

// if peers.read().await.len() >= min_connections {
// return Err(ClientError::GetPeersError(String::from(
// "Already connected to minimum number of peers",
// )));
// }
if peers.read().await.len() >= min_connections {
return Err(ClientError::GetPeersError(String::from(
"Already connected to minimum number of peers",
)));
}

send_queue.lock().await.push_back((
peer_id.clone(),
Message::new(MessageId::Bitfield, &bitfield),
));
peers
.write()
.await
.insert(peer_id.clone(), PeerState::new(&peer_id, stream));
peers.write().await.insert(
peer_id.clone(),
Arc::new(RwLock::new(PeerState::new(&peer_id, stream))),
);

println!("Connected to peer: {:?}", peer.addr);

Expand Down

0 comments on commit 2271365

Please sign in to comment.