From d0c6f74f8158ce3fd5e151ed87de495aae209296 Mon Sep 17 00:00:00 2001 From: justing-bq <62349012+justing-bq@users.noreply.github.com> Date: Thu, 27 Apr 2023 11:05:44 -0700 Subject: [PATCH] Verify Writer Cluster Connections (#139) --- driver/connect.cc | 6 +- driver/failover.h | 15 +- driver/failover_handler.cc | 189 ++++++++++++++---- integration/base_failover_integration_test.cc | 6 + .../network_failover_integration_test.cc | 6 +- unit_testing/failover_handler_test.cc | 89 +++++++-- unit_testing/mock_objects.h | 8 + unit_testing/test_utils.cc | 4 + unit_testing/test_utils.h | 2 + 9 files changed, 262 insertions(+), 63 deletions(-) diff --git a/driver/connect.cc b/driver/connect.cc index 86dfb3f1d..4aa1b9ce0 100644 --- a/driver/connect.cc +++ b/driver/connect.cc @@ -1047,7 +1047,7 @@ SQLRETURN SQL_API MySQLConnect(SQLHDBC hdbc, dbc->init_proxy_chain(ds); dbc->connection_handler = std::make_shared(dbc); dbc->fh = new FAILOVER_HANDLER(dbc, ds); - rc = dbc->fh->init_cluster_info(); + rc = dbc->fh->init_connection(); if (!dbc->ds) ds_delete(ds); return rc; @@ -1165,7 +1165,7 @@ SQLRETURN SQL_API MySQLDriverConnect(SQLHDBC hdbc, SQLHWND hwnd, dbc->init_proxy_chain(ds); dbc->connection_handler = std::make_shared(dbc); dbc->fh = new FAILOVER_HANDLER(dbc, ds); - rc = dbc->fh->init_cluster_info(); + rc = dbc->fh->init_connection(); if (rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO) goto connected; bPrompt= TRUE; @@ -1345,7 +1345,7 @@ SQLRETURN SQL_API MySQLDriverConnect(SQLHDBC hdbc, SQLHWND hwnd, dbc->init_proxy_chain(ds); dbc->connection_handler = std::make_shared(dbc); dbc->fh = new FAILOVER_HANDLER(dbc, ds); - rc = dbc->fh->init_cluster_info(); + rc = dbc->fh->init_connection(); if (rc != SQL_SUCCESS && rc != SQL_SUCCESS_WITH_INFO) { goto error; diff --git a/driver/failover.h b/driver/failover.h index a4ae9e034..90900ddc6 100644 --- a/driver/failover.h +++ b/driver/failover.h @@ -167,7 +167,7 @@ class FAILOVER_HANDLER { std::shared_ptr topology_service, std::shared_ptr metrics_container); ~FAILOVER_HANDLER(); - SQLRETURN init_cluster_info(); + SQLRETURN init_connection(); bool trigger_failover_if_needed(const char* error_code, const char*& new_error_code, const char*& error_msg); bool is_failover_enabled(); bool is_rds(); @@ -190,15 +190,20 @@ class FAILOVER_HANDLER { bool m_is_rds_proxy = false; bool m_is_rds = false; bool m_is_rds_custom_cluster = false; - bool initialized = false; + bool is_cluster_info_initialized = false; + void init_cluster_info(); + bool should_connect_to_new_writer(); + void initialize_topology(); + bool is_read_only(); + virtual std::string host_to_IP(std::string host); + SQLRETURN reconnect(bool failover_enabled); static bool is_dns_pattern_valid(std::string host); static bool is_rds_dns(std::string host); static bool is_rds_cluster_dns(std::string host); static bool is_rds_proxy_dns(std::string host); + static bool is_rds_writer_cluster_dns(std::string host); static bool is_rds_custom_cluster_dns(std::string host); - SQLRETURN create_connection_and_initialize_topology(); - SQLRETURN reconnect(bool failover_enabled); static std::string get_rds_cluster_host_url(std::string host); static std::string get_rds_instance_host_pattern(std::string host); bool is_ipv4(std::string host); @@ -213,7 +218,7 @@ class FAILOVER_HANDLER { std::chrono::steady_clock::time_point failover_start_time_ms; #ifdef UNIT_TEST_BUILD - // Allows for testing private methods + // Allows for testing private/protected methods friend class TEST_UTILS; #endif }; diff --git a/driver/failover_handler.cc b/driver/failover_handler.cc index 5724d2496..94966301a 100644 --- a/driver/failover_handler.cc +++ b/driver/failover_handler.cc @@ -36,6 +36,15 @@ #include #include "driver.h" +#include "mylog.h" + +#if defined(__APPLE__) || defined(__linux__) + #include + #include + #include + #include + #include +#endif namespace { const std::regex AURORA_DNS_PATTERN( @@ -47,6 +56,9 @@ const std::regex AURORA_PROXY_DNS_PATTERN( const std::regex AURORA_CLUSTER_PATTERN( R"#((.+)\.(cluster-|cluster-ro-)+([a-zA-Z0-9]+\.[a-zA-Z0-9\-]+\.rds\.amazonaws\.com))#", std::regex_constants::icase); +const std::regex AURORA_WRITER_CLUSTER_PATTERN( + R"#((.+)\.(cluster-)+([a-zA-Z0-9]+\.[a-zA-Z0-9\-]+\.rds\.amazonaws\.com))#", + std::regex_constants::icase); const std::regex AURORA_CUSTOM_CLUSTER_PATTERN( R"#((.+)\.(cluster-custom-)+([a-zA-Z0-9]+\.[a-zA-Z0-9\-]+\.rds\.amazonaws\.com))#", std::regex_constants::icase); @@ -59,6 +71,9 @@ const std::regex AURORA_CHINA_PROXY_DNS_PATTERN( const std::regex AURORA_CHINA_CLUSTER_PATTERN( R"#((.+)\.(cluster-|cluster-ro-)+([a-zA-Z0-9]+\.rds\.[a-zA-Z0-9\-]+\.amazonaws\.com\.cn))#", std::regex_constants::icase); +const std::regex AURORA_CHINA_WRITER_CLUSTER_PATTERN( + R"#((.+)\.(cluster-)+([a-zA-Z0-9]+\.rds\.[a-zA-Z0-9\-]+\.amazonaws\.com\.cn))#", + std::regex_constants::icase); const std::regex AURORA_CHINA_CUSTOM_CLUSTER_PATTERN( R"#((.+)\.(cluster-custom-)+([a-zA-Z0-9]+\.rds\.[a-zA-Z0-9\-]+\.amazonaws\.com\.cn))#", std::regex_constants::icase); @@ -67,6 +82,8 @@ const std::regex IPV4_PATTERN( const std::regex IPV6_PATTERN(R"#(^[0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){7}$)#"); const std::regex IPV6_COMPRESSED_PATTERN( R"#(^(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)::(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)$)#"); + +const char* MYSQL_READONLY_QUERY = "SELECT @@innodb_read_only AS is_reader"; } // namespace FAILOVER_HANDLER::FAILOVER_HANDLER(DBC* dbc, DataSource* ds) @@ -108,18 +125,45 @@ FAILOVER_HANDLER::FAILOVER_HANDLER(DBC* dbc, DataSource* ds, FAILOVER_HANDLER::~FAILOVER_HANDLER() {} -SQLRETURN FAILOVER_HANDLER::init_cluster_info() { - SQLRETURN rc = SQL_ERROR; - if (initialized) { - return rc; +SQLRETURN FAILOVER_HANDLER::init_connection() { + SQLRETURN rc = connection_handler->do_connect(dbc, ds, false); + if (SQL_SUCCEEDED(rc)) { + metrics_container->register_invalid_initial_connection(false); } - - if (!ds->enable_cluster_failover) { - // Use a standard default connection - no further initialization required - rc = connection_handler->do_connect(dbc, ds, false); - initialized = true; + else { + metrics_container->register_invalid_initial_connection(true); return rc; } + + bool reconnect_with_updated_timeouts = false; + if (ds->enable_cluster_failover) { + this->init_cluster_info(); + + if (is_failover_enabled()) { + // Since we can't determine whether failover should be enabled + // before we connect, there is a possibility we need to reconnect + // again with the correct connection settings for failover. + const unsigned int connect_timeout = get_connect_timeout(ds->connect_timeout); + const unsigned int network_timeout = get_network_timeout(ds->network_timeout); + + reconnect_with_updated_timeouts = (connect_timeout != dbc->login_timeout || + network_timeout != ds->read_timeout || + network_timeout != ds->write_timeout); + } + } + + if (should_connect_to_new_writer() || reconnect_with_updated_timeouts) { + rc = reconnect(reconnect_with_updated_timeouts); + } + + return rc; +} + +void FAILOVER_HANDLER::init_cluster_info() { + if (is_cluster_info_initialized) { + return; + } + std::stringstream err; // Cluster-aware failover is enabled @@ -199,7 +243,7 @@ SQLRETURN FAILOVER_HANDLER::init_cluster_info() { } } - rc = create_connection_and_initialize_topology(); + initialize_topology(); } else if (is_ipv4(main_host) || is_ipv6(main_host)) { // TODO: do we need to setup host template in this case? // HOST_INFO* host_template = new HOST_INFO(); @@ -211,7 +255,7 @@ SQLRETURN FAILOVER_HANDLER::init_cluster_info() { set_cluster_id(clid_str); } - rc = create_connection_and_initialize_topology(); + initialize_topology(); if (m_is_cluster_topology_available) { err << "Host Pattern configuration setting is required when IP " @@ -242,7 +286,7 @@ SQLRETURN FAILOVER_HANDLER::init_cluster_info() { set_cluster_id(clid_str); } - rc = create_connection_and_initialize_topology(); + initialize_topology(); if (m_is_cluster_topology_available) { err << "The provided host appears to be a custom domain. The " @@ -288,12 +332,58 @@ SQLRETURN FAILOVER_HANDLER::init_cluster_info() { } } - rc = create_connection_and_initialize_topology(); + initialize_topology(); } } - initialized = true; - return rc; + is_cluster_info_initialized = true; +} + +bool FAILOVER_HANDLER::should_connect_to_new_writer() { + auto host = (const char*)ds->server8; + if (host == nullptr || host == "") { + return false; + } + + if (!is_rds_writer_cluster_dns(host)) { + return false; + } + + std::string host_ip = host_to_IP(host); + if (host_ip == "") { + return false; + } + + this->init_cluster_info(); + + // We need to force refresh the topology if we are connected to a read only instance. + auto topology = topology_service->get_topology(dbc->connection_proxy, is_read_only()); + + std::shared_ptr writer; + try { + writer = topology->get_writer(); + } + catch (std::runtime_error) { + return false; + } + + std::string writer_host = writer->get_host(); + if (is_rds_cluster_dns(writer_host.c_str())) { + return false; + } + + std::string writer_host_ip = host_to_IP(writer_host); + if (writer_host_ip == "" || writer_host_ip == host_ip) { + return false; + } + + // DNS must have resolved the cluster endpoint to a wrong writer + // so we should reconnect to a proper writer node. + const sqlwchar_string writer_host_wstr = to_sqlwchar_string(writer_host); + ds_set_wstrnattr(&ds->server, (SQLWCHAR*)writer_host_wstr.c_str(), writer_host_wstr.size()); + ds_set_strnattr(&ds->server8, (SQLCHAR*)writer_host.c_str(), writer_host.size()); + + return true; } void FAILOVER_HANDLER::set_cluster_id(std::string host, int port) { @@ -322,10 +412,55 @@ bool FAILOVER_HANDLER::is_rds_proxy_dns(std::string host) { return std::regex_match(host, AURORA_PROXY_DNS_PATTERN) || std::regex_match(host, AURORA_CHINA_PROXY_DNS_PATTERN); } +bool FAILOVER_HANDLER::is_rds_writer_cluster_dns(std::string host) { + return std::regex_match(host, AURORA_WRITER_CLUSTER_PATTERN) || std::regex_match(host, AURORA_CHINA_WRITER_CLUSTER_PATTERN); +} + bool FAILOVER_HANDLER::is_rds_custom_cluster_dns(std::string host) { return std::regex_match(host, AURORA_CUSTOM_CLUSTER_PATTERN) || std::regex_match(host, AURORA_CHINA_CUSTOM_CLUSTER_PATTERN); } +bool FAILOVER_HANDLER::is_read_only() { + bool read_only = false; + if (dbc->connection_proxy->query(MYSQL_READONLY_QUERY) == 0) { + auto result = dbc->connection_proxy->store_result(); + MYSQL_ROW row; + if (row = dbc->connection_proxy->fetch_row(result)) { + read_only = (strcmp(row[0], "1") == 0); + } + dbc->connection_proxy->free_result(result); + } + + return read_only; +} + +std::string FAILOVER_HANDLER::host_to_IP(std::string host) { + int status; + struct addrinfo hints; + struct addrinfo* servinfo; + struct addrinfo* p; + char ipstr[INET_ADDRSTRLEN]; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; //IPv4 + hints.ai_socktype = SOCK_STREAM; + + if ((status = getaddrinfo(host.c_str(), NULL, &hints, &servinfo)) != 0) { + return ""; + } + + for (p = servinfo; p != NULL; p = p->ai_next) { + void* addr; + + struct sockaddr_in* ipv4 = (struct sockaddr_in*)p->ai_addr; + addr = &(ipv4->sin_addr); + inet_ntop(p->ai_family, addr, ipstr, sizeof(ipstr)); + } + + freeaddrinfo(servinfo); + return std::string(ipstr); +} + #if defined(__APPLE__) || defined(__linux__) #define strcmp_case_insensitive(str1, str2) strcasecmp(str1, str2) #else @@ -398,14 +533,8 @@ bool FAILOVER_HANDLER::is_cluster_topology_available() { return m_is_cluster_topology_available; } -SQLRETURN FAILOVER_HANDLER::create_connection_and_initialize_topology() { - SQLRETURN rc = connection_handler->do_connect(dbc, ds, false); - if (!SQL_SUCCEEDED(rc)) { - metrics_container->register_invalid_initial_connection(true); - return rc; - } - - metrics_container->register_invalid_initial_connection(false); +void FAILOVER_HANDLER::initialize_topology() { + current_topology = topology_service->get_topology(dbc->connection_proxy, false); if (current_topology) { m_is_multi_writer_cluster = current_topology->is_multi_writer_cluster; @@ -413,24 +542,10 @@ SQLRETURN FAILOVER_HANDLER::create_connection_and_initialize_topology() { MYLOG_DBC_TRACE(dbc, "[FAILOVER_HANDLER] m_is_cluster_topology_available=%s", m_is_cluster_topology_available ? "true" : "false"); - - // Since we can't determine whether failover should be enabled - // before we connect, there is a possibility we need to reconnect - // again with the correct connection settings for failover. - const unsigned int connect_timeout = get_connect_timeout(ds->connect_timeout); - const unsigned int network_timeout = get_network_timeout(ds->network_timeout); - - if (is_failover_enabled() && (connect_timeout != dbc->login_timeout || - network_timeout != ds->read_timeout || - network_timeout != ds->write_timeout)) { - rc = reconnect(true); - } if (is_failover_enabled()) { this->dbc->env->failover_thread_pool.resize(current_topology->total_hosts()); } } - - return rc; } SQLRETURN FAILOVER_HANDLER::reconnect(bool failover_enabled) { diff --git a/integration/base_failover_integration_test.cc b/integration/base_failover_integration_test.cc index e1b1bf25e..52e42e879 100644 --- a/integration/base_failover_integration_test.cc +++ b/integration/base_failover_integration_test.cc @@ -486,6 +486,12 @@ class BaseFailoverIntegrationTest : public testing::Test { } void test_connection(const SQLHDBC dbc, const std::string& test_server, const int test_port) { + sprintf(reinterpret_cast(conn_in), "%sSERVER=%s;PORT=%d;", get_default_config().c_str(), test_server.c_str(), test_port); + EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, conn_in, SQL_NTS, conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT)); + EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc)); + } + + void test_connection_with_proxy_pattern(const SQLHDBC dbc, const std::string& test_server, const int test_port) { sprintf(reinterpret_cast(conn_in), "%sSERVER=%s;PORT=%d;", get_default_proxied_config().c_str(), test_server.c_str(), test_port); EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, conn_in, SQL_NTS, conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT)); EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc)); diff --git a/integration/network_failover_integration_test.cc b/integration/network_failover_integration_test.cc index e9a7cde8c..2e73a0301 100644 --- a/integration/network_failover_integration_test.cc +++ b/integration/network_failover_integration_test.cc @@ -93,11 +93,11 @@ class NetworkFailoverIntegrationTest : public BaseFailoverIntegrationTest { TEST_F(NetworkFailoverIntegrationTest, connection_test) { test_connection(dbc, MYSQL_INSTANCE_1_URL, MYSQL_PORT); - test_connection(dbc, MYSQL_INSTANCE_1_URL + PROXIED_DOMAIN_NAME_SUFFIX, MYSQL_PROXY_PORT); + test_connection_with_proxy_pattern(dbc, MYSQL_INSTANCE_1_URL + PROXIED_DOMAIN_NAME_SUFFIX, MYSQL_PROXY_PORT); test_connection(dbc, MYSQL_CLUSTER_URL, MYSQL_PORT); - test_connection(dbc, MYSQL_CLUSTER_URL + PROXIED_DOMAIN_NAME_SUFFIX, MYSQL_PROXY_PORT); + test_connection_with_proxy_pattern(dbc, MYSQL_CLUSTER_URL + PROXIED_DOMAIN_NAME_SUFFIX, MYSQL_PROXY_PORT); test_connection(dbc, MYSQL_RO_CLUSTER_URL, MYSQL_PORT); - test_connection(dbc, MYSQL_RO_CLUSTER_URL + PROXIED_DOMAIN_NAME_SUFFIX, MYSQL_PROXY_PORT); + test_connection_with_proxy_pattern(dbc, MYSQL_RO_CLUSTER_URL + PROXIED_DOMAIN_NAME_SUFFIX, MYSQL_PROXY_PORT); } TEST_F(NetworkFailoverIntegrationTest, lost_connection_to_writer) { diff --git a/unit_testing/failover_handler_test.cc b/unit_testing/failover_handler_test.cc index a78fd847b..a055e824f 100644 --- a/unit_testing/failover_handler_test.cc +++ b/unit_testing/failover_handler_test.cc @@ -35,7 +35,9 @@ using ::testing::_; using ::testing::AtLeast; +using ::testing::DeleteArg; using ::testing::Return; +using ::testing::ReturnNew; using ::testing::StrEq; namespace { @@ -119,7 +121,7 @@ TEST_F(FailoverHandlerTest, CustomDomain) { .WillOnce(Return(SQL_SUCCESS)); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_FALSE(failover_handler.is_rds()); EXPECT_FALSE(failover_handler.is_rds_proxy()); @@ -133,7 +135,7 @@ TEST_F(FailoverHandlerTest, FailoverDisabled) { EXPECT_CALL(*mock_connection_handler, do_connect(dbc, ds, false)).Times(1); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_FALSE(failover_handler.is_failover_enabled()); } @@ -145,13 +147,11 @@ TEST_F(FailoverHandlerTest, IP_TopologyAvailable_PatternRequired) { EXPECT_CALL(*mock_connection_handler, do_connect(dbc, ds, false)) .WillOnce(Return(SQL_SUCCESS)); - EXPECT_CALL(*mock_connection_handler, do_connect(dbc, ds, true)) - .WillOnce(Return(SQL_SUCCESS)); EXPECT_CALL(*mock_ts, get_topology(_, false)).WillOnce(Return(topology)); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - EXPECT_THROW(failover_handler.init_cluster_info(), std::runtime_error); + EXPECT_THROW(failover_handler.init_connection(), std::runtime_error); } TEST_F(FailoverHandlerTest, IP_TopologyNotAvailable) { @@ -166,7 +166,7 @@ TEST_F(FailoverHandlerTest, IP_TopologyNotAvailable) { .WillOnce(Return(std::make_shared())); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_FALSE(failover_handler.is_rds()); EXPECT_FALSE(failover_handler.is_rds_proxy()); @@ -191,7 +191,7 @@ TEST_F(FailoverHandlerTest, IP_Cluster) { EXPECT_CALL(*mock_ts, set_cluster_id(_)).Times(0); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_FALSE(failover_handler.is_rds()); EXPECT_FALSE(failover_handler.is_rds_proxy()); @@ -218,7 +218,7 @@ TEST_F(FailoverHandlerTest, IP_Cluster_ClusterID) { EXPECT_CALL(*mock_ts, set_cluster_id(StrEq(reinterpret_cast(cluster_id)))).Times(AtLeast(1)); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_FALSE(failover_handler.is_rds()); EXPECT_FALSE(failover_handler.is_rds_proxy()); @@ -242,7 +242,7 @@ TEST_F(FailoverHandlerTest, RDS_Cluster) { EXPECT_CALL(*mock_ts, set_cluster_id(StrEq("my-cluster-name.cluster-XYZ.us-east-2.rds.amazonaws.com:1234"))).Times(AtLeast(1)); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_TRUE(failover_handler.is_rds()); EXPECT_FALSE(failover_handler.is_rds_proxy()); @@ -265,7 +265,7 @@ TEST_F(FailoverHandlerTest, RDS_CustomCluster) { EXPECT_CALL(*mock_ts, set_cluster_id(_)).Times(1); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_TRUE(failover_handler.is_rds()); EXPECT_FALSE(failover_handler.is_rds_proxy()); @@ -288,7 +288,7 @@ TEST_F(FailoverHandlerTest, RDS_Instance) { EXPECT_CALL(*mock_ts, set_cluster_id(_)).Times(1); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_TRUE(failover_handler.is_rds()); EXPECT_FALSE(failover_handler.is_rds_proxy()); @@ -310,7 +310,7 @@ TEST_F(FailoverHandlerTest, RDS_Proxy) { EXPECT_CALL(*mock_ts, set_cluster_id(StrEq("test-proxy.proxy-XYZ.us-east-2.rds.amazonaws.com:1234"))).Times(AtLeast(1)); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_TRUE(failover_handler.is_rds()); EXPECT_TRUE(failover_handler.is_rds_proxy()); @@ -337,7 +337,7 @@ TEST_F(FailoverHandlerTest, RDS_ReaderCluster) { .Times(AtLeast(1)); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_TRUE(failover_handler.is_rds()); EXPECT_FALSE(failover_handler.is_rds_proxy()); @@ -368,7 +368,7 @@ TEST_F(FailoverHandlerTest, RDS_MultiWriterCluster) { .Times(AtLeast(1)); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_TRUE(failover_handler.is_rds()); EXPECT_FALSE(failover_handler.is_rds_proxy()); @@ -396,7 +396,7 @@ TEST_F(FailoverHandlerTest, ReconnectWithFailoverSettings) { .WillOnce(Return(SQL_SUCCESS)); FAILOVER_HANDLER failover_handler(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); - failover_handler.init_cluster_info(); + failover_handler.init_connection(); EXPECT_TRUE(failover_handler.is_failover_enabled()); } @@ -437,6 +437,18 @@ TEST_F(FailoverHandlerTest, IsRdsProxyDns) { EXPECT_FALSE(TEST_UTILS::is_rds_proxy_dns(CHINA_REGION_CUSTON_DOMAIN)); } +TEST_F(FailoverHandlerTest, IsRdsWriterClusterDns) { + EXPECT_TRUE(TEST_UTILS::is_rds_writer_cluster_dns(US_EAST_REGION_CLUSTER)); + EXPECT_FALSE(TEST_UTILS::is_rds_writer_cluster_dns(US_EAST_REGION_CLUSTER_READ_ONLY)); + EXPECT_FALSE(TEST_UTILS::is_rds_writer_cluster_dns(US_EAST_REGION_PROXY)); + EXPECT_FALSE(TEST_UTILS::is_rds_writer_cluster_dns(US_EAST_REGION_CUSTON_DOMAIN)); + + EXPECT_TRUE(TEST_UTILS::is_rds_writer_cluster_dns(CHINA_REGION_CLUSTER)); + EXPECT_FALSE(TEST_UTILS::is_rds_writer_cluster_dns(CHINA_REGION_CLUSTER_READ_ONLY)); + EXPECT_FALSE(TEST_UTILS::is_rds_writer_cluster_dns(CHINA_REGION_PROXY)); + EXPECT_FALSE(TEST_UTILS::is_rds_writer_cluster_dns(CHINA_REGION_CUSTON_DOMAIN)); +} + TEST_F(FailoverHandlerTest, IsRdsCustomClusterDns) { EXPECT_FALSE(TEST_UTILS::is_rds_custom_cluster_dns(US_EAST_REGION_CLUSTER)); EXPECT_FALSE(TEST_UTILS::is_rds_custom_cluster_dns(US_EAST_REGION_CLUSTER_READ_ONLY)); @@ -474,3 +486,50 @@ TEST_F(FailoverHandlerTest, GetRdsClusterHostUrl) { EXPECT_EQ(std::string(), TEST_UTILS::get_rds_cluster_host_url(CHINA_REGION_PROXY)); EXPECT_EQ(std::string(), TEST_UTILS::get_rds_cluster_host_url(CHINA_REGION_CUSTON_DOMAIN)); } + +TEST_F(FailoverHandlerTest, ConnectToNewWriter) { + SQLCHAR server[] = "my-cluster-name.cluster-XYZ.us-east-2.rds.amazonaws.com"; + + EXPECT_CALL(*mock_connection_handler, do_connect(dbc, ds, _)) + .WillOnce(Return(SQL_SUCCESS)) + .WillOnce(Return(SQL_SUCCESS)); + + ds_setattr_from_utf8(&ds->server, server); + ds->port = 1234; + ds->enable_cluster_failover = true; + + auto mock_proxy = new MOCK_CONNECTION_PROXY(dbc, ds); + delete dbc->connection_proxy; + dbc->connection_proxy = mock_proxy; + + EXPECT_CALL(*mock_proxy, query(_)) + .WillOnce(Return(0)); + + EXPECT_CALL(*mock_proxy, store_result()) + .WillOnce(ReturnNew()); + + char* row[1] = { "1" }; + EXPECT_CALL(*mock_proxy, fetch_row(_)) + .WillOnce(Return(row)); + + EXPECT_CALL(*mock_proxy, free_result(_)) + .WillOnce(DeleteArg<0>()); + + auto topology = std::make_shared(); + topology->add_host(writer_host); + topology->add_host(reader_host); + + EXPECT_CALL(*mock_ts, get_topology(_, false)) + .WillOnce(Return(topology)); + EXPECT_CALL(*mock_ts, get_topology(_, true)) + .WillOnce(Return(topology)); + + auto mock_failover_handler = std::make_shared(dbc, ds, mock_connection_handler, mock_ts, mock_metrics); + EXPECT_CALL(*mock_failover_handler, host_to_IP(_)) + .WillOnce(Return("10.10.10.10")) + .WillOnce(Return("20.20.20.20")); + + mock_failover_handler->init_connection(); + + EXPECT_EQ(std::string("writer-host.com"), std::string((const char*)ds->server8)); +} diff --git a/unit_testing/mock_objects.h b/unit_testing/mock_objects.h index 156b18e15..157f44866 100644 --- a/unit_testing/mock_objects.h +++ b/unit_testing/mock_objects.h @@ -122,6 +122,14 @@ class MOCK_CONNECTION_HANDLER : public CONNECTION_HANDLER { MOCK_METHOD(SQLRETURN, do_connect, (DBC*, DataSource*, bool)); }; +class MOCK_FAILOVER_HANDLER : public FAILOVER_HANDLER { +public: + MOCK_FAILOVER_HANDLER(DBC* dbc, DataSource* ds, std::shared_ptr ch, + std::shared_ptr ts, std::shared_ptr mc) : + FAILOVER_HANDLER(dbc, ds, ch, ts, mc) {} + MOCK_METHOD(std::string, host_to_IP, (std::string)); +}; + class MOCK_FAILOVER_SYNC : public FAILOVER_SYNC { public: MOCK_FAILOVER_SYNC() : FAILOVER_SYNC(1) {} diff --git a/unit_testing/test_utils.cc b/unit_testing/test_utils.cc index 5b7b438a0..0b2b42cf6 100644 --- a/unit_testing/test_utils.cc +++ b/unit_testing/test_utils.cc @@ -147,6 +147,10 @@ bool TEST_UTILS::is_rds_proxy_dns(std::string host) { return FAILOVER_HANDLER::is_rds_proxy_dns(host); } +bool TEST_UTILS::is_rds_writer_cluster_dns(std::string host) { + return FAILOVER_HANDLER::is_rds_writer_cluster_dns(host); +} + bool TEST_UTILS::is_rds_custom_cluster_dns(std::string host) { return FAILOVER_HANDLER::is_rds_custom_cluster_dns(host); } diff --git a/unit_testing/test_utils.h b/unit_testing/test_utils.h index c4567604b..ac83dd0a5 100644 --- a/unit_testing/test_utils.h +++ b/unit_testing/test_utils.h @@ -31,6 +31,7 @@ #define __TESTUTILS_H__ #include "driver/driver.h" +#include "driver/failover.h" #include "driver/iam_proxy.h" #include "driver/monitor.h" #include "driver/monitor_thread_container.h" @@ -63,6 +64,7 @@ class TEST_UTILS { static bool is_rds_dns(std::string host); static bool is_rds_cluster_dns(std::string host); static bool is_rds_proxy_dns(std::string host); + static bool is_rds_writer_cluster_dns(std::string host); static bool is_rds_custom_cluster_dns(std::string host); static std::string get_rds_cluster_host_url(std::string host); static std::string get_rds_instance_host_pattern(std::string host);