Skip to content

Commit

Permalink
Add Authentication Parameters to ConnectionStringBuilder (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
justing-bq committed Apr 28, 2023
1 parent 10db71f commit 0796afa
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 4 deletions.
87 changes: 85 additions & 2 deletions integration/connection_string_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ class ConnectionString {
m_connect_timeout(-1), m_network_timeout(-1), m_host_pattern(""),
m_enable_failure_detection(true), m_failure_detection_time(-1), m_failure_detection_timeout(-1),
m_failure_detection_interval(-1), m_failure_detection_count(-1), m_monitor_disposal_time(-1),
m_read_timeout(-1), m_write_timeout(-1),
m_read_timeout(-1), m_write_timeout(-1), m_auth_mode(""), m_auth_region(""), m_auth_host(""),
m_auth_port(-1), m_auth_expiration(-1), m_secret_id(""),

is_set_uid(false), is_set_pwd(false), is_set_db(false), is_set_log_query(false),
is_set_allow_reader_connections(false), is_set_multi_statements(false), is_set_enable_cluster_failover(false),
is_set_failover_timeout(false), is_set_connect_timeout(false), is_set_network_timeout(false), is_set_host_pattern(false),
is_set_enable_failure_detection(false), is_set_failure_detection_time(false), is_set_failure_detection_timeout(false),
is_set_failure_detection_interval(false), is_set_failure_detection_count(false), is_set_monitor_disposal_time(false),
is_set_read_timeout(false), is_set_write_timeout(false) {};
is_set_read_timeout(false), is_set_write_timeout(false), is_set_auth_mode(false), is_set_auth_region(false),
is_set_auth_host(false), is_set_auth_port(false), is_set_auth_expiration(false), is_set_secret_id(false) {};

std::string get_connection_string() const {
char conn_in[4096] = "\0";
Expand Down Expand Up @@ -115,6 +117,24 @@ class ConnectionString {
if (is_set_write_timeout) {
length += sprintf(conn_in + length, "WRITETIMEOUT=%d;", m_write_timeout);
}
if (is_set_auth_mode) {
length += sprintf(conn_in + length, "AUTHENTICATION_MODE=%s;", m_auth_mode.c_str());
}
if (is_set_auth_region) {
length += sprintf(conn_in + length, "AWS_REGION=%s;", m_auth_region.c_str());
}
if (is_set_auth_host) {
length += sprintf(conn_in + length, "IAM_HOST=%s;", m_auth_host.c_str());
}
if (is_set_auth_port) {
length += sprintf(conn_in + length, "IAM_PORT=%d;", m_auth_port);
}
if (is_set_auth_expiration) {
length += sprintf(conn_in + length, "IAM_EXPIRATION_TIME=%d;", m_auth_expiration);
}
if (is_set_secret_id) {
length += sprintf(conn_in + length, "SECRET_ID=%s;", m_secret_id.c_str());
}
snprintf(conn_in + length, sizeof(conn_in) - length, "\0");

std::string connection_string(conn_in);
Expand All @@ -133,6 +153,8 @@ class ConnectionString {
std::string m_host_pattern;
bool m_enable_failure_detection;
int m_failure_detection_time, m_failure_detection_timeout, m_failure_detection_interval, m_failure_detection_count, m_monitor_disposal_time, m_read_timeout, m_write_timeout;
std::string m_auth_mode, m_auth_region, m_auth_host, m_secret_id;
int m_auth_port, m_auth_expiration;

bool is_set_uid, is_set_pwd, is_set_db;
bool is_set_log_query, is_set_allow_reader_connections, is_set_multi_statements;
Expand All @@ -143,6 +165,7 @@ class ConnectionString {
bool is_set_failure_detection_time, is_set_failure_detection_timeout, is_set_failure_detection_interval, is_set_failure_detection_count;
bool is_set_monitor_disposal_time;
bool is_set_read_timeout, is_set_write_timeout;
bool is_set_auth_mode, is_set_auth_region, is_set_auth_host, is_set_auth_port, is_set_auth_expiration, is_set_secret_id;

void set_dsn(const std::string& dsn) {
m_dsn = dsn;
Expand Down Expand Up @@ -250,6 +273,36 @@ class ConnectionString {
m_write_timeout = write_timeout;
is_set_write_timeout = true;
}

void set_auth_mode(const std::string& auth_mode) {
m_auth_mode = auth_mode;
is_set_auth_mode = true;
}

void set_auth_region(const std::string& auth_region) {
m_auth_region = auth_region;
is_set_auth_region = true;
}

void set_auth_host(const std::string& auth_host) {
m_auth_host = auth_host;
is_set_auth_host = true;
}

void set_auth_port(const int& auth_port) {
m_auth_port = auth_port;
is_set_auth_port = true;
}

void set_auth_expiration(const int& auth_expiration) {
m_auth_expiration = auth_expiration;
is_set_auth_expiration = true;
}

void set_secret_id(const std::string& secret_id) {
m_secret_id = secret_id;
is_set_secret_id = true;
}
};

class ConnectionStringBuilder {
Expand Down Expand Up @@ -368,6 +421,36 @@ class ConnectionStringBuilder {
return *this;
}

ConnectionStringBuilder& withAuthMode(const std::string& auth_mode) {
connection_string->set_auth_mode(auth_mode);
return *this;
}

ConnectionStringBuilder& withAuthRegion(const std::string& auth_region) {
connection_string->set_auth_region(auth_region);
return *this;
}

ConnectionStringBuilder& withAuthHost(const std::string& auth_host) {
connection_string->set_auth_host(auth_host);
return *this;
}

ConnectionStringBuilder& withAuthPort(const int& auth_port) {
connection_string->set_auth_port(auth_port);
return *this;
}

ConnectionStringBuilder& withAuthExpiration(const int& auth_expiration) {
connection_string->set_auth_expiration(auth_expiration);
return *this;
}

ConnectionStringBuilder& withSecretId(const std::string& secret_id) {
connection_string->set_secret_id(secret_id);
return *this;
}

std::string build() const {
if (connection_string->m_dsn.empty()) {
throw std::runtime_error("DSN is a required field in a connection string.");
Expand Down
4 changes: 2 additions & 2 deletions util/installer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ static SQLWCHAR W_SSL_CRLPATH[] =

/* AWS Authentication */
static SQLWCHAR W_AUTH_MODE[] = { 'A', 'U', 'T', 'H', 'E', 'N', 'T', 'I', 'C', 'A', 'T', 'I', 'O', 'N', '_', 'M', 'O', 'D', 'E', 0};
static SQLWCHAR W_AUTH_REGION[] = { 'I', 'A', 'M', '_', 'R', 'E', 'G', 'I', 'O', 'N', 0 };
static SQLWCHAR W_AUTH_REGION[] = { 'A', 'W', 'S', '_', 'R', 'E', 'G', 'I', 'O', 'N', 0 };
static SQLWCHAR W_AUTH_HOST[] = { 'I', 'A', 'M', '_', 'H', 'O', 'S', 'T', 0 };
static SQLWCHAR W_AUTH_PORT[] = { 'I', 'A', 'M', '_', 'P', 'O', 'R', 'T', 0 };
static SQLWCHAR W_AUTH_EXPIRATION[] = { 'E', 'X', 'P', 'I', 'R', 'A', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 0 };
static SQLWCHAR W_AUTH_EXPIRATION[] = { 'I', 'A', 'M', '_', 'E', 'X', 'P', 'I', 'R', 'A', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 0 };
static SQLWCHAR W_AUTH_SECRET_ID[] = { 'S', 'E', 'C', 'R', 'E', 'T', '_', 'I', 'D', 0 };

/* Failover */
Expand Down

0 comments on commit 0796afa

Please sign in to comment.