From c9e862380a79b545e69db43e8cd8da61440302c3 Mon Sep 17 00:00:00 2001 From: Antonio Abbatangelo Date: Mon, 13 Mar 2023 03:01:28 -0400 Subject: [PATCH] Add support for no_std --- .github/workflows/ci.yml | 71 +++++++++++ Cargo.toml | 16 ++- src/blocks/block.rs | 4 +- src/blocks/literals_section.rs | 4 +- src/blocks/sequence_section.rs | 4 +- src/decoding/block_decoder.rs | 42 +++---- src/decoding/decodebuffer.rs | 41 +++--- src/decoding/dictionary.rs | 3 +- src/decoding/literals_section_decoder.rs | 5 +- src/decoding/ringbuffer.rs | 19 +-- src/decoding/scratch.rs | 1 + src/decoding/sequence_execution.rs | 2 - src/decoding/sequence_section_decoder.rs | 68 ++++------ src/frame.rs | 18 +-- src/frame_decoder.rs | 62 ++++----- src/fse/fse_decoder.rs | 1 + src/huff0/huff0_decoder.rs | 11 +- src/io.rs | 2 + src/io_nostd.rs | 152 +++++++++++++++++++++++ src/lib.rs | 29 ++++- src/streaming_decoder.rs | 10 +- src/tests/decode_corpus.rs | 5 + src/tests/dict_test.rs | 6 + src/tests/fuzz_regressions.rs | 2 + src/tests/mod.rs | 131 +++++++++++++++++-- 25 files changed, 536 insertions(+), 173 deletions(-) create mode 100644 src/io.rs create mode 100644 src/io_nostd.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d4efd0d..3255ac5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,6 +68,77 @@ jobs: command: clippy args: -- -D warnings + check-no-std: + name: Check (no_std) + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Install nightly toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: nightly + override: true + + - name: Run cargo check without std feature + uses: actions-rs/cargo@v1 + with: + command: check + toolchain: nightly + args: --no-default-features + + test-no-std: + name: Test Suite (no_std) + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Install nightly toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: nightly + override: true + + - name: Run cargo test without std feature + uses: actions-rs/cargo@v1 + with: + command: test + toolchain: nightly + args: --no-default-features + + lints-no-std: + name: Lints + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Install nightly toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: nightly + override: true + components: rustfmt, clippy + + - name: Run cargo fmt + uses: actions-rs/cargo@v1 + with: + command: fmt + toolchain: nightly + args: --all -- --check + + - name: Run cargo clippy + uses: actions-rs/cargo@v1 + with: + command: clippy + toolchain: nightly + args: --no-default-features -- -D warnings + # fails CI because criterion needs two versions of autocfg #cargo-deny: # name: Cargo Deny diff --git a/Cargo.toml b/Cargo.toml index ec4e0e8..da76a80 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,14 +11,26 @@ exclude = ["decodecorpus_files/*", "dict_tests/*", "fuzz_decodecorpus/*"] readme = "Readme.md" [dependencies] -byteorder = "1.4" +byteorder = { version = "1.4", default-features = false } twox-hash = { version = "1.6", default-features = false } -thiserror = "1" +thiserror = { package = "thiserror-core", version = "1.0.38", default-features = false } [dev-dependencies] criterion = "0.3" rand = "0.8.5" +[features] +default = ["std"] +std = ["thiserror/std"] + [[bench]] name = "reversedbitreader_bench" harness = false + +[[bin]] +name = "zstd" +required-features = ["std"] + +[[bin]] +name = "zstd_stream" +required-features = ["std"] diff --git a/src/blocks/block.rs b/src/blocks/block.rs index 9c872eb..078eb44 100644 --- a/src/blocks/block.rs +++ b/src/blocks/block.rs @@ -6,8 +6,8 @@ pub enum BlockType { Reserved, } -impl std::fmt::Display for BlockType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { +impl core::fmt::Display for BlockType { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { match self { BlockType::Compressed => write!(f, "Compressed"), BlockType::Raw => write!(f, "Raw"), diff --git a/src/blocks/literals_section.rs b/src/blocks/literals_section.rs index 9d83072..159e3b3 100644 --- a/src/blocks/literals_section.rs +++ b/src/blocks/literals_section.rs @@ -25,8 +25,8 @@ pub enum LiteralsSectionParseError { NotEnoughBytes { have: usize, need: u8 }, } -impl std::fmt::Display for LiteralsSectionType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { +impl core::fmt::Display for LiteralsSectionType { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { match self { LiteralsSectionType::Compressed => write!(f, "Compressed"), LiteralsSectionType::Raw => write!(f, "Raw"), diff --git a/src/blocks/sequence_section.rs b/src/blocks/sequence_section.rs index 28b9eff..544755e 100644 --- a/src/blocks/sequence_section.rs +++ b/src/blocks/sequence_section.rs @@ -10,8 +10,8 @@ pub struct Sequence { pub of: u32, } -impl std::fmt::Display for Sequence { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { +impl core::fmt::Display for Sequence { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { write!(f, "LL: {}, ML: {}, OF: {}", self.ll, self.ml, self.of) } } diff --git a/src/decoding/block_decoder.rs b/src/decoding/block_decoder.rs index 11a4c28..0fbce8f 100644 --- a/src/decoding/block_decoder.rs +++ b/src/decoding/block_decoder.rs @@ -11,7 +11,7 @@ use crate::blocks::literals_section::LiteralsSectionParseError; use crate::blocks::sequence_section::SequencesHeaderParseError; use crate::decoding::scratch::DecoderScratch; use crate::decoding::sequence_execution::execute_sequences; -use std::io::{self, Read}; +use crate::io::{self, Read}; pub struct BlockDecoder { header_buffer: [u8; 3], @@ -203,12 +203,12 @@ impl BlockDecoder { let mut section = LiteralsSection::new(); let bytes_in_literals_header = section.parse_from_header(raw)?; let raw = &raw[bytes_in_literals_header as usize..]; - if crate::VERBOSE { - println!( - "Found {} literalssection with regenerated size: {}, and compressed size: {:?}", - section.ls_type, section.regenerated_size, section.compressed_size - ); - } + vprintln!( + "Found {} literalssection with regenerated size: {}, and compressed size: {:?}", + section.ls_type, + section.regenerated_size, + section.compressed_size + ); let upper_limit_for_literals = match section.compressed_size { Some(x) => x as usize, @@ -227,9 +227,7 @@ impl BlockDecoder { } let raw_literals = &raw[..upper_limit_for_literals]; - if crate::VERBOSE { - println!("Slice for literals: {}", raw_literals.len()); - } + vprintln!("Slice for literals: {}", raw_literals.len()); workspace.literals_buffer.clear(); //all literals of the previous block must have been used in the sequence execution anyways. just be defensive here let bytes_used_in_literals_section = decode_literals( @@ -247,20 +245,16 @@ impl BlockDecoder { assert!(bytes_used_in_literals_section == upper_limit_for_literals as u32); let raw = &raw[upper_limit_for_literals..]; - if crate::VERBOSE { - println!("Slice for sequences with headers: {}", raw.len()); - } + vprintln!("Slice for sequences with headers: {}", raw.len()); let mut seq_section = SequencesHeader::new(); let bytes_in_sequence_header = seq_section.parse_from_header(raw)?; let raw = &raw[bytes_in_sequence_header as usize..]; - if crate::VERBOSE { - println!( - "Found sequencessection with sequences: {} and size: {}", - seq_section.num_sequences, - raw.len() - ); - } + vprintln!( + "Found sequencessection with sequences: {} and size: {}", + seq_section.num_sequences, + raw.len() + ); assert!( u32::from(bytes_in_literals_header) @@ -269,9 +263,7 @@ impl BlockDecoder { + raw.len() as u32 == header.content_size ); - if crate::VERBOSE { - println!("Slice for sequences: {}", raw.len()); - } + vprintln!("Slice for sequences: {}", raw.len()); if seq_section.num_sequences != 0 { decode_sequences( @@ -280,9 +272,7 @@ impl BlockDecoder { &mut workspace.fse, &mut workspace.sequences, )?; - if crate::VERBOSE { - println!("Executing sequences"); - } + vprintln!("Executing sequences"); execute_sequences(workspace)?; } else { workspace.buffer.push(&workspace.literals_buffer); diff --git a/src/decoding/decodebuffer.rs b/src/decoding/decodebuffer.rs index 0dea701..33cc58c 100644 --- a/src/decoding/decodebuffer.rs +++ b/src/decoding/decodebuffer.rs @@ -1,5 +1,6 @@ -use std::hash::Hasher; -use std::io; +use crate::io::{Error, Read, Write}; +use alloc::vec::Vec; +use core::hash::Hasher; use twox_hash::XxHash64; @@ -23,8 +24,8 @@ pub enum DecodebufferError { OffsetTooBig { offset: usize, buf_len: usize }, } -impl io::Read for Decodebuffer { - fn read(&mut self, target: &mut [u8]) -> io::Result { +impl Read for Decodebuffer { + fn read(&mut self, target: &mut [u8]) -> Result { let max_amount = self.can_drain_to_window_size().unwrap_or(0); let amount = max_amount.min(target.len()); @@ -176,7 +177,7 @@ impl Decodebuffer { } } - pub fn drain_to_window_size_writer(&mut self, mut sink: impl io::Write) -> io::Result { + pub fn drain_to_window_size_writer(&mut self, mut sink: impl Write) -> Result { match self.can_drain_to_window_size() { None => Ok(0), Some(can_drain) => { @@ -199,14 +200,14 @@ impl Decodebuffer { vec } - pub fn drain_to_writer(&mut self, mut sink: impl io::Write) -> io::Result { + pub fn drain_to_writer(&mut self, mut sink: impl Write) -> Result { let len = self.buffer.len(); self.drain_to(len, |buf| write_all_bytes(&mut sink, buf))?; Ok(len) } - pub fn read_all(&mut self, target: &mut [u8]) -> io::Result { + pub fn read_all(&mut self, target: &mut [u8]) -> Result { let amount = self.buffer.len().min(target.len()); let mut written = 0; @@ -224,8 +225,8 @@ impl Decodebuffer { fn drain_to( &mut self, amount: usize, - mut write_bytes: impl FnMut(&[u8]) -> (usize, io::Result<()>), - ) -> io::Result<()> { + mut write_bytes: impl FnMut(&[u8]) -> (usize, Result<(), Error>), + ) -> Result<(), Error> { if amount == 0 { return Ok(()); } @@ -280,7 +281,7 @@ impl Decodebuffer { } /// Like Write::write_all but returns partial write length even on error -fn write_all_bytes(mut sink: impl io::Write, buf: &[u8]) -> (usize, io::Result<()>) { +fn write_all_bytes(mut sink: impl Write, buf: &[u8]) -> (usize, Result<(), Error>) { let mut written = 0; while written < buf.len() { match sink.write(&buf[written..]) { @@ -294,7 +295,11 @@ fn write_all_bytes(mut sink: impl io::Write, buf: &[u8]) -> (usize, io::Result<( #[cfg(test)] mod tests { use super::Decodebuffer; - use std::io::Write; + use crate::io::{Error, ErrorKind, Write}; + + extern crate std; + use alloc::vec; + use alloc::vec::Vec; #[test] fn short_writer() { @@ -304,7 +309,7 @@ mod tests { } impl Write for ShortWriter { - fn write(&mut self, buf: &[u8]) -> std::result::Result { + fn write(&mut self, buf: &[u8]) -> std::result::Result { if buf.len() > self.write_len { self.buf.extend_from_slice(&buf[..self.write_len]); Ok(self.write_len) @@ -314,7 +319,7 @@ mod tests { } } - fn flush(&mut self) -> std::result::Result<(), std::io::Error> { + fn flush(&mut self) -> std::result::Result<(), Error> { Ok(()) } } @@ -352,18 +357,18 @@ mod tests { } impl Write for WouldblockWriter { - fn write(&mut self, buf: &[u8]) -> std::result::Result { + fn write(&mut self, buf: &[u8]) -> std::result::Result { if self.last_blocked < self.block_every { self.buf.extend_from_slice(buf); self.last_blocked += 1; Ok(buf.len()) } else { self.last_blocked = 0; - Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)) + Err(Error::from(ErrorKind::WouldBlock)) } } - fn flush(&mut self) -> std::result::Result<(), std::io::Error> { + fn flush(&mut self) -> std::result::Result<(), Error> { Ok(()) } } @@ -390,7 +395,7 @@ mod tests { } } Err(e) => { - if e.kind() == std::io::ErrorKind::WouldBlock { + if e.kind() == ErrorKind::WouldBlock { continue; } else { panic!("Unexpected error {:?}", e); @@ -410,7 +415,7 @@ mod tests { } } Err(e) => { - if e.kind() == std::io::ErrorKind::WouldBlock { + if e.kind() == ErrorKind::WouldBlock { continue; } else { panic!("Unexpected error {:?}", e); diff --git a/src/decoding/dictionary.rs b/src/decoding/dictionary.rs index 51fbcdf..aa67693 100644 --- a/src/decoding/dictionary.rs +++ b/src/decoding/dictionary.rs @@ -1,4 +1,5 @@ -use std::convert::TryInto; +use alloc::vec::Vec; +use core::convert::TryInto; use crate::decoding::scratch::FSEScratch; use crate::decoding::scratch::HuffmanScratch; diff --git a/src/decoding/literals_section_decoder.rs b/src/decoding/literals_section_decoder.rs index d947f87..bd7fb18 100644 --- a/src/decoding/literals_section_decoder.rs +++ b/src/decoding/literals_section_decoder.rs @@ -2,6 +2,7 @@ use super::super::blocks::literals_section::{LiteralsSection, LiteralsSectionTyp use super::bit_reader_reverse::{BitReaderReversed, GetBitsError}; use super::scratch::HuffmanScratch; use crate::huff0::{HuffmanDecoder, HuffmanDecoderError, HuffmanTableError}; +use alloc::vec::Vec; #[derive(Debug, thiserror::Error)] #[non_exhaustive] @@ -75,9 +76,7 @@ fn decompress_literals( LiteralsSectionType::Compressed => { //read Huffman tree description bytes_read += scratch.table.build_decoder(source)?; - if crate::VERBOSE { - println!("Built huffman table using {} bytes", bytes_read); - } + vprintln!("Built huffman table using {} bytes", bytes_read); } LiteralsSectionType::Treeless => { if scratch.table.max_num_bits == 0 { diff --git a/src/decoding/ringbuffer.rs b/src/decoding/ringbuffer.rs index 9e3e9ba..fc9a2e3 100644 --- a/src/decoding/ringbuffer.rs +++ b/src/decoding/ringbuffer.rs @@ -1,4 +1,5 @@ -use std::{alloc::Layout, ptr::NonNull, slice}; +use alloc::alloc::{alloc, dealloc}; +use core::{alloc::Layout, ptr::NonNull, slice}; pub struct RingBuffer { buf: NonNull, @@ -70,7 +71,7 @@ impl RingBuffer { // alloc the new memory region and panic if alloc fails // TODO maybe rework this to generate an error? let new_buf = unsafe { - let new_buf = std::alloc::alloc(new_layout); + let new_buf = alloc(new_layout); NonNull::new(new_buf).expect("Allocating new space for the ringbuffer failed") }; @@ -85,7 +86,7 @@ impl RingBuffer { .as_ptr() .add(s1_len) .copy_from_nonoverlapping(s2_ptr, s2_len); - std::alloc::dealloc(self.buf.as_ptr(), current_layout); + dealloc(self.buf.as_ptr(), current_layout); } self.tail = s1_len + s2_len; @@ -341,7 +342,7 @@ impl Drop for RingBuffer { let current_layout = unsafe { Layout::array::(self.cap).unwrap_unchecked() }; unsafe { - std::alloc::dealloc(self.buf.as_ptr(), current_layout); + dealloc(self.buf.as_ptr(), current_layout); } } } @@ -448,8 +449,8 @@ unsafe fn copy_with_nobranch_check( f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2); } - 6 => std::hint::unreachable_unchecked(), - 7 => std::hint::unreachable_unchecked(), + 6 => core::hint::unreachable_unchecked(), + 7 => core::hint::unreachable_unchecked(), 9 => { f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1); f2_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f2); @@ -480,9 +481,9 @@ unsafe fn copy_with_nobranch_check( .add(m1_in_f2) .copy_from_nonoverlapping(m2_ptr, m2_in_f2); } - 14 => std::hint::unreachable_unchecked(), - 15 => std::hint::unreachable_unchecked(), - _ => std::hint::unreachable_unchecked(), + 14 => core::hint::unreachable_unchecked(), + 15 => core::hint::unreachable_unchecked(), + _ => core::hint::unreachable_unchecked(), } } diff --git a/src/decoding/scratch.rs b/src/decoding/scratch.rs index 2bd753b..5b905a7 100644 --- a/src/decoding/scratch.rs +++ b/src/decoding/scratch.rs @@ -3,6 +3,7 @@ use super::decodebuffer::Decodebuffer; use crate::decoding::dictionary::Dictionary; use crate::fse::FSETable; use crate::huff0::HuffmanTable; +use alloc::vec::Vec; pub struct DecoderScratch { pub huf: HuffmanScratch, diff --git a/src/decoding/sequence_execution.rs b/src/decoding/sequence_execution.rs index 19247ec..5946df1 100644 --- a/src/decoding/sequence_execution.rs +++ b/src/decoding/sequence_execution.rs @@ -18,8 +18,6 @@ pub fn execute_sequences(scratch: &mut DecoderScratch) -> Result<(), ExecuteSequ for idx in 0..scratch.sequences.len() { let seq = scratch.sequences[idx]; - if crate::VERBOSE {} - //println!("{}: {}", idx, seq); if seq.ll > 0 { let high = literals_copy_counter + seq.ll as usize; diff --git a/src/decoding/sequence_section_decoder.rs b/src/decoding/sequence_section_decoder.rs index 3d5990c..6c366fb 100644 --- a/src/decoding/sequence_section_decoder.rs +++ b/src/decoding/sequence_section_decoder.rs @@ -4,6 +4,7 @@ use super::super::blocks::sequence_section::SequencesHeader; use super::bit_reader_reverse::{BitReaderReversed, GetBitsError}; use super::scratch::FSEScratch; use crate::fse::{FSEDecoder, FSEDecoderError, FSETableError}; +use alloc::vec::Vec; #[derive(Debug, thiserror::Error)] #[non_exhaustive] @@ -42,9 +43,7 @@ pub fn decode_sequences( ) -> Result<(), DecodeSequenceError> { let bytes_read = maybe_update_fse_tables(section, source, scratch)?; - if crate::VERBOSE { - println!("Updating tables used {} bytes", bytes_read); - } + vprintln!("Updating tables used {} bytes", bytes_read); let bit_stream = &source[bytes_read..]; @@ -319,16 +318,13 @@ fn maybe_update_fse_tables( ModeType::FSECompressed => { let bytes = scratch.literal_lengths.build_decoder(source, LL_MAX_LOG)?; bytes_read += bytes; - if crate::VERBOSE { - println!("Updating ll table"); - println!("Used bytes: {}", bytes); - } + + vprintln!("Updating ll table"); + vprintln!("Used bytes: {}", bytes); scratch.ll_rle = None; } ModeType::RLE => { - if crate::VERBOSE { - println!("Use RLE ll table"); - } + vprintln!("Use RLE ll table"); if source.is_empty() { return Err(DecodeSequenceError::MissingByteForRleLlTable); } @@ -336,9 +332,7 @@ fn maybe_update_fse_tables( scratch.ll_rle = Some(source[0]); } ModeType::Predefined => { - if crate::VERBOSE { - println!("Use predefined ll table"); - } + vprintln!("Use predefined ll table"); scratch.literal_lengths.build_from_probabilities( LL_DEFAULT_ACC_LOG, &Vec::from(&LITERALS_LENGTH_DEFAULT_DISTRIBUTION[..]), @@ -346,9 +340,7 @@ fn maybe_update_fse_tables( scratch.ll_rle = None; } ModeType::Repeat => { - if crate::VERBOSE { - println!("Repeat ll table"); - } + vprintln!("Repeat ll table"); /* Nothing to do */ } }; @@ -358,17 +350,13 @@ fn maybe_update_fse_tables( match modes.of_mode() { ModeType::FSECompressed => { let bytes = scratch.offsets.build_decoder(of_source, OF_MAX_LOG)?; - if crate::VERBOSE { - println!("Updating of table"); - println!("Used bytes: {}", bytes); - } + vprintln!("Updating of table"); + vprintln!("Used bytes: {}", bytes); bytes_read += bytes; scratch.of_rle = None; } ModeType::RLE => { - if crate::VERBOSE { - println!("Use RLE of table"); - } + vprintln!("Use RLE of table"); if of_source.is_empty() { return Err(DecodeSequenceError::MissingByteForRleOfTable); } @@ -376,9 +364,7 @@ fn maybe_update_fse_tables( scratch.of_rle = Some(of_source[0]); } ModeType::Predefined => { - if crate::VERBOSE { - println!("Use predefined of table"); - } + vprintln!("Use predefined of table"); scratch.offsets.build_from_probabilities( OF_DEFAULT_ACC_LOG, &Vec::from(&OFFSET_DEFAULT_DISTRIBUTION[..]), @@ -386,9 +372,7 @@ fn maybe_update_fse_tables( scratch.of_rle = None; } ModeType::Repeat => { - if crate::VERBOSE { - println!("Repeat of table"); - } + vprintln!("Repeat of table"); /* Nothing to do */ } }; @@ -399,16 +383,12 @@ fn maybe_update_fse_tables( ModeType::FSECompressed => { let bytes = scratch.match_lengths.build_decoder(ml_source, ML_MAX_LOG)?; bytes_read += bytes; - if crate::VERBOSE { - println!("Updating ml table"); - println!("Used bytes: {}", bytes); - } + vprintln!("Updating ml table"); + vprintln!("Used bytes: {}", bytes); scratch.ml_rle = None; } ModeType::RLE => { - if crate::VERBOSE { - println!("Use RLE ml table"); - } + vprintln!("Use RLE ml table"); if ml_source.is_empty() { return Err(DecodeSequenceError::MissingByteForRleMlTable); } @@ -416,9 +396,7 @@ fn maybe_update_fse_tables( scratch.ml_rle = Some(ml_source[0]); } ModeType::Predefined => { - if crate::VERBOSE { - println!("Use predefined ml table"); - } + vprintln!("Use predefined ml table"); scratch.match_lengths.build_from_probabilities( ML_DEFAULT_ACC_LOG, &Vec::from(&MATCH_LENGTH_DEFAULT_DISTRIBUTION[..]), @@ -426,9 +404,7 @@ fn maybe_update_fse_tables( scratch.ml_rle = None; } ModeType::Repeat => { - if crate::VERBOSE { - println!("Repeat ml table"); - } + vprintln!("Repeat ml table"); /* Nothing to do */ } }; @@ -463,10 +439,14 @@ fn test_ll_default() { ) .unwrap(); + #[cfg(feature = "std")] for idx in 0..table.decode.len() { - println!( + std::println!( "{:3}: {:3} {:3} {:3}", - idx, table.decode[idx].symbol, table.decode[idx].num_bits, table.decode[idx].base_line + idx, + table.decode[idx].symbol, + table.decode[idx].num_bits, + table.decode[idx].base_line ); } diff --git a/src/frame.rs b/src/frame.rs index 1eb3f89..b9f32e0 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -1,5 +1,8 @@ -use std::convert::TryInto; -use std::io; +use crate::io::{Error, Read}; +use core::convert::TryInto; + +use alloc::vec; +use alloc::vec::Vec; pub const MAGIC_NUM: u32 = 0xFD2F_B528; pub const MIN_WINDOW_SIZE: u64 = 1024; @@ -214,22 +217,21 @@ impl Frame { #[non_exhaustive] pub enum ReadFrameHeaderError { #[error("Error while reading magic number: {0}")] - MagicNumberReadError(#[source] io::Error), + MagicNumberReadError(#[source] Error), #[error("Error while reading frame descriptor: {0}")] - FrameDescriptorReadError(#[source] io::Error), + FrameDescriptorReadError(#[source] Error), #[error(transparent)] InvalidFrameDescriptor(#[from] FrameDescriptorError), #[error("Error while reading window descriptor: {0}")] - WindowDescriptorReadError(#[source] io::Error), + WindowDescriptorReadError(#[source] Error), #[error("Error while reading dictionary id: {0}")] - DictionaryIdReadError(#[source] io::Error), + DictionaryIdReadError(#[source] Error), #[error("Error while reading frame content size: {0}")] - FrameContentSizeReadError(#[source] io::Error), + FrameContentSizeReadError(#[source] Error), #[error("SkippableFrame encountered with MagicNumber 0x{0:X} and length {1} bytes")] SkipFrame(u32, u32), } -use std::io::Read; pub fn read_frame_header(mut r: impl Read) -> Result<(Frame, u8), ReadFrameHeaderError> { use ReadFrameHeaderError as err; let mut buf = [0u8; 4]; diff --git a/src/frame_decoder.rs b/src/frame_decoder.rs index 560e828..db9ca76 100644 --- a/src/frame_decoder.rs +++ b/src/frame_decoder.rs @@ -2,10 +2,11 @@ use super::frame; use crate::decoding::dictionary::Dictionary; use crate::decoding::scratch::DecoderScratch; use crate::decoding::{self, dictionary}; -use std::collections::HashMap; -use std::convert::TryInto; -use std::hash::Hasher; -use std::io::{self, Read}; +use crate::io::{Error, Read, Write}; +use alloc::collections::BTreeMap; +use alloc::vec::Vec; +use core::convert::TryInto; +use core::hash::Hasher; /// This implements a decoder for zstd frames. This decoder is able to decode frames only partially and gives control /// over how many bytes/blocks will be decoded at a time (so you don't have to decode a 10GB file into memory all at once). @@ -17,11 +18,15 @@ use std::io::{self, Read}; /// Workflow is as follows: /// ``` /// use ruzstd::frame_decoder::BlockDecodingStrategy; -/// use std::io::Read; -/// use std::io::Write; /// +/// # #[cfg(feature = "std")] +/// use std::io::{Read, Write}; /// -/// fn decode_this(mut file: impl std::io::Read) { +/// // no_std environments can use the crate's own Read traits +/// # #[cfg(not(feature = "std"))] +/// use ruzstd::io::{Read, Write}; +/// +/// fn decode_this(mut file: impl Read) { /// //Create a new decoder /// let mut frame_dec = ruzstd::FrameDecoder::new(); /// let mut result = Vec::new(); @@ -50,12 +55,13 @@ use std::io::{self, Read}; /// } /// /// fn do_something(data: &[u8]) { +/// # #[cfg(feature = "std")] /// std::io::stdout().write_all(data).unwrap(); /// } /// ``` pub struct FrameDecoder { state: Option, - dicts: HashMap, + dicts: BTreeMap, } struct FrameDecoderState { @@ -92,13 +98,13 @@ pub enum FrameDecoderError { #[error("Failed to parse block header: {0}")] FailedToReadBlockBody(decoding::block_decoder::DecodeBlockContentError), #[error("Failed to read checksum: {0}")] - FailedToReadChecksum(#[source] io::Error), + FailedToReadChecksum(#[source] Error), #[error("Decoder must initialized or reset before using it")] NotYetInitialized, #[error("Decoder encountered error while initializing: {0}")] FailedToInitialize(frame::FrameHeaderError), #[error("Decoder encountered error while draining the decodebuffer: {0}")] - FailedToDrainDecodebuffer(#[source] io::Error), + FailedToDrainDecodebuffer(#[source] Error), #[error("Target must have at least as many bytes as the contentsize of the frame reports")] TargetTooSmall, #[error("Frame header specified dictionary id that wasnt provided by add_dict() or reset_with_dict()")] @@ -158,7 +164,7 @@ impl FrameDecoder { pub fn new() -> FrameDecoder { FrameDecoder { state: None, - dicts: HashMap::new(), + dicts: BTreeMap::new(), } } @@ -319,25 +325,21 @@ impl FrameDecoder { let buffer_size_before = state.decoder_scratch.buffer.len(); let block_counter_before = state.block_counter; loop { - if crate::VERBOSE { - println!("################"); - println!("Next Block: {}", state.block_counter); - println!("################"); - } + vprintln!("################"); + vprintln!("Next Block: {}", state.block_counter); + vprintln!("################"); let (block_header, block_header_size) = block_dec .read_block_header(&mut source) .map_err(err::FailedToReadBlockHeader)?; state.bytes_read_counter += u64::from(block_header_size); - if crate::VERBOSE { - println!(); - println!( - "Found {} block with size: {}, which will be of size: {}", - block_header.block_type, - block_header.content_size, - block_header.decompressed_size - ); - } + vprintln!(); + vprintln!( + "Found {} block with size: {}, which will be of size: {}", + block_header.block_type, + block_header.content_size, + block_header.decompressed_size + ); let bytes_read_in_block_body = block_dec .decode_block_content(&block_header, &mut state.decoder_scratch, &mut source) @@ -346,9 +348,7 @@ impl FrameDecoder { state.block_counter += 1; - if crate::VERBOSE { - println!("Output: {}", state.decoder_scratch.buffer.len()); - } + vprintln!("Output: {}", state.decoder_scratch.buffer.len()); if block_header.last_block { state.frame_finished = true; @@ -396,7 +396,7 @@ impl FrameDecoder { /// Collect bytes and retain window_size bytes while decoding is still going on. /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes - pub fn collect_to_writer(&mut self, w: impl std::io::Write) -> Result { + pub fn collect_to_writer(&mut self, w: impl Write) -> Result { let finished = self.is_finished(); let state = match &mut self.state { None => return Ok(0), @@ -554,8 +554,8 @@ impl FrameDecoder { /// Read bytes from the decode_buffer that are no longer needed. While the frame is not yet finished /// this will retain window_size bytes, else it will drain it completely -impl std::io::Read for FrameDecoder { - fn read(&mut self, target: &mut [u8]) -> std::result::Result { +impl Read for FrameDecoder { + fn read(&mut self, target: &mut [u8]) -> Result { let state = match &mut self.state { None => return Ok(0), Some(s) => s, diff --git a/src/fse/fse_decoder.rs b/src/fse/fse_decoder.rs index 1847da7..5cd776f 100644 --- a/src/fse/fse_decoder.rs +++ b/src/fse/fse_decoder.rs @@ -1,5 +1,6 @@ use crate::decoding::bit_reader::BitReader; use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; +use alloc::vec::Vec; #[derive(Clone)] pub struct FSETable { diff --git a/src/huff0/huff0_decoder.rs b/src/huff0/huff0_decoder.rs index 831ddd6..3da62d6 100644 --- a/src/huff0/huff0_decoder.rs +++ b/src/huff0/huff0_decoder.rs @@ -1,5 +1,6 @@ use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; use crate::fse::{FSEDecoder, FSEDecoderError, FSETable, FSETableError}; +use alloc::vec::Vec; #[derive(Clone)] pub struct HuffmanTable { @@ -182,12 +183,10 @@ impl HuffmanTable { }); } - if crate::VERBOSE { - println!( - "Building fse table for huffman weights used: {}", - bytes_used_by_fse_header - ); - } + vprintln!( + "Building fse table for huffman weights used: {}", + bytes_used_by_fse_header + ); let mut dec1 = FSEDecoder::new(&self.fse_table); let mut dec2 = FSEDecoder::new(&self.fse_table); diff --git a/src/io.rs b/src/io.rs new file mode 100644 index 0000000..6970cd1 --- /dev/null +++ b/src/io.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "std")] +pub use std::io::{Error, ErrorKind, Read, Write}; diff --git a/src/io_nostd.rs b/src/io_nostd.rs new file mode 100644 index 0000000..1e5d141 --- /dev/null +++ b/src/io_nostd.rs @@ -0,0 +1,152 @@ +use alloc::boxed::Box; + +#[non_exhaustive] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub enum ErrorKind { + Interrupted, + UnexpectedEof, + WouldBlock, + Other, +} + +impl ErrorKind { + fn as_str(&self) -> &'static str { + use ErrorKind::*; + match *self { + Interrupted => "operation interrupted", + UnexpectedEof => "unexpected end of file", + WouldBlock => "operation would block", + Other => "other error", + } + } +} + +impl core::fmt::Display for ErrorKind { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(self.as_str()) + } +} + +#[derive(Debug)] +pub struct Error { + kind: ErrorKind, + err: Option>, +} + +impl Error { + pub fn new(kind: ErrorKind, err: E) -> Self + where + E: Into>, + { + Self { + kind, + err: Some(err.into()), + } + } + + pub fn from(kind: ErrorKind) -> Self { + Self { kind, err: None } + } + + pub fn kind(&self) -> ErrorKind { + self.kind + } + + pub fn get_ref(&self) -> Option<&(dyn core::error::Error + Send + Sync + 'static)> { + self.err.as_ref().map(|e| e.as_ref()) + } + + pub fn get_mut(&mut self) -> Option<&mut (dyn core::error::Error + Send + Sync + 'static)> { + self.err.as_mut().map(|e| e.as_mut()) + } + + pub fn into_inner(self) -> Option> { + self.err + } +} + +impl core::fmt::Display for Error { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(self.kind.as_str())?; + + if let Some(ref e) = self.err { + e.fmt(f)?; + } + + Ok(()) + } +} + +impl core::error::Error for Error {} + +impl From for Error { + fn from(value: ErrorKind) -> Self { + Self::from(value) + } +} + +pub trait Read { + fn read(&mut self, buf: &mut [u8]) -> Result; + + fn read_exact(&mut self, mut buf: &mut [u8]) -> Result<(), Error> { + while !buf.is_empty() { + match self.read(buf) { + Ok(0) => break, + Ok(n) => { + let tmp = buf; + buf = &mut tmp[n..]; + } + Err(ref e) if e.kind() == ErrorKind::Interrupted => {} + Err(e) => return Err(e), + } + } + if !buf.is_empty() { + Err(Error::from(ErrorKind::UnexpectedEof)) + } else { + Ok(()) + } + } +} + +impl Read for &[u8] { + fn read(&mut self, buf: &mut [u8]) -> Result { + let size = core::cmp::min(self.len(), buf.len()); + let (to_copy, rest) = self.split_at(size); + + if size == 1 { + buf[0] = to_copy[0]; + } else { + buf[..size].copy_from_slice(to_copy); + } + + *self = rest; + Ok(size) + } +} + +impl<'a, T> Read for &'a mut T +where + T: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> Result { + (*self).read(buf) + } +} + +pub trait Write { + fn write(&mut self, buf: &[u8]) -> Result; + fn flush(&mut self) -> Result<(), Error>; +} + +impl<'a, T> Write for &'a mut T +where + T: Write, +{ + fn write(&mut self, buf: &[u8]) -> Result { + (*self).write(buf) + } + + fn flush(&mut self) -> Result<(), Error> { + (*self).flush() + } +} diff --git a/src/lib.rs b/src/lib.rs index e079ad3..aecef62 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,23 @@ +#![no_std] #![deny(trivial_casts, trivial_numeric_casts, rust_2018_idioms)] +#![cfg_attr(not(feature = "std"), feature(error_in_core))] + +#[cfg(feature = "std")] +extern crate std; + +extern crate alloc; + +#[cfg(feature = "std")] +pub const VERBOSE: bool = false; + +macro_rules! vprintln { + ($($x:expr),*) => { + #[cfg(feature = "std")] + if crate::VERBOSE { + std::println!($($x),*); + } + } +} pub mod blocks; pub mod decoding; @@ -9,7 +28,15 @@ pub mod huff0; pub mod streaming_decoder; mod tests; -pub const VERBOSE: bool = false; +#[cfg(feature = "std")] +pub mod io; + +#[cfg(not(feature = "std"))] +pub mod io_nostd; + +#[cfg(not(feature = "std"))] +pub use io_nostd as io; + pub use frame_decoder::BlockDecodingStrategy; pub use frame_decoder::FrameDecoder; pub use streaming_decoder::StreamingDecoder; diff --git a/src/streaming_decoder.rs b/src/streaming_decoder.rs index 2613a36..d033fbc 100644 --- a/src/streaming_decoder.rs +++ b/src/streaming_decoder.rs @@ -1,5 +1,5 @@ use crate::frame_decoder::{BlockDecodingStrategy, FrameDecoder, FrameDecoderError}; -use std::io::Read; +use crate::io::{Error, ErrorKind, Read}; /// High level decoder that implements a io::Read that can be used with /// io::Read::read_to_end / io::Read::read_exact or passing this to another library / module as a source for the decoded content @@ -32,7 +32,7 @@ impl StreamingDecoder { } impl Read for StreamingDecoder { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + fn read(&mut self, buf: &mut [u8]) -> Result { if self.decoder.is_finished() && self.decoder.can_collect() == 0 { //No more bytes can ever be decoded return Ok(0); @@ -52,9 +52,9 @@ impl Read for StreamingDecoder { ) { Ok(_) => { /*Nothing to do*/ } Err(e) => { - let err = std::io::Error::new( - std::io::ErrorKind::Other, - format!("Error in the zstd decoder: {:?}", e), + let err = Error::new( + ErrorKind::Other, + alloc::format!("Error in the zstd decoder: {:?}", e), ); return Err(err); } diff --git a/src/tests/decode_corpus.rs b/src/tests/decode_corpus.rs index ea1a723..a8e2168 100644 --- a/src/tests/decode_corpus.rs +++ b/src/tests/decode_corpus.rs @@ -1,8 +1,13 @@ #[test] fn test_decode_corpus_files() { + extern crate std; use crate::frame_decoder; + use alloc::borrow::ToOwned; + use alloc::string::{String, ToString}; + use alloc::vec::Vec; use std::fs; use std::io::Read; + use std::println; let mut success_counter = 0; let mut fail_counter_diff = 0; diff --git a/src/tests/dict_test.rs b/src/tests/dict_test.rs index a8a5fd4..9cf105c 100644 --- a/src/tests/dict_test.rs +++ b/src/tests/dict_test.rs @@ -1,6 +1,7 @@ #[test] fn test_dict_parsing() { use crate::decoding::dictionary::Dictionary; + use alloc::vec; let mut raw = vec![0u8; 8]; // correct magic num @@ -75,9 +76,14 @@ fn test_dict_parsing() { #[test] fn test_dict_decoding() { + extern crate std; use crate::frame_decoder; + use alloc::borrow::ToOwned; + use alloc::string::{String, ToString}; + use alloc::vec::Vec; use std::fs; use std::io::Read; + use std::println; let mut success_counter = 0; let mut fail_counter_diff = 0; diff --git a/src/tests/fuzz_regressions.rs b/src/tests/fuzz_regressions.rs index 2e293af..bc675ca 100644 --- a/src/tests/fuzz_regressions.rs +++ b/src/tests/fuzz_regressions.rs @@ -1,6 +1,8 @@ #[test] fn test_all_artifacts() { + extern crate std; use crate::frame_decoder; + use std::borrow::ToOwned; use std::fs; use std::fs::File; diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 070565c..c38cba1 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,4 +1,28 @@ #[cfg(test)] +use alloc::vec; + +#[cfg(test)] +use alloc::vec::Vec; + +#[cfg(test)] +extern crate std; + +#[cfg(all(test, not(feature = "std")))] +impl crate::io_nostd::Read for std::fs::File { + fn read(&mut self, buf: &mut [u8]) -> Result { + std::io::Read::read(self, buf).map_err(|e| { + if e.get_ref().is_none() { + crate::io_nostd::Error::from(crate::io_nostd::ErrorKind::Other) + } else { + crate::io_nostd::Error::new( + crate::io_nostd::ErrorKind::Other, + e.into_inner().unwrap(), + ) + } + }) + } +} + #[test] fn skippable_frame() { use crate::frame; @@ -131,16 +155,16 @@ fn test_decode_from_to() { match frame_dec.get_checksum_from_data() { Some(chksum) => { if frame_dec.get_calculated_checksum().unwrap() != chksum { - println!( + std::println!( "Checksum did not match! From data: {}, calculated while decoding: {}\n", chksum, frame_dec.get_calculated_checksum().unwrap() ); } else { - println!("Checksums are ok!\n"); + std::println!("Checksums are ok!\n"); } } - None => println!("No checksums to test\n"), + None => std::println!("No checksums to test\n"), } let original_f = File::open("./decodecorpus_files/z000088").unwrap(); @@ -163,7 +187,7 @@ fn test_decode_from_to() { for idx in 0..min { if original[idx] != result[idx] { counter += 1; - //println!( + //std::println!( // "Original {:3} not equal to result {:3} at byte: {}", // original[idx], result[idx], idx, //); @@ -204,10 +228,10 @@ fn test_specific_file() { let original_f = fs::File::open("./decodecorpus_files/z000088").unwrap(); let original: Vec = original_f.bytes().map(|x| x.unwrap()).collect(); - println!("Results for file: {}", path); + std::println!("Results for file: {}", path); if original.len() != result.len() { - println!( + std::println!( "Result has wrong length: {}, should be: {}", result.len(), original.len() @@ -223,18 +247,19 @@ fn test_specific_file() { for idx in 0..min { if original[idx] != result[idx] { counter += 1; - //println!( + //std::println!( // "Original {:3} not equal to result {:3} at byte: {}", // original[idx], result[idx], idx, //); } } if counter > 0 { - println!("Result differs in at least {} bytes from original", counter); + std::println!("Result differs in at least {} bytes from original", counter); } } #[test] +#[cfg(feature = "std")] fn test_streaming() { use std::fs; use std::io::Read; @@ -265,7 +290,7 @@ fn test_streaming() { for idx in 0..min { if original[idx] != result[idx] { counter += 1; - //println!( + //std::println!( // "Original {:3} not equal to result {:3} at byte: {}", // original[idx], result[idx], idx, //); @@ -288,7 +313,91 @@ fn test_streaming() { let original_f = fs::File::open("./decodecorpus_files/z000068").unwrap(); let original: Vec = original_f.bytes().map(|x| x.unwrap()).collect(); - println!("Results for file:"); + std::println!("Results for file:"); + + if original.len() != result.len() { + panic!( + "Result has wrong length: {}, should be: {}", + result.len(), + original.len() + ); + } + + let mut counter = 0; + let min = if original.len() < result.len() { + original.len() + } else { + result.len() + }; + for idx in 0..min { + if original[idx] != result[idx] { + counter += 1; + //std::println!( + // "Original {:3} not equal to result {:3} at byte: {}", + // original[idx], result[idx], idx, + //); + } + } + if counter > 0 { + panic!("Result differs in at least {} bytes from original", counter); + } +} + +#[test] +#[cfg(not(feature = "std"))] +fn test_streaming_no_std() { + use crate::io::Read; + + let content = include_bytes!("../../decodecorpus_files/z000088.zst"); + let mut content = content.as_slice(); + let mut stream = crate::streaming_decoder::StreamingDecoder::new(&mut content).unwrap(); + + let original = include_bytes!("../../decodecorpus_files/z000088"); + let mut result = Vec::new(); + result.resize(original.len(), 0); + Read::read_exact(&mut stream, &mut result).unwrap(); + + if original.len() != result.len() { + panic!( + "Result has wrong length: {}, should be: {}", + result.len(), + original.len() + ); + } + + let mut counter = 0; + let min = if original.len() < result.len() { + original.len() + } else { + result.len() + }; + for idx in 0..min { + if original[idx] != result[idx] { + counter += 1; + //std::println!( + // "Original {:3} not equal to result {:3} at byte: {}", + // original[idx], result[idx], idx, + //); + } + } + if counter > 0 { + panic!("Result differs in at least {} bytes from original", counter); + } + + // Test resetting to a new file while keeping the old decoder + + let content = include_bytes!("../../decodecorpus_files/z000068.zst"); + let mut content = content.as_slice(); + let mut stream = + crate::streaming_decoder::StreamingDecoder::new_with_decoder(&mut content, stream.inner()) + .unwrap(); + + let original = include_bytes!("../../decodecorpus_files/z000068"); + let mut result = Vec::new(); + result.resize(original.len(), 0); + Read::read_exact(&mut stream, &mut result).unwrap(); + + std::println!("Results for file:"); if original.len() != result.len() { panic!( @@ -307,7 +416,7 @@ fn test_streaming() { for idx in 0..min { if original[idx] != result[idx] { counter += 1; - //println!( + //std::println!( // "Original {:3} not equal to result {:3} at byte: {}", // original[idx], result[idx], idx, //);