Skip to content

Commit

Permalink
Secrets manager integration test (#127)
Browse files Browse the repository at this point in the history
* create secrets in java util

* simple sm test

* simplify setup

* fix integration

* comment out toxiporxy

* comment out toxiporxy

* sm test not inherit from base anymore

* fix build

* fix set error message, fix dsn not picked up

* add {} in connection string

* fix curly braces

* fix log

* fix build

* more logging

* more logging

* fix test

* better error handling

* uncomment

* address comments

* revert
  • Loading branch information
yanw-bq authored and justing-bq committed May 4, 2023
1 parent b5ae0e5 commit 4d240ad
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/dockerized.yml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ jobs:
- name: 'Display and save log'
if: always()
working-directory: ${{ github.workspace }}/build
working-directory: ${{ github.workspace }}
run: |
echo "Displaying logs"
mkdir -p ./reports/tests
Expand Down
45 changes: 29 additions & 16 deletions driver/secrets_manager_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,25 @@ bool SECRETS_MANAGER_PROXY::connect(const char* host, const char* user, const ch

if (this->secret_key.first.empty()) {
const auto error = "Missing required config parameter for Secrets Manager: Secret ID";
MYLOG_DBC_TRACE(dbc, error);
MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] %s", error);
this->set_custom_error_message(error);
return false;
}

bool fetched = update_secret(false);
std::string username = get_from_secret_json_value(USERNAME_KEY);
std::string password = get_from_secret_json_value(PASSWORD_KEY);
fetched = false;
if (username.empty() || password.empty()) {
const auto error = "Failed to fetch username or password from Secrets Manager.";
MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] %s", error);
this->set_custom_error_message(error);
return false;
}
bool ret = next_proxy->connect(host, username.c_str(), password.c_str(), database, port, unix_socket, flags);

