Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

fix(security): fix bug in negotiation_service::on_negotiation_request when rpc_session is closed #652

Merged
merged 18 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/runtime/security/client_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "client_negotiation.h"
#include "negotiation_utils.h"
#include "negotiation_manager.h"

#include <boost/algorithm/string/join.hpp>
#include <dsn/dist/fmt_logging.h>
Expand All @@ -29,7 +30,7 @@ namespace security {
DSN_DECLARE_bool(mandatory_auth);
extern const std::set<std::string> supported_mechanisms;

client_negotiation::client_negotiation(rpc_session *session) : negotiation(session)
client_negotiation::client_negotiation(rpc_session_ptr session) : negotiation(session)
{
_name = fmt::format("CLIENT_NEGOTIATION(SERVER={})", _session->remote_address().to_string());
}
Expand Down Expand Up @@ -179,8 +180,8 @@ void client_negotiation::send(negotiation_status::type status, const blob &msg)
req->msg = msg;

negotiation_rpc rpc(std::move(req), RPC_NEGOTIATION);
rpc.call(_session->remote_address(), nullptr, [this, rpc](error_code err) mutable {
handle_response(err, std::move(rpc.response()));
rpc.call(_session->remote_address(), nullptr, [rpc](error_code err) mutable {
negotiation_manager::on_negotiation_response(err, rpc);
});
}

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/security/client_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ namespace security {
class client_negotiation : public negotiation
{
public:
client_negotiation(rpc_session *session);
client_negotiation(rpc_session_ptr session);

void start();
void handle_response(error_code err, const negotiation_response &&response);

private:
void handle_response(error_code err, const negotiation_response &&response);
void on_recv_mechanisms(const negotiation_response &resp);
void on_mechanism_selected(const negotiation_response &resp);
void on_challenge(const negotiation_response &resp);
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#include "kinit_context.h"
#include "sasl_init.h"
#include "negotiation_service.h"
#include "negotiation_manager.h"

#include <dsn/dist/fmt_logging.h>
#include <dsn/utility/flags.h>
Expand Down
6 changes: 2 additions & 4 deletions src/runtime/security/negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ typedef rpc_holder<negotiation_request, negotiation_response> negotiation_rpc;
class negotiation
{
public:
negotiation(rpc_session *session)
negotiation(rpc_session_ptr session)
: _session(session), _status(negotiation_status::type::INVALID)
{
_sasl = create_sasl_wrapper(_session->is_client());
Expand All @@ -49,9 +49,7 @@ class negotiation
bool check_status(negotiation_status::type status, negotiation_status::type expected_status);

protected:
// The ownership of the negotiation instance is held by rpc_session.
// So negotiation keeps only a raw pointer.
rpc_session *_session;
rpc_session_ptr _session;
std::string _name;
negotiation_status::type _status;
std::string _selected_mechanism;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
// specific language governing permissions and limitations
// under the License.

#include "negotiation_service.h"
#include "negotiation_manager.h"
#include "negotiation_utils.h"
#include "server_negotiation.h"
#include "client_negotiation.h"

#include <dsn/utility/flags.h>
#include <dsn/tool-api/zlocks.h>
#include <dsn/dist/failure_detector/fd.code.definition.h>
#include <dsn/dist/fmt_logging.h>

namespace dsn {
namespace security {
Expand All @@ -38,21 +40,21 @@ inline bool in_white_list(task_code code)
return is_negotiation_message(code) || fd::is_failure_detector_message(code);
}

negotiation_map negotiation_service::_negotiations;
utils::rw_lock_nr negotiation_service::_lock;
negotiation_map negotiation_manager::_negotiations;
utils::rw_lock_nr negotiation_manager::_lock;

negotiation_service::negotiation_service() : serverlet("negotiation_service") {}
negotiation_manager::negotiation_manager() : serverlet("negotiation_manager") {}

void negotiation_service::open_service()
void negotiation_manager::open_service()
{
register_rpc_handler_with_rpc_holder(
RPC_NEGOTIATION, "Negotiation", &negotiation_service::on_negotiation_request);
RPC_NEGOTIATION, "Negotiation", &negotiation_manager::on_negotiation_request);
}

void negotiation_service::on_negotiation_request(negotiation_rpc rpc)
void negotiation_manager::on_negotiation_request(negotiation_rpc rpc)
{
dassert(!rpc.dsn_request()->io_session->is_client(),
"only server session receive negotiation request");
"only server session receives negotiation request");

// reply SASL_AUTH_DISABLE if auth is not enable
if (!security::FLAGS_enable_auth) {
Expand All @@ -67,13 +69,33 @@ void negotiation_service::on_negotiation_request(negotiation_rpc rpc)
static_cast<server_negotiation *>(_negotiations[rpc.dsn_request()->io_session].get());
}

dassert(srv_negotiation != nullptr,
"negotiation is null for msg: {}",
rpc.dsn_request()->rpc_code().to_string());
if (nullptr == srv_negotiation) {
derror_f("negotiation is null for msg: {}", rpc.dsn_request()->rpc_code().to_string());
return;
}
srv_negotiation->handle_request(rpc);
}
hycdong marked this conversation as resolved.
Show resolved Hide resolved

void negotiation_service::on_rpc_connected(rpc_session *session)
void negotiation_manager::on_negotiation_response(error_code err, negotiation_rpc rpc)
{
dassert(rpc.dsn_request()->io_session->is_client(),
"only client session receives negotiation response");

client_negotiation *cli_negotiation = nullptr;
{
utils::auto_read_lock l(_lock);
cli_negotiation =
static_cast<client_negotiation *>(_negotiations[rpc.dsn_request()->io_session].get());
}

if (nullptr == cli_negotiation) {
derror_f("negotiation is null for msg: {}", rpc.dsn_request()->rpc_code().to_string());
levy5307 marked this conversation as resolved.
Show resolved Hide resolved
return;
}
cli_negotiation->handle_response(err, std::move(rpc.response()));
}

void negotiation_manager::on_rpc_connected(rpc_session *session)
{
std::unique_ptr<negotiation> nego = security::create_negotiation(session->is_client(), session);
nego->start();
Expand All @@ -83,21 +105,21 @@ void negotiation_service::on_rpc_connected(rpc_session *session)
}
}

void negotiation_service::on_rpc_disconnected(rpc_session *session)
void negotiation_manager::on_rpc_disconnected(rpc_session *session)
{
{
utils::auto_write_lock l(_lock);
_negotiations.erase(session);
}
}

bool negotiation_service::on_rpc_recv_msg(message_ex *msg)
bool negotiation_manager::on_rpc_recv_msg(message_ex *msg)
{
return !FLAGS_mandatory_auth || in_white_list(msg->rpc_code()) ||
msg->io_session->is_negotiation_succeed();
}

bool negotiation_service::on_rpc_send_msg(message_ex *msg)
bool negotiation_manager::on_rpc_send_msg(message_ex *msg)
{
// if try_pend_message return true, it means the msg is pended to the resend message queue
return !FLAGS_mandatory_auth || in_white_list(msg->rpc_code()) ||
Expand All @@ -106,12 +128,12 @@ bool negotiation_service::on_rpc_send_msg(message_ex *msg)

void init_join_point()
{
rpc_session::on_rpc_session_connected.put_back(negotiation_service::on_rpc_connected,
rpc_session::on_rpc_session_connected.put_back(negotiation_manager::on_rpc_connected,
"security");
rpc_session::on_rpc_session_disconnected.put_back(negotiation_service::on_rpc_disconnected,
rpc_session::on_rpc_session_disconnected.put_back(negotiation_manager::on_rpc_disconnected,
"security");
rpc_session::on_rpc_recv_message.put_native(negotiation_service::on_rpc_recv_msg);
rpc_session::on_rpc_send_message.put_native(negotiation_service::on_rpc_send_msg);
rpc_session::on_rpc_recv_message.put_native(negotiation_manager::on_rpc_recv_msg);
rpc_session::on_rpc_send_message.put_native(negotiation_manager::on_rpc_send_msg);
}
} // namespace security
} // namespace dsn
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,23 @@ namespace dsn {
namespace security {
typedef std::unordered_map<rpc_session *, std::unique_ptr<negotiation>> negotiation_map;

class negotiation_service : public serverlet<negotiation_service>,
public utils::singleton<negotiation_service>
class negotiation_manager : public serverlet<negotiation_manager>,
public utils::singleton<negotiation_manager>
{
public:
static void on_rpc_connected(rpc_session *session);
static void on_rpc_disconnected(rpc_session *session);
static bool on_rpc_recv_msg(message_ex *msg);
static bool on_rpc_send_msg(message_ex *msg);
static void on_negotiation_response(error_code err, negotiation_rpc rpc);

void open_service();

private:
negotiation_service();
negotiation_manager();
void on_negotiation_request(negotiation_rpc rpc);
friend class utils::singleton<negotiation_service>;
friend class negotiation_service_test;
friend class utils::singleton<negotiation_manager>;
friend class negotiation_manager_test;

static utils::rw_lock_nr _lock; // [
static negotiation_map _negotiations;
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/server_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace security {
DSN_DECLARE_string(service_fqdn);
DSN_DECLARE_string(service_name);

server_negotiation::server_negotiation(rpc_session *session) : negotiation(session)
server_negotiation::server_negotiation(rpc_session_ptr session) : negotiation(session)
{
_name = fmt::format("SERVER_NEGOTIATION(CLIENT={})", _session->remote_address().to_string());
}
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/server_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ extern const std::set<std::string> supported_mechanisms;
class server_negotiation : public negotiation
{
public:
server_negotiation(rpc_session *session);
server_negotiation(rpc_session_ptr session);

void start();
void handle_request(negotiation_rpc rpc);
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/service_api_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
#include "runtime/rpc/rpc_engine.h"
#include "runtime/task/task_engine.h"
#include "utils/coredump.h"
#include "runtime/security/negotiation_service.h"
#include "runtime/security/negotiation_manager.h"

namespace dsn {
namespace security {
Expand Down Expand Up @@ -562,7 +562,7 @@ service_app *service_app::new_service_app(const std::string &type,

service_app::service_app(const dsn::service_app_info *info) : _info(info), _started(false)
{
security::negotiation_service::instance().open_service();
security::negotiation_manager::instance().open_service();
}

const service_app_info &service_app::info() const { return *_info; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

#include "runtime/security/negotiation_service.h"
#include "runtime/security/negotiation_manager.h"
#include "runtime/security/negotiation_utils.h"
#include "runtime/rpc/network.sim.h"

Expand All @@ -29,7 +29,7 @@ namespace security {
DSN_DECLARE_bool(enable_auth);
DSN_DECLARE_bool(mandatory_auth);

class negotiation_service_test : public testing::Test
class negotiation_manager_test : public testing::Test
{
public:
negotiation_rpc create_fake_rpc()
Expand All @@ -52,21 +52,21 @@ class negotiation_service_test : public testing::Test

void on_negotiation_request(negotiation_rpc rpc)
{
negotiation_service::instance().on_negotiation_request(rpc);
negotiation_manager::instance().on_negotiation_request(rpc);
}

bool on_rpc_recv_msg(message_ex *msg)
{
return negotiation_service::instance().on_rpc_recv_msg(msg);
return negotiation_manager::instance().on_rpc_recv_msg(msg);
}

bool on_rpc_send_msg(message_ex *msg)
{
return negotiation_service::instance().on_rpc_send_msg(msg);
return negotiation_manager::instance().on_rpc_send_msg(msg);
}
};

TEST_F(negotiation_service_test, disable_auth)
TEST_F(negotiation_manager_test, disable_auth)
{
RPC_MOCKING(negotiation_rpc)
{
Expand All @@ -78,7 +78,7 @@ TEST_F(negotiation_service_test, disable_auth)
}
}

TEST_F(negotiation_service_test, on_rpc_recv_msg)
TEST_F(negotiation_manager_test, on_rpc_recv_msg)
{
struct
{
Expand Down Expand Up @@ -107,7 +107,7 @@ TEST_F(negotiation_service_test, on_rpc_recv_msg)
}
}

TEST_F(negotiation_service_test, on_rpc_send_msg)
TEST_F(negotiation_manager_test, on_rpc_send_msg)
{
struct
{
Expand Down