Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clients: request ID as RAII guard #543

Merged
merged 6 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 18 additions & 33 deletions http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::transport::HttpTransportClient;
use crate::types::{
traits::Client,
v2::{Id, NotificationSer, ParamsSer, RequestSer, Response, RpcError},
CertificateStore, Error, RequestIdGuard, TEN_MB_SIZE_BYTES,
CertificateStore, Error, RequestIdManager, TEN_MB_SIZE_BYTES,
};
use async_trait::async_trait;
use fnv::FnvHashMap;
Expand Down Expand Up @@ -75,7 +75,7 @@ impl HttpClientBuilder {
.map_err(|e| Error::Transport(e.into()))?;
Ok(HttpClient {
transport,
id_guard: RequestIdGuard::new(self.max_concurrent_requests),
id_manager: RequestIdManager::new(self.max_concurrent_requests),
request_timeout: self.request_timeout,
})
}
Expand All @@ -100,7 +100,7 @@ pub struct HttpClient {
/// Request timeout. Defaults to 60sec.
request_timeout: Duration,
/// Request ID manager.
id_guard: RequestIdGuard,
id_manager: RequestIdManager,
}

#[async_trait]
Expand All @@ -120,27 +120,20 @@ impl Client for HttpClient {
where
R: DeserializeOwned,
{
// NOTE: the IDs wrap on overflow which is intended.
let id = self.id_guard.next_request_id()?;
let request = RequestSer::new(Id::Number(id), method, params);

let fut = self.transport.send_and_read_body(serde_json::to_string(&request).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?);
let id = self.id_manager.next_request_id()?;
let request = RequestSer::new(Id::Number(*id.inner()), method, params);

let fut = self.transport.send_and_read_body(serde_json::to_string(&request).map_err(Error::ParseError)?);
let body = match tokio::time::timeout(self.request_timeout, fut).await {
Ok(Ok(body)) => body,
Err(_e) => {
self.id_guard.reclaim_request_id();
return Err(Error::RequestTimeout);
}
Ok(Err(e)) => {
self.id_guard.reclaim_request_id();
return Err(Error::Transport(e.into()));
}
};

self.id_guard.reclaim_request_id();
let response: Response<_> = match serde_json::from_slice(&body) {
Ok(response) => response,
Err(_) => {
Expand All @@ -151,7 +144,7 @@ impl Client for HttpClient {

let response_id = response.id.as_number().copied().ok_or(Error::InvalidRequestId)?;

if response_id == id {
if response_id == *id.inner() {
Ok(response.result)
} else {
Err(Error::InvalidRequestId)
Expand All @@ -167,34 +160,26 @@ impl Client for HttpClient {
let mut ordered_requests = Vec::with_capacity(batch.len());
let mut request_set = FnvHashMap::with_capacity_and_hasher(batch.len(), Default::default());

let ids = self.id_guard.next_request_ids(batch.len())?;
let ids = self.id_manager.next_request_ids(batch.len())?;
for (pos, (method, params)) in batch.into_iter().enumerate() {
batch_request.push(RequestSer::new(Id::Number(ids[pos]), method, params));
ordered_requests.push(ids[pos]);
request_set.insert(ids[pos], pos);
batch_request.push(RequestSer::new(Id::Number(ids.inner()[pos]), method, params));
ordered_requests.push(ids.inner()[pos]);
request_set.insert(ids.inner()[pos], pos);
}

let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?);
let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?);

let body = match tokio::time::timeout(self.request_timeout, fut).await {
Ok(Ok(body)) => body,
Err(_e) => return Err(Error::RequestTimeout),
Ok(Err(e)) => return Err(Error::Transport(e.into())),
};

let rps: Vec<Response<_>> = match serde_json::from_slice(&body) {
Ok(response) => response,
Err(_) => {
let err: RpcError = serde_json::from_slice(&body).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
return Err(Error::Request(err.to_string()));
}
};
let rps: Vec<Response<_>> =
serde_json::from_slice(&body).map_err(|_| match serde_json::from_slice::<RpcError>(&body) {
Ok(e) => Error::Request(e.to_string()),
Err(e) => Error::ParseError(e),
})?;

// NOTE: `R::default` is placeholder and will be replaced in loop below.
let mut responses = vec![R::default(); ordered_requests.len()];
Expand Down
87 changes: 53 additions & 34 deletions types/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ use futures_channel::{mpsc, oneshot};
use futures_util::{future::FutureExt, sink::SinkExt, stream::StreamExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;

/// Subscription kind
#[derive(Debug)]
Expand Down Expand Up @@ -188,65 +189,83 @@ impl<Notif> Drop for Subscription<Notif> {

#[derive(Debug)]
/// Keep track of request IDs.
pub struct RequestIdGuard {
pub struct RequestIdManager {
// Current pending requests.
current_pending: AtomicUsize,
current_pending: Arc<()>,
/// Max concurrent pending requests allowed.
max_concurrent_requests: usize,
/// Get the next request ID.
current_id: AtomicU64,
}

impl RequestIdGuard {
impl RequestIdManager {
/// Create a new `RequestIdGuard` with the provided concurrency limit.
pub fn new(limit: usize) -> Self {
Self { current_pending: AtomicUsize::new(0), max_concurrent_requests: limit, current_id: AtomicU64::new(0) }
Self { current_pending: Arc::new(()), max_concurrent_requests: limit, current_id: AtomicU64::new(0) }
}

fn get_slot(&self) -> Result<(), Error> {
self.current_pending
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
if val >= self.max_concurrent_requests {
None
} else {
Some(val + 1)
}
})
.map(|_| ())
.map_err(|_| Error::MaxSlotsExceeded)
fn get_slot(&self) -> Result<Arc<()>, Error> {
// Strong count is 1 at start, so that's why we use `>` and not `>=`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useful comment, thank you!

if Arc::strong_count(&self.current_pending) > self.max_concurrent_requests {
Err(Error::MaxSlotsExceeded)
} else {
Ok(self.current_pending.clone())
}
}

/// Attempts to get the next request ID.
///
/// Fails if request limit has been exceeded.
pub fn next_request_id(&self) -> Result<u64, Error> {
self.get_slot()?;
pub fn next_request_id(&self) -> Result<RequestIdGuard<u64>, Error> {
let rc = self.get_slot()?;
let id = self.current_id.fetch_add(1, Ordering::SeqCst);
Ok(id)
Ok(RequestIdGuard { _rc: rc, id })
}

/// Attempts to get the `n` number next IDs that only counts as one request.
///
/// Fails if request limit has been exceeded.
pub fn next_request_ids(&self, len: usize) -> Result<Vec<u64>, Error> {
self.get_slot()?;
let mut batch = Vec::with_capacity(len);
pub fn next_request_ids(&self, len: usize) -> Result<RequestIdGuard<Vec<u64>>, Error> {
let rc = self.get_slot()?;
let mut ids = Vec::with_capacity(len);
for _ in 0..len {
batch.push(self.current_id.fetch_add(1, Ordering::SeqCst));
ids.push(self.current_id.fetch_add(1, Ordering::SeqCst));
}
Ok(batch)
Ok(RequestIdGuard { _rc: rc, id: ids })
}
}

/// Reference counted request ID.
#[derive(Debug)]
pub struct RequestIdGuard<T> {
id: T,
/// Reference count decreased when dropped.
_rc: Arc<()>,
}

impl<T> RequestIdGuard<T> {
/// Get the actual ID.
pub fn inner(&self) -> &T {
&self.id
}
}

#[cfg(test)]
mod tests {
use super::RequestIdManager;

#[test]
fn request_id_guard_works() {
let manager = RequestIdManager::new(2);
let _first = manager.next_request_id().unwrap();

{
let _second = manager.next_request_ids(13).unwrap();
assert!(manager.next_request_id().is_err());
// second dropped here.
}

/// Decrease the currently pending counter by one (saturated at 0).
pub fn reclaim_request_id(&self) {
// NOTE we ignore the error here, since we are simply saturating at 0
let _ = self.current_pending.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
if val > 0 {
Some(val - 1)
} else {
None
}
});
assert!(manager.next_request_id().is_ok());
}
}

Expand Down
Loading