Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Search API functionality #414

Merged
merged 16 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 135 additions & 35 deletions src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,109 @@ pub trait Paginable {
fn set_last(&mut self, item: Self::O);
}

#[derive(Debug)]
pub struct ListPaginator<T, P> {
pub page: List<T>,
pub params: P,
pub trait PaginableList {
type O: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug;
fn new(data: Vec<Self::O>, url: String, has_more: bool, total_count: Option<u64>) -> Self;
fn get_data(&self) -> Vec<Self::O>;
fn get_url(&self) -> String;
fn get_total_count(&self) -> Option<u64>;
fn has_more(&self) -> bool;
}

/// A single page of a cursor-paginated list of a search object.
///
/// For more details, see <https://stripe.com/docs/api/pagination/search>
#[derive(Debug, Deserialize, Serialize)]
pub struct SearchList<T> {
pub object: String,
pub url: String,
pub has_more: bool,
pub data: Vec<T>,
pub next_page: Option<String>,
pub total_count: Option<u64>,
}

impl<T> Default for SearchList<T> {
fn default() -> Self {
SearchList {
object: String::new(),
data: Vec::new(),
has_more: false,
total_count: None,
url: String::new(),
next_page: None,
}
}
}

impl<T: Clone> Clone for SearchList<T> {
fn clone(&self) -> Self {
SearchList {
object: self.object.clone(),
data: self.data.clone(),
has_more: self.has_more,
total_count: self.total_count,
url: self.url.clone(),
next_page: self.next_page.clone(),
}
}
}

impl<T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug> PaginableList
for SearchList<T>
{
type O = T;

fn new(
data: Vec<Self::O>,
url: String,
has_more: bool,
total_count: Option<u64>,
) -> SearchList<T> {
Self { object: "".to_string(), url, has_more, data, next_page: None, total_count }
}

fn get_data(&self) -> Vec<Self::O> {
self.data.clone()
}
fn get_url(&self) -> String {
self.url.clone()
}
fn get_total_count(&self) -> Option<u64> {
self.total_count
}
fn has_more(&self) -> bool {
self.has_more
}
}

impl<T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug> PaginableList
for List<T>
{
type O = T;

fn new(data: Vec<Self::O>, url: String, has_more: bool, total_count: Option<u64>) -> List<T> {
Self { url, has_more, data, total_count }
}

fn get_data(&self) -> Vec<Self::O> {
self.data.clone()
}
fn get_url(&self) -> String {
self.url.clone()
}
fn get_total_count(&self) -> Option<u64> {
self.total_count
}
fn has_more(&self) -> bool {
self.has_more
}
}

impl<T> SearchList<T> {
pub fn paginate<P>(self, params: P) -> ListPaginator<SearchList<T>, P> {
ListPaginator { page: self, params }
}
}

/// A single page of a cursor-paginated list of an object.
Expand Down Expand Up @@ -214,33 +313,39 @@ impl<T: Clone> Clone for List<T> {
}

impl<T> List<T> {
pub fn paginate<P>(self, params: P) -> ListPaginator<T, P> {
pub fn paginate<P>(self, params: P) -> ListPaginator<List<T>, P> {
ListPaginator { page: self, params }
}
}

#[derive(Debug)]
pub struct ListPaginator<T, P> {
pub page: T,
pub params: P,
}

impl<
T: Paginate + DeserializeOwned + Send + Sync + 'static + Clone + std::fmt::Debug,
T: PaginableList + Send + DeserializeOwned + 'static,
P: Clone + Serialize + Send + 'static + std::fmt::Debug,
> ListPaginator<T, P>
where
P: Paginable<O = T>,
P: Paginable<O = T::O>,
{
/// Repeatedly queries Stripe for more data until all elements in list are fetched, using
/// Stripe's default page size.
///
/// Requires `feature = "blocking"`.
#[cfg(feature = "blocking")]
pub fn get_all(self, client: &Client) -> Response<Vec<T>> {
let mut data = Vec::with_capacity(self.page.total_count.unwrap_or(0) as usize);
pub fn get_all(self, client: &Client) -> Response<Vec<T::O>> {
let mut data = Vec::with_capacity(self.page.get_total_count().unwrap_or(0) as usize);
let mut paginator = self;
loop {
if !paginator.page.has_more {
data.extend(paginator.page.data.into_iter());
if !paginator.page.has_more() {
data.extend(paginator.page.get_data().into_iter());
break;
}
let next_paginator = paginator.next(client)?;
data.extend(paginator.page.data.into_iter());
data.extend(paginator.page.get_data().into_iter());
paginator = next_paginator
}
Ok(data)
Expand Down Expand Up @@ -276,11 +381,11 @@ where
/// Requires `feature = ["async", "stream"]`.
#[cfg(all(feature = "async", feature = "stream"))]
pub fn stream(
mut self,
self,
client: &Client,
) -> impl futures_util::Stream<Item = Result<T, StripeError>> + Unpin {
) -> impl futures_util::Stream<Item = Result<T::O, StripeError>> + Unpin {
// We are going to be popping items off the end of the list, so we need to reverse it.
self.page.data.reverse();
self.page.get_data().reverse();

Box::pin(futures_util::stream::unfold(Some((self, client.clone())), Self::unfold_stream))
}
Expand All @@ -289,22 +394,22 @@ where
#[cfg(all(feature = "async", feature = "stream"))]
async fn unfold_stream(
state: Option<(Self, Client)>,
) -> Option<(Result<T, StripeError>, Option<(Self, Client)>)> {
let (mut paginator, client) = state?; // If none, we sent the last item in the last iteration
) -> Option<(Result<T::O, StripeError>, Option<(Self, Client)>)> {
let (paginator, client) = state?; // If none, we sent the last item in the last iteration

if paginator.page.data.len() > 1 {
return Some((Ok(paginator.page.data.pop()?), Some((paginator, client))));
if paginator.page.get_data().len() > 1 {
return Some((Ok(paginator.page.get_data().pop()?), Some((paginator, client))));
// We have more data on this page
}

if !paginator.page.has_more {
return Some((Ok(paginator.page.data.pop()?), None)); // Final value of the stream, no errors
if !paginator.page.has_more() {
return Some((Ok(paginator.page.get_data().pop()?), None)); // Final value of the stream, no errors
}

match paginator.next(&client).await {
Ok(mut next_paginator) => {
let data = paginator.page.data.pop()?;
next_paginator.page.data.reverse();
Ok(next_paginator) => {
let data = paginator.page.get_data().pop()?;
next_paginator.page.get_data().reverse();

// Yield last value of thimuts page, the next page (and client) becomes the state
Some((Ok(data), Some((next_paginator, client))))
Expand All @@ -315,9 +420,9 @@ where

/// Fetch an additional page of data from stripe.
pub fn next(&self, client: &Client) -> Response<Self> {
if let Some(last) = self.page.data.last() {
if self.page.url.starts_with("/v1/") {
let path = self.page.url.trim_start_matches("/v1/").to_string(); // the url we get back is prefixed
if let Some(last) = self.page.get_data().last() {
if self.page.get_url().starts_with("/v1/") {
let path = self.page.get_url().trim_start_matches("/v1/").to_string(); // the url we get back is prefixed

// clone the params and set the cursor
let params_next = {
Expand All @@ -334,12 +439,7 @@ where
}
} else {
ok(ListPaginator {
page: List {
data: Vec::new(),
has_more: false,
total_count: self.page.total_count,
url: self.page.url.clone(),
},
page: T::new(Vec::new(), self.page.get_url(), false, self.page.get_total_count()),
params: self.params.clone(),
})
}
Expand All @@ -348,13 +448,13 @@ where
/// Pin a new future which maps the result inside the page future into
/// a ListPaginator
#[cfg(feature = "async")]
fn create_paginator(page: Response<List<T>>, params: P) -> Response<Self> {
fn create_paginator(page: Response<T>, params: P) -> Response<Self> {
use futures_util::FutureExt;
Box::pin(page.map(|page| page.map(|page| ListPaginator { page, params })))
}

#[cfg(feature = "blocking")]
fn create_paginator(page: Response<List<T>>, params: P) -> Response<Self> {
fn create_paginator(page: Response<T>, params: P) -> Response<Self> {
page.map(|page| ListPaginator { page, params })
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/resources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ mod billing {
pub mod usage_record_ext;
}

#[path = "resources"]
#[cfg(feature = "products")]
mod products {
pub mod price_ext;
pub mod product_ext;
}

#[path = "resources"]
#[cfg(feature = "checkout")]
mod checkout {
Expand Down
25 changes: 24 additions & 1 deletion src/resources/charge_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};

use crate::client::{Client, Response};
use crate::ids::{AccountId, BankAccountId, CardId, ChargeId, SourceId, TokenId};
use crate::params::Object;
use crate::params::{Object, SearchList};
use crate::resources::{Charge, Rule};

/// The set of PaymentSource parameters that can be used to create a charge.
Expand Down Expand Up @@ -44,6 +44,13 @@ impl Charge {
) -> Response<Charge> {
client.post_form(&format!("/charges/{}/capture", charge_id), params)
}

/// Searches for a charge.
///
/// For more details see <https://stripe.com/docs/api/charges/search>.
pub fn search(client: &Client, params: ChargeSearchParams) -> Response<SearchList<Charge>> {
client.get_query("/charges/search", params)
}
}

impl Object for Rule {
Expand All @@ -55,3 +62,19 @@ impl Object for Rule {
""
}
}

#[derive(Clone, Debug, Default, Serialize)]
pub struct ChargeSearchParams<'a> {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page: Option<u64>,
pub expand: &'a [&'a str],
}

impl<'a> ChargeSearchParams<'a> {
pub fn new() -> ChargeSearchParams<'a> {
ChargeSearchParams { query: String::new(), limit: None, page: None, expand: &[] }
}
}
25 changes: 24 additions & 1 deletion src/resources/customer_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};

use crate::client::{Client, Response};
use crate::ids::{BankAccountId, CardId, CustomerId, PaymentSourceId};
use crate::params::{Deleted, Expand, List};
use crate::params::{Deleted, Expand, List, SearchList};
use crate::resources::{
BankAccount, Customer, PaymentMethod, PaymentSource, PaymentSourceParams, Source,
};
Expand Down Expand Up @@ -70,6 +70,22 @@ pub enum CustomerPaymentMethodRetrievalType {
WechatPay,
}

#[derive(Clone, Debug, Default, Serialize)]
pub struct CustomerSearchParams<'a> {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page: Option<u64>,
pub expand: &'a [&'a str],
}

impl<'a> CustomerSearchParams<'a> {
pub fn new() -> CustomerSearchParams<'a> {
CustomerSearchParams { query: String::new(), limit: None, page: None, expand: &[] }
}
}

impl Customer {
/// Attaches a source to a customer, does not change default Source for the Customer
///
Expand Down Expand Up @@ -132,6 +148,13 @@ impl Customer {
) -> Response<List<PaymentMethod>> {
client.get_query(&format!("/customers/{}/payment_methods", customer_id), &params)
}

/// Searches for a customer.
///
/// For more details see <https://stripe.com/docs/api/customers/search>.
pub fn search(client: &Client, params: CustomerSearchParams) -> Response<SearchList<Customer>> {
client.get_query("/customers/search", params)
}
}

/// The set of parameters that can be used when verifying a Bank Account.
Expand Down
25 changes: 24 additions & 1 deletion src/resources/invoice_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::Serialize;

use crate::client::{Client, Response};
use crate::ids::{CouponId, CustomerId, InvoiceId, PlanId, SubscriptionId, SubscriptionItemId};
use crate::params::{Metadata, Timestamp};
use crate::params::{Metadata, SearchList, Timestamp};
use crate::resources::{CollectionMethod, Invoice};

#[deprecated(since = "0.12.0")]
Expand All @@ -22,6 +22,13 @@ impl Invoice {
pub fn pay(client: &Client, invoice_id: &InvoiceId) -> Response<Invoice> {
client.post(&format!("/invoices/{}/pay", invoice_id))
}

/// Searches for an invoice.
///
/// For more details see <https://stripe.com/docs/api/invoices/search>.
pub fn search(client: &Client, params: InvoiceSearchParams) -> Response<SearchList<Invoice>> {
client.get_query("/invoices/search", params)
}
}

#[derive(Clone, Debug, Serialize)]
Expand Down Expand Up @@ -71,3 +78,19 @@ pub struct SubscriptionItemFilter {
#[serde(skip_serializing_if = "Option::is_none")]
pub quantity: Option<u64>,
}

#[derive(Clone, Debug, Default, Serialize)]
pub struct InvoiceSearchParams<'a> {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page: Option<u64>,
pub expand: &'a [&'a str],
}

impl<'a> InvoiceSearchParams<'a> {
pub fn new() -> InvoiceSearchParams<'a> {
InvoiceSearchParams { query: String::new(), limit: None, page: None, expand: &[] }
}
}
Loading