Skip to content

Commit

Permalink
add: quota service
Browse files Browse the repository at this point in the history
add: token bucket
remove: token bucket from retry
  • Loading branch information
Velfi committed Jan 11, 2023
1 parent bcea15a commit c6244ec
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 115 deletions.
1 change: 1 addition & 0 deletions rust-runtime/aws-smithy-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ aws-smithy-protocol-test = { path = "../aws-smithy-protocol-test", optional = tr
aws-smithy-types = { path = "../aws-smithy-types" }
bytes = "1"
fastrand = "1.4.0"
futures-util = "0.3.25"
http = "0.2.3"
http-body = "0.4.4"
hyper = { version = "0.14.12", features = ["client", "http2", "http1", "tcp"], optional = true }
Expand Down
24 changes: 18 additions & 6 deletions rust-runtime/aws-smithy-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

pub mod bounds;
pub mod erase;
pub mod quota;
pub mod retry;
pub mod token_bucket;

// https://github.com/rust-lang/rust/issues/72081
#[allow(rustdoc::private_doc_tests)]
Expand Down Expand Up @@ -97,10 +99,14 @@ pub use aws_smithy_http::result::{SdkError, SdkSuccess};
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_http_tower::dispatch::DispatchLayer;
use aws_smithy_http_tower::parse_response::ParseResponseLayer;
use aws_smithy_http_tower::SendOperationError;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_smithy_types::retry::ProvideErrorKind;
use aws_smithy_types::timeout::OperationTimeoutConfig;
use quota::QuotaLayer;
use std::error::Error;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use timeout::ClientTimeoutParams;
use tower::{Layer, Service, ServiceBuilder, ServiceExt};
Expand Down Expand Up @@ -215,19 +221,25 @@ where
ClientTimeoutParams::new(&self.operation_timeout_config, self.sleep_impl.clone());

let svc = ServiceBuilder::new()
.layer(TimeoutLayer::new(timeout_params.operation_timeout))
.retry(
self.retry_policy
.new_request_policy(self.sleep_impl.clone()),
)
.layer(TimeoutLayer::new(timeout_params.operation_attempt_timeout))
// .layer(TimeoutLayer::new(timeout_params.operation_timeout))
.layer(QuotaLayer::new(|| {
token_bucket::standard::TokenBucket::builder().build()
}))
// .retry(
// self.retry_policy
// .new_request_policy(self.sleep_impl.clone()),
// )
// .layer(TimeoutLayer::new(timeout_params.operation_attempt_timeout))
.layer(ParseResponseLayer::<O, Retry>::new())
// These layers can be considered as occurring in order. That is, first invoke the
// customer-provided middleware, then dispatch dispatch over the wire.
.layer(&self.middleware)
.layer(DispatchLayer::new())
.service(connector);

// let c: Box<dyn tower::Service<_, Response = _, Error = _, Future = _>> =
// Box::new(svc.clone());

// send_operation records the full request-response lifecycle.
// NOTE: For operations that stream output, only the setup is captured in this span.
let span = debug_span!(
Expand Down
235 changes: 235 additions & 0 deletions rust-runtime/aws-smithy-client/src/quota.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

//! Token bucket management

use crate::token_bucket::TokenBucket;
use aws_smithy_http::result::SdkError;
use futures_util::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::Layer;

/// A service that wraps another service, adding the ability to set a quota for requests
/// handled by the inner service.
#[derive(Clone)]
pub struct QuotaService<S, Tb> {
inner: S,
token_bucket: Arc<Tb>,
}

impl<S, Tb> QuotaService<S, Tb>
where
Tb: TokenBucket,
{
/// Create a new `QuotaService`
pub fn new(inner: S, token_bucket: Tb) -> Self {
Self {
inner,
token_bucket: Arc::new(token_bucket),
}
}
}

type BoxedResultFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>>>>;

impl<InnerService, Req, E, Tb> tower::Service<Req> for QuotaService<InnerService, Tb>
where
InnerService: tower::Service<Req, Error = SdkError<E>>,
InnerService::Response: Send + 'static,
InnerService::Future: Send + 'static,
E: Send + 'static,
Tb: TokenBucket,
Tb::Token: Send + 'static,
{
type Response = InnerService::Response;
type Error = SdkError<E>;
type Future = BoxedResultFuture<Self::Response, Self::Error>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// Check the inner service to see if it's ready yet. If no tokens are available, requests
// should fail with an error instead of waiting for the next token.
self.inner.poll_ready(cx).map_err(|err| {
SdkError::construction_failure(format!("inner service failed to become ready"))
})
}

fn call(&mut self, mut req: Req) -> Self::Future {
match self.token_bucket.try_acquire(None) {
Ok(token) => {
// req.properties_mut().insert(token);
let fut = self.inner.call(req);

Box::pin(fut)
}
Err(err) => {
let fut = futures_util::future::err::<_, SdkError<E>>(
SdkError::construction_failure(err),
);

Box::pin(fut)
}
}
}
}

/// A layer that wraps services in a quota service
#[non_exhaustive]
#[derive(Debug)]
pub struct QuotaLayer<Tbb> {
token_bucket_builder: Tbb,
}

impl<Tb, TbBuilder> QuotaLayer<TbBuilder>
where
Tb: TokenBucket,
TbBuilder: Fn() -> Tb,
{
/// Create a new `QuotaLayer`
pub fn new(token_bucket_builder: TbBuilder) -> Self {
QuotaLayer {
token_bucket_builder,
}
}
}

impl<S, Tb, TbBuilder> Layer<S> for QuotaLayer<TbBuilder>
where
Tb: TokenBucket,
TbBuilder: Fn() -> Tb,
{
type Service = QuotaService<S, Tb>;

fn layer(&self, inner: S) -> Self::Service {
QuotaService {
inner,
token_bucket: Arc::new((self.token_bucket_builder)()),
}
}
}

#[cfg(test)]
mod tests {
use super::QuotaService;
use crate::token_bucket::standard;
use crate::token_bucket::TokenBucket;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::operation::Operation;
use aws_smithy_http::result::SdkError;
use aws_smithy_types::retry::ErrorKind;
use futures_util::future::TryFutureExt;
use http::{Request, Response, StatusCode};
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tower::{Service, ServiceExt};

#[derive(Clone)]
struct TestService<H, R> {
handler: PhantomData<H>,
retry: PhantomData<R>,
}

impl<H, R> TestService<H, R> {
pub fn new() -> Self {
Self {
handler: PhantomData::default(),
retry: PhantomData::default(),
}
}
}

impl<H, R> Service<Operation<H, R>> for TestService<H, R> {
type Response = Response<&'static str>;
type Error = SdkError<()>;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + Sync>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, _req: Operation<H, R>) -> Self::Future {
let fut = async {
Ok(Response::builder()
.status(StatusCode::OK)
.body("Hello, world!")
.unwrap())
};

Box::pin(fut)
}
}

#[tokio::test]
async fn quota_service_has_ready_trait_method() {
let mut svc = QuotaService::new(
TestService::<(), ()>::new(),
standard::TokenBucket::builder().build(),
);

let _mut_ref = svc.ready().await.unwrap();
}

#[tokio::test]
async fn quota_service_is_send_sync() {
fn check_send_sync<T: Send + Sync>(t: T) -> T {
t
}

let svc = QuotaService::new(
TestService::<(), ()>::new(),
standard::TokenBucket::builder().build(),
);

let _mut_ref = check_send_sync(svc).ready().await.unwrap();
}

#[tokio::test]
async fn quota_layer_keeps_working_after_getting_emptied_and_then_refilled() {
let quota_state = standard::TokenBucket::builder()
.max_tokens(500)
.retryable_error_cost(5)
.timeout_error_cost(10)
.starting_tokens(10)
.build();
assert_eq!(quota_state.available(), 10);
// Remove the only token in the bucket, from the bucket
let the_only_token_in_the_bucket = quota_state
.try_acquire(Some(ErrorKind::TransientError))
.unwrap();
assert_eq!(quota_state.available(), 0);

let mut svc = QuotaService::new(TestService::new(), quota_state);

let req = Request::builder()
.body(SdkBody::empty())
.expect("failed to construct empty request");
let req = aws_smithy_http::operation::Request::new(req);
let op = Operation::new(req, ());

let op_clone = op.try_clone().unwrap();
let svc_clone = svc.clone();
let handle_a = tokio::task::spawn(async move {
let mut svc = svc_clone;
let _ = svc.ready().await;
svc.call(op_clone).await
});

// We need to make sure that the task has time to check readiness and find that the token
// bucket is empty.
tokio::time::sleep(Duration::from_secs(1)).await;

// Relinquish the semaphore token we held, enabling future requests to succeed.
drop(the_only_token_in_the_bucket);
let res_a = handle_a.await.expect("join handle is valid");
let res_b = svc.ready().and_then(|f| f.call(op)).await;

println!("{res_a:#?}, {res_b:#?}");
}
}
Loading

0 comments on commit c6244ec

Please sign in to comment.