Skip to content

Commit

Permalink
refactor integration tests to use shared pointer for rds client
Browse files Browse the repository at this point in the history
  • Loading branch information
yanw-bq committed Apr 21, 2023
1 parent 254cb1a commit 2ec1ae5
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 27 deletions.
32 changes: 17 additions & 15 deletions integration/base_failover_integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include <climits>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <random>
#include <stdexcept>
#include <sql.h>
Expand All @@ -65,6 +66,7 @@ static std::string DOWN_STREAM_STR = "DOWNSTREAM";
static std::string UP_STREAM_STR = "UPSTREAM";

static Aws::SDKOptions options;
static std::shared_ptr<Aws::RDS::RDSClient> rds_client;

class BaseFailoverIntegrationTest : public testing::Test {
protected:
Expand Down Expand Up @@ -219,15 +221,15 @@ class BaseFailoverIntegrationTest : public testing::Test {

// Helper functions from integration tests

static std::vector<std::string> retrieve_topology_via_SDK(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id) {
static std::vector<std::string> retrieve_topology_via_SDK(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id) {
std::vector<std::string> instances;

std::string writer;
std::vector<std::string> readers;

Aws::RDS::Model::DescribeDBClustersRequest rds_req;
rds_req.WithDBClusterIdentifier(cluster_id);
auto outcome = client.DescribeDBClusters(rds_req);
auto outcome = client->DescribeDBClusters(rds_req);

if (!outcome.IsSuccess()) {
throw std::runtime_error("Unable to get cluster topology using SDK.");
Expand All @@ -252,15 +254,15 @@ class BaseFailoverIntegrationTest : public testing::Test {
return instances;
}

static Aws::RDS::Model::DBCluster get_DB_cluster(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id) {
static Aws::RDS::Model::DBCluster get_DB_cluster(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id) {
Aws::RDS::Model::DescribeDBClustersRequest rds_req;
rds_req.WithDBClusterIdentifier(cluster_id);
auto outcome = client.DescribeDBClusters(rds_req);
auto outcome = client->DescribeDBClusters(rds_req);
const auto result = outcome.GetResult();
return result.GetDBClusters().at(0);
}

static void wait_until_cluster_has_right_state(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id) {
static void wait_until_cluster_has_right_state(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id) {
Aws::String status = get_DB_cluster(client, cluster_id).GetStatus();

while (status != "available") {
Expand All @@ -269,7 +271,7 @@ class BaseFailoverIntegrationTest : public testing::Test {
}
}

static Aws::RDS::Model::DBClusterMember get_DB_cluster_writer_instance(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id) {
static Aws::RDS::Model::DBClusterMember get_DB_cluster_writer_instance(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id) {
Aws::RDS::Model::DBClusterMember instance;
const Aws::RDS::Model::DBCluster cluster = get_DB_cluster(client, cluster_id);
for (const auto& member : cluster.GetDBClusterMembers()) {
Expand All @@ -280,11 +282,11 @@ class BaseFailoverIntegrationTest : public testing::Test {
return instance;
}

static Aws::String get_DB_cluster_writer_instance_id(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id) {
static Aws::String get_DB_cluster_writer_instance_id(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id) {
return get_DB_cluster_writer_instance(client, cluster_id).GetDBInstanceIdentifier();
}

static void wait_until_writer_instance_changed(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id,
static void wait_until_writer_instance_changed(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id,
const Aws::String& initial_writer_instance_id) {
Aws::String next_cluster_writer_id = get_DB_cluster_writer_instance_id(client, cluster_id);
while (initial_writer_instance_id == next_cluster_writer_id) {
Expand All @@ -293,14 +295,14 @@ class BaseFailoverIntegrationTest : public testing::Test {
}
}

static void failover_cluster(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id, const Aws::String& target_instance_id = "") {
static void failover_cluster(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id, const Aws::String& target_instance_id = "") {
wait_until_cluster_has_right_state(client, cluster_id);
Aws::RDS::Model::FailoverDBClusterRequest rds_req;
rds_req.WithDBClusterIdentifier(cluster_id);
if (!target_instance_id.empty()) {
rds_req.WithTargetDBInstanceIdentifier(target_instance_id);
}
auto outcome = client.FailoverDBCluster(rds_req);
auto outcome = client->FailoverDBCluster(rds_req);
}

static Aws::String get_random_DB_cluster_reader_instance_id(std::vector<std::string> readers) {
Expand All @@ -310,7 +312,7 @@ class BaseFailoverIntegrationTest : public testing::Test {
return readers.at(distribution(generator));
}

static bool has_writer_changed(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id, std::string initial_writer_id, std::chrono::nanoseconds timeout) {
static bool has_writer_changed(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id, std::string initial_writer_id, std::chrono::nanoseconds timeout) {
auto start = std::chrono::high_resolution_clock::now();

std::string current_writer_id = get_DB_cluster_writer_instance_id(client, cluster_id);
Expand All @@ -325,7 +327,7 @@ class BaseFailoverIntegrationTest : public testing::Test {
return true;
}

static void failover_cluster_and_wait_until_writer_changed(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id,
static void failover_cluster_and_wait_until_writer_changed(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id,
const Aws::String& initial_writer_id,
const Aws::String& target_writer_id = "") {

Expand Down Expand Up @@ -357,7 +359,7 @@ class BaseFailoverIntegrationTest : public testing::Test {
}
}

static Aws::RDS::Model::DBClusterMember get_matched_DBClusterMember(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id,
static Aws::RDS::Model::DBClusterMember get_matched_DBClusterMember(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id,
const Aws::String& instance_id) {
Aws::RDS::Model::DBClusterMember instance;
const Aws::RDS::Model::DBCluster cluster = get_DB_cluster(client, cluster_id);
Expand All @@ -370,11 +372,11 @@ class BaseFailoverIntegrationTest : public testing::Test {
return instance;
}

static bool is_DB_instance_writer(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id, const Aws::String& instance_id) {
static bool is_DB_instance_writer(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id, const Aws::String& instance_id) {
return get_matched_DBClusterMember(client, cluster_id, instance_id).GetIsClusterWriter();
}

static bool is_DB_instance_reader(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id, const Aws::String& instance_id) {
static bool is_DB_instance_reader(std::shared_ptr<Aws::RDS::RDSClient> client, const Aws::String& cluster_id, const Aws::String& instance_id) {
return !get_matched_DBClusterMember(client, cluster_id, instance_id).GetIsClusterWriter();
}

Expand Down
4 changes: 2 additions & 2 deletions integration/failover_integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ class FailoverIntegrationTest : public BaseFailoverIntegrationTest {
Aws::String(SECRET_ACCESS_KEY),
Aws::String(SESSION_TOKEN));
Aws::Client::ClientConfiguration client_config;
Aws::RDS::RDSClient rds_client;
SQLHENV env = nullptr;
SQLHDBC dbc = nullptr;

static void SetUpTestSuite() {
Aws::InitAPI(options);
rds_client = std::make_shared<Aws::RDS::RDSClient>(credentials, client_config);
}

static void TearDownTestSuite() {
rds_client->reset();
Aws::ShutdownAPI(options);
}

Expand All @@ -55,7 +56,6 @@ class FailoverIntegrationTest : public BaseFailoverIntegrationTest {
SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast<SQLPOINTER>(SQL_OV_ODBC3), 0);
SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc);
client_config.region = "us-east-2";
rds_client = Aws::RDS::RDSClient(credentials, client_config);

cluster_instances = retrieve_topology_via_SDK(rds_client, cluster_id);
writer_id = get_writer_id(cluster_instances);
Expand Down
3 changes: 2 additions & 1 deletion integration/failover_performance_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ class FailoverPerformanceTest :
Aws::String(SECRET_ACCESS_KEY),
Aws::String(SESSION_TOKEN));
Aws::Client::ClientConfiguration client_config;
Aws::RDS::RDSClient rds_client;
SQLHENV env = nullptr;
SQLHDBC dbc = nullptr;

Expand All @@ -161,9 +160,11 @@ class FailoverPerformanceTest :

static void SetUpTestSuite() {
Aws::InitAPI(options);
rds_client = std::make_shared<Aws::RDS::RDSClient>(credentials, client_config);
}

static void TearDownTestSuite() {
rds_client->reset();
Aws::ShutdownAPI(options);

// Save results to spreadsheet
Expand Down
12 changes: 6 additions & 6 deletions integration/network_failover_integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ class NetworkFailoverIntegrationTest : public BaseFailoverIntegrationTest {
Aws::String(SECRET_ACCESS_KEY),
Aws::String(SESSION_TOKEN));
Aws::Client::ClientConfiguration client_config;
Aws::RDS::RDSClient rds_client;
SQLHENV env = nullptr;
SQLHDBC dbc = nullptr;

static void SetUpTestSuite() {
Aws::InitAPI(options);
rds_client = std::make_shared<Aws::RDS::RDSClient>(credentials, client_config);
}

static void TearDownTestSuite() {
rds_client->reset();
Aws::ShutdownAPI(options);
}

Expand All @@ -56,7 +57,6 @@ class NetworkFailoverIntegrationTest : public BaseFailoverIntegrationTest {
SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc);

client_config.region = "us-east-2";
rds_client = Aws::RDS::RDSClient(credentials, client_config);

for (const auto& x : proxy_map) {
enable_connectivity(x.second);
Expand Down Expand Up @@ -91,7 +91,7 @@ class NetworkFailoverIntegrationTest : public BaseFailoverIntegrationTest {
}
};

TEST_F(NetworkFailoverIntegrationTest, DISABLED_connection_test) {
TEST_F(NetworkFailoverIntegrationTest, connection_test) {
test_connection(dbc, MYSQL_INSTANCE_1_URL, MYSQL_PORT);
test_connection(dbc, MYSQL_INSTANCE_1_URL + PROXIED_DOMAIN_NAME_SUFFIX, MYSQL_PROXY_PORT);
test_connection(dbc, MYSQL_CLUSTER_URL, MYSQL_PORT);
Expand All @@ -100,7 +100,7 @@ TEST_F(NetworkFailoverIntegrationTest, DISABLED_connection_test) {
test_connection(dbc, MYSQL_RO_CLUSTER_URL + PROXIED_DOMAIN_NAME_SUFFIX, MYSQL_PROXY_PORT);
}

TEST_F(NetworkFailoverIntegrationTest, DISABLED_lost_connection_to_writer) {
TEST_F(NetworkFailoverIntegrationTest, lost_connection_to_writer) {
const std::string server = get_proxied_endpoint(writer_id);
connection_string = builder.withServer(server).withFailoverTimeout(GLOBAL_FAILOVER_TIMEOUT).build();
EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));
Expand All @@ -121,7 +121,7 @@ TEST_F(NetworkFailoverIntegrationTest, DISABLED_lost_connection_to_writer) {
EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc));
}

TEST_F(NetworkFailoverIntegrationTest, DISABLED_use_same_connection_after_failing_failover) {
TEST_F(NetworkFailoverIntegrationTest, use_same_connection_after_failing_failover) {
const std::string server = get_proxied_endpoint(writer_id);
connection_string = builder.withServer(server).withFailoverTimeout(GLOBAL_FAILOVER_TIMEOUT).build();
EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));
Expand All @@ -148,7 +148,7 @@ TEST_F(NetworkFailoverIntegrationTest, DISABLED_use_same_connection_after_failin
EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc));
}

TEST_F(NetworkFailoverIntegrationTest, DISABLED_lost_connection_to_all_readers) {
TEST_F(NetworkFailoverIntegrationTest, lost_connection_to_all_readers) {
connection_string = builder.withServer(reader_endpoint).build();
EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out, MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));

Expand Down
4 changes: 1 addition & 3 deletions testframework/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ group = "software.aws.rds"
version = "1.0-SNAPSHOT"

repositories {
maven {
url = uri("https://repo1.maven.org/maven2")
}
mavenCentral()
}

dependencies {
Expand Down

0 comments on commit 2ec1ae5

Please sign in to comment.