Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Hash for MapObserver #1989

Merged
merged 10 commits into from
Apr 19, 2024
135 changes: 105 additions & 30 deletions libafl/src/observers/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use alloc::{
};
use core::{
fmt::Debug,
hash::{BuildHasher, Hasher},
hash::{BuildHasher, Hash, Hasher},
iter::Flatten,
marker::PhantomData,
mem::size_of,
Expand Down Expand Up @@ -70,8 +70,7 @@ fn init_count_class_16() {
}

/// Compute the hash of a slice
fn hash_slice<T>(slice: &[T]) -> u64 {
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
fn hash_slice<T, H: Hasher>(slice: &[T], hasher: &mut H) -> u64 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire function may replaced with .as_slice().hash(&mut hasher), and should be deleted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hash_slice should be renamed to hash_helper. Its used to define the hash function. I can't see how to remove it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@addisoncrump any last words on this? Otherwise we'll merge

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be removed. One sec.

let ptr = slice.as_ptr() as *const u8;
let map_size = slice.len() / size_of::<T>();
unsafe {
Expand All @@ -83,7 +82,7 @@ fn hash_slice<T>(slice: &[T]) -> u64 {
/// A [`MapObserver`] observes the static map, as oftentimes used for AFL-like coverage information
///
/// TODO: enforce `iter() -> AssociatedTypeIter` when generic associated types stabilize
pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned
pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned + Hash
// where
// for<'it> &'it Self: IntoIterator<Item = &'it Self::Entry>
{
Expand All @@ -102,8 +101,8 @@ pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned
/// Count the set bytes in the map
fn count_bytes(&self) -> u64;

/// Compute the hash of the map
fn hash(&self) -> u64;
/// Compute the hash of the map without needing to provide a hasher
fn hash_easy(&self) -> u64;

/// Get the initial value for `reset()`
fn initial(&self) -> Self::Entry;
Expand Down Expand Up @@ -346,6 +345,22 @@ where
}
}

impl<'a, T, const DIFFERENTIAL: bool> Hash for StdMapObserver<'a, T, DIFFERENTIAL>
where
T: Bounded
+ PartialEq
+ Default
+ Copy
+ 'static
+ Serialize
+ serde::de::DeserializeOwned
+ Debug,
{
fn hash<H: Hasher>(&self, hasher: &mut H) {
hash_slice(self.as_slice(), hasher);
}
}

impl<'a, T, const DIFFERENTIAL: bool> MapObserver for StdMapObserver<'a, T, DIFFERENTIAL>
where
T: Bounded
Expand Down Expand Up @@ -388,8 +403,10 @@ where
self.as_slice().len()
}

fn hash(&self) -> u64 {
hash_slice(self.as_slice())
fn hash_easy(&self) -> u64 {
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
Hash::hash(self, &mut hasher);
hasher.finish()
}

#[inline]
Expand Down Expand Up @@ -830,6 +847,22 @@ where
}
}

impl<'a, T, const N: usize> Hash for ConstMapObserver<'a, T, N>
where
T: Bounded
+ PartialEq
+ Default
+ Copy
+ 'static
+ Serialize
+ serde::de::DeserializeOwned
+ Debug,
{
fn hash<H: Hasher>(&self, hasher: &mut H) {
hash_slice(self.as_slice(), hasher);
}
}

impl<'a, T, const N: usize> MapObserver for ConstMapObserver<'a, T, N>
where
T: Bounded
Expand Down Expand Up @@ -876,8 +909,10 @@ where
self.as_slice().len()
}

fn hash(&self) -> u64 {
hash_slice(self.as_slice())
fn hash_easy(&self) -> u64 {
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
Hash::hash(self, &mut hasher);
hasher.finish()
}

/// Reset the map
Expand Down Expand Up @@ -1142,6 +1177,24 @@ where
}
}

impl<'a, T> Hash for VariableMapObserver<'a, T>
where
T: Bounded
+ PartialEq
+ Default
+ Copy
+ 'static
+ Serialize
+ serde::de::DeserializeOwned
+ Debug
+ PartialEq
+ Bounded,
{
fn hash<H: Hasher>(&self, hasher: &mut H) {
hash_slice(self.as_slice(), hasher);
}
}

impl<'a, T> MapObserver for VariableMapObserver<'a, T>
where
T: Bounded
Expand Down Expand Up @@ -1188,8 +1241,10 @@ where
}
res
}
fn hash(&self) -> u64 {
hash_slice(self.as_slice())
fn hash_easy(&self) -> u64 {
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
Hash::hash(self, &mut hasher);
hasher.finish()
}

/// Reset the map
Expand Down Expand Up @@ -1307,7 +1362,7 @@ where
///
/// [`MapObserver`]s that are not slice-backed,
/// such as [`MultiMapObserver`], can use [`HitcountsIterableMapObserver`] instead.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
#[serde(bound = "M: serde::de::DeserializeOwned")]
pub struct HitcountsMapObserver<M>
where
Expand Down Expand Up @@ -1433,8 +1488,8 @@ where
self.base.reset_map()
}

