diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 9b299c96e..b52fa9c21 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -57,15 +57,48 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) SET(DRIVER_NAME "awsmysqlodbc${CONNECTOR_DRIVER_TYPE_SHORT}") SET(DRIVER_SRCS - aws_sdk_helper.cc base_metrics_holder.cc catalog.cc catalog_no_i_s.cc cluster_topology_info.cc - cluster_aware_hit_metrics_holder.cc cluster_aware_metrics_container.cc - cluster_aware_metrics.cc cluster_aware_time_metrics_holder.cc - connect.cc connection_handler.cc connection_proxy.cc cursor.cc desc.cc dll.cc driver.cc efm_proxy.cc - error.cc execute.cc failover_handler.cc - failover_reader_handler.cc failover_writer_handler.cc handle.cc host_info.cc info.cc - monitor.cc monitor_connection_context.cc monitor_service.cc monitor_thread_container.cc - my_prepared_stmt.cc my_stmt.cc mylog.cc mysql_proxy.cc options.cc parse.cc prepare.cc query_parsing.cc - results.cc topology_service.cc transact.cc utility.cc) + aws_sdk_helper.cc + base_metrics_holder.cc + catalog.cc + catalog_no_i_s.cc + cluster_topology_info.cc + cluster_aware_hit_metrics_holder.cc + cluster_aware_metrics_container.cc + cluster_aware_metrics.cc + cluster_aware_time_metrics_holder.cc + connect.cc + connection_handler.cc + connection_proxy.cc + cursor.cc + desc.cc + dll.cc + driver.cc + efm_proxy.cc + error.cc + execute.cc + failover_handler.cc + failover_reader_handler.cc + failover_writer_handler.cc + handle.cc + host_info.cc + iam_proxy.cc + info.cc + monitor.cc + monitor_connection_context.cc + monitor_service.cc + monitor_thread_container.cc + my_prepared_stmt.cc + my_stmt.cc + mylog.cc + mysql_proxy.cc + options.cc + parse.cc + prepare.cc + query_parsing.cc + results.cc + topology_service.cc + transact.cc + utility.cc) IF(UNICODE) SET(DRIVER_SRCS ${DRIVER_SRCS} unicode.cc) @@ -82,10 +115,32 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) CONFIGURE_FILE(${CMAKE_SOURCE_DIR}/driver/driver.def.cmake ${CMAKE_SOURCE_DIR}/driver/driver${CONNECTOR_DRIVER_TYPE_SHORT}.def @ONLY) CONFIGURE_FILE(${CMAKE_SOURCE_DIR}/driver/driver.rc.cmake ${CMAKE_SOURCE_DIR}/driver/driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc @ONLY) SET(DRIVER_SRCS ${DRIVER_SRCS} driver${CONNECTOR_DRIVER_TYPE_SHORT}.def driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc - aws_sdk_helper.h base_metrics_holder.h catalog.h cluster_aware_hit_metrics_holder.h cluster_aware_metrics_container.h - cluster_aware_metrics.h cluster_aware_time_metrics_holder.h cluster_topology_info.h connection_handler.h connection_proxy.h - driver.h efm_proxy.h error.h failover.h host_info.h monitor.h monitor_connection_context.h monitor_service.h - monitor_thread_container.h mylog.h mysql_proxy.h myutil.h parse.h query_parsing.h topology_service.h + aws_sdk_helper.h + base_metrics_holder.h + catalog.h + cluster_aware_hit_metrics_holder.h + cluster_aware_metrics_container.h + cluster_aware_metrics.h + cluster_aware_time_metrics_holder.h + cluster_topology_info.h + connection_handler.h + connection_proxy.h + driver.h + efm_proxy.h + error.h + failover.h + host_info.h + iam_proxy.h + monitor.h + monitor_connection_context.h + monitor_service.h + monitor_thread_container.h + mylog.h + mysql_proxy.h + myutil.h + parse.h + query_parsing.h + topology_service.h ../MYODBC_MYSQL.h ../MYODBC_CONF.h ../MYODBC_ODBC.h) ENDIF(WIN32) diff --git a/driver/connect.cc b/driver/connect.cc index 38068bd08..cd2b6f76b 100644 --- a/driver/connect.cc +++ b/driver/connect.cc @@ -666,7 +666,9 @@ SQLRETURN DBC::connect(DataSource *dsrc, bool failover_enabled) #endif #if (MYSQL_VERSION_ID >= 50527 && MYSQL_VERSION_ID < 50600) || MYSQL_VERSION_ID >= 50607 - if (dsrc->enable_cleartext_plugin) + // IAM authentication requires the plugin to be set. + if (dsrc->enable_cleartext_plugin || + (dsrc->auth_mode8 && !myodbc_strcasecmp(AUTH_MODE_IAM, (const char*)dsrc->auth_mode8))) { connection_proxy->options(MYSQL_ENABLE_CLEARTEXT_PLUGIN, (char *)&on); } @@ -823,20 +825,12 @@ SQLRETURN DBC::connect(DataSource *dsrc, bool failover_enabled) ds_set_strnattr(&dsrc->server8, (SQLCHAR*)host, strlen(host)); dsrc->port = port; - const bool connect_result = dsrc->enable_dns_srv ? - connection_proxy->real_connect_dns_srv(host, - ds_get_utf8attr(dsrc->uid, &dsrc->uid8), - ds_get_utf8attr(dsrc->pwd, &dsrc->pwd8), - ds_get_utf8attr(dsrc->database, &dsrc->database8), - flags) - : - connection_proxy->real_connect(host, - ds_get_utf8attr(dsrc->uid, &dsrc->uid8), - ds_get_utf8attr(dsrc->pwd, &dsrc->pwd8), - ds_get_utf8attr(dsrc->database, &dsrc->database8), - port, - ds_get_utf8attr(dsrc->socket, &dsrc->socket8), - flags); + const char* user = ds_get_utf8attr(dsrc->uid, &dsrc->uid8); + const char* password = ds_get_utf8attr(dsrc->pwd, &dsrc->pwd8); + const char* database = ds_get_utf8attr(dsrc->database, &dsrc->database8); + const char* socket = ds_get_utf8attr(dsrc->socket, &dsrc->socket8); + + const bool connect_result = connection_proxy->connect(host, user, password, database, port, socket, flags); if (!connect_result) { unsigned int native_error= connection_proxy->error_code(); diff --git a/driver/connection_proxy.cc b/driver/connection_proxy.cc index 7073ac3aa..db8716aad 100644 --- a/driver/connection_proxy.cc +++ b/driver/connection_proxy.cc @@ -45,6 +45,16 @@ CONNECTION_PROXY::~CONNECTION_PROXY() { } } +bool CONNECTION_PROXY::connect(const char* host, const char* user, const char* password, + const char* database, unsigned int port, const char* socket, unsigned long flags) { + + if (ds->enable_dns_srv) { + return this->real_connect_dns_srv(host, user, password, database, flags); + } + + return this->real_connect(host, user, password, database, port, socket, flags); +} + void CONNECTION_PROXY::delete_ds() { next_proxy->delete_ds(); } @@ -78,6 +88,14 @@ unsigned int CONNECTION_PROXY::error_code() { } const char* CONNECTION_PROXY::error() { + if (has_custom_error_message) { + // We disable this flag after fetching the custom message once + // so it does not obscure future proxy errors. + has_custom_error_message = false; + + return this->custom_error_message.c_str(); + } + return next_proxy->error(); } @@ -393,3 +411,8 @@ void CONNECTION_PROXY::set_next_proxy(CONNECTION_PROXY* next_proxy) { MYSQL* CONNECTION_PROXY::move_mysql_connection() { return next_proxy ? next_proxy->move_mysql_connection() : nullptr; } + +void CONNECTION_PROXY::set_custom_error_message(const char* error_message) { + this->custom_error_message = error_message; + has_custom_error_message = true; +} diff --git a/driver/connection_proxy.h b/driver/connection_proxy.h index b12160548..d3539b0c2 100644 --- a/driver/connection_proxy.h +++ b/driver/connection_proxy.h @@ -37,9 +37,19 @@ struct DataSource; class CONNECTION_PROXY { public: + CONNECTION_PROXY() = default; CONNECTION_PROXY(DBC* dbc, DataSource* ds); virtual ~CONNECTION_PROXY(); + virtual bool connect( + const char* host, + const char* user, + const char* password, + const char* database, + unsigned int port, + const char* socket, + unsigned long flags); + virtual void delete_ds(); virtual uint64_t num_rows(MYSQL_RES* res); virtual unsigned int num_fields(MYSQL_RES* res); @@ -163,10 +173,14 @@ class CONNECTION_PROXY { virtual MYSQL* move_mysql_connection(); + void set_custom_error_message(const char* error_message); + protected: DBC* dbc = nullptr; DataSource* ds = nullptr; CONNECTION_PROXY* next_proxy = nullptr; + bool has_custom_error_message = false; + std::string custom_error_message = ""; }; #endif /* __CONNECTION_PROXY__ */ diff --git a/driver/handle.cc b/driver/handle.cc index 59b0401df..cfc2f98ff 100644 --- a/driver/handle.cc +++ b/driver/handle.cc @@ -49,6 +49,7 @@ #include "driver.h" #include "efm_proxy.h" +#include "iam_proxy.h" #include "mysql_proxy.h" #include @@ -125,18 +126,18 @@ void DBC::init_proxy_chain(DataSource* dsrc) head = efm_proxy; } - ds_get_utf8attr(dsrc->auth_mode, &dsrc->auth_mode8); - - if (!myodbc_strcasecmp(AUTH_MODE_IAM, reinterpret_cast(dsrc->auth_mode8))) { - // CONNECTION_PROXY* iam_proxy = new IAM_PROXY(his, dsrc); - // iam_proxy->set_next_proxy(head); - // head = iam_proxy; - } - - if (!myodbc_strcasecmp(AUTH_MODE_SECRETS_MANAGER, reinterpret_cast(dsrc->auth_mode8))) { - // CONNECTION_PROXY* secrets_manager_proxy = new SECRETS_MANAGER_PROXY(his, dsrc); - // secrets_manager_proxy->set_next_proxy(head); - // head = secrets_manager_proxy; + if (dsrc->auth_mode) { + const char* auth_mode = ds_get_utf8attr(dsrc->auth_mode, &dsrc->auth_mode8); + if (!myodbc_strcasecmp(AUTH_MODE_IAM, auth_mode)) { + CONNECTION_PROXY* iam_proxy = new IAM_PROXY(this, dsrc); + iam_proxy->set_next_proxy(head); + head = iam_proxy; + } + else if (!myodbc_strcasecmp(AUTH_MODE_SECRETS_MANAGER, auth_mode)) { + // CONNECTION_PROXY* secrets_manager_proxy = new SECRETS_MANAGER_PROXY(his, dsrc); + // secrets_manager_proxy->set_next_proxy(head); + // head = secrets_manager_proxy; + } } this->connection_proxy = head; diff --git a/driver/iam_proxy.cc b/driver/iam_proxy.cc new file mode 100644 index 000000000..79b8076ac --- /dev/null +++ b/driver/iam_proxy.cc @@ -0,0 +1,164 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include + +#include "aws_sdk_helper.h" +#include "driver.h" +#include "iam_proxy.h" + +namespace { + AWS_SDK_HELPER SDK_HELPER; +} + +std::unordered_map IAM_PROXY::token_cache; +std::mutex IAM_PROXY::token_cache_mutex; + +IAM_PROXY::IAM_PROXY(DBC* dbc, DataSource* ds) : IAM_PROXY(dbc, ds, nullptr) {}; + +IAM_PROXY::IAM_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) { + ++SDK_HELPER; + + this->next_proxy = next_proxy; + + Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider; + Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials(); + + Aws::RDS::RDSClientConfiguration client_config; + if (ds->auth_region) { + client_config.region = ds_get_utf8attr(ds->auth_region, &ds->auth_region8); + } + + this->rds_client = std::make_shared(credentials, client_config); +} + +IAM_PROXY::~IAM_PROXY() { + this->rds_client.reset(); + --SDK_HELPER; +} + +bool IAM_PROXY::connect(const char* host, const char* user, const char* password, + const char* database, unsigned int port, const char* socket, unsigned long flags) { + + if (ds->auth_host) { + host = ds_get_utf8attr(ds->auth_host, &ds->auth_host8); + } + + const char* region; + if (ds->auth_region8) { + region = (const char*)ds->auth_region8; + } + else { + // Go with default region if region is not provided. + region = "us-east-1"; + } + + port = ds->auth_port; + + std::string auth_token = this->get_auth_token(host, region, port, user, ds->auth_expiration); + + bool connect_result = next_proxy->connect(host, user, auth_token.c_str(), database, port, socket, flags); + if (!connect_result) { + Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider; + Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials(); + if (credentials.IsEmpty()) { + this->set_custom_error_message( + "Could not find AWS Credentials for IAM Authentication. Please set up AWS credentials."); + } + else if (credentials.IsExpired()) { + this->set_custom_error_message( + "AWS Credentials for IAM Authentication are expired. Please refresh AWS credentials."); + } + } + + return connect_result; +} + +std::string IAM_PROXY::get_auth_token( + const char* host, const char* region, unsigned int port, + const char* user, unsigned int time_until_expiration) { + + if (!host) { + host = ""; + } + if (!region) { + region = ""; + } + if (!user) { + user = ""; + } + + std::string auth_token; + std::string cache_key = build_cache_key(host, region, port, user); + + { + std::unique_lock lock(token_cache_mutex); + + // Search for token in cache + auto find_token = token_cache.find(cache_key); + if (find_token != token_cache.end()) + { + TOKEN_INFO info = find_token->second; + if (info.is_expired()) { + token_cache.erase(cache_key); + } + else { + return info.token; + } + } + + // Generate new token + auth_token = generate_auth_token(host, region, port, user); + + token_cache[cache_key] = TOKEN_INFO(auth_token, time_until_expiration); + } + + return auth_token; +} + +std::string IAM_PROXY::build_cache_key( + const char* host, const char* region, unsigned int port, const char* user) { + + // Format should be ":::" + return std::string(region) + .append(":").append(host) + .append(":").append(std::to_string(port)) + .append(":").append(user); +} + +std::string IAM_PROXY::generate_auth_token( + const char* host, const char* region, unsigned int port, const char* user) { + + return this->rds_client->GenerateConnectAuthToken(host, region, port, user); +} + +void IAM_PROXY::clear_token_cache() { + std::unique_lock lock(token_cache_mutex); + token_cache.clear(); +} diff --git a/driver/iam_proxy.h b/driver/iam_proxy.h new file mode 100644 index 000000000..eb1601698 --- /dev/null +++ b/driver/iam_proxy.h @@ -0,0 +1,100 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + + +#ifndef __IAM_PROXY__ +#define __IAM_PROXY__ + +#include +#include + +#include "connection_proxy.h" + +constexpr auto DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60; + +class TOKEN_INFO { +public: + TOKEN_INFO() {}; + TOKEN_INFO(std::string token) : TOKEN_INFO(token, DEFAULT_TOKEN_EXPIRATION_SEC) {}; + TOKEN_INFO(std::string token, unsigned int seconds_until_expiration) { + this->token = token; + this->expiration_time = std::chrono::system_clock::now() + std::chrono::seconds(seconds_until_expiration); + } + + bool is_expired() { + std::chrono::system_clock::time_point current_time = std::chrono::system_clock::now(); + return current_time > this->expiration_time; + } + + std::string token; + +private: + std::chrono::system_clock::time_point expiration_time; +}; + +class IAM_PROXY : public CONNECTION_PROXY { +public: + IAM_PROXY() = default; + IAM_PROXY(DBC* dbc, DataSource* ds); + IAM_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy); + ~IAM_PROXY() override; + + bool connect( + const char* host, + const char* user, + const char* password, + const char* database, + unsigned int port, + const char* socket, + unsigned long flags) override; + + std::string get_auth_token( + const char* host,const char* region, unsigned int port, + const char* user, unsigned int time_until_expiration); + +protected: + static std::unordered_map token_cache; + static std::mutex token_cache_mutex; + std::shared_ptr rds_client; + + static std::string build_cache_key( + const char* host, const char* region, unsigned int port, const char* user); + + virtual std::string generate_auth_token( + const char* host, const char* region, unsigned int port, const char* user); + + static void clear_token_cache(); + +#ifdef UNIT_TEST_BUILD + // Allows for testing private/protected methods + friend class TEST_UTILS; +#endif +}; + +#endif /* __IAM_PROXY__ */ diff --git a/unit_testing/CMakeLists.txt b/unit_testing/CMakeLists.txt index 2f2f350e2..42845af77 100644 --- a/unit_testing/CMakeLists.txt +++ b/unit_testing/CMakeLists.txt @@ -55,6 +55,7 @@ add_executable( cluster_aware_metrics_test.cc efm_proxy_test.cc + iam_proxy_test.cc failover_handler_test.cc failover_reader_handler_test.cc failover_writer_handler_test.cc diff --git a/unit_testing/iam_proxy_test.cc b/unit_testing/iam_proxy_test.cc new file mode 100644 index 000000000..ed19e417e --- /dev/null +++ b/unit_testing/iam_proxy_test.cc @@ -0,0 +1,133 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include +#include + +#include "test_utils.h" +#include "mock_objects.h" + +using ::testing::_; +using ::testing::Return; + +static Aws::SDKOptions options; +static MOCK_IAM_PROXY* mock_proxy; + +class IamProxyTest : public testing::Test { +protected: + const char* host = "test_host"; + const char* region = "test_region"; + unsigned int port = 3306; + const char* user = "test_user"; + const char* token = "test_token"; + + static void SetUpTestSuite() { + Aws::InitAPI(options); + mock_proxy = new MOCK_IAM_PROXY(); + } + + static void TearDownTestSuite() { + delete mock_proxy; + Aws::ShutdownAPI(options); + } + + void SetUp() override { + } + + void TearDown() override { + TEST_UTILS::clear_token_cache(mock_proxy); + } +}; + +TEST_F(IamProxyTest, TokenExpiration) { + const unsigned int time_to_expire = 5; + TOKEN_INFO info = TOKEN_INFO("test_key", time_to_expire); + EXPECT_FALSE(info.is_expired()); + + std::this_thread::sleep_for(std::chrono::seconds(time_to_expire + 1)); + EXPECT_TRUE(info.is_expired()); +} + +TEST_F(IamProxyTest, TokenGetsCachedAndRetrieved) { + std::string cache_key = TEST_UTILS::build_cache_key(host, region, port, user); + EXPECT_FALSE(TEST_UTILS::token_cache_contains_key(cache_key)); + + // We should only generate the token once. + EXPECT_CALL(*mock_proxy, generate_auth_token(host, region, port, user)).WillOnce(Return(token)); + + std::string token1 = mock_proxy->get_auth_token(host, region, port, user, 100); + + EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(cache_key)); + + // This 2nd call to get_auth_token() will retrieve the cached token. + std::string token2 = mock_proxy->get_auth_token(host, region, port, user, 100); + + EXPECT_EQ(token, token1); + EXPECT_TRUE(token1 == token2); +} + +TEST_F(IamProxyTest, MultipleCachedTokens) { + // Two separate tokens should be generated. + EXPECT_CALL(*mock_proxy, generate_auth_token(_, region, port, user)) + .WillOnce(Return(token)) + .WillOnce(Return(token)); + + const char* host2 = "test_host2"; + + mock_proxy->get_auth_token(host, region, port, user, 100); + mock_proxy->get_auth_token(host2, region, port, user, 100); + + std::string cache_key1 = TEST_UTILS::build_cache_key(host, region, port, user); + std::string cache_key2 = TEST_UTILS::build_cache_key(host2, region, port, user); + + EXPECT_NE(cache_key1, cache_key2); + + EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(cache_key1)); + EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(cache_key2)); +} + +TEST_F(IamProxyTest, RegenerateTokenAfterExpiration) { + // We will generate the token twice because the first token will expire before the 2nd call to get_auth_token(). + EXPECT_CALL(*mock_proxy, generate_auth_token(host, region, port, user)) + .WillOnce(Return(token)) + .WillOnce(Return(token)); + + const unsigned int time_to_expire = 5; + mock_proxy->get_auth_token(host, region, port, user, time_to_expire); + + std::string cache_key = TEST_UTILS::build_cache_key(host, region, port, user); + EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(cache_key)); + + // Wait for first token to expire. + std::this_thread::sleep_for(std::chrono::seconds(time_to_expire + 1)); + mock_proxy->get_auth_token(host, region, port, user, time_to_expire); + + EXPECT_TRUE(TEST_UTILS::token_cache_contains_key(cache_key)); +} diff --git a/unit_testing/mock_objects.h b/unit_testing/mock_objects.h index ec26c0d50..345c10e79 100644 --- a/unit_testing/mock_objects.h +++ b/unit_testing/mock_objects.h @@ -34,6 +34,7 @@ #include "driver/connection_proxy.h" #include "driver/failover.h" +#include "driver/iam_proxy.h" #include "driver/monitor_thread_container.h" #include "driver/monitor_service.h" @@ -84,6 +85,13 @@ class MOCK_CONNECTION_PROXY : public CONNECTION_PROXY { MOCK_METHOD(void, delete_ds, ()); }; +class MOCK_IAM_PROXY : public IAM_PROXY { +public: + MOCK_IAM_PROXY() : IAM_PROXY() {}; + + MOCK_METHOD(std::string , generate_auth_token, (const char*, const char*, unsigned int, const char*)); +}; + class MOCK_TOPOLOGY_SERVICE : public TOPOLOGY_SERVICE { public: MOCK_TOPOLOGY_SERVICE() : TOPOLOGY_SERVICE(0) {}; diff --git a/unit_testing/test_utils.cc b/unit_testing/test_utils.cc index bd3d83cf8..56533056b 100644 --- a/unit_testing/test_utils.cc +++ b/unit_testing/test_utils.cc @@ -110,3 +110,15 @@ size_t TEST_UTILS::get_map_size(std::shared_ptr contai std::list> TEST_UTILS::get_contexts(std::shared_ptr monitor) { return monitor->contexts; } + +std::string TEST_UTILS::build_cache_key(const char* host, const char* region, unsigned int port, const char* user) { + return IAM_PROXY::build_cache_key(host, region, port, user); +} + +bool TEST_UTILS::token_cache_contains_key(std::string cache_key) { + return IAM_PROXY::token_cache.find(cache_key) != IAM_PROXY::token_cache.end(); +} + +void TEST_UTILS::clear_token_cache(IAM_PROXY* iam_proxy) { + iam_proxy->clear_token_cache(); +} diff --git a/unit_testing/test_utils.h b/unit_testing/test_utils.h index 2231adcca..1412992f6 100644 --- a/unit_testing/test_utils.h +++ b/unit_testing/test_utils.h @@ -31,6 +31,7 @@ #define __TESTUTILS_H__ #include "driver/driver.h" +#include "driver/iam_proxy.h" #include "driver/monitor.h" #include "driver/monitor_thread_container.h" @@ -52,6 +53,9 @@ class TEST_UTILS { static std::shared_ptr get_available_monitor(std::shared_ptr container); static size_t get_map_size(std::shared_ptr container); static std::list> get_contexts(std::shared_ptr monitor); + static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user); + static bool token_cache_contains_key(std::string cache_key); + static void clear_token_cache(IAM_PROXY* iam_proxy); }; #endif /* __TESTUTILS_H__ */