Skip to content

Commit

Permalink
Set up IAM user in ODBC
Browse files Browse the repository at this point in the history
  • Loading branch information
justing-bq committed Mar 30, 2023
1 parent 7d9b843 commit 1d920bc
Showing 1 changed file with 80 additions and 22 deletions.
102 changes: 80 additions & 22 deletions integration/iam_authentication_integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,95 @@
class IamAuthenticationIntegrationTest : public BaseFailoverIntegrationTest {
protected:
// Connection string parameters
char* dsn = std::getenv("TEST_DSN");
char* db = std::getenv("TEST_DATABASE");
char* pwd = std::getenv("TEST_PASSWORD");
char* user = get_env_var("TEST_IAM_USER", "john_doe");
static char* dsn;
static char* db;
static char* user;
static char* pwd;
static char* iam_user;

std::string ACCESS_KEY = std::getenv("AWS_ACCESS_KEY_ID");
std::string SECRET_ACCESS_KEY = std::getenv("AWS_SECRET_ACCESS_KEY");
std::string SESSION_TOKEN = std::getenv("AWS_SESSION_TOKEN");
Aws::Auth::AWSCredentials credentials = Aws::Auth::AWSCredentials(
Aws::String(ACCESS_KEY),
Aws::String(SECRET_ACCESS_KEY),
Aws::String(SESSION_TOKEN));
Aws::Client::ClientConfiguration client_config;
Aws::RDS::RDSClient rds_client;
static std::string ACCESS_KEY;
static std::string SECRET_ACCESS_KEY;
static std::string SESSION_TOKEN;
static Aws::Auth::AWSCredentials credentials;
static Aws::Client::ClientConfiguration client_config;
static Aws::RDS::RDSClient rds_client;

static std::string server_endpoint;

ConnectionStringBuilder builder;
std::string connection_string;

std::string server_endpoint;

SQLHENV env = nullptr;
SQLHDBC dbc = nullptr;

static void SetUpTestSuite() {
Aws::InitAPI(options);

ACCESS_KEY = std::getenv("AWS_ACCESS_KEY_ID");
SECRET_ACCESS_KEY = std::getenv("AWS_SECRET_ACCESS_KEY");
SESSION_TOKEN = std::getenv("AWS_SESSION_TOKEN");
credentials = Aws::Auth::AWSCredentials(
Aws::String(ACCESS_KEY),
Aws::String(SECRET_ACCESS_KEY),
Aws::String(SESSION_TOKEN));

client_config.region = "us-east-2";
rds_client = Aws::RDS::RDSClient(credentials, client_config);

std::string MYSQL_CLUSTER_URL = std::getenv("TEST_SERVER");
Aws::String cluster_id = MYSQL_CLUSTER_URL.substr(0, MYSQL_CLUSTER_URL.find('.'));
std::string DB_CONN_STR_SUFFIX = std::getenv("DB_CONN_STR_SUFFIX");

auto cluster_instances = retrieve_topology_via_SDK(rds_client, cluster_id);
server_endpoint = get_writer_id(cluster_instances) + DB_CONN_STR_SUFFIX;

dsn = std::getenv("TEST_DSN");
db = std::getenv("TEST_DATABASE");
user = std::getenv("TEST_UID");
pwd = std::getenv("TEST_PASSWORD");
iam_user = get_env_var("TEST_IAM_USER", "john_doe");

auto conn_str_builder = ConnectionStringBuilder();
auto conn_str = conn_str_builder.withDSN(dsn).withServer(server_endpoint).withUID(user).withPWD(pwd).withDatabase(db).build();

SQLHENV env1 = nullptr;
SQLHDBC dbc1 = nullptr;
SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env1);
SQLSetEnvAttr(env1, SQL_ATTR_ODBC_VERSION, reinterpret_cast<SQLPOINTER>(SQL_OV_ODBC3), 0);
SQLAllocHandle(SQL_HANDLE_DBC, env1, &dbc1);

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc1, nullptr, AS_SQLCHAR(conn_str.c_str()), SQL_NTS, conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));

SQLHSTMT stmt = nullptr;
EXPECT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, dbc1, &stmt));

char query_buffer[200];
sprintf(query_buffer, "DROP USER IF EXISTS %s;", iam_user);
SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS);
memset(query_buffer, 0, sizeof(query_buffer));

sprintf(query_buffer, "CREATE USER %s IDENTIFIED WITH AWSAuthenticationPlugin AS 'RDS';", iam_user);
EXPECT_EQ(SQL_SUCCESS, SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS));
memset(query_buffer, 0, sizeof(query_buffer));

sprintf(query_buffer, "GRANT ALL ON `%`.* TO %s@`%`;", iam_user);
EXPECT_EQ(SQL_SUCCESS, SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS));

EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt));
EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc1));

if (nullptr != stmt) {
SQLFreeHandle(SQL_HANDLE_STMT, stmt);
}
if (nullptr != dbc1) {
SQLDisconnect(dbc1);
SQLFreeHandle(SQL_HANDLE_DBC, dbc1);
}
if (nullptr != env1) {
SQLFreeHandle(SQL_HANDLE_ENV, env1);
}
}

static void TearDownTestSuite() {
Expand All @@ -68,12 +132,6 @@ class IamAuthenticationIntegrationTest : public BaseFailoverIntegrationTest {
SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast<SQLPOINTER>(SQL_OV_ODBC3), 0);
SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc);

client_config.region = "us-east-2";
rds_client = Aws::RDS::RDSClient(credentials, client_config);

cluster_instances = retrieve_topology_via_SDK(rds_client, cluster_id);
server_endpoint = get_endpoint(get_writer_id(cluster_instances));

builder = ConnectionStringBuilder();
builder
.withAuthMode("IAM")
Expand All @@ -92,7 +150,7 @@ class IamAuthenticationIntegrationTest : public BaseFailoverIntegrationTest {
};

TEST_F(IamAuthenticationIntegrationTest, SimpleIamConnection) {
connection_string = builder.withDSN(dsn).withAuthHost(server_endpoint).withUID(user).withPWD(pwd).withDatabase(db).build();
connection_string = builder.withDSN(dsn).withAuthHost(server_endpoint).withUID(iam_user).withPWD(pwd).withDatabase(db).build();
SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));
Expand Down

0 comments on commit 1d920bc

Please sign in to comment.