From 4058664bee6880be73986afe6514df61680f5974 Mon Sep 17 00:00:00 2001 From: zhao liwei Date: Tue, 8 Sep 2020 14:51:11 +0800 Subject: [PATCH] feat(security): server_negotiation handle SASL_INITIATE request (#613) --- src/runtime/security/sasl_server_wrapper.cpp | 20 ++++-- src/runtime/security/server_negotiation.cpp | 38 ++++++++++++ src/runtime/security/server_negotiation.h | 2 + src/runtime/test/server_negotiation_test.cpp | 65 ++++++++++++++++++-- 4 files changed, 116 insertions(+), 9 deletions(-) diff --git a/src/runtime/security/sasl_server_wrapper.cpp b/src/runtime/security/sasl_server_wrapper.cpp index 6f4f070c87..ef3e892fc5 100644 --- a/src/runtime/security/sasl_server_wrapper.cpp +++ b/src/runtime/security/sasl_server_wrapper.cpp @@ -28,8 +28,10 @@ DSN_DECLARE_string(service_name); error_s sasl_server_wrapper::init() { - FAIL_POINT_INJECT_F("sasl_server_wrapper_init", - [](dsn::string_view) { return error_s::make(ERR_OK); }); + FAIL_POINT_INJECT_F("sasl_server_wrapper_init", [](dsn::string_view str) { + error_code err = error_code::try_get(str.data(), ERR_UNKNOWN); + return error_s::make(err); + }); int sasl_err = sasl_server_new( FLAGS_service_name, FLAGS_service_fqdn, nullptr, nullptr, nullptr, nullptr, 0, &_conn); @@ -40,8 +42,18 @@ error_s sasl_server_wrapper::start(const std::string &mechanism, const std::string &input, std::string &output) { - // TBD(zlw) - return error_s::make(ERR_OK); + FAIL_POINT_INJECT_F("sasl_server_wrapper_start", [](dsn::string_view str) { + error_code err = error_code::try_get(str.data(), ERR_UNKNOWN); + return error_s::make(err); + }); + + const char *msg = nullptr; + unsigned msg_len = 0; + int sasl_err = + sasl_server_start(_conn, mechanism.c_str(), input.c_str(), input.length(), &msg, &msg_len); + + output.assign(msg, msg_len); + return wrap_error(sasl_err); } error_s sasl_server_wrapper::step(const std::string &input, std::string &output) diff --git a/src/runtime/security/server_negotiation.cpp b/src/runtime/security/server_negotiation.cpp index 890cf79224..807689ce53 100644 --- a/src/runtime/security/server_negotiation.cpp +++ b/src/runtime/security/server_negotiation.cpp @@ -50,6 +50,8 @@ void server_negotiation::handle_request(negotiation_rpc rpc) on_select_mechanism(rpc); break; case negotiation_status::type::SASL_SELECT_MECHANISMS_RESP: + on_initiate(rpc); + break; case negotiation_status::type::SASL_CHALLENGE: // TBD(zlw) break; @@ -99,5 +101,41 @@ void server_negotiation::on_select_mechanism(negotiation_rpc rpc) negotiation_response &response = rpc.response(); _status = response.status = negotiation_status::type::SASL_SELECT_MECHANISMS_RESP; } + +void server_negotiation::on_initiate(negotiation_rpc rpc) +{ + const negotiation_request &request = rpc.request(); + if (!check_status(request.status, negotiation_status::type::SASL_INITIATE)) { + fail_negotiation(); + return; + } + + std::string start_output; + error_s err_s = _sasl->start(_selected_mechanism, request.msg, start_output); + return do_challenge(rpc, err_s, start_output); +} + +void server_negotiation::do_challenge(negotiation_rpc rpc, + error_s err_s, + const std::string &resp_msg) +{ + if (!err_s.is_ok() && err_s.code() != ERR_SASL_INCOMPLETE) { + dwarn_f("{}: negotiation failed, with err = {}, msg = {}", + _name, + err_s.code().to_string(), + err_s.description()); + fail_negotiation(); + return; + } + + if (err_s.is_ok()) { + negotiation_response &response = rpc.response(); + _status = response.status = negotiation_status::type::SASL_SUCC; + } else { + negotiation_response &challenge = rpc.response(); + _status = challenge.status = negotiation_status::type::SASL_CHALLENGE; + challenge.msg = resp_msg; + } +} } // namespace security } // namespace dsn diff --git a/src/runtime/security/server_negotiation.h b/src/runtime/security/server_negotiation.h index 9270a3e8ca..a8d97afb20 100644 --- a/src/runtime/security/server_negotiation.h +++ b/src/runtime/security/server_negotiation.h @@ -36,6 +36,8 @@ class server_negotiation : public negotiation private: void on_list_mechanisms(negotiation_rpc rpc); void on_select_mechanism(negotiation_rpc rpc); + void on_initiate(negotiation_rpc rpc); + void do_challenge(negotiation_rpc rpc, error_s err_s, const std::string &resp_msg); friend class server_negotiation_test; }; diff --git a/src/runtime/test/server_negotiation_test.cpp b/src/runtime/test/server_negotiation_test.cpp index fbf2847725..abc69e64fb 100644 --- a/src/runtime/test/server_negotiation_test.cpp +++ b/src/runtime/test/server_negotiation_test.cpp @@ -48,6 +48,8 @@ class server_negotiation_test : public testing::Test void on_select_mechanism(negotiation_rpc rpc) { _srv_negotiation->on_select_mechanism(rpc); } + void on_initiate(negotiation_rpc rpc) { _srv_negotiation->on_initiate(rpc); } + negotiation_status::type get_negotiation_status() { return _srv_negotiation->_status; } // _sim_session is used for holding the sim_rpc_session which is created in ctor, @@ -90,38 +92,91 @@ TEST_F(server_negotiation_test, on_select_mechanism) { struct { + std::string sasl_init_result; negotiation_status::type req_status; std::string req_msg; negotiation_status::type resp_status; negotiation_status::type nego_status; } tests[] = {{ + "ERR_OK", negotiation_status::type::SASL_SELECT_MECHANISMS, "GSSAPI", negotiation_status::type::SASL_SELECT_MECHANISMS_RESP, negotiation_status::type::SASL_SELECT_MECHANISMS_RESP, }, - {negotiation_status::type::SASL_SELECT_MECHANISMS, + {"ERR_OK", + negotiation_status::type::SASL_SELECT_MECHANISMS, "TEST", negotiation_status::type::INVALID, negotiation_status::type::SASL_AUTH_FAIL}, - {negotiation_status::type::SASL_INITIATE, + {"ERR_TIMEOUT", + negotiation_status::type::SASL_SELECT_MECHANISMS, + "GSSAPI", + negotiation_status::type::INVALID, + negotiation_status::type::SASL_AUTH_FAIL}, + {"ERR_OK", + negotiation_status::type::SASL_INITIATE, "GSSAPI", negotiation_status::type::INVALID, negotiation_status::type::SASL_AUTH_FAIL}}; - fail::setup(); - fail::cfg("sasl_server_wrapper_init", "return()"); RPC_MOCKING(negotiation_rpc) { for (const auto &test : tests) { + fail::setup(); + fail::cfg("sasl_server_wrapper_init", "return(" + test.sasl_init_result + ")"); + auto rpc = create_negotiation_rpc(test.req_status, test.req_msg); on_select_mechanism(rpc); + ASSERT_EQ(rpc.response().status, test.resp_status); + ASSERT_EQ(get_negotiation_status(), test.nego_status); + + fail::teardown(); + } + } +} +TEST_F(server_negotiation_test, on_initiate) +{ + struct + { + std::string sasl_start_result; + negotiation_status::type req_status; + negotiation_status::type resp_status; + negotiation_status::type nego_status; + } tests[] = { + {"ERR_TIMEOUT", + negotiation_status::type::SASL_INITIATE, + negotiation_status::type::INVALID, + negotiation_status::type::SASL_AUTH_FAIL}, + {"ERR_OK", + negotiation_status::type::SASL_SELECT_MECHANISMS, + negotiation_status::type::INVALID, + negotiation_status::type::SASL_AUTH_FAIL}, + {"ERR_SASL_INCOMPLETE", + negotiation_status::type::SASL_INITIATE, + negotiation_status::type::SASL_CHALLENGE, + negotiation_status::type::SASL_CHALLENGE}, + {"ERR_OK", + negotiation_status::type::SASL_INITIATE, + negotiation_status::type::SASL_SUCC, + negotiation_status::type::SASL_SUCC}, + }; + + RPC_MOCKING(negotiation_rpc) + { + for (const auto &test : tests) { + fail::setup(); + fail::cfg("sasl_server_wrapper_start", "return(" + test.sasl_start_result + ")"); + + auto rpc = create_negotiation_rpc(test.req_status, ""); + on_initiate(rpc); ASSERT_EQ(rpc.response().status, test.resp_status); ASSERT_EQ(get_negotiation_status(), test.nego_status); + + fail::teardown(); } } - fail::teardown(); } } // namespace security } // namespace dsn