Skip to content

Commit

Permalink
fix(pagination): prevent infinite loop caused by clone
Browse files Browse the repository at this point in the history
  • Loading branch information
skytz authored and arlyon committed Jan 24, 2024
1 parent 67a8fd4 commit bc20bd4
Showing 1 changed file with 14 additions and 37 deletions.
51 changes: 14 additions & 37 deletions src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ pub trait Paginable {
pub trait PaginableList {
type O: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug;
fn new(data: Vec<Self::O>, url: String, has_more: bool, total_count: Option<u64>) -> Self;
fn get_data(&self) -> Vec<Self::O>;
fn get_data(&mut self) -> &mut Vec<Self::O>;
fn get_url(&self) -> String;
fn get_total_count(&self) -> Option<u64>;
fn has_more(&self) -> bool;
Expand All @@ -191,7 +191,7 @@ pub trait PaginableList {
/// A single page of a cursor-paginated list of a search object.
///
/// For more details, see <https://stripe.com/docs/api/pagination/search>
#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SearchList<T> {
pub object: String,
pub url: String,
Expand All @@ -214,19 +214,6 @@ impl<T> Default for SearchList<T> {
}
}

impl<T: Clone> Clone for SearchList<T> {
fn clone(&self) -> Self {
SearchList {
object: self.object.clone(),
data: self.data.clone(),
has_more: self.has_more,
total_count: self.total_count,
url: self.url.clone(),
next_page: self.next_page.clone(),
}
}
}

impl<T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug> PaginableList
for SearchList<T>
{
Expand All @@ -241,8 +228,8 @@ impl<T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::
Self { object: "".to_string(), url, has_more, data, next_page: None, total_count }
}

fn get_data(&self) -> Vec<Self::O> {
self.data.clone()
fn get_data(&mut self) -> &mut Vec<Self::O> {
&mut self.data
}
fn get_url(&self) -> String {
self.url.clone()
Expand All @@ -264,8 +251,8 @@ impl<T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::
Self { url, has_more, data, total_count }
}

fn get_data(&self) -> Vec<Self::O> {
self.data.clone()
fn get_data(&mut self) -> &mut Vec<Self::O> {
&mut self.data
}
fn get_url(&self) -> String {
self.url.clone()
Expand All @@ -287,7 +274,7 @@ impl<T> SearchList<T> {
/// A single page of a cursor-paginated list of an object.
///
/// For more details, see <https://stripe.com/docs/api/pagination>
#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct List<T> {
pub data: Vec<T>,
pub has_more: bool,
Expand All @@ -301,17 +288,6 @@ impl<T> Default for List<T> {
}
}

impl<T: Clone> Clone for List<T> {
fn clone(&self) -> Self {
List {
data: self.data.clone(),
has_more: self.has_more,
total_count: self.total_count,
url: self.url.clone(),
}
}
}

impl<T> List<T> {
pub fn paginate<P>(self, params: P) -> ListPaginator<List<T>, P> {
ListPaginator { page: self, params }
Expand Down Expand Up @@ -381,7 +357,7 @@ where
/// Requires `feature = ["async", "stream"]`.
#[cfg(all(feature = "async", feature = "stream"))]
pub fn stream(
self,
mut self,
client: &Client,
) -> impl futures_util::Stream<Item = Result<T::O, StripeError>> + Unpin {
// We are going to be popping items off the end of the list, so we need to reverse it.
Expand All @@ -395,7 +371,7 @@ where
async fn unfold_stream(
state: Option<(Self, Client)>,
) -> Option<(Result<T::O, StripeError>, Option<(Self, Client)>)> {
let (paginator, client) = state?; // If none, we sent the last item in the last iteration
let (mut paginator, client) = state?; // If none, we sent the last item in the last iteration

if paginator.page.get_data().len() > 1 {
return Some((Ok(paginator.page.get_data().pop()?), Some((paginator, client))));
Expand All @@ -407,7 +383,7 @@ where
}

match paginator.next(&client).await {
Ok(next_paginator) => {
Ok(mut next_paginator) => {
let data = paginator.page.get_data().pop()?;
next_paginator.page.get_data().reverse();

Expand All @@ -419,10 +395,11 @@ where
}

/// Fetch an additional page of data from stripe.
pub fn next(&self, client: &Client) -> Response<Self> {
pub fn next(&mut self, client: &Client) -> Response<Self> {
let page_url = self.page.get_url();
if let Some(last) = self.page.get_data().last() {
if self.page.get_url().starts_with("/v1/") {
let path = self.page.get_url().trim_start_matches("/v1/").to_string(); // the url we get back is prefixed
if page_url.starts_with("/v1/") {
let path = page_url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed

// clone the params and set the cursor
let params_next = {
Expand Down

0 comments on commit bc20bd4

Please sign in to comment.