Skip to content

Commit

Permalink
IAM Authentication Integration Tests (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
justing-bq committed Apr 28, 2023
1 parent 3d562d9 commit dfbbbe0
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 9 deletions.
9 changes: 4 additions & 5 deletions driver/iam_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ IAM_PROXY::~IAM_PROXY() {
bool IAM_PROXY::connect(const char* host, const char* user, const char* password,
const char* database, unsigned int port, const char* socket, unsigned long flags) {

const char* auth_host = host;
if (ds->auth_host) {
host = ds_get_utf8attr(ds->auth_host, &ds->auth_host8);
auth_host = ds_get_utf8attr(ds->auth_host, &ds->auth_host8);
}

const char* region;
Expand All @@ -76,12 +77,10 @@ bool IAM_PROXY::connect(const char* host, const char* user, const char* password
}
else {
// Go with default region if region is not provided.
region = "us-east-1";
region = Aws::Region::US_EAST_1;
}

port = ds->auth_port;

std::string auth_token = this->get_auth_token(host, region, port, user, ds->auth_expiration);
std::string auth_token = this->get_auth_token(auth_host, region, ds->auth_port, user, ds->auth_expiration);

bool connect_result = next_proxy->connect(host, user, auth_token.c_str(), database, port, socket, flags);
if (!connect_result) {
Expand Down
11 changes: 9 additions & 2 deletions integration/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,15 @@ endif()

enable_testing()

set(TEST_SOURCES connection_string_builder.cc base_failover_integration_test.cc connection_string_builder_test.cc)
set(INTEGRATION_TESTS secrets_manager_integration_test.cc network_failover_integration_test.cc failover_integration_test.cc)
set(TEST_SOURCES
connection_string_builder.cc
base_failover_integration_test.cc
connection_string_builder_test.cc)
set(INTEGRATION_TESTS
iam_authentication_integration_test.cc
secrets_manager_integration_test.cc
network_failover_integration_test.cc
failover_integration_test.cc)

if(NOT ENABLE_PERFORMANCE_TESTS)
set(TEST_SOURCES ${TEST_SOURCES} ${INTEGRATION_TESTS})
Expand Down
318 changes: 318 additions & 0 deletions integration/iam_authentication_integration_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include <gtest/gtest.h>
#include <httplib.h>
#include <sql.h>
#include <sqlext.h>

#include "connection_string_builder.cc"

#define MAX_NAME_LEN 255
#define AS_SQLCHAR(str) const_cast<SQLCHAR*>(reinterpret_cast<const SQLCHAR*>(str))

// Connection string parameters
static char* test_dsn;
static char* test_db;
static char* test_user;
static char* test_pwd;
static char* iam_user;

static std::string test_endpoint;

static std::string host_to_IP(std::string hostname) {
int status;
struct addrinfo hints;
struct addrinfo* servinfo;
struct addrinfo* p;
char ipstr[INET_ADDRSTRLEN];

memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET; //IPv4
hints.ai_socktype = SOCK_STREAM;

if ((status = getaddrinfo(hostname.c_str(), NULL, &hints, &servinfo)) != 0) {
ADD_FAILURE() << "The IP address of host " << hostname << " could not be determined."
<< "getaddrinfo error:" << gai_strerror(status);
return {};
}

for (p = servinfo; p != NULL; p = p->ai_next) {
void* addr;

struct sockaddr_in* ipv4 = (struct sockaddr_in*)p->ai_addr;
addr = &(ipv4->sin_addr);
inet_ntop(p->ai_family, addr, ipstr, sizeof(ipstr));
}

freeaddrinfo(servinfo);
return std::string(ipstr);
}

class IamAuthenticationIntegrationTest : public testing::Test {
protected:
ConnectionStringBuilder builder;
SQLHENV env = nullptr;
SQLHDBC dbc = nullptr;

static void SetUpTestSuite() {
test_endpoint = std::getenv("TEST_SERVER");

test_dsn = std::getenv("TEST_DSN");
test_db = std::getenv("TEST_DATABASE");
test_user = std::getenv("TEST_UID");
test_pwd = std::getenv("TEST_PASSWORD");
iam_user = "john_doe";

auto conn_str_builder = ConnectionStringBuilder();
auto conn_str = conn_str_builder
.withDSN(test_dsn)
.withServer(test_endpoint)
.withUID(test_user)
.withPWD(test_pwd)
.withPort(3306)
.withDatabase(test_db).build();

SQLHENV env1 = nullptr;
SQLHDBC dbc1 = nullptr;
SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env1);
SQLSetEnvAttr(env1, SQL_ATTR_ODBC_VERSION, reinterpret_cast<SQLPOINTER>(SQL_OV_ODBC3), 0);
SQLAllocHandle(SQL_HANDLE_DBC, env1, &dbc1);

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
EXPECT_EQ(SQL_SUCCESS,
SQLDriverConnect(dbc1, nullptr, AS_SQLCHAR(conn_str.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));

SQLHSTMT stmt = nullptr;
EXPECT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, dbc1, &stmt));

char query_buffer[200];
sprintf(query_buffer, "DROP USER IF EXISTS %s;", iam_user);
SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS);
memset(query_buffer, 0, sizeof(query_buffer));

sprintf(query_buffer, "CREATE USER %s IDENTIFIED WITH AWSAuthenticationPlugin AS 'RDS';", iam_user);
EXPECT_EQ(SQL_SUCCESS, SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS));
memset(query_buffer, 0, sizeof(query_buffer));

sprintf(query_buffer, "GRANT ALL ON `%`.* TO %s@`%`;", iam_user);
EXPECT_EQ(SQL_SUCCESS, SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS));

EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt));
EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc1));

