Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CookieStore trait and default Jar #1203

Merged
merged 1 commit into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use std::any::Any;
use std::convert::TryInto;
use std::net::IpAddr;
use std::sync::Arc;
#[cfg(feature = "cookies")]
use std::sync::RwLock;
use std::time::Duration;
use std::{fmt, str};

Expand Down Expand Up @@ -105,7 +103,7 @@ struct Config {
local_address: Option<IpAddr>,
nodelay: bool,
#[cfg(feature = "cookies")]
cookie_store: Option<cookie::CookieStore>,
cookie_store: Option<Arc<dyn cookie::CookieStore>>,
trust_dns: bool,
error: Option<crate::Error>,
https_only: bool,
Expand Down Expand Up @@ -350,7 +348,7 @@ impl ClientBuilder {
inner: Arc::new(ClientRef {
accepts: config.accepts,
#[cfg(feature = "cookies")]
cookie_store: config.cookie_store.map(RwLock::new),
cookie_store: config.cookie_store,
hyper: hyper_client,
headers: config.headers,
redirect_policy: config.redirect_policy,
Expand Down Expand Up @@ -464,11 +462,31 @@ impl ClientBuilder {
#[cfg(feature = "cookies")]
#[cfg_attr(docsrs, doc(cfg(feature = "cookies")))]
pub fn cookie_store(mut self, enable: bool) -> ClientBuilder {
self.config.cookie_store = if enable {
Some(cookie::CookieStore::default())
if enable {
self.cookie_provider(Arc::new(cookie::Jar::default()))
} else {
None
};
self.config.cookie_store = None;
self
}
}

/// Set the persistent cookie store for the client.
///
/// Cookies received in responses will be passed to this store, and
/// additional requests will query this store for cookies.
///
/// By default, no cookie store is used.
///
/// # Optional
///
/// This requires the optional `cookies` feature to be enabled.
#[cfg(feature = "cookies")]
#[cfg_attr(docsrs, doc(cfg(feature = "cookies")))]
pub fn cookie_provider<C: cookie::CookieStore + 'static>(
mut self,
cookie_store: Arc<C>,
) -> ClientBuilder {
self.config.cookie_store = Some(cookie_store as _);
self
}

Expand Down Expand Up @@ -1109,10 +1127,9 @@ impl Client {
// Add cookies from the cookie store.
#[cfg(feature = "cookies")]
{
if let Some(cookie_store_wrapper) = self.inner.cookie_store.as_ref() {
if let Some(cookie_store) = self.inner.cookie_store.as_ref() {
if headers.get(crate::header::COOKIE).is_none() {
let cookie_store = cookie_store_wrapper.read().unwrap();
add_cookie_header(&mut headers, &cookie_store, &url);
add_cookie_header(&mut headers, &**cookie_store, &url);
}
}
}
Expand Down Expand Up @@ -1289,7 +1306,7 @@ impl Config {
struct ClientRef {
accepts: Accepts,
#[cfg(feature = "cookies")]
cookie_store: Option<RwLock<cookie::CookieStore>>,
cookie_store: Option<Arc<dyn cookie::CookieStore>>,
headers: HeaderMap,
hyper: HyperClient,
redirect_policy: redirect::Policy,
Expand Down Expand Up @@ -1431,14 +1448,11 @@ impl Future for PendingRequest {

#[cfg(feature = "cookies")]
{
if let Some(store_wrapper) = self.client.cookie_store.as_ref() {
let mut cookies = cookie::extract_response_cookies(&res.headers())
.filter_map(|res| res.ok())
.map(|cookie| cookie.into_inner().into_owned())
.peekable();
if let Some(ref cookie_store) = self.client.cookie_store {
let mut cookies =
cookie::extract_response_cookie_headers(&res.headers()).peekable();
if cookies.peek().is_some() {
let mut store = store_wrapper.write().unwrap();
store.0.store_response_cookies(cookies, &self.url);
cookie_store.set_cookies(&mut cookies, &self.url);
}
}
}
Expand Down Expand Up @@ -1531,11 +1545,8 @@ impl Future for PendingRequest {
// Add cookies from the cookie store.
#[cfg(feature = "cookies")]
{
if let Some(cookie_store_wrapper) =
self.client.cookie_store.as_ref()
{
let cookie_store = cookie_store_wrapper.read().unwrap();
add_cookie_header(&mut headers, &cookie_store, &self.url);
if let Some(ref cookie_store) = self.client.cookie_store {
add_cookie_header(&mut headers, &**cookie_store, &self.url);
}
}

Expand Down Expand Up @@ -1592,18 +1603,9 @@ fn make_referer(next: &Url, previous: &Url) -> Option<HeaderValue> {
}

#[cfg(feature = "cookies")]
fn add_cookie_header(headers: &mut HeaderMap, cookie_store: &cookie::CookieStore, url: &Url) {
let header = cookie_store
.0
.get_request_cookies(url)
.map(|c| format!("{}={}", c.name(), c.value()))
.collect::<Vec<_>>()
.join("; ");
if !header.is_empty() {
headers.insert(
crate::header::COOKIE,
HeaderValue::from_bytes(header.as_bytes()).unwrap(),
);
fn add_cookie_header(headers: &mut HeaderMap, cookie_store: &dyn cookie::CookieStore, url: &Url) {
if let Some(header) = cookie_store.cookies(url) {
headers.insert(crate::header::COOKIE, header);
}
}

Expand Down
19 changes: 19 additions & 0 deletions src/blocking/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ impl ClientBuilder {
self.with_inner(|inner| inner.cookie_store(enable))
}

/// Set the persistent cookie store for the client.
///
/// Cookies received in responses will be passed to this store, and
/// additional requests will query this store for cookies.
///
/// By default, no cookie store is used.
///
/// # Optional
///
/// This requires the optional `cookies` feature to be enabled.
#[cfg(feature = "cookies")]
#[cfg_attr(docsrs, doc(cfg(feature = "cookies")))]
pub fn cookie_provider<C: crate::cookie::CookieStore + 'static>(
self,
cookie_store: Arc<C>,
) -> ClientBuilder {
self.with_inner(|inner| inner.cookie_provider(cookie_store))
}

/// Enable auto gzip decompression by checking the `Content-Encoding` response header.
///
/// If auto gzip decompresson is turned on:
Expand Down
101 changes: 83 additions & 18 deletions src/cookie.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,43 @@
//! HTTP Cookies

use std::convert::TryInto;

use crate::header;
use std::fmt;
use std::sync::RwLock;
use std::time::SystemTime;

use crate::header::{HeaderValue, SET_COOKIE};
use bytes::Bytes;

/// Actions for a persistent cookie store providing session supprt.
pub trait CookieStore: Send + Sync {
/// Store a set of Set-Cookie header values recevied from `url`
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, url: &url::Url);
/// Get any Cookie values in the store for `url`
fn cookies(&self, url: &url::Url) -> Option<HeaderValue>;
}

/// A single HTTP cookie.
pub struct Cookie<'a>(cookie_crate::Cookie<'a>);

/// A good default `CookieStore` implementation.
///
/// This is the implementation used when simply calling `cookie_store(true)`.
/// This type is exposed to allow creating one and filling it with some
/// existing cookies more easily, before creating a `Client`.
#[derive(Debug, Default)]
pub struct Jar(RwLock<cookie_store::CookieStore>);

// ===== impl Cookie =====

impl<'a> Cookie<'a> {
fn parse(value: &'a crate::header::HeaderValue) -> Result<Cookie<'a>, CookieParseError> {
fn parse(value: &'a HeaderValue) -> Result<Cookie<'a>, CookieParseError> {
std::str::from_utf8(value.as_bytes())
.map_err(cookie_crate::ParseError::from)
.and_then(cookie_crate::Cookie::parse)
.map_err(CookieParseError)
.map(Cookie)
}

pub(crate) fn into_inner(self) -> cookie_crate::Cookie<'a> {
self.0
}

/// The name of the cookie.
pub fn name(&self) -> &str {
self.0.name()
Expand Down Expand Up @@ -82,25 +98,21 @@ impl<'a> fmt::Debug for Cookie<'a> {
}
}

pub(crate) fn extract_response_cookie_headers<'a>(
headers: &'a hyper::HeaderMap,
) -> impl Iterator<Item = &'a HeaderValue> + 'a {
headers.get_all(SET_COOKIE).iter()
}

pub(crate) fn extract_response_cookies<'a>(
headers: &'a hyper::HeaderMap,
) -> impl Iterator<Item = Result<Cookie<'a>, CookieParseError>> + 'a {
headers
.get_all(header::SET_COOKIE)
.get_all(SET_COOKIE)
.iter()
.map(|value| Cookie::parse(value))
}

/// A persistent cookie store that provides session support.
#[derive(Default)]
pub(crate) struct CookieStore(pub(crate) cookie_store::CookieStore);

impl<'a> fmt::Debug for CookieStore {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}

/// Error representing a parse failure of a 'Set-Cookie' header.
pub(crate) struct CookieParseError(cookie_crate::ParseError);

Expand All @@ -117,3 +129,56 @@ impl<'a> fmt::Display for CookieParseError {
}

impl std::error::Error for CookieParseError {}

// ===== impl Jar =====

impl Jar {
/// Add a cookie to this jar.
///
/// # Example
///
/// ```
/// use reqwest::{cookie::Jar, Url};
///
/// let cookie = "foo=bar; Domain=yolo.local";
/// let url = "https://yolo.local".parse::<Url>().unwrap();
///
/// let jar = Jar::default();
/// jar.add_cookie_str(cookie, &url);
///
/// // and now add to a `ClientBuilder`?
/// ```
pub fn add_cookie_str(&self, cookie: &str, url: &url::Url) {
let cookies = cookie_crate::Cookie::parse(cookie)
.ok()
.map(|c| c.into_owned())
.into_iter();
self.0.write().unwrap().store_response_cookies(cookies, url);
}
}

impl CookieStore for Jar {
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, url: &url::Url) {
let iter =
cookie_headers.filter_map(|val| Cookie::parse(val).map(|c| c.0.into_owned()).ok());

self.0.write().unwrap().store_response_cookies(iter, url);
}

fn cookies(&self, url: &url::Url) -> Option<HeaderValue> {
let s = self
.0
.read()
.unwrap()
.get_request_cookies(url)
.map(|c| format!("{}={}", c.name(), c.value()))
.collect::<Vec<_>>()
.join("; ");

if s.is_empty() {
return None;
}

HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
}
}