Skip to content

Commit

Permalink
Release an RxNodeComponent edge on error (#327)
Browse files Browse the repository at this point in the history
Followup to #326

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #327
  • Loading branch information
dagardner-nv authored May 10, 2023
1 parent 59a9474 commit 7b0d9fa
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
1 change: 1 addition & 0 deletions cpp/mrc/include/mrc/node/rx_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class RxNodeComponent : public WritableProvider<InputT>, public WritableAcceptor
this->get_writable_edge()->await_write(std::move(message));
},
[this](std::exception_ptr ptr) {
WritableAcceptor<OutputT>::release_edge_connection();
runnable::Context::get_runtime_context().set_exception(std::move(ptr));
},
[this]() {
Expand Down
69 changes: 69 additions & 0 deletions cpp/mrc/tests/test_edges.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mrc/node/operators/combine_latest.hpp"
#include "mrc/node/operators/node_component.hpp"
#include "mrc/node/operators/router.hpp"
#include "mrc/node/rx_node.hpp"
#include "mrc/node/sink_channel_owner.hpp"
#include "mrc/node/sink_properties.hpp"
#include "mrc/node/source_channel_owner.hpp"
Expand All @@ -36,6 +37,7 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <gtest/internal/gtest-internal.h>
#include <rxcpp/rx.hpp> // for observable_member

#include <functional>
#include <map>
Expand Down Expand Up @@ -278,6 +280,25 @@ class TestNodeComponent : public NodeComponent<T, T>
}
};

template <typename T>
class TestRxNodeComponent : public RxNodeComponent<T, T>
{
using base_t = node::RxNodeComponent<T, T>;

public:
using typename base_t::stream_fn_t;

void make_stream(stream_fn_t fn)
{
return base_t::make_stream([this, fn](auto&&... args) {
stream_fn_called = true;
return fn(std::forward<decltype(args)>(args)...);
});
}

bool stream_fn_called = false;
};

template <typename T>
class TestSinkComponent : public WritableProvider<T>
{
Expand Down Expand Up @@ -517,6 +538,26 @@ TEST_F(TestEdges, SourceToNodeComponentToSinkComponent)
source->run();
}

TEST_F(TestEdges, SourceToRxNodeComponentToSinkComponent)
{
auto source = std::make_shared<node::TestSource<int>>();
auto node = std::make_shared<node::TestRxNodeComponent<int>>();
auto sink = std::make_shared<node::TestSinkComponent<int>>();

mrc::make_edge(*source, *node);
mrc::make_edge(*node, *sink);

node->make_stream([=](rxcpp::observable<int> input) {
return input.map([](int i) {
return i * 2;
});
});

source->run();

EXPECT_TRUE(node->stream_fn_called);
}

TEST_F(TestEdges, SourceComponentToNodeToSinkComponent)
{
auto source = std::make_shared<node::TestSourceComponent<int>>();
Expand Down Expand Up @@ -825,6 +866,10 @@ TEST_F(TestEdges, CreateAndDestroy)
auto x = std::make_shared<node::TestNodeComponent<int>>();
}

{
auto x = std::make_shared<node::TestRxNodeComponent<int>>();
}

{
auto x = std::make_shared<node::TestSinkComponent<int>>();
}
Expand Down Expand Up @@ -927,4 +972,28 @@ TEST_F(TestEdges, EdgeTapWithSpliceComponent)
source->run();
sink->run();
}

TEST_F(TestEdges, EdgeTapWithSpliceRxComponent)
{
auto source = std::make_shared<node::TestSource<int>>();
auto node = std::make_shared<node::TestRxNodeComponent<int>>();
auto sink = std::make_shared<node::TestSink<int>>();

// Original edge
mrc::make_edge(*source, *sink);

node->make_stream([=](rxcpp::observable<int> input) {
return input.map([](int i) {
return i * 2;
});
});

// Tap edge
mrc::edge::EdgeBuilder::splice_edge<int>(*source, *sink, *node, *node);

source->run();
sink->run();

EXPECT_TRUE(node->stream_fn_called);
}
} // namespace mrc
50 changes: 50 additions & 0 deletions cpp/mrc/tests/test_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "test_mrc.hpp"

#include "mrc/channel/status.hpp" // for Status
#include "mrc/core/executor.hpp"
#include "mrc/engine/pipeline/ipipeline.hpp"
#include "mrc/node/rx_node.hpp"
Expand Down Expand Up @@ -45,6 +46,7 @@
#include <mutex>
#include <set>
#include <sstream>
#include <stdexcept> // for runtime_error
#include <string>
#include <thread>
#include <utility>
Expand Down Expand Up @@ -485,6 +487,54 @@ TEST_F(TestNode, NodePrologueEpilogue)
EXPECT_EQ(epilogue_tap_sum, 20);
}

TEST_F(TestNode, RxNodeComponentThrows)
{
auto p = pipeline::make_pipeline();
std::atomic<int> throw_count = 0;
std::atomic<int> sink_call_count = 0;
std::atomic<int> complete_count = 0;

auto my_segment = p->make_segment("test_segment", [&](segment::Builder& seg) {
auto source = seg.make_source<int>("source", [&](rxcpp::subscriber<int>& s) {
s.on_next(1);
s.on_next(2);
s.on_next(3);
s.on_completed();
});

auto node_comp = seg.make_node_component<int, int>("node", rxcpp::operators::map([&](int i) -> int {
++throw_count;
throw std::runtime_error("test");
return 0;
}));

auto sink = seg.make_sink<int>(
"sinkInt",
[&](const int& x) {
++sink_call_count;
},
[&]() {
++complete_count;
});

seg.make_edge(source, node_comp);
seg.make_edge(node_comp, sink);
});

auto options = std::make_unique<Options>();
options->topology().user_cpuset("0");

Executor exec(std::move(options));
exec.register_pipeline(std::move(p));
exec.start();

EXPECT_THROW(exec.join(), std::runtime_error);

EXPECT_EQ(throw_count, 1);
EXPECT_EQ(sink_call_count, 0);
EXPECT_EQ(complete_count, 0);
}

// the parallel tests:
// - SourceMultiThread
// - SinkMultiThread
Expand Down

0 comments on commit 7b0d9fa

Please sign in to comment.