Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use reference type in DataFrame #740

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions crates/curp/src/server/storage/wal/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ trait FrameEncoder {
#[derive(Debug)]
pub(super) struct WAL<C, H = Sha256> {
/// Frames stored in decoding
frames: Vec<DataFrame<C>>,
frames: Vec<DataFrameOwned<C>>,
/// The hasher state for decoding
hasher: H,
}
Expand All @@ -48,7 +48,7 @@ pub(super) struct WAL<C, H = Sha256> {
#[derive(Debug)]
enum WALFrame<C> {
/// Data frame type
Data(DataFrame<C>),
Data(DataFrameOwned<C>),
/// Commit frame type
Commit(CommitFrame),
}
Expand All @@ -58,13 +58,25 @@ enum WALFrame<C> {
/// Contains either a log entry or a seal index
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) enum DataFrame<C> {
pub(crate) enum DataFrameOwned<C> {
/// A Frame containing a log entry
Entry(LogEntry<C>),
/// A Frame containing the sealed index
SealIndex(LogIndex),
}

/// The data frame
///
/// Contains either a log entry or a seal index
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) enum DataFrame<'a, C> {
/// A Frame containing a log entry
Entry(&'a LogEntry<C>),
/// A Frame containing the sealed index
SealIndex(LogIndex),
}

/// The commit frame
///
/// This frames contains a SHA256 checksum of all previous frames since last commit
Expand Down Expand Up @@ -98,14 +110,14 @@ impl<C> WAL<C> {
}
}

impl<C> Encoder<Vec<DataFrame<C>>> for WAL<C>
impl<C> Encoder<Vec<DataFrame<'_, C>>> for WAL<C>
where
C: Serialize,
{
type Error = io::Error;

/// Encodes a frame
fn encode(&mut self, frames: Vec<DataFrame<C>>) -> Result<Vec<u8>, Self::Error> {
fn encode(&mut self, frames: Vec<DataFrame<'_, C>>) -> Result<Vec<u8>, Self::Error> {
let mut frame_data: Vec<_> = frames.into_iter().flat_map(|f| f.encode()).collect();
let commit_frame = CommitFrame::new_from_data(&frame_data);
frame_data.extend_from_slice(&commit_frame.encode());
Expand All @@ -118,7 +130,7 @@ impl<C> Decoder for WAL<C>
where
C: Serialize + DeserializeOwned,
{
type Item = Vec<DataFrame<C>>;
type Item = Vec<DataFrameOwned<C>>;

type Error = WALError;

Expand Down Expand Up @@ -208,14 +220,14 @@ where
let entry: LogEntry<C> = bincode::deserialize(payload)
.map_err(|e| WALError::Corrupted(CorruptType::Codec(e.to_string())))?;

Ok(Some((Self::Data(DataFrame::Entry(entry)), 8 + len)))
Ok(Some((Self::Data(DataFrameOwned::Entry(entry)), 8 + len)))
}

/// Decodes an seal index frame from source
fn decode_seal_index(header: [u8; 8]) -> Result<Option<(Self, usize)>, WALError> {
let index = Self::decode_u64_from_header(header);

Ok(Some((Self::Data(DataFrame::SealIndex(index)), 8)))
Ok(Some((Self::Data(DataFrameOwned::SealIndex(index)), 8)))
}

/// Decodes a commit frame from source
Expand All @@ -239,7 +251,17 @@ where
}
}

impl<C> FrameType for DataFrame<C> {
impl<C> DataFrameOwned<C> {
/// Converts `DataFrameOwned` to `DataFrame`
pub(super) fn get_ref(&self) -> DataFrame<'_, C> {
match *self {
DataFrameOwned::Entry(ref entry) => DataFrame::Entry(entry),
DataFrameOwned::SealIndex(index) => DataFrame::SealIndex(index),
}
}
}

impl<C> FrameType for DataFrame<'_, C> {
fn frame_type(&self) -> u8 {
match *self {
DataFrame::Entry(_) => ENTRY,
Expand All @@ -248,7 +270,7 @@ impl<C> FrameType for DataFrame<C> {
}
}

impl<C> FrameEncoder for DataFrame<C>
impl<C> FrameEncoder for DataFrame<'_, C>
where
C: Serialize,
{
Expand Down Expand Up @@ -322,31 +344,31 @@ mod tests {
async fn frame_encode_decode_is_ok() {
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
encoded.extend_from_slice(&codec.encode(vec![seal_frame]).unwrap());
let data_frame = DataFrameOwned::Entry(entry.clone());
let seal_frame = DataFrameOwned::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap();
encoded.extend_from_slice(&codec.encode(vec![seal_frame.get_ref()]).unwrap());

let (data_frame_get, len) = codec.decode(&encoded).unwrap();
let (seal_frame_get, _) = codec.decode(&encoded[len..]).unwrap();
let DataFrame::Entry(ref entry_get) = data_frame_get[0] else {
let DataFrameOwned::Entry(ref entry_get) = data_frame_get[0] else {
panic!("frame should be type: DataFrame::Entry");
};
let DataFrame::SealIndex(ref index) = seal_frame_get[0] else {
let DataFrameOwned::SealIndex(index) = seal_frame_get[0] else {
panic!("frame should be type: DataFrame::Entry");
};

assert_eq!(*entry_get, entry);
assert_eq!(*index, 1);
assert_eq!(index, 1);
}

#[tokio::test]
async fn frame_zero_write_will_be_detected() {
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
let data_frame = DataFrameOwned::Entry(entry.clone());
let seal_frame = DataFrameOwned::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap();
encoded[0] = 0;

let err = codec.decode(&encoded).unwrap_err();
Expand All @@ -357,9 +379,9 @@ mod tests {
async fn frame_corrupt_will_be_detected() {
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
let data_frame = DataFrameOwned::Entry(entry.clone());
let seal_frame = DataFrameOwned::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame.get_ref()]).unwrap();
encoded[1] = 0;

let err = codec.decode(&encoded).unwrap_err();
Expand Down
119 changes: 73 additions & 46 deletions crates/curp/src/server/storage/wal/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{
Arc,
},
task::Poll,
thread::JoinHandle,
};

use clippy_utilities::OverflowArithmetic;
Expand All @@ -28,77 +29,94 @@ pub(super) struct FilePipeline {
///
/// As tokio::fs is generally slower than std::fs, we use synchronous file allocation.
/// Please also refer to the issue discussed on the tokio repo: https://github.com/tokio-rs/tokio/issues/3664
file_iter: flume::IntoIter<LockedFile>,
file_iter: Option<flume::IntoIter<LockedFile>>,
/// Stopped flag
stopped: Arc<AtomicBool>,
/// Join handle of the allocation task
file_alloc_task_handle: Option<JoinHandle<()>>,
}

impl FilePipeline {
/// Creates a new `FilePipeline`
pub(super) fn new(dir: PathBuf, file_size: u64) -> io::Result<Self> {
Self::clean_up(&dir)?;
pub(super) fn new(dir: PathBuf, file_size: u64) -> Self {
if let Err(e) = Self::clean_up(&dir) {
error!("Failed to clean up tmp files: {e}");
}

let (file_tx, file_rx) = flume::bounded(1);
let dir_c = dir.clone();
let stopped = Arc::new(AtomicBool::new(false));
let stopped_c = Arc::clone(&stopped);

#[cfg(not(madsim))]
let _ignore = std::thread::spawn(move || {
let mut file_count = 0;
loop {
match Self::alloc(&dir_c, file_size, &mut file_count) {
Ok(file) => {
if file_tx.send(file).is_err() {
// The receiver is already dropped, stop this task
break;
}
if stopped_c.load(Ordering::Relaxed) {
if let Err(e) = Self::clean_up(&dir_c) {
error!("failed to clean up pipeline temp files: {e}");
{
let file_alloc_task_handle = std::thread::spawn(move || {
let mut file_count = 0;
loop {
match Self::alloc(&dir_c, file_size, &mut file_count) {
Ok(file) => {
if file_tx.send(file).is_err() {
// The receiver is already dropped, stop this task
break;
}
if stopped_c.load(Ordering::Relaxed) {
if let Err(e) = Self::clean_up(&dir_c) {
error!("failed to clean up pipeline temp files: {e}");
}
break;
}
}
Err(e) => {
error!("failed to allocate file: {e}");
break;
}
}
Err(e) => {
error!("failed to allocate file: {e}");
break;
}
}
});

Self {
dir,
file_size,
file_iter: Some(file_rx.into_iter()),
stopped,
file_alloc_task_handle: Some(file_alloc_task_handle),
}
});
}

#[cfg(madsim)]
let _ignore = tokio::spawn(async move {
let mut file_count = 0;
loop {
match Self::alloc(&dir_c, file_size, &mut file_count) {
Ok(file) => {
if file_tx.send_async(file).await.is_err() {
// The receiver is already dropped, stop this task
break;
}
if stopped_c.load(Ordering::Relaxed) {
if let Err(e) = Self::clean_up(&dir_c) {
error!("failed to clean up pipeline temp files: {e}");
{
let _ignore = tokio::spawn(async move {
let mut file_count = 0;
loop {
match Self::alloc(&dir_c, file_size, &mut file_count) {
Ok(file) => {
if file_tx.send_async(file).await.is_err() {
// The receiver is already dropped, stop this task
break;
}
if stopped_c.load(Ordering::Relaxed) {
if let Err(e) = Self::clean_up(&dir_c) {
error!("failed to clean up pipeline temp files: {e}");
}
break;
}
}
Err(e) => {
error!("failed to allocate file: {e}");
break;
}
}
Err(e) => {
error!("failed to allocate file: {e}");
break;
}
}
});

Self {
dir,
file_size,
file_iter: Some(file_rx.into_iter()),
stopped,
file_alloc_task_handle: None,
}
});

Ok(Self {
dir,
file_size,
file_iter: file_rx.into_iter(),
stopped,
})
}
}

/// Stops the pipeline
Expand Down Expand Up @@ -132,6 +150,11 @@ impl FilePipeline {
impl Drop for FilePipeline {
fn drop(&mut self) {
self.stop();
// Drops the file rx so that the allocation task could exit
drop(self.file_iter.take());
if let Some(Err(e)) = self.file_alloc_task_handle.take().map(JoinHandle::join) {
error!("failed to join file allocation task: {e:?}");
}
}
}

Expand All @@ -142,7 +165,11 @@ impl Iterator for FilePipeline {
if self.stopped.load(Ordering::Relaxed) {
return None;
}
self.file_iter.next().map(Ok)
self.file_iter
.as_mut()
.unwrap_or_else(|| unreachable!("Option is always `Some`"))
.next()
.map(Ok)
}
}

Expand All @@ -164,7 +191,7 @@ mod tests {
async fn file_pipeline_is_ok() {
let file_size = 1024;
let dir = tempfile::tempdir().unwrap();
let mut pipeline = FilePipeline::new(dir.as_ref().into(), file_size).unwrap();
let mut pipeline = FilePipeline::new(dir.as_ref().into(), file_size);

let check_size = |mut file: LockedFile| {
let file = file.into_std();
Expand Down
Loading
Loading