Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a shared_ptr and thread_local to store credentials for use #202

Merged
merged 11 commits into from
Apr 16, 2018
133 changes: 16 additions & 117 deletions aws/auth/mutable_static_creds_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,137 +15,36 @@
#include <limits>

namespace {
#ifdef DEBUG
thread_local aws::auth::DebugStats debug_stats;

void update_step(std::uint64_t& step, std::uint64_t& outcome) {
step++;
outcome++;
debug_stats.attempts_++;
}

void update_step(std::uint64_t& step) {
std::uint64_t ignored = 0;
update_step(step, ignored);
}
#else
#define update_step(...)
#endif
thread_local aws::auth::VersionedCredentials current_creds;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a short comment about the role of the thread local current creds and how its version is compared against the global version would make it easier to follow.

}

using namespace aws::auth;

VersionedCredentials::VersionedCredentials(std::uint64_t version, const std::string& akid, const std::string& sk, const std::string& token) :
version_(version), creds_(Aws::Auth::AWSCredentials(akid, sk, token)) {
}

MutableStaticCredentialsProvider::MutableStaticCredentialsProvider(const std::string& akid,
const std::string& sk,
std::string token) {
Aws::Auth::AWSCredentials creds(akid, sk, token);
current_.creds_ = creds;
std::string token) :
creds_(std::make_shared<VersionedCredentials>(1, akid, sk, token)), version_(1) {
}

void MutableStaticCredentialsProvider::set_credentials(const std::string& akid, const std::string& sk, std::string token) {
std::lock_guard<std::mutex> lock(update_mutex_);

//
// The ordering of stores here is important. current_.updating_, and current_.version_ are atomics.
// Between the two of them they create an interlocked state change that allows consumers
// to detect that the credentials have changed during the process of copying the values.
//
// Specifically current_.version_ must be incremented before current_.updating_ is set to false.
// This ensures that consumers will either see a version mismatch or see the updating state change.
//
std::uint64_t next_version = version_ + 1;


current_.updating_ = true;
current_.creds_.SetAWSAccessKeyId(akid);
current_.creds_.SetAWSSecretKey(sk);
current_.creds_.SetSessionToken(token);
current_.version_++;
current_.updating_ = false;
}

bool MutableStaticCredentialsProvider::try_optimistic_read(Aws::Auth::AWSCredentials& destination) {
//
// This is an attempt to do an optimistic read. We assume that the contents of the
// credentials may change in while copying to the result.
//
// 1. To handle this we first check if an update is in progress, and if it is we bounce out
// and retry.
//
// 2. If the credential isn't being updated we go ahead and load the current version of
// the credential.
//
// 3. At this point we go ahead and trigger a copy of the credential to the result. This
// is what we will return if our remaining checks indicate everything is still ok.
//
// 4. We check again to see if the credential has entered updating, if it has it's possible
// the version of the credential we copied was split between the two updates. So we
// discard our copy, and try again.
//
// 5. Finally we make check to see that the version hasn't changed since we started. This
// ensures that the credential didn't enter and exit updates between our update checks.
//
// If everything is ok we are safe to return the credential, while it may be a nanosecnds
// out date it should be fine.
//

if (current_.updating_) {
//
// The credentials are currently being updated. It's not safe to read so
// spin while we wait for the update to complete.
//
update_step(debug_stats.update_before_load_, debug_stats.retried_);
return false;
}
std::uint64_t starting_version = current_.version_;

//
// Trigger a trivial copy to the result
//
destination = current_.creds_;

if (current_.updating_) {
//
// The credentials object is being updated and the update may have started while we were
// copying it. For safety we discard what we have and try again.
//
update_step(debug_stats.update_after_load_, debug_stats.retried_);
return false;
}

std::uint64_t ending_version = current_.version_;

if (starting_version != ending_version) {
//
// The version changed in between the start of the copy, and the end
// of the copy. We can no longer trust that the resulting copy is
// correct. So we discard what we have and try again.
//
update_step(debug_stats.version_mismatch_, debug_stats.retried_);
return false;
}
update_step(debug_stats.success_);
return true;
std::shared_ptr<VersionedCredentials> new_credentials = std::make_shared<VersionedCredentials>(next_version, akid, sk, token);
std::atomic_store(&creds_, new_credentials);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be good to add a comment that this is updating the global credentials ?

version_ = next_version;
}

