diff --git a/src/simple_http.rs b/src/simple_http.rs index fada349d..6ec5f4ff 100644 --- a/src/simple_http.rs +++ b/src/simple_http.rs @@ -5,7 +5,7 @@ #[cfg(feature = "proxy")] use socks::Socks5Stream; -use std::io::{BufRead, BufReader, BufWriter, Read, Write}; +use std::io::{BufRead, BufReader, Read, Write}; #[cfg(not(fuzzing))] use std::net::TcpStream; use std::net::{SocketAddr, ToSocketAddrs}; @@ -175,45 +175,45 @@ impl SimpleHttpTransport { // Serialize the body first so we can set the Content-Length header. let body = serde_json::to_vec(&req)?; - // Send HTTP request - { - let mut write_sock = BufWriter::new(sock.get_mut()); - // When we write to a socket, it may have died but we do not detect it. In this case we - // want to detect this ASAP and reconnect. We do this by writing the literal text POST - // in two pieces and checking for error returns on either one, and retrying in this - // case. - // - // From http://www.softlab.ntua.gr/facilities/documentation/unix/unix-socket-faq/unix-socket-faq-2.html - // "If the peer calls close() or exits...I would expect EPIPE, not on the next call, - // but the one after." - if write_sock.write_all(b"PO").is_err() || write_sock.write_all(b"ST ").is_err() { - **write_sock.get_mut() = self.fresh_socket()?; - write_sock.write_all(b"POST ")?; - } - write_sock.write_all(self.path.as_bytes())?; - write_sock.write_all(b" HTTP/1.1\r\n")?; - // Write headers - write_sock.write_all(b"host: ")?; - write_sock.write_all(self.addr.to_string().as_bytes())?; - write_sock.write_all(b"\r\n")?; - write_sock.write_all(b"Content-Type: application/json\r\n")?; - write_sock.write_all(b"Content-Length: ")?; - write_sock.write_all(body.len().to_string().as_bytes())?; - write_sock.write_all(b"\r\n")?; - if let Some(ref auth) = self.basic_auth { - write_sock.write_all(b"Authorization: ")?; - write_sock.write_all(auth.as_ref())?; - write_sock.write_all(b"\r\n")?; - } - // Write body - write_sock.write_all(b"\r\n")?; - write_sock.write_all(&body)?; - write_sock.flush()?; + let mut request_bytes = Vec::new(); + + request_bytes.write_all(b"POST ")?; + request_bytes.write_all(self.path.as_bytes())?; + request_bytes.write_all(b" HTTP/1.1\r\n")?; + // Write headers + request_bytes.write_all(b"host: ")?; + request_bytes.write_all(self.addr.to_string().as_bytes())?; + request_bytes.write_all(b"\r\n")?; + request_bytes.write_all(b"Content-Type: application/json\r\n")?; + request_bytes.write_all(b"Content-Length: ")?; + request_bytes.write_all(body.len().to_string().as_bytes())?; + request_bytes.write_all(b"\r\n")?; + if let Some(ref auth) = self.basic_auth { + request_bytes.write_all(b"Authorization: ")?; + request_bytes.write_all(auth.as_ref())?; + request_bytes.write_all(b"\r\n")?; } + // Write body + request_bytes.write_all(b"\r\n")?; + request_bytes.write_all(&body)?; + + // Send HTTP request + sock.get_mut().write_all(request_bytes.as_slice())?; + sock.get_mut().flush()?; // Parse first HTTP response header line let mut header_buf = String::new(); sock.read_line(&mut header_buf)?; + + // This indicates the socket is broken so lets retry the send once with a fresh socket + if header_buf.is_empty() { + *sock.get_mut() = self.fresh_socket()?; + sock.get_mut().write_all(request_bytes.as_slice())?; + sock.get_mut().flush()?; + + sock.read_line(&mut header_buf)?; + } + if header_buf.len() < 12 { return Err(Error::HttpResponseTooShort { actual: header_buf.len(), needed: 12 }); } @@ -622,9 +622,11 @@ impl crate::Client { #[cfg(test)] mod tests { - use std::net; + use serde_json::{Number, Value}; + use std::net::{Shutdown, TcpListener}; #[cfg(feature = "proxy")] use std::str::FromStr; + use std::{net, thread}; use super::*; use crate::Client; @@ -725,4 +727,57 @@ mod tests { ) .unwrap(); } + + /// Test that the client will detect that a socket is closed and open a fresh one before sending + /// the request + #[cfg(not(feature = "proxy"))] + #[test] + fn request_to_closed_socket() { + thread::spawn(move || { + let server = TcpListener::bind("localhost:2222").expect("Binding a Tcp Listener"); + + for (request_id, stream) in server.incoming().enumerate() { + let mut stream = stream.unwrap(); + + let buf_reader = BufReader::new(&mut stream); + + let _http_request: Vec<_> = buf_reader + .lines() + .map(|result| result.unwrap()) + .take_while(|line| !line.is_empty()) + .collect(); + + let response = Response { + result: None, + error: None, + id: Value::Number(Number::from(request_id)), + jsonrpc: Some(String::from("2.0")), + }; + let response_str = serde_json::to_string(&response).unwrap(); + + stream.write_all(b"HTTP/1.1 200\r\n").unwrap(); + stream.write_all(b"Content-Length: ").unwrap(); + stream.write_all(response_str.len().to_string().as_bytes()).unwrap(); + stream.write_all(b"\r\n").unwrap(); + stream.write_all(b"\r\n").unwrap(); + stream.write_all(response_str.as_bytes()).unwrap(); + stream.flush().unwrap(); + + stream.shutdown(Shutdown::Both).unwrap(); + } + }); + + // Give the server thread a second to start up and listen + thread::sleep(Duration::from_secs(1)); + + let client = Client::simple_http("localhost:2222", None, None).unwrap(); + let request = client.build_request("test_request", &[]); + let result = client.send_request(request).unwrap(); + assert_eq!(result.id, Value::Number(Number::from(0))); + thread::sleep(Duration::from_secs(1)); + let request = client.build_request("test_request2", &[]); + let result2 = client.send_request(request) + .expect("This second request should not be an Err like `Err(Transport(HttpResponseTooShort { actual: 0, needed: 12 }))`"); + assert_eq!(result2.id, Value::Number(Number::from(1))); + } }