From 06433dec34885ad960b86f3c2ab956d05f0acabe Mon Sep 17 00:00:00 2001 From: Moritz Borcherding Date: Wed, 29 Mar 2023 19:53:25 +0200 Subject: [PATCH] start overhauling dict API --- src/decoding/scratch.rs | 30 +++----- src/frame.rs | 146 +++++++++++-------------------------- src/frame_decoder.rs | 52 ++----------- src/fse/fse_decoder.rs | 10 ++- src/huff0/huff0_decoder.rs | 11 ++- src/tests/dict_test.rs | 3 +- src/tests/mod.rs | 6 +- 7 files changed, 82 insertions(+), 176 deletions(-) diff --git a/src/decoding/scratch.rs b/src/decoding/scratch.rs index 5b905a7..bd6e5f9 100644 --- a/src/decoding/scratch.rs +++ b/src/decoding/scratch.rs @@ -58,29 +58,13 @@ impl DecoderScratch { } pub fn use_dict(&mut self, dict: &Dictionary) { - self.fse = dict.fse.clone(); - self.huf = dict.huf.clone(); + self.fse.reinit_from(&dict.fse); + self.huf.table.reinit_from(&dict.huf.table); self.offset_hist = dict.offset_hist; self.buffer.dict_content = dict.dict_content.clone(); } - - /// parses the dictionary and set the tables - /// it returns the dict_id for checking with the frame's dict_id - pub fn load_dict( - &mut self, - raw: &[u8], - ) -> Result { - let dict = super::dictionary::Dictionary::decode_dict(raw)?; - - self.huf = dict.huf.clone(); - self.fse = dict.fse.clone(); - self.offset_hist = dict.offset_hist; - self.buffer.dict_content = dict.dict_content.clone(); - Ok(dict.id) - } } -#[derive(Clone)] pub struct HuffmanScratch { pub table: HuffmanTable, } @@ -99,7 +83,6 @@ impl Default for HuffmanScratch { } } -#[derive(Clone)] pub struct FSEScratch { pub offsets: FSETable, pub of_rle: Option, @@ -120,6 +103,15 @@ impl FSEScratch { ml_rle: None, } } + + pub fn reinit_from(&mut self, other: &Self) { + self.offsets.reinit_from(&other.offsets); + self.literal_lengths.reinit_from(&other.literal_lengths); + self.match_lengths.reinit_from(&other.match_lengths); + self.of_rle = other.of_rle; + self.ll_rle = other.ll_rle; + self.ml_rle = other.ml_rle; + } } impl Default for FSEScratch { diff --git a/src/frame.rs b/src/frame.rs index b9f32e0..50c7086 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -1,23 +1,17 @@ use crate::io::{Error, Read}; -use core::convert::TryInto; - -use alloc::vec; -use alloc::vec::Vec; - pub const MAGIC_NUM: u32 = 0xFD2F_B528; pub const MIN_WINDOW_SIZE: u64 = 1024; pub const MAX_WINDOW_SIZE: u64 = (1 << 41) + 7 * (1 << 38); pub struct Frame { - magic_num: u32, pub header: FrameHeader, } pub struct FrameHeader { pub descriptor: FrameDescriptor, window_descriptor: u8, - dict_id: Vec, - frame_content_size: Vec, + dict_id: Option, + frame_content_size: u64, } pub struct FrameDescriptor(u8); @@ -100,7 +94,7 @@ pub enum FrameHeaderError { impl FrameHeader { pub fn window_size(&self) -> Result { if self.descriptor.single_segment_flag() { - self.frame_content_size() + Ok(self.frame_content_size()) } else { let exp = self.window_descriptor >> 3; let mantissa = self.window_descriptor & 0x7; @@ -123,93 +117,12 @@ impl FrameHeader { } } - pub fn dictionary_id(&self) -> Result, FrameHeaderError> { - if self.descriptor.dict_id_flag() == 0 { - Ok(None) - } else { - let bytes = self.descriptor.dictionary_id_bytes()?; - if self.dict_id.len() != bytes as usize { - Err(FrameHeaderError::DictIdTooSmall { - got: self.dict_id.len(), - expected: bytes as usize, - }) - } else { - let mut value: u32 = 0; - let mut shift = 0; - for x in &self.dict_id { - value |= u32::from(*x) << shift; - shift += 8; - } - - Ok(Some(value)) - } - } + pub fn dictionary_id(&self) -> Option { + self.dict_id } - pub fn frame_content_size(&self) -> Result { - let bytes = self.descriptor.frame_content_size_bytes()?; - - if self.frame_content_size.len() != (bytes as usize) { - return Err(FrameHeaderError::MismatchedFrameSize { - got: self.frame_content_size.len(), - expected: bytes, - }); - } - - match bytes { - 0 => Err(FrameHeaderError::FrameSizeIsZero), - 1 => Ok(u64::from(self.frame_content_size[0])), - 2 => { - let val = (u64::from(self.frame_content_size[1]) << 8) - + (u64::from(self.frame_content_size[0])); - Ok(val + 256) //this weird offset is from the documentation. Only if bytes == 2 - } - 4 => { - let val = self.frame_content_size[..4] - .try_into() - .expect("optimized away"); - let val = u32::from_le_bytes(val); - Ok(u64::from(val)) - } - 8 => { - let val = self.frame_content_size[..8] - .try_into() - .expect("optimized away"); - let val = u64::from_le_bytes(val); - Ok(val) - } - other => Err(FrameHeaderError::InvalidFrameSize { got: other }), - } - } -} - -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum FrameCheckError { - #[error("magic_num wrong. Is: {got}. Should be: {MAGIC_NUM}")] - WrongMagicNum { got: u32 }, - #[error("Reserved Flag set. Must be zero")] - ReservedFlagSet, - #[error(transparent)] - FrameHeaderError(#[from] FrameHeaderError), -} - -impl Frame { - pub fn check_valid(&self) -> Result<(), FrameCheckError> { - if self.magic_num != MAGIC_NUM { - Err(FrameCheckError::WrongMagicNum { - got: self.magic_num, - }) - } else if self.header.descriptor.reserved_flag() { - Err(FrameCheckError::ReservedFlagSet) - } else { - self.header.dictionary_id()?; - self.header.window_size()?; - if self.header.descriptor.single_segment_flag() { - self.header.frame_content_size()?; - } - Ok(()) - } + pub fn frame_content_size(&self) -> u64 { + self.frame_content_size } } @@ -218,6 +131,8 @@ impl Frame { pub enum ReadFrameHeaderError { #[error("Error while reading magic number: {0}")] MagicNumberReadError(#[source] Error), + #[error("Read wrong magic number: 0x{0:X}")] + BadMagicNumber(u32), #[error("Error while reading frame descriptor: {0}")] FrameDescriptorReadError(#[source] Error), #[error(transparent)] @@ -237,6 +152,7 @@ pub fn read_frame_header(mut r: impl Read) -> Result<(Frame, u8), ReadFrameHeade let mut buf = [0u8; 4]; r.read_exact(&mut buf).map_err(err::MagicNumberReadError)?; + let mut bytes_read = 4; let magic_num = u32::from_le_bytes(buf); // Skippable frames have a magic number in this interval @@ -247,7 +163,9 @@ pub fn read_frame_header(mut r: impl Read) -> Result<(Frame, u8), ReadFrameHeade return Err(ReadFrameHeaderError::SkipFrame(magic_num, skip_size)); } - let mut bytes_read = 4; + if magic_num != MAGIC_NUM { + return Err(ReadFrameHeaderError::BadMagicNumber(magic_num)); + } r.read_exact(&mut buf[0..1]) .map_err(err::FrameDescriptorReadError)?; @@ -257,8 +175,8 @@ pub fn read_frame_header(mut r: impl Read) -> Result<(Frame, u8), ReadFrameHeade let mut frame_header = FrameHeader { descriptor: FrameDescriptor(desc.0), - dict_id: vec![0; desc.dictionary_id_bytes()? as usize], - frame_content_size: vec![0; desc.frame_content_size_bytes()? as usize], + dict_id: None, + frame_content_size: 0, window_descriptor: 0, }; @@ -269,20 +187,38 @@ pub fn read_frame_header(mut r: impl Read) -> Result<(Frame, u8), ReadFrameHeade bytes_read += 1; } - if !frame_header.dict_id.is_empty() { - r.read_exact(frame_header.dict_id.as_mut_slice()) - .map_err(err::DictionaryIdReadError)?; - bytes_read += frame_header.dict_id.len(); + let dict_id_len = desc.dictionary_id_bytes()? as usize; + if dict_id_len != 0 { + let buf = &mut buf[..dict_id_len]; + r.read_exact(buf).map_err(err::DictionaryIdReadError)?; + bytes_read += dict_id_len; + let mut dict_id = 0u32; + for i in 0..dict_id_len { + dict_id += (buf[i] as u32) << (8 * i); + } + if dict_id != 0 { + frame_header.dict_id = Some(dict_id); + } } - if !frame_header.frame_content_size.is_empty() { - r.read_exact(frame_header.frame_content_size.as_mut_slice()) + let fcs_len = desc.frame_content_size_bytes()? as usize; + if fcs_len != 0 { + let mut fcs_buf = [0u8; 8]; + let fcs_buf = &mut fcs_buf[..fcs_len]; + r.read_exact(fcs_buf) .map_err(err::FrameContentSizeReadError)?; - bytes_read += frame_header.frame_content_size.len(); + bytes_read += fcs_len; + let mut fcs = 0u64; + for i in 0..fcs_len { + fcs += (fcs_buf[i] as u64) << (8 * i); + } + if fcs_len == 2 { + fcs += 256; + } + frame_header.frame_content_size = fcs; } let frame: Frame = Frame { - magic_num, header: frame_header, }; diff --git a/src/frame_decoder.rs b/src/frame_decoder.rs index db9ca76..b5f5879 100644 --- a/src/frame_decoder.rs +++ b/src/frame_decoder.rs @@ -87,8 +87,6 @@ pub enum FrameDecoderError { ReadFrameHeaderError(#[from] frame::ReadFrameHeaderError), #[error(transparent)] FrameHeaderError(#[from] frame::FrameHeaderError), - #[error(transparent)] - FrameCheckError(#[from] frame::FrameCheckError), #[error("Specified window_size is too big; Requested: {requested}, Max: {MAX_WINDOW_SIZE}")] WindowSizeTooBig { requested: u64 }, #[error(transparent)] @@ -117,7 +115,6 @@ impl FrameDecoderState { pub fn new(source: impl Read) -> Result { let (frame, header_size) = frame::read_frame_header(source)?; let window_size = frame.header.window_size()?; - frame.check_valid()?; Ok(FrameDecoderState { frame, frame_finished: false, @@ -132,7 +129,6 @@ impl FrameDecoderState { pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> { let (frame, header_size) = frame::read_frame_header(source)?; let window_size = frame.header.window_size()?; - frame.check_valid()?; if window_size > MAX_WINDOW_SIZE { return Err(FrameDecoderError::WindowSizeTooBig { @@ -177,14 +173,6 @@ impl FrameDecoder { pub fn init(&mut self, source: impl Read) -> Result<(), FrameDecoderError> { self.reset(source) } - /// Like init but provides the dict to use for the next frame - pub fn init_with_dict( - &mut self, - source: impl Read, - dict: &[u8], - ) -> Result<(), FrameDecoderError> { - self.reset_with_dict(source, dict) - } /// reset() will allocate all needed buffers if it is the first time this decoder is used /// else they just reset these buffers with not further allocations @@ -202,37 +190,17 @@ impl FrameDecoder { } } - /// Like reset but provides the dict to use for the next frame - pub fn reset_with_dict( - &mut self, - source: impl Read, - dict: &[u8], - ) -> Result<(), FrameDecoderError> { - self.reset(source)?; - if let Some(state) = &mut self.state { - let id = state.decoder_scratch.load_dict(dict)?; - state.using_dict = Some(id); - }; - Ok(()) - } - /// Add a dict to the FrameDecoder that can be used when needed. The FrameDecoder uses the appropriate one dynamically - pub fn add_dict(&mut self, raw_dict: &[u8]) -> Result<(), FrameDecoderError> { - let dict = Dictionary::decode_dict(raw_dict)?; + pub fn add_dict(&mut self, dict: Dictionary) -> Result<(), FrameDecoderError> { self.dicts.insert(dict.id, dict); Ok(()) } /// Returns how many bytes the frame contains after decompression - pub fn content_size(&self) -> Option { - let state = match &self.state { - None => return Some(0), - Some(s) => s, - }; - - match state.frame.header.frame_content_size() { - Err(_) => None, - Ok(x) => Some(x), + pub fn content_size(&self) -> u64 { + match &self.state { + None => return 0, + Some(s) => s.frame.header.frame_content_size(), } } @@ -303,10 +271,7 @@ impl FrameDecoder { use FrameDecoderError as err; let state = self.state.as_mut().ok_or(err::NotYetInitialized)?; - if let Some(id) = state.frame.header.dictionary_id().map_err( - //should never happen we check this directly after decoding the frame header - err::FailedToInitialize, - )? { + if let Some(id) = state.frame.header.dictionary_id() { match state.using_dict { Some(using_id) => { //happy @@ -482,10 +447,7 @@ impl FrameDecoder { return Ok((4, 0)); } - if let Some(id) = state.frame.header.dictionary_id().map_err( - //should never happen we check this directly after decoding the frame header - err::FailedToInitialize, - )? { + if let Some(id) = state.frame.header.dictionary_id() { match state.using_dict { Some(using_id) => { //happy diff --git a/src/fse/fse_decoder.rs b/src/fse/fse_decoder.rs index 5cd776f..21868ff 100644 --- a/src/fse/fse_decoder.rs +++ b/src/fse/fse_decoder.rs @@ -2,7 +2,6 @@ use crate::decoding::bit_reader::BitReader; use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; use alloc::vec::Vec; -#[derive(Clone)] pub struct FSETable { pub decode: Vec, //used to decode symbols, and calculate the next state @@ -114,6 +113,15 @@ impl FSETable { } } + pub fn reinit_from(&mut self, other: &Self) { + self.reset(); + self.symbol_counter.extend_from_slice(&other.symbol_counter); + self.symbol_probabilities + .extend_from_slice(&other.symbol_probabilities); + self.decode.extend_from_slice(&other.decode); + self.accuracy_log = other.accuracy_log; + } + pub fn reset(&mut self) { self.symbol_counter.clear(); self.symbol_probabilities.clear(); diff --git a/src/huff0/huff0_decoder.rs b/src/huff0/huff0_decoder.rs index 3da62d6..a8aa745 100644 --- a/src/huff0/huff0_decoder.rs +++ b/src/huff0/huff0_decoder.rs @@ -2,7 +2,6 @@ use crate::decoding::bit_reader_reverse::{BitReaderReversed, GetBitsError}; use crate::fse::{FSEDecoder, FSEDecoderError, FSETable, FSETableError}; use alloc::vec::Vec; -#[derive(Clone)] pub struct HuffmanTable { decode: Vec, @@ -135,6 +134,16 @@ impl HuffmanTable { } } + pub fn reinit_from(&mut self, other: &Self) { + self.reset(); + self.decode.extend_from_slice(&other.decode); + self.weights.extend_from_slice(&other.weights); + self.max_num_bits = other.max_num_bits; + self.bits.extend_from_slice(&other.bits); + self.rank_indexes.extend_from_slice(&other.rank_indexes); + self.fse_table.reinit_from(&other.fse_table); + } + pub fn reset(&mut self) { self.decode.clear(); self.weights.clear(); diff --git a/src/tests/dict_test.rs b/src/tests/dict_test.rs index 9cf105c..5bf76f9 100644 --- a/src/tests/dict_test.rs +++ b/src/tests/dict_test.rs @@ -105,7 +105,8 @@ fn test_dict_decoding() { }); let mut frame_dec = frame_decoder::FrameDecoder::new(); - frame_dec.add_dict(&dict).unwrap(); + let dict = crate::decoding::dictionary::Dictionary::decode_dict(&dict).unwrap(); + frame_dec.add_dict(dict).unwrap(); for file in files { let f = file.unwrap(); diff --git a/src/tests/mod.rs b/src/tests/mod.rs index c38cba1..0f25073 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -58,8 +58,7 @@ fn test_frame_header_reading() { use std::fs; let mut content = fs::File::open("./decodecorpus_files/z000088.zst").unwrap(); - let (frame, _) = frame::read_frame_header(&mut content).unwrap(); - frame.check_valid().unwrap(); + let (_frame, _) = frame::read_frame_header(&mut content).unwrap(); } #[test] @@ -69,8 +68,7 @@ fn test_block_header_reading() { use std::fs; let mut content = fs::File::open("./decodecorpus_files/z000088.zst").unwrap(); - let (frame, _) = frame::read_frame_header(&mut content).unwrap(); - frame.check_valid().unwrap(); + let (_frame, _) = frame::read_frame_header(&mut content).unwrap(); let mut block_dec = decoding::block_decoder::new(); let block_header = block_dec.read_block_header(&mut content).unwrap();