diff --git a/integration/connection_string_builder.cc b/integration/connection_string_builder.cc index d850a0ffd..119209ea0 100644 --- a/integration/connection_string_builder.cc +++ b/integration/connection_string_builder.cc @@ -44,14 +44,16 @@ class ConnectionString { m_connect_timeout(-1), m_network_timeout(-1), m_host_pattern(""), m_enable_failure_detection(true), m_failure_detection_time(-1), m_failure_detection_timeout(-1), m_failure_detection_interval(-1), m_failure_detection_count(-1), m_monitor_disposal_time(-1), - m_read_timeout(-1), m_write_timeout(-1), + m_read_timeout(-1), m_write_timeout(-1), m_auth_mode(""), m_auth_region(""), m_auth_host(""), + m_auth_port(-1), m_auth_expiration(-1), m_secret_id(""), is_set_uid(false), is_set_pwd(false), is_set_db(false), is_set_log_query(false), is_set_allow_reader_connections(false), is_set_multi_statements(false), is_set_enable_cluster_failover(false), is_set_failover_timeout(false), is_set_connect_timeout(false), is_set_network_timeout(false), is_set_host_pattern(false), is_set_enable_failure_detection(false), is_set_failure_detection_time(false), is_set_failure_detection_timeout(false), is_set_failure_detection_interval(false), is_set_failure_detection_count(false), is_set_monitor_disposal_time(false), - is_set_read_timeout(false), is_set_write_timeout(false) {}; + is_set_read_timeout(false), is_set_write_timeout(false), is_set_auth_mode(false), is_set_auth_region(false), + is_set_auth_host(false), is_set_auth_port(false), is_set_auth_expiration(false), is_set_secret_id(false) {}; std::string get_connection_string() const { char conn_in[4096] = "\0"; @@ -115,6 +117,24 @@ class ConnectionString { if (is_set_write_timeout) { length += sprintf(conn_in + length, "WRITETIMEOUT=%d;", m_write_timeout); } + if (is_set_auth_mode) { + length += sprintf(conn_in + length, "AUTHENTICATION_MODE=%s;", m_auth_mode.c_str()); + } + if (is_set_auth_region) { + length += sprintf(conn_in + length, "AWS_REGION=%s;", m_auth_region.c_str()); + } + if (is_set_auth_host) { + length += sprintf(conn_in + length, "IAM_HOST=%s;", m_auth_host.c_str()); + } + if (is_set_auth_port) { + length += sprintf(conn_in + length, "IAM_PORT=%d;", m_auth_port); + } + if (is_set_auth_expiration) { + length += sprintf(conn_in + length, "IAM_EXPIRATION_TIME=%d;", m_auth_expiration); + } + if (is_set_secret_id) { + length += sprintf(conn_in + length, "SECRET_ID=%s;", m_secret_id.c_str()); + } snprintf(conn_in + length, sizeof(conn_in) - length, "\0"); std::string connection_string(conn_in); @@ -133,6 +153,8 @@ class ConnectionString { std::string m_host_pattern; bool m_enable_failure_detection; int m_failure_detection_time, m_failure_detection_timeout, m_failure_detection_interval, m_failure_detection_count, m_monitor_disposal_time, m_read_timeout, m_write_timeout; + std::string m_auth_mode, m_auth_region, m_auth_host, m_secret_id; + int m_auth_port, m_auth_expiration; bool is_set_uid, is_set_pwd, is_set_db; bool is_set_log_query, is_set_allow_reader_connections, is_set_multi_statements; @@ -143,6 +165,7 @@ class ConnectionString { bool is_set_failure_detection_time, is_set_failure_detection_timeout, is_set_failure_detection_interval, is_set_failure_detection_count; bool is_set_monitor_disposal_time; bool is_set_read_timeout, is_set_write_timeout; + bool is_set_auth_mode, is_set_auth_region, is_set_auth_host, is_set_auth_port, is_set_auth_expiration, is_set_secret_id; void set_dsn(const std::string& dsn) { m_dsn = dsn; @@ -250,6 +273,36 @@ class ConnectionString { m_write_timeout = write_timeout; is_set_write_timeout = true; } + + void set_auth_mode(const std::string& auth_mode) { + m_auth_mode = auth_mode; + is_set_auth_mode = true; + } + + void set_auth_region(const std::string& auth_region) { + m_auth_region = auth_region; + is_set_auth_region = true; + } + + void set_auth_host(const std::string& auth_host) { + m_auth_host = auth_host; + is_set_auth_host = true; + } + + void set_auth_port(const int& auth_port) { + m_auth_port = auth_port; + is_set_auth_port = true; + } + + void set_auth_expiration(const int& auth_expiration) { + m_auth_expiration = auth_expiration; + is_set_auth_expiration = true; + } + + void set_secret_id(const std::string& secret_id) { + m_secret_id = secret_id; + is_set_secret_id = true; + } }; class ConnectionStringBuilder { @@ -368,6 +421,36 @@ class ConnectionStringBuilder { return *this; } + ConnectionStringBuilder& withAuthMode(const std::string& auth_mode) { + connection_string->set_auth_mode(auth_mode); + return *this; + } + + ConnectionStringBuilder& withAuthRegion(const std::string& auth_region) { + connection_string->set_auth_region(auth_region); + return *this; + } + + ConnectionStringBuilder& withAuthHost(const std::string& auth_host) { + connection_string->set_auth_host(auth_host); + return *this; + } + + ConnectionStringBuilder& withAuthPort(const int& auth_port) { + connection_string->set_auth_port(auth_port); + return *this; + } + + ConnectionStringBuilder& withAuthExpiration(const int& auth_expiration) { + connection_string->set_auth_expiration(auth_expiration); + return *this; + } + + ConnectionStringBuilder& withSecretId(const std::string& secret_id) { + connection_string->set_secret_id(secret_id); + return *this; + } + std::string build() const { if (connection_string->m_dsn.empty()) { throw std::runtime_error("DSN is a required field in a connection string."); diff --git a/util/installer.cc b/util/installer.cc index d7a3dd2ee..b23fa5f03 100644 --- a/util/installer.cc +++ b/util/installer.cc @@ -236,10 +236,10 @@ static SQLWCHAR W_SSL_CRLPATH[] = /* AWS Authentication */ static SQLWCHAR W_AUTH_MODE[] = { 'A', 'U', 'T', 'H', 'E', 'N', 'T', 'I', 'C', 'A', 'T', 'I', 'O', 'N', '_', 'M', 'O', 'D', 'E', 0}; -static SQLWCHAR W_AUTH_REGION[] = { 'I', 'A', 'M', '_', 'R', 'E', 'G', 'I', 'O', 'N', 0 }; +static SQLWCHAR W_AUTH_REGION[] = { 'A', 'W', 'S', '_', 'R', 'E', 'G', 'I', 'O', 'N', 0 }; static SQLWCHAR W_AUTH_HOST[] = { 'I', 'A', 'M', '_', 'H', 'O', 'S', 'T', 0 }; static SQLWCHAR W_AUTH_PORT[] = { 'I', 'A', 'M', '_', 'P', 'O', 'R', 'T', 0 }; -static SQLWCHAR W_AUTH_EXPIRATION[] = { 'E', 'X', 'P', 'I', 'R', 'A', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 0 }; +static SQLWCHAR W_AUTH_EXPIRATION[] = { 'I', 'A', 'M', '_', 'E', 'X', 'P', 'I', 'R', 'A', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 0 }; static SQLWCHAR W_AUTH_SECRET_ID[] = { 'S', 'E', 'C', 'R', 'E', 'T', '_', 'I', 'D', 0 }; /* Failover */