diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 5ed1ce8e6..c2a473451 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -60,8 +60,8 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) base_metrics_holder.cc catalog.cc catalog_no_i_s.cc cluster_topology_info.cc cluster_aware_hit_metrics_holder.cc cluster_aware_metrics_container.cc cluster_aware_metrics.cc cluster_aware_time_metrics_holder.cc - connect.cc cursor.cc desc.cc dll.cc driver.cc - error.cc execute.cc failover_connection_handler.cc failover_handler.cc + connect.cc connection_handler.cc cursor.cc desc.cc dll.cc driver.cc efm_proxy.cc + error.cc execute.cc failover_handler.cc failover_reader_handler.cc failover_writer_handler.cc handle.cc host_info.cc info.cc monitor.cc monitor_connection_context.cc monitor_service.cc monitor_thread_container.cc my_prepared_stmt.cc my_stmt.cc mylog.cc mysql_proxy.cc options.cc parse.cc prepare.cc query_parsing.cc @@ -83,8 +83,8 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) CONFIGURE_FILE(${CMAKE_SOURCE_DIR}/driver/driver.rc.cmake ${CMAKE_SOURCE_DIR}/driver/driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc @ONLY) SET(DRIVER_SRCS ${DRIVER_SRCS} driver${CONNECTOR_DRIVER_TYPE_SHORT}.def driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc base_metrics_holder.h catalog.h cluster_aware_hit_metrics_holder.h cluster_aware_metrics_container.h - cluster_aware_metrics.h cluster_aware_time_metrics_holder.h cluster_topology_info.h - driver.h error.h failover.h host_info.h monitor.h monitor_connection_context.h monitor_service.h + cluster_aware_metrics.h cluster_aware_time_metrics_holder.h cluster_topology_info.h connection_handler.h + driver.h efm_proxy.h error.h failover.h host_info.h monitor.h monitor_connection_context.h monitor_service.h monitor_thread_container.h mylog.h mysql_proxy.h myutil.h parse.h query_parsing.h topology_service.h ../MYODBC_MYSQL.h ../MYODBC_CONF.h ../MYODBC_ODBC.h) ENDIF(WIN32) diff --git a/driver/connect.cc b/driver/connect.cc index 5933749d3..ac2d71079 100644 --- a/driver/connect.cc +++ b/driver/connect.cc @@ -751,7 +751,7 @@ SQLRETURN DBC::connect(DataSource *dsrc, bool failover_enabled) ds_set_strnattr(&dsrc->server8, (SQLCHAR*)host, strlen(host)); dsrc->port = port; - Aws::SDKOptions options; + //Aws::SDKOptions options; //Aws::InitAPI(options); TODO: causing SSL connection error: SSL_CTX_new failed //Aws::ShutdownAPI(options); @@ -1056,7 +1056,8 @@ SQLRETURN SQL_API MySQLConnect(SQLHDBC hdbc, if (ds->save_queries && !dbc->log_file) dbc->log_file = init_log_file(); - dbc->mysql_proxy = new MYSQL_PROXY(dbc, ds); + dbc->init_proxy_chain(ds); + dbc->connection_handler = std::make_shared(dbc); dbc->fh = new FAILOVER_HANDLER(dbc, ds); rc = dbc->fh->init_cluster_info(); if (!dbc->ds) @@ -1173,7 +1174,8 @@ SQLRETURN SQL_API MySQLDriverConnect(SQLHDBC hdbc, SQLHWND hwnd, if (ds->save_queries && !dbc->log_file) dbc->log_file = init_log_file(); - dbc->mysql_proxy = new MYSQL_PROXY(dbc, ds); + dbc->init_proxy_chain(ds); + dbc->connection_handler = std::make_shared(dbc); dbc->fh = new FAILOVER_HANDLER(dbc, ds); rc = dbc->fh->init_cluster_info(); if (rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO) @@ -1352,7 +1354,8 @@ SQLRETURN SQL_API MySQLDriverConnect(SQLHDBC hdbc, SQLHWND hwnd, if (ds->save_queries && !dbc->log_file) dbc->log_file = init_log_file(); - dbc->mysql_proxy = new MYSQL_PROXY(dbc, ds); + dbc->init_proxy_chain(ds); + dbc->connection_handler = std::make_shared(dbc); dbc->fh = new FAILOVER_HANDLER(dbc, ds); rc = dbc->fh->init_cluster_info(); if (rc != SQL_SUCCESS && rc != SQL_SUCCESS_WITH_INFO) diff --git a/driver/failover_connection_handler.cc b/driver/connection_handler.cc similarity index 77% rename from driver/failover_connection_handler.cc rename to driver/connection_handler.cc index 453ece447..863b39055 100644 --- a/driver/failover_connection_handler.cc +++ b/driver/connection_handler.cc @@ -28,12 +28,13 @@ // http://www.gnu.org/licenses/gpl-2.0.html. /** - @file failover_connection_handler.c - @brief Failover connection functions. + @file connection_handler.c + @brief connection functions. */ +#include "connection_handler.h" #include "driver.h" -#include "failover.h" +#include "mysql_proxy.h" #include #include @@ -51,32 +52,36 @@ } #endif -FAILOVER_CONNECTION_HANDLER::FAILOVER_CONNECTION_HANDLER(DBC* dbc) : dbc{dbc} {} +CONNECTION_HANDLER::CONNECTION_HANDLER(DBC* dbc) : dbc{dbc} {} -FAILOVER_CONNECTION_HANDLER::~FAILOVER_CONNECTION_HANDLER() {} +CONNECTION_HANDLER::~CONNECTION_HANDLER() = default; -SQLRETURN FAILOVER_CONNECTION_HANDLER::do_connect(DBC* dbc_ptr, DataSource* ds, bool failover_enabled) { +SQLRETURN CONNECTION_HANDLER::do_connect(DBC* dbc_ptr, DataSource* ds, bool failover_enabled) { return dbc_ptr->connect(ds, failover_enabled); } -MYSQL_PROXY* FAILOVER_CONNECTION_HANDLER::connect(const std::shared_ptr& host_info) { +MYSQL_PROXY* CONNECTION_HANDLER::connect(const std::shared_ptr& host_info, DataSource* ds) { - if (dbc == nullptr || dbc->ds == nullptr || host_info == nullptr) { + if (dbc == nullptr || host_info == nullptr) { return nullptr; } + DataSource* ds_to_use = ds_new(); + ds_copy(ds_to_use, ds ? ds : dbc->ds); + const auto new_host = to_sqlwchar_string(host_info->get_host()); - DBC* dbc_clone = clone_dbc(dbc); - ds_set_wstrnattr(&dbc_clone->ds->server, (SQLWCHAR*)new_host.c_str(), new_host.size()); + DBC* dbc_clone = clone_dbc(dbc, ds_to_use); + ds_set_wstrnattr(&ds_to_use->server, (SQLWCHAR*)new_host.c_str(), new_host.size()); MYSQL_PROXY* new_connection = nullptr; CLEAR_DBC_ERROR(dbc_clone); - const SQLRETURN rc = do_connect(dbc_clone, dbc_clone->ds, true); + const SQLRETURN rc = do_connect(dbc_clone, ds_to_use, ds_to_use->enable_cluster_failover); if (rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO) { new_connection = dbc_clone->mysql_proxy; dbc_clone->mysql_proxy = nullptr; + // postpone the deletion of ds_to_use/dbc_clone->ds until we are done with new_connection dbc_clone->ds = nullptr; } @@ -85,7 +90,7 @@ MYSQL_PROXY* FAILOVER_CONNECTION_HANDLER::connect(const std::shared_ptris_connected()) { @@ -102,7 +107,7 @@ void FAILOVER_CONNECTION_HANDLER::update_connection( } } -DBC* FAILOVER_CONNECTION_HANDLER::clone_dbc(DBC* source_dbc) { +DBC* CONNECTION_HANDLER::clone_dbc(DBC* source_dbc, DataSource* ds) { DBC* dbc_clone = nullptr; @@ -114,11 +119,9 @@ DBC* FAILOVER_CONNECTION_HANDLER::clone_dbc(DBC* source_dbc) { status = my_SQLAllocConnect(henv, &hdbc); if (status == SQL_SUCCESS || status == SQL_SUCCESS_WITH_INFO) { dbc_clone = static_cast(hdbc); - dbc_clone->ds = ds_new(); - ds_copy(dbc_clone->ds, source_dbc->ds); - dbc_clone->mysql_proxy = new MYSQL_PROXY(dbc_clone, dbc_clone->ds); + dbc_clone->init_proxy_chain(ds); } else { - const char* err = "Cannot allocate connection handle when cloning DBC in writer failover process"; + const char* err = "Cannot allocate connection handle when cloning DBC"; MYLOG_DBC_TRACE(dbc, err); throw std::runtime_error(err); } diff --git a/driver/connection_handler.h b/driver/connection_handler.h new file mode 100644 index 000000000..1cab2745a --- /dev/null +++ b/driver/connection_handler.h @@ -0,0 +1,64 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __CONNECTION_HANDLER_H__ +#define __CONNECTION_HANDLER_H__ + +#include "host_info.h" + +#include + +#ifdef __linux__ +typedef std::u16string sqlwchar_string; +#else +typedef std::wstring sqlwchar_string; +#endif + +sqlwchar_string to_sqlwchar_string(const std::string& src); + +struct DBC; +struct DataSource; +class MYSQL_PROXY; +typedef short SQLRETURN; + +class CONNECTION_HANDLER { + public: + CONNECTION_HANDLER(DBC* dbc); + virtual ~CONNECTION_HANDLER(); + + virtual SQLRETURN do_connect(DBC* dbc_ptr, DataSource* ds, bool failover_enabled); + virtual MYSQL_PROXY* connect(const std::shared_ptr& host_info, DataSource* ds); + void update_connection(MYSQL_PROXY* new_connection, const std::string& new_host_name); + + private: + DBC* dbc; + DBC* clone_dbc(DBC* source_dbc, DataSource* ds); +}; + +#endif /* __CONNECTION_HANDLER_H__ */ diff --git a/driver/driver.h b/driver/driver.h index 684a2fec1..a9cd79015 100644 --- a/driver/driver.h +++ b/driver/driver.h @@ -41,9 +41,11 @@ #include "../MYODBC_MYSQL.h" #include "../MYODBC_CONF.h" #include "../MYODBC_ODBC.h" -#include "util/installer.h" +#include "connection_handler.h" +#include "efm_proxy.h" #include "failover.h" #include "mysql_proxy.h" +#include "util/installer.h" /* Disable _attribute__ on non-gcc compilers. */ #if !defined(__attribute__) && !defined(__GNUC__) @@ -643,6 +645,7 @@ struct DBC fido_callback_func fido_callback = nullptr; FAILOVER_HANDLER *fh = nullptr; /* Failover handler */ + std::shared_ptr connection_handler = nullptr; DBC(ENV *p_env); void free_explicit_descriptors(); @@ -654,6 +657,7 @@ struct DBC SQLRETURN connect(DataSource *dsrc, bool failover_enabled); void execute_prep_stmt(MYSQL_STMT *pstmt, std::string &query, MYSQL_BIND *param_bind, MYSQL_BIND *result_bind); + void init_proxy_chain(DataSource *dsrc); inline bool transactions_supported() { return mysql_proxy->get_server_capabilities() & CLIENT_TRANSACTIONS; diff --git a/driver/efm_proxy.cc b/driver/efm_proxy.cc new file mode 100644 index 000000000..4329c1ea8 --- /dev/null +++ b/driver/efm_proxy.cc @@ -0,0 +1,535 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "driver.h" + +namespace { + const char* RETRIEVE_HOST_PORT_SQL = "SELECT CONCAT(@@hostname, ':', @@port)"; +} + +EFM_PROXY::EFM_PROXY(DBC* dbc, DataSource* ds) : EFM_PROXY( + dbc, ds, nullptr, std::make_shared(ds && ds->save_queries)) {} + +EFM_PROXY::EFM_PROXY(DBC* dbc, DataSource* ds, MYSQL_PROXY* next_proxy) : EFM_PROXY( + dbc, ds, next_proxy, std::make_shared(ds && ds->save_queries)) {} + +EFM_PROXY::EFM_PROXY(DBC* dbc, DataSource* ds, MYSQL_PROXY* next_proxy, + std::shared_ptr monitor_service) + : MYSQL_PROXY(dbc, ds), + monitor_service{std::move(monitor_service)} { + + this->next_proxy = next_proxy; +} + +std::shared_ptr EFM_PROXY::start_monitoring() { + if (!ds || !ds->enable_failure_detection) { + return nullptr; + } + + auto failure_detection_timeout = ds->failure_detection_timeout; + // Use network timeout defined if failure detection timeout is not set + if (failure_detection_timeout == 0) { + failure_detection_timeout = ds->network_timeout == 0 ? failure_detection_timeout_default : ds->network_timeout; + } + + return monitor_service->start_monitoring( + dbc, + ds, + node_keys, + std::make_shared(get_host(), get_port()), + std::chrono::milliseconds{ds->failure_detection_time}, + std::chrono::seconds{failure_detection_timeout}, + std::chrono::milliseconds{ds->failure_detection_interval}, + ds->failure_detection_count, + std::chrono::milliseconds{ds->monitor_disposal_time}); +} + +void EFM_PROXY::stop_monitoring(std::shared_ptr context) { + if (!ds || !ds->enable_failure_detection || context == nullptr) { + return; + } + monitor_service->stop_monitoring(context); + if (context->is_node_unhealthy() && is_connected()) { + close_socket(); + } +} + +void EFM_PROXY::generate_node_keys() { + node_keys.clear(); + node_keys.insert(std::string(get_host()) + ":" + std::to_string(get_port())); + + if (is_connected()) { + // Temporarily turn off failure detection if on + const auto failure_detection_old_state = ds->enable_failure_detection; + ds->enable_failure_detection = false; + + const auto error = query(RETRIEVE_HOST_PORT_SQL); + if (error == 0) { + MYSQL_RES* result = store_result(); + MYSQL_ROW row; + while ((row = fetch_row(result))) { + node_keys.insert(std::string(row[0])); + } + free_result(result); + } + + ds->enable_failure_detection = failure_detection_old_state; + } +} + +void EFM_PROXY::set_next_proxy(MYSQL_PROXY* next_proxy) { + MYSQL_PROXY::set_next_proxy(next_proxy); + generate_node_keys(); +} + +void EFM_PROXY::set_connection(MYSQL_PROXY* mysql_proxy) { + next_proxy->set_connection(mysql_proxy); + + if (monitor_service != nullptr && !node_keys.empty()) { + monitor_service->stop_monitoring_for_all_connections(node_keys); + } + + generate_node_keys(); +} + +void EFM_PROXY::close_socket() { + next_proxy->close_socket(); +} + +void EFM_PROXY::delete_ds() { + next_proxy->delete_ds(); +} + +uint64_t EFM_PROXY::num_rows(MYSQL_RES* res) { + return next_proxy->num_rows(res); +} + +unsigned EFM_PROXY::num_fields(MYSQL_RES* res) { + return next_proxy->num_fields(res); +} + +MYSQL_FIELD* EFM_PROXY::fetch_field_direct(MYSQL_RES* res, unsigned fieldnr) { + return next_proxy->fetch_field_direct(res, fieldnr); +} + +MYSQL_ROW_OFFSET EFM_PROXY::row_tell(MYSQL_RES* res) { + return next_proxy->row_tell(res); +} + +unsigned EFM_PROXY::field_count() { + return next_proxy->field_count(); +} + +uint64_t EFM_PROXY::affected_rows() { + return next_proxy->affected_rows(); +} + +unsigned EFM_PROXY::error_code() { + return next_proxy->error_code(); +} + +const char* EFM_PROXY::error() { + return next_proxy->error(); +} + +const char* EFM_PROXY::sqlstate() { + return next_proxy->sqlstate(); +} + +unsigned long EFM_PROXY::thread_id() { + return next_proxy->thread_id(); +} + +int EFM_PROXY::set_character_set(const char* csname) { + const auto context = start_monitoring(); + const int ret = next_proxy->set_character_set(csname); + stop_monitoring(context); + return ret; +} + +void EFM_PROXY::init() { + next_proxy->init(); +} + +bool EFM_PROXY::ssl_set(const char* key, const char* cert, const char* ca, const char* capath, const char* cipher) { + return next_proxy->ssl_set(key, cert, ca, capath, cipher); +} + +bool EFM_PROXY::change_user(const char* user, const char* passwd, const char* db) { + const auto context = start_monitoring(); + const bool ret = next_proxy->change_user(user, passwd, db); + stop_monitoring(context); + return ret; +} + +bool EFM_PROXY::real_connect( + const char* host, const char* user, const char* passwd, + const char* db, unsigned int port, const char* unix_socket, + unsigned long clientflag) { + + const bool ret = next_proxy->real_connect(host, user, passwd, db, port, unix_socket, clientflag); + + generate_node_keys(); + + return ret; +} + +int EFM_PROXY::select_db(const char* db) { + const auto context = start_monitoring(); + const int ret = next_proxy->select_db(db); + stop_monitoring(context); + return ret; +} + +int EFM_PROXY::query(const char* q) { + const auto context = start_monitoring(); + const int ret = next_proxy->query(q); + stop_monitoring(context); + return ret; +} + +int EFM_PROXY::real_query(const char* q, unsigned long length) { + const auto context = start_monitoring(); + const int ret = next_proxy->real_query(q, length); + stop_monitoring(context); + return ret; +} + +MYSQL_RES* EFM_PROXY::store_result() { + const auto context = start_monitoring(); + MYSQL_RES* ret = next_proxy->store_result(); + stop_monitoring(context); + return ret; +} + +MYSQL_RES* EFM_PROXY::use_result() { + const auto context = start_monitoring(); + MYSQL_RES* ret = next_proxy->use_result(); + stop_monitoring(context); + return ret; +} + +CHARSET_INFO* EFM_PROXY::get_character_set() const { + return next_proxy->get_character_set(); +} + +void EFM_PROXY::get_character_set_info(MY_CHARSET_INFO* charset) { + next_proxy->get_character_set_info(charset); +} + +bool EFM_PROXY::autocommit(bool auto_mode) { + return next_proxy->autocommit(auto_mode); +} + +int EFM_PROXY::next_result() { + const auto context = start_monitoring(); + const int ret = next_proxy->next_result(); + stop_monitoring(context); + return ret; +} + +int EFM_PROXY::stmt_next_result(MYSQL_STMT* stmt) { + const auto context = start_monitoring(); + const int ret = next_proxy->stmt_next_result(stmt); + stop_monitoring(context); + return ret; +} + +void EFM_PROXY::close() { + next_proxy->close(); +} + +bool EFM_PROXY::real_connect_dns_srv( + const char* dns_srv_name, const char* user, + const char* passwd, const char* db, unsigned long client_flag) { + + const bool ret = next_proxy->real_connect_dns_srv(dns_srv_name, user, passwd, db, client_flag); + + generate_node_keys(); + + return ret; +} + +int EFM_PROXY::ping() { + return next_proxy->ping(); +} + +int EFM_PROXY::options4(mysql_option option, const void* arg1, const void* arg2) { + return next_proxy->options4(option, arg1, arg2); +} + +int EFM_PROXY::get_option(mysql_option option, const void* arg) { + return next_proxy->get_option(option, arg); +} + +int EFM_PROXY::options(mysql_option option, const void* arg) { + return next_proxy->options(option, arg); +} + +void EFM_PROXY::free_result(MYSQL_RES* result) { + const auto context = start_monitoring(); + next_proxy->free_result(result); + stop_monitoring(context); +} + +void EFM_PROXY::data_seek(MYSQL_RES* result, uint64_t offset) { + next_proxy->data_seek(result, offset); +} + +MYSQL_ROW_OFFSET EFM_PROXY::row_seek(MYSQL_RES* result, MYSQL_ROW_OFFSET offset) { + return next_proxy->row_seek(result, offset); +} + +MYSQL_FIELD_OFFSET EFM_PROXY::field_seek(MYSQL_RES* result, MYSQL_FIELD_OFFSET offset) { + return next_proxy->field_seek(result, offset); +} + +MYSQL_ROW EFM_PROXY::fetch_row(MYSQL_RES* result) { + const auto context = start_monitoring(); + const MYSQL_ROW ret = next_proxy->fetch_row(result); + stop_monitoring(context); + return ret; +} + +unsigned long* EFM_PROXY::fetch_lengths(MYSQL_RES* result) { + return next_proxy->fetch_lengths(result); +} + +MYSQL_FIELD* EFM_PROXY::fetch_field(MYSQL_RES* result) { + return next_proxy->fetch_field(result); +} + +MYSQL_RES* EFM_PROXY::list_fields(const char* table, const char* wild) { + const auto context = start_monitoring(); + MYSQL_RES* ret = next_proxy->list_fields(table, wild); + stop_monitoring(context); + return ret; +} + +unsigned long EFM_PROXY::real_escape_string(char* to, const char* from, unsigned long length) { + const auto context = start_monitoring(); + const unsigned long ret = next_proxy->real_escape_string(to, from, length); + stop_monitoring(context); + return ret; +} + +bool EFM_PROXY::bind_param(unsigned n_params, MYSQL_BIND* binds, const char** names) { + const auto context = start_monitoring(); + const bool ret = next_proxy->bind_param(n_params, binds, names); + stop_monitoring(context); + return ret; +} + +MYSQL_STMT* EFM_PROXY::stmt_init() { + const auto context = start_monitoring(); + MYSQL_STMT* ret = next_proxy->stmt_init(); + stop_monitoring(context); + return ret; +} + +int EFM_PROXY::stmt_prepare(MYSQL_STMT* stmt, const char* query, unsigned long length) { + const auto context = start_monitoring(); + const int ret = next_proxy->stmt_prepare(stmt, query, length); + stop_monitoring(context); + return ret; +} + +int EFM_PROXY::stmt_execute(MYSQL_STMT* stmt) { + const auto context = start_monitoring(); + const int ret = next_proxy->stmt_execute(stmt); + stop_monitoring(context); + return ret; +} + +int EFM_PROXY::stmt_fetch(MYSQL_STMT* stmt) { + const auto context = start_monitoring(); + const int ret = next_proxy->stmt_fetch(stmt); + stop_monitoring(context); + return ret; +} + +int EFM_PROXY::stmt_fetch_column(MYSQL_STMT* stmt, MYSQL_BIND* bind_arg, unsigned int column, unsigned long offset) { + const auto context = start_monitoring(); + const int ret = next_proxy->stmt_fetch_column(stmt, bind_arg, column, offset); + stop_monitoring(context); + return ret; +} + +int EFM_PROXY::stmt_store_result(MYSQL_STMT* stmt) { + const auto context = start_monitoring(); + const int ret = next_proxy->stmt_store_result(stmt); + stop_monitoring(context); + return ret; +} + +unsigned long EFM_PROXY::stmt_param_count(MYSQL_STMT* stmt) { + return next_proxy->stmt_param_count(stmt); +} + +bool EFM_PROXY::stmt_bind_param(MYSQL_STMT* stmt, MYSQL_BIND* bnd) { + const auto context = start_monitoring(); + const bool ret = next_proxy->stmt_bind_param(stmt, bnd); + stop_monitoring(context); + return ret; +} + +bool EFM_PROXY::stmt_bind_result(MYSQL_STMT* stmt, MYSQL_BIND* bnd) { + const auto context = start_monitoring(); + const bool ret = next_proxy->stmt_bind_result(stmt, bnd); + stop_monitoring(context); + return ret; +} + +bool EFM_PROXY::stmt_close(MYSQL_STMT* stmt) { + const auto context = start_monitoring(); + const bool ret = next_proxy->stmt_close(stmt); + stop_monitoring(context); + return ret; +} + +bool EFM_PROXY::stmt_reset(MYSQL_STMT* stmt) { + const auto context = start_monitoring(); + const bool ret = next_proxy->stmt_reset(stmt); + stop_monitoring(context); + return ret; +} + +bool EFM_PROXY::stmt_free_result(MYSQL_STMT* stmt) { + const auto context = start_monitoring(); + const bool ret = next_proxy->stmt_free_result(stmt); + stop_monitoring(context); + return ret; +} + +bool EFM_PROXY::stmt_send_long_data(MYSQL_STMT* stmt, unsigned int param_number, const char* data, + unsigned long length) { + const auto context = start_monitoring(); + const bool ret = next_proxy->stmt_send_long_data(stmt, param_number, data, length); + stop_monitoring(context); + return ret; +} + +MYSQL_RES* EFM_PROXY::stmt_result_metadata(MYSQL_STMT* stmt) { + const auto context = start_monitoring(); + MYSQL_RES* ret = next_proxy->stmt_result_metadata(stmt); + stop_monitoring(context); + return ret; +} + +unsigned EFM_PROXY::stmt_errno(MYSQL_STMT* stmt) { + return next_proxy->stmt_errno(stmt); +} + +const char* EFM_PROXY::stmt_error(MYSQL_STMT* stmt) { + return next_proxy->stmt_error(stmt); +} + +MYSQL_ROW_OFFSET EFM_PROXY::stmt_row_seek(MYSQL_STMT* stmt, MYSQL_ROW_OFFSET offset) { + return next_proxy->stmt_row_seek(stmt, offset); +} + +MYSQL_ROW_OFFSET EFM_PROXY::stmt_row_tell(MYSQL_STMT* stmt) { + return next_proxy->stmt_row_tell(stmt); +} + +void EFM_PROXY::stmt_data_seek(MYSQL_STMT* stmt, uint64_t offset) { + next_proxy->stmt_data_seek(stmt, offset); +} + +uint64_t EFM_PROXY::stmt_num_rows(MYSQL_STMT* stmt) { + return next_proxy->stmt_num_rows(stmt); +} + +uint64_t EFM_PROXY::stmt_affected_rows(MYSQL_STMT* stmt) { + return next_proxy->stmt_affected_rows(stmt); +} + +unsigned EFM_PROXY::stmt_field_count(MYSQL_STMT* stmt) { + return next_proxy->stmt_field_count(stmt); +} + +st_mysql_client_plugin* EFM_PROXY::client_find_plugin(const char* name, int type) { + return next_proxy->client_find_plugin(name, type); +} + +bool EFM_PROXY::is_connected() { + return next_proxy->is_connected(); +} + +void EFM_PROXY::set_last_error_code(unsigned error_code) { + next_proxy->set_last_error_code(error_code); +} + +char* EFM_PROXY::get_last_error() const { + return next_proxy->get_last_error(); +} + +unsigned EFM_PROXY::get_last_error_code() const { + return next_proxy->get_last_error_code(); +} + +char* EFM_PROXY::get_sqlstate() const { + return next_proxy->get_sqlstate(); +} + +char* EFM_PROXY::get_server_version() const { + return next_proxy->get_server_version(); +} + +uint64_t EFM_PROXY::get_affected_rows() const { + return next_proxy->get_affected_rows(); +} + +void EFM_PROXY::set_affected_rows(uint64_t num_rows) { + next_proxy->set_affected_rows(num_rows); +} + +char* EFM_PROXY::get_host_info() const { + return next_proxy->get_host_info(); +} + +std::string EFM_PROXY::get_host() { + return next_proxy->get_host(); +} + +unsigned EFM_PROXY::get_port() { + return next_proxy->get_port(); +} + +unsigned long EFM_PROXY::get_max_packet() const { + return next_proxy->get_max_packet(); +} + +unsigned long EFM_PROXY::get_server_capabilities() const { + return next_proxy->get_server_capabilities(); +} + +unsigned EFM_PROXY::get_server_status() const { + return next_proxy->get_server_status(); +} diff --git a/driver/efm_proxy.h b/driver/efm_proxy.h new file mode 100644 index 000000000..a2168d78f --- /dev/null +++ b/driver/efm_proxy.h @@ -0,0 +1,173 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __EFM_PROXY__ +#define __EFM_PROXY__ + +#include + +#include "monitor_service.h" +#include "mysql_proxy.h" + +class EFM_PROXY : public MYSQL_PROXY { +public: + EFM_PROXY(DBC* dbc, DataSource* ds); + EFM_PROXY(DBC* dbc, DataSource* ds, MYSQL_PROXY* next_proxy); + EFM_PROXY(DBC* dbc, DataSource* ds, MYSQL_PROXY* next_proxy, std::shared_ptr monitor_service); + + void delete_ds() override; + uint64_t num_rows(MYSQL_RES* res) override; + unsigned int num_fields(MYSQL_RES* res) override; + MYSQL_FIELD* fetch_field_direct(MYSQL_RES* res, unsigned int fieldnr) override; + MYSQL_ROW_OFFSET row_tell(MYSQL_RES* res) override; + + unsigned int field_count() override; + uint64_t affected_rows() override; + unsigned int error_code() override; + const char* error() override; + const char* sqlstate() override; + unsigned long thread_id() override; + int set_character_set(const char* csname) override; + + void init() override; + bool ssl_set(const char* key, const char* cert, const char* ca, + const char* capath, const char* cipher) override; + bool change_user(const char* user, const char* passwd, + const char* db) override; + bool real_connect(const char* host, const char* user, + const char* passwd, const char* db, unsigned int port, + const char* unix_socket, unsigned long clientflag) override; + int select_db(const char* db) override; + int query(const char* q) override; + int real_query(const char* q, unsigned long length) override; + MYSQL_RES* store_result() override; + MYSQL_RES* use_result() override; + struct CHARSET_INFO* get_character_set() const override; + void get_character_set_info(MY_CHARSET_INFO* charset) override; + + int ping() override; + int options(enum mysql_option option, const void* arg) override; + int options4(enum mysql_option option, const void* arg1, + const void* arg2) override; + int get_option(enum mysql_option option, const void* arg) override; + void free_result(MYSQL_RES* result) override; + void data_seek(MYSQL_RES* result, uint64_t offset) override; + MYSQL_ROW_OFFSET row_seek(MYSQL_RES* result, MYSQL_ROW_OFFSET offset) override; + MYSQL_FIELD_OFFSET field_seek(MYSQL_RES* result, MYSQL_FIELD_OFFSET offset) override; + MYSQL_ROW fetch_row(MYSQL_RES* result) override; + + unsigned long* fetch_lengths(MYSQL_RES* result) override; + MYSQL_FIELD* fetch_field(MYSQL_RES* result) override; + MYSQL_RES* list_fields(const char* table, const char* wild) override; + unsigned long real_escape_string(char* to, const char* from, + unsigned long length) override; + + bool bind_param(unsigned n_params, MYSQL_BIND* binds, + const char** names) override; + + MYSQL_STMT* stmt_init() override; + int stmt_prepare(MYSQL_STMT* stmt, const char* query, unsigned long length) override; + int stmt_execute(MYSQL_STMT* stmt) override; + int stmt_fetch(MYSQL_STMT* stmt) override; + int stmt_fetch_column(MYSQL_STMT* stmt, MYSQL_BIND* bind_arg, + unsigned int column, unsigned long offset) override; + int stmt_store_result(MYSQL_STMT* stmt) override; + unsigned long stmt_param_count(MYSQL_STMT* stmt) override; + bool stmt_bind_param(MYSQL_STMT* stmt, MYSQL_BIND* bnd) override; + bool stmt_bind_result(MYSQL_STMT* stmt, MYSQL_BIND* bnd) override; + bool stmt_close(MYSQL_STMT* stmt) override; + bool stmt_reset(MYSQL_STMT* stmt) override; + bool stmt_free_result(MYSQL_STMT* stmt) override; + bool stmt_send_long_data(MYSQL_STMT* stmt, unsigned int param_number, + const char* data, unsigned long length) override; + MYSQL_RES* stmt_result_metadata(MYSQL_STMT* stmt) override; + unsigned int stmt_errno(MYSQL_STMT* stmt) override; + const char* stmt_error(MYSQL_STMT* stmt) override; + MYSQL_ROW_OFFSET stmt_row_seek(MYSQL_STMT* stmt, MYSQL_ROW_OFFSET offset) override; + MYSQL_ROW_OFFSET stmt_row_tell(MYSQL_STMT* stmt) override; + void stmt_data_seek(MYSQL_STMT* stmt, uint64_t offset) override; + uint64_t stmt_num_rows(MYSQL_STMT* stmt) override; + uint64_t stmt_affected_rows(MYSQL_STMT* stmt) override; + unsigned int stmt_field_count(MYSQL_STMT* stmt) override; + + bool autocommit(bool auto_mode) override; + int next_result() override; + int stmt_next_result(MYSQL_STMT* stmt) override; + void close() override; + + bool real_connect_dns_srv(const char* dns_srv_name, + const char* user, const char* passwd, + const char* db, unsigned long client_flag) override; + struct st_mysql_client_plugin* client_find_plugin( + const char* name, int type) override; + + bool is_connected() override; + + void set_last_error_code(unsigned int error_code) override; + + char* get_last_error() const override; + + unsigned int get_last_error_code() const override; + + char* get_sqlstate() const override; + + char* get_server_version() const override; + + uint64_t get_affected_rows() const override; + + void set_affected_rows(uint64_t num_rows) override; + + char* get_host_info() const override; + + std::string get_host() override; + + unsigned int get_port() override; + + unsigned long get_max_packet() const override; + + unsigned long get_server_capabilities() const override; + + unsigned int get_server_status() const override; + + void set_connection(MYSQL_PROXY* mysql_proxy) override; + + void close_socket() override; + + void set_next_proxy(MYSQL_PROXY* next_proxy) override; + +private: + std::shared_ptr monitor_service = nullptr; + std::set node_keys; + + std::shared_ptr start_monitoring(); + void stop_monitoring(std::shared_ptr context); + void generate_node_keys(); +}; + +#endif /* __EFM_PROXY__ */ diff --git a/driver/failover.h b/driver/failover.h index 0b19e86f1..cae82c4f1 100644 --- a/driver/failover.h +++ b/driver/failover.h @@ -30,38 +30,16 @@ #ifndef __FAILOVER_H__ #define __FAILOVER_H__ +#include "connection_handler.h" #include "topology_service.h" #include "mylog.h" -#include #include -#ifdef __linux__ -typedef std::u16string sqlwchar_string; -#else -typedef std::wstring sqlwchar_string; -#endif - -sqlwchar_string to_sqlwchar_string(const std::string& src); - struct DBC; struct DataSource; typedef short SQLRETURN; -class FAILOVER_CONNECTION_HANDLER { - public: - FAILOVER_CONNECTION_HANDLER(DBC* dbc); - virtual ~FAILOVER_CONNECTION_HANDLER(); - - virtual SQLRETURN do_connect(DBC* dbc_ptr, DataSource* ds, bool failover_enabled); - virtual MYSQL_PROXY* connect(const std::shared_ptr& host_info); - void update_connection(MYSQL_PROXY* new_connection, const std::string& new_host_name); - - private: - DBC* dbc; - DBC* clone_dbc(DBC* source_dbc); -}; - struct READER_FAILOVER_RESULT { bool connected = false; std::shared_ptr new_host; @@ -96,7 +74,7 @@ class FAILOVER_READER_HANDLER { public: FAILOVER_READER_HANDLER( std::shared_ptr topology_service, - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, int failover_timeout_ms, int failover_reader_connect_timeout, unsigned long dbc_id, bool enable_logging = false); @@ -123,7 +101,7 @@ class FAILOVER_READER_HANDLER { private: std::shared_ptr topology_service; - std::shared_ptr connection_handler; + std::shared_ptr connection_handler; const int READER_CONNECT_INTERVAL_SEC = 1; // 1 sec std::shared_ptr logger = nullptr; unsigned long dbc_id = 0; @@ -157,7 +135,7 @@ class FAILOVER_WRITER_HANDLER { FAILOVER_WRITER_HANDLER( std::shared_ptr topology_service, std::shared_ptr reader_handler, - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, int writer_failover_timeout_ms, int read_topology_interval_ms, int reconnect_writer_interval_ms, unsigned long dbc_id, bool enable_logging = false); ~FAILOVER_WRITER_HANDLER(); @@ -171,7 +149,7 @@ class FAILOVER_WRITER_HANDLER { private: std::shared_ptr topology_service; - std::shared_ptr connection_handler; + std::shared_ptr connection_handler; std::shared_ptr reader_handler; std::shared_ptr logger = nullptr; unsigned long dbc_id = 0; @@ -182,7 +160,7 @@ class FAILOVER_HANDLER { FAILOVER_HANDLER(DBC* dbc, DataSource* ds); FAILOVER_HANDLER( DBC* dbc, DataSource* ds, - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, std::shared_ptr topology_service, std::shared_ptr metrics_container); ~FAILOVER_HANDLER(); @@ -203,7 +181,7 @@ class FAILOVER_HANDLER { std::shared_ptr failover_writer_handler; std::shared_ptr current_topology; std::shared_ptr current_host = nullptr; - std::shared_ptr connection_handler = nullptr; + std::shared_ptr connection_handler = nullptr; bool m_is_cluster_topology_available = false; bool m_is_multi_writer_cluster = false; bool m_is_rds_proxy = false; @@ -238,7 +216,7 @@ class FAILOVER_HANDLER { class FAILOVER { public: - FAILOVER(std::shared_ptr connection_handler, + FAILOVER(std::shared_ptr connection_handler, std::shared_ptr topology_service, unsigned long dbc_id, bool enable_logging = false); virtual ~FAILOVER() = default; @@ -248,7 +226,7 @@ class FAILOVER { bool connect(const std::shared_ptr& host_info); void sleep(int miliseconds); void release_new_connection(); - std::shared_ptr connection_handler; + std::shared_ptr connection_handler; std::shared_ptr topology_service; MYSQL_PROXY* new_connection; std::shared_ptr logger = nullptr; @@ -258,7 +236,7 @@ class FAILOVER { class CONNECT_TO_READER_HANDLER : public FAILOVER { public: CONNECT_TO_READER_HANDLER( - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, std::shared_ptr topology_service, unsigned long dbc_id, bool enable_logging = false); ~CONNECT_TO_READER_HANDLER(); @@ -272,7 +250,7 @@ class CONNECT_TO_READER_HANDLER : public FAILOVER { class RECONNECT_TO_WRITER_HANDLER : public FAILOVER { public: RECONNECT_TO_WRITER_HANDLER( - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, std::shared_ptr topology_service, int connection_interval, unsigned long dbc_id, bool enable_logging = false); ~RECONNECT_TO_WRITER_HANDLER(); @@ -293,7 +271,7 @@ class RECONNECT_TO_WRITER_HANDLER : public FAILOVER { class WAIT_NEW_WRITER_HANDLER : public FAILOVER { public: WAIT_NEW_WRITER_HANDLER( - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, std::shared_ptr topology_service, std::shared_ptr current_topology, std::shared_ptr reader_handler, diff --git a/driver/failover_handler.cc b/driver/failover_handler.cc index 73b7b2a7a..467f64ff5 100644 --- a/driver/failover_handler.cc +++ b/driver/failover_handler.cc @@ -56,12 +56,12 @@ const std::regex IPV6_COMPRESSED_PATTERN( FAILOVER_HANDLER::FAILOVER_HANDLER(DBC* dbc, DataSource* ds) : FAILOVER_HANDLER( - dbc, ds, std::make_shared(dbc), - std::make_shared(dbc ? dbc->id : 0, ds ? ds->save_queries : false), - std::make_shared(dbc, ds)) {} + dbc, ds, dbc ? dbc->connection_handler : nullptr, + std::make_shared(dbc ? dbc->id : 0, ds ? ds->save_queries : false), + std::make_shared(dbc, ds)) {} FAILOVER_HANDLER::FAILOVER_HANDLER(DBC* dbc, DataSource* ds, - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, std::shared_ptr topology_service, std::shared_ptr metrics_container) { if (!dbc) { diff --git a/driver/failover_reader_handler.cc b/driver/failover_reader_handler.cc index 43f1d9b32..99eafbf5c 100644 --- a/driver/failover_reader_handler.cc +++ b/driver/failover_reader_handler.cc @@ -38,7 +38,7 @@ FAILOVER_READER_HANDLER::FAILOVER_READER_HANDLER( std::shared_ptr topology_service, - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, int failover_timeout_ms, int failover_reader_connect_timeout, unsigned long dbc_id, bool enable_logging) : topology_service{topology_service}, @@ -223,7 +223,7 @@ READER_FAILOVER_RESULT FAILOVER_READER_HANDLER::get_connection_from_hosts( // *** CONNECT_TO_READER_HANDLER // Handler to connect to a reader host. CONNECT_TO_READER_HANDLER::CONNECT_TO_READER_HANDLER( - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, std::shared_ptr topology_service, unsigned long dbc_id, bool enable_logging) : FAILOVER{connection_handler, topology_service, dbc_id, enable_logging} {} diff --git a/driver/failover_writer_handler.cc b/driver/failover_writer_handler.cc index da27dfc32..33f718ba6 100644 --- a/driver/failover_writer_handler.cc +++ b/driver/failover_writer_handler.cc @@ -70,7 +70,7 @@ bool FAILOVER_SYNC::is_completed() { // ************* FAILOVER *********************************** // Base class of two writer failover task handlers FAILOVER::FAILOVER( - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, std::shared_ptr topology_service, unsigned long dbc_id, bool enable_logging) : connection_handler{connection_handler}, @@ -87,7 +87,7 @@ bool FAILOVER::is_writer_connected() { } bool FAILOVER::connect(const std::shared_ptr& host_info) { - new_connection = connection_handler->connect(host_info); + new_connection = connection_handler->connect(host_info, nullptr); return is_writer_connected(); } @@ -107,7 +107,7 @@ void FAILOVER::release_new_connection() { // ************************ RECONNECT_TO_WRITER_HANDLER // handler reconnecting to a given host, e.g. reconnect to a current writer RECONNECT_TO_WRITER_HANDLER::RECONNECT_TO_WRITER_HANDLER( - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, std::shared_ptr topology_service, int connection_interval, unsigned long dbc_id, bool enable_logging) : FAILOVER{connection_handler, topology_service, dbc_id, enable_logging}, @@ -168,7 +168,7 @@ bool RECONNECT_TO_WRITER_HANDLER::is_current_host_writer( // handler getting the latest cluster topology and connecting to a newly elected // writer WAIT_NEW_WRITER_HANDLER::WAIT_NEW_WRITER_HANDLER( - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, std::shared_ptr topology_service, std::shared_ptr current_topology, std::shared_ptr reader_handler, @@ -274,7 +274,7 @@ void WAIT_NEW_WRITER_HANDLER::clean_up_reader_connection() { FAILOVER_WRITER_HANDLER::FAILOVER_WRITER_HANDLER( std::shared_ptr topology_service, std::shared_ptr reader_handler, - std::shared_ptr connection_handler, + std::shared_ptr connection_handler, int writer_failover_timeout_ms, int read_topology_interval_ms, int reconnect_writer_interval_ms, unsigned long dbc_id, bool enable_logging) : connection_handler{connection_handler}, diff --git a/driver/handle.cc b/driver/handle.cc index 3a669d58a..78f291b37 100644 --- a/driver/handle.cc +++ b/driver/handle.cc @@ -111,6 +111,34 @@ void DBC::close() mysql_proxy->close(); } +// construct a proxy chain, example: iam->efm->mysql +void DBC::init_proxy_chain(DataSource* dsrc) +{ + MYSQL_PROXY *head = new MYSQL_PROXY(this, dsrc); + + if (dsrc->enable_failure_detection) { + MYSQL_PROXY* efm_proxy = new EFM_PROXY(this, dsrc); + efm_proxy->set_next_proxy(head); + head = efm_proxy; + } + + ds_get_utf8attr(dsrc->auth_mode, &dsrc->auth_mode8); + + if (!myodbc_strcasecmp(AUTH_MODE_IAM, reinterpret_cast(dsrc->auth_mode8))) { + // MYSQL_PROXY* iam_proxy = new IAM_PROXY(his, dsrc); + // iam_proxy->set_next_proxy(head); + // head = iam_proxy; + } + + if (!myodbc_strcasecmp(AUTH_MODE_SECRETS_MANAGER, reinterpret_cast(dsrc->auth_mode8))) { + // MYSQL_PROXY* secrets_manager_proxy = new SECRETS_MANAGER_PROXY(his, dsrc); + // secrets_manager_proxy->set_next_proxy(head); + // head = secrets_manager_proxy; + } + + this->mysql_proxy = head; +} + DBC::~DBC() { if (env) @@ -190,12 +218,13 @@ SQLRETURN SQL_API SQLAllocEnv(SQLHENV *phenv) SQLRETURN SQL_API my_SQLFreeEnv(SQLHENV henv) { + MONITOR_THREAD_CONTAINER::release_instance(); + ENV *env= (ENV *) henv; delete env; #ifdef _UNIX_ myodbc_end(); #endif /* _UNIX_ */ - MONITOR_THREAD_CONTAINER::release_instance(); return(SQL_SUCCESS); } diff --git a/driver/monitor.cc b/driver/monitor.cc index a63232727..7aa0341e8 100644 --- a/driver/monitor.cc +++ b/driver/monitor.cc @@ -27,34 +27,45 @@ // along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. + +#include "driver.h" #include "monitor.h" + #include "monitor_service.h" #include "mylog.h" #include "mysql_proxy.h" MONITOR::MONITOR( std::shared_ptr host_info, + std::shared_ptr connection_handler, std::chrono::seconds failure_detection_timeout, std::chrono::milliseconds monitor_disposal_time, DataSource* ds, bool enable_logging) : MONITOR( std::move(host_info), + std::move(connection_handler), failure_detection_timeout, monitor_disposal_time, - new MYSQL_MONITOR_PROXY(ds), + ds, + nullptr, enable_logging) {}; MONITOR::MONITOR( std::shared_ptr host_info, + std::shared_ptr connection_handler, std::chrono::seconds failure_detection_timeout, std::chrono::milliseconds monitor_disposal_time, - MYSQL_MONITOR_PROXY* proxy, + DataSource* ds, + MYSQL_PROXY* proxy, bool enable_logging) { this->host = std::move(host_info); + this->connection_handler = std::move(connection_handler); this->failure_detection_timeout = failure_detection_timeout; this->disposal_time = monitor_disposal_time; + this->ds = ds_new(); + ds_copy(this->ds, ds); this->mysql_proxy = proxy; this->connection_check_interval = (std::chrono::milliseconds::max)(); if (enable_logging) @@ -62,6 +73,11 @@ MONITOR::MONITOR( } MONITOR::~MONITOR() { + if (this->ds) { + ds_delete(this->ds); + this->ds = nullptr; + } + if (this->mysql_proxy) { delete this->mysql_proxy; this->mysql_proxy = nullptr; @@ -195,21 +211,31 @@ CONNECTION_STATUS MONITOR::check_connection_status() { } bool MONITOR::connect() { - this->mysql_proxy->close(); - this->mysql_proxy->init(); - + if (this->mysql_proxy) { + this->mysql_proxy->close(); + delete this->mysql_proxy; + } // Timeout shouldn't be 0 by now, but double check just in case unsigned int timeout_sec = this->failure_detection_timeout.count() == 0 ? failure_detection_timeout_default : this->failure_detection_timeout.count(); - this->mysql_proxy->options(MYSQL_OPT_CONNECT_TIMEOUT, &timeout_sec); - this->mysql_proxy->options(MYSQL_OPT_READ_TIMEOUT, &timeout_sec); - if (!this->mysql_proxy->connect()) { - MYLOG_TRACE(this->logger.get(), 0, this->mysql_proxy->error()); - this->mysql_proxy->close(); + // timeout should be set in DBC::connect() + if (this->ds->enable_cluster_failover) { + this->ds->connect_timeout = timeout_sec; + this->ds->network_timeout = timeout_sec; + } else { + // cannot change login_timeout here because no access to dbc + this->ds->read_timeout = timeout_sec; + } + + this->ds->enable_cluster_failover = false; + this->ds->enable_failure_detection= false; + + this->mysql_proxy = this->connection_handler->connect(this->host, this->ds); + if (!this->mysql_proxy) { return false; } - return true; + return this->mysql_proxy->is_connected(); } std::chrono::milliseconds MONITOR::find_shortest_interval() { diff --git a/driver/monitor.h b/driver/monitor.h index 761db2876..e946678fd 100644 --- a/driver/monitor.h +++ b/driver/monitor.h @@ -30,6 +30,7 @@ #ifndef __MONITOR_H__ #define __MONITOR_H__ +#include "connection_handler.h" #include "host_info.h" #include "monitor_connection_context.h" @@ -43,8 +44,7 @@ struct CONNECTION_STATUS { struct DataSource; class MONITOR_SERVICE; -class MYSQL_MONITOR_PROXY; - +class MYSQL_PROXY; namespace { const std::chrono::milliseconds thread_sleep_when_inactive = std::chrono::milliseconds(100); @@ -55,15 +55,18 @@ class MONITOR : public std::enable_shared_from_this { public: MONITOR( std::shared_ptr host_info, + std::shared_ptr connection_handler, std::chrono::seconds failure_detection_timeout, std::chrono::milliseconds monitor_disposal_time, DataSource* ds, bool enable_logging = false); MONITOR( std::shared_ptr host_info, + std::shared_ptr connection_handler, std::chrono::seconds failure_detection_timeout, std::chrono::milliseconds monitor_disposal_time, - MYSQL_MONITOR_PROXY* proxy, + DataSource* ds, + MYSQL_PROXY* proxy, bool enable_logging = false); virtual ~MONITOR(); @@ -77,12 +80,14 @@ class MONITOR : public std::enable_shared_from_this { private: std::atomic_bool stopped{true}; std::shared_ptr host; + std::shared_ptr connection_handler; std::chrono::milliseconds connection_check_interval; std::chrono::seconds failure_detection_timeout; std::chrono::milliseconds disposal_time; std::list> contexts; std::chrono::steady_clock::time_point last_context_timestamp; - MYSQL_MONITOR_PROXY* mysql_proxy = nullptr; + MYSQL_PROXY* mysql_proxy = nullptr; + DataSource* ds = nullptr; std::shared_ptr logger; std::mutex mutex_; diff --git a/driver/monitor_service.cc b/driver/monitor_service.cc index 4d4cb3a2b..bf2e0a2e4 100644 --- a/driver/monitor_service.cc +++ b/driver/monitor_service.cc @@ -56,6 +56,12 @@ std::shared_ptr MONITOR_SERVICE::start_monitoring( int failure_detection_count, std::chrono::milliseconds disposal_time) { + if (!dbc || !ds) { + auto msg = "[MONITOR_SERVICE] Parameter dbc or ds cannot be null"; + MYLOG_TRACE(this->logger.get(), dbc ? dbc->id : 0, msg); + throw std::invalid_argument(msg); + } + if (node_keys.empty()) { auto msg = "[MONITOR_SERVICE] Parameter node_keys cannot be empty"; MYLOG_TRACE(this->logger.get(), dbc ? dbc->id : 0, msg); @@ -70,6 +76,7 @@ std::shared_ptr MONITOR_SERVICE::start_monitoring( failure_detection_timeout, disposal_time, ds, + dbc ? dbc->connection_handler : nullptr, enable_logging); auto context = std::make_shared( diff --git a/driver/monitor_thread_container.cc b/driver/monitor_thread_container.cc index c6e47931d..40a63fea9 100644 --- a/driver/monitor_thread_container.cc +++ b/driver/monitor_thread_container.cc @@ -78,6 +78,7 @@ std::shared_ptr MONITOR_THREAD_CONTAINER::get_or_create_monitor( std::chrono::seconds failure_detection_timeout, std::chrono::milliseconds disposal_time, DataSource* ds, + std::shared_ptr connection_handler, bool enable_logging) { std::shared_ptr monitor; @@ -91,7 +92,7 @@ std::shared_ptr MONITOR_THREAD_CONTAINER::get_or_create_monitor( else { monitor = this->get_available_monitor(); if (monitor == nullptr) { - monitor = this->create_monitor(std::move(host), failure_detection_timeout, disposal_time, ds, enable_logging); + monitor = this->create_monitor(std::move(host), std::move(connection_handler), failure_detection_timeout, disposal_time, ds, enable_logging); } } @@ -187,12 +188,17 @@ std::shared_ptr MONITOR_THREAD_CONTAINER::get_available_monitor() { std::shared_ptr MONITOR_THREAD_CONTAINER::create_monitor( std::shared_ptr host, + std::shared_ptr connection_handler, std::chrono::seconds failure_detection_timeout, std::chrono::milliseconds disposal_time, DataSource* ds, bool enable_logging) { - return std::make_shared(host, failure_detection_timeout, disposal_time, ds, enable_logging); + if (!connection_handler) { + return nullptr; + } + + return std::make_shared(host, connection_handler, failure_detection_timeout, disposal_time, ds, enable_logging); } void MONITOR_THREAD_CONTAINER::release_resources() { diff --git a/driver/monitor_thread_container.h b/driver/monitor_thread_container.h index e855fee1c..d8c9f58fa 100644 --- a/driver/monitor_thread_container.h +++ b/driver/monitor_thread_container.h @@ -30,6 +30,7 @@ #ifndef __MONITORTHREADCONTAINER_H__ #define __MONITORTHREADCONTAINER_H__ +#include "connection_handler.h" #include "monitor.h" #include @@ -50,6 +51,7 @@ class MONITOR_THREAD_CONTAINER { std::chrono::seconds failure_detection_timeout, std::chrono::milliseconds disposal_time, DataSource* ds, + std::shared_ptr connection_handler, bool enable_logging = false); virtual void add_task(const std::shared_ptr& monitor, const std::shared_ptr& service); void reset_resource(const std::shared_ptr& monitor); @@ -65,6 +67,7 @@ class MONITOR_THREAD_CONTAINER { std::shared_ptr get_available_monitor(); virtual std::shared_ptr create_monitor( std::shared_ptr host, + std::shared_ptr connection_handler, std::chrono::seconds failure_detection_timeout, std::chrono::milliseconds disposal_time, DataSource* ds, diff --git a/driver/my_stmt.cc b/driver/my_stmt.cc index 6c4eccae4..70c708fec 100644 --- a/driver/my_stmt.cc +++ b/driver/my_stmt.cc @@ -462,7 +462,7 @@ SQLRETURN prepare(STMT *stmt, char * query, SQLINTEGER query_length, return SQL_ERROR; } - stmt->param_count= mysql_stmt_param_count(stmt->ssps); + stmt->param_count= stmt->dbc->mysql_proxy->stmt_param_count(stmt->ssps); free_internal_result_buffers(stmt); /* make sure we free the result from the previous time */ diff --git a/driver/mysql_proxy.cc b/driver/mysql_proxy.cc index 546222599..d3c936af8 100644 --- a/driver/mysql_proxy.cc +++ b/driver/mysql_proxy.cc @@ -32,14 +32,10 @@ #include namespace { - const char* RETRIEVE_HOST_PORT_SQL = "SELECT CONCAT(@@hostname, ':', @@port)"; const auto SOCKET_CLOSE_DELAY = std::chrono::milliseconds(100); } -MYSQL_PROXY::MYSQL_PROXY(DBC* dbc, DataSource* ds) - : MYSQL_PROXY(dbc, ds, std::make_shared(ds && ds->save_queries)) {} - -MYSQL_PROXY::MYSQL_PROXY(DBC* dbc, DataSource* ds, std::shared_ptr monitor_service) : dbc{dbc}, ds{ds} { +MYSQL_PROXY::MYSQL_PROXY(DBC* dbc, DataSource* ds) : dbc{dbc}, ds{ds} { if (!this->dbc) { throw std::runtime_error("DBC cannot be null."); } @@ -48,18 +44,17 @@ MYSQL_PROXY::MYSQL_PROXY(DBC* dbc, DataSource* ds, std::shared_ptrds->enable_failure_detection) { - this->monitor_service = std::move(monitor_service); - } - this->host = get_host_info_from_ds(ds); - generate_node_keys(); } MYSQL_PROXY::~MYSQL_PROXY() { if (this->mysql) { close(); } + + if (this->next_proxy) { + delete next_proxy; + } } void MYSQL_PROXY::delete_ds() { @@ -110,30 +105,19 @@ unsigned long MYSQL_PROXY::thread_id() { } int MYSQL_PROXY::set_character_set(const char* csname) { - const auto context = start_monitoring(); - const int ret = mysql_set_character_set(mysql, csname); - stop_monitoring(context); - return ret; + return mysql_set_character_set(mysql, csname); } void MYSQL_PROXY::init() { - const auto context = start_monitoring(); this->mysql = mysql_init(nullptr); - stop_monitoring(context); } bool MYSQL_PROXY::ssl_set(const char* key, const char* cert, const char* ca, const char* capath, const char* cipher) { - const auto context = start_monitoring(); - const bool ret = mysql_ssl_set(mysql, key, cert, ca, capath, cipher); - stop_monitoring(context); - return ret; + return mysql_ssl_set(mysql, key, cert, ca, capath, cipher); } bool MYSQL_PROXY::change_user(const char* user, const char* passwd, const char* db) { - const auto context = start_monitoring(); - const bool ret = mysql_change_user(mysql, user, passwd, db); - stop_monitoring(context); - return ret; + return mysql_change_user(mysql, user, passwd, db); } bool MYSQL_PROXY::real_connect( @@ -141,45 +125,28 @@ bool MYSQL_PROXY::real_connect( const char* db, unsigned int port, const char* unix_socket, unsigned long clientflag) { - const auto context = start_monitoring(); const MYSQL* new_mysql = mysql_real_connect(mysql, host, user, passwd, db, port, unix_socket, clientflag); - stop_monitoring(context); return new_mysql != nullptr; } int MYSQL_PROXY::select_db(const char* db) { - const auto context = start_monitoring(); - const int ret = mysql_select_db(mysql, db); - stop_monitoring(context); - return ret; + return mysql_select_db(mysql, db); } int MYSQL_PROXY::query(const char* q) { - const auto context = start_monitoring(); - const int ret = mysql_query(mysql, q); - stop_monitoring(context); - return ret; + return mysql_query(mysql, q); } int MYSQL_PROXY::real_query(const char* q, unsigned long length) { - const auto context = start_monitoring(); - const int ret = mysql_real_query(mysql, q, length); - stop_monitoring(context); - return ret; + return mysql_real_query(mysql, q, length); } MYSQL_RES* MYSQL_PROXY::store_result() { - const auto context = start_monitoring(); - MYSQL_RES* ret = mysql_store_result(mysql); - stop_monitoring(context); - return ret; + return mysql_store_result(mysql); } MYSQL_RES* MYSQL_PROXY::use_result() { - const auto context = start_monitoring(); - MYSQL_RES* ret = mysql_use_result(mysql); - stop_monitoring(context); - return ret; + return mysql_use_result(mysql); } struct CHARSET_INFO* MYSQL_PROXY::get_character_set() const { @@ -191,24 +158,15 @@ void MYSQL_PROXY::get_character_set_info(MY_CHARSET_INFO* charset) { } bool MYSQL_PROXY::autocommit(bool auto_mode) { - const auto context = start_monitoring(); - const bool ret = mysql_autocommit(mysql, auto_mode); - stop_monitoring(context); - return ret; + return mysql_autocommit(mysql, auto_mode); } int MYSQL_PROXY::next_result() { - const auto context = start_monitoring(); - const int ret = mysql_next_result(mysql); - stop_monitoring(context); - return ret; + return mysql_next_result(mysql); } int MYSQL_PROXY::stmt_next_result(MYSQL_STMT* stmt) { - const auto context = start_monitoring(); - const int ret = mysql_stmt_next_result(stmt); - stop_monitoring(context); - return ret; + return mysql_stmt_next_result(stmt); } void MYSQL_PROXY::close() { @@ -220,17 +178,12 @@ bool MYSQL_PROXY::real_connect_dns_srv( const char* dns_srv_name, const char* user, const char* passwd, const char* db, unsigned long client_flag) { - const auto context = start_monitoring(); const MYSQL* new_mysql = mysql_real_connect_dns_srv(mysql, dns_srv_name, user, passwd, db, client_flag); - stop_monitoring(context); return new_mysql != nullptr; } int MYSQL_PROXY::ping() { - const auto context = start_monitoring(); - const int ret = mysql_ping(mysql); - stop_monitoring(context); - return ret; + return mysql_ping(mysql); } unsigned long MYSQL_PROXY::get_client_version(void) { @@ -242,23 +195,15 @@ int MYSQL_PROXY::get_option(mysql_option option, const void* arg) { } int MYSQL_PROXY::options4(mysql_option option, const void* arg1, const void* arg2) { - const auto context = start_monitoring(); - const int ret = mysql_options4(mysql, option, arg1, arg2); - stop_monitoring(context); - return ret; + return mysql_options4(mysql, option, arg1, arg2); } int MYSQL_PROXY::options(mysql_option option, const void* arg) { - const auto context = start_monitoring(); - const int ret = mysql_options(mysql, option, arg); - stop_monitoring(context); - return ret; + return mysql_options(mysql, option, arg); } void MYSQL_PROXY::free_result(MYSQL_RES* result) { - const auto context = start_monitoring(); mysql_free_result(result); - stop_monitoring(context); } void MYSQL_PROXY::data_seek(MYSQL_RES* result, uint64_t offset) { @@ -274,10 +219,7 @@ MYSQL_FIELD_OFFSET MYSQL_PROXY::field_seek(MYSQL_RES* result, MYSQL_FIELD_OFFSET } MYSQL_ROW MYSQL_PROXY::fetch_row(MYSQL_RES* result) { - const auto context = start_monitoring(); - const MYSQL_ROW ret = mysql_fetch_row(result); - stop_monitoring(context); - return ret; + return mysql_fetch_row(result); } unsigned long* MYSQL_PROXY::fetch_lengths(MYSQL_RES* result) { @@ -289,117 +231,73 @@ MYSQL_FIELD* MYSQL_PROXY::fetch_field(MYSQL_RES* result) { } MYSQL_RES* MYSQL_PROXY::list_fields(const char* table, const char* wild) { - const auto context = start_monitoring(); - MYSQL_RES* ret = mysql_list_fields(mysql, table, wild); - stop_monitoring(context); - return ret; + return mysql_list_fields(mysql, table, wild); } unsigned long MYSQL_PROXY::real_escape_string(char* to, const char* from, unsigned long length) { - const auto context = start_monitoring(); - const unsigned long ret = mysql_real_escape_string(mysql, to, from, length); - stop_monitoring(context); - return ret; + return mysql_real_escape_string(mysql, to, from, length); } bool MYSQL_PROXY::bind_param(unsigned n_params, MYSQL_BIND* binds, const char** names) { - const auto context = start_monitoring(); - const bool ret = mysql_bind_param(mysql, n_params, binds, names); - stop_monitoring(context); - return ret; + return mysql_bind_param(mysql, n_params, binds, names); } MYSQL_STMT* MYSQL_PROXY::stmt_init() { - const auto context = start_monitoring(); - MYSQL_STMT* ret = mysql_stmt_init(mysql); - stop_monitoring(context); - return ret; + return mysql_stmt_init(mysql); } int MYSQL_PROXY::stmt_prepare(MYSQL_STMT* stmt, const char* query, unsigned long length) { - const auto context = start_monitoring(); - const int ret = mysql_stmt_prepare(stmt, query, length); - stop_monitoring(context); - return ret; + return mysql_stmt_prepare(stmt, query, length); } int MYSQL_PROXY::stmt_execute(MYSQL_STMT* stmt) { - const auto context = start_monitoring(); - const int ret = mysql_stmt_execute(stmt); - stop_monitoring(context); - return ret; + return mysql_stmt_execute(stmt); } int MYSQL_PROXY::stmt_fetch(MYSQL_STMT* stmt) { - const auto context = start_monitoring(); - const int ret = mysql_stmt_fetch(stmt); - stop_monitoring(context); - return ret; + return mysql_stmt_fetch(stmt); } int MYSQL_PROXY::stmt_fetch_column(MYSQL_STMT* stmt, MYSQL_BIND* bind_arg, unsigned int column, unsigned long offset) { - const auto context = start_monitoring(); - const int ret = mysql_stmt_fetch_column(stmt, bind_arg, column, offset); - stop_monitoring(context); - return ret; + return mysql_stmt_fetch_column(stmt, bind_arg, column, offset); } int MYSQL_PROXY::stmt_store_result(MYSQL_STMT* stmt) { - const auto context = start_monitoring(); - const int ret = mysql_stmt_store_result(stmt); - stop_monitoring(context); - return ret; + return mysql_stmt_store_result(stmt); +} + +unsigned long MYSQL_PROXY::stmt_param_count(MYSQL_STMT* stmt) { + return mysql_stmt_param_count(stmt); } bool MYSQL_PROXY::stmt_bind_param(MYSQL_STMT* stmt, MYSQL_BIND* bnd) { - const auto context = start_monitoring(); - const bool ret = mysql_stmt_bind_param(stmt, bnd); - stop_monitoring(context); - return ret; + return mysql_stmt_bind_param(stmt, bnd); } bool MYSQL_PROXY::stmt_bind_result(MYSQL_STMT* stmt, MYSQL_BIND* bnd) { - const auto context = start_monitoring(); - const bool ret = mysql_stmt_bind_result(stmt, bnd); - stop_monitoring(context); - return ret; + return mysql_stmt_bind_result(stmt, bnd); } bool MYSQL_PROXY::stmt_close(MYSQL_STMT* stmt) { - const auto context = start_monitoring(); - const bool ret = mysql_stmt_close(stmt); - stop_monitoring(context); - return ret; + return mysql_stmt_close(stmt); } bool MYSQL_PROXY::stmt_reset(MYSQL_STMT* stmt) { - const auto context = start_monitoring(); - const bool ret = mysql_stmt_reset(stmt); - stop_monitoring(context); - return ret; + return mysql_stmt_reset(stmt); } bool MYSQL_PROXY::stmt_free_result(MYSQL_STMT* stmt) { - const auto context = start_monitoring(); - const bool ret = mysql_stmt_free_result(stmt); - stop_monitoring(context); - return ret; + return mysql_stmt_free_result(stmt); } bool MYSQL_PROXY::stmt_send_long_data(MYSQL_STMT* stmt, unsigned int param_number, const char* data, unsigned long length) { - const auto context = start_monitoring(); - const bool ret = mysql_stmt_send_long_data(stmt, param_number, data, length); - stop_monitoring(context); - return ret; + return mysql_stmt_send_long_data(stmt, param_number, data, length); } MYSQL_RES* MYSQL_PROXY::stmt_result_metadata(MYSQL_STMT* stmt) { - const auto context = start_monitoring(); - MYSQL_RES* ret = mysql_stmt_result_metadata(stmt); - stop_monitoring(context); - return ret; + return mysql_stmt_result_metadata(stmt); } unsigned int MYSQL_PROXY::stmt_errno(MYSQL_STMT* stmt) { @@ -435,10 +333,7 @@ unsigned int MYSQL_PROXY::stmt_field_count(MYSQL_STMT* stmt) { } st_mysql_client_plugin* MYSQL_PROXY::client_find_plugin(const char* name, int type) { - const auto context = start_monitoring(); - st_mysql_client_plugin* ret = mysql_client_find_plugin(mysql, name, type); - stop_monitoring(context); - return ret; + return mysql_client_find_plugin(mysql, name, type); } bool MYSQL_PROXY::is_connected() { @@ -501,14 +396,9 @@ void MYSQL_PROXY::set_connection(MYSQL_PROXY* mysql_proxy) { close(); this->mysql = mysql_proxy->mysql; mysql_proxy->mysql = nullptr; - - ds_delete(mysql_proxy->ds); + // delete the ds initialized in CONNECTION_HANDLER::clone_dbc() + mysql_proxy->delete_ds(); delete mysql_proxy; - - if (monitor_service != nullptr && !node_keys.empty()) { - monitor_service->stop_monitoring_for_all_connections(node_keys); - } - generate_node_keys(); } void MYSQL_PROXY::close_socket() { @@ -524,108 +414,10 @@ void MYSQL_PROXY::close_socket() { } } -std::shared_ptr MYSQL_PROXY::start_monitoring() { - if (!ds || !ds->enable_failure_detection) { - return nullptr; - } - - - auto failure_detection_timeout = ds->failure_detection_timeout; - // Use network timeout defined if failure detection timeout is not set - if (failure_detection_timeout == 0) { - failure_detection_timeout = ds->network_timeout == 0 ? failure_detection_timeout_default : ds->network_timeout; - } - - return monitor_service->start_monitoring( - dbc, - ds, - node_keys, - std::make_shared(get_host(), get_port()), - std::chrono::milliseconds{ds->failure_detection_time}, - std::chrono::seconds{failure_detection_timeout}, - std::chrono::milliseconds{ds->failure_detection_interval}, - ds->failure_detection_count, - std::chrono::milliseconds{ds->monitor_disposal_time}); -} - -void MYSQL_PROXY::stop_monitoring(std::shared_ptr context) { - if (!ds ||!ds->enable_failure_detection || context == nullptr) { - return; - } - monitor_service->stop_monitoring(context); - if (context->is_node_unhealthy() && is_connected()) { - close_socket(); - } -} - -void MYSQL_PROXY::generate_node_keys() { - node_keys.clear(); - node_keys.insert(std::string(get_host()) + ":" + std::to_string(get_port())); - - if (is_connected()) { - // Temporarily turn off failure detection if on - const auto failure_detection_old_state = ds->enable_failure_detection; - ds->enable_failure_detection = false; - - const auto error = query(RETRIEVE_HOST_PORT_SQL); - if (error == 0) { - MYSQL_RES* result = store_result(); - MYSQL_ROW row; - while ((row = fetch_row(result))) { - node_keys.insert(std::string(row[0])); - } - free_result(result); - } - - ds->enable_failure_detection = failure_detection_old_state; - } -} - -MYSQL_MONITOR_PROXY::MYSQL_MONITOR_PROXY(DataSource* ds) { - this->ds = ds_new(); - ds_copy(this->ds, ds); -} - -MYSQL_MONITOR_PROXY::~MYSQL_MONITOR_PROXY() { - if (this->mysql) { - mysql_close(this->mysql); - } - if (this->ds) { - ds_delete(this->ds); +void MYSQL_PROXY::set_next_proxy(MYSQL_PROXY* next_proxy) { + if (this->next_proxy) { + throw std::runtime_error("There is already a next proxy present!"); } -} - -void MYSQL_MONITOR_PROXY::init() { - this->mysql = mysql_init(nullptr); -} - -int MYSQL_MONITOR_PROXY::ping() { - return mysql_ping(mysql); -} -int MYSQL_MONITOR_PROXY::options(enum mysql_option option, const void* arg) { - return mysql_options(mysql, option, arg); -} - -bool MYSQL_MONITOR_PROXY::connect() { - if (!ds) - return false; - - const auto host = get_host_info_from_ds(ds); - - return mysql_real_connect(mysql, host->get_host().c_str(), ds_get_utf8attr(ds->uid, &ds->uid8), - ds_get_utf8attr(ds->pwd, &ds->pwd8), nullptr, host->get_port(), - ds_get_utf8attr(ds->socket, &ds->socket8), 0) != nullptr; -} - -bool MYSQL_MONITOR_PROXY::is_connected() { - return this->mysql != nullptr && this->mysql->net.vio; -} - -const char* MYSQL_MONITOR_PROXY::error() { - return mysql_error(mysql); } - -void MYSQL_MONITOR_PROXY::close() { - mysql_close(mysql); - mysql= nullptr; + this->next_proxy = next_proxy; } diff --git a/driver/mysql_proxy.h b/driver/mysql_proxy.h index 20c9de2c2..5ec55477d 100644 --- a/driver/mysql_proxy.h +++ b/driver/mysql_proxy.h @@ -32,7 +32,7 @@ #include -#include "monitor_service.h" +#include "host_info.h" struct DBC; struct DataSource; @@ -40,157 +40,137 @@ struct DataSource; class MYSQL_PROXY { public: MYSQL_PROXY(DBC* dbc, DataSource* ds); - MYSQL_PROXY(DBC* dbc, DataSource* ds, std::shared_ptr monitor_service); virtual ~MYSQL_PROXY(); - void delete_ds(); - uint64_t num_rows(MYSQL_RES* res); - unsigned int num_fields(MYSQL_RES* res); - MYSQL_FIELD* fetch_field_direct(MYSQL_RES* res, unsigned int fieldnr); - MYSQL_ROW_OFFSET row_tell(MYSQL_RES* res); - - unsigned int field_count(); - uint64_t affected_rows(); - unsigned int error_code(); - const char* error(); - const char* sqlstate(); - unsigned long thread_id(); - int set_character_set(const char* csname); - - void init(); - bool ssl_set(const char* key, const char* cert, const char* ca, - const char* capath, const char* cipher); - bool change_user(const char* user, const char* passwd, - const char* db); - bool real_connect(const char* host, const char* user, - const char* passwd, const char* db, unsigned int port, - const char* unix_socket, unsigned long clientflag); - int select_db(const char* db); + virtual void delete_ds(); + virtual uint64_t num_rows(MYSQL_RES* res); + virtual unsigned int num_fields(MYSQL_RES* res); + virtual MYSQL_FIELD* fetch_field_direct(MYSQL_RES* res, unsigned int fieldnr); + virtual MYSQL_ROW_OFFSET row_tell(MYSQL_RES* res); + + virtual unsigned int field_count(); + virtual uint64_t affected_rows(); + virtual unsigned int error_code(); + virtual const char* error(); + virtual const char* sqlstate(); + virtual unsigned long thread_id(); + virtual int set_character_set(const char* csname); + + virtual void init(); + virtual bool ssl_set(const char* key, const char* cert, const char* ca, + const char* capath, const char* cipher); + virtual bool change_user(const char* user, const char* passwd, + const char* db); + virtual bool real_connect(const char* host, const char* user, + const char* passwd, const char* db, unsigned int port, + const char* unix_socket, unsigned long clientflag); + virtual int select_db(const char* db); virtual int query(const char* q); - int real_query(const char* q, unsigned long length); + virtual int real_query(const char* q, unsigned long length); virtual MYSQL_RES* store_result(); - MYSQL_RES* use_result(); - struct CHARSET_INFO* get_character_set() const; - void get_character_set_info(MY_CHARSET_INFO* charset); + virtual MYSQL_RES* use_result(); + virtual struct CHARSET_INFO* get_character_set() const; + virtual void get_character_set_info(MY_CHARSET_INFO* charset); virtual int ping(); static unsigned long get_client_version(void); virtual int options(enum mysql_option option, const void* arg); - int options4(enum mysql_option option, const void* arg1, + virtual int options4(enum mysql_option option, const void* arg1, const void* arg2); - int get_option(enum mysql_option option, const void* arg); + virtual int get_option(enum mysql_option option, const void* arg); virtual void free_result(MYSQL_RES* result); - void data_seek(MYSQL_RES* result, uint64_t offset); - MYSQL_ROW_OFFSET row_seek(MYSQL_RES* result, MYSQL_ROW_OFFSET offset); - MYSQL_FIELD_OFFSET field_seek(MYSQL_RES* result, MYSQL_FIELD_OFFSET offset); + virtual void data_seek(MYSQL_RES* result, uint64_t offset); + virtual MYSQL_ROW_OFFSET row_seek(MYSQL_RES* result, MYSQL_ROW_OFFSET offset); + virtual MYSQL_FIELD_OFFSET field_seek(MYSQL_RES* result, MYSQL_FIELD_OFFSET offset); virtual MYSQL_ROW fetch_row(MYSQL_RES* result); - unsigned long* fetch_lengths(MYSQL_RES* result); - MYSQL_FIELD* fetch_field(MYSQL_RES* result); - MYSQL_RES* list_fields(const char* table, const char* wild); - unsigned long real_escape_string(char* to, const char* from, - unsigned long length); - - bool bind_param(unsigned n_params, MYSQL_BIND* binds, - const char** names); - - MYSQL_STMT* stmt_init(); - int stmt_prepare(MYSQL_STMT* stmt, const char* query, unsigned long length); - int stmt_execute(MYSQL_STMT* stmt); - int stmt_fetch(MYSQL_STMT* stmt); - int stmt_fetch_column(MYSQL_STMT* stmt, MYSQL_BIND* bind_arg, - unsigned int column, unsigned long offset); - int stmt_store_result(MYSQL_STMT* stmt); - unsigned long stmt_param_count(MYSQL_STMT* stmt); - bool stmt_bind_param(MYSQL_STMT* stmt, MYSQL_BIND* bnd); - bool stmt_bind_result(MYSQL_STMT* stmt, MYSQL_BIND* bnd); - bool stmt_close(MYSQL_STMT* stmt); - bool stmt_reset(MYSQL_STMT* stmt); - bool stmt_free_result(MYSQL_STMT* stmt); - bool stmt_send_long_data(MYSQL_STMT* stmt, unsigned int param_number, - const char* data, unsigned long length); - MYSQL_RES* stmt_result_metadata(MYSQL_STMT* stmt); - unsigned int stmt_errno(MYSQL_STMT* stmt); - const char* stmt_error(MYSQL_STMT* stmt); - MYSQL_ROW_OFFSET stmt_row_seek(MYSQL_STMT* stmt, MYSQL_ROW_OFFSET offset); - MYSQL_ROW_OFFSET stmt_row_tell(MYSQL_STMT* stmt); - void stmt_data_seek(MYSQL_STMT* stmt, uint64_t offset); - uint64_t stmt_num_rows(MYSQL_STMT* stmt); - uint64_t stmt_affected_rows(MYSQL_STMT* stmt); - unsigned int stmt_field_count(MYSQL_STMT* stmt); - - bool autocommit(bool auto_mode); - int next_result(); - int stmt_next_result(MYSQL_STMT* stmt); - void close(); - - bool real_connect_dns_srv(const char* dns_srv_name, - const char* user, const char* passwd, - const char* db, unsigned long client_flag); - struct st_mysql_client_plugin* client_find_plugin( + virtual unsigned long* fetch_lengths(MYSQL_RES* result); + virtual MYSQL_FIELD* fetch_field(MYSQL_RES* result); + virtual MYSQL_RES* list_fields(const char* table, const char* wild); + virtual unsigned long real_escape_string(char* to, const char* from, + unsigned long length); + + virtual bool bind_param(unsigned n_params, MYSQL_BIND* binds, + const char** names); + + virtual MYSQL_STMT* stmt_init(); + virtual int stmt_prepare(MYSQL_STMT* stmt, const char* query, unsigned long length); + virtual int stmt_execute(MYSQL_STMT* stmt); + virtual int stmt_fetch(MYSQL_STMT* stmt); + virtual int stmt_fetch_column(MYSQL_STMT* stmt, MYSQL_BIND* bind_arg, + unsigned int column, unsigned long offset); + virtual int stmt_store_result(MYSQL_STMT* stmt); + virtual unsigned long stmt_param_count(MYSQL_STMT* stmt); + virtual bool stmt_bind_param(MYSQL_STMT* stmt, MYSQL_BIND* bnd); + virtual bool stmt_bind_result(MYSQL_STMT* stmt, MYSQL_BIND* bnd); + virtual bool stmt_close(MYSQL_STMT* stmt); + virtual bool stmt_reset(MYSQL_STMT* stmt); + virtual bool stmt_free_result(MYSQL_STMT* stmt); + virtual bool stmt_send_long_data(MYSQL_STMT* stmt, unsigned int param_number, + const char* data, unsigned long length); + virtual MYSQL_RES* stmt_result_metadata(MYSQL_STMT* stmt); + virtual unsigned int stmt_errno(MYSQL_STMT* stmt); + virtual const char* stmt_error(MYSQL_STMT* stmt); + virtual MYSQL_ROW_OFFSET stmt_row_seek(MYSQL_STMT* stmt, MYSQL_ROW_OFFSET offset); + virtual MYSQL_ROW_OFFSET stmt_row_tell(MYSQL_STMT* stmt); + virtual void stmt_data_seek(MYSQL_STMT* stmt, uint64_t offset); + virtual uint64_t stmt_num_rows(MYSQL_STMT* stmt); + virtual uint64_t stmt_affected_rows(MYSQL_STMT* stmt); + virtual unsigned int stmt_field_count(MYSQL_STMT* stmt); + + virtual bool autocommit(bool auto_mode); + virtual int next_result(); + virtual int stmt_next_result(MYSQL_STMT* stmt); + virtual void close(); + + virtual bool real_connect_dns_srv(const char* dns_srv_name, + const char* user, const char* passwd, + const char* db, unsigned long client_flag); + virtual struct st_mysql_client_plugin* client_find_plugin( const char* name, int type); virtual bool is_connected(); - void set_last_error_code(unsigned int error_code); + virtual void set_last_error_code(unsigned int error_code); - char* get_last_error() const; + virtual char* get_last_error() const; - unsigned int get_last_error_code() const; + virtual unsigned int get_last_error_code() const; - char* get_sqlstate() const; + virtual char* get_sqlstate() const; - char* get_server_version() const; + virtual char* get_server_version() const; - uint64_t get_affected_rows() const; + virtual uint64_t get_affected_rows() const; - void set_affected_rows(uint64_t num_rows); + virtual void set_affected_rows(uint64_t num_rows); - char* get_host_info() const; + virtual char* get_host_info() const; - std::string get_host(); + virtual std::string get_host(); - unsigned int get_port(); + virtual unsigned int get_port(); - unsigned long get_max_packet() const; + virtual unsigned long get_max_packet() const; - unsigned long get_server_capabilities() const; + virtual unsigned long get_server_capabilities() const; - unsigned int get_server_status() const; + virtual unsigned int get_server_status() const; - void set_connection(MYSQL_PROXY* mysql_proxy); + virtual void set_connection(MYSQL_PROXY* mysql_proxy); virtual void close_socket(); + virtual void set_next_proxy(MYSQL_PROXY* next_proxy); + protected: DBC* dbc = nullptr; DataSource* ds = nullptr; - MYSQL* mysql = nullptr; - std::shared_ptr monitor_service = nullptr; - std::shared_ptr host = nullptr; - std::set node_keys; - - std::shared_ptr start_monitoring(); - void stop_monitoring(std::shared_ptr context); - void generate_node_keys(); -}; - -class MYSQL_MONITOR_PROXY { -public: - MYSQL_MONITOR_PROXY(DataSource* ds); - virtual ~MYSQL_MONITOR_PROXY(); - - virtual void init(); - virtual int ping(); - virtual int options(enum mysql_option option, const void* arg); - virtual bool connect(); - virtual bool is_connected(); - virtual const char* error(); - virtual void close(); + MYSQL_PROXY* next_proxy = nullptr; private: - DataSource* ds = nullptr; MYSQL* mysql = nullptr; + std::shared_ptr host = nullptr; }; #endif /* __MYSQL_PROXY__ */ diff --git a/driver/topology_service.h b/driver/topology_service.h index a95ed8b54..cf5dfd8dd 100644 --- a/driver/topology_service.h +++ b/driver/topology_service.h @@ -33,11 +33,9 @@ #include "cluster_aware_metrics_container.h" #include "cluster_topology_info.h" #include "mysql_proxy.h" -#include "mylog.h" #include #include -#include #include #include diff --git a/unit_testing/CMakeLists.txt b/unit_testing/CMakeLists.txt index 29bf13af0..2f2f350e2 100644 --- a/unit_testing/CMakeLists.txt +++ b/unit_testing/CMakeLists.txt @@ -54,6 +54,7 @@ add_executable( test_utils.cc cluster_aware_metrics_test.cc + efm_proxy_test.cc failover_handler_test.cc failover_reader_handler_test.cc failover_writer_handler_test.cc @@ -62,7 +63,6 @@ add_executable( monitor_test.cc monitor_thread_container_test.cc multi_threaded_monitor_service_test.cc - mysql_proxy_test.cc query_parsing_test.cc main.cc topology_service_test.cc diff --git a/unit_testing/mysql_proxy_test.cc b/unit_testing/efm_proxy_test.cc similarity index 70% rename from unit_testing/mysql_proxy_test.cc rename to unit_testing/efm_proxy_test.cc index 5bcc949ff..ee0412088 100644 --- a/unit_testing/mysql_proxy_test.cc +++ b/unit_testing/efm_proxy_test.cc @@ -36,12 +36,14 @@ using testing::_; using testing::Return; -class MySQLProxyTest : public testing::Test { +class EFMProxyTest : public testing::Test { protected: SQLHENV env; DBC* dbc; DataSource* ds; std::shared_ptr mock_monitor_service; + MOCK_MYSQL_PROXY* mock_mysql_proxy; + static void SetUpTestSuite() {} @@ -55,6 +57,7 @@ class MySQLProxyTest : public testing::Test { ds->enable_failure_detection = true; mock_monitor_service = std::make_shared(); + mock_mysql_proxy = new MOCK_MYSQL_PROXY(dbc, ds); } void TearDown() override { @@ -62,40 +65,48 @@ class MySQLProxyTest : public testing::Test { } }; -TEST_F(MySQLProxyTest, NullDBC) { - EXPECT_THROW(MYSQL_PROXY mysql_proxy(nullptr, ds), std::runtime_error); +TEST_F(EFMProxyTest, NullDBC) { + EXPECT_THROW(EFM_PROXY efm_proxy(nullptr, ds, mock_mysql_proxy), std::runtime_error); + delete mock_mysql_proxy; } -TEST_F(MySQLProxyTest, NullDS) { - EXPECT_THROW(MYSQL_PROXY mysql_proxy(dbc, nullptr), std::runtime_error); +TEST_F(EFMProxyTest, NullDS) { + EXPECT_THROW(EFM_PROXY efm_proxy(dbc, nullptr, mock_mysql_proxy), std::runtime_error); + delete mock_mysql_proxy; } -TEST_F(MySQLProxyTest, FailureDetectionDisabled) { +TEST_F(EFMProxyTest, FailureDetectionDisabled) { ds->enable_failure_detection = false; EXPECT_CALL(*mock_monitor_service, start_monitoring(_, _, _, _, _, _, _, _, _)).Times(0); EXPECT_CALL(*mock_monitor_service, stop_monitoring(_)).Times(0); + EXPECT_CALL(*mock_mysql_proxy, init()); + EXPECT_CALL(*mock_mysql_proxy, mock_mysql_proxy_destructor()); - MYSQL_PROXY mysql_proxy(dbc, ds, mock_monitor_service); - mysql_proxy.init(); + EFM_PROXY efm_proxy(dbc, ds, mock_mysql_proxy, mock_monitor_service); + efm_proxy.init(); } -TEST_F(MySQLProxyTest, FailureDetectionEnabled) { +TEST_F(EFMProxyTest, FailureDetectionEnabled) { auto mock_context = std::make_shared( nullptr, std::set(), std::chrono::milliseconds(0), std::chrono::milliseconds(0), 0); EXPECT_CALL(*mock_monitor_service, start_monitoring(_, _, _, _, _, _, _, _, _)).WillOnce(Return(mock_context)); EXPECT_CALL(*mock_monitor_service, stop_monitoring(mock_context)).Times(1); + EXPECT_CALL(*mock_mysql_proxy, query("")); + EXPECT_CALL(*mock_mysql_proxy, mock_mysql_proxy_destructor()); - MYSQL_PROXY mysql_proxy(dbc, ds, mock_monitor_service); - mysql_proxy.init(); + EFM_PROXY efm_proxy(dbc, ds, mock_mysql_proxy, mock_monitor_service); + efm_proxy.query(""); } -TEST_F(MySQLProxyTest, DoesNotNeedMonitoring) { +TEST_F(EFMProxyTest, DoesNotNeedMonitoring) { EXPECT_CALL(*mock_monitor_service, start_monitoring(_, _, _, _, _, _, _, _, _)).Times(0); EXPECT_CALL(*mock_monitor_service, stop_monitoring(_)).Times(0); + EXPECT_CALL(*mock_mysql_proxy, close()); + EXPECT_CALL(*mock_mysql_proxy, mock_mysql_proxy_destructor()); - MYSQL_PROXY mysql_proxy(dbc, ds, mock_monitor_service); - mysql_proxy.close(); + EFM_PROXY efm_proxy(dbc, ds, mock_mysql_proxy, mock_monitor_service); + efm_proxy.close(); } diff --git a/unit_testing/failover_reader_handler_test.cc b/unit_testing/failover_reader_handler_test.cc index f7987f80b..04724b510 100644 --- a/unit_testing/failover_reader_handler_test.cc +++ b/unit_testing/failover_reader_handler_test.cc @@ -220,7 +220,7 @@ TEST_F(FailoverReaderHandlerTest, BuildHostsList) { TEST_F(FailoverReaderHandlerTest, GetConnectionFromHosts_Failure) { EXPECT_CALL(*mock_ts, get_topology(_, true)).WillRepeatedly(Return(topology)); - EXPECT_CALL(*mock_connection_handler, connect(_)).WillRepeatedly(Return(nullptr)); + EXPECT_CALL(*mock_connection_handler, connect(_, nullptr)).WillRepeatedly(Return(nullptr)); EXPECT_CALL(*mock_ts, mark_host_down(reader_a_host)).Times(1); EXPECT_CALL(*mock_ts, mark_host_down(reader_b_host)).Times(1); @@ -244,8 +244,8 @@ TEST_F(FailoverReaderHandlerTest, GetConnectionFromHosts_Success_Reader) { EXPECT_CALL(*mock_reader_a_proxy, is_connected()).WillRepeatedly(Return(true)); - EXPECT_CALL(*mock_connection_handler, connect(_)).WillRepeatedly(Return(nullptr)); - EXPECT_CALL(*mock_connection_handler, connect(reader_a_host)).WillRepeatedly(Return(mock_reader_a_proxy)); + EXPECT_CALL(*mock_connection_handler, connect(_, nullptr)).WillRepeatedly(Return(nullptr)); + EXPECT_CALL(*mock_connection_handler, connect(reader_a_host, nullptr)).WillRepeatedly(Return(mock_reader_a_proxy)); // Reader C will not be used as it is put at the end. Will only try to connect to A and B EXPECT_CALL(*mock_ts, mark_host_up(reader_a_host)).Times(1); @@ -271,8 +271,8 @@ TEST_F(FailoverReaderHandlerTest, GetConnectionFromHosts_Success_Writer) { EXPECT_CALL(*mock_writer_proxy, is_connected()).WillRepeatedly(Return(true)); - EXPECT_CALL(*mock_connection_handler, connect(_)).WillRepeatedly(Return(nullptr)); - EXPECT_CALL(*mock_connection_handler, connect(writer_host)).WillRepeatedly(Return(mock_writer_proxy)); + EXPECT_CALL(*mock_connection_handler, connect(_, nullptr)).WillRepeatedly(Return(nullptr)); + EXPECT_CALL(*mock_connection_handler, connect(writer_host, nullptr)).WillRepeatedly(Return(mock_writer_proxy)); EXPECT_CALL(*mock_ts, mark_host_up(writer_host)).Times(1); @@ -306,9 +306,9 @@ TEST_F(FailoverReaderHandlerTest, GetConnectionFromHosts_FastestHost) { EXPECT_CALL(*mock_reader_a_proxy, is_connected()).WillRepeatedly(Return(true)); EXPECT_CALL(*mock_reader_b_proxy, is_connected()).WillRepeatedly(Return(true)); - EXPECT_CALL(*mock_connection_handler, connect(_)).WillRepeatedly(Return(nullptr)); - EXPECT_CALL(*mock_connection_handler, connect(reader_a_host)).WillRepeatedly(Return(mock_reader_a_proxy)); - EXPECT_CALL(*mock_connection_handler, connect(reader_b_host)).WillRepeatedly(Invoke([&]() { + EXPECT_CALL(*mock_connection_handler, connect(_, nullptr)).WillRepeatedly(Return(nullptr)); + EXPECT_CALL(*mock_connection_handler, connect(reader_a_host, nullptr)).WillRepeatedly(Return(mock_reader_a_proxy)); + EXPECT_CALL(*mock_connection_handler, connect(reader_b_host, nullptr)).WillRepeatedly(Invoke([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(5000)); return mock_reader_b_proxy; })); @@ -339,12 +339,12 @@ TEST_F(FailoverReaderHandlerTest, GetConnectionFromHosts_Timeout) { EXPECT_CALL(*mock_reader_b_proxy, is_connected()).WillRepeatedly(Return(true)); EXPECT_CALL(*mock_reader_b_proxy, mock_mysql_proxy_destructor()); - EXPECT_CALL(*mock_connection_handler, connect(_)).WillRepeatedly(Return(nullptr)); - EXPECT_CALL(*mock_connection_handler, connect(reader_a_host)).WillRepeatedly(Invoke([&]() { + EXPECT_CALL(*mock_connection_handler, connect(_, nullptr)).WillRepeatedly(Return(nullptr)); + EXPECT_CALL(*mock_connection_handler, connect(reader_a_host, nullptr)).WillRepeatedly(Invoke([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(5000)); return mock_reader_a_proxy; })); - EXPECT_CALL(*mock_connection_handler, connect(reader_b_host)).WillRepeatedly(Invoke([&]() { + EXPECT_CALL(*mock_connection_handler, connect(reader_b_host, nullptr)).WillRepeatedly(Invoke([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(5000)); return mock_reader_b_proxy; })); @@ -365,7 +365,7 @@ TEST_F(FailoverReaderHandlerTest, GetConnectionFromHosts_Timeout) { TEST_F(FailoverReaderHandlerTest, Failover_Failure) { EXPECT_CALL(*mock_ts, get_topology(_, true)).WillRepeatedly(Return(topology)); - EXPECT_CALL(*mock_connection_handler, connect(_)).WillRepeatedly(Return(nullptr)); + EXPECT_CALL(*mock_connection_handler, connect(_, nullptr)).WillRepeatedly(Return(nullptr)); EXPECT_CALL(*mock_ts, mark_host_down(reader_a_host)).Times(1); EXPECT_CALL(*mock_ts, mark_host_down(reader_b_host)).Times(1); @@ -401,10 +401,10 @@ TEST_F(FailoverReaderHandlerTest, Failover_Success_Reader) { EXPECT_CALL(*mock_reader_b_proxy, is_connected()).WillRepeatedly(Return(true)); EXPECT_CALL(*mock_reader_b_proxy, mock_mysql_proxy_destructor()); - EXPECT_CALL(*mock_connection_handler, connect(_)).WillRepeatedly(Return(nullptr)); - EXPECT_CALL(*mock_connection_handler, connect(reader_a_host)).WillRepeatedly( + EXPECT_CALL(*mock_connection_handler, connect(_, nullptr)).WillRepeatedly(Return(nullptr)); + EXPECT_CALL(*mock_connection_handler, connect(reader_a_host, nullptr)).WillRepeatedly( Return(mock_reader_a_proxy)); - EXPECT_CALL(*mock_connection_handler, connect(reader_b_host)).WillRepeatedly(Invoke([&]() { + EXPECT_CALL(*mock_connection_handler, connect(reader_b_host, nullptr)).WillRepeatedly(Invoke([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(5000)); return mock_reader_b_proxy; })); diff --git a/unit_testing/failover_writer_handler_test.cc b/unit_testing/failover_writer_handler_test.cc index d03bbdedf..d21cead6a 100644 --- a/unit_testing/failover_writer_handler_test.cc +++ b/unit_testing/failover_writer_handler_test.cc @@ -126,11 +126,11 @@ TEST_F(FailoverWriterHandlerTest, ReconnectToWriter_TaskBEmptyReaderResult) { EXPECT_CALL(*mock_reader_handler, get_reader_connection(_, _)) .WillRepeatedly(Return(READER_FAILOVER_RESULT(false, nullptr, nullptr))); - EXPECT_CALL(*mock_connection_handler, connect(writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(writer_host, nullptr)) .WillRepeatedly(Return(mock_writer_proxy)); - EXPECT_CALL(*mock_connection_handler, connect(reader_a_host)) + EXPECT_CALL(*mock_connection_handler, connect(reader_a_host, nullptr)) .WillRepeatedly(Return(nullptr)); - EXPECT_CALL(*mock_connection_handler, connect(reader_b_host)) + EXPECT_CALL(*mock_connection_handler, connect(reader_b_host, nullptr)) .WillRepeatedly(Return(nullptr)); FAILOVER_WRITER_HANDLER writer_handler( @@ -170,9 +170,9 @@ TEST_F(FailoverWriterHandlerTest, ReconnectToWriter_SlowReaderA) { EXPECT_CALL(*mock_reader_a_proxy, is_connected()).WillRepeatedly(Return(true)); - EXPECT_CALL(*mock_connection_handler, connect(writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(writer_host, nullptr)) .WillRepeatedly(Return(mock_writer_proxy)); - EXPECT_CALL(*mock_connection_handler, connect(reader_b_host)) + EXPECT_CALL(*mock_connection_handler, connect(reader_b_host, nullptr)) .WillRepeatedly(Return(nullptr)); EXPECT_CALL(*mock_ts, get_topology(mock_writer_proxy, true)) @@ -220,12 +220,12 @@ TEST_F(FailoverWriterHandlerTest, ReconnectToWriter_TaskBDefers) { EXPECT_CALL(*mock_reader_a_proxy, is_connected()).WillRepeatedly(Return(true)); EXPECT_CALL(*mock_reader_a_proxy, mock_mysql_proxy_destructor()); - EXPECT_CALL(*mock_connection_handler, connect(writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(writer_host, nullptr)) .WillRepeatedly(Invoke([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(5000)); return mock_writer_proxy; })); - EXPECT_CALL(*mock_connection_handler, connect(reader_b_host)) + EXPECT_CALL(*mock_connection_handler, connect(reader_b_host, nullptr)) .WillRepeatedly(Return(nullptr)); EXPECT_CALL(*mock_ts, get_topology(_, true)) @@ -281,13 +281,13 @@ TEST_F(FailoverWriterHandlerTest, ConnectToReaderA_SlowWriter) { EXPECT_CALL(*mock_reader_a_proxy, is_connected()).WillRepeatedly(Return(true)); EXPECT_CALL(*mock_reader_a_proxy, mock_mysql_proxy_destructor()); - EXPECT_CALL(*mock_connection_handler, connect(writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(writer_host, nullptr)) .WillRepeatedly(Invoke([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(5000)); return mock_writer_proxy; })); - EXPECT_CALL(*mock_connection_handler, connect(new_writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(new_writer_host, nullptr)) .WillRepeatedly(Return(mock_new_writer_proxy)); EXPECT_CALL(*mock_ts, get_topology(mock_writer_proxy, true)) @@ -341,12 +341,12 @@ TEST_F(FailoverWriterHandlerTest, ConnectToReaderA_TaskADefers) { EXPECT_CALL(*mock_reader_a_proxy, is_connected()).WillRepeatedly(Return(true)); EXPECT_CALL(*mock_reader_a_proxy, mock_mysql_proxy_destructor()); - EXPECT_CALL(*mock_connection_handler, connect(writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(writer_host, nullptr)) .WillOnce(Return(mock_writer_proxy)) .WillRepeatedly(Return(nullptr)); // Connection is deleted after first connect - EXPECT_CALL(*mock_connection_handler, connect(reader_a_host)) + EXPECT_CALL(*mock_connection_handler, connect(reader_a_host, nullptr)) .WillRepeatedly(Return(mock_reader_a_proxy)); - EXPECT_CALL(*mock_connection_handler, connect(new_writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(new_writer_host, nullptr)) .WillRepeatedly(Invoke([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(5000)); return mock_new_writer_proxy; @@ -403,16 +403,16 @@ TEST_F(FailoverWriterHandlerTest, FailedToConnect_FailoverTimeout) { EXPECT_CALL(*mock_reader_b_proxy, is_connected()).WillRepeatedly(Return(true)); - EXPECT_CALL(*mock_connection_handler, connect(writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(writer_host, nullptr)) .WillRepeatedly(Invoke([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(5000)); return mock_writer_proxy; })); - EXPECT_CALL(*mock_connection_handler, connect(reader_a_host)) + EXPECT_CALL(*mock_connection_handler, connect(reader_a_host, nullptr)) .WillRepeatedly(Return(mock_reader_a_proxy)); - EXPECT_CALL(*mock_connection_handler, connect(reader_b_host)) + EXPECT_CALL(*mock_connection_handler, connect(reader_b_host, nullptr)) .WillRepeatedly(Return(mock_reader_b_proxy)); - EXPECT_CALL(*mock_connection_handler, connect(new_writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(new_writer_host, nullptr)) .WillRepeatedly(Invoke([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(5000)); return mock_new_writer_proxy; @@ -461,13 +461,13 @@ TEST_F(FailoverWriterHandlerTest, FailedToConnect_TaskAFailed_TaskBWriterFailed) EXPECT_CALL(*mock_reader_b_proxy, is_connected()).WillRepeatedly(Return(true)); - EXPECT_CALL(*mock_connection_handler, connect(writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(writer_host, nullptr)) .WillRepeatedly(Return(nullptr)); - EXPECT_CALL(*mock_connection_handler, connect(reader_a_host)) + EXPECT_CALL(*mock_connection_handler, connect(reader_a_host, nullptr)) .WillRepeatedly(Return(mock_reader_a_proxy)); - EXPECT_CALL(*mock_connection_handler, connect(reader_b_host)) + EXPECT_CALL(*mock_connection_handler, connect(reader_b_host, nullptr)) .WillRepeatedly(Return(mock_reader_b_proxy)); - EXPECT_CALL(*mock_connection_handler, connect(new_writer_host)) + EXPECT_CALL(*mock_connection_handler, connect(new_writer_host, nullptr)) .WillRepeatedly(Return(nullptr)); EXPECT_CALL(*mock_ts, get_topology(_, _)) diff --git a/unit_testing/mock_objects.h b/unit_testing/mock_objects.h index 520623699..bc22bd6b0 100644 --- a/unit_testing/mock_objects.h +++ b/unit_testing/mock_objects.h @@ -56,7 +56,7 @@ class MOCK_MYSQL_PROXY : public MYSQL_PROXY { this->ds = ds_new(); ds_copy(this->ds, ds); }; - ~MOCK_MYSQL_PROXY() { + ~MOCK_MYSQL_PROXY() override { mock_mysql_proxy_destructor(); if (this->ds) { ds_delete(this->ds); @@ -77,6 +77,9 @@ class MOCK_MYSQL_PROXY : public MYSQL_PROXY { MOCK_METHOD(void, free_result, (MYSQL_RES*)); MOCK_METHOD(void, close_socket, ()); MOCK_METHOD(void, mock_mysql_proxy_destructor, ()); + MOCK_METHOD(void, close, ()); + MOCK_METHOD(void, init, ()); + MOCK_METHOD(int, ping, ()); }; class MOCK_TOPOLOGY_SERVICE : public TOPOLOGY_SERVICE { @@ -97,10 +100,10 @@ class MOCK_READER_HANDLER : public FAILOVER_READER_HANDLER { (std::shared_ptr, FAILOVER_SYNC&)); }; -class MOCK_CONNECTION_HANDLER : public FAILOVER_CONNECTION_HANDLER { +class MOCK_CONNECTION_HANDLER : public CONNECTION_HANDLER { public: - MOCK_CONNECTION_HANDLER() : FAILOVER_CONNECTION_HANDLER(nullptr) {} - MOCK_METHOD(MYSQL_PROXY*, connect, (const std::shared_ptr&)); + MOCK_CONNECTION_HANDLER() : CONNECTION_HANDLER(nullptr) {} + MOCK_METHOD(MYSQL_PROXY*, connect, (const std::shared_ptr&, DataSource*)); MOCK_METHOD(SQLRETURN, do_connect, (DBC*, DataSource*, bool)); }; @@ -125,8 +128,8 @@ class MOCK_CLUSTER_AWARE_METRICS_CONTAINER : public CLUSTER_AWARE_METRICS_CONTAI class MOCK_MONITOR : public MONITOR { public: MOCK_MONITOR(std::shared_ptr host, std::chrono::milliseconds disposal_time, - MYSQL_MONITOR_PROXY* monitor_proxy) - : MONITOR(host, std::chrono::seconds{5}, disposal_time, monitor_proxy) {} + MYSQL_PROXY* monitor_proxy) + : MONITOR(host, nullptr, std::chrono::seconds{5}, disposal_time, nullptr, monitor_proxy) {} MOCK_METHOD(void, start_monitoring, (std::shared_ptr)); MOCK_METHOD(void, stop_monitoring, (std::shared_ptr)); @@ -138,7 +141,7 @@ class MOCK_MONITOR : public MONITOR { class MOCK_MONITOR2 : public MONITOR { public: MOCK_MONITOR2(std::shared_ptr host, std::chrono::milliseconds disposal_time) - : MONITOR(host, std::chrono::seconds{5}, disposal_time, (MYSQL_MONITOR_PROXY*)nullptr) {} + : MONITOR(host, nullptr, std::chrono::seconds{5}, disposal_time, nullptr, nullptr) {} MOCK_METHOD(void, run, (std::shared_ptr)); }; @@ -147,9 +150,8 @@ class MOCK_MONITOR2 : public MONITOR { class MOCK_MONITOR3 : public MONITOR { public: MOCK_MONITOR3(std::shared_ptr host, - std::chrono::milliseconds disposal_time, - MYSQL_MONITOR_PROXY* monitor_proxy) - : MONITOR(host, std::chrono::seconds{5}, disposal_time, monitor_proxy) {} + std::chrono::milliseconds disposal_time) + : MONITOR(host, nullptr, std::chrono::seconds{5}, disposal_time, nullptr, nullptr) {} MOCK_METHOD(std::chrono::steady_clock::time_point, get_current_time, ()); }; @@ -163,7 +165,7 @@ class MOCK_MONITOR_THREAD_CONTAINER : public MONITOR_THREAD_CONTAINER { } MOCK_METHOD(std::shared_ptr, create_monitor, - (std::shared_ptr, std::chrono::seconds, std::chrono::milliseconds, DataSource*, bool)); + (std::shared_ptr, std::shared_ptr, std::chrono::seconds, std::chrono::milliseconds, DataSource*, bool)); }; class MOCK_MONITOR_CONNECTION_CONTEXT : public MONITOR_CONNECTION_CONTEXT { @@ -179,17 +181,6 @@ class MOCK_MONITOR_CONNECTION_CONTEXT : public MONITOR_CONNECTION_CONTEXT { MOCK_METHOD(void, set_start_monitor_time, (std::chrono::steady_clock::time_point)); }; -class MOCK_MYSQL_MONITOR_PROXY : public MYSQL_MONITOR_PROXY { -public: - MOCK_MYSQL_MONITOR_PROXY() : MYSQL_MONITOR_PROXY(nullptr) {}; - - MOCK_METHOD(bool, is_connected, ()); - MOCK_METHOD(void, init, ()); - MOCK_METHOD(int, ping, ()); - MOCK_METHOD(int, options, (enum mysql_option, const void*)); - MOCK_METHOD(bool, connect, ()); -}; - class MOCK_MONITOR_SERVICE : public MONITOR_SERVICE { public: MOCK_MONITOR_SERVICE() : MONITOR_SERVICE() {}; diff --git a/unit_testing/monitor_service_test.cc b/unit_testing/monitor_service_test.cc index c79e2ebe5..3d5094f9b 100644 --- a/unit_testing/monitor_service_test.cc +++ b/unit_testing/monitor_service_test.cc @@ -29,6 +29,7 @@ #include "driver/monitor_service.h" +#include "test_utils.h" #include "mock_objects.h" #include @@ -48,6 +49,9 @@ namespace { class MonitorServiceTest : public testing::Test { protected: + SQLHENV env; + DBC* dbc; + DataSource* ds; std::shared_ptr host; std::shared_ptr mock_monitor; std::shared_ptr mock_thread_container; @@ -58,6 +62,7 @@ class MonitorServiceTest : public testing::Test { static void TearDownTestSuite() {} void SetUp() override { + allocate_odbc_handles(env, dbc, ds); host = std::make_shared("host", 1234); mock_thread_container = std::make_shared(); monitor_service = std::make_shared(mock_thread_container); @@ -67,19 +72,20 @@ class MonitorServiceTest : public testing::Test { void TearDown() override { monitor_service->release_resources(); mock_thread_container->release_resources(); + cleanup_odbc_handles(env, dbc, ds); } }; TEST_F(MonitorServiceTest, StartMonitoring) { - EXPECT_CALL(*mock_thread_container, create_monitor(_, _, _, _, _)) + EXPECT_CALL(*mock_thread_container, create_monitor(_, _, _, _, _, _)) .WillOnce(Return(mock_monitor)); EXPECT_CALL(*mock_monitor, start_monitoring(_)).Times(1); EXPECT_CALL(*mock_monitor, run(_)).Times(1); auto context = monitor_service->start_monitoring( - nullptr, - nullptr, + dbc, + ds, node_keys, host, failure_detection_time, @@ -91,7 +97,7 @@ TEST_F(MonitorServiceTest, StartMonitoring) { } TEST_F(MonitorServiceTest, StartMonitoringCalledMultipleTimes) { - EXPECT_CALL(*mock_thread_container, create_monitor(_, _, _, _, _)) + EXPECT_CALL(*mock_thread_container, create_monitor(_, _, _, _, _, _)) .WillOnce(Return(mock_monitor)); const int runs = 5; @@ -101,8 +107,8 @@ TEST_F(MonitorServiceTest, StartMonitoringCalledMultipleTimes) { for (int i = 0; i < runs; i++) { auto context = monitor_service->start_monitoring( - nullptr, - nullptr, + dbc, + ds, node_keys, host, failure_detection_time, @@ -115,15 +121,15 @@ TEST_F(MonitorServiceTest, StartMonitoringCalledMultipleTimes) { } TEST_F(MonitorServiceTest, StopMonitoring) { - EXPECT_CALL(*mock_thread_container, create_monitor(_, _, _, _, _)) + EXPECT_CALL(*mock_thread_container, create_monitor(_, _, _, _, _, _)) .WillOnce(Return(mock_monitor)); EXPECT_CALL(*mock_monitor, start_monitoring(_)).Times(1); EXPECT_CALL(*mock_monitor, run(_)).Times(1); auto context = monitor_service->start_monitoring( - nullptr, - nullptr, + dbc, + ds, node_keys, host, failure_detection_time, @@ -139,15 +145,15 @@ TEST_F(MonitorServiceTest, StopMonitoring) { } TEST_F(MonitorServiceTest, StopMonitoringCalledTwice) { - EXPECT_CALL(*mock_thread_container, create_monitor(_, _, _, _, _)) + EXPECT_CALL(*mock_thread_container, create_monitor(_, _, _, _, _, _)) .WillOnce(Return(mock_monitor)); EXPECT_CALL(*mock_monitor, start_monitoring(_)).Times(1); EXPECT_CALL(*mock_monitor, run(_)).Times(1); auto context = monitor_service->start_monitoring( - nullptr, - nullptr, + dbc, + ds, node_keys, host, failure_detection_time, @@ -168,8 +174,8 @@ TEST_F(MonitorServiceTest, EmptyNodeKeys) { EXPECT_THROW( monitor_service->start_monitoring( - nullptr, - nullptr, + dbc, + ds, keys, host, failure_detection_time, diff --git a/unit_testing/monitor_test.cc b/unit_testing/monitor_test.cc index d13d486ac..52e27d8c7 100644 --- a/unit_testing/monitor_test.cc +++ b/unit_testing/monitor_test.cc @@ -37,12 +37,11 @@ #include using ::testing::_; +using ::testing::AllOf; using ::testing::AtLeast; using ::testing::Eq; -using ::testing::MatcherCast; -using ::testing::Pointee; +using ::testing::Field; using ::testing::Return; -using ::testing::SafeMatcherCast; namespace { const std::set node_keys = { "any.node.domain" }; @@ -58,9 +57,13 @@ namespace { class MonitorTest : public testing::Test { protected: + SQLHENV env; + DBC* dbc; + DataSource* ds; std::shared_ptr host; std::shared_ptr monitor; - MOCK_MYSQL_MONITOR_PROXY* mock_proxy; + MOCK_MYSQL_PROXY* mock_proxy; + std::shared_ptr mock_connection_handler; std::shared_ptr mock_context_short_interval; std::shared_ptr mock_context_long_interval; @@ -69,10 +72,11 @@ class MonitorTest : public testing::Test { static void TearDownTestSuite() {} void SetUp() override { + allocate_odbc_handles(env, dbc, ds); host = std::make_shared("host", 1234); - mock_proxy = new MOCK_MYSQL_MONITOR_PROXY(); - monitor = std::make_shared(host, failure_detection_timeout, monitor_disposal_time, mock_proxy); - + mock_connection_handler = std::make_shared(); + monitor = std::make_shared(host, mock_connection_handler, failure_detection_timeout, monitor_disposal_time, ds, nullptr, false); + mock_context_short_interval = std::make_shared( node_keys, failure_detection_time, @@ -86,7 +90,9 @@ class MonitorTest : public testing::Test { failure_detection_count); } - void TearDown() override {} + void TearDown() override { + cleanup_odbc_handles(env, dbc, ds); + } }; TEST_F(MonitorTest, StartMonitoringWithDifferentContexts) { @@ -136,15 +142,12 @@ TEST_F(MonitorTest, StopMonitoringTwiceWithSameContext) { } TEST_F(MonitorTest, IsConnectionHealthyWithNoExistingConnection) { - EXPECT_CALL(*mock_proxy, is_connected()) - .WillOnce(Return(false)) - .WillRepeatedly(Return(true)); - - EXPECT_CALL(*mock_proxy, init()).Times(1); - - EXPECT_CALL(*mock_proxy, connect()) - .WillOnce(Return(true)); + mock_proxy = new MOCK_MYSQL_PROXY(dbc, ds); + EXPECT_CALL(*mock_connection_handler, connect(host, _)) + .WillOnce(Return(mock_proxy)); + + EXPECT_CALL(*mock_proxy, is_connected()).WillRepeatedly(Return(true)); EXPECT_CALL(*mock_proxy, ping()).Times(0); CONNECTION_STATUS status = TEST_UTILS::check_connection_status(monitor); @@ -153,22 +156,20 @@ TEST_F(MonitorTest, IsConnectionHealthyWithNoExistingConnection) { } TEST_F(MonitorTest, IsConnectionHealthyOrUnhealthy) { + mock_proxy = new MOCK_MYSQL_PROXY(dbc, ds); + + EXPECT_CALL(*mock_connection_handler, connect(host, _)) + .WillRepeatedly(Return(mock_proxy)); + EXPECT_CALL(*mock_proxy, is_connected()) .WillOnce(Return(false)) .WillRepeatedly(Return(true)); - - EXPECT_CALL(*mock_proxy, init()).Times(AtLeast(1)); - EXPECT_CALL(*mock_proxy, options(_, _)).Times(AtLeast(1)); - - EXPECT_CALL(*mock_proxy, connect()) - .WillRepeatedly(Return(true)); - EXPECT_CALL(*mock_proxy, ping()) .WillOnce(Return(0)) .WillOnce(Return(1)); CONNECTION_STATUS status1 = TEST_UTILS::check_connection_status(monitor); - EXPECT_TRUE(status1.is_valid); + EXPECT_FALSE(status1.is_valid); CONNECTION_STATUS status2 = TEST_UTILS::check_connection_status(monitor); EXPECT_TRUE(status2.is_valid); @@ -178,16 +179,13 @@ TEST_F(MonitorTest, IsConnectionHealthyOrUnhealthy) { } TEST_F(MonitorTest, IsConnectionHealthyAfterFailedConnection) { + mock_proxy = new MOCK_MYSQL_PROXY(dbc, ds); + + EXPECT_CALL(*mock_connection_handler, connect(host, _)) + .WillOnce(Return(mock_proxy)); + EXPECT_CALL(*mock_proxy, is_connected()) - .WillOnce(Return(false)) .WillRepeatedly(Return(true)); - - EXPECT_CALL(*mock_proxy, init()).Times(AtLeast(1)); - EXPECT_CALL(*mock_proxy, options(_, _)).Times(AtLeast(1)); - - EXPECT_CALL(*mock_proxy, connect()) - .WillOnce(Return(true)); - EXPECT_CALL(*mock_proxy, ping()) .WillOnce(Return(1)); @@ -201,11 +199,10 @@ TEST_F(MonitorTest, IsConnectionHealthyAfterFailedConnection) { } TEST_F(MonitorTest, RunWithoutContext) { - auto proxy = new MOCK_MYSQL_MONITOR_PROXY(); std::shared_ptr container = MONITOR_THREAD_CONTAINER::get_instance(); auto monitor_service = std::make_shared(container); - auto mock_monitor = std::make_shared(host, short_interval, proxy); + auto mock_monitor = std::make_shared(host, short_interval); EXPECT_CALL(*mock_monitor, get_current_time()) .WillRepeatedly(Return(short_interval_time)); @@ -227,23 +224,16 @@ TEST_F(MonitorTest, RunWithoutContext) { } TEST_F(MonitorTest, RunWithContext) { - auto proxy = new MOCK_MYSQL_MONITOR_PROXY(); - - EXPECT_CALL(*proxy, is_connected()) - .WillOnce(Return(false)) - .WillRepeatedly(Return(true)); - - EXPECT_CALL(*proxy, init()).Times(AtLeast(1)); - EXPECT_CALL(*proxy, options(_, _)).Times(AtLeast(1)); + auto proxy = new MOCK_MYSQL_PROXY(dbc, ds); - EXPECT_CALL(*proxy, connect()) - .WillRepeatedly(Return(true)); + EXPECT_CALL(*mock_connection_handler, connect(host, _)) + .WillOnce(Return(proxy)); - EXPECT_CALL(*proxy, ping()) - .WillRepeatedly(Return(0)); + EXPECT_CALL(*proxy, is_connected()).WillRepeatedly(Return(true)); + EXPECT_CALL(*proxy, ping()).WillRepeatedly(Return(0)); std::shared_ptr monitorA = - std::make_shared(host, failure_detection_timeout, short_interval, proxy); + std::make_shared(host, mock_connection_handler, failure_detection_timeout, short_interval, ds, nullptr); auto container = MONITOR_THREAD_CONTAINER::get_instance(); auto monitor_service = std::make_shared(container); @@ -285,34 +275,23 @@ TEST_F(MonitorTest, RunWithContext) { // Verify that if 0 timeout is passed in, we should set it to default value TEST_F(MonitorTest, ZeroEFMTimeout) { - auto proxy = new MOCK_MYSQL_MONITOR_PROXY(); - - EXPECT_CALL(*proxy, is_connected()) - .WillOnce(Return(false)) - .WillRepeatedly(Return(true)); + auto proxy = new MOCK_MYSQL_PROXY(dbc, ds); - EXPECT_CALL(*proxy, init()).Times(AtLeast(1)); - EXPECT_CALL( - *proxy, - options(MYSQL_OPT_CONNECT_TIMEOUT, - MatcherCast(SafeMatcherCast( - Pointee(Eq(failure_detection_timeout_default)))))) - .Times(1); - EXPECT_CALL( - *proxy, - options(MYSQL_OPT_READ_TIMEOUT, - MatcherCast(SafeMatcherCast( - Pointee(Eq(failure_detection_timeout_default)))))) - .Times(1); - - EXPECT_CALL(*proxy, connect()).WillOnce(Return(true)); + EXPECT_CALL(*proxy, is_connected()).WillRepeatedly(Return(true)); + + EXPECT_CALL(*mock_connection_handler, + connect(host, + AllOf( + Field("connect_timeout",&DataSource::connect_timeout, Eq(failure_detection_timeout_default)), + Field("network_timeout", &DataSource::network_timeout, Eq(failure_detection_timeout_default))))) + .WillOnce(Return(proxy)); EXPECT_CALL(*proxy, ping()).WillRepeatedly(Return(0)); - std::chrono::seconds zero_timeout = std::chrono::seconds(0); + auto zero_timeout = std::chrono::seconds(0); - std::shared_ptr monitorA = - std::make_shared(host, zero_timeout, short_interval, proxy); + auto monitorA = + std::make_shared(host, mock_connection_handler, zero_timeout, short_interval, ds, nullptr); CONNECTION_STATUS status1 = TEST_UTILS::check_connection_status(monitorA); EXPECT_TRUE(status1.is_valid); @@ -320,33 +299,24 @@ TEST_F(MonitorTest, ZeroEFMTimeout) { // Verify that if non-zero timeout is passed in, we should set it to that value TEST_F(MonitorTest, NonZeroEFMTimeout) { - auto proxy = new MOCK_MYSQL_MONITOR_PROXY(); - std::chrono::seconds timeout = std::chrono::seconds(1); + auto proxy = new MOCK_MYSQL_PROXY(dbc, ds); + auto timeout = std::chrono::seconds(1); - EXPECT_CALL(*proxy, is_connected()) - .WillOnce(Return(false)) - .WillRepeatedly(Return(true)); + EXPECT_CALL(*proxy, is_connected()).WillRepeatedly(Return(true)); - EXPECT_CALL(*proxy, init()).Times(AtLeast(1)); - EXPECT_CALL( - *proxy, - options(MYSQL_OPT_CONNECT_TIMEOUT, - MatcherCast(SafeMatcherCast(Pointee(Eq(timeout.count())))))) - .Times(1); - EXPECT_CALL( - *proxy, - options(MYSQL_OPT_READ_TIMEOUT, - MatcherCast(SafeMatcherCast(Pointee(Eq(timeout.count())))))) - .Times(1); + EXPECT_CALL(*mock_connection_handler, + connect(host, + AllOf( + Field("connect_timeout", &DataSource::connect_timeout, Eq(timeout.count())), + Field("network_timeout", &DataSource::network_timeout, Eq(timeout.count()))))) + .WillOnce(Return(proxy)); - EXPECT_CALL(*proxy, connect()).WillOnce(Return(true)); - - EXPECT_CALL(*proxy, ping()).WillRepeatedly(Return(0)); + EXPECT_CALL(*proxy, ping()).WillRepeatedly(Return(0)); - std::shared_ptr monitorA = - std::make_shared(host, timeout, short_interval, proxy); + auto monitorA = + std::make_shared(host, mock_connection_handler, timeout, short_interval, ds, nullptr); - CONNECTION_STATUS status1 = TEST_UTILS::check_connection_status(monitorA); - EXPECT_TRUE(status1.is_valid); + CONNECTION_STATUS status1 = TEST_UTILS::check_connection_status(monitorA); + EXPECT_TRUE(status1.is_valid); } diff --git a/unit_testing/monitor_thread_container_test.cc b/unit_testing/monitor_thread_container_test.cc index 8194b6278..3edd583a9 100644 --- a/unit_testing/monitor_thread_container_test.cc +++ b/unit_testing/monitor_thread_container_test.cc @@ -78,14 +78,16 @@ TEST_F(MonitorThreadContainerTest, MultipleNodeKeys) { std::set node_keys1 = { "nodeOne.domain", "nodeTwo.domain" }; std::set node_keys2 = { "nodeTwo.domain" }; + auto mock_connection_handler = std::make_shared(); + auto monitor1 = thread_container->get_or_create_monitor( - node_keys1, host, failure_detection_timeout, monitor_disposal_time, nullptr); + node_keys1, host, failure_detection_timeout, monitor_disposal_time, nullptr, mock_connection_handler); EXPECT_NE(nullptr, monitor1); // Should return the same monitor again because the first call to get_or_create_monitor() // mapped the monitor to both "nodeOne.domain" and "nodeTwo.domain". auto monitor2 = thread_container->get_or_create_monitor( - node_keys2, host, failure_detection_timeout, monitor_disposal_time, nullptr); + node_keys2, host, failure_detection_timeout, monitor_disposal_time, nullptr, mock_connection_handler); EXPECT_NE(nullptr, monitor2); EXPECT_TRUE(monitor1 == monitor2); @@ -94,12 +96,14 @@ TEST_F(MonitorThreadContainerTest, MultipleNodeKeys) { TEST_F(MonitorThreadContainerTest, DifferentNodeKeys) { std::set keys = { "nodeNEW.domain" }; + auto mock_connection_handler = std::make_shared(); + auto monitor1 = thread_container->get_or_create_monitor( - keys, host, failure_detection_timeout, monitor_disposal_time, nullptr); + keys, host, failure_detection_timeout, monitor_disposal_time, nullptr, mock_connection_handler); EXPECT_NE(nullptr, monitor1); auto monitor2 = thread_container->get_or_create_monitor( - keys, host, failure_detection_timeout, monitor_disposal_time, nullptr); + keys, host, failure_detection_timeout, monitor_disposal_time, nullptr, mock_connection_handler); EXPECT_NE(nullptr, monitor2); // Monitors should be the same because both calls to get_or_create_monitor() @@ -107,7 +111,7 @@ TEST_F(MonitorThreadContainerTest, DifferentNodeKeys) { EXPECT_TRUE(monitor1 == monitor2); auto monitor3 = thread_container->get_or_create_monitor( - node_keys, host, failure_detection_timeout, monitor_disposal_time, nullptr); + node_keys, host, failure_detection_timeout, monitor_disposal_time, nullptr, mock_connection_handler); EXPECT_NE(nullptr, monitor3); // Last monitor should be different because it has a different node key. @@ -120,19 +124,21 @@ TEST_F(MonitorThreadContainerTest, SameKeysInDifferentNodeKeys) { std::set keys2 = { "nodeA", "nodeB" }; std::set keys3 = { "nodeB" }; + auto mock_connection_handler = std::make_shared(); + auto monitor1 = thread_container->get_or_create_monitor( - keys1, host, failure_detection_timeout, monitor_disposal_time, nullptr); + keys1, host, failure_detection_timeout, monitor_disposal_time, nullptr, mock_connection_handler); EXPECT_NE(nullptr, monitor1); auto monitor2 = thread_container->get_or_create_monitor( - keys2, host, failure_detection_timeout, monitor_disposal_time, nullptr); + keys2, host, failure_detection_timeout, monitor_disposal_time, nullptr, mock_connection_handler); EXPECT_NE(nullptr, monitor2); // Monitors should be the same because both sets of keys have "nodeA". EXPECT_TRUE(monitor1 == monitor2); auto monitor3 = thread_container->get_or_create_monitor( - keys3, host, failure_detection_timeout, monitor_disposal_time, nullptr); + keys3, host, failure_detection_timeout, monitor_disposal_time, nullptr, mock_connection_handler); EXPECT_NE(nullptr, monitor3); // Last monitor should be also be the same because the 2nd call to get_or_create_monitor() @@ -143,8 +149,11 @@ TEST_F(MonitorThreadContainerTest, SameKeysInDifferentNodeKeys) { TEST_F(MonitorThreadContainerTest, PopulateMonitorMap) { std::set keys = { "nodeA", "nodeB", "nodeC", "nodeD" }; + + auto mock_connection_handler = std::make_shared(); + auto monitor = thread_container->get_or_create_monitor( - keys, host, failure_detection_timeout, monitor_disposal_time, nullptr); + keys, host, failure_detection_timeout, monitor_disposal_time, nullptr, mock_connection_handler); // Check that we now have mappings for all the keys. for (auto it = keys.begin(); it != keys.end(); ++it) { @@ -158,12 +167,15 @@ TEST_F(MonitorThreadContainerTest, PopulateMonitorMap) { TEST_F(MonitorThreadContainerTest, RemoveMonitorMapping) { std::set keys1 = { "nodeA", "nodeB", "nodeC", "nodeD" }; + + auto mock_connection_handler = std::make_shared(); + auto monitor1 = thread_container->get_or_create_monitor( - keys1, host, failure_detection_timeout, monitor_disposal_time, nullptr); + keys1, host, failure_detection_timeout, monitor_disposal_time, nullptr, mock_connection_handler); std::set keys2 = { "nodeE", "nodeF", "nodeG", "nodeH" }; auto monitor2 = thread_container->get_or_create_monitor( - keys2, host, failure_detection_timeout, std::chrono::milliseconds(100), nullptr); + keys2, host, failure_detection_timeout, std::chrono::milliseconds(100), nullptr, mock_connection_handler); // This should remove the mappings for keys1 but not keys2. thread_container->reset_resource(monitor1); @@ -193,7 +205,7 @@ TEST_F(MonitorThreadContainerTest, AvailableMonitorsQueue) { auto mock_monitor2 = std::make_shared(host, monitor_disposal_time, nullptr); // While we have three get_or_create_monitor() calls, we only call create_monitor() twice. - EXPECT_CALL(*mock_thread_container, create_monitor(_, _, _, _, _)) + EXPECT_CALL(*mock_thread_container, create_monitor(_, _, _, _, _, _)) .WillOnce(Return(mock_monitor1)) .WillOnce(Return(mock_monitor2)); @@ -205,7 +217,7 @@ TEST_F(MonitorThreadContainerTest, AvailableMonitorsQueue) { // This first call should create the monitor. auto monitor1 = mock_thread_container->get_or_create_monitor( - keys, host, failure_detection_timeout, monitor_disposal_time, nullptr); + keys, host, failure_detection_timeout, monitor_disposal_time, nullptr, nullptr); EXPECT_NE(nullptr, monitor1); mock_thread_container->add_task(monitor1, monitor_service); @@ -225,7 +237,7 @@ TEST_F(MonitorThreadContainerTest, AvailableMonitorsQueue) { // This second call should get the monitor from the available monitors queue // instead of creating a new monitor. auto monitor2 = mock_thread_container->get_or_create_monitor( - keys, host, failure_detection_timeout, monitor_disposal_time, nullptr); + keys, host, failure_detection_timeout, monitor_disposal_time, nullptr, nullptr); EXPECT_NE(nullptr, monitor2); EXPECT_TRUE(monitor2 == available_monitor); @@ -236,7 +248,7 @@ TEST_F(MonitorThreadContainerTest, AvailableMonitorsQueue) { // This call will discard the available monitor because it is now stopped // and create a new monitor. auto monitor3 = mock_thread_container->get_or_create_monitor( - { node_key }, host, failure_detection_timeout, monitor_disposal_time, nullptr); + { node_key }, host, failure_detection_timeout, monitor_disposal_time, nullptr, nullptr); EXPECT_NE(nullptr, monitor3); EXPECT_NE(monitor1, monitor3); @@ -247,9 +259,15 @@ TEST_F(MonitorThreadContainerTest, AvailableMonitorsQueue) { } TEST_F(MonitorThreadContainerTest, PopulateAndRemoveMappings) { + SQLHENV env; + DBC* dbc; + DataSource* ds; + allocate_odbc_handles(env, dbc, ds); + dbc->connection_handler = std::make_shared(); + auto context = service->start_monitoring( - nullptr, - nullptr, + dbc, + ds, node_keys, host, failure_detection_time, @@ -267,4 +285,6 @@ TEST_F(MonitorThreadContainerTest, PopulateAndRemoveMappings) { EXPECT_FALSE(TEST_UTILS::has_monitor(thread_container, node_key)); EXPECT_FALSE(TEST_UTILS::has_any_tasks(thread_container)); + + cleanup_odbc_handles(env, dbc, ds); } diff --git a/unit_testing/multi_threaded_monitor_service_test.cc b/unit_testing/multi_threaded_monitor_service_test.cc index 3b886fee2..8b57546c2 100644 --- a/unit_testing/multi_threaded_monitor_service_test.cc +++ b/unit_testing/multi_threaded_monitor_service_test.cc @@ -48,6 +48,9 @@ namespace { class MultiThreadedMonitorServiceTest : public testing::Test { protected: + SQLHENV env; + DBC* dbc; + DataSource* ds; const int num_connections = 10; std::shared_ptr host; std::shared_ptr mock_container; @@ -58,6 +61,7 @@ class MultiThreadedMonitorServiceTest : public testing::Test { static void TearDownTestSuite() {} void SetUp() override { + allocate_odbc_handles(env, dbc, ds); host = std::make_shared("host", 1234); mock_container = std::make_shared(); services = generate_services(num_connections); @@ -68,6 +72,7 @@ class MultiThreadedMonitorServiceTest : public testing::Test { service->release_resources(); } MONITOR_THREAD_CONTAINER::release_instance(); + cleanup_odbc_handles(env, dbc, ds); } std::vector> generate_services(int num_services) { @@ -110,10 +115,10 @@ class MultiThreadedMonitorServiceTest : public testing::Test { auto service = services.at(i); auto node_keys = node_key_list.at(i); - auto thread = std::thread([&service, node_keys, host]() { + auto thread = std::thread([&service, node_keys, host, this]() { service->start_monitoring( - nullptr, - nullptr, + dbc, + ds, node_keys, host, failure_detection_time, @@ -167,7 +172,7 @@ TEST_F(MultiThreadedMonitorServiceTest, StartAndStopMonitoring_MultipleConnectio Sequence s1; for (int i = 0; i < num_connections; i++) { - EXPECT_CALL(*mock_container, create_monitor(_, _, _, _, _)) + EXPECT_CALL(*mock_container, create_monitor(_, _, _, _, _, _)) .InSequence(s1) .WillOnce(Return(monitors[i])); } @@ -199,7 +204,7 @@ TEST_F(MultiThreadedMonitorServiceTest, StartAndStopMonitoring_MultipleConnectio auto mock_monitor = std::make_shared(host, monitor_disposal_time); - EXPECT_CALL(*mock_container, create_monitor(_, _, _, _, _)) + EXPECT_CALL(*mock_container, create_monitor(_, _, _, _, _, _)) .WillOnce(Return(mock_monitor)); EXPECT_CALL(*mock_monitor, run(_)).Times(AtLeast(1)); diff --git a/util/installer.cc b/util/installer.cc index cb777fad0..e216fe66b 100644 --- a/util/installer.cc +++ b/util/installer.cc @@ -749,12 +749,8 @@ DataSource *ds_new() ds->port = 3306; ds->has_port = false; ds->no_schema = 1; - ds->auth_mode = 0; - ds->auth_region = 0; - ds->auth_host = 0; ds->auth_port = 0; ds->auth_expiration = 0; - ds->auth_secret_id = 0; ds->enable_cluster_failover = true; ds->allow_reader_connections = false; ds->gather_perf_metrics = false; @@ -846,6 +842,15 @@ void ds_delete(DataSource *ds) x_free(ds->ssl_crlpath8); x_free(ds->load_data_local_dir8); + x_free(ds->auth_mode); + x_free(ds->auth_region); + x_free(ds->auth_host); + x_free(ds->auth_secret_id); + x_free(ds->auth_mode8); + x_free(ds->auth_region8); + x_free(ds->auth_host8); + x_free(ds->auth_secret_id8); + x_free(ds->host_pattern); x_free(ds->cluster_id); x_free(ds->host_pattern8); @@ -2083,12 +2088,6 @@ void ds_copy(DataSource *ds, DataSource *ds_source) { sqlwcharlen(ds_source->cluster_id)); } - ds->auth_mode = ds_source->auth_mode; - ds->auth_region = ds_source->auth_region; - ds->auth_host = ds_source->auth_host; - ds->auth_port = ds_source->auth_port; - ds->auth_expiration = ds_source->auth_expiration; - ds->auth_secret_id = ds_source->auth_secret_id; ds->enable_cluster_failover = ds_source->enable_cluster_failover; ds->allow_reader_connections = ds_source->allow_reader_connections; ds->gather_perf_metrics = ds_source->gather_perf_metrics; @@ -2109,4 +2108,24 @@ void ds_copy(DataSource *ds, DataSource *ds_source) { ds->failure_detection_count = ds_source->failure_detection_count; ds->monitor_disposal_time = ds_source->monitor_disposal_time; ds->failure_detection_timeout = ds_source->failure_detection_timeout; + + /* AWS Authentication */ + if (ds_source->auth_mode != nullptr) { + ds_set_wstrnattr(&ds->auth_mode, ds_source->auth_mode, + sqlwcharlen(ds_source->auth_mode)); + } + if (ds_source->auth_region != nullptr) { + ds_set_wstrnattr(&ds->auth_region, ds_source->auth_region, + sqlwcharlen(ds_source->auth_region)); + } + if (ds_source->auth_host != nullptr) { + ds_set_wstrnattr(&ds->auth_host, ds_source->auth_host, + sqlwcharlen(ds_source->auth_host)); + } + ds->auth_port = ds_source->auth_port; + ds->auth_expiration = ds_source->auth_expiration; + if (ds_source->auth_secret_id != nullptr) { + ds_set_wstrnattr(&ds->auth_secret_id, ds_source->auth_secret_id, + sqlwcharlen(ds_source->auth_secret_id)); + } }