Skip to content

Commit

Permalink
attempt to fix contention issues
Browse files Browse the repository at this point in the history
  • Loading branch information
menghaoyu2002 committed Jun 3, 2024
1 parent aedbbe4 commit bf841e3
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 128 deletions.
118 changes: 87 additions & 31 deletions src/client/message.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use std::fmt::Display;

use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
use tokio::{io::AsyncReadExt, net::TcpStream};

pub enum MessageId {
Choke = 0,
Expand Down Expand Up @@ -78,6 +75,40 @@ pub struct SendMessageError {
error: String,
}

#[derive(Debug)]
pub struct ReceiveMessageError {
error: String,
}

#[derive(Debug)]
pub enum ReceiveError {
ReceiveError(ReceiveMessageError),
WouldBlock,
}

pub enum SendError {
SendError(SendMessageError),
WouldBlock,
}

impl Display for ReceiveError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReceiveError::ReceiveError(e) => write!(f, "Failed to receive message: {}", e.error),
ReceiveError::WouldBlock => write!(f, "Would block"),
}
}
}

impl Display for SendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SendError::SendError(e) => write!(f, "Failed to send message: {}", e.error),
SendError::WouldBlock => write!(f, "Would block"),
}
}
}

impl Display for SendMessageError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
Expand Down Expand Up @@ -141,32 +172,59 @@ impl Display for Message {
}
}

pub async fn send_message(
stream: &mut TcpStream,
message: Message,
) -> Result<(), SendMessageError> {
stream.writable().await.unwrap();
stream
.write_all(&message.serialize())
.await
.map_err(|e| SendMessageError {
message: message.clone(),
error: format!("Failed to send message: {}", e),
})?;

pub async fn send_message(stream: &mut TcpStream, message: &Message) -> Result<(), SendError> {
let mut bytes_written = 0;
while bytes_written < message.serialize().len() {
stream.writable().await.unwrap();
match stream.try_write(&message.serialize()) {
Ok(0) => {
return Err(SendError::SendError(SendMessageError {
message: message.clone(),
error: "EOF".to_string(),
}))
}
Ok(n) => {
bytes_written += n;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
return Err(SendError::WouldBlock);
}
Err(e) => {
return Err(SendError::SendError(SendMessageError {
message: message.clone(),
error: format!("Failed to send message: {}", e),
}));
}
};
}
Ok(())
}

pub async fn receive_message(stream: &mut TcpStream) -> Result<Message, SendMessageError> {
pub async fn receive_message(stream: &mut TcpStream) -> Result<Message, ReceiveError> {
let mut len = [0u8; 4];
stream
.read_exact(&mut len)
.await
.map_err(|e| SendMessageError {
message: Message::new(MessageId::Choke, &Vec::new()),
error: format!("Failed to read message length: {}", e),
})?;

let mut bytes_read = 0;
while bytes_read < 4 {
stream.readable().await.unwrap();
match stream.try_read(&mut len[bytes_read..]) {
Ok(0) => {
return Err(ReceiveError::ReceiveError(ReceiveMessageError {
error: "stream was closed".to_string(),
}))
}
Ok(n) => {
bytes_read += n;
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
return Err(ReceiveError::WouldBlock);
}
Err(e) => {
return Err(ReceiveError::ReceiveError(ReceiveMessageError {
error: format!("Failed to read message length: {}", e),
}));
}
}
}
let len = u32::from_be_bytes(len);
if len == 0 {
return Ok(Message {
Expand All @@ -178,13 +236,11 @@ pub async fn receive_message(stream: &mut TcpStream) -> Result<Message, SendMess

let mut message = vec![0u8; len as usize];
stream.readable().await.unwrap();
stream
.read_exact(&mut message)
.await
.map_err(|e| SendMessageError {
message: Message::new(MessageId::Choke, &Vec::new()),
stream.read_exact(&mut message).await.map_err(|e| {
ReceiveError::ReceiveError(ReceiveMessageError {
error: format!("Failed to read message: {}", e),
})?;
})
})?;

let id = message[0];
let payload = message[1..].to_vec();
Expand Down
Loading

0 comments on commit bf841e3

Please sign in to comment.