Skip to content

Commit

Permalink
add more infra
Browse files Browse the repository at this point in the history
  • Loading branch information
richardm-stripe committed Dec 13, 2023
1 parent 85d2c3d commit 1db8145
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 2 deletions.
72 changes: 70 additions & 2 deletions stripe/_api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def retrieve(cls, id, api_key=None, **params) -> T:
def refresh(self) -> Self:
return self._request_and_refresh("get", self.instance_url())

async def refresh_async(self) -> Self:
return await self._request_and_refresh_async(
"get", self.instance_url()
)

@classmethod
def class_url(cls) -> str:
if cls == APIResource:
Expand Down Expand Up @@ -122,8 +127,32 @@ def _request_and_refresh(
self.refresh_from(obj)
return self

# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
async def _request_and_refresh_async(
self,
method_: Literal["get", "post", "delete"],
url_: str,
api_key: Optional[str] = None,
idempotency_key: Optional[str] = None,
stripe_version: Optional[str] = None,
stripe_account: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
params: Optional[Mapping[str, Any]] = None,
) -> Self:
obj = await StripeObject._request_async(
self,
method_,
url_,
api_key,
idempotency_key,
stripe_version,
stripe_account,
headers,
params,
)

self.refresh_from(obj)
return self

@classmethod
def _static_request(
cls,
Expand Down Expand Up @@ -161,6 +190,45 @@ def _static_request(
response, api_key, stripe_version, stripe_account, params
)

@classmethod
async def _static_request_async(
cls,
method_,
url_,
api_key=None,
idempotency_key=None,
stripe_version=None,
stripe_account=None,
params=None,
):
params = None if params is None else params.copy()
api_key = read_special_variable(params, "api_key", api_key)
idempotency_key = read_special_variable(
params, "idempotency_key", idempotency_key
)
stripe_version = read_special_variable(
params, "stripe_version", stripe_version
)
stripe_account = read_special_variable(
params, "stripe_account", stripe_account
)
headers = read_special_variable(params, "headers", None)

requestor = APIRequestor(
api_key, api_version=stripe_version, account=stripe_account
)

if idempotency_key is not None:
headers = {} if headers is None else headers.copy()
headers.update(populate_headers(idempotency_key))

response, api_key = await requestor.request_async(
method_, url_, params, headers
)
return convert_to_stripe_object(
response, api_key, stripe_version, stripe_account, params
)

# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
@classmethod
Expand Down
25 changes: 25 additions & 0 deletions stripe/_searchable_api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,31 @@ def _search(

return ret

@classmethod
async def _search_async(
cls,
search_url,
api_key=None,
stripe_version=None,
stripe_account=None,
**params
):
ret = await cls._static_request_async(
"get",
search_url,
api_key=api_key,
stripe_version=stripe_version,
stripe_account=stripe_account,
params=params,
)
if not isinstance(ret, SearchResultObject):
raise TypeError(
"Expected search result from API, got %s"
% (type(ret).__name__,)
)

return ret

@classmethod
def search(cls, *args, **kwargs):
raise NotImplementedError
Expand Down
48 changes: 48 additions & 0 deletions stripe/_stripe_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,54 @@ def _request(
response, api_key, stripe_version, stripe_account, params
)

async def _request_async(
self,
method_: Literal["get", "post", "delete"],
url_: str,
api_key: Optional[str] = None,
idempotency_key: Optional[str] = None,
stripe_version: Optional[str] = None,
stripe_account: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
params: Optional[Mapping[str, Any]] = None,
) -> "StripeObject":
params = None if params is None else dict(params)
api_key = _util.read_special_variable(params, "api_key", api_key)
idempotency_key = _util.read_special_variable(
params, "idempotency_key", idempotency_key
)
stripe_version = _util.read_special_variable(
params, "stripe_version", stripe_version
)
stripe_account = _util.read_special_variable(
params, "stripe_account", stripe_account
)
headers = _util.read_special_variable(params, "headers", headers)

stripe_account = stripe_account or self.stripe_account
stripe_version = stripe_version or self.stripe_version
api_key = api_key or self.api_key
params = params or self._retrieve_params

requestor = stripe.APIRequestor(
key=api_key,
api_base=self.api_base(),
api_version=stripe_version,
account=stripe_account,
)

if idempotency_key is not None:
headers = {} if headers is None else headers.copy()
headers.update(_util.populate_headers(idempotency_key))

response, api_key = await requestor.request_async(
method_, url_, params, headers
)

return _util.convert_to_stripe_object(
response, api_key, stripe_version, stripe_account, params
)

def request_stream(
self,
method: str,
Expand Down
4 changes: 4 additions & 0 deletions stripe/_test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __init__(self, resource):
def _static_request(cls, *args, **kwargs):
return cls._resource_cls._static_request(*args, **kwargs)

@classmethod
async def _static_request_async(cls, *args, **kwargs):
return cls._resource_cls._static_request_async(*args, **kwargs)

@classmethod
def _static_request_stream(cls, *args, **kwargs):
return cls._resource_cls._static_request_stream(*args, **kwargs)
Expand Down

0 comments on commit 1db8145

Please sign in to comment.