diff --git a/Cargo.toml b/Cargo.toml index dda46fde7..4d8d8ed00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ serde = "0.9" serde_json = "0.9" serde_urlencoded = "0.4" url = "1.2" +libflate = "0.1.3" [dev-dependencies] env_logger = "0.3" diff --git a/src/client.rs b/src/client.rs index e5fea00f5..42b994eaf 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,9 +1,10 @@ use std::fmt; use std::io::{self, Read}; use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicBool, Ordering}; use hyper::client::IntoUrl; -use hyper::header::{Headers, ContentType, Location, Referer, UserAgent, Accept}; +use hyper::header::{Headers, ContentType, Location, Referer, UserAgent, Accept, ContentEncoding, Encoding, ContentLength}; use hyper::method::Method; use hyper::status::StatusCode; use hyper::version::HttpVersion; @@ -38,10 +39,16 @@ impl Client { inner: Arc::new(ClientRef { hyper: client, redirect_policy: Mutex::new(RedirectPolicy::default()), + auto_ungzip: AtomicBool::new(true), }), }) } + /// Enable auto gzip decompression by checking the ContentEncoding response header. + pub fn gzip(&mut self, enable: bool) { + self.inner.auto_ungzip.store(enable, Ordering::Relaxed); + } + /// Set a `RedirectPolicy` for this client. pub fn redirect(&mut self, policy: RedirectPolicy) { *self.inner.redirect_policy.lock().unwrap() = policy; @@ -94,6 +101,7 @@ impl fmt::Debug for Client { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Client") .field("redirect_policy", &self.inner.redirect_policy) + .field("auto_ungzip", &self.inner.auto_ungzip) .finish() } } @@ -101,6 +109,7 @@ impl fmt::Debug for Client { struct ClientRef { hyper: ::hyper::Client, redirect_policy: Mutex<RedirectPolicy>, + auto_ungzip: AtomicBool, } fn new_hyper_client() -> ::Result<::hyper::Client> { @@ -268,7 +277,7 @@ impl RequestBuilder { loc } else { return Ok(Response { - inner: res + inner: Decoder::from_hyper_response(res, client.auto_ungzip.load(Ordering::Relaxed)) }); } }; @@ -282,14 +291,14 @@ impl RequestBuilder { } else { debug!("redirect_policy disallowed redirection to '{}'", loc); return Ok(Response { - inner: res + inner: Decoder::from_hyper_response(res, client.auto_ungzip.load(Ordering::Relaxed)) }) } }, Err(e) => { debug!("Location header had invalid URI: {:?}", e); return Ok(Response { - inner: res + inner: Decoder::from_hyper_response(res, client.auto_ungzip.load(Ordering::Relaxed)) }) } }; @@ -299,7 +308,7 @@ impl RequestBuilder { //TODO: removeSensitiveHeaders(&mut headers, &url); } else { return Ok(Response { - inner: res + inner: Decoder::from_hyper_response(res, client.auto_ungzip.load(Ordering::Relaxed)) }); } } @@ -318,26 +327,56 @@ impl fmt::Debug for RequestBuilder { /// A Response to a submitted `Request`. pub struct Response { - inner: ::hyper::client::Response, + inner: Decoder, +} + +impl fmt::Debug for Response { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + return match &self.inner { + &Decoder::PlainText(ref hyper_response) => { + f.debug_struct("Response") + .field("status", &hyper_response.status) + .field("headers", &hyper_response.headers) + .field("version", &hyper_response.version) + .finish() + }, + &Decoder::Gzip{ref status, ref version, ref headers, ..} => { + f.debug_struct("Response") + .field("status", &status) + .field("headers", &headers) + .field("version", &version) + .finish() + } + } + } } impl Response { /// Get the `StatusCode`. #[inline] pub fn status(&self) -> &StatusCode { - &self.inner.status + match &self.inner { + &Decoder::PlainText(ref hyper_response) => &hyper_response.status, + &Decoder::Gzip{ref status, ..} => status + } } /// Get the `Headers`. #[inline] pub fn headers(&self) -> &Headers { - &self.inner.headers + match &self.inner { + &Decoder::PlainText(ref hyper_response) => &hyper_response.headers, + &Decoder::Gzip{ref headers, ..} => headers + } } /// Get the `HttpVersion`. #[inline] pub fn version(&self) -> &HttpVersion { - &self.inner.version + match &self.inner { + &Decoder::PlainText(ref hyper_response) => &hyper_response.version, + &Decoder::Gzip{ref version, ..} => version + } } /// Try and deserialize the response body as JSON. @@ -347,6 +386,72 @@ impl Response { } } +enum Decoder { + /// A `PlainText` decoder just returns the response content as is. + PlainText(::hyper::client::Response), + /// A `Gzip` decoder will uncompress the gziped response content before returning it. + Gzip { + decoder: ::libflate::gzip::Decoder<::hyper::client::Response>, + headers: ::hyper::header::Headers, + version: ::hyper::version::HttpVersion, + status: ::hyper::status::StatusCode, + } +} + +impl Decoder { + /// Constructs a Decoder from a hyper request. + /// + /// A decoder is just a wrapper around the hyper request that knows + /// how to decode the content body of the request. + /// + /// Uses the correct variant by inspecting the Content-Encoding header. + fn from_hyper_response(res: ::hyper::client::Response, check_gzip: bool) -> Self { + if !check_gzip { + return Decoder::PlainText(res); + } + + let mut is_gzip = false; + match res.headers.get::<ContentEncoding>() { + Some(encoding_types) => { + if encoding_types.contains(&Encoding::Gzip) { + is_gzip = true; + } + if let Some(content_length) = res.headers.get::<ContentLength>() { + if content_length.0 == 0 { + warn!("GZipped response with content-length of 0"); + is_gzip = false; + } + } + } + _ => {} + } + + if is_gzip { + return Decoder::Gzip { + status: res.status.clone(), + version: res.version.clone(), + headers: res.headers.clone(), + decoder: ::libflate::gzip::Decoder::new(res).unwrap(), + }; + } else { + return Decoder::PlainText(res); + } + } +} + +impl Read for Decoder { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + match self { + &mut Decoder::PlainText(ref mut hyper_response) => { + hyper_response.read(buf) + }, + &mut Decoder::Gzip{ref mut decoder, ..} => { + decoder.read(buf) + } + } + } +} + /// Read the body of the Response. impl Read for Response { #[inline] @@ -355,16 +460,6 @@ impl Read for Response { } } -impl fmt::Debug for Response { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Response") - .field("status", self.status()) - .field("headers", self.headers()) - .field("version", self.version()) - .finish() - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index 89a5bcb74..72c065a46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,6 +101,7 @@ extern crate hyper; #[macro_use] extern crate log; +extern crate libflate; extern crate hyper_native_tls; extern crate serde; extern crate serde_json; diff --git a/tests/client.rs b/tests/client.rs index 875a4c352..15d53a6b6 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1,8 +1,10 @@ extern crate reqwest; +extern crate libflate; #[macro_use] mod server; use std::io::Read; +use std::io::prelude::*; #[test] fn test_get() { @@ -248,3 +250,44 @@ fn test_accept_header_is_not_changed_if_set() { assert_eq!(res.status(), &reqwest::StatusCode::Ok); } + +#[test] +fn test_gzip_response() { + let mut encoder = ::libflate::gzip::Encoder::new(Vec::new()).unwrap(); + match encoder.write(b"test request") { + Ok(n) => assert!(n > 0, "Failed to write to encoder."), + _ => panic!("Failed to gzip encode string.") + }; + + let gzipped_content = encoder.finish().into_result().unwrap(); + + let mut response = format!("\ + HTTP/1.1 200 OK\r\n\ + Server: test-accept\r\n\ + Content-Encoding: gzip\r\n\ + Content-Length: {}\r\n\ + \r\n", &gzipped_content.len()) + .into_bytes(); + response.extend(&gzipped_content); + + let server = server! { + request: b"\ + GET /gzip HTTP/1.1\r\n\ + Host: $HOST\r\n\ + User-Agent: $USERAGENT\r\n\ + Accept: */*\r\n\ + \r\n\ + ", + response: response + }; + let mut res = reqwest::get(&format!("http://{}/gzip", server.addr())) + .unwrap(); + + let mut body = ::std::string::String::new(); + match res.read_to_string(&mut body) { + Ok(n) => assert!(n > 0, "Failed to write to buffer."), + _ => panic!("Failed to write to buffer.") + }; + + assert_eq!(body, "test request"); +} \ No newline at end of file