Skip to content

Commit

Permalink
feat: add read mode support
Browse files Browse the repository at this point in the history
  • Loading branch information
HoKim98 committed Sep 12, 2024
1 parent b7b188f commit 07805b9
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 55 deletions.
35 changes: 34 additions & 1 deletion src/args.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt;

use byte_unit::Byte;
use clap::{ArgAction, Parser};
use clap::{ArgAction, Parser, ValueEnum};
use duration_string::DurationString;
use s3::{creds::Credentials, Region};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -206,6 +206,16 @@ pub struct LoadTesterJobArgs {
#[serde(default)]
pub duration: Option<DurationString>,

#[arg(
long,
env = "SOS_MODE",
value_name = "MODE",
value_enum,
default_value_t = Mode::default(),
)]
#[serde(default)]
pub mode: Mode,

#[arg(
long,
env = "SOS_NO_PROGRESS_BAR",
Expand All @@ -229,6 +239,7 @@ impl Default for LoadTesterJobArgs {
fn default() -> Self {
Self {
duration: None,
mode: Mode::default(),
no_progress_bar: Self::default_no_progress_bar(),
threads_max: Self::default_threads_max(),
}
Expand All @@ -247,6 +258,7 @@ impl LoadTesterJobArgs {
fn print(&self) {
let Self {
duration,
mode,
no_progress_bar,
threads_max,
} = self;
Expand All @@ -258,11 +270,32 @@ impl LoadTesterJobArgs {
.map(ToString::to_string)
.unwrap_or_else(|| "None".into(),)
);
info!("mode: {mode:?}");
info!("no_progress_bar: {no_progress_bar}");
info!("threads_max: {threads_max}");
}
}

#[derive(
Copy,
Clone,
Debug,
Default,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Serialize,
Deserialize,
ValueEnum,
)]
pub enum Mode {
Read,
#[default]
Write,
}

