Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

Commit

Permalink
feat(security): server_negotiation handle SASL_INITIATE request (#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
levy5307 authored Sep 8, 2020
1 parent 4f4ba53 commit 4058664
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 9 deletions.
20 changes: 16 additions & 4 deletions src/runtime/security/sasl_server_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ DSN_DECLARE_string(service_name);

error_s sasl_server_wrapper::init()
{
FAIL_POINT_INJECT_F("sasl_server_wrapper_init",
[](dsn::string_view) { return error_s::make(ERR_OK); });
FAIL_POINT_INJECT_F("sasl_server_wrapper_init", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
return error_s::make(err);
});

int sasl_err = sasl_server_new(
FLAGS_service_name, FLAGS_service_fqdn, nullptr, nullptr, nullptr, nullptr, 0, &_conn);
Expand All @@ -40,8 +42,18 @@ error_s sasl_server_wrapper::start(const std::string &mechanism,
const std::string &input,
std::string &output)
{
// TBD(zlw)
return error_s::make(ERR_OK);
FAIL_POINT_INJECT_F("sasl_server_wrapper_start", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
return error_s::make(err);
});

const char *msg = nullptr;
unsigned msg_len = 0;
int sasl_err =
sasl_server_start(_conn, mechanism.c_str(), input.c_str(), input.length(), &msg, &msg_len);

output.assign(msg, msg_len);
return wrap_error(sasl_err);
}

error_s sasl_server_wrapper::step(const std::string &input, std::string &output)
Expand Down
38 changes: 38 additions & 0 deletions src/runtime/security/server_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ void server_negotiation::handle_request(negotiation_rpc rpc)
on_select_mechanism(rpc);
break;
case negotiation_status::type::SASL_SELECT_MECHANISMS_RESP:
on_initiate(rpc);
break;
case negotiation_status::type::SASL_CHALLENGE:
// TBD(zlw)
break;
Expand Down Expand Up @@ -99,5 +101,41 @@ void server_negotiation::on_select_mechanism(negotiation_rpc rpc)
negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_SELECT_MECHANISMS_RESP;
}

void server_negotiation::on_initiate(negotiation_rpc rpc)
{
const negotiation_request &request = rpc.request();
if (!check_status(request.status, negotiation_status::type::SASL_INITIATE)) {
fail_negotiation();
return;
}

std::string start_output;
error_s err_s = _sasl->start(_selected_mechanism, request.msg, start_output);
return do_challenge(rpc, err_s, start_output);
}

void server_negotiation::do_challenge(negotiation_rpc rpc,
error_s err_s,
const std::string &resp_msg)
{
if (!err_s.is_ok() && err_s.code() != ERR_SASL_INCOMPLETE) {
dwarn_f("{}: negotiation failed, with err = {}, msg = {}",
_name,
err_s.code().to_string(),
err_s.description());
fail_negotiation();
return;
}

if (err_s.is_ok()) {
negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_SUCC;
} else {
negotiation_response &challenge = rpc.response();
_status = challenge.status = negotiation_status::type::SASL_CHALLENGE;
challenge.msg = resp_msg;
}
}
} // namespace security
} // namespace dsn
2 changes: 2 additions & 0 deletions src/runtime/security/server_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class server_negotiation : public negotiation
private:
void on_list_mechanisms(negotiation_rpc rpc);
void on_select_mechanism(negotiation_rpc rpc);
void on_initiate(negotiation_rpc rpc);
void do_challenge(negotiation_rpc rpc, error_s err_s, const std::string &resp_msg);

friend class server_negotiation_test;
};
Expand Down
65 changes: 60 additions & 5 deletions src/runtime/test/server_negotiation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class server_negotiation_test : public testing::Test

void on_select_mechanism(negotiation_rpc rpc) { _srv_negotiation->on_select_mechanism(rpc); }

void on_initiate(negotiation_rpc rpc) { _srv_negotiation->on_initiate(rpc); }

negotiation_status::type get_negotiation_status() { return _srv_negotiation->_status; }

// _sim_session is used for holding the sim_rpc_session which is created in ctor,
Expand Down Expand Up @@ -90,38 +92,91 @@ TEST_F(server_negotiation_test, on_select_mechanism)
{
struct
{
std::string sasl_init_result;
negotiation_status::type req_status;
std::string req_msg;
negotiation_status::type resp_status;
negotiation_status::type nego_status;
} tests[] = {{
"ERR_OK",
negotiation_status::type::SASL_SELECT_MECHANISMS,
"GSSAPI",
negotiation_status::type::SASL_SELECT_MECHANISMS_RESP,
negotiation_status::type::SASL_SELECT_MECHANISMS_RESP,
},
{negotiation_status::type::SASL_SELECT_MECHANISMS,
{"ERR_OK",
negotiation_status::type::SASL_SELECT_MECHANISMS,
"TEST",
negotiation_status::type::INVALID,
negotiation_status::type::SASL_AUTH_FAIL},
{negotiation_status::type::SASL_INITIATE,
{"ERR_TIMEOUT",
negotiation_status::type::SASL_SELECT_MECHANISMS,
"GSSAPI",
negotiation_status::type::INVALID,
negotiation_status::type::SASL_AUTH_FAIL},
{"ERR_OK",
negotiation_status::type::SASL_INITIATE,
"GSSAPI",
negotiation_status::type::INVALID,
negotiation_status::type::SASL_AUTH_FAIL}};

fail::setup();
fail::cfg("sasl_server_wrapper_init", "return()");
RPC_MOCKING(negotiation_rpc)
{
for (const auto &test : tests) {
fail::setup();
fail::cfg("sasl_server_wrapper_init", "return(" + test.sasl_init_result + ")");

auto rpc = create_negotiation_rpc(test.req_status, test.req_msg);
on_select_mechanism(rpc);
ASSERT_EQ(rpc.response().status, test.resp_status);
ASSERT_EQ(get_negotiation_status(), test.nego_status);

fail::teardown();
}
}
}

TEST_F(server_negotiation_test, on_initiate)
{
struct
{
std::string sasl_start_result;
negotiation_status::type req_status;
negotiation_status::type resp_status;
negotiation_status::type nego_status;
} tests[] = {
{"ERR_TIMEOUT",
negotiation_status::type::SASL_INITIATE,
negotiation_status::type::INVALID,
negotiation_status::type::SASL_AUTH_FAIL},
{"ERR_OK",
negotiation_status::type::SASL_SELECT_MECHANISMS,
negotiation_status::type::INVALID,
negotiation_status::type::SASL_AUTH_FAIL},
{"ERR_SASL_INCOMPLETE",
negotiation_status::type::SASL_INITIATE,
negotiation_status::type::SASL_CHALLENGE,
negotiation_status::type::SASL_CHALLENGE},
{"ERR_OK",
negotiation_status::type::SASL_INITIATE,
negotiation_status::type::SASL_SUCC,
negotiation_status::type::SASL_SUCC},
};

RPC_MOCKING(negotiation_rpc)
{
for (const auto &test : tests) {
fail::setup();
fail::cfg("sasl_server_wrapper_start", "return(" + test.sasl_start_result + ")");

auto rpc = create_negotiation_rpc(test.req_status, "");
on_initiate(rpc);
ASSERT_EQ(rpc.response().status, test.resp_status);
ASSERT_EQ(get_negotiation_status(), test.nego_status);

fail::teardown();
}
}
fail::teardown();
}
} // namespace security
} // namespace dsn

0 comments on commit 4058664

Please sign in to comment.