Skip to content

Commit

Permalink
Merge pull request #414 from hzargar2/searchable
Browse files Browse the repository at this point in the history
Added Search API functionality
  • Loading branch information
arlyon authored Oct 31, 2023
2 parents c568330 + 0298de5 commit d9d4ecd
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 41 deletions.
170 changes: 135 additions & 35 deletions src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,109 @@ pub trait Paginable {
fn set_last(&mut self, item: Self::O);
}

#[derive(Debug)]
pub struct ListPaginator<T, P> {
pub page: List<T>,
pub params: P,
pub trait PaginableList {
type O: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug;
fn new(data: Vec<Self::O>, url: String, has_more: bool, total_count: Option<u64>) -> Self;
fn get_data(&self) -> Vec<Self::O>;
fn get_url(&self) -> String;
fn get_total_count(&self) -> Option<u64>;
fn has_more(&self) -> bool;
}

/// A single page of a cursor-paginated list of a search object.
///
/// For more details, see <https://stripe.com/docs/api/pagination/search>
#[derive(Debug, Deserialize, Serialize)]
pub struct SearchList<T> {
pub object: String,
pub url: String,
pub has_more: bool,
pub data: Vec<T>,
pub next_page: Option<String>,
pub total_count: Option<u64>,
}

impl<T> Default for SearchList<T> {
fn default() -> Self {
SearchList {
object: String::new(),
data: Vec::new(),
has_more: false,
total_count: None,
url: String::new(),
next_page: None,
}
}
}

impl<T: Clone> Clone for SearchList<T> {
fn clone(&self) -> Self {
SearchList {
object: self.object.clone(),
data: self.data.clone(),
has_more: self.has_more,
total_count: self.total_count,
url: self.url.clone(),
next_page: self.next_page.clone(),
}
}
}

impl<T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug> PaginableList
for SearchList<T>
{
type O = T;

fn new(
data: Vec<Self::O>,
url: String,
has_more: bool,
total_count: Option<u64>,
) -> SearchList<T> {
Self { object: "".to_string(), url, has_more, data, next_page: None, total_count }
}

fn get_data(&self) -> Vec<Self::O> {
self.data.clone()
}
fn get_url(&self) -> String {
self.url.clone()
}
fn get_total_count(&self) -> Option<u64> {
self.total_count
}
fn has_more(&self) -> bool {
self.has_more
}
}

impl<T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug> PaginableList
for List<T>
{
type O = T;

fn new(data: Vec<Self::O>, url: String, has_more: bool, total_count: Option<u64>) -> List<T> {
Self { url, has_more, data, total_count }
}

fn get_data(&self) -> Vec<Self::O> {
self.data.clone()
}
fn get_url(&self) -> String {
self.url.clone()
}
fn get_total_count(&self) -> Option<u64> {
self.total_count
}
fn has_more(&self) -> bool {
self.has_more
}
}

impl<T> SearchList<T> {
pub fn paginate<P>(self, params: P) -> ListPaginator<SearchList<T>, P> {
ListPaginator { page: self, params }
}
}

/// A single page of a cursor-paginated list of an object.
Expand Down Expand Up @@ -214,33 +313,39 @@ impl<T: Clone> Clone for List<T> {
}

impl<T> List<T> {
pub fn paginate<P>(self, params: P) -> ListPaginator<T, P> {
pub fn paginate<P>(self, params: P) -> ListPaginator<List<T>, P> {
ListPaginator { page: self, params }
}
}

#[derive(Debug)]
pub struct ListPaginator<T, P> {
pub page: T,
pub params: P,
}

impl<
T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug,
T: PaginableList + Send + DeserializeOwned + 'static,
P: Clone + Serialize + Send + 'static + std::fmt::Debug,
> ListPaginator<T, P>
where
P: Paginable<O = T>,
P: Paginable<O = T::O>,
{
/// Repeatedly queries Stripe for more data until all elements in list are fetched, using
/// Stripe's default page size.
///
/// Requires `feature = "blocking"`.
#[cfg(feature = "blocking")]
pub fn get_all(self, client: &Client) -> Response<Vec<T>> {
let mut data = Vec::with_capacity(self.page.total_count.unwrap_or(0) as usize);
pub fn get_all(self, client: &Client) -> Response<Vec<T::O>> {
let mut data = Vec::with_capacity(self.page.get_total_count().unwrap_or(0) as usize);
let mut paginator = self;
loop {
if !paginator.page.has_more {
data.extend(paginator.page.data.into_iter());
if !paginator.page.has_more() {
data.extend(paginator.page.get_data().into_iter());
break;
}
let next_paginator = paginator.next(client)?;
data.extend(paginator.page.data.into_iter());
data.extend(paginator.page.get_data().into_iter());
paginator = next_paginator
}
Ok(data)
Expand Down Expand Up @@ -276,11 +381,11 @@ where
/// Requires `feature = ["async", "stream"]`.
#[cfg(all(feature = "async", feature = "stream"))]
pub fn stream(
mut self,
self,
client: &Client,
) -> impl futures_util::Stream<Item = Result<T, StripeError>> + Unpin {
) -> impl futures_util::Stream<Item = Result<T::O, StripeError>> + Unpin {
// We are going to be popping items off the end of the list, so we need to reverse it.
self.page.data.reverse();
self.page.get_data().reverse();

Box::pin(futures_util::stream::unfold(Some((self, client.clone())), Self::unfold_stream))
}
Expand All @@ -289,22 +394,22 @@ where
#[cfg(all(feature = "async", feature = "stream"))]
async fn unfold_stream(
state: Option<(Self, Client)>,
) -> Option<(Result<T, StripeError>, Option<(Self, Client)>)> {
let (mut paginator, client) = state?; // If none, we sent the last item in the last iteration
) -> Option<(Result<T::O, StripeError>, Option<(Self, Client)>)> {
let (paginator, client) = state?; // If none, we sent the last item in the last iteration

if paginator.page.data.len() > 1 {
return Some((Ok(paginator.page.data.pop()?), Some((paginator, client))));
if paginator.page.get_data().len() > 1 {
return Some((Ok(paginator.page.get_data().pop()?), Some((paginator, client))));
// We have more data on this page
}

if !paginator.page.has_more {
return Some((Ok(paginator.page.data.pop()?), None)); // Final value of the stream, no errors
if !paginator.page.has_more() {
return Some((Ok(paginator.page.get_data().pop()?), None)); // Final value of the stream, no errors
}

match paginator.next(&client).await {
Ok(mut next_paginator) => {
let data = paginator.page.data.pop()?;
next_paginator.page.data.reverse();
Ok(next_paginator) => {
let data = paginator.page.get_data().pop()?;
next_paginator.page.get_data().reverse();

// Yield last value of thimuts page, the next page (and client) becomes the state
Some((Ok(data), Some((next_paginator, client))))
Expand All @@ -315,9 +420,9 @@ where

/// Fetch an additional page of data from stripe.
pub fn next(&self, client: &Client) -> Response<Self> {
if let Some(last) = self.page.data.last() {
if self.page.url.starts_with("/v1/") {
let path = self.page.url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed
if let Some(last) = self.page.get_data().last() {
if self.page.get_url().starts_with("/v1/") {
let path = self.page.get_url().trim_start_matches("/v1/").to_string(); // the url we get back is prefixed

// clone the params and set the cursor
let params_next = {
Expand All @@ -334,12 +439,7 @@ where
}
} else {
ok(ListPaginator {
page: List {
data: Vec::new(),
has_more: false,
total_count: self.page.total_count,
url: self.page.url.clone(),
},
page: T::new(Vec::new(), self.page.get_url(), false, self.page.get_total_count()),
params: self.params.clone(),
})
}
Expand All @@ -348,13 +448,13 @@ where
/// Pin a new future which maps the result inside the page future into
/// a ListPaginator
#[cfg(feature = "async")]
fn create_paginator(page: Response<List<T>>, params: P) -> Response<Self> {
fn create_paginator(page: Response<T>, params: P) -> Response<Self> {
use futures_util::FutureExt;
Box::pin(page.map(|page| page.map(|page| ListPaginator { page, params })))
}

#[cfg(feature = "blocking")]
fn create_paginator(page: Response<List<T>>, params: P) -> Response<Self> {
fn create_paginator(page: Response<T>, params: P) -> Response<Self> {
page.map(|page| ListPaginator { page, params })
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/resources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ mod billing {
pub mod usage_record_ext;
}

#[path = "resources"]
#[cfg(feature = "products")]
mod products {
pub mod price_ext;
pub mod product_ext;
}

#[path = "resources"]
#[cfg(feature = "checkout")]
mod checkout {
Expand Down
25 changes: 24 additions & 1 deletion src/resources/charge_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};

use crate::client::{Client, Response};
use crate::ids::{AccountId, BankAccountId, CardId, ChargeId, SourceId, TokenId};
use crate::params::Object;
use crate::params::{Object, SearchList};
use crate::resources::{Charge, Rule};

/// The set of PaymentSource parameters that can be used to create a charge.
Expand Down Expand Up @@ -44,6 +44,13 @@ impl Charge {
) -> Response<Charge> {
client.post_form(&format!("/charges/{}/capture", charge_id), params)
}

/// Searches for a charge.
///
/// For more details see <https://stripe.com/docs/api/charges/search>.
pub fn search(client: &Client, params: ChargeSearchParams) -> Response<SearchList<Charge>> {
client.get_query("/charges/search", params)
}
}

impl Object for Rule {
Expand All @@ -55,3 +62,19 @@ impl Object for Rule {
""
}
}

#[derive(Clone, Debug, Default, Serialize)]
pub struct ChargeSearchParams<'a> {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page: Option<u64>,
pub expand: &'a [&'a str],
}

impl<'a> ChargeSearchParams<'a> {
pub fn new() -> ChargeSearchParams<'a> {
ChargeSearchParams { query: String::new(), limit: None, page: None, expand: &[] }
}
}
25 changes: 24 additions & 1 deletion src/resources/customer_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};

use crate::client::{Client, Response};
use crate::ids::{BankAccountId, CardId, CustomerId, PaymentSourceId};
use crate::params::{Deleted, Expand, List};
use crate::params::{Deleted, Expand, List, SearchList};
use crate::resources::{
BankAccount, Customer, PaymentMethod, PaymentSource, PaymentSourceParams, Source,
};
Expand Down Expand Up @@ -70,6 +70,22 @@ pub enum CustomerPaymentMethodRetrievalType {
WechatPay,
}

#[derive(Clone, Debug, Default, Serialize)]
pub struct CustomerSearchParams<'a> {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page: Option<u64>,
pub expand: &'a [&'a str],
}

impl<'a> CustomerSearchParams<'a> {
pub fn new() -> CustomerSearchParams<'a> {
CustomerSearchParams { query: String::new(), limit: None, page: None, expand: &[] }
}
}

impl Customer {
/// Attaches a source to a customer, does not change default Source for the Customer
///
Expand Down Expand Up @@ -132,6 +148,13 @@ impl Customer {
) -> Response<List<PaymentMethod>> {
client.get_query(&format!("/customers/{}/payment_methods", customer_id), &params)
}

/// Searches for a customer.
///
/// For more details see <https://stripe.com/docs/api/customers/search>.
pub fn search(client: &Client, params: CustomerSearchParams) -> Response<SearchList<Customer>> {
client.get_query("/customers/search", params)
}
}

/// The set of parameters that can be used when verifying a Bank Account.
Expand Down
25 changes: 24 additions & 1 deletion src/resources/invoice_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::Serialize;

use crate::client::{Client, Response};
use crate::ids::{CouponId, CustomerId, InvoiceId, PlanId, SubscriptionId, SubscriptionItemId};
use crate::params::{Metadata, Timestamp};
use crate::params::{Metadata, SearchList, Timestamp};
use crate::resources::{CollectionMethod, Invoice};

#[deprecated(since = "0.12.0")]
Expand All @@ -22,6 +22,13 @@ impl Invoice {
pub fn pay(client: &Client, invoice_id: &InvoiceId) -> Response<Invoice> {
client.post(&format!("/invoices/{}/pay", invoice_id))
}

/// Searches for an invoice.
///
/// For more details see <https://stripe.com/docs/api/invoices/search>.
pub fn search(client: &Client, params: InvoiceSearchParams) -> Response<SearchList<Invoice>> {
client.get_query("/invoices/search", params)
}
}

#[derive(Clone, Debug, Serialize)]
Expand Down Expand Up @@ -71,3 +78,19 @@ pub struct SubscriptionItemFilter {
#[serde(skip_serializing_if = "Option::is_none")]
pub quantity: Option<u64>,
}

#[derive(Clone, Debug, Default, Serialize)]
pub struct InvoiceSearchParams<'a> {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page: Option<u64>,
pub expand: &'a [&'a str],
}

impl<'a> InvoiceSearchParams<'a> {
pub fn new() -> InvoiceSearchParams<'a> {
InvoiceSearchParams { query: String::new(), limit: None, page: None, expand: &[] }
}
}
Loading

0 comments on commit d9d4ecd

Please sign in to comment.