diff --git a/src/compression/lz4.rs b/src/compression/lz4.rs index b4c01fe..a9abc12 100644 --- a/src/compression/lz4.rs +++ b/src/compression/lz4.rs @@ -4,7 +4,7 @@ use std::{ task::{Context, Poll}, }; -use bytes::{Buf, BufMut, Bytes}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures::{ready, stream::Stream}; use lz4::liblz4::LZ4_decompress_safe; @@ -71,33 +71,23 @@ where // - [ 1b] magic number (0x82) // - [ 4b] compressed size // - [ 4b] uncompressed size -const LZ4_META_SIZE: usize = 25; +const LZ4_CHECKSUM_SIZE: usize = 16; const LZ4_HEADER_SIZE: usize = 9; +const LZ4_META_SIZE: usize = LZ4_CHECKSUM_SIZE + LZ4_HEADER_SIZE; const LZ4_MAGIC: u8 = 0x82; struct Lz4Meta { checksum: u128, - compressed_size: u32, // TODO: use NonZeroU32? + compressed_size: u32, uncompressed_size: u32, } -impl Lz4Decoder { - pub(crate) fn new(stream: S) -> Self { - Self { - inner: stream, - chunks: BufList::default(), - meta: None, - buffer: Vec::new(), - } - } - - fn read_meta(&mut self) -> Result { - assert!(self.chunks.remaining() >= LZ4_META_SIZE); - - let checksum = self.chunks.get_u128_le(); - let magic = self.chunks.get_u8(); - let compressed_size = self.chunks.get_u32_le(); - let uncompressed_size = self.chunks.get_u32_le(); +impl Lz4Meta { + fn read(mut buffer: impl Buf) -> Result { + let checksum = buffer.get_u128_le(); + let magic = buffer.get_u8(); + let compressed_size = buffer.get_u32_le(); + let uncompressed_size = buffer.get_u32_le(); if magic != LZ4_MAGIC { return Err(Error::Decompression("incorrect magic number".into())); @@ -114,26 +104,48 @@ impl Lz4Decoder { }) } - fn read_data(&mut self, header: Lz4Meta) -> Result { - assert!(self.chunks.remaining() >= header.compressed_size as usize - LZ4_HEADER_SIZE); + fn write_checksum(&self, mut buffer: impl BufMut) { + buffer.put_u128_le(self.checksum); + } - self.buffer.resize(header.compressed_size as usize, 0); + fn write_header(&self, mut buffer: impl BufMut) { + buffer.put_u8(LZ4_MAGIC); + buffer.put_u32_le(self.compressed_size); + buffer.put_u32_le(self.uncompressed_size); + } +} - let compressed = &mut self.buffer[..]; - (&mut compressed[0..]).put_u8(LZ4_MAGIC); - (&mut compressed[1..]).put_u32_le(header.compressed_size); - (&mut compressed[5..]).put_u32_le(header.uncompressed_size); +impl Lz4Decoder { + pub(crate) fn new(stream: S) -> Self { + Self { + inner: stream, + chunks: BufList::default(), + meta: None, + buffer: Vec::new(), + } + } + + fn read_meta(&mut self) -> Result { + assert!(self.chunks.remaining() >= LZ4_META_SIZE); + Lz4Meta::read(&mut self.chunks) + } + + fn read_data(&mut self, meta: Lz4Meta) -> Result { + assert!(self.chunks.remaining() >= meta.compressed_size as usize - LZ4_HEADER_SIZE); + + self.buffer.resize(meta.compressed_size as usize, 0); + meta.write_header(&mut self.buffer[..]); self.chunks - .copy_to_slice(&mut compressed[LZ4_HEADER_SIZE..]); + .copy_to_slice(&mut self.buffer[LZ4_HEADER_SIZE..]); - let actual_checksum = calc_checksum(compressed); - if actual_checksum != header.checksum { + let actual_checksum = calc_checksum(&self.buffer); + if actual_checksum != meta.checksum { return Err(Error::Decompression("checksum mismatch".into())); } - let mut uncompressed = vec![0u8; header.uncompressed_size as usize]; - decompress(&compressed[LZ4_HEADER_SIZE..], &mut uncompressed)?; + let mut uncompressed = vec![0u8; meta.uncompressed_size as usize]; + decompress(&self.buffer[LZ4_HEADER_SIZE..], &mut uncompressed)?; Ok(uncompressed.into()) } } @@ -160,8 +172,37 @@ fn decompress(compressed: &[u8], uncompressed: &mut [u8]) -> Result<()> { Ok(()) } +pub(crate) fn compress(uncompressed: &[u8]) -> Result { + do_compress(uncompressed).map_err(|err| Error::Decompression(err.into())) +} + +fn do_compress(uncompressed: &[u8]) -> std::io::Result { + let max_compressed_size = lz4::block::compress_bound(uncompressed.len())?; + + let mut buffer = BytesMut::new(); + buffer.resize(LZ4_META_SIZE + max_compressed_size, 0); + + // TODO: pass settings. + let compressed_data_size = + lz4::block::compress_to_buffer(uncompressed, None, false, &mut buffer[LZ4_META_SIZE..])?; + + buffer.truncate(LZ4_META_SIZE + compressed_data_size); + + let mut meta = Lz4Meta { + checksum: 0, // will be calculated below. + compressed_size: (LZ4_HEADER_SIZE + compressed_data_size) as u32, + uncompressed_size: uncompressed.len() as u32, + }; + + meta.write_header(&mut buffer[LZ4_CHECKSUM_SIZE..]); + meta.checksum = calc_checksum(&buffer[LZ4_CHECKSUM_SIZE..]); + meta.write_checksum(&mut buffer[..]); + + Ok(buffer.freeze()) +} + #[tokio::test] -async fn it_decompress() { +async fn it_decompresses() { use futures::stream::{self, TryStreamExt}; let expected = vec![ @@ -185,7 +226,7 @@ async fn it_decompress() { ); let mut decoder = Lz4Decoder::new(stream); let actual = decoder.try_next().await.unwrap(); - assert_eq!(actual, Some(Bytes::copy_from_slice(expected))); + assert_eq!(actual.as_deref(), Some(expected)); } // 1 chunk. @@ -203,3 +244,20 @@ async fn it_decompress() { } } } + +#[test] +fn it_compresses() { + let source = vec![ + 1u8, 0, 2, 255, 255, 255, 255, 0, 1, 1, 1, 115, 6, 83, 116, 114, 105, 110, 103, 3, 97, 98, + 99, + ]; + + let expected = vec![ + 245_u8, 5, 222, 235, 225, 158, 59, 108, 225, 31, 65, 215, 66, 66, 36, 92, 130, 34, 0, 0, 0, + 23, 0, 0, 0, 240, 8, 1, 0, 2, 255, 255, 255, 255, 0, 1, 1, 1, 115, 6, 83, 116, 114, 105, + 110, 103, 3, 97, 98, 99, + ]; + + let actual = compress(&source).unwrap(); + assert_eq!(actual, expected); +} diff --git a/src/compression/mod.rs b/src/compression/mod.rs index a27e2f7..0bbde61 100644 --- a/src/compression/mod.rs +++ b/src/compression/mod.rs @@ -23,12 +23,7 @@ impl Default for Compression { #[cfg(feature = "lz4")] #[inline] fn default() -> Self { - // TODO: remove when compression will be implemented. - if cfg!(feature = "test-util") { - Compression::None - } else { - Compression::Lz4 - } + Compression::Lz4 } #[cfg(not(feature = "lz4"))] diff --git a/src/error.rs b/src/error.rs index 248e5c8..cd782eb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,8 @@ pub enum Error { InvalidParams(#[source] Box), #[error("network error: {0}")] Network(#[source] Box), + #[error("compression error: {0}")] + Compression(#[source] Box), #[error("decompression error: {0}")] Decompression(#[source] Box), #[error("no rows returned by a query that expected to return at least one row")] diff --git a/src/insert.rs b/src/insert.rs index 8090a23..3195cbc 100644 --- a/src/insert.rs +++ b/src/insert.rs @@ -1,12 +1,13 @@ use std::{future::Future, marker::PhantomData, mem, panic}; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use hyper::{self, body, Body, Request}; use serde::Serialize; use tokio::task::JoinHandle; use url::Url; use crate::{ + compression, error::{Error, Result}, response::Response, row::{self, Row}, @@ -20,6 +21,8 @@ const MIN_CHUNK_SIZE: usize = BUFFER_SIZE - 1024; pub struct Insert { buffer: BytesMut, sender: Option, + #[cfg(feature = "lz4")] + compression: Compression, handle: JoinHandle>, _marker: PhantomData T>, // TODO: test contravariance. } @@ -44,6 +47,12 @@ impl Insert { // https://clickhouse.yandex/docs/en/query_language/syntax/#syntax-identifiers let query = format!("INSERT INTO {}({}) FORMAT RowBinary", table, fields); pairs.append_pair("query", &query); + + #[cfg(feature = "lz4")] + if client.compression == Compression::Lz4 { + pairs.append_pair("decompress", "1"); + } + drop(pairs); let mut builder = Request::post(url.as_str()); @@ -66,9 +75,18 @@ impl Insert { let handle = tokio::spawn(async move { Response::new(future, Compression::None).finish().await }); - Ok(Insert { + #[cfg(feature = "lz4")] + let compression = if client.compression == Compression::Lz4 { + Compression::Lz4 + } else { + Compression::None + }; + + Ok(Self { buffer: BytesMut::with_capacity(BUFFER_SIZE), sender: Some(sender), + #[cfg(feature = "lz4")] + compression, handle, _marker: PhantomData, }) @@ -102,10 +120,10 @@ impl Insert { // Hyper uses non-trivial and inefficient (see benches) schema of buffering chunks. // It's difficult to determine when allocations occur. // So, instead we control it manually here and rely on the system allocator. - let chunk = mem::replace(&mut self.buffer, BytesMut::with_capacity(BUFFER_SIZE)); + let chunk = self.take_and_prepare_chunk()?; if let Some(sender) = &mut self.sender { - if sender.send_data(chunk.freeze()).await.is_err() { + if sender.send_data(chunk).await.is_err() { self.abort(); self.wait_handle().await?; // real error should be here. return Err(Error::Network("channel closed".into())); @@ -129,6 +147,22 @@ impl Insert { } } + #[cfg(feature = "lz4")] + fn take_and_prepare_chunk(&mut self) -> Result { + Ok(if self.compression == Compression::Lz4 { + let compressed = compression::lz4::compress(&self.buffer)?; + self.buffer.clear(); + compressed + } else { + mem::replace(&mut self.buffer, BytesMut::with_capacity(BUFFER_SIZE)).freeze() + }) + } + + #[cfg(not(feature = "lz4"))] + fn take_and_prepare_chunk(&mut self) -> Result { + mem::replace(&mut self.buffer, BytesMut::with_capacity(BUFFER_SIZE)).freeze() + } + fn abort(&mut self) { if let Some(sender) = self.sender.take() { sender.abort(); diff --git a/src/lib.rs b/src/lib.rs index 6db26a2..3a0bc2d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,12 +101,7 @@ impl Client { } pub fn with_compression(mut self, compression: Compression) -> Self { - // TODO: remove when compression will be implemented. - self.compression = if cfg!(feature = "test-util") { - Compression::None - } else { - compression - }; + self.compression = compression; self }