diff --git a/driver/iam_proxy.cc b/driver/iam_proxy.cc index 79b8076ac..b633e7921 100644 --- a/driver/iam_proxy.cc +++ b/driver/iam_proxy.cc @@ -66,8 +66,9 @@ IAM_PROXY::~IAM_PROXY() { bool IAM_PROXY::connect(const char* host, const char* user, const char* password, const char* database, unsigned int port, const char* socket, unsigned long flags) { + const char* auth_host = host; if (ds->auth_host) { - host = ds_get_utf8attr(ds->auth_host, &ds->auth_host8); + auth_host = ds_get_utf8attr(ds->auth_host, &ds->auth_host8); } const char* region; @@ -76,12 +77,10 @@ bool IAM_PROXY::connect(const char* host, const char* user, const char* password } else { // Go with default region if region is not provided. - region = "us-east-1"; + region = Aws::Region::US_EAST_1; } - port = ds->auth_port; - - std::string auth_token = this->get_auth_token(host, region, port, user, ds->auth_expiration); + std::string auth_token = this->get_auth_token(auth_host, region, ds->auth_port, user, ds->auth_expiration); bool connect_result = next_proxy->connect(host, user, auth_token.c_str(), database, port, socket, flags); if (!connect_result) { diff --git a/integration/CMakeLists.txt b/integration/CMakeLists.txt index 29b68116d..b46000417 100644 --- a/integration/CMakeLists.txt +++ b/integration/CMakeLists.txt @@ -92,7 +92,7 @@ endif() enable_testing() set(TEST_SOURCES connection_string_builder.cc base_failover_integration_test.cc connection_string_builder_test.cc) -set(INTEGRATION_TESTS network_failover_integration_test.cc failover_integration_test.cc) +set(INTEGRATION_TESTS iam_authentication_integration_test.cc network_failover_integration_test.cc failover_integration_test.cc) if(NOT ENABLE_PERFORMANCE_TESTS) set(TEST_SOURCES ${TEST_SOURCES} ${INTEGRATION_TESTS}) diff --git a/integration/iam_authentication_integration_test.cc b/integration/iam_authentication_integration_test.cc new file mode 100644 index 000000000..f9a89ad71 --- /dev/null +++ b/integration/iam_authentication_integration_test.cc @@ -0,0 +1,362 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include +#include +#include + +#include "connection_string_builder.cc" + +#define MAX_NAME_LEN 255 +#define AS_SQLCHAR(str) const_cast(reinterpret_cast(str)) + +// Connection string parameters +static char* test_dsn; +static char* test_db; +static char* test_user; +static char* test_pwd; +static char* iam_user; + +static std::string test_endpoint; + +static std::string host_to_IP(std::string hostname) { + int status; + struct addrinfo hints; + struct addrinfo* servinfo; + struct addrinfo* p; + char ipstr[INET_ADDRSTRLEN]; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; //IPv4 + hints.ai_socktype = SOCK_STREAM; + + if ((status = getaddrinfo(hostname.c_str(), NULL, &hints, &servinfo)) != 0) { + ADD_FAILURE() << "The IP address of host " << hostname << " could not be determined." + << "getaddrinfo error:" << gai_strerror(status); + return {}; + } + + for (p = servinfo; p != NULL; p = p->ai_next) { + void* addr; + + struct sockaddr_in* ipv4 = (struct sockaddr_in*)p->ai_addr; + addr = &(ipv4->sin_addr); + inet_ntop(p->ai_family, addr, ipstr, sizeof(ipstr)); + } + + freeaddrinfo(servinfo); + return std::string(ipstr); +} + +class IamAuthenticationIntegrationTest : public testing::Test { +protected: + ConnectionStringBuilder builder; + SQLHENV env = nullptr; + SQLHDBC dbc = nullptr; + + static void SetUpTestSuite() { + test_endpoint = std::getenv("TEST_SERVER"); + + test_dsn = std::getenv("TEST_DSN"); + test_db = std::getenv("TEST_DATABASE"); + test_user = std::getenv("TEST_UID"); + test_pwd = std::getenv("TEST_PASSWORD"); + iam_user = "john_doe"; + + auto conn_str_builder = ConnectionStringBuilder(); + auto conn_str = conn_str_builder + .withDSN(test_dsn) + .withServer(test_endpoint) + .withUID(test_user) + .withPWD(test_pwd) + .withPort(3306) + .withDatabase(test_db).build(); + + SQLHENV env1 = nullptr; + SQLHDBC dbc1 = nullptr; + SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env1); + SQLSetEnvAttr(env1, SQL_ATTR_ODBC_VERSION, reinterpret_cast(SQL_OV_ODBC3), 0); + SQLAllocHandle(SQL_HANDLE_DBC, env1, &dbc1); + + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + EXPECT_EQ(SQL_SUCCESS, + SQLDriverConnect(dbc1, nullptr, AS_SQLCHAR(conn_str.c_str()), SQL_NTS, + conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT)); + + SQLHSTMT stmt = nullptr; + EXPECT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, dbc1, &stmt)); + + char query_buffer[200]; + sprintf(query_buffer, "DROP USER IF EXISTS %s;", iam_user); + SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS); + memset(query_buffer, 0, sizeof(query_buffer)); + + sprintf(query_buffer, "CREATE USER %s IDENTIFIED WITH AWSAuthenticationPlugin AS 'RDS';", iam_user); + EXPECT_EQ(SQL_SUCCESS, SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS)); + memset(query_buffer, 0, sizeof(query_buffer)); + + sprintf(query_buffer, "GRANT ALL ON `%`.* TO %s@`%`;", iam_user); + EXPECT_EQ(SQL_SUCCESS, SQLExecDirect(stmt, AS_SQLCHAR(query_buffer), SQL_NTS)); + + EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt)); + EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc1)); + + if (nullptr != stmt) { + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + } + if (nullptr != dbc1) { + SQLDisconnect(dbc1); + SQLFreeHandle(SQL_HANDLE_DBC, dbc1); + } + if (nullptr != env1) { + SQLFreeHandle(SQL_HANDLE_ENV, env1); + } + } + + static void TearDownTestSuite() { + } + + void SetUp() override { + SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env); + SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast(SQL_OV_ODBC3), 0); + SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc); + + builder = ConnectionStringBuilder(); + builder + .withDSN(test_dsn) + .withDatabase(test_db) + .withEnableClusterFailover(false) // Failover interferes with some of our tests + .withAuthMode("IAM") + .withAuthRegion("us-east-2") + .withAuthExpiration(900); + } + + void TearDown() override { + if (nullptr != dbc) { + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + } + if (nullptr != env) { + SQLFreeHandle(SQL_HANDLE_ENV, env); + } + } +}; + +// Tests a simple IAM connection with all expected fields provided. +TEST_F(IamAuthenticationIntegrationTest, SimpleIamConnection) { + auto connection_string = builder + .withServer(test_endpoint) + .withAuthHost(test_endpoint) + .withUID(iam_user) + .withPort(3306) + .withAuthPort(3306).build(); + + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, + conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT); + EXPECT_EQ(SQL_SUCCESS, rc); + + rc = SQLDisconnect(dbc); + EXPECT_EQ(SQL_SUCCESS, rc); +} + +// Tests that IAM connection will still connect to the provided server +// when the Auth host is not provided. +TEST_F(IamAuthenticationIntegrationTest, ServerWithNoAuthHost) { + auto connection_string = builder + .withServer(test_endpoint) + .withUID(iam_user) + .withPort(3306) + .withAuthPort(3306).build(); + + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, + conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT); + EXPECT_EQ(SQL_SUCCESS, rc); + + rc = SQLDisconnect(dbc); + EXPECT_EQ(SQL_SUCCESS, rc); +} + +// Tests that IAM connection will fail when we do not provide the server. +TEST_F(IamAuthenticationIntegrationTest, AuthHostWithNoServer) { + auto connection_string = builder + .withAuthHost(test_endpoint) + .withUID(iam_user) + .withPort(3306) + .withAuthPort(3306).build(); + + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, + conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT); + EXPECT_EQ(SQL_ERROR, rc); + + SQLSMALLINT stmt_length; + SQLINTEGER native_err; + SQLCHAR msg[SQL_MAX_MESSAGE_LENGTH] = "\0", state[6] = "\0"; + rc = SQLError(nullptr, dbc, nullptr, state, &native_err, + msg, SQL_MAX_MESSAGE_LENGTH - 1, &stmt_length); + EXPECT_EQ(SQL_SUCCESS, rc); + + const std::string state_str = reinterpret_cast(state); + EXPECT_EQ("HY000", state_str); +} + +// Tests that IAM connection will still connect via the provided port +// when the auth port is not provided. +TEST_F(IamAuthenticationIntegrationTest, PortWithNoAuthPort) { + auto connection_string = builder + .withServer(test_endpoint) + .withAuthHost(test_endpoint) + .withUID(iam_user) + .withPort(3306).build(); + + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, + conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT); + EXPECT_EQ(SQL_SUCCESS, rc); + + rc = SQLDisconnect(dbc); + EXPECT_EQ(SQL_SUCCESS, rc); +} + +// Tests that IAM connection will still connect via the provided auth port +// when the regular port is not provided. +TEST_F(IamAuthenticationIntegrationTest, AuthPortWithNoPort) { + auto connection_string = builder + .withServer(test_endpoint) + .withAuthHost(test_endpoint) + .withUID(iam_user) + .withAuthPort(3306).build(); + + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, + conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT); + EXPECT_EQ(SQL_SUCCESS, rc); + + rc = SQLDisconnect(dbc); + EXPECT_EQ(SQL_SUCCESS, rc); +} + +// Tests that IAM connection will still connect +// when given an IP address instead of a cluster name. +TEST_F(IamAuthenticationIntegrationTest, ConnectToIpAddress) { + auto ip_address = host_to_IP(test_endpoint); + + auto connection_string = builder + .withServer(ip_address) + .withAuthHost(test_endpoint) + .withUID(iam_user) + .withAuthPort(3306).build(); + + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, + conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT); + EXPECT_EQ(SQL_SUCCESS, rc); + + rc = SQLDisconnect(dbc); + EXPECT_EQ(SQL_SUCCESS, rc); +} + +// Tests that IAM connection will still connect +// when given a wrong password (because the password gets replaced by the auth token). +TEST_F(IamAuthenticationIntegrationTest, WrongPassword) { + auto connection_string = builder + .withServer(test_endpoint) + .withAuthHost(test_endpoint) + .withUID(iam_user) + .withPWD("WRONG_PASSWORD") + .withAuthPort(3306).build(); + + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, + conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT); + EXPECT_EQ(SQL_SUCCESS, rc); + + rc = SQLDisconnect(dbc); + EXPECT_EQ(SQL_SUCCESS, rc); +} + +// Tests that the IAM connection will fail when provided a wrong user. +TEST_F(IamAuthenticationIntegrationTest, WrongUser) { + auto connection_string = builder + .withServer(test_endpoint) + .withAuthHost(test_endpoint) + .withUID("WRONG_USER") + .withAuthPort(3306).build(); + + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, + conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT); + EXPECT_EQ(SQL_ERROR, rc); + + SQLSMALLINT stmt_length; + SQLINTEGER native_err; + SQLCHAR msg[SQL_MAX_MESSAGE_LENGTH] = "\0", state[6] = "\0"; + rc = SQLError(nullptr, dbc, nullptr, state, &native_err, + msg, SQL_MAX_MESSAGE_LENGTH - 1, &stmt_length); + EXPECT_EQ(SQL_SUCCESS, rc); + + const std::string state_str = reinterpret_cast(state); + EXPECT_EQ("HY000", state_str); +} + +// Tests that the IAM connection will fail when provided an empty user. +TEST_F(IamAuthenticationIntegrationTest, EmptyUser) { + auto connection_string = builder + .withServer(test_endpoint) + .withAuthHost(test_endpoint) + .withUID("") + .withAuthPort(3306).build(); + + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, + conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT); + EXPECT_EQ(SQL_ERROR, rc); + + SQLSMALLINT stmt_length; + SQLINTEGER native_err; + SQLCHAR msg[SQL_MAX_MESSAGE_LENGTH] = "\0", state[6] = "\0"; + rc = SQLError(nullptr, dbc, nullptr, state, &native_err, + msg, SQL_MAX_MESSAGE_LENGTH - 1, &stmt_length); + EXPECT_EQ(SQL_SUCCESS, rc); + + const std::string state_str = reinterpret_cast(state); + EXPECT_EQ("HY000", state_str); +} diff --git a/util/installer.cc b/util/installer.cc index 4a28c67ec..781da0554 100644 --- a/util/installer.cc +++ b/util/installer.cc @@ -749,8 +749,8 @@ DataSource *ds_new() ds->port = 3306; ds->has_port = false; ds->no_schema = 1; - ds->auth_port = 0; - ds->auth_expiration = 0; + ds->auth_port = 3306; + ds->auth_expiration = 900; // 15 minutes ds->enable_cluster_failover = true; ds->allow_reader_connections = false; ds->gather_perf_metrics = false;