Skip to content

Commit

Permalink
CXXCBC-574: fix memory leak when open fails in public API
Browse files Browse the repository at this point in the history
  • Loading branch information
avsej committed Oct 15, 2024
1 parent 862fd4e commit 31fb90e
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 188 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ set(couchbase_cxx_client_FILES
core/io/mcbp_message.cxx
core/io/mcbp_parser.cxx
core/io/mcbp_session.cxx
core/io/streams.cxx
core/key_value_config.cxx
core/logger/custom_rotating_file_sink.cxx
core/logger/logger.cxx
Expand Down
63 changes: 29 additions & 34 deletions core/impl/cluster.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,8 @@ class cluster_impl : public std::enable_shared_from_this<cluster_impl>
// We cannot use close() method here, as it is capturing self as a shared
// pointer to extend lifetime for the user's callback. Here the reference
// counter has reached zero already, so we can only capture `*this`.
std::thread([this, &barrier]() mutable {
if (auto txns = std::move(transactions_); txns != nullptr) {
// blocks until cleanup is finished
txns->close();
}
std::promise<void> core_stopped;
auto f = core_stopped.get_future();
core_.close([&core_stopped]() {
core_stopped.set_value();
});
f.get();
io_.stop();
if (io_thread_.joinable()) {
io_thread_.join();
}
std::thread([this, barrier = std::move(barrier)]() mutable {
do_close();
barrier.set_value();
}).detach();

Expand Down Expand Up @@ -372,20 +359,7 @@ class cluster_impl : public std::enable_shared_from_this<cluster_impl>
{
// Spawn new thread to avoid joining IO thread from the same thread
std::thread([self = shared_from_this(), handler = std::move(handler)]() mutable {
if (auto txns = std::move(self->transactions_); txns != nullptr) {
// blocks until cleanup is finished
txns->close();
}
std::promise<void> barrier;
auto future = barrier.get_future();
self->core_.close([&barrier]() {
barrier.set_value();
});
future.get();
self->io_.stop();
if (self->io_thread_.joinable()) {
self->io_thread_.join();
}
self->do_close();
handler();
}).detach();
}
Expand All @@ -401,7 +375,25 @@ class cluster_impl : public std::enable_shared_from_this<cluster_impl>
}

private:
asio::io_context io_{ ASIO_CONCURRENCY_HINT_1 };
void do_close()
{
if (auto txns = std::move(transactions_); txns != nullptr) {
// blocks until cleanup is finished
txns->close();
}
std::promise<void> core_stopped;
auto f = core_stopped.get_future();
core_.close([core_stopped = std::move(core_stopped)]() mutable {
core_stopped.set_value();
});
f.get();
io_.stop();
if (io_thread_.joinable()) {
io_thread_.join();
}
}

asio::io_context io_{ ASIO_CONCURRENCY_HINT_SAFE };
core::cluster core_{ io_ };
std::shared_ptr<core::transactions::transactions> transactions_{ nullptr };
std::thread io_thread_{ [&io = io_] {
Expand Down Expand Up @@ -540,12 +532,15 @@ cluster::connect(const std::string& connection_string,
// Spawn new thread for connection to ensure that cluster_impl pointer will
// not be deallocated in IO thread in case of error.
std::thread([connection_string, options, handler = std::move(handler)]() {
auto impl = std::make_shared<cluster_impl>();
auto barrier = std::make_shared<std::promise<std::pair<error, cluster>>>();
auto future = barrier->get_future();
impl->open(connection_string, options, [barrier](auto err, auto c) {
barrier->set_value({ std::move(err), std::move(c) });
});
{
auto impl = std::make_shared<cluster_impl>();
impl->open(connection_string, options, [barrier](auto err, auto c) {
barrier->set_value({ std::move(err), std::move(c) });
});
}

auto [err, c] = future.get();
handler(std::move(err), std::move(c));
}).detach();
Expand Down
256 changes: 256 additions & 0 deletions core/io/streams.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
/* -*- Mode: C++; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*- */
/*
* Copyright 2020-2024 Couchbase, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "streams.hxx"

#include "core/platform/uuid.h"

#include <asio.hpp>
#include <asio/error.hpp>
#include <asio/ssl.hpp>

namespace couchbase::core::io
{
stream_impl::stream_impl(asio::io_context& ctx, bool is_tls)
: strand_(asio::make_strand(ctx))
, tls_(is_tls)
, id_(uuid::to_string(uuid::random()))
{
}

auto
stream_impl::log_prefix() const -> std::string_view
{
return tls_ ? "tls" : "plain";
}

auto
stream_impl::id() const -> const std::string&
{
return id_;
}

plain_stream_impl::plain_stream_impl(asio::io_context& ctx)
: stream_impl(ctx, false)
, stream_(std::make_shared<asio::ip::tcp::socket>(strand_))
{
}

auto
plain_stream_impl::local_endpoint() const -> asio::ip::tcp::endpoint
{
if (!stream_) {
return {};
}
std::error_code ec;
auto res = stream_->local_endpoint(ec);
if (ec) {
return {};
}
return res;
}

auto
plain_stream_impl::is_open() const -> bool
{
if (stream_) {
return stream_->is_open();
}
return false;
}

void
plain_stream_impl::close(utils::movable_function<void(std::error_code)>&& handler)
{
if (!stream_) {
return handler(asio::error::bad_descriptor);
}
return asio::post(strand_, [stream = std::move(stream_), handler = std::move(handler)]() {
asio::error_code ec{};
stream->shutdown(asio::socket_base::shutdown_both, ec);
stream->close(ec);
handler(ec);
});
}

void
plain_stream_impl::set_options()
{
if (!is_open()) {
return;
}
std::error_code ec{};
stream_->set_option(asio::ip::tcp::no_delay{ true }, ec);
stream_->set_option(asio::socket_base::keep_alive{ true }, ec);
}

void
plain_stream_impl::async_connect(
const asio::ip::tcp::resolver::results_type::endpoint_type& endpoint,
utils::movable_function<void(std::error_code)>&& handler)
{
if (!stream_) {
id_ = uuid::to_string(uuid::random());
stream_ = std::make_shared<asio::ip::tcp::socket>(strand_);
}
return stream_->async_connect(endpoint,
[stream = stream_, handler = std::move(handler)](auto ec) {
return handler(ec);
});
}

void
plain_stream_impl::async_write(
std::vector<asio::const_buffer>& buffers,
utils::movable_function<void(std::error_code, std::size_t)>&& handler)
{
if (!is_open()) {
return handler(asio::error::bad_descriptor, {});
}
return asio::async_write(
*stream_,
buffers,
[stream = stream_, handler = std::move(handler)](auto ec, auto bytes_transferred) {
return handler(ec, bytes_transferred);
});
}

void
plain_stream_impl::async_read_some(
asio::mutable_buffer buffer,
utils::movable_function<void(std::error_code, std::size_t)>&& handler)
{
if (!is_open()) {
return handler(asio::error::bad_descriptor, {});
}
return stream_->async_read_some(buffer, std::move(handler));
}

tls_stream_impl::tls_stream_impl(asio::io_context& ctx, asio::ssl::context& tls)
: stream_impl(ctx, true)
, tls_(tls)
, stream_(
std::make_shared<asio::ssl::stream<asio::ip::tcp::socket>>(asio::ip::tcp::socket(strand_),
tls_))
{
}

auto
tls_stream_impl::local_endpoint() const -> asio::ip::tcp::endpoint
{
if (!stream_) {
return {};
}
std::error_code ec;
auto res = stream_->lowest_layer().local_endpoint(ec);
if (ec) {
return {};
}
return res;
}

auto
tls_stream_impl::is_open() const -> bool
{
if (stream_) {
return stream_->lowest_layer().is_open();
}
return false;
}

void
tls_stream_impl::close(utils::movable_function<void(std::error_code)>&& handler)
{
if (!stream_) {
return handler(asio::error::bad_descriptor);
}
return asio::post(strand_, [stream = std::move(stream_), handler = std::move(handler)]() {
asio::error_code ec{};
stream->lowest_layer().shutdown(asio::socket_base::shutdown_both, ec);
stream->lowest_layer().close(ec);
handler(ec);
});
}

void
tls_stream_impl::set_options()
{
if (!is_open()) {
return;
}
std::error_code ec{};
stream_->lowest_layer().set_option(asio::ip::tcp::no_delay{ true }, ec);
stream_->lowest_layer().set_option(asio::socket_base::keep_alive{ true }, ec);
}

void
tls_stream_impl::async_connect(const asio::ip::tcp::resolver::results_type::endpoint_type& endpoint,
utils::movable_function<void(std::error_code)>&& handler)
{
if (!stream_) {
id_ = uuid::to_string(uuid::random());
stream_ = std::make_shared<asio::ssl::stream<asio::ip::tcp::socket>>(
asio::ip::tcp::socket(strand_), tls_);
}
return stream_->lowest_layer().async_connect(
endpoint, [stream = stream_, handler = std::move(handler)](std::error_code ec_connect) mutable {
if (ec_connect == asio::error::operation_aborted) {
return;
}
if (ec_connect) {
return handler(ec_connect);
}
stream->async_handshake(
asio::ssl::stream_base::client,
[stream, handler = std::move(handler)](std::error_code ec_handshake) mutable {
if (ec_handshake == asio::error::operation_aborted) {
return;
}
return handler(ec_handshake);
});
});
}

void
tls_stream_impl::async_write(std::vector<asio::const_buffer>& buffers,
utils::movable_function<void(std::error_code, std::size_t)>&& handler)
{
if (!is_open()) {
return handler(asio::error::bad_descriptor, {});
}
return asio::async_write(
*stream_,
buffers,
[stream = stream_, handler = std::move(handler)](auto ec, auto bytes_transferred) {
return handler(ec, bytes_transferred);
});
}

void
tls_stream_impl::async_read_some(
asio::mutable_buffer buffer,
utils::movable_function<void(std::error_code, std::size_t)>&& handler)
{
if (!is_open()) {
return handler(asio::error::bad_descriptor, {});
}
return stream_->async_read_some(
buffer, [stream = stream_, handler = std::move(handler)](auto ec, auto bytes_transferred) {
return handler(ec, bytes_transferred);
});
}
} // namespace couchbase::core::io
Loading

0 comments on commit 31fb90e

Please sign in to comment.