Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IAM Authentication Integration Tests #126

Merged
merged 2 commits into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions driver/iam_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ IAM_PROXY::~IAM_PROXY() {
bool IAM_PROXY::connect(const char* host, const char* user, const char* password,
const char* database, unsigned int port, const char* socket, unsigned long flags) {

const char* auth_host = host;
if (ds->auth_host) {
host = ds_get_utf8attr(ds->auth_host, &ds->auth_host8);
auth_host = ds_get_utf8attr(ds->auth_host, &ds->auth_host8);
}

const char* region;
Expand All @@ -76,12 +77,10 @@ bool IAM_PROXY::connect(const char* host, const char* user, const char* password
}
else {
// Go with default region if region is not provided.
region = "us-east-1";
region = Aws::Region::US_EAST_1;
}

port = ds->auth_port;

std::string auth_token = this->get_auth_token(host, region, port, user, ds->auth_expiration);
std::string auth_token = this->get_auth_token(auth_host, region, ds->auth_port, user, ds->auth_expiration);

bool connect_result = next_proxy->connect(host, user, auth_token.c_str(), database, port, socket, flags);
if (!connect_result) {
Expand Down
11 changes: 9 additions & 2 deletions integration/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,15 @@ endif()

enable_testing()

set(TEST_SOURCES connection_string_builder.cc base_failover_integration_test.cc connection_string_builder_test.cc)
set(INTEGRATION_TESTS secrets_manager_integration_test.cc network_failover_integration_test.cc failover_integration_test.cc)
set(TEST_SOURCES
connection_string_builder.cc
base_failover_integration_test.cc
connection_string_builder_test.cc)
set(INTEGRATION_TESTS
iam_authentication_integration_test.cc
secrets_manager_integration_test.cc
network_failover_integration_test.cc
failover_integration_test.cc)

if(NOT ENABLE_PERFORMANCE_TESTS)
set(TEST_SOURCES ${TEST_SOURCES} ${INTEGRATION_TESTS})
Expand Down
318 changes: 318 additions & 0 deletions integration/iam_authentication_integration_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include <gtest/gtest.h>
#include <httplib.h>
#include <sql.h>
#include <sqlext.h>

#include "connection_string_builder.cc"

#define MAX_NAME_LEN 255
#define AS_SQLCHAR(str) const_cast<SQLCHAR*>(reinterpret_cast<const SQLCHAR*>(str))

// Connection string parameters
static char* test_dsn;
static char* test_db;
static char* test_user;
static char* test_pwd;
static char* iam_user;

static std::string test_endpoint;

static std::string host_to_IP(std::string hostname) {
int status;
struct addrinfo hints;
struct addrinfo* servinfo;
struct addrinfo* p;
char ipstr[INET_ADDRSTRLEN];

memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET; //IPv4
hints.ai_socktype = SOCK_STREAM;

if ((status = getaddrinfo(hostname.c_str(), NULL, &hints, &servinfo)) != 0) {
ADD_FAILURE() << "The IP address of host " << hostname << " could not be determined."
<< "getaddrinfo error:" << gai_strerror(status);
return {};
}

for (p = servinfo; p != NULL; p = p->ai_next) {
void* addr;

struct sockaddr_in* ipv4 = (struct sockaddr_in*)p->ai_addr;
addr = &(ipv4->sin_addr);
inet_ntop(p->ai_family, addr, ipstr, sizeof(ipstr));
}

freeaddrinfo(servinfo);
return std::string(ipstr);
}

class IamAuthenticationIntegrationTest : public testing::Test {
protected:
ConnectionStringBuilder builder;
SQLHENV env = nullptr;
SQLHDBC dbc = nullptr;

static void SetUpTestSuite() {
test_endpoint = std::getenv("TEST_SERVER");

test_dsn = std::getenv("TEST_DSN");
test_db = std::getenv("TEST_DATABASE");
test_user = std::getenv("TEST_UID");
test_pwd = std::getenv("TEST_PASSWORD");
iam_user = "john_doe";

auto conn_str_builder = ConnectionStringBuilder();
auto conn_str = conn_str_builder
.withDSN(test_dsn)
.withServer(test_endpoint)
.withUID(test_user)
.withPWD(test_pwd)
.withPort(3306)
.withDatabase(test_db).build();

SQLHENV env1 = nullptr;
SQLHDBC dbc1 = nullptr;
SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env1);
SQLSetEnvAttr(env1, SQL_ATTR_ODBC_VERSION, reinterpret_cast<SQLPOINTER>(SQL_OV_ODBC3), 0);
SQLAllocHandle(SQL_HANDLE_DBC, env1, &dbc1);

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
EXPECT_EQ(SQL_SUCCESS,
SQLDriverConnect(dbc1, nullptr, AS_SQLCHAR(conn_str.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));

SQLHSTMT stmt = nullptr;
EXPECT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, dbc1, &stmt));

char query_buffer[200];
sprintf(query_buffer, "DROP USER IF EXISTS %s;", iam_user);
SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS);
memset(query_buffer, 0, sizeof(query_buffer));

sprintf(query_buffer, "CREATE USER %s IDENTIFIED WITH AWSAuthenticationPlugin AS 'RDS';", iam_user);
EXPECT_EQ(SQL_SUCCESS, SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS));
memset(query_buffer, 0, sizeof(query_buffer));

sprintf(query_buffer, "GRANT ALL ON `%`.* TO %s@`%`;", iam_user);
EXPECT_EQ(SQL_SUCCESS, SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS));

EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt));
EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc1));

