Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor hashchainholder #12

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/complevel_estimator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use crate::{
hash_algorithm::HashAlgorithm,
hash_chain::DictionaryAddPolicy,
hash_chain_holder::{new_hash_chain_holder, HashChainHolderTrait},
hash_chain_holder::{new_hash_chain_holder, HashChainHolder},
preflate_constants,
preflate_input::PreflateInput,
preflate_parameter_estimator::PreflateStrategy,
Expand Down Expand Up @@ -40,7 +40,7 @@ pub struct CompLevelInfo {
struct CandidateInfo {
hash_algorithm: HashAlgorithm,
add_policy: DictionaryAddPolicy,
hash_chain: Box<dyn HashChainHolderTrait>,
hash_chain: Box<dyn HashChainHolder>,

longest_dist_at_hop_0: u32,
longest_dist_at_hop_1_plus: u32,
Expand Down Expand Up @@ -207,6 +207,13 @@ impl<'a> CompLevelEstimatorState<'a> {
wbits,
)));

// RandomVector candidate
candidates.push(Box::new(CandidateInfo::new(
add_policy,
HashAlgorithm::RandomVector,
wbits,
)));

CompLevelEstimatorState {
input,
candidates,
Expand Down
6 changes: 3 additions & 3 deletions src/hash_algorithm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::hash_chain::{HashChain, HashChainNormalize, HashChainNormalizeLibflate4};
use crate::hash_chain::{HashChain, HashChainAbs, HashChainNormalize, HashChainNormalizeLibflate4};

#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
pub enum HashAlgorithm {
Expand Down Expand Up @@ -207,7 +207,7 @@ const RANDOM_VECTOR: [u16; 768] = [
];

impl HashImplementation for RandomVectorHash {
type HashChainType = HashChainNormalize<RandomVectorHash>;
type HashChainType = HashChainAbs<RandomVectorHash>;

fn get_hash(&self, b: &[u8]) -> usize {
(RANDOM_VECTOR[b[0] as usize]
Expand All @@ -220,6 +220,6 @@ impl HashImplementation for RandomVectorHash {
}

fn new_hash_chain(self) -> Self::HashChainType {
crate::hash_chain::HashChainNormalize::<RandomVectorHash>::new(self)
crate::hash_chain::HashChainAbs::<RandomVectorHash>::new(self)
}
}
225 changes: 178 additions & 47 deletions src/hash_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,22 @@ pub enum DictionaryAddPolicy {
/// Add only the first and last substring of a match to the dictionary that are larger than the limit
AddFirstAndLast(u16),
}

trait InternalPosition: Copy + Clone + Eq + PartialEq + Default + std::fmt::Debug {
fn saturating_sub(&self, other: u16) -> Self;
fn to_index(self) -> usize;
fn inc(&self) -> Self;
fn from_absolute(pos: u32, total_shift: i32) -> Self;
fn is_valid(&self) -> bool;
fn dist(&self, pos: Self) -> u32;
}

#[derive(Default, Copy, Clone, Eq, PartialEq, Debug)]
struct InternalPosition {
struct InternalPositionRel {
pos: u16,
}

impl InternalPosition {
impl InternalPosition for InternalPositionRel {
fn saturating_sub(&self, other: u16) -> Self {
Self {
pos: self.pos.saturating_sub(other),
Expand All @@ -61,16 +71,47 @@ impl InternalPosition {
self.pos > 0
}

fn dist(&self, pos: InternalPosition) -> u32 {
fn dist(&self, pos: InternalPositionRel) -> u32 {
u32::from(self.pos - pos.pos)
}
}

#[derive(Default, Copy, Clone, Eq, PartialEq, Debug)]
struct InternalPositionAbs {
pos: u32,
}

impl InternalPosition for InternalPositionAbs {
fn saturating_sub(&self, _other: u16) -> Self {
unimplemented!()
}

fn to_index(self) -> usize {
(self.pos & 0x7fff) as usize
}

fn inc(&self) -> Self {
Self { pos: self.pos + 1 }
}

fn from_absolute(pos: u32, _total_shift: i32) -> Self {
Self { pos }
}

fn is_valid(&self) -> bool {
self.pos > 0
}

fn dist(&self, pos: Self) -> u32 {
u32::from(self.pos - pos.pos)
}
}

#[derive(DefaultBoxed)]
struct HashTable<H: HashImplementation> {
struct HashTable<H: HashImplementation, I: InternalPosition> {
/// Represents the head of the hash chain for a given hash value. In order
/// to find additional matches, you follow the prev chain from the head.
head: [InternalPosition; 65536],
head: [I; 65536],

/// Represents the number of following nodes in the chain for a given
/// position. For example, if chainDepth[100] == 5, then there are 5 more
Expand All @@ -90,25 +131,25 @@ struct HashTable<H: HashImplementation> {
/// all the potential matches for a given hash. The value points to previous
/// position in the chain, or 0 if there are no more matches. (We start
/// with an offset of 8 to avoid confusion with the end of the chain)
prev: [InternalPosition; 65536],
prev: [I; 65536],

/// hash function used to calculate the hash
hash: H,
}

impl<H: HashImplementation> HashTable<H> {
fn get_head(&self, h: usize) -> InternalPosition {
impl<H: HashImplementation, I: InternalPosition> HashTable<H, I> {
fn get_head(&self, h: usize) -> I {
self.head[h]
}

fn get_node_depth(&self, node: InternalPosition) -> i32 {
fn get_node_depth(&self, node: I) -> i32 {
self.chain_depth[node.to_index()]
}

fn update_chain<const MAINTAIN_DEPTH: bool, const UPDATE_MODE: u32>(
&mut self,
chars: &[u8],
mut pos: InternalPosition,
mut pos: I,
length: u32,
) {
let offset = H::num_hash_bytes() as usize - 1;
Expand Down Expand Up @@ -157,7 +198,7 @@ impl<H: HashImplementation> HashTable<H> {
}
}

pub fn match_depth(&self, end_pos: InternalPosition, input: &PreflateInput) -> u32 {
pub fn match_depth(&self, end_pos: I, input: &PreflateInput) -> u32 {
let h = self.hash.get_hash(input.cur_chars(0));
let head = self.get_head(h);

Expand Down Expand Up @@ -195,38 +236,11 @@ pub trait HashChain {
) -> u32;

fn checksum(&self, checksum: &mut DebugHash);

fn update_hash_with_policy<const MAINTAIN_DEPTH: bool>(
&mut self,
length: u32,
input: &PreflateInput,
add_policy: DictionaryAddPolicy,
) {
match add_policy {
DictionaryAddPolicy::AddAll => {
self.update_hash::<MAINTAIN_DEPTH, UPDATE_MODE_ALL>(length, input);
}
DictionaryAddPolicy::AddFirst(limit) => {
if length > limit.into() {
self.update_hash::<MAINTAIN_DEPTH, UPDATE_MODE_FIRST>(length, input);
} else {
self.update_hash::<MAINTAIN_DEPTH, UPDATE_MODE_ALL>(length, input);
}
}
DictionaryAddPolicy::AddFirstAndLast(limit) => {
if length > limit.into() {
self.update_hash::<MAINTAIN_DEPTH, UPDATE_MODE_FIRST_AND_LAST>(length, input);
} else {
self.update_hash::<MAINTAIN_DEPTH, UPDATE_MODE_ALL>(length, input);
}
}
}
}
}

/// This hash chain algorithm periodically normalizes the hash table
pub struct HashChainNormalize<H: HashImplementation> {
hash_table: Box<HashTable<H>>,
hash_table: Box<HashTable<H, InternalPositionRel>>,
total_shift: i32,
}

Expand All @@ -248,7 +262,7 @@ impl<H: HashImplementation> HashChainNormalize<H> {

impl<H: HashImplementation> HashChain for HashChainNormalize<H> {
fn iterate<'a>(&'a self, input: &PreflateInput, offset: u32) -> impl Iterator<Item = u32> + 'a {
let ref_pos = InternalPosition::from_absolute(input.pos() + offset, self.total_shift);
let ref_pos = InternalPositionRel::from_absolute(input.pos() + offset, self.total_shift);

// if we have a match that needs to be inserted at the head first before
// we start walking the chain
Expand Down Expand Up @@ -311,7 +325,7 @@ impl<H: HashImplementation> HashChain for HashChainNormalize<H> {
}

let end_pos =
InternalPosition::from_absolute(cur_pos - target_reference.dist(), self.total_shift);
InternalPositionRel::from_absolute(cur_pos - target_reference.dist(), self.total_shift);

self.hash_table.match_depth(end_pos, input)
}
Expand Down Expand Up @@ -341,7 +355,7 @@ impl<H: HashImplementation> HashChain for HashChainNormalize<H> {
self.total_shift += DELTA as i32;
}

let pos = InternalPosition::from_absolute(input.pos(), self.total_shift);
let pos = InternalPositionRel::from_absolute(input.pos(), self.total_shift);
let chars = input.cur_chars(0);

self.hash_table
Expand All @@ -352,8 +366,8 @@ impl<H: HashImplementation> HashChain for HashChainNormalize<H> {
/// implementation of the hash chain that uses the libdeflate rotating hash.
/// This consists of two hash tables, one for length 3 and one for length 4.
pub struct HashChainNormalizeLibflate4 {
hash_table: Box<HashTable<LibdeflateRotatingHash4>>,
hash_table_3: Box<HashTable<LibdeflateRotatingHash3>>,
hash_table: Box<HashTable<LibdeflateRotatingHash4, InternalPositionRel>>,
hash_table_3: Box<HashTable<LibdeflateRotatingHash3, InternalPositionRel>>,
total_shift: i32,
}

Expand All @@ -372,7 +386,7 @@ impl HashChainNormalizeLibflate4 {

impl HashChain for HashChainNormalizeLibflate4 {
fn iterate<'a>(&'a self, input: &PreflateInput, offset: u32) -> impl Iterator<Item = u32> + 'a {
let ref_pos = InternalPosition::from_absolute(input.pos() + offset, self.total_shift);
let ref_pos = InternalPositionRel::from_absolute(input.pos() + offset, self.total_shift);

// if we have a match that needs to be inserted at the head first before
// we start walking the chain
Expand Down Expand Up @@ -444,7 +458,7 @@ impl HashChain for HashChainNormalizeLibflate4 {
}

let end_pos =
InternalPosition::from_absolute(cur_pos - target_reference.dist(), self.total_shift);
InternalPositionRel::from_absolute(cur_pos - target_reference.dist(), self.total_shift);

if target_reference.len() == 3 {
// libdeflate uses the 3 byte hash table only for a single match attempt
Expand Down Expand Up @@ -494,7 +508,7 @@ impl HashChain for HashChainNormalizeLibflate4 {
self.total_shift += DELTA as i32;
}

let pos = InternalPosition::from_absolute(input.pos(), self.total_shift);
let pos = InternalPositionRel::from_absolute(input.pos(), self.total_shift);
let chars = input.cur_chars(0);

self.hash_table
Expand Down Expand Up @@ -630,3 +644,120 @@ impl<H: RotatingHashTrait> HashChain for HashChainAbs<H> {
}
}
*/

/// This hash chain algorithm periodically normalizes the hash table
pub struct HashChainAbs<H: HashImplementation> {
hash_table: Box<HashTable<H, InternalPositionAbs>>,
total_shift: i32,
}

impl<H: HashImplementation> HashChainAbs<H> {
pub fn new(hash: H) -> Self {
// Important: total_shift starts at -8 since 0 indicates the end of the hash chain
// so this means that all valid values will be >= 8, otherwise the very first hash
// offset would be zero and so it would get missed
let mut c = HashChainAbs {
total_shift: -8,
hash_table: HashTable::default_boxed(),
};

c.hash_table.hash = hash;

c
}
}

impl<H: HashImplementation> HashChain for HashChainAbs<H> {
fn iterate<'a>(&'a self, input: &PreflateInput, offset: u32) -> impl Iterator<Item = u32> + 'a {
let ref_pos = InternalPositionAbs::from_absolute(input.pos() + offset, self.total_shift);

// if we have a match that needs to be inserted at the head first before
// we start walking the chain
let mut first_match = None;

let h1 = self.hash_table.hash.get_hash(input.cur_chars(0));

let curr_hash;

if offset == 0 {
curr_hash = h1;
} else {
assert_eq!(offset, 1);

// current hash is the next hash since we are starting at offset 1
curr_hash = self.hash_table.hash.get_hash(input.cur_chars(1));

// we are a lazy match, then we haven't added the last byte to the hash yet
// which is a problem if that hash should have been part of this hash chain
// (ie the same hash chain) and we have a limited number of enumerations
// throught the hash chain.
//
// In order to fix this, we see if the hashes are the same, and then add
// a distance 1 item to the iterator that we return.
if h1 == curr_hash {
first_match = Some(1);
}
}

let mut cur_pos = self.hash_table.get_head(curr_hash);

std::iter::from_fn(move || {
if let Some(d) = first_match {
first_match = None;
Some(d)
} else {
if cur_pos.is_valid() {
let d = ref_pos.dist(cur_pos);
cur_pos = self.hash_table.prev[cur_pos.to_index()];
Some(d)
} else {
None
}
}
})
}

fn match_depth(
&self,
target_reference: &PreflateTokenReference,
window_size: u32,
input: &PreflateInput,
) -> u32 {
let cur_pos = input.pos();
let cur_max_dist = std::cmp::min(cur_pos, window_size);

if target_reference.dist() > cur_max_dist {
//println!("dtl {:?} > {}", target_reference, cur_max_dist);
return 0xffff;
}

let end_pos =
InternalPositionAbs::from_absolute(cur_pos - target_reference.dist(), self.total_shift);

self.hash_table.match_depth(end_pos, input)
}

#[allow(dead_code)]
fn checksum(&self, checksum: &mut DebugHash) {
checksum.update_slice(&self.hash_table.chain_depth);
//checksum.update_slice(&self.hash_table.head);
//checksum.update_slice(&self.hash_table.prev);
//checksum.update(self.hash_shift);
//checksum.update(self.running_hash.hash(self.hash_mask));
//checksum.update(self.total_shift);
}

fn update_hash<const MAINTAIN_DEPTH: bool, const UPDATE_MODE: u32>(
&mut self,
length: u32,
input: &PreflateInput,
) {
assert!(length <= MAX_UPDATE_HASH_BATCH);

let pos = InternalPositionAbs::from_absolute(input.pos(), self.total_shift);
let chars = input.cur_chars(0);

self.hash_table
.update_chain::<MAINTAIN_DEPTH, UPDATE_MODE>(chars, pos, length);
}
}
Loading
Loading