Skip to content

Commit

Permalink
Override change_user() in IAM/Secrets Manager proxy (#137)
Browse files Browse the repository at this point in the history
* override change_user() in IAM/Secrets Manager proxy

* better naming
  • Loading branch information
yanw-bq authored and justing-bq committed May 4, 2023
1 parent 53715a5 commit 1e6bd15
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 34 deletions.
67 changes: 36 additions & 31 deletions driver/iam_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
// http://www.gnu.org/licenses/gpl-2.0.html.

#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <functional>

#include "aws_sdk_helper.h"
#include "driver.h"
Expand Down Expand Up @@ -64,39 +65,16 @@ IAM_PROXY::~IAM_PROXY() {
}

bool IAM_PROXY::connect(const char* host, const char* user, const char* password,
const char* database, unsigned int port, const char* socket, unsigned long flags) {
const char* database, unsigned int port, const char* socket, unsigned long flags) {

const char* auth_host = host;
if (ds->auth_host) {
auth_host = ds_get_utf8attr(ds->auth_host, &ds->auth_host8);
}

const char* region;
if (ds->auth_region8) {
region = (const char*)ds->auth_region8;
}
else {
// Go with default region if region is not provided.
region = Aws::Region::US_EAST_1;
}

std::string auth_token = this->get_auth_token(auth_host, region, ds->auth_port, user, ds->auth_expiration);

bool connect_result = next_proxy->connect(host, user, auth_token.c_str(), database, port, socket, flags);
if (!connect_result) {
Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider;
Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials();
if (credentials.IsEmpty()) {
this->set_custom_error_message(
"Could not find AWS Credentials for IAM Authentication. Please set up AWS credentials.");
}
else if (credentials.IsExpired()) {
this->set_custom_error_message(
"AWS Credentials for IAM Authentication are expired. Please refresh AWS credentials.");
}
}
auto f = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port,
socket, flags);
return invoke_func_with_generated_token(f);
}

return connect_result;
bool IAM_PROXY::change_user(const char* user, const char* passwd, const char* db) {
auto f = std::bind(&CONNECTION_PROXY::change_user, next_proxy, user, std::placeholders::_1, db);
return invoke_func_with_generated_token(f);
}

std::string IAM_PROXY::get_auth_token(
Expand Down Expand Up @@ -161,3 +139,30 @@ void IAM_PROXY::clear_token_cache() {
std::unique_lock<std::mutex> lock(token_cache_mutex);
token_cache.clear();
}

bool IAM_PROXY::invoke_func_with_generated_token(std::function<bool(const char*)> func) {

// Use user provided auth host if present, otherwise, use server host
const char* auth_host = ds->auth_host ? ds_get_utf8attr(ds->auth_host, &ds->auth_host8) : (const char*)ds->server8;

// Go with default region if region is not provided.
const char* region = ds->auth_region8 ? (const char*)ds->auth_region8 : Aws::Region::US_EAST_1;

std::string auth_token = this->get_auth_token(auth_host, region, ds->auth_port, (const char*)ds->uid8, ds->auth_expiration);

bool connect_result = func(auth_token.c_str());
if (!connect_result) {
Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider;
Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials();
if (credentials.IsEmpty()) {
this->set_custom_error_message(
"Could not find AWS Credentials for IAM Authentication. Please set up AWS credentials.");
}
else if (credentials.IsExpired()) {
this->set_custom_error_message(
"AWS Credentials for IAM Authentication are expired. Please refresh AWS credentials.");
}
}

return connect_result;
}
5 changes: 5 additions & 0 deletions driver/iam_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class IAM_PROXY : public CONNECTION_PROXY {
const char* socket,
unsigned long flags) override;

bool change_user(const char* user, const char* passwd,
const char* db) override;

std::string get_auth_token(
const char* host,const char* region, unsigned int port,
const char* user, unsigned int time_until_expiration);
Expand All @@ -91,6 +94,8 @@ class IAM_PROXY : public CONNECTION_PROXY {

static void clear_token_cache();

bool invoke_func_with_generated_token(std::function<bool(const char*)> func);

#ifdef UNIT_TEST_BUILD
// Allows for testing private/protected methods
friend class TEST_UTILS;
Expand Down
16 changes: 14 additions & 2 deletions driver/secrets_manager_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <aws/secretsmanager/SecretsManagerServiceClientModel.h>
#include <aws/secretsmanager/model/GetSecretValueRequest.h>
#include <functional>
#include <regex>

#include "aws_sdk_helper.h"
Expand Down Expand Up @@ -94,6 +95,17 @@ SECRETS_MANAGER_PROXY::~SECRETS_MANAGER_PROXY() {
bool SECRETS_MANAGER_PROXY::connect(const char* host, const char* user, const char* passwd, const char* database,
unsigned int port, const char* unix_socket, unsigned long flags) {

auto f = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, std::placeholders::_1, std::placeholders::_2, database, port, unix_socket, flags);
return invoke_func_with_retrieved_secret(f);
}

bool SECRETS_MANAGER_PROXY::change_user(const char* user, const char* passwd, const char* db) {
auto f = std::bind(&CONNECTION_PROXY::change_user, next_proxy, std::placeholders::_1, std::placeholders::_2, db);
return invoke_func_with_retrieved_secret(f);
}

bool SECRETS_MANAGER_PROXY::invoke_func_with_retrieved_secret(std::function<bool(const char*, const char*)> func) {

if (this->secret_key.first.empty()) {
const auto error = "Missing required config parameter for Secrets Manager: Secret ID";
MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] %s", error);
Expand All @@ -110,7 +122,7 @@ bool SECRETS_MANAGER_PROXY::connect(const char* host, const char* user, const ch
this->set_custom_error_message(error);
return false;
}
bool ret = next_proxy->connect(host, username.c_str(), password.c_str(), database, port, unix_socket, flags);
bool ret = func(username.c_str(), password.c_str());
if (!ret && next_proxy->error_code() == ER_ACCESS_DENIED_ERROR && !fetched) {
// Login unsuccessful with cached credentials
// Try to re-fetch credentials and try again
Expand All @@ -119,7 +131,7 @@ bool SECRETS_MANAGER_PROXY::connect(const char* host, const char* user, const ch
if (fetched) {
username = get_from_secret_json_value(USERNAME_KEY);
password = get_from_secret_json_value(PASSWORD_KEY);
ret = next_proxy->connect(host, username.c_str(), password.c_str(), database, port, unix_socket, flags);
ret = func(username.c_str(), password.c_str());
}
}

Expand Down
5 changes: 4 additions & 1 deletion driver/secrets_manager_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ class SECRETS_MANAGER_PROXY : public CONNECTION_PROXY {

bool connect(const char* host, const char* user, const char* passwd, const char* database,
unsigned int port, const char* unix_socket, unsigned long flags) override;

bool change_user(const char* user, const char* passwd,
const char* db) override;

private:
std::shared_ptr<Aws::SecretsManager::SecretsManagerClient> sm_client;
std::pair<Aws::String, Aws::String> secret_key;
Aws::Utils::Json::JsonValue secret_json_value;

bool invoke_func_with_retrieved_secret(std::function<bool(const char*, const char*)> func);
bool update_secret(bool force_re_fetch);
bool fetch_latest_credentials();
bool parse_json_value(Aws::String json_string);
Expand Down

0 comments on commit 1e6bd15

Please sign in to comment.