Skip to content

Commit

Permalink
Refactor Socket::write to prevent infinite loop on long writes (esp-r…
Browse files Browse the repository at this point in the history
  • Loading branch information
teotwaki authored and bjoernQ committed May 24, 2024
1 parent fdddaf5 commit 10e8fab
Showing 1 changed file with 20 additions and 43 deletions.
63 changes: 20 additions & 43 deletions esp-wifi/src/wifi_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,16 +447,14 @@ impl<'s, 'n: 's> Read for Socket<'s, 'n> {
impl<'s, 'n: 's> Write for Socket<'s, 'n> {
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
loop {
self.network.with_mut(|interface, device, sockets| {
interface.poll(
Instant::from_millis((self.network.current_millis_fn)() as i64),
device,
sockets,
)
});

let (may_send, is_open, can_send) =
self.network.with_mut(|_interface, _device, sockets| {
self.network.with_mut(|interface, device, sockets| {
interface.poll(
Instant::from_millis((self.network.current_millis_fn)() as i64),
device,
sockets,
);

let socket = sockets.get_mut::<TcpSocket>(self.socket_handle);

(socket.may_send(), socket.is_open(), socket.can_send())
Expand All @@ -466,50 +464,29 @@ impl<'s, 'n: 's> Write for Socket<'s, 'n> {
break;
}

if !is_open {
return Err(IoError::SocketClosed);
}

if !can_send {
if !is_open || !can_send {
return Err(IoError::SocketClosed);
}
}

let mut written = 0;
loop {
let res = self.network.with_mut(|interface, device, sockets| {
interface.poll(
Instant::from_millis((self.network.current_millis_fn)() as i64),
device,
sockets,
)
});
self.flush()?;

if let false = res {
self.network.with_mut(|_interface, _device, sockets| {
sockets
.get_mut::<TcpSocket>(self.socket_handle)
.send_slice(&buf[written..])
.map(|len| written += len)
.map_err(IoError::TcpSendError)
})?;

if written >= buf.len() {
break;
}
}

let res = self.network.with_mut(|_interface, _device, sockets| {
let socket = sockets.get_mut::<TcpSocket>(self.socket_handle);

let mut written = 0;
loop {
match socket.send_slice(&buf[written..]) {
Ok(len) => {
written += len;

if written >= buf.len() {
break Ok(written);
}

log::info!("not fully written: {}", len);
}
Err(err) => break Err(IoError::TcpSendError(err)),
}
}
});

res
Ok(written)
}

fn flush(&mut self) -> Result<(), Self::Error> {
Expand Down

0 comments on commit 10e8fab

Please sign in to comment.