#[derive(Clone, Debug, PartialEq, Parser, Serialize, Deserialize)]
#[clap(rename_all = "kebab-case")]
#[serde(rename_all = "camelCase")]
Expand Down
198 changes: 144 additions & 54 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
convert::identity,
fmt::Write,
sync::{
atomic::{AtomicU64, Ordering},
atomic::{AtomicU64, AtomicU8, Ordering},
Arc,
},
time::{Duration, Instant},
Expand All @@ -19,7 +19,7 @@ use s3::{serde_types::InitiateMultipartUploadResponse, Bucket, BucketConfigurati
use tokio::{spawn, task::JoinHandle, time::sleep};
use tracing::{error, info};

use crate::args::{Args, LoadTesterArgs, LoadTesterJobArgs};
use crate::args::{Args, LoadTesterArgs, LoadTesterJobArgs, Mode};

pub struct ObjectStorageSession {
bucket: Bucket,
Expand Down Expand Up @@ -99,13 +99,15 @@ impl ObjectStorageSession {
load_tester_job:
LoadTesterJobArgs {
duration,
mode,
no_progress_bar,
threads_max,
},
} = self;

let duration = duration.map(Into::into);
let counter = Arc::<AtomicU64>::default();
let state = Arc::<AtomicU8>::default();

let task_handler = (0..threads_max)
.map(|id| SessionTask {
Expand All @@ -114,7 +116,9 @@ impl ObjectStorageSession {
counter: counter.clone(),
duration,
id,
mode,
signal: signal.clone(),
state: state.clone(),
total_tasks: threads_max,
})
.map(|task| {
Expand Down Expand Up @@ -161,6 +165,9 @@ impl ObjectStorageSession {
.progress_chars("#>-");
pb.set_style(style);

while state.load(Ordering::SeqCst) != SessionTask::STATE_READE {
sleep(Duration::from_millis(10)).await;
}
loop {
let progressed = counter.load(Ordering::SeqCst);
pb.set_position(progressed * size.as_u64());
Expand Down Expand Up @@ -189,37 +196,41 @@ struct SessionTask {
counter: Arc<AtomicU64>,
duration: Option<Duration>,
id: usize,
mode: Mode,
signal: FunctionSignal,
state: Arc<AtomicU8>,
total_tasks: usize,
}

impl SessionTask {
const CONTENT_TYPE: &'static str = "application/octet-stream";

const STATE_PENDING: u8 = 0;
const STATE_INIT: u8 = 1;
const STATE_READE: u8 = 2;

async fn try_loop_forever(self) -> Result<()> {
let Self {
args:
LoadTesterArgs {
count,
multipart_threshold,
multipart_threshold: _,
size,
step,
},
bucket,
counter,
bucket: _,
counter: _,
duration,
id,
mode,
signal,
state,
total_tasks,
} = self;

let instant = Instant::now();
} = &self;

let content_type = "application/octet-stream";
let count = count.map(|count| count.as_u64() as usize);
let multipart_minimal = LoadTesterArgs::minimal_multipart_threshold().as_u64() as usize;
let multipart_threshold = multipart_threshold.as_u64() as usize;
let size = size.as_u64() as usize;
let step = step.as_u64() as usize;
let use_multipart = size > multipart_threshold;

info!("Creating buffer map: {id}/{total_tasks}");
let mut buf = vec![0; size + step];
Expand All @@ -228,9 +239,32 @@ impl SessionTask {
rng.fill_bytes(&mut buf);
}

let mut index = id;
if state
.compare_exchange(
Self::STATE_PENDING,
Self::STATE_INIT,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
info!("Initializing mode: {mode:?}");
match mode {
Mode::Read => self.init_read(&buf).await?,
Mode::Write => self.init_write().await?,
}
state.store(Self::STATE_READE, Ordering::SeqCst);
} else {
while state.load(Ordering::SeqCst) != Self::STATE_READE {
sleep(Duration::from_millis(10)).await;
}
}

info!("Starting task: {id}/{total_tasks}");

let mut index = *id;
let instant = Instant::now();

loop {
if signal.is_terminating() {
break;
Expand All @@ -240,66 +274,118 @@ impl SessionTask {
break;
}
}
if let Some(duration) = duration {
if let Some(duration) = *duration {
if instant.elapsed() >= duration {
break;
}
}

let index = {
let ret = index;
index += total_tasks;
index += *total_tasks;
ret % step
};
let path = format!("/sample/{index:06}.bin");
match mode {
Mode::Read => self.read(index, true).await?,
Mode::Write => self.write(index, &buf, true).await?,
}
}

let data = &buf[index..index + size];
if use_multipart {
let InitiateMultipartUploadResponse { upload_id, .. } = bucket
.initiate_multipart_upload(&path, content_type)
.await?;
info!("Stopped task: {id}/{total_tasks}");
Ok(())
}

let mut chunks = vec![];
{
let mut pos = 0;
let len = data.len();
while pos < len {
let pos_next = pos + multipart_threshold;
let remaining = len - pos_next;

let pos_next = if remaining >= multipart_minimal {
pos_next
} else {
len
};

let chunk = &data[pos..pos_next];
chunks.push(chunk);
pos = pos_next;
}
}
async fn init_read(&self, buf: &[u8]) -> Result<()> {
let step = self.args.step.as_u64() as usize;

let mut parts = vec![];
for (part_number, reader) in chunks.iter_mut().enumerate() {
let part_number = (part_number + 1).try_into()?;
for index in 0..step {
self.write(index, buf, false).await?;
}
Ok(())
}

async fn init_write(&self) -> Result<()> {
Ok(())
}

let part = bucket
.put_multipart_stream(reader, &path, part_number, &upload_id, content_type)
.await?;
parts.push(part);
async fn read(&self, index: usize, add_counter: bool) -> Result<()> {
let size = self.args.size.as_u64() as usize;

let path = get_s3_path(index);

let response = self.bucket.get_object(&path).await?;
// assert_eq!(response.bytes().len(), size);
drop(response);

if add_counter {
self.counter.fetch_add(1, Ordering::SeqCst);
}
Ok(())
}

async fn write(&self, index: usize, buf: &[u8], add_counter: bool) -> Result<()> {
let multipart_minimal = LoadTesterArgs::minimal_multipart_threshold().as_u64() as usize;
let multipart_threshold = self.args.multipart_threshold.as_u64() as usize;
let size = self.args.size.as_u64() as usize;
let use_multipart = size > multipart_threshold;

let path = get_s3_path(index);

let data = &buf[index..index + size];
if use_multipart {
let InitiateMultipartUploadResponse { upload_id, .. } = self
.bucket
.initiate_multipart_upload(&path, Self::CONTENT_TYPE)
.await?;

let mut chunks = vec![];
{
let mut pos = 0;
let len = data.len();
while pos < len {
let pos_next = pos + multipart_threshold;
let remaining = len - pos_next;

let pos_next = if remaining >= multipart_minimal {
pos_next
} else {
len
};

let chunk = &data[pos..pos_next];
chunks.push(chunk);
pos = pos_next;
}
}

bucket
.complete_multipart_upload(&path, &upload_id, parts)
let mut parts = vec![];
for (part_number, reader) in chunks.iter_mut().enumerate() {
let part_number = (part_number + 1).try_into()?;

let part = self
.bucket
.put_multipart_stream(
reader,
&path,
part_number,
&upload_id,
Self::CONTENT_TYPE,
)
.await?;
} else {
let mut reader = data;
bucket.put_object_stream(&mut reader, &path).await?;
parts.push(part);
}
counter.fetch_add(1, Ordering::SeqCst);

self.bucket
.complete_multipart_upload(&path, &upload_id, parts)
.await?;
} else {
let mut reader = data;
self.bucket.put_object_stream(&mut reader, &path).await?;
}

info!("Stopped task: {id}/{total_tasks}");
if add_counter {
self.counter.fetch_add(1, Ordering::SeqCst);
}
Ok(())
}
}
Expand Down Expand Up @@ -341,3 +427,7 @@ async fn cleanup(bucket: Bucket) -> Result<()> {
// .map_err(|error| anyhow!("failed to cleanup bucket: {error}"))
// .await
}

fn get_s3_path(index: usize) -> String {
format!("/sample/{index:06}.bin")
}

0 comments on commit 07805b9

Please sign in to comment.