diff --git a/Cargo.lock b/Cargo.lock index c3457ee91..a64785fbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1945,6 +1945,7 @@ dependencies = [ "assert_fs", "async-channel", "async-lock", + "async-trait", "aws-config", "aws-sdk-s3", "aws-sdk-sts", diff --git a/mountpoint-s3/Cargo.toml b/mountpoint-s3/Cargo.toml index dae63c594..8be5593e4 100644 --- a/mountpoint-s3/Cargo.toml +++ b/mountpoint-s3/Cargo.toml @@ -13,6 +13,7 @@ mountpoint-s3-crt = { path = "../mountpoint-s3-crt", version = "0.3.0" } anyhow = { version = "1.0.64", features = ["backtrace"] } async-channel = "1.8.0" async-lock = "2.6.0" +async-trait = "0.1.57" bytes = "1.2.1" clap = { version = "4.1.9", features = ["derive"] } crc32c = "0.6.3" diff --git a/mountpoint-s3/src/main.rs b/mountpoint-s3/src/main.rs index 937dc3550..86416b343 100644 --- a/mountpoint-s3/src/main.rs +++ b/mountpoint-s3/src/main.rs @@ -461,7 +461,6 @@ fn mount(args: CliArgs) -> anyhow::Result { filesystem_config.file_mode = file_mode; } filesystem_config.storage_class = args.storage_class; - filesystem_config.prefetcher_config.part_alignment = args.part_size as usize; filesystem_config.allow_delete = args.allow_delete; let fs = S3FuseFilesystem::new(client, runtime, &args.bucket_name, &prefix, filesystem_config); diff --git a/mountpoint-s3/src/prefetch.rs b/mountpoint-s3/src/prefetch.rs index fae1ef1fe..310487802 100644 --- a/mountpoint-s3/src/prefetch.rs +++ b/mountpoint-s3/src/prefetch.rs @@ -8,6 +8,7 @@ //! non-sequential read, we abandon the prefetching and start again with the minimum request size. pub mod checksummed_bytes; +mod feed; mod part; mod part_queue; @@ -15,20 +16,17 @@ use std::collections::VecDeque; use std::fmt::Debug; use std::time::Duration; -use bytes::Bytes; use futures::future::RemoteHandle; -use futures::pin_mut; -use futures::stream::StreamExt; use futures::task::{Spawn, SpawnExt}; use metrics::counter; use mountpoint_s3_client::error::{GetObjectError, ObjectClientError}; use mountpoint_s3_client::types::ETag; use mountpoint_s3_client::ObjectClient; -use mountpoint_s3_crt::checksums::crc32c; use thiserror::Error; use tracing::{debug_span, error, trace, Instrument}; use crate::prefetch::checksummed_bytes::{ChecksummedBytes, IntegrityError}; +use crate::prefetch::feed::{ClientPartFeed, ObjectPartFeed}; use crate::prefetch::part::Part; use crate::prefetch::part_queue::{unbounded_part_queue, PartQueue}; use crate::sync::{Arc, RwLock}; @@ -45,8 +43,6 @@ pub struct PrefetcherConfig { pub sequential_prefetch_multiplier: usize, /// Timeout to wait for a part to become available pub read_timeout: Duration, - /// The size of the parts that the prefetcher is trying to align with - pub part_alignment: usize, } impl Default for PrefetcherConfig { @@ -65,7 +61,6 @@ impl Default for PrefetcherConfig { max_request_size: 2 * 1024 * 1024 * 1024, sequential_prefetch_multiplier: 8, read_timeout: Duration::from_secs(60), - part_alignment: 8 * 1024 * 1024, } } } @@ -76,13 +71,18 @@ pub struct Prefetcher { inner: Arc>, } -#[derive(Debug)] struct PrefetcherInner { - client: Arc, + part_feed: Arc + Send + Sync>, config: PrefetcherConfig, runtime: Runtime, } +impl Debug for PrefetcherInner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrefetcherInner").field("config", &self.config).finish() + } +} + impl Prefetcher where Client: ObjectClient + Send + Sync + 'static, @@ -90,8 +90,9 @@ where { /// Create a new [Prefetcher] that will make requests to the given client. pub fn new(client: Arc, runtime: Runtime, config: PrefetcherConfig) -> Self { + let part_feed = Arc::new(ClientPartFeed::new(client)); let inner = PrefetcherInner { - client, + part_feed, config, runtime, }; @@ -101,7 +102,7 @@ where /// Start a new get request to the specified object. pub fn get(&self, bucket: &str, key: &str, size: u64, etag: ETag) -> PrefetchGetObject { - PrefetchGetObject::new(Arc::clone(&self.inner), bucket, key, size, etag) + PrefetchGetObject::new(self.inner.clone(), bucket, key, size, etag) } } @@ -276,7 +277,7 @@ where trace!(?range, size, "spawning request"); let request_task = { - let client = Arc::clone(&self.inner.client); + let feed = self.inner.part_feed.clone(); let preferred_part_size = self.preferred_part_size; let bucket = self.bucket.to_owned(); let key = self.key.to_owned(); @@ -284,47 +285,8 @@ where let span = debug_span!("prefetch", range=?range); async move { - match client.get_object(&bucket, &key, Some(range.clone()), Some(etag)).await { - Err(e) => { - error!(error=?e, "RequestTask get object failed"); - part_queue_producer.push(Err(e)); - } - Ok(request) => { - pin_mut!(request); - loop { - match request.next().await { - Some(Ok((offset, body))) => { - // pre-split the body into multiple parts as suggested by preferred part size - // in order to avoid validating checksum on large parts at read. - assert!(preferred_part_size > 0); - let mut body: Bytes = body.into(); - let mut curr_offset = offset; - loop { - let chunk_size = preferred_part_size.min(body.len()); - if chunk_size == 0 { - break; - } - let chunk = body.split_to(chunk_size); - // S3 doesn't provide checksum for us if the request range is not aligned to object part boundaries, - // so we're computing our own checksum here. - let checksum = crc32c::checksum(&chunk); - let checksum_bytes = ChecksummedBytes::new(chunk, checksum); - let part = Part::new(&key, curr_offset, checksum_bytes); - curr_offset += part.len() as u64; - part_queue_producer.push(Ok(part)); - } - } - Some(Err(e)) => { - error!(error=?e, "RequestTask body part failed"); - part_queue_producer.push(Err(e)); - break; - } - None => break, - } - } - trace!("request finished"); - } - } + feed.get_object_parts(&bucket, &key, range, etag, preferred_part_size, part_queue_producer) + .await } .instrument(span) }; @@ -344,29 +306,14 @@ where } /// Suggest next request size. - /// Normally, next request size is current request size multiply by sequential prefetch multiplier, - /// but if the request size is getting bigger than a part size we will try to align it to part boundaries. + /// The next request size is the current request size multiplied by sequential prefetch multiplier. fn get_next_request_size(&self) -> usize { // calculate next request size let next_request_size = (self.next_request_size * self.inner.config.sequential_prefetch_multiplier) .min(self.inner.config.max_request_size); - - let offset_in_part = (self.next_request_offset % self.inner.config.part_alignment as u64) as usize; - // if the offset is not at the start of the part we will drain all the bytes from that part first - if offset_in_part != 0 { - let remaining_in_part = self.inner.config.part_alignment - offset_in_part; - next_request_size.min(remaining_in_part) - } else { - // if the next request size is smaller than the part size, just return that value - if next_request_size < self.inner.config.part_alignment { - return next_request_size; - } - - // if it exceeds part boundaries, trim it to the part boundaries - let next_request_boundary = self.next_request_offset + next_request_size as u64; - let remainder = (next_request_boundary % self.inner.config.part_alignment as u64) as usize; - next_request_size - remainder - } + self.inner + .part_feed + .get_aligned_request_size(self.next_request_offset, next_request_size) } /// Reset this prefetch request to a new offset, clearing any existing tasks queued. @@ -457,7 +404,6 @@ mod tests { max_request_size: test_config.max_request_size, sequential_prefetch_multiplier: test_config.sequential_prefetch_multiplier, read_timeout: Duration::from_secs(5), - part_alignment: test_config.client_part_size, }; let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); let prefetcher = Prefetcher::new(Arc::new(client), runtime, test_config); @@ -608,7 +554,6 @@ mod tests { sequential_prefetch_multiplier: prefetch_multiplier, max_request_size, read_timeout: Duration::from_secs(60), - part_alignment: part_size, }; let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); let prefetcher = Prefetcher::new(Arc::new(client), runtime, test_config); diff --git a/mountpoint-s3/src/prefetch/feed.rs b/mountpoint-s3/src/prefetch/feed.rs new file mode 100644 index 000000000..0a7cf2819 --- /dev/null +++ b/mountpoint-s3/src/prefetch/feed.rs @@ -0,0 +1,127 @@ +use std::{fmt::Debug, ops::Range, sync::Arc}; + +use async_trait::async_trait; +use bytes::Bytes; +use futures::{pin_mut, StreamExt}; +use mountpoint_s3_client::{ + error::{GetObjectError, ObjectClientError}, + types::ETag, + ObjectClient, +}; +use mountpoint_s3_crt::checksums::crc32c; +use tracing::{error, trace}; + +use crate::prefetch::{checksummed_bytes::ChecksummedBytes, part::Part, part_queue::PartQueueProducer}; + +/// A generic interface to retrieve data from objects in a S3-like store. +#[async_trait] +pub trait ObjectPartFeed { + /// Get the content of an object in fixed size parts. The parts are pushed to the provided `part_sink` + /// and are guaranteed to be contiguous and in the correct order. Callers need to specify a preferred + /// size for the parts, but implementations are allowed to ignore it. + async fn get_object_parts( + &self, + bucket: &str, + key: &str, + range: Range, + if_match: ETag, + preferred_part_size: usize, + part_sink: PartQueueProducer>, + ); + + /// Adjust the size of a request to align to optimal part boundaries for this client. + fn get_aligned_request_size(&self, offset: u64, preferred_size: usize) -> usize; +} + +/// [ObjectPartFeed] implementation which delegates retrieving object data to a [Client]. +#[derive(Debug)] +pub struct ClientPartFeed { + client: Arc, +} + +impl ClientPartFeed { + pub fn new(client: Arc) -> Self { + Self { client } + } +} + +#[async_trait] +impl ObjectPartFeed for ClientPartFeed +where + Client: ObjectClient + Send + Sync + 'static, +{ + async fn get_object_parts( + &self, + bucket: &str, + key: &str, + range: Range, + if_match: ETag, + preferred_part_size: usize, + part_queue_producer: PartQueueProducer>, + ) { + assert!(preferred_part_size > 0); + let get_object_result = match self.client.get_object(bucket, key, Some(range), Some(if_match)).await { + Ok(get_object_result) => get_object_result, + Err(e) => { + error!(error=?e, "GetObject request failed"); + part_queue_producer.push(Err(e)); + return; + } + }; + + pin_mut!(get_object_result); + loop { + match get_object_result.next().await { + Some(Ok((offset, body))) => { + trace!(offset, length = body.len(), "received GetObject part"); + // pre-split the body into multiple parts as suggested by preferred part size + // in order to avoid validating checksum on large parts at read. + let mut body: Bytes = body.into(); + let mut curr_offset = offset; + loop { + let chunk_size = preferred_part_size.min(body.len()); + if chunk_size == 0 { + break; + } + let chunk = body.split_to(chunk_size); + // S3 doesn't provide checksum for us if the request range is not aligned to + // object part boundaries, so we're computing our own checksum here. + let checksum = crc32c::checksum(&chunk); + let checksum_bytes = ChecksummedBytes::new(chunk, checksum); + let part = Part::new(key, curr_offset, checksum_bytes); + curr_offset += part.len() as u64; + part_queue_producer.push(Ok(part)); + } + } + Some(Err(e)) => { + error!(error=?e, "GetObject body part failed"); + part_queue_producer.push(Err(e)); + break; + } + None => break, + } + } + trace!("request finished"); + } + + fn get_aligned_request_size(&self, offset: u64, preferred_length: usize) -> usize { + // If the request size is bigger than a part size we will try to align it to part boundaries. + let part_alignment = self.client.part_size().unwrap_or(8 * 1024 * 1024); + let offset_in_part = (offset % part_alignment as u64) as usize; + if offset_in_part != 0 { + // if the offset is not at the start of the part we will drain all the bytes from that part first + let remaining_in_part = part_alignment - offset_in_part; + preferred_length.min(remaining_in_part) + } else { + // if the request size is smaller than the part size, just return that value + if preferred_length < part_alignment { + preferred_length + } else { + // if it exceeds part boundaries, trim it to the part boundaries + let request_boundary = offset + preferred_length as u64; + let remainder = (request_boundary % part_alignment as u64) as usize; + preferred_length - remainder + } + } + } +}