Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Authentication Parameters to ConnectionStringBuilder #124

Merged
merged 2 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions integration/connection_string_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ class ConnectionString {
m_connect_timeout(-1), m_network_timeout(-1), m_host_pattern(""),
m_enable_failure_detection(true), m_failure_detection_time(-1), m_failure_detection_timeout(-1),
m_failure_detection_interval(-1), m_failure_detection_count(-1), m_monitor_disposal_time(-1),
m_read_timeout(-1), m_write_timeout(-1),
m_read_timeout(-1), m_write_timeout(-1), m_auth_mode(""), m_auth_region(""), m_auth_host(""),
m_auth_port(-1), m_auth_expiration(-1), m_secret_id(""),

is_set_uid(false), is_set_pwd(false), is_set_db(false), is_set_log_query(false),
is_set_allow_reader_connections(false), is_set_multi_statements(false), is_set_enable_cluster_failover(false),
is_set_failover_timeout(false), is_set_connect_timeout(false), is_set_network_timeout(false), is_set_host_pattern(false),
is_set_enable_failure_detection(false), is_set_failure_detection_time(false), is_set_failure_detection_timeout(false),
is_set_failure_detection_interval(false), is_set_failure_detection_count(false), is_set_monitor_disposal_time(false),
is_set_read_timeout(false), is_set_write_timeout(false) {};
is_set_read_timeout(false), is_set_write_timeout(false), is_set_auth_mode(false), is_set_auth_region(false),
is_set_auth_host(false), is_set_auth_port(false), is_set_auth_expiration(false), is_set_secret_id(false) {};

std::string get_connection_string() const {
char conn_in[4096] = "\0";
Expand Down Expand Up @@ -115,6 +117,24 @@ class ConnectionString {
if (is_set_write_timeout) {
length += sprintf(conn_in + length, "WRITETIMEOUT=%d;", m_write_timeout);
}
if (is_set_auth_mode) {
length += sprintf(conn_in + length, "AUTHENTICATION_MODE=%s;", m_auth_mode.c_str());
}
if (is_set_auth_region) {
length += sprintf(conn_in + length, "AWS_REGION=%s;", m_auth_region.c_str());
}
if (is_set_auth_host) {
length += sprintf(conn_in + length, "IAM_HOST=%s;", m_auth_host.c_str());
}
if (is_set_auth_port) {
length += sprintf(conn_in + length, "IAM_PORT=%d;", m_auth_port);
}
if (is_set_auth_expiration) {
length += sprintf(conn_in + length, "IAM_EXPIRATION_TIME=%d;", m_auth_expiration);
}
if (is_set_secret_id) {
length += sprintf(conn_in + length, "SECRET_ID=%s;", m_secret_id.c_str());
}
snprintf(conn_in + length, sizeof(conn_in) - length, "\0");

std::string connection_string(conn_in);
Expand All @@ -133,6 +153,8 @@ class ConnectionString {
std::string m_host_pattern;
bool m_enable_failure_detection;
int m_failure_detection_time, m_failure_detection_timeout, m_failure_detection_interval, m_failure_detection_count, m_monitor_disposal_time, m_read_timeout, m_write_timeout;
std::string m_auth_mode, m_auth_region, m_auth_host, m_secret_id;
int m_auth_port, m_auth_expiration;

bool is_set_uid, is_set_pwd, is_set_db;
bool is_set_log_query, is_set_allow_reader_connections, is_set_multi_statements;
Expand All @@ -143,6 +165,7 @@ class ConnectionString {
bool is_set_failure_detection_time, is_set_failure_detection_timeout, is_set_failure_detection_interval, is_set_failure_detection_count;
bool is_set_monitor_disposal_time;
bool is_set_read_timeout, is_set_write_timeout;
bool is_set_auth_mode, is_set_auth_region, is_set_auth_host, is_set_auth_port, is_set_auth_expiration, is_set_secret_id;

void set_dsn(const std::string& dsn) {
m_dsn = dsn;
Expand Down Expand Up @@ -250,6 +273,36 @@ class ConnectionString {
m_write_timeout = write_timeout;
is_set_write_timeout = true;
}

void set_auth_mode(const std::string& auth_mode) {
m_auth_mode = auth_mode;
is_set_auth_mode = true;
}

void set_auth_region(const std::string& auth_region) {
m_auth_region = auth_region;
is_set_auth_region = true;
}

void set_auth_host(const std::string& auth_host) {
m_auth_host = auth_host;
is_set_auth_host = true;
}

void set_auth_port(const int& auth_port) {
m_auth_port = auth_port;
is_set_auth_port = true;
}

void set_auth_expiration(const int& auth_expiration) {
m_auth_expiration = auth_expiration;
is_set_auth_expiration = true;
}

void set_secret_id(const std::string& secret_id) {
m_secret_id = secret_id;
is_set_secret_id = true;
}
};

class ConnectionStringBuilder {
Expand Down Expand Up @@ -368,6 +421,36 @@ class ConnectionStringBuilder {
return *this;
}

ConnectionStringBuilder& withAuthMode(const std::string& auth_mode) {
connection_string->set_auth_mode(auth_mode);
return *this;
}

ConnectionStringBuilder& withAuthRegion(const std::string& auth_region) {
connection_string->set_auth_region(auth_region);
return *this;
}

ConnectionStringBuilder& withAuthHost(const std::string& auth_host) {
connection_string->set_auth_host(auth_host);
return *this;
}

ConnectionStringBuilder& withAuthPort(const int& auth_port) {
connection_string->set_auth_port(auth_port);
return *this;
}

ConnectionStringBuilder& withAuthExpiration(const int& auth_expiration) {
connection_string->set_auth_expiration(auth_expiration);
return *this;
}

ConnectionStringBuilder& withSecretId(const std::string& secret_id) {
connection_string->set_secret_id(secret_id);
return *this;
}

std::string build() const {
if (connection_string->m_dsn.empty()) {
throw std::runtime_error("DSN is a required field in a connection string.");
Expand Down
4 changes: 2 additions & 2 deletions util/installer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ static SQLWCHAR W_SSL_CRLPATH[] =

/* AWS Authentication */
static SQLWCHAR W_AUTH_MODE[] = { 'A', 'U', 'T', 'H', 'E', 'N', 'T', 'I', 'C', 'A', 'T', 'I', 'O', 'N', '_', 'M', 'O', 'D', 'E', 0};
static SQLWCHAR W_AUTH_REGION[] = { 'I', 'A', 'M', '_', 'R', 'E', 'G', 'I', 'O', 'N', 0 };
static SQLWCHAR W_AUTH_REGION[] = { 'A', 'W', 'S', '_', 'R', 'E', 'G', 'I', 'O', 'N', 0 };
static SQLWCHAR W_AUTH_HOST[] = { 'I', 'A', 'M', '_', 'H', 'O', 'S', 'T', 0 };
static SQLWCHAR W_AUTH_PORT[] = { 'I', 'A', 'M', '_', 'P', 'O', 'R', 'T', 0 };
static SQLWCHAR W_AUTH_EXPIRATION[] = { 'E', 'X', 'P', 'I', 'R', 'A', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 0 };
static SQLWCHAR W_AUTH_EXPIRATION[] = { 'I', 'A', 'M', '_', 'E', 'X', 'P', 'I', 'R', 'A', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 0 };
static SQLWCHAR W_AUTH_SECRET_ID[] = { 'S', 'E', 'C', 'R', 'E', 'T', '_', 'I', 'D', 0 };

/* Failover */
Expand Down