Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override change_user() in IAM/Secrets Manager proxy #137

Merged
merged 2 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions driver/iam_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
// http://www.gnu.org/licenses/gpl-2.0.html.

#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <functional>

#include "aws_sdk_helper.h"
#include "driver.h"
Expand Down Expand Up @@ -64,39 +65,16 @@ IAM_PROXY::~IAM_PROXY() {
}

bool IAM_PROXY::connect(const char* host, const char* user, const char* password,
const char* database, unsigned int port, const char* socket, unsigned long flags) {
const char* database, unsigned int port, const char* socket, unsigned long flags) {

const char* auth_host = host;
if (ds->auth_host) {
auth_host = ds_get_utf8attr(ds->auth_host, &ds->auth_host8);
}

const char* region;
if (ds->auth_region8) {
region = (const char*)ds->auth_region8;
}
else {
// Go with default region if region is not provided.
region = Aws::Region::US_EAST_1;
}

std::string auth_token = this->get_auth_token(auth_host, region, ds->auth_port, user, ds->auth_expiration);

bool connect_result = next_proxy->connect(host, user, auth_token.c_str(), database, port, socket, flags);
if (!connect_result) {
Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider;
Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials();
if (credentials.IsEmpty()) {
this->set_custom_error_message(
"Could not find AWS Credentials for IAM Authentication. Please set up AWS credentials.");
}
else if (credentials.IsExpired()) {
this->set_custom_error_message(
"AWS Credentials for IAM Authentication are expired. Please refresh AWS credentials.");
}
}
auto f = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port,
socket, flags);
return invoke_func_with_generated_token(f);
}

return connect_result;
bool IAM_PROXY::change_user(const char* user, const char* passwd, const char* db) {
auto f = std::bind(&CONNECTION_PROXY::change_user, next_proxy, user, std::placeholders::_1, db);
return invoke_func_with_generated_token(f);
}

std::string IAM_PROXY::get_auth_token(
Expand Down Expand Up @@ -161,3 +139,30 @@ void IAM_PROXY::clear_token_cache() {
std::unique_lock<std::mutex> lock(token_cache_mutex);
token_cache.clear();
}

bool IAM_PROXY::invoke_func_with_generated_token(std::function<bool(const char*)> func) {

// Use user provided auth host if present, otherwise, use server host
const char* auth_host = ds->auth_host ? ds_get_utf8attr(ds->auth_host, &ds->auth_host8) : (const char*)ds->server8;

// Go with default region if region is not provided.
const char* region = ds->auth_region8 ? (const char*)ds->auth_region8 : Aws::Region::US_EAST_1;

std::string auth_token = this->get_auth_token(auth_host, region, ds->auth_port, (const char*)ds->uid8, ds->auth_expiration);

bool connect_result = func(auth_token.c_str());
if (!connect_result) {
Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider;
Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials();
if (credentials.IsEmpty()) {
this->set_custom_error_message(
"Could not find AWS Credentials for IAM Authentication. Please set up AWS credentials.");
}
else if (credentials.IsExpired()) {
this->set_custom_error_message(
"AWS Credentials for IAM Authentication are expired. Please refresh AWS credentials.");
}
}

return connect_result;
}
5 changes: 5 additions & 0 deletions driver/iam_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class IAM_PROXY : public CONNECTION_PROXY {
const char* socket,
unsigned long flags) override;

bool change_user(const char* user, const char* passwd,
const char* db) override;

std::string get_auth_token(
const char* host,const char* region, unsigned int port,
const char* user, unsigned int time_until_expiration);
Expand All @@ -91,6 +94,8 @@ class IAM_PROXY : public CONNECTION_PROXY {

static void clear_token_cache();

bool invoke_func_with_generated_token(std::function<bool(const char*)> func);

#ifdef UNIT_TEST_BUILD
// Allows for testing private/protected methods
friend class TEST_UTILS;
Expand Down
16 changes: 14 additions & 2 deletions driver/secrets_manager_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <aws/secretsmanager/SecretsManagerServiceClientModel.h>
#include <aws/secretsmanager/model/GetSecretValueRequest.h>
#include <functional>
#include <regex>

#include "aws_sdk_helper.h"
Expand Down Expand Up @@ -94,6 +95,17 @@ SECRETS_MANAGER_PROXY::~SECRETS_MANAGER_PROXY() {
bool SECRETS_MANAGER_PROXY::connect(const char* host, const char* user, const char* passwd, const char* database,
unsigned int port, const char* unix_socket, unsigned long flags) {

auto f = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, std::placeholders::_1, std::placeholders::_2, database, port, unix_socket, flags);
return invoke_func_with_retrieved_secret(f);
}

bool SECRETS_MANAGER_PROXY::change_user(const char* user, const char* passwd, const char* db) {
auto f = std::bind(&CONNECTION_PROXY::change_user, next_proxy, std::placeholders::_1, std::placeholders::_2, db);
return invoke_func_with_retrieved_secret(f);
}

bool SECRETS_MANAGER_PROXY::invoke_func_with_retrieved_secret(std::function<bool(const char*, const char*)> func) {

if (this->secret_key.first.empty()) {
const auto error = "Missing required config parameter for Secrets Manager: Secret ID";
MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] %s", error);
Expand All @@ -110,7 +122,7 @@ bool SECRETS_MANAGER_PROXY::connect(const char* host, const char* user, const ch
this->set_custom_error_message(error);
return false;
}
bool ret = next_proxy->connect(host, username.c_str(), password.c_str(), database, port, unix_socket, flags);
bool ret = func(username.c_str(), password.c_str());
if (!ret && next_proxy->error_code() == ER_ACCESS_DENIED_ERROR && !fetched) {
// Login unsuccessful with cached credentials
// Try to re-fetch credentials and try again
Expand All @@ -119,7 +131,7 @@ bool SECRETS_MANAGER_PROXY::connect(const char* host, const char* user, const ch
if (fetched) {
username = get_from_secret_json_value(USERNAME_KEY);
password = get_from_secret_json_value(PASSWORD_KEY);
ret = next_proxy->connect(host, username.c_str(), password.c_str(), database, port, unix_socket, flags);
ret = func(username.c_str(), password.c_str());
}
}

Expand Down
5 changes: 4 additions & 1 deletion driver/secrets_manager_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ class SECRETS_MANAGER_PROXY : public CONNECTION_PROXY {

bool connect(const char* host, const char* user, const char* passwd, const char* database,
unsigned int port, const char* unix_socket, unsigned long flags) override;

bool change_user(const char* user, const char* passwd,
const char* db) override;

private:
std::shared_ptr<Aws::SecretsManager::SecretsManagerClient> sm_client;
std::pair<Aws::String, Aws::String> secret_key;
Aws::Utils::Json::JsonValue secret_json_value;

bool invoke_func_with_retrieved_secret(std::function<bool(const char*, const char*)> func);
bool update_secret(bool force_re_fetch);
bool fetch_latest_credentials();
bool parse_json_value(Aws::String json_string);
Expand Down