diff --git a/rustworkx-core/src/generators/heavy_hex_graph.rs b/rustworkx-core/src/generators/heavy_hex_graph.rs new file mode 100644 index 000000000..2d2821d8b --- /dev/null +++ b/rustworkx-core/src/generators/heavy_hex_graph.rs @@ -0,0 +1,316 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use std::hash::Hash; + +use petgraph::data::{Build, Create}; +use petgraph::visit::{Data, NodeIndexable}; + +use super::InvalidInputError; + +/// Generate an undirected heavy hex graph. Fig. 2 of +/// https://arxiv.org/abs/1907.09528 +/// An ASCII diagram of the graph is given by: +/// +/// .. note:: +/// +/// ... D-S-D D ... +/// | | | +/// ...-F F-S-F ... +/// | | | +/// ... D D D ... +/// | | | +/// ... F-S-F F-... +/// | | | +/// ......... +/// | | | +/// ... D D D ... +/// | | | +/// ...-F F-S-F ... +/// | | | +/// ... D D D ... +/// | | | +/// ... F-S-F F-... +/// | | | +/// ......... +/// | | | +/// ... D D D ... +/// | | | +/// ...-F F-S-F ... +/// | | | +/// ... D D D ... +/// | | | +/// ... F-S-F F-... +/// | | | +/// ... D D-S-D ... +/// +/// +/// * `d` - Distance of the code. If ``d`` is set to ``1`` a graph with a +/// single node will be returned. d must be an odd number. +/// * `default_node_weight` - A callable that will return the weight to use +/// for newly created nodes. This is ignored if `weights` is specified, +/// as the weights from that argument will be used instead. +/// * `default_edge_weight` - A callable that will return the weight object +/// to use for newly created edges. +/// * `bidirectional` - Whether edges are added bidirectionally, if set to +/// `true` then for any edge `(u, v)` an edge `(v, u)` will also be added. +/// If the graph is undirected this will result in a pallel edge. +/// +/// # Example +/// ```rust +/// use rustworkx_core::petgraph; +/// use rustworkx_core::generators::heavy_hex_graph; +/// use rustworkx_core::petgraph::visit::EdgeRef; +/// +/// let expected_edge_list = vec![ +/// (0, 13), +/// (1, 13), +/// (1, 14), +/// (2, 14), +/// (3, 15), +/// (4, 15), +/// (4, 16), +/// (5, 16), +/// (6, 17), +/// (7, 17), +/// (7, 18), +/// (8, 18), +/// (0, 9), +/// (3, 9), +/// (5, 12), +/// (8, 12), +/// (10, 14), +/// (10, 16), +/// (11, 15), +/// (11, 17), +/// ]; +/// let d = 3; +/// let g: petgraph::graph::UnGraph<(), ()> = heavy_hex_graph(d, || (), || (), false).unwrap(); +/// assert_eq!(g.node_count(), (5 * d * d - 2 * d - 1) / 2); +/// assert_eq!(g.edge_count(), 2 * d * (d - 1) + (d + 1) * (d - 1)); +/// assert_eq!( +/// expected_edge_list, +/// g.edge_references() +/// .map(|edge| (edge.source().index(), edge.target().index())) +/// .collect::>(), +/// ) +/// ``` +pub fn heavy_hex_graph( + d: usize, + mut default_node_weight: F, + mut default_edge_weight: H, + bidirectional: bool, +) -> Result +where + G: Build + Create + Data + NodeIndexable, + F: FnMut() -> T, + H: FnMut() -> M, + G::NodeId: Eq + Hash, +{ + if d % 2 == 0 { + return Err(InvalidInputError {}); + } + let num_nodes = (5 * d * d - 2 * d - 1) / 2; + let num_edges = 2 * d * (d - 1) + (d + 1) * (d - 1); + let mut graph = G::with_capacity(num_nodes, num_edges); + + if d == 1 { + graph.add_node(default_node_weight()); + return Ok(graph); + } + let num_data = d * d; + let num_syndrome = (d - 1) * (d + 1) / 2; + let num_flag = d * (d - 1); + + let nodes_data: Vec = (0..num_data) + .map(|_| graph.add_node(default_node_weight())) + .collect(); + let nodes_syndrome: Vec = (0..num_syndrome) + .map(|_| graph.add_node(default_node_weight())) + .collect(); + let nodes_flag: Vec = (0..num_flag) + .map(|_| graph.add_node(default_node_weight())) + .collect(); + + // connect data and flags + for (i, flag_chunk) in nodes_flag.chunks(d - 1).enumerate() { + for (j, flag) in flag_chunk.iter().enumerate() { + graph.add_edge(nodes_data[i * d + j], *flag, default_edge_weight()); + graph.add_edge(nodes_data[i * d + j + 1], *flag, default_edge_weight()); + if bidirectional { + graph.add_edge(*flag, nodes_data[i * d + j], default_edge_weight()); + graph.add_edge(*flag, nodes_data[i * d + j + 1], default_edge_weight()); + } + } + } + + // connect data and syndromes + for (i, syndrome_chunk) in nodes_syndrome.chunks((d + 1) / 2).enumerate() { + if i % 2 == 0 { + graph.add_edge(nodes_data[i * d], syndrome_chunk[0], default_edge_weight()); + graph.add_edge( + nodes_data[(i + 1) * d], + syndrome_chunk[0], + default_edge_weight(), + ); + if bidirectional { + graph.add_edge(syndrome_chunk[0], nodes_data[i * d], default_edge_weight()); + graph.add_edge( + syndrome_chunk[0], + nodes_data[(i + 1) * d], + default_edge_weight(), + ); + } + } else if i % 2 == 1 { + graph.add_edge( + nodes_data[i * d + (d - 1)], + syndrome_chunk[syndrome_chunk.len() - 1], + default_edge_weight(), + ); + graph.add_edge( + nodes_data[i * d + (2 * d - 1)], + syndrome_chunk[syndrome_chunk.len() - 1], + default_edge_weight(), + ); + if bidirectional { + graph.add_edge( + syndrome_chunk[syndrome_chunk.len() - 1], + nodes_data[i * d + (d - 1)], + default_edge_weight(), + ); + graph.add_edge( + syndrome_chunk[syndrome_chunk.len() - 1], + nodes_data[i * d + (2 * d - 1)], + default_edge_weight(), + ); + } + } + } + + // connect flag and syndromes + for (i, syndrome_chunk) in nodes_syndrome.chunks((d + 1) / 2).enumerate() { + if i % 2 == 0 { + for (j, syndrome) in syndrome_chunk.iter().enumerate() { + if j != 0 { + graph.add_edge( + *syndrome, + nodes_flag[i * (d - 1) + 2 * (j - 1) + 1], + default_edge_weight(), + ); + graph.add_edge( + *syndrome, + nodes_flag[(i + 1) * (d - 1) + 2 * (j - 1) + 1], + default_edge_weight(), + ); + if bidirectional { + graph.add_edge( + nodes_flag[i * (d - 1) + 2 * (j - 1) + 1], + *syndrome, + default_edge_weight(), + ); + graph.add_edge( + nodes_flag[(i + 1) * (d - 1) + 2 * (j - 1) + 1], + *syndrome, + default_edge_weight(), + ); + } + } + } + } else if i % 2 == 1 { + for (j, syndrome) in syndrome_chunk.iter().enumerate() { + if j != syndrome_chunk.len() - 1 { + graph.add_edge( + *syndrome, + nodes_flag[i * (d - 1) + 2 * j], + default_edge_weight(), + ); + graph.add_edge( + *syndrome, + nodes_flag[(i + 1) * (d - 1) + 2 * j], + default_edge_weight(), + ); + if bidirectional { + graph.add_edge( + nodes_flag[i * (d - 1) + 2 * j], + *syndrome, + default_edge_weight(), + ); + graph.add_edge( + nodes_flag[(i + 1) * (d - 1) + 2 * j], + *syndrome, + default_edge_weight(), + ); + } + } + } + } + } + Ok(graph) +} + +#[cfg(test)] +mod tests { + use crate::generators::heavy_hex_graph; + use crate::generators::InvalidInputError; + use crate::petgraph; + use crate::petgraph::visit::EdgeRef; + + #[test] + fn test_heavy_hex_graph_3() { + let expected_edge_list = vec![ + (0, 13), + (1, 13), + (1, 14), + (2, 14), + (3, 15), + (4, 15), + (4, 16), + (5, 16), + (6, 17), + (7, 17), + (7, 18), + (8, 18), + (0, 9), + (3, 9), + (5, 12), + (8, 12), + (10, 14), + (10, 16), + (11, 15), + (11, 17), + ]; + let d = 3; + let g: petgraph::graph::UnGraph<(), ()> = heavy_hex_graph(d, || (), || (), false).unwrap(); + assert_eq!(g.node_count(), (5 * d * d - 2 * d - 1) / 2); + assert_eq!(g.edge_count(), 2 * d * (d - 1) + (d + 1) * (d - 1)); + assert_eq!( + expected_edge_list, + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + } + + #[test] + fn test_heavy_hex_error() { + let d = 2; + match heavy_hex_graph::, (), _, _, ()>( + d, + || (), + || (), + false, + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } +} diff --git a/rustworkx-core/src/generators/mod.rs b/rustworkx-core/src/generators/mod.rs index f242b2979..febf9e187 100644 --- a/rustworkx-core/src/generators/mod.rs +++ b/rustworkx-core/src/generators/mod.rs @@ -14,6 +14,7 @@ mod cycle_graph; mod grid_graph; +mod heavy_hex_graph; mod path_graph; mod petersen_graph; mod star_graph; @@ -37,6 +38,7 @@ impl fmt::Display for InvalidInputError { pub use cycle_graph::cycle_graph; pub use grid_graph::grid_graph; +pub use heavy_hex_graph::heavy_hex_graph; pub use path_graph::path_graph; pub use petersen_graph::petersen_graph; pub use star_graph::star_graph; diff --git a/src/generators.rs b/src/generators.rs index a83f54a17..9d4f77c12 100644 --- a/src/generators.rs +++ b/src/generators.rs @@ -1325,86 +1325,12 @@ pub fn directed_heavy_square_graph( #[pyfunction(multigraph = true)] #[pyo3(text_signature = "(d, /, multigraph=True)")] pub fn heavy_hex_graph(py: Python, d: usize, multigraph: bool) -> PyResult { - let mut graph = StablePyGraph::::default(); - - if d % 2 == 0 { - return Err(PyIndexError::new_err("d must be odd")); - } - - if d == 1 { - graph.add_node(py.None()); - return Ok(graph::PyGraph { - graph, - node_removed: false, - multigraph, - attrs: py.None(), - }); - } - - let num_data = d * d; - let num_syndrome = (d - 1) * (d + 1) / 2; - let num_flag = d * (d - 1); - - let nodes_data: Vec = (0..num_data).map(|_| graph.add_node(py.None())).collect(); - let nodes_syndrome: Vec = (0..num_syndrome) - .map(|_| graph.add_node(py.None())) - .collect(); - let nodes_flag: Vec = (0..num_flag).map(|_| graph.add_node(py.None())).collect(); - - // connect data and flags - for (i, flag_chunk) in nodes_flag.chunks(d - 1).enumerate() { - for (j, flag) in flag_chunk.iter().enumerate() { - graph.add_edge(nodes_data[i * d + j], *flag, py.None()); - graph.add_edge(*flag, nodes_data[i * d + j + 1], py.None()); - } - } - - // connect data and syndromes - for (i, syndrome_chunk) in nodes_syndrome.chunks((d + 1) / 2).enumerate() { - if i % 2 == 0 { - graph.add_edge(nodes_data[i * d], syndrome_chunk[0], py.None()); - graph.add_edge(syndrome_chunk[0], nodes_data[(i + 1) * d], py.None()); - } else if i % 2 == 1 { - graph.add_edge( - nodes_data[i * d + (d - 1)], - syndrome_chunk[syndrome_chunk.len() - 1], - py.None(), - ); - graph.add_edge( - syndrome_chunk[syndrome_chunk.len() - 1], - nodes_data[i * d + (2 * d - 1)], - py.None(), - ); - } - } - - // connect flag and syndromes - for (i, syndrome_chunk) in nodes_syndrome.chunks((d + 1) / 2).enumerate() { - if i % 2 == 0 { - for (j, syndrome) in syndrome_chunk.iter().enumerate() { - if j != 0 { - graph.add_edge( - nodes_flag[i * (d - 1) + 2 * (j - 1) + 1], - *syndrome, - py.None(), - ); - graph.add_edge( - *syndrome, - nodes_flag[(i + 1) * (d - 1) + 2 * (j - 1) + 1], - py.None(), - ); - } - } - } else if i % 2 == 1 { - for (j, syndrome) in syndrome_chunk.iter().enumerate() { - if j != syndrome_chunk.len() - 1 { - graph.add_edge(nodes_flag[i * (d - 1) + 2 * j], *syndrome, py.None()); - graph.add_edge(*syndrome, nodes_flag[(i + 1) * (d - 1) + 2 * j], py.None()); - } - } - } - } - + let default_fn = || py.None(); + let graph: StablePyGraph = + match core_generators::heavy_hex_graph(d, default_fn, default_fn, false) { + Ok(graph) => graph, + Err(_) => return Err(PyIndexError::new_err("d must be an odd number.")), + }; Ok(graph::PyGraph { graph, node_removed: false, @@ -1478,124 +1404,12 @@ pub fn directed_heavy_hex_graph( bidirectional: bool, multigraph: bool, ) -> PyResult { - let mut graph = StablePyGraph::::default(); - - if d % 2 == 0 { - return Err(PyIndexError::new_err("d must be odd")); - } - - if d == 1 { - graph.add_node(py.None()); - return Ok(digraph::PyDiGraph { - graph, - node_removed: false, - check_cycle: false, - cycle_state: algo::DfsSpace::default(), - multigraph, - attrs: py.None(), - }); - } - - let num_data = d * d; - let num_syndrome = (d - 1) * (d + 1) / 2; - let num_flag = d * (d - 1); - - let nodes_data: Vec = (0..num_data).map(|_| graph.add_node(py.None())).collect(); - let nodes_syndrome: Vec = (0..num_syndrome) - .map(|_| graph.add_node(py.None())) - .collect(); - let nodes_flag: Vec = (0..num_flag).map(|_| graph.add_node(py.None())).collect(); - - // connect data and flags - for (i, flag_chunk) in nodes_flag.chunks(d - 1).enumerate() { - for (j, flag) in flag_chunk.iter().enumerate() { - graph.add_edge(nodes_data[i * d + j], *flag, py.None()); - graph.add_edge(nodes_data[i * d + j + 1], *flag, py.None()); - if bidirectional { - graph.add_edge(*flag, nodes_data[i * d + j], py.None()); - graph.add_edge(*flag, nodes_data[i * d + j + 1], py.None()); - } - } - } - - // connect data and syndromes - for (i, syndrome_chunk) in nodes_syndrome.chunks((d + 1) / 2).enumerate() { - if i % 2 == 0 { - graph.add_edge(nodes_data[i * d], syndrome_chunk[0], py.None()); - graph.add_edge(nodes_data[(i + 1) * d], syndrome_chunk[0], py.None()); - if bidirectional { - graph.add_edge(syndrome_chunk[0], nodes_data[i * d], py.None()); - graph.add_edge(syndrome_chunk[0], nodes_data[(i + 1) * d], py.None()); - } - } else if i % 2 == 1 { - graph.add_edge( - nodes_data[i * d + (d - 1)], - syndrome_chunk[syndrome_chunk.len() - 1], - py.None(), - ); - graph.add_edge( - nodes_data[i * d + (2 * d - 1)], - syndrome_chunk[syndrome_chunk.len() - 1], - py.None(), - ); - if bidirectional { - graph.add_edge( - syndrome_chunk[syndrome_chunk.len() - 1], - nodes_data[i * d + (d - 1)], - py.None(), - ); - graph.add_edge( - syndrome_chunk[syndrome_chunk.len() - 1], - nodes_data[i * d + (2 * d - 1)], - py.None(), - ); - } - } - } - - // connect flag and syndromes - for (i, syndrome_chunk) in nodes_syndrome.chunks((d + 1) / 2).enumerate() { - if i % 2 == 0 { - for (j, syndrome) in syndrome_chunk.iter().enumerate() { - if j != 0 { - graph.add_edge( - *syndrome, - nodes_flag[i * (d - 1) + 2 * (j - 1) + 1], - py.None(), - ); - graph.add_edge( - *syndrome, - nodes_flag[(i + 1) * (d - 1) + 2 * (j - 1) + 1], - py.None(), - ); - if bidirectional { - graph.add_edge( - nodes_flag[i * (d - 1) + 2 * (j - 1) + 1], - *syndrome, - py.None(), - ); - graph.add_edge( - nodes_flag[(i + 1) * (d - 1) + 2 * (j - 1) + 1], - *syndrome, - py.None(), - ); - } - } - } - } else if i % 2 == 1 { - for (j, syndrome) in syndrome_chunk.iter().enumerate() { - if j != syndrome_chunk.len() - 1 { - graph.add_edge(*syndrome, nodes_flag[i * (d - 1) + 2 * j], py.None()); - graph.add_edge(*syndrome, nodes_flag[(i + 1) * (d - 1) + 2 * j], py.None()); - if bidirectional { - graph.add_edge(nodes_flag[i * (d - 1) + 2 * j], *syndrome, py.None()); - graph.add_edge(nodes_flag[(i + 1) * (d - 1) + 2 * j], *syndrome, py.None()); - } - } - } - } - } - + let default_fn = || py.None(); + let graph: StablePyGraph = + match core_generators::heavy_hex_graph(d, default_fn, default_fn, bidirectional) { + Ok(graph) => graph, + Err(_) => return Err(PyIndexError::new_err("d must be an odd number.")), + }; Ok(digraph::PyDiGraph { graph, node_removed: false, diff --git a/tests/retworkx_backwards_compat/generators/test_heavy_hex.py b/tests/retworkx_backwards_compat/generators/test_heavy_hex.py index 904af3266..09e5093fc 100644 --- a/tests/retworkx_backwards_compat/generators/test_heavy_hex.py +++ b/tests/retworkx_backwards_compat/generators/test_heavy_hex.py @@ -113,24 +113,24 @@ def test_heavy_hex_graph_3(self): self.assertEqual(len(graph.edges()), 2 * d * (d - 1) + (d + 1) * (d - 1)) expected_edges = [ (0, 13), - (13, 1), + (1, 13), (1, 14), - (14, 2), + (2, 14), (3, 15), - (15, 4), + (4, 15), (4, 16), - (16, 5), + (5, 16), (6, 17), - (17, 7), + (7, 17), (7, 18), - (18, 8), + (8, 18), (0, 9), - (9, 3), + (3, 9), (5, 12), - (12, 8), - (14, 10), + (8, 12), + (10, 14), (10, 16), - (15, 11), + (11, 15), (11, 17), ] self.assertEqual(list(graph.edge_list()), expected_edges) @@ -352,68 +352,68 @@ def test_heavy_hex_graph_5(self): self.assertEqual(len(graph.edges()), 2 * d * (d - 1) + (d + 1) * (d - 1)) expected_edges = [ (0, 37), - (37, 1), + (1, 37), (1, 38), - (38, 2), + (2, 38), (2, 39), - (39, 3), + (3, 39), (3, 40), - (40, 4), + (4, 40), (5, 41), - (41, 6), + (6, 41), (6, 42), - (42, 7), + (7, 42), (7, 43), - (43, 8), + (8, 43), (8, 44), - (44, 9), + (9, 44), (10, 45), - (45, 11), + (11, 45), (11, 46), - (46, 12), + (12, 46), (12, 47), - (47, 13), + (13, 47), (13, 48), - (48, 14), + (14, 48), (15, 49), - (49, 16), + (16, 49), (16, 50), - (50, 17), + (17, 50), (17, 51), - (51, 18), + (18, 51), (18, 52), - (52, 19), + (19, 52), (20, 53), - (53, 21), + (21, 53), (21, 54), - (54, 22), + (22, 54), (22, 55), - (55, 23), + (23, 55), (23, 56), - (56, 24), + (24, 56), (0, 25), - (25, 5), + (5, 25), (9, 30), - (30, 14), + (14, 30), (10, 31), - (31, 15), + (15, 31), (19, 36), - (36, 24), - (38, 26), + (24, 36), + (26, 38), (26, 42), - (40, 27), + (27, 40), (27, 44), - (41, 28), + (28, 41), (28, 45), - (43, 29), + (29, 43), (29, 47), - (46, 32), + (32, 46), (32, 50), - (48, 33), + (33, 48), (33, 52), - (49, 34), + (34, 49), (34, 53), - (51, 35), + (35, 51), (35, 55), ] diff --git a/tests/rustworkx_tests/generators/test_heavy_hex.py b/tests/rustworkx_tests/generators/test_heavy_hex.py index 832683801..c97c1d08b 100644 --- a/tests/rustworkx_tests/generators/test_heavy_hex.py +++ b/tests/rustworkx_tests/generators/test_heavy_hex.py @@ -113,24 +113,24 @@ def test_heavy_hex_graph_3(self): self.assertEqual(len(graph.edges()), 2 * d * (d - 1) + (d + 1) * (d - 1)) expected_edges = [ (0, 13), - (13, 1), + (1, 13), (1, 14), - (14, 2), + (2, 14), (3, 15), - (15, 4), + (4, 15), (4, 16), - (16, 5), + (5, 16), (6, 17), - (17, 7), + (7, 17), (7, 18), - (18, 8), + (8, 18), (0, 9), - (9, 3), + (3, 9), (5, 12), - (12, 8), - (14, 10), + (8, 12), + (10, 14), (10, 16), - (15, 11), + (11, 15), (11, 17), ] self.assertEqual(list(graph.edge_list()), expected_edges) @@ -352,68 +352,68 @@ def test_heavy_hex_graph_5(self): self.assertEqual(len(graph.edges()), 2 * d * (d - 1) + (d + 1) * (d - 1)) expected_edges = [ (0, 37), - (37, 1), + (1, 37), (1, 38), - (38, 2), + (2, 38), (2, 39), - (39, 3), + (3, 39), (3, 40), - (40, 4), + (4, 40), (5, 41), - (41, 6), + (6, 41), (6, 42), - (42, 7), + (7, 42), (7, 43), - (43, 8), + (8, 43), (8, 44), - (44, 9), + (9, 44), (10, 45), - (45, 11), + (11, 45), (11, 46), - (46, 12), + (12, 46), (12, 47), - (47, 13), + (13, 47), (13, 48), - (48, 14), + (14, 48), (15, 49), - (49, 16), + (16, 49), (16, 50), - (50, 17), + (17, 50), (17, 51), - (51, 18), + (18, 51), (18, 52), - (52, 19), + (19, 52), (20, 53), - (53, 21), + (21, 53), (21, 54), - (54, 22), + (22, 54), (22, 55), - (55, 23), + (23, 55), (23, 56), - (56, 24), + (24, 56), (0, 25), - (25, 5), + (5, 25), (9, 30), - (30, 14), + (14, 30), (10, 31), - (31, 15), + (15, 31), (19, 36), - (36, 24), - (38, 26), + (24, 36), + (26, 38), (26, 42), - (40, 27), + (27, 40), (27, 44), - (41, 28), + (28, 41), (28, 45), - (43, 29), + (29, 43), (29, 47), - (46, 32), + (32, 46), (32, 50), - (48, 33), + (33, 48), (33, 52), - (49, 34), + (34, 49), (34, 53), - (51, 35), + (35, 51), (35, 55), ]