diff --git a/cpp/mrc/include/mrc/node/rx_node.hpp b/cpp/mrc/include/mrc/node/rx_node.hpp index 4eb0e7697..34a48fd0d 100644 --- a/cpp/mrc/include/mrc/node/rx_node.hpp +++ b/cpp/mrc/include/mrc/node/rx_node.hpp @@ -235,6 +235,7 @@ class RxNodeComponent : public WritableProvider, public WritableAcceptor this->get_writable_edge()->await_write(std::move(message)); }, [this](std::exception_ptr ptr) { + WritableAcceptor::release_edge_connection(); runnable::Context::get_runtime_context().set_exception(std::move(ptr)); }, [this]() { diff --git a/cpp/mrc/tests/test_edges.cpp b/cpp/mrc/tests/test_edges.cpp index 631f6d27c..86e42dfb5 100644 --- a/cpp/mrc/tests/test_edges.cpp +++ b/cpp/mrc/tests/test_edges.cpp @@ -28,6 +28,7 @@ #include "mrc/node/operators/combine_latest.hpp" #include "mrc/node/operators/node_component.hpp" #include "mrc/node/operators/router.hpp" +#include "mrc/node/rx_node.hpp" #include "mrc/node/sink_channel_owner.hpp" #include "mrc/node/sink_properties.hpp" #include "mrc/node/source_channel_owner.hpp" @@ -36,6 +37,7 @@ #include #include #include +#include // for observable_member #include #include @@ -278,6 +280,25 @@ class TestNodeComponent : public NodeComponent } }; +template +class TestRxNodeComponent : public RxNodeComponent +{ + using base_t = node::RxNodeComponent; + + public: + using typename base_t::stream_fn_t; + + void make_stream(stream_fn_t fn) + { + return base_t::make_stream([this, fn](auto&&... args) { + stream_fn_called = true; + return fn(std::forward(args)...); + }); + } + + bool stream_fn_called = false; +}; + template class TestSinkComponent : public WritableProvider { @@ -517,6 +538,26 @@ TEST_F(TestEdges, SourceToNodeComponentToSinkComponent) source->run(); } +TEST_F(TestEdges, SourceToRxNodeComponentToSinkComponent) +{ + auto source = std::make_shared>(); + auto node = std::make_shared>(); + auto sink = std::make_shared>(); + + mrc::make_edge(*source, *node); + mrc::make_edge(*node, *sink); + + node->make_stream([=](rxcpp::observable input) { + return input.map([](int i) { + return i * 2; + }); + }); + + source->run(); + + EXPECT_TRUE(node->stream_fn_called); +} + TEST_F(TestEdges, SourceComponentToNodeToSinkComponent) { auto source = std::make_shared>(); @@ -825,6 +866,10 @@ TEST_F(TestEdges, CreateAndDestroy) auto x = std::make_shared>(); } + { + auto x = std::make_shared>(); + } + { auto x = std::make_shared>(); } @@ -927,4 +972,28 @@ TEST_F(TestEdges, EdgeTapWithSpliceComponent) source->run(); sink->run(); } + +TEST_F(TestEdges, EdgeTapWithSpliceRxComponent) +{ + auto source = std::make_shared>(); + auto node = std::make_shared>(); + auto sink = std::make_shared>(); + + // Original edge + mrc::make_edge(*source, *sink); + + node->make_stream([=](rxcpp::observable input) { + return input.map([](int i) { + return i * 2; + }); + }); + + // Tap edge + mrc::edge::EdgeBuilder::splice_edge(*source, *sink, *node, *node); + + source->run(); + sink->run(); + + EXPECT_TRUE(node->stream_fn_called); +} } // namespace mrc diff --git a/cpp/mrc/tests/test_node.cpp b/cpp/mrc/tests/test_node.cpp index 51e57eeb8..8ba8cf15e 100644 --- a/cpp/mrc/tests/test_node.cpp +++ b/cpp/mrc/tests/test_node.cpp @@ -17,6 +17,7 @@ #include "test_mrc.hpp" +#include "mrc/channel/status.hpp" // for Status #include "mrc/core/executor.hpp" #include "mrc/engine/pipeline/ipipeline.hpp" #include "mrc/node/rx_node.hpp" @@ -45,6 +46,7 @@ #include #include #include +#include // for runtime_error #include #include #include @@ -485,6 +487,54 @@ TEST_F(TestNode, NodePrologueEpilogue) EXPECT_EQ(epilogue_tap_sum, 20); } +TEST_F(TestNode, RxNodeComponentThrows) +{ + auto p = pipeline::make_pipeline(); + std::atomic throw_count = 0; + std::atomic sink_call_count = 0; + std::atomic complete_count = 0; + + auto my_segment = p->make_segment("test_segment", [&](segment::Builder& seg) { + auto source = seg.make_source("source", [&](rxcpp::subscriber& s) { + s.on_next(1); + s.on_next(2); + s.on_next(3); + s.on_completed(); + }); + + auto node_comp = seg.make_node_component("node", rxcpp::operators::map([&](int i) -> int { + ++throw_count; + throw std::runtime_error("test"); + return 0; + })); + + auto sink = seg.make_sink( + "sinkInt", + [&](const int& x) { + ++sink_call_count; + }, + [&]() { + ++complete_count; + }); + + seg.make_edge(source, node_comp); + seg.make_edge(node_comp, sink); + }); + + auto options = std::make_unique(); + options->topology().user_cpuset("0"); + + Executor exec(std::move(options)); + exec.register_pipeline(std::move(p)); + exec.start(); + + EXPECT_THROW(exec.join(), std::runtime_error); + + EXPECT_EQ(throw_count, 1); + EXPECT_EQ(sink_call_count, 0); + EXPECT_EQ(complete_count, 0); +} + // the parallel tests: // - SourceMultiThread // - SinkMultiThread