Skip to content

Commit

Permalink
Derive node count from node values if present (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
s1ck authored Nov 17, 2023
1 parent 77e74bf commit 7dbb175
Showing 1 changed file with 50 additions and 5 deletions.
55 changes: 50 additions & 5 deletions crates/builder/src/graph/csr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,13 +543,14 @@ where
{
fn from((node_values, edge_list, csr_option): (NodeValues<NV>, E, CsrLayout)) -> Self {
info!("Creating directed graph");
let node_count = edge_list.max_node_id() + NI::new(1);
let node_count = NI::new(node_values.0.len());
let node_count_from_edge_list = edge_list.max_node_id() + NI::new(1);

assert!(
node_values.0.len() >= node_count.index(),
node_count >= node_count_from_edge_list,
"number of node values ({}) does not match node count of edge list ({})",
node_values.0.len(),
node_count.index()
node_count.index(),
node_count_from_edge_list.index()
);

let start = Instant::now();
Expand Down Expand Up @@ -743,7 +744,15 @@ where
{
fn from((node_values, edge_list, csr_option): (NodeValues<NV>, E, CsrLayout)) -> Self {
info!("Creating undirected graph");
let node_count = edge_list.max_node_id() + NI::new(1);
let node_count = NI::new(node_values.0.len());
let node_count_from_edge_list = edge_list.max_node_id() + NI::new(1);

assert!(
node_count >= node_count_from_edge_list,
"number of node values ({}) does not match node count of edge list ({})",
node_count.index(),
node_count_from_edge_list.index()
);

let start = Instant::now();
let csr = Csr::from((&edge_list, node_count, Direction::Undirected, csr_option));
Expand Down Expand Up @@ -1193,4 +1202,40 @@ mod tests {
assert_eq!(ug.neighbors(0).as_slice(), &[1, 3, 7, 21, 42]);
});
}

#[test]
fn directed_from_node_values_exceeding_edge_list_max_id() {
let g0: DirectedCsrGraph<u32, u32> = GraphBuilder::new()
.edges(vec![(0, 1), (1, 2)])
.node_values(vec![0, 1, 2, 3])
.build();

assert_eq!(g0.node_count(), 4);
for node in 0..4 {
assert_eq!(g0.node_value(node), &node);
}

assert_eq!(g0.out_degree(0), 1);
assert_eq!(g0.out_degree(1), 1);
assert_eq!(g0.out_degree(2), 0);
assert_eq!(g0.out_degree(3), 0);
}

#[test]
fn undirected_from_node_values_exceeding_edge_list_max_id() {
let g0: UndirectedCsrGraph<u32, u32> = GraphBuilder::new()
.edges(vec![(0, 1), (1, 2)])
.node_values(vec![0, 1, 2, 3])
.build();

assert_eq!(g0.node_count(), 4);
for node in 0..4 {
assert_eq!(g0.node_value(node), &node);
}

assert_eq!(g0.degree(0), 1);
assert_eq!(g0.degree(1), 2);
assert_eq!(g0.degree(2), 1);
assert_eq!(g0.degree(3), 0);
}
}

0 comments on commit 7dbb175

Please sign in to comment.