Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the Sorter able to sort in parallel #42

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ jobs:
command: test
args: --all-features

- uses: actions-rs/cargo@v1
with:
command: test

lint:
runs-on: ubuntu-latest
steps:
Expand Down
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ license = "MIT"
[dependencies]
bytemuck = { version = "1.7.0", features = ["derive"] }
byteorder = "1.3.4"
crossbeam-channel = "0.5.8"
flate2 = { version = "1.0", optional = true }
lz4_flex = { version = "0.9.2", optional = true }
rayon = { version = "1.7.0", optional = true }
snap = { version = "1.0.5", optional = true }
tempfile = { version = "3.2.0", optional = true }
zstd = { version = "0.10.0", optional = true }
Expand Down
20 changes: 17 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,21 @@
//! The entries in the grenad files are _immutable_ and the only way to modify them is by _creating
//! a new file_ with the changes.
//!
//! # Example: Use the `Writer` and `Reader` structs
//! # Features
//!
//! You can define which compression schemes to support, there are currently a few
//! available choices, these determine which types will be available inside the above modules:
//!
//! - _Snappy_ with the [`snap`](https://crates.io/crates/snap) crate.
//! - _Zlib_ with the [`flate2`](https://crates.io/crates/flate2) crate.
//! - _Lz4_ with the [`lz4_flex`](https://crates.io/crates/lz4_flex) crate.
//!
//! If you need more performances you can enable the `rayon` feature that will enable a bunch
//! of new settings like being able to make the `Sorter` sort in parallel.
//!
//! # Examples
//!
//! ## Use the `Writer` and `Reader` structs
//!
//! You can use the [`Writer`] struct to store key-value pairs into the specified
//! [`std::io::Write`] type. The [`Reader`] type can then be used to read the entries.
Expand Down Expand Up @@ -37,7 +51,7 @@
//! # Ok(()) }
//! ```
//!
//! # Example: Use the `Merger` struct
//! ## Use the `Merger` struct
//!
//! In this example we show how you can merge multiple [`Reader`]s
//! by using a _merge function_ when a conflict is encountered.
Expand Down Expand Up @@ -107,7 +121,7 @@
//! # Ok(()) }
//! ```
//!
//! # Example: Use the `Sorter` struct
//! ## Use the `Sorter` struct
//!
//! In this example we show how by defining a _merge function_, we can insert
//! multiple entries with the same key and output them in lexicographic order.
Expand Down
166 changes: 164 additions & 2 deletions src/sorter.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use std::alloc::{alloc, dealloc, Layout};
use std::borrow::Cow;
use std::collections::hash_map::DefaultHasher;
use std::convert::Infallible;
#[cfg(feature = "tempfile")]
use std::fs::File;
use std::hash::{Hash, Hasher};
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
use std::iter::repeat_with;
use std::mem::{align_of, size_of};
use std::num::NonZeroUsize;
use std::thread::{self, JoinHandle};
use std::{cmp, io, ops, slice};

use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
use crossbeam_channel::{unbounded, Sender};

use crate::count_write::CountWrite;

Expand Down Expand Up @@ -47,6 +52,7 @@ pub struct SorterBuilder<MF, CC> {
index_levels: Option<u8>,
chunk_creator: CC,
sort_algorithm: SortAlgorithm,
sort_in_parallel: bool,
merge: MF,
}

Expand All @@ -65,6 +71,7 @@ impl<MF> SorterBuilder<MF, DefaultChunkCreator> {
index_levels: None,
chunk_creator: DefaultChunkCreator::default(),
sort_algorithm: SortAlgorithm::Stable,
sort_in_parallel: false,
merge,
}
}
Expand Down Expand Up @@ -142,6 +149,15 @@ impl<MF, CC> SorterBuilder<MF, CC> {
self
}

/// Whether we use [rayon to sort](https://docs.rs/rayon/latest/rayon/slice/trait.ParallelSliceMut.html#method.par_sort_by_key) the entries.
///
/// By default we do not sort in parallel, the value is `false`.
#[cfg(feature = "rayon")]
pub fn sort_in_parallel(&mut self, value: bool) -> &mut Self {
self.sort_in_parallel = value;
self
}

/// The [`ChunkCreator`] struct used to generate the chunks used
/// by the [`Sorter`] to bufferize when required.
pub fn chunk_creator<CC2>(self, creation: CC2) -> SorterBuilder<MF, CC2> {
Expand All @@ -156,6 +172,7 @@ impl<MF, CC> SorterBuilder<MF, CC> {
index_levels: self.index_levels,
chunk_creator: creation,
sort_algorithm: self.sort_algorithm,
sort_in_parallel: self.sort_in_parallel,
merge: self.merge,
}
}
Expand All @@ -181,9 +198,48 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
index_levels: self.index_levels,
chunk_creator: self.chunk_creator,
sort_algorithm: self.sort_algorithm,
sort_in_parallel: self.sort_in_parallel,
merge: self.merge,
}
}

/// Creates a [`ParallelSorter`] configured by this builder.
///
/// Indicate the number of different [`Sorter`] you want to use to balanced
/// the load to sort.
pub fn build_in_parallel<U>(self, number: NonZeroUsize) -> ParallelSorter<MF, U, CC>
where
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>
+ Clone
+ Send
+ 'static,
U: Send + 'static,
CC: Clone + Send + 'static,
CC::Chunk: Send + 'static,
{
match number.get() {
1 => ParallelSorter::Single(self.build()),
number => {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We must do number - 1 as if we want only 2 threads we will result in the main sending thread and only one sorter which is useless, lets us keep only the single sorter when the number is equal to 2.

let (senders, receivers): (Vec<Sender<(usize, Vec<u8>)>>, Vec<_>) =
repeat_with(unbounded).take(number).unzip();

let mut handles = Vec::new();
for receiver in receivers {
let sorter_builder = self.clone();
handles.push(thread::spawn(move || {
let mut sorter = sorter_builder.build();
for (key_length, data) in receiver {
let (key, val) = data.split_at(key_length);
sorter.insert(key, val)?;
}
sorter.into_reader_cursors().map_err(Into::into)
}));
}

ParallelSorter::Multi { senders, handles, merge_function: self.merge }
}
}
}
}