if (!ret && next_proxy->error_code() == ER_ACCESS_DENIED_ERROR && !fetched) {
// Login unsuccessful with cached credentials
// Try to re-fetch credentials and try again
MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] Login failed with cached credentials");
fetched = update_secret(true);
if (fetched) {
username = get_from_secret_json_value(USERNAME_KEY);
Expand All @@ -115,10 +120,10 @@ bool SECRETS_MANAGER_PROXY::update_secret(bool force_re_fetch) {

const auto search = secrets_cache.find(this->secret_key);
if (search != secrets_cache.end() && !force_re_fetch) {
MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] Fetching credentials from cache.");
this->secret_json_value = search->second;
}
else {
this->secret_json_value = fetch_latest_credentials();
else if (fetch_latest_credentials()) {
fetched = true;
secrets_cache[this->secret_key] = this->secret_json_value;
}
Expand All @@ -127,32 +132,39 @@ bool SECRETS_MANAGER_PROXY::update_secret(bool force_re_fetch) {
return fetched;
}

Aws::Utils::Json::JsonValue SECRETS_MANAGER_PROXY::fetch_latest_credentials() const {
bool SECRETS_MANAGER_PROXY::fetch_latest_credentials() {
Aws::String secret_string;
MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] Fetching credentials from Secrets Manager Service.");

Model::GetSecretValueRequest request;
request.SetSecretId(this->secret_key.first);
auto get_secret_value_outcome = this->sm_client->GetSecretValue(request);

if (get_secret_value_outcome.IsSuccess()) {
secret_string = get_secret_value_outcome.GetResult().GetSecretString();
}
else {
MYLOG_DBC_TRACE(dbc, get_secret_value_outcome.GetError().GetMessage().c_str());
const auto error_message = get_secret_value_outcome.GetError().GetMessage().c_str();
MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] %s", error_message);
this->set_custom_error_message(error_message);
return false;
}
return parse_json_value(secret_string);
}

Aws::Utils::Json::JsonValue SECRETS_MANAGER_PROXY::parse_json_value(Aws::String json_string) const {
auto res_json = Aws::Utils::Json::JsonValue(json_string);
bool SECRETS_MANAGER_PROXY::parse_json_value(Aws::String json_string) {
const auto res_json = Aws::Utils::Json::JsonValue(json_string);
if (!res_json.WasParseSuccessful()) {
MYLOG_DBC_TRACE(dbc, res_json.GetErrorMessage().c_str());
throw std::runtime_error("Error parsing secrets manager response body. " + res_json.GetErrorMessage());
const auto error_message = res_json.GetErrorMessage().c_str();
MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] %s", error_message);
this->set_custom_error_message(error_message);
return false;
}
return res_json;

this->secret_json_value = res_json;
return true;
}

std::string SECRETS_MANAGER_PROXY::get_from_secret_json_value(std::string key) const {
std::string SECRETS_MANAGER_PROXY::get_from_secret_json_value(std::string key) {
std::string value;
const auto view = this->secret_json_value.View();

Expand All @@ -161,8 +173,9 @@ std::string SECRETS_MANAGER_PROXY::get_from_secret_json_value(std::string key) c
}
else {
const auto error = "Unable to extract the " + key + " from secrets manager response.";
MYLOG_DBC_TRACE(dbc, error.c_str());
throw std::runtime_error(error);
MYLOG_DBC_TRACE(dbc, "[SECRETS_MANAGER_PROXY] %s", error.c_str());
this->set_custom_error_message(error.c_str());
return std::string();
}
return value;
}
6 changes: 3 additions & 3 deletions driver/secrets_manager_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ class SECRETS_MANAGER_PROXY : public CONNECTION_PROXY {
Aws::Utils::Json::JsonValue secret_json_value;

bool update_secret(bool force_re_fetch);
Aws::Utils::Json::JsonValue fetch_latest_credentials() const;
Aws::Utils::Json::JsonValue parse_json_value(Aws::String json_string) const;
std::string get_from_secret_json_value(std::string key) const;
bool fetch_latest_credentials();
bool parse_json_value(Aws::String json_string);
std::string get_from_secret_json_value(std::string key);

static std::map<std::pair<Aws::String, Aws::String>, Aws::Utils::Json::JsonValue> secrets_cache;
static std::mutex secrets_cache_mutex;
Expand Down
2 changes: 1 addition & 1 deletion integration/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ endif()
enable_testing()

set(TEST_SOURCES connection_string_builder.cc base_failover_integration_test.cc connection_string_builder_test.cc)
set(INTEGRATION_TESTS network_failover_integration_test.cc failover_integration_test.cc)
set(INTEGRATION_TESTS secrets_manager_integration_test.cc network_failover_integration_test.cc failover_integration_test.cc)

if(NOT ENABLE_PERFORMANCE_TESTS)
set(TEST_SOURCES ${TEST_SOURCES} ${INTEGRATION_TESTS})
Expand Down
152 changes: 152 additions & 0 deletions integration/secrets_manager_integration_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include <gtest/gtest.h>
#include <sql.h>
#include <sqlext.h>

#include <cassert>
#include <chrono>
#include <climits>
#include <cstdlib>
#include <iostream>
#include <random>
#include <stdexcept>

#include "connection_string_builder.cc"

#define MAX_NAME_LEN 255

#define AS_SQLCHAR(str) const_cast<SQLCHAR*>(reinterpret_cast<const SQLCHAR*>(str))

static int str_to_int(const char* str) {
const long int x = strtol(str, nullptr, 10);
assert(x <= INT_MAX);
assert(x >= INT_MIN);
return static_cast<int>(x);
}

class SecretsManagerIntegrationTest : public testing::Test {
protected:
std::string SECRETS_ARN = std::getenv("SECRETS_ARN");
char* dsn = std::getenv("TEST_DSN");

int MYSQL_PORT = str_to_int(std::getenv("MYSQL_PORT"));

std::string MYSQL_CLUSTER_URL = std::getenv("TEST_SERVER");

SQLHENV env = nullptr;
SQLHDBC dbc = nullptr;

ConnectionStringBuilder builder;
std::string connection_string;

static void SetUpTestSuite() {
}

static void TearDownTestSuite() {
}

void SetUp() override {
SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env);
SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast<SQLPOINTER>(SQL_OV_ODBC3), 0);
SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc);

builder = ConnectionStringBuilder();
builder.withPort(MYSQL_PORT).withLogQuery(true);
}

void TearDown() override {
if (nullptr != dbc) {
SQLFreeHandle(SQL_HANDLE_DBC, dbc);
}
if (nullptr != env) {
SQLFreeHandle(SQL_HANDLE_ENV, env);
}
}
};

TEST_F(SecretsManagerIntegrationTest, EnableSecretsManager) {
connection_string = builder
.withDSN(dsn)
.withServer(MYSQL_CLUSTER_URL)
.withAuthMode("SECRETS MANAGER")
.withAuthRegion("us-east-2")
.withSecretId(SECRETS_ARN)
.build();
SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;

EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));
EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc));
}

TEST_F(SecretsManagerIntegrationTest, EnableSecretsManagerWrongRegion) {
connection_string = builder
.withDSN(dsn)
.withServer(MYSQL_CLUSTER_URL)
.withAuthMode("SECRETS MANAGER")
.withAuthRegion("us-east-1")
.withSecretId(SECRETS_ARN)
.build();
SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;

EXPECT_EQ(SQL_ERROR, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));