if (nullptr != stmt) {
SQLFreeHandle(SQL_HANDLE_STMT, stmt);
}
if (nullptr != dbc1) {
SQLDisconnect(dbc1);
SQLFreeHandle(SQL_HANDLE_DBC, dbc1);
}
if (nullptr != env1) {
SQLFreeHandle(SQL_HANDLE_ENV, env1);
}
}

static void TearDownTestSuite() {
}

void SetUp() override {
SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env);
SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast<SQLPOINTER>(SQL_OV_ODBC3), 0);
SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc);

builder = ConnectionStringBuilder();
builder
.withDSN(test_dsn)
.withDatabase(test_db)
.withEnableClusterFailover(false) // Failover interferes with some of our tests
.withAuthMode("IAM")
.withAuthRegion("us-east-2")
.withAuthExpiration(900);
}

void TearDown() override {
if (nullptr != dbc) {
SQLFreeHandle(SQL_HANDLE_DBC, dbc);
}
if (nullptr != env) {
SQLFreeHandle(SQL_HANDLE_ENV, env);
}
}
};

// Tests a simple IAM connection with all expected fields provided.
TEST_F(IamAuthenticationIntegrationTest, SimpleIamConnection) {
auto connection_string = builder
.withServer(test_endpoint)
.withAuthHost(test_endpoint)
.withUID(iam_user)
.withPort(3306)
.withAuthPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_SUCCESS, rc);

rc = SQLDisconnect(dbc);
EXPECT_EQ(SQL_SUCCESS, rc);
}

// Tests that IAM connection will still connect to the provided server
// when the Auth host is not provided.
TEST_F(IamAuthenticationIntegrationTest, ServerWithNoAuthHost) {
auto connection_string = builder
.withServer(test_endpoint)
.withUID(iam_user)
.withPort(3306)
.withAuthPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_SUCCESS, rc);

rc = SQLDisconnect(dbc);
EXPECT_EQ(SQL_SUCCESS, rc);
}

// Tests that IAM connection will still connect via the provided port
// when the auth port is not provided.
TEST_F(IamAuthenticationIntegrationTest, PortWithNoAuthPort) {
auto connection_string = builder
.withServer(test_endpoint)
.withAuthHost(test_endpoint)
.withUID(iam_user)
.withPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_SUCCESS, rc);

rc = SQLDisconnect(dbc);
EXPECT_EQ(SQL_SUCCESS, rc);
}

// Tests that IAM connection will still connect
// when given an IP address instead of a cluster name.
TEST_F(IamAuthenticationIntegrationTest, ConnectToIpAddress) {
auto ip_address = host_to_IP(test_endpoint);

auto connection_string = builder
.withServer(ip_address)
.withAuthHost(test_endpoint)
.withUID(iam_user)
.withPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_SUCCESS, rc);

rc = SQLDisconnect(dbc);
EXPECT_EQ(SQL_SUCCESS, rc);
}

// Tests that IAM connection will still connect
// when given a wrong password (because the password gets replaced by the auth token).
TEST_F(IamAuthenticationIntegrationTest, WrongPassword) {
auto connection_string = builder
.withServer(test_endpoint)
.withAuthHost(test_endpoint)
.withUID(iam_user)
.withPWD("WRONG_PASSWORD")
.withPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_SUCCESS, rc);

rc = SQLDisconnect(dbc);
EXPECT_EQ(SQL_SUCCESS, rc);
}

// Tests that the IAM connection will fail when provided a wrong user.
TEST_F(IamAuthenticationIntegrationTest, WrongUser) {
auto connection_string = builder
.withServer(test_endpoint)
.withAuthHost(test_endpoint)
.withUID("WRONG_USER")
.withPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_ERROR, rc);

SQLSMALLINT stmt_length;
SQLINTEGER native_err;
SQLCHAR msg[SQL_MAX_MESSAGE_LENGTH] = "\0", state[6] = "\0";
rc = SQLError(nullptr, dbc, nullptr, state, &native_err,
msg, SQL_MAX_MESSAGE_LENGTH - 1, &stmt_length);
EXPECT_EQ(SQL_SUCCESS, rc);

const std::string state_str = reinterpret_cast<char*>(state);
EXPECT_EQ("HY000", state_str);
}

// Tests that the IAM connection will fail when provided an empty user.
TEST_F(IamAuthenticationIntegrationTest, EmptyUser) {
auto connection_string = builder
.withServer(test_endpoint)
.withAuthHost(test_endpoint)
.withUID("")
.withPort(3306).build();

SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS,
conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT);
EXPECT_EQ(SQL_ERROR, rc);

SQLSMALLINT stmt_length;
SQLINTEGER native_err;
SQLCHAR msg[SQL_MAX_MESSAGE_LENGTH] = "\0", state[6] = "\0";
rc = SQLError(nullptr, dbc, nullptr, state, &native_err,
msg, SQL_MAX_MESSAGE_LENGTH - 1, &stmt_length);
EXPECT_EQ(SQL_SUCCESS, rc);

const std::string state_str = reinterpret_cast<char*>(state);
EXPECT_EQ("HY000", state_str);
}
4 changes: 2 additions & 2 deletions util/installer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,8 @@ DataSource *ds_new()
ds->port = 3306;
ds->has_port = false;
ds->no_schema = 1;
ds->auth_port = 0;
ds->auth_expiration = 0;
ds->auth_port = 3306;
ds->auth_expiration = 900; // 15 minutes
ds->enable_cluster_failover = true;
ds->allow_reader_connections = false;
ds->gather_perf_metrics = false;
Expand Down