/// Stores entries memory efficiently in a buffer.
Expand Down Expand Up @@ -281,6 +337,27 @@ impl Entries {
sort(bounds, |b: &EntryBound| &tail[tail.len() - b.key_start..][..b.key_length as usize]);
}

/// Sorts in **parallel** the entry bounds by the entries keys,
/// after a sort the `iter` method will yield the entries sorted.
#[cfg(feature = "rayon")]
pub fn par_sort_by_key(&mut self, algorithm: SortAlgorithm) {
use rayon::slice::ParallelSliceMut;

let bounds_end = self.bounds_count * size_of::<EntryBound>();
let (bounds, tail) = self.buffer.split_at_mut(bounds_end);
let bounds = cast_slice_mut::<_, EntryBound>(bounds);
let sort = match algorithm {
SortAlgorithm::Stable => <[EntryBound]>::par_sort_by_key,
SortAlgorithm::Unstable => <[EntryBound]>::par_sort_unstable_by_key,
};
sort(bounds, |b: &EntryBound| &tail[tail.len() - b.key_start..][..b.key_length as usize]);
}

#[cfg(not(feature = "rayon"))]
pub fn par_sort_by_key(&mut self, algorithm: SortAlgorithm) {
self.sort_by_key(algorithm);
}

/// Returns an iterator over the keys and datas.
pub fn iter(&self) -> impl Iterator<Item = (&[u8], &[u8])> + '_ {
let bounds_end = self.bounds_count * size_of::<EntryBound>();
Expand Down Expand Up @@ -399,6 +476,7 @@ pub struct Sorter<MF, CC: ChunkCreator = DefaultChunkCreator> {
index_levels: Option<u8>,
chunk_creator: CC,
sort_algorithm: SortAlgorithm,
sort_in_parallel: bool,
merge: MF,
}

Expand Down Expand Up @@ -489,7 +567,11 @@ where
}
let mut writer = writer_builder.build(count_write_chunk);

self.entries.sort_by_key(self.sort_algorithm);
if self.sort_in_parallel {
self.entries.par_sort_by_key(self.sort_algorithm);
} else {
self.entries.sort_by_key(self.sort_algorithm);
}

let mut current = None;
for (key, value) in self.entries.iter() {
Expand All @@ -509,7 +591,7 @@ where

if let Some((key, vals)) = current.take() {
let merged_val = (self.merge)(key, &vals).map_err(Error::Merge)?;
writer.insert(&key, &merged_val)?;
writer.insert(key, &merged_val)?;
}

// We retrieve the wrapped CountWrite and extract
Expand Down Expand Up @@ -630,6 +712,86 @@ where
}
}

// TODO Make this private by wrapping it
pub enum ParallelSorter<MF, U, CC: ChunkCreator = DefaultChunkCreator>
where
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
{
Single(Sorter<MF, CC>),
Multi {
// Indicates the length of the key and the bytes assoicated to the key + the data.
senders: Vec<Sender<(usize, Vec<u8>)>>,
handles: Vec<JoinHandle<Result<Vec<ReaderCursor<CC::Chunk>>, Error<U>>>>,
merge_function: MF,
},
}

impl<MF, U, CC> ParallelSorter<MF, U, CC>
where
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
CC: ChunkCreator,
{
/// Insert an entry into the [`Sorter`] making sure that conflicts
/// are resolved by the provided merge function.
pub fn insert<K, V>(&mut self, key: K, val: V) -> Result<(), Error<U>>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
let key = key.as_ref();
let val = val.as_ref();
match self {
ParallelSorter::Single(sorter) => sorter.insert(key, val),
ParallelSorter::Multi { senders, .. } => {
let key_length = key.len();
let key_hash = compute_hash(key);

// We put the key and val into the same allocation to speed things up
// by reducing the amount of calls to the allocator.
//
// TODO test that it works for real because having a bigger allocation
// can make it harder to find the space.
let mut data = Vec::with_capacity(key.len() + val.len());
data.extend_from_slice(key);
data.extend_from_slice(val);

let index = (key_hash % senders.len() as u64) as usize;
// TODO remove unwraps
senders[index].send((key_length, data)).unwrap();

Ok(())
}
}
}

/// Consumes this [`Sorter`] and outputs a stream of the merged entries in key-order.
pub fn into_stream_merger_iter(self) -> Result<MergerIter<CC::Chunk, MF>, Error<U>> {
match self {
ParallelSorter::Single(sorter) => sorter.into_stream_merger_iter(),
ParallelSorter::Multi { senders, handles, merge_function } => {
drop(senders);

let mut sources = Vec::new();
for handle in handles {
// TODO remove unwraps
sources.extend(handle.join().unwrap()?);
}

let mut builder = Merger::builder(merge_function);
builder.extend(sources);
builder.build().into_stream_merger_iter().map_err(Error::convert_merge_error)
}
}
}
}

/// Computes the hash of a key.
fn compute_hash(key: &[u8]) -> u64 {
let mut state = DefaultHasher::new();
key.hash(&mut state);
state.finish()
}

/// A trait that represent a `ChunkCreator`.
pub trait ChunkCreator {
/// The generated chunk by this `ChunkCreator`.
Expand Down