diff --git a/crates/bevy_render/src/render_graph/graph.rs b/crates/bevy_render/src/render_graph/graph.rs index d975ad68e6b8ec..da54065fd28acc 100644 --- a/crates/bevy_render/src/render_graph/graph.rs +++ b/crates/bevy_render/src/render_graph/graph.rs @@ -685,7 +685,7 @@ mod tests { }, renderer::RenderContext, }; - use bevy_ecs::world::World; + use bevy_ecs::world::{FromWorld, World}; use bevy_utils::HashSet; #[derive(Debug)] @@ -726,6 +726,22 @@ mod tests { } } + fn input_nodes(name: &'static str, graph: &RenderGraph) -> HashSet { + graph + .iter_node_inputs(name) + .unwrap() + .map(|(_edge, node)| node.id) + .collect::>() + } + + fn output_nodes(name: &'static str, graph: &RenderGraph) -> HashSet { + graph + .iter_node_outputs(name) + .unwrap() + .map(|(_edge, node)| node.id) + .collect::>() + } + #[test] fn test_graph_edges() { let mut graph = RenderGraph::default(); @@ -738,22 +754,6 @@ mod tests { graph.add_node_edge("B", "C"); graph.add_slot_edge("C", 0, "D", 0); - fn input_nodes(name: &'static str, graph: &RenderGraph) -> HashSet { - graph - .iter_node_inputs(name) - .unwrap() - .map(|(_edge, node)| node.id) - .collect::>() - } - - fn output_nodes(name: &'static str, graph: &RenderGraph) -> HashSet { - graph - .iter_node_outputs(name) - .unwrap() - .map(|(_edge, node)| node.id) - .collect::>() - } - assert!(input_nodes("A", &graph).is_empty(), "A has no inputs"); assert!( output_nodes("A", &graph) == HashSet::from_iter(vec![c_id]), @@ -853,4 +853,48 @@ mod tests { "Adding to a duplicate edge should return an error" ); } + + #[test] + fn test_add_node_with_edges() { + struct SimpleNode; + impl Node for SimpleNode { + fn run( + &self, + _graph: &mut RenderGraphContext, + _render_context: &mut RenderContext, + _world: &World, + ) -> Result<(), NodeRunError> { + Ok(()) + } + } + impl FromWorld for SimpleNode { + fn from_world(_world: &mut World) -> Self { + Self + } + } + + let mut graph = RenderGraph::default(); + let a_id = graph.add_node("A", SimpleNode); + let c_id = graph.add_node("C", SimpleNode); + + // A and C need to exist first + let b_id = graph.add_node_with_edges("B", SimpleNode, &["A", "B", "C"]); + + assert!( + output_nodes("A", &graph) == HashSet::from_iter(vec![b_id]), + "A -> B" + ); + assert!( + input_nodes("B", &graph) == HashSet::from_iter(vec![a_id]), + "B -> C" + ); + assert!( + output_nodes("B", &graph) == HashSet::from_iter(vec![c_id]), + "B -> C" + ); + assert!( + input_nodes("C", &graph) == HashSet::from_iter(vec![b_id]), + "B -> C" + ); + } }