Aws::Auth::AWSCredentials MutableStaticCredentialsProvider::GetAWSCredentials() {

Aws::Auth::AWSCredentials result;

if (!try_optimistic_read(result)) {
//
// The optimistic read failed, so just give up and use the lock to acquire the credentials.
//
std::lock_guard<std::mutex> lock(update_mutex_);
update_step(debug_stats.used_lock_, debug_stats.success_);
result = current_.creds_;
if (current_creds.version_ != version_) {
std::shared_ptr<VersionedCredentials> updated = std::atomic_load(&creds_);
current_creds = *updated;
}

return result;
}

DebugStats MutableStaticCredentialsProvider::get_debug_stats() {
#ifdef DEBUG
return debug_stats;
#else
return DebugStats();
#endif
return current_creds.creds_;
}
32 changes: 8 additions & 24 deletions aws/auth/mutable_static_creds_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,17 @@
#include <atomic>
#include <cstdint>
#include <mutex>
#include <array>
#include <memory>

namespace aws {
namespace auth {

struct DebugStats {
std::uint64_t update_before_load_;
std::uint64_t update_after_load_;
std::uint64_t version_mismatch_;
std::uint64_t success_;
std::uint64_t retried_;
std::uint64_t attempts_;
std::uint64_t used_lock_;
struct VersionedCredentials {
std::uint64_t version_;
Aws::Auth::AWSCredentials creds_;

DebugStats() : update_before_load_(0), update_after_load_(0), version_mismatch_(0),
success_(0), retried_(0), attempts_(0), used_lock_(0) {}
VersionedCredentials() : version_(0), creds_("", "", "") {}
VersionedCredentials(std::uint64_t version, const std::string& akid, const std::string& sk, const std::string& token);
};

// Like basic static creds, but with an atomic set operation
Expand All @@ -46,21 +41,10 @@ class MutableStaticCredentialsProvider

Aws::Auth::AWSCredentials GetAWSCredentials() override;

DebugStats get_debug_stats();


private:
struct VersionedCredentials {
Aws::Auth::AWSCredentials creds_;
std::atomic<std::uint64_t> version_;
std::atomic<bool> updating_;
VersionedCredentials() : version_(0) {}
};
VersionedCredentials current_;

std::mutex update_mutex_;

bool try_optimistic_read(Aws::Auth::AWSCredentials& destination);
std::shared_ptr<VersionedCredentials> creds_;
std::atomic<std::uint64_t> version_;

};

Expand Down
33 changes: 5 additions & 28 deletions aws/auth/test/mutable_static_creds_provider_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ struct ResultsCounter {
BOOST_AUTO_TEST_SUITE(MutableStaticCredsProviderTest)

BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) {
const std::uint32_t kReaderThreadCount = 20;
const std::uint32_t kReaderThreadCount = 128;
std::atomic<bool> test_running(true);
std::array<ResultsCounter, kReaderThreadCount> results;
std::array<aws::auth::DebugStats, kReaderThreadCount> debug_stats;
std::vector<std::thread> reader_threads;

aws::auth::MutableStaticCredentialsProvider provider("initial-0", "initial-0", "initial-0");
Expand All @@ -79,16 +78,18 @@ BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) {
++counter;
std::stringstream ss;
std::time_t now_time = system_clock::to_time_t(now);
ss << std::put_time(std::gmtime(&now_time), "%FT%T") << "-" << std::setfill('0') << std::setw(10) << counter;
ss << now_time << "-" << std::setfill('0') << std::setw(10) << counter;
std::string value = ss.str();
provider.set_credentials(value, value, value);
using namespace std::chrono_literals;
std::this_thread::sleep_for(20ms);
} while (now < end);
LOG(info) << "Producer thread completed";
});

LOG(info) << "Starting " << kReaderThreadCount << " consumer threads";
for(std::uint32_t i = 0; i < kReaderThreadCount; ++i) {
reader_threads.emplace_back([i, &provider, &results, &test_running, &debug_stats] {
reader_threads.emplace_back([i, &provider, &results, &test_running] {
using namespace std::chrono;
auto start = high_resolution_clock::now();
while(test_running) {
Expand All @@ -104,7 +105,6 @@ BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) {
milliseconds millis_taken = duration_cast<milliseconds>(taken);
double seconds = millis_taken.count() / 1000.0;
results[i].seconds_ = seconds;
debug_stats[i] = provider.get_debug_stats();
});
}

Expand Down Expand Up @@ -151,29 +151,6 @@ BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) {
<< std::setw(results_width) << total_average
<< std::setw(results_width) << overall_calls_per;

#ifdef DEBUG
std::uint32_t debug_width = 20;
LOG(info) << "Debug Stats";
LOG(info) << "\t"
<< std::setw(debug_width) << "Update Before Load"
<< std::setw(debug_width) << "Update After Load"
<< std::setw(debug_width) << "Version Mismatch"
<< std::setw(debug_width) << "Used Lock"
<< std::setw(debug_width) << "Retried"
<< std::setw(debug_width) << "Success"
<< std::setw(debug_width) << "Total";
std::for_each(debug_stats.begin(), debug_stats.end(), [debug_width](aws::auth::DebugStats& d) {
LOG(info) << "\t"
<< std::setw(debug_width) << d.update_before_load_
<< std::setw(debug_width) << d.update_after_load_
<< std::setw(debug_width) << d.version_mismatch_
<< std::setw(debug_width) << d.used_lock_
<< std::setw(debug_width) << d.retried_
<< std::setw(debug_width) << d.success_
<< std::setw(debug_width) << d.attempts_;

});
#endif
BOOST_CHECK_EQUAL(failures, 0);
LOG(info) << "Test Completed";
}
Expand Down