Skip to content

Commit

Permalink
Added from_coordinates method to SparseVector
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jun 4, 2024
1 parent 8b6102f commit dc8b4b8
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions src/sparsevec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ impl SparseVector {
}
}

/// Creates a sparse vector from coordinates.
pub fn from_coordinates<I: IntoIterator<Item = (i32, f32)>>(
iter: I,
dim: usize,
) -> SparseVector {
let mut elements: Vec<(i32, f32)> = iter.into_iter().collect();
elements.sort_by_key(|v| v.0);
let indices: Vec<i32> = elements.iter().map(|v| v.0).collect();
let values: Vec<f32> = elements.iter().map(|v| v.1).collect();

SparseVector {
dim,
indices,
values,
}
}

/// Returns the sparse vector as a `Vec<f32>`.
pub fn to_vec(&self) -> Vec<f32> {
let mut vec = vec![0.0; self.dim];
Expand Down Expand Up @@ -91,13 +108,28 @@ impl SparseVector {
#[cfg(test)]
mod tests {
use crate::SparseVector;
use std::collections::HashMap;

#[test]
fn test_from_dense() {
let vec = SparseVector::from_dense(&[1.0, 0.0, 2.0, 0.0, 3.0, 0.0]);
assert_eq!(vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0], vec.to_vec());
}

#[test]
fn test_from_coo_map() {
let elements = HashMap::from([(0, 1.0), (2, 2.0), (4, 3.0)]);
let vec = SparseVector::from_coordinates(elements, 6);
assert_eq!(vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0], vec.to_vec());
}

#[test]
fn test_from_coo_vec() {
let elements = vec![(0, 1.0), (2, 2.0), (4, 3.0)];
let vec = SparseVector::from_coordinates(elements, 6);
assert_eq!(vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0], vec.to_vec());
}

#[test]
fn test_to_vec() {
let vec = SparseVector::new(6, vec![0, 2, 4], vec![1.0, 2.0, 3.0]);
Expand Down

0 comments on commit dc8b4b8

Please sign in to comment.