Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: switch wal to sync implementation #705

Merged
merged 7 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 46 additions & 74 deletions crates/curp/src/server/storage/wal/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use curp_external_api::LogIndex;
use serde::{de::DeserializeOwned, Serialize};
use sha2::{Digest, Sha256};
use thiserror::Error;
use tokio_util::codec::{Decoder, Encoder};

use super::{
error::{CorruptType, WALError},
framed::{Decoder, Encoder},
util::{get_checksum, validate_data},
};
use crate::log_entry::LogEntry;
Expand Down Expand Up @@ -104,18 +104,13 @@ where
{
type Error = io::Error;

fn encode(
&mut self,
frames: Vec<DataFrame<C>>,
dst: &mut bytes::BytesMut,
) -> Result<(), Self::Error> {
let frames_bytes: Vec<_> = frames.into_iter().flat_map(|f| f.encode()).collect();
let commit_frame = CommitFrame::new_from_data(&frames_bytes);
/// Encodes a frame
fn encode(&mut self, frames: Vec<DataFrame<C>>) -> Result<Vec<u8>, Self::Error> {
let mut frame_data: Vec<_> = frames.into_iter().flat_map(|f| f.encode()).collect();
let commit_frame = CommitFrame::new_from_data(&frame_data);
frame_data.extend_from_slice(&commit_frame.encode());

dst.extend(frames_bytes);
dst.extend(commit_frame.encode());

Ok(())
Ok(frame_data)
}
}

Expand All @@ -127,30 +122,32 @@ where

type Error = WALError;

fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
if let Some((frame, len)) = WALFrame::<C>::decode(src)? {
let decoded_bytes = src.split_to(len);
match frame {
WALFrame::Data(data) => {
self.frames.push(data);
self.hasher.update(decoded_bytes);
}
WALFrame::Commit(commit) => {
let frames_bytes: Vec<_> =
self.frames.iter().flat_map(DataFrame::encode).collect();
let checksum = self.hasher.clone().finalize();
self.hasher.reset();
if commit.validate(&checksum) {
return Ok(Some(self.frames.drain(..).collect()));
}
return Err(WALError::Corrupted(CorruptType::Checksum));
#[allow(clippy::arithmetic_side_effects)] // the arithmetic only used as slice indices
fn decode(&mut self, src: &[u8]) -> Result<(Self::Item, usize), Self::Error> {
let mut cursor = 0;
while cursor < src.len() {
let next = src.get(cursor..).ok_or(WALError::MaybeEnded)?;
let Some((frame, len)) = WALFrame::<C>::decode(next)? else {
return Err(WALError::MaybeEnded);
};
let decoded_bytes = src.get(cursor..cursor + len).ok_or(WALError::MaybeEnded)?;
cursor += len;
match frame {
WALFrame::Data(data) => {
self.frames.push(data);
self.hasher.update(decoded_bytes);
}
WALFrame::Commit(commit) => {
let checksum = self.hasher.clone().finalize();
self.hasher.reset();
if commit.validate(&checksum) {
return Ok((self.frames.drain(..).collect(), cursor));
}
return Err(WALError::Corrupted(CorruptType::Checksum));
}
} else {
return Ok(None);
}
}
Err(WALError::MaybeEnded)
}
}

Expand Down Expand Up @@ -323,25 +320,19 @@ mod tests {

#[tokio::test]
async fn frame_encode_decode_is_ok() {
let file = TokioFile::from(tempfile().unwrap());
let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
framed.send(vec![data_frame]).await.unwrap();
framed.send(vec![seal_frame]).await.unwrap();
framed.get_mut().flush().await;
let mut encoded = codec.encode(vec![data_frame]).unwrap();
encoded.extend_from_slice(&codec.encode(vec![seal_frame]).unwrap());

let mut file = framed.into_inner();
file.seek(io::SeekFrom::Start(0)).await.unwrap();
let mut framed = Framed::new(file, WAL::<TestCommand>::new());

let data_frame_get = &framed.next().await.unwrap().unwrap()[0];
let seal_frame_get = &framed.next().await.unwrap().unwrap()[0];
let DataFrame::Entry(ref entry_get) = *data_frame_get else {
let (data_frame_get, len) = codec.decode(&encoded).unwrap();
let (seal_frame_get, _) = codec.decode(&encoded[len..]).unwrap();
let DataFrame::Entry(ref entry_get) = data_frame_get[0] else {
panic!("frame should be type: DataFrame::Entry");
};
let DataFrame::SealIndex(ref index) = *seal_frame_get else {
let DataFrame::SealIndex(ref index) = seal_frame_get[0] else {
panic!("frame should be type: DataFrame::Entry");
};

Expand All @@ -351,46 +342,27 @@ mod tests {

#[tokio::test]
async fn frame_zero_write_will_be_detected() {
let file = TokioFile::from(tempfile().unwrap());
let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
framed.send(vec![data_frame]).await.unwrap();
framed.get_mut().flush().await;

let mut file = framed.into_inner();
/// zero the first byte, it will reach a success state,
/// all following data will be truncated
file.seek(io::SeekFrom::Start(0)).await.unwrap();
file.write_u8(0).await;

file.seek(io::SeekFrom::Start(0)).await.unwrap();

let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
encoded[0] = 0;

let err = framed.next().await.unwrap().unwrap_err();
let err = codec.decode(&encoded).unwrap_err();
assert!(matches!(err, WALError::MaybeEnded), "error {err} not match");
}

#[tokio::test]
async fn frame_corrupt_will_be_detected() {
let file = TokioFile::from(tempfile().unwrap());
let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let mut codec = WAL::<TestCommand>::new();
let entry = LogEntry::<TestCommand>::new(1, 1, ProposeId(1, 2), EntryData::Empty);
let data_frame = DataFrame::Entry(entry.clone());
framed.send(vec![data_frame]).await.unwrap();
framed.get_mut().flush().await;

let mut file = framed.into_inner();
/// This will cause a failure state
file.seek(io::SeekFrom::Start(1)).await.unwrap();
file.write_u8(0).await;

file.seek(io::SeekFrom::Start(0)).await.unwrap();

let mut framed = Framed::new(file, WAL::<TestCommand>::new());
let seal_frame = DataFrame::<TestCommand>::SealIndex(1);
let mut encoded = codec.encode(vec![data_frame]).unwrap();
encoded[1] = 0;

let err = framed.next().await.unwrap().unwrap_err();
let err = codec.decode(&encoded).unwrap_err();
assert!(
matches!(err, WALError::Corrupted(_)),
"error {err} not match"
Expand Down
22 changes: 22 additions & 0 deletions crates/curp/src/server/storage/wal/framed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::io;

/// Decoding of frames via buffers.
pub(super) trait Decoder {
/// The type of decoded frames.
type Item;

/// The type of unrecoverable frame decoding errors.
type Error: From<io::Error>;

/// Attempts to decode a frame from the provided buffer of bytes.
fn decode(&mut self, src: &[u8]) -> Result<(Self::Item, usize), Self::Error>;
}

/// Trait of helper objects to write out messages as bytes
pub(super) trait Encoder<Item> {
/// The type of encoding errors.
type Error: From<io::Error>;

/// Encodes a frame
fn encode(&mut self, item: Item) -> Result<Vec<u8>, Self::Error>;
}
3 changes: 3 additions & 0 deletions crates/curp/src/server/storage/wal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ mod segment;
/// File utils
mod util;

/// Framed traits
mod framed;

/// The magic of the WAL file
const WAL_MAGIC: u32 = 0xd86e_0be2;

Expand Down
31 changes: 13 additions & 18 deletions crates/curp/src/server/storage/wal/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@ use std::{

use clippy_utilities::OverflowArithmetic;
use event_listener::Event;
use flume::r#async::RecvStream;
use futures::{FutureExt, StreamExt};
use thiserror::Error;
use tokio::task::JoinHandle;
use tokio_stream::Stream;
use tracing::error;

use super::util::LockedFile;
Expand All @@ -28,8 +24,11 @@ pub(super) struct FilePipeline {
dir: PathBuf,
/// The size of the temp file
file_size: u64,
/// The file receive stream
file_stream: RecvStream<'static, LockedFile>,
/// The file receive iterator
///
/// As tokio::fs is generally slower than std::fs, we use synchronous file allocation.
/// Please also refer to the issue discussed on the tokio repo: https://github.com/tokio-rs/tokio/issues/3664
file_iter: flume::IntoIter<LockedFile>,
/// Stopped flag
stopped: Arc<AtomicBool>,
}
Expand Down Expand Up @@ -97,7 +96,7 @@ impl FilePipeline {
Ok(Self {
dir,
file_size,
file_stream: file_rx.into_stream(),
file_iter: file_rx.into_iter(),
stopped,
})
}
Expand Down Expand Up @@ -136,18 +135,14 @@ impl Drop for FilePipeline {
}
}

impl Stream for FilePipeline {
impl Iterator for FilePipeline {
type Item = io::Result<LockedFile>;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
fn next(&mut self) -> Option<Self::Item> {
if self.stopped.load(Ordering::Relaxed) {
return Poll::Ready(None);
return None;
}

self.file_stream.poll_next_unpin(cx).map(|opt| opt.map(Ok))
self.file_iter.next().map(Ok)
}
}

Expand Down Expand Up @@ -175,11 +170,11 @@ mod tests {
let file = file.into_std();
assert_eq!(file.metadata().unwrap().len(), file_size,);
};
let file0 = pipeline.next().await.unwrap().unwrap();
let file0 = pipeline.next().unwrap().unwrap();
check_size(file0);
let file1 = pipeline.next().await.unwrap().unwrap();
let file1 = pipeline.next().unwrap().unwrap();
check_size(file1);
pipeline.stop();
assert!(pipeline.next().await.is_none());
assert!(pipeline.next().is_none());
}
}
Loading
Loading