fn hash(&self) -> u64 {
self.base.hash()
fn hash_easy(&self) -> u64 {
self.base.hash_easy()
}
fn to_vec(&self) -> Vec<u8> {
self.base.to_vec()
Expand Down Expand Up @@ -1589,7 +1644,7 @@ where
/// Map observer with hitcounts postprocessing
/// Less optimized version for non-slice iterators.
/// Slice-backed observers should use a [`HitcountsMapObserver`].
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
#[serde(bound = "M: serde::de::DeserializeOwned")]
pub struct HitcountsIterableMapObserver<M>
where
Expand Down Expand Up @@ -1683,8 +1738,8 @@ where
self.base.reset_map()
}

fn hash(&self) -> u64 {
self.base.hash()
fn hash_easy(&self) -> u64 {
self.base.hash_easy()
}
fn to_vec(&self) -> Vec<u8> {
self.base.to_vec()
Expand Down Expand Up @@ -1890,6 +1945,22 @@ where
}
}

impl<'a, T, const DIFFERENTIAL: bool> Hash for MultiMapObserver<'a, T, DIFFERENTIAL>
where
T: 'static + Default + Copy + Serialize + serde::de::DeserializeOwned + Debug,
{
fn hash<H: Hasher>(&self, hasher: &mut H) {
for map in &self.maps {
let slice = map.as_slice();
let ptr = slice.as_ptr() as *const u8;
let map_size = slice.len() / size_of::<T>();
unsafe {
hasher.write(slice::from_raw_parts(ptr, map_size));
}
}
}
}

impl<'a, T, const DIFFERENTIAL: bool> MapObserver for MultiMapObserver<'a, T, DIFFERENTIAL>
where
T: 'static
Expand Down Expand Up @@ -1937,16 +2008,9 @@ where
res
}

fn hash(&self) -> u64 {
fn hash_easy(&self) -> u64 {
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
for map in &self.maps {
let slice = map.as_slice();
let ptr = slice.as_ptr() as *const u8;
let map_size = slice.len() / size_of::<T>();
unsafe {
hasher.write(slice::from_raw_parts(ptr, map_size));
}
}
Hash::hash(self, &mut hasher);
hasher.finish()
}

Expand Down Expand Up @@ -2246,6 +2310,15 @@ where
}
}

impl<T> Hash for OwnedMapObserver<T>
where
T: 'static + Default + Copy + Serialize + serde::de::DeserializeOwned + Debug,
{
fn hash<H: Hasher>(&self, hasher: &mut H) {
hash_slice(self.as_slice(), hasher);
}
}

impl<T> MapObserver for OwnedMapObserver<T>
where
T: 'static
Expand Down Expand Up @@ -2288,8 +2361,10 @@ where
self.as_slice().len()
}

fn hash(&self) -> u64 {
hash_slice(self.as_slice())
fn hash_easy(&self) -> u64 {
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
Hash::hash(self, &mut hasher);
hasher.finish()
}

#[inline]
Expand Down Expand Up @@ -2666,7 +2741,7 @@ pub mod pybind {
mapob_unwrap_me!($wrapper_name, self.wrapper, m, { m.usable_count() })
}

fn hash(&self) -> u64 {
fn hash_easy(&self) -> u64 {
mapob_unwrap_me!($wrapper_name, self.wrapper, m, { m.hash() })
}

Expand Down
2 changes: 1 addition & 1 deletion libafl/src/schedulers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ where
.match_name::<O>(self.map_observer_name())
.ok_or_else(|| Error::key_not_found("MapObserver not found".to_string()))?;

let mut hash = observer.hash() as usize;
let mut hash = observer.hash_easy() as usize;

let psmeta = state.metadata_mut::<SchedulerMetadata>()?;

Expand Down
2 changes: 1 addition & 1 deletion libafl/src/stages/colorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ where
.match_name::<O>(name)
.ok_or_else(|| Error::key_not_found("MapObserver not found".to_string()))?;

let hash = observer.hash() as usize;
let hash = observer.hash_easy() as usize;

executor
.observers_mut()
Expand Down
4 changes: 2 additions & 2 deletions libafl/src/stages/tmin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ where
let obs = observers
.match_name::<M>(self.observer_name())
.expect("Should have been provided valid observer name.");
Ok(obs.hash() == self.orig_hash)
Ok(obs.hash_easy() == self.orig_hash)
}
}

Expand Down Expand Up @@ -444,7 +444,7 @@ where
MapEqualityFeedback {
name: "MapEq".to_string(),
obs_name: self.obs_name.clone(),
orig_hash: obs.hash(),
orig_hash: obs.hash_easy(),
phantom: PhantomData,
}
}
Expand Down
Loading