diff --git a/driver/iam_proxy.cc b/driver/iam_proxy.cc index b633e7921..fef3a3222 100644 --- a/driver/iam_proxy.cc +++ b/driver/iam_proxy.cc @@ -28,6 +28,7 @@ // http://www.gnu.org/licenses/gpl-2.0.html. #include +#include #include "aws_sdk_helper.h" #include "driver.h" @@ -64,39 +65,16 @@ IAM_PROXY::~IAM_PROXY() { } bool IAM_PROXY::connect(const char* host, const char* user, const char* password, - const char* database, unsigned int port, const char* socket, unsigned long flags) { + const char* database, unsigned int port, const char* socket, unsigned long flags) { - const char* auth_host = host; - if (ds->auth_host) { - auth_host = ds_get_utf8attr(ds->auth_host, &ds->auth_host8); - } - - const char* region; - if (ds->auth_region8) { - region = (const char*)ds->auth_region8; - } - else { - // Go with default region if region is not provided. - region = Aws::Region::US_EAST_1; - } - - std::string auth_token = this->get_auth_token(auth_host, region, ds->auth_port, user, ds->auth_expiration); - - bool connect_result = next_proxy->connect(host, user, auth_token.c_str(), database, port, socket, flags); - if (!connect_result) { - Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider; - Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials(); - if (credentials.IsEmpty()) { - this->set_custom_error_message( - "Could not find AWS Credentials for IAM Authentication. Please set up AWS credentials."); - } - else if (credentials.IsExpired()) { - this->set_custom_error_message( - "AWS Credentials for IAM Authentication are expired. Please refresh AWS credentials."); - } - } + auto f = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port, + socket, flags); + return invoke_func_with_generated_token(f); +} - return connect_result; +bool IAM_PROXY::change_user(const char* user, const char* passwd, const char* db) { + auto f = std::bind(&CONNECTION_PROXY::change_user, next_proxy, user, std::placeholders::_1, db); + return invoke_func_with_generated_token(f); } std::string IAM_PROXY::get_auth_token( @@ -161,3 +139,30 @@ void IAM_PROXY::clear_token_cache() { std::unique_lock lock(token_cache_mutex); token_cache.clear(); } + +bool IAM_PROXY::invoke_func_with_generated_token(std::function func) { + + // Use user provided auth host if present, otherwise, use server host + const char* auth_host = ds->auth_host ? ds_get_utf8attr(ds->auth_host, &ds->auth_host8) : (const char*)ds->server8; + + // Go with default region if region is not provided. + const char* region = ds->auth_region8 ? (const char*)ds->auth_region8 : Aws::Region::US_EAST_1; + + std::string auth_token = this->get_auth_token(auth_host, region, ds->auth_port, (const char*)ds->uid8, ds->auth_expiration); + + bool connect_result = func(auth_token.c_str()); + if (!connect_result) { + Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider; + Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials(); + if (credentials.IsEmpty()) { + this->set_custom_error_message( + "Could not find AWS Credentials for IAM Authentication. Please set up AWS credentials."); + } + else if (credentials.IsExpired()) { + this->set_custom_error_message( + "AWS Credentials for IAM Authentication are expired. Please refresh AWS credentials."); + } + } + + return connect_result; +} diff --git a/driver/iam_proxy.h b/driver/iam_proxy.h index eb1601698..7b56833ce 100644 --- a/driver/iam_proxy.h +++ b/driver/iam_proxy.h @@ -74,6 +74,9 @@ class IAM_PROXY : public CONNECTION_PROXY { const char* socket, unsigned long flags) override; + bool change_user(const char* user, const char* passwd, + const char* db) override; + std::string get_auth_token( const char* host,const char* region, unsigned int port, const char* user, unsigned int time_until_expiration); @@ -91,6 +94,8 @@ class IAM_PROXY : public CONNECTION_PROXY { static void clear_token_cache(); + bool invoke_func_with_generated_token(std::function func); + #ifdef UNIT_TEST_BUILD // Allows for testing private/protected methods friend class TEST_UTILS; diff --git a/driver/secrets_manager_proxy.cc b/driver/secrets_manager_proxy.cc index 2f8f7ccdc..0c4fec915 100644 --- a/driver/secrets_manager_proxy.cc +++ b/driver/secrets_manager_proxy.cc @@ -29,6 +29,7 @@ #include #include +#include #include #include "aws_sdk_helper.h" @@ -94,6 +95,17 @@ SECRETS_MANAGER_PROXY::~SECRETS_MANAGER_PROXY() { bool SECRETS_MANAGER_PROXY::connect(const char* host, const char* user, const char* passwd, const char* database, unsigned int port, const char* unix_socket, unsigned long flags) { + auto f = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, std::placeholders::_1, std::placeholders::_2, database, port, unix_socket, flags); + return invoke_func_with_retrieved_secret(f); +} + +bool SECRETS_MANAGER_PROXY::change_user(const char* user, const char* passwd, const char* db) { + auto f = std::bind(&CONNECTION_PROXY::change_user, next_proxy, std::placeholders::_1, std::placeholders::_2, db); + return invoke_func_with_retrieved_secret(f); +} + +bool SECRETS_MANAGER_PROXY::invoke_func_with_retrieved_secret(std::function func) { + if (this->secret_key.first.empty()) { const auto error = "Missing required config parameter for Secrets Manager: Secret ID"; MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] %s", error); @@ -110,7 +122,7 @@ bool SECRETS_MANAGER_PROXY::connect(const char* host, const char* user, const ch this->set_custom_error_message(error); return false; } - bool ret = next_proxy->connect(host, username.c_str(), password.c_str(), database, port, unix_socket, flags); + bool ret = func(username.c_str(), password.c_str()); if (!ret && next_proxy->error_code() == ER_ACCESS_DENIED_ERROR && !fetched) { // Login unsuccessful with cached credentials // Try to re-fetch credentials and try again @@ -119,7 +131,7 @@ bool SECRETS_MANAGER_PROXY::connect(const char* host, const char* user, const ch if (fetched) { username = get_from_secret_json_value(USERNAME_KEY); password = get_from_secret_json_value(PASSWORD_KEY); - ret = next_proxy->connect(host, username.c_str(), password.c_str(), database, port, unix_socket, flags); + ret = func(username.c_str(), password.c_str()); } } diff --git a/driver/secrets_manager_proxy.h b/driver/secrets_manager_proxy.h index d1798df6f..7a6176309 100644 --- a/driver/secrets_manager_proxy.h +++ b/driver/secrets_manager_proxy.h @@ -48,12 +48,15 @@ class SECRETS_MANAGER_PROXY : public CONNECTION_PROXY { bool connect(const char* host, const char* user, const char* passwd, const char* database, unsigned int port, const char* unix_socket, unsigned long flags) override; + + bool change_user(const char* user, const char* passwd, + const char* db) override; private: std::shared_ptr sm_client; std::pair secret_key; Aws::Utils::Json::JsonValue secret_json_value; - + bool invoke_func_with_retrieved_secret(std::function func); bool update_secret(bool force_re_fetch); bool fetch_latest_credentials(); bool parse_json_value(Aws::String json_string);