diff --git a/src/sparsevec.rs b/src/sparsevec.rs index 0296bfc..6b9115a 100644 --- a/src/sparsevec.rs +++ b/src/sparsevec.rs @@ -35,12 +35,15 @@ impl SparseVector { } } - /// Creates a sparse vector from `(index, value)` pairs. - pub fn from_pairs>(pairs: I, dim: usize) -> SparseVector { - let mut elements: Vec<(i32, f32)> = pairs.into_iter().filter(|v| v.1 != 0.0).collect(); - elements.sort_by_key(|v| v.0); - let indices: Vec = elements.iter().map(|v| v.0).collect(); - let values: Vec = elements.iter().map(|v| v.1).collect(); + /// Creates a sparse vector from a map of non-zero elements. + pub fn from_map<'a, I: IntoIterator>( + map: I, + dim: usize, + ) -> SparseVector { + let mut elements: Vec<(&i32, &f32)> = map.into_iter().filter(|v| *v.1 != 0.0).collect(); + elements.sort_by_key(|v| *v.0); + let indices: Vec = elements.iter().map(|v| *v.0).collect(); + let values: Vec = elements.iter().map(|v| *v.1).collect(); SparseVector { dim, @@ -107,7 +110,7 @@ impl SparseVector { #[cfg(test)] mod tests { use crate::SparseVector; - use std::collections::HashMap; + use std::collections::{BTreeMap, HashMap}; #[test] fn test_from_dense() { @@ -119,9 +122,9 @@ mod tests { } #[test] - fn test_from_pairs_map() { - let pairs = HashMap::from([(0, 1.0), (2, 2.0), (4, 3.0)]); - let vec = SparseVector::from_pairs(pairs, 6); + fn test_from_hash_map() { + let map = HashMap::from([(0, 1.0), (2, 2.0), (4, 3.0)]); + let vec = SparseVector::from_map(&map, 6); assert_eq!(vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0], vec.to_vec()); assert_eq!(6, vec.dimensions()); assert_eq!(&[0, 2, 4], vec.indices()); @@ -129,9 +132,12 @@ mod tests { } #[test] - fn test_from_pairs_vec() { - let pairs = vec![(0, 1.0), (2, 2.0), (4, 3.0)]; - let vec = SparseVector::from_pairs(pairs, 6); + fn test_from_btree_map() { + let map = BTreeMap::from([(0, 1.0), (2, 2.0), (4, 3.0)]); + let vec = SparseVector::from_map(&map, 6); assert_eq!(vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0], vec.to_vec()); + assert_eq!(6, vec.dimensions()); + assert_eq!(&[0, 2, 4], vec.indices()); + assert_eq!(&[1.0, 2.0, 3.0], vec.values()); } }