// Check state
SQLCHAR sqlstate[6] = "\0", message[SQL_MAX_MESSAGE_LENGTH] = "\0";;
SQLINTEGER native_error = 0;
SQLSMALLINT stmt_length;
EXPECT_EQ(SQL_SUCCESS, SQLError(env, dbc, nullptr, sqlstate, &native_error, message, SQL_MAX_MESSAGE_LENGTH - 1, &stmt_length));
const std::string state = reinterpret_cast<char*>(sqlstate);
EXPECT_EQ("HY000", state);
}

TEST_F(SecretsManagerIntegrationTest, EnableSecretsManagerInvalidSecretID) {
connection_string = builder
.withDSN(dsn)
.withServer(MYSQL_CLUSTER_URL)
.withAuthMode("SECRETS MANAGER")
.withAuthRegion("us-east-2")
.withSecretId("invalid-id")
.build();
SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;

EXPECT_EQ(SQL_ERROR, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));

// Check state
SQLCHAR sqlstate[6] = "\0", message[SQL_MAX_MESSAGE_LENGTH] = "\0";;
SQLINTEGER native_error = 0;
SQLSMALLINT stmt_length;
EXPECT_EQ(SQL_SUCCESS, SQLError(env, dbc, nullptr, sqlstate, &native_error, message, SQL_MAX_MESSAGE_LENGTH - 1, &stmt_length));
const std::string state = reinterpret_cast<char*>(sqlstate);
EXPECT_EQ("HY000", state);
}
3 changes: 3 additions & 0 deletions testframework/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ repositories {
dependencies {
testImplementation("com.amazonaws:aws-java-sdk-rds:1.12.150")
testImplementation("com.amazonaws:aws-java-sdk-ec2:1.12.154")
testImplementation("software.amazon.awssdk:secretsmanager:2.20.34")
testImplementation("org.junit.jupiter:junit-jupiter-api:5.8.2")
testImplementation("org.testcontainers:toxiproxy:1.16.3")
testImplementation("org.testcontainers:mysql:1.16.3")
testImplementation("org.json:json:20230227")

testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine")
}

Expand Down
11 changes: 9 additions & 2 deletions testframework/src/test/java/host/IntegrationContainerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public class IntegrationContainerTest {
private static String dbHostClusterRo = "";
private static String runnerIP = null;
private static String dbConnStrSuffix = "";
private static String secretsArn = "";

private static final Network NETWORK = Network.newNetwork();

Expand All @@ -97,13 +98,16 @@ static void tearDown() {
auroraUtil.deleteCluster(TEST_DB_CLUSTER_IDENTIFIER);
}

if (!StringUtils.isNullOrEmpty(secretsArn)) {
auroraUtil.deleteSecrets(secretsArn);
}

auroraUtil.ec2DeauthorizesIP(runnerIP);

for (ToxiproxyContainer proxy : proxyContainers) {
proxy.stop();
}
}

testContainer.stop();
if (mysqlContainer != null) {
mysqlContainer.stop();
Expand Down Expand Up @@ -164,6 +168,8 @@ private void setupFailoverIntegrationTests(final Network network) throws Interru
dbHostClusterRo = clusterInfo.getClusterROEndpoint();

mySqlInstances = clusterInfo.getInstances();
String secretValue = auroraUtil.createSecretValue(dbHostCluster, TEST_USERNAME, TEST_PASSWORD);
secretsArn = auroraUtil.createSecrets("AWS-MySQL-ODBC-Tests-" + dbHostCluster, secretValue);

proxyContainers = containerHelper.createProxyContainers(network, mySqlInstances, PROXIED_DOMAIN_NAME_SUFFIX);
for (ToxiproxyContainer container : proxyContainers) {
Expand Down Expand Up @@ -200,7 +206,8 @@ private void setupFailoverIntegrationTests(final Network network) throws Interru
.withEnv("TEST_SERVER", dbHostCluster)
.withEnv("TEST_RO_SERVER", dbHostClusterRo)
.withEnv("DB_CONN_STR_SUFFIX", "." + dbConnStrSuffix)
.withEnv("PROXIED_CLUSTER_TEMPLATE", "?." + dbConnStrSuffix + PROXIED_DOMAIN_NAME_SUFFIX);
.withEnv("PROXIED_CLUSTER_TEMPLATE", "?." + dbConnStrSuffix + PROXIED_DOMAIN_NAME_SUFFIX)
.withEnv("SECRETS_ARN", secretsArn);

// Add mysql instances & proxies to container env
for (int i = 0; i < mySqlInstances.size(); i++) {
Expand Down
Loading

0 comments on commit 4d240ad

Please sign in to comment.