diff --git a/src/v/ssx/sharded_ptr.h b/src/v/ssx/sharded_ptr.h new file mode 100644 index 000000000000..bb61c900ff90 --- /dev/null +++ b/src/v/ssx/sharded_ptr.h @@ -0,0 +1,126 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.md + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0 + */ + +#pragma once + +#include "base/seastarx.h" +#include "base/vassert.h" +#include "utils/mutex.h" + +#include +#include + +#include + +namespace ssx { + +/// A pointer that is safe to share between shards. +/// +/// The pointer can be reset only by the home shard; other shards shall not +/// observe the change until the update has run on their reactor. +/// +/// As such, it is safe to maintain a pointer or reference to the pointee an any +/// shard until the next yield point. +template +class sharded_ptr { +public: + sharded_ptr() + : _shard{ss::this_shard_id()} {} + ~sharded_ptr() noexcept = default; + + sharded_ptr(sharded_ptr&& other) noexcept = default; + sharded_ptr& operator=(sharded_ptr&&) noexcept = default; + + sharded_ptr(sharded_ptr const&) = delete; + sharded_ptr& operator=(sharded_ptr const&) = delete; + + /// dereferences pointer to the managed object for the local shard. + /// + /// reset must have been called at least once. + /// stop must not have been called. + T& operator*() const { return local().operator*(); } + + /// dereferences pointer to the managed object for the local shard. + /// + /// reset must have been called at least once. + /// stop must not have been called. + T* operator->() const { return local().operator->(); } + + /// checks if there is an associated managed object on the local shard. + /// + /// This is safe to call at any time on any shard. + explicit operator bool() const { + return _state.size() > ss::this_shard_id() && local() != nullptr; + } + + /// replaces the managed object. + /// + /// Must be called on the home shard and is safe to call consurrently. + ss::future<> reset(std::shared_ptr u = nullptr) { + assert_shard(); + auto mu{co_await _mutex.get_units()}; + if (_state.empty()) { + _state.resize(ss::smp::count); + } + + co_await ss::smp::invoke_on_all([this, u]() noexcept { local() = u; }); + } + + /// replaces the managed object by constructing a new one. + /// + /// Must be called on the home shard and is safe to call concurrently. + /// returns an ss::broken_semaphore if stop() has been called. + template + ss::future<> reset(Args&&... args) { + return reset(std::make_shared(std::forward(args)...)); + } + + /// stop managing any object. + /// + /// Must be called on the home shard and is safe to call concurrently. + /// returns an ss::broken_semaphore if stop() has been called. + ss::future<> stop() { + co_await _mutex.with([this] { _mutex.broken(); }); + _state = {}; + } + + /// return the home shard. + /// + /// This is safe to call at any time on any shard. + auto shard_id() const { return _shard; } + + /// get a reference to the local instance + /// + /// reset must have been called at least once. + /// stop must not have been called. + std::shared_ptr const& local() const { + return _state[ss::this_shard_id()]; + } + + /// get a reference to the local instance + /// + /// reset must have been called at least once. + /// stop must not have been called. + std::shared_ptr& local() { return _state[ss::this_shard_id()]; } + +private: + void assert_shard() const { + vassert( + ss::this_shard_id() == _shard, + "reset must be called on home shard: ", + _shard); + } + ss::shard_id _shard; + std::vector> _state; + mutex _mutex; +}; + +} // namespace ssx diff --git a/src/v/ssx/tests/CMakeLists.txt b/src/v/ssx/tests/CMakeLists.txt index 8a4498047773..bff8bf5b723c 100644 --- a/src/v/ssx/tests/CMakeLists.txt +++ b/src/v/ssx/tests/CMakeLists.txt @@ -30,6 +30,7 @@ rp_test( BINARY_NAME ssx_multi_thread SOURCES abort_source_test.cc + sharded_ptr_test.cc LIBRARIES v::seastar_testing_main ARGS "-- -c 2" LABELS ssx diff --git a/src/v/ssx/tests/sharded_ptr_test.cc b/src/v/ssx/tests/sharded_ptr_test.cc new file mode 100644 index 000000000000..cd16088711dc --- /dev/null +++ b/src/v/ssx/tests/sharded_ptr_test.cc @@ -0,0 +1,105 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.md +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0 + +#include "base/seastarx.h" +#include "ssx/sharded_ptr.h" + +#include +#include +#include + +#include + +SEASTAR_THREAD_TEST_CASE(test_sharded_ptr_basic_ops) { + ssx::sharded_ptr p0; + BOOST_REQUIRE(!p0); + BOOST_REQUIRE_EQUAL(p0.shard_id(), ss::this_shard_id()); + + // Test operator bool (before reset) + for (auto i : boost::irange(0u, ss::smp::count)) { + ss::smp::submit_to(i, [&]() { BOOST_REQUIRE(!p0); }).get(); + } + + // Test reset + p0.reset(std::make_shared(43)).get(); + p0.reset(43).get(); + + // Test operator bool and deref (after reset) + for (auto i : boost::irange(0u, ss::smp::count)) { + ss::smp::submit_to(i, [&]() { + BOOST_REQUIRE(p0 && p0.operator*() == 43); + BOOST_REQUIRE(p0 && *p0.operator->() == 43); + }).get(); + } + + // Test operator bool (after stop) + p0.stop().get(); + for (auto i : boost::irange(0u, ss::smp::count)) { + ss::smp::submit_to(i, [&]() { BOOST_REQUIRE(!p0); }).get(); + } + + // Test reset (after stop) + try { + p0.reset().get(); + BOOST_FAIL("Expected exception"); + } catch (ss::broken_semaphore const&) { + // Success + } catch (...) { + BOOST_FAIL("Unexpected exception"); + } + + // Test stop (after stop) + try { + p0.stop().get(); + BOOST_FAIL("Expected exception"); + } catch (ss::broken_semaphore const&) { + // Success + } catch (...) { + BOOST_FAIL("Unexpected exception"); + } +} + +SEASTAR_THREAD_TEST_CASE(test_sharded_ptr_stop_without_reset) { + ssx::sharded_ptr p0; + p0.stop().get(); +} + +SEASTAR_THREAD_TEST_CASE(test_sharded_ptr_shared) { + ssx::sharded_ptr p0; + p0.reset(42).get(); + + std::shared_ptr shared = p0.local(); + std::weak_ptr weak = p0.local(); + BOOST_REQUIRE(p0 && *p0 == 42); + + p0.reset().get(); + BOOST_REQUIRE(shared.get() != nullptr); + BOOST_REQUIRE(weak.lock().get() != nullptr); + + shared.reset(); + BOOST_REQUIRE(shared.get() == nullptr); + BOOST_REQUIRE(weak.lock().get() == nullptr); +} + +SEASTAR_THREAD_TEST_CASE(test_sharded_ptr_move) { + ssx::sharded_ptr p0; + p0.reset(42).get(); + + std::shared_ptr shared = p0.local(); + + // Move construction + auto p1{std::move(p0)}; + BOOST_REQUIRE(shared && *shared == 42); + BOOST_REQUIRE(p1 && p1.local() && *p1 == 42); + + // Move assignment + p0 = std::move(p1); + BOOST_REQUIRE(shared && *shared == 42); + BOOST_REQUIRE(p0 && p0.local() && *p0 == 42); +}