if (nullptr != stmt) {
SQLFreeHandle(SQL_HANDLE_STMT, stmt);
}
if (nullptr != dbc1) {
SQLDisconnect(dbc1);
SQLFreeHandle(SQL_HANDLE_DBC, dbc1);
}
if (nullptr != env1) {
SQLFreeHandle(SQL_HANDLE_ENV, env1);
}
}

static void TearDownTestSuite() {
}

void SetUp() override {
SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env);
SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast<SQLPOINTER>(SQL_OV_ODBC3), 0);
SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc);

builder = ConnectionStringBuilder();
builder
.withDSN(test_dsn)
.withDatabase(test_db)
.withEnableClusterFailover(false) // Failover interferes with some of our tests
.withAuthMode("IAM")
.withAuthRegion("us-east-2")
.withAuthExpiration(900);
}

void TearDown() override {
if (nullptr != dbc) {
SQLFreeHandle(SQL_HANDLE_DBC, dbc);
}
if (nullptr != env) {
SQLFreeHandle(SQL_HANDLE_ENV, env);
}
}
};

// Tests a simple IAM connection with all expected fields provided.
TEST_F(IamAuthenticationIntegrationTest, SimpleIamConnection) {
auto connection_string = builder
.withServer(test_endpoint)
.withAuthHost(test_endpoint)
.withUID(iam_user)
.withPort(3306)
.withAuthPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_SUCCESS, rc);

rc = SQLDisconnect(dbc);
EXPECT_EQ(SQL_SUCCESS, rc);
}

// Tests that IAM connection will still connect to the provided server
// when the Auth host is not provided.
TEST_F(IamAuthenticationIntegrationTest, ServerWithNoAuthHost) {
auto connection_string = builder
.withServer(test_endpoint)
.withUID(iam_user)
.withPort(3306)
.withAuthPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_SUCCESS, rc);

rc = SQLDisconnect(dbc);
EXPECT_EQ(SQL_SUCCESS, rc);
}

// Tests that IAM connection will still connect via the provided port
// when the auth port is not provided.
TEST_F(IamAuthenticationIntegrationTest, PortWithNoAuthPort) {
auto connection_string = builder
.withServer(test_endpoint)
.withAuthHost(test_endpoint)
.withUID(iam_user)
.withPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_SUCCESS, rc);

rc = SQLDisconnect(dbc);
EXPECT_EQ(SQL_SUCCESS, rc);
}

// Tests that IAM connection will still connect
// when given an IP address instead of a cluster name.
TEST_F(IamAuthenticationIntegrationTest, ConnectToIpAddress) {
auto ip_address = host_to_IP(test_endpoint);

auto connection_string = builder
.withServer(ip_address)
.withAuthHost(test_endpoint)
.withUID(iam_user)
.withPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_SUCCESS, rc);

rc = SQLDisconnect(dbc);
EXPECT_EQ(SQL_SUCCESS, rc);
}

// Tests that IAM connection will still connect
// when given a wrong password (because the password gets replaced by the auth token).
TEST_F(IamAuthenticationIntegrationTest, WrongPassword) {
auto connection_string = builder
.withServer(test_endpoint)
.withAuthHost(test_endpoint)
.withUID(iam_user)
.withPWD("WRONG_PASSWORD")
.withPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_SUCCESS, rc);

rc = SQLDisconnect(dbc);
EXPECT_EQ(SQL_SUCCESS, rc);
}

// Tests that the IAM connection will fail when provided a wrong user.
TEST_F(IamAuthenticationIntegrationTest, WrongUser) {
auto connection_string = builder
.withServer(test_endpoint)
.withAuthHost(test_endpoint)
.withUID("WRONG_USER")
.withPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_ERROR, rc);

SQLSMALLINT stmt_length;
SQLINTEGER native_err;
SQLCHAR msg[SQL_MAX_MESSAGE_LENGTH] = "\0", state[6] = "\0";
rc = SQLError(nullptr, dbc, nullptr, state, &native_err,
msg, SQL_MAX_MESSAGE_LENGTH - 1, &stmt_length);
EXPECT_EQ(SQL_SUCCESS, rc);

const std::string state_str = reinterpret_cast<char*>(state);
EXPECT_EQ("HY000", state_str);
}

// Tests that the IAM connection will fail when provided an empty user.
TEST_F(IamAuthenticationIntegrationTest, EmptyUser) {
auto connection_string = builder
.withServer(test_endpoint)
.withAuthHost(test_endpoint)
.withUID("")
.withPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_ERROR, rc);

SQLSMALLINT stmt_length;
SQLINTEGER native_err;
SQLCHAR msg[SQL_MAX_MESSAGE_LENGTH] = "\0", state[6] = "\0";
rc = SQLError(nullptr, dbc, nullptr, state, &native_err,
msg, SQL_MAX_MESSAGE_LENGTH - 1, &stmt_length);
EXPECT_EQ(SQL_SUCCESS, rc);

const std::string state_str = reinterpret_cast<char*>(state);
EXPECT_EQ("HY000", state_str);
}
4 changes: 2 additions & 2 deletions util/installer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,8 @@ DataSource *ds_new()
ds->port = 3306;
ds->has_port = false;
ds->no_schema = 1;
ds->auth_port = 0;
ds->auth_expiration = 0;
ds->auth_port = 3306;
ds->auth_expiration = 900; // 15 minutes
ds->enable_cluster_failover = true;
ds->allow_reader_connections = false;
ds->gather_perf_metrics = false;
Expand Down

0 comments on commit dfbbbe0

Please sign in to comment.