Skip to content

Commit

Permalink
Use a shared_ptr and thread_local to store credentials for use (#202)
Browse files Browse the repository at this point in the history
Switched to using a thread local in conjunction with shared_ptr atomic operations.  With thread pooling this makes it so most threads don't need to access the primary credential when it's stable.
  • Loading branch information
pfifer committed Apr 16, 2018
1 parent b0eff40 commit 7b50aca
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 164 deletions.
151 changes: 39 additions & 112 deletions aws/auth/mutable_static_creds_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,140 +12,67 @@
// permissions and limitations under the License.

#include "mutable_static_creds_provider.h"
#include <limits>

namespace {
#ifdef DEBUG
thread_local aws::auth::DebugStats debug_stats;

void update_step(std::uint64_t& step, std::uint64_t& outcome) {
step++;
outcome++;
debug_stats.attempts_++;
}

void update_step(std::uint64_t& step) {
std::uint64_t ignored = 0;
update_step(step, ignored);
}
#else
#define update_step(...)
#endif
//
// Provides a thread scoped copy of the current credentials to an executing thread.
// This makes the most difference when using a thread pool, as the retrieval of the
// credentials will only require a lock when the credentials version changes.
//
thread_local aws::auth::VersionedCredentials current_creds;
}

using namespace aws::auth;

VersionedCredentials::VersionedCredentials(std::uint64_t version, const std::string& akid, const std::string& sk, const std::string& token) :
version_(version), creds_(Aws::Auth::AWSCredentials(akid, sk, token)) {
}

MutableStaticCredentialsProvider::MutableStaticCredentialsProvider(const std::string& akid,
const std::string& sk,
std::string token) {
Aws::Auth::AWSCredentials creds(akid, sk, token);
current_.creds_ = creds;
std::string token) :
creds_(std::make_shared<VersionedCredentials>(1, akid, sk, token)), version_(1) {
}

void MutableStaticCredentialsProvider::set_credentials(const std::string& akid, const std::string& sk, std::string token) {
std::lock_guard<std::mutex> lock(update_mutex_);

//
// The ordering of stores here is important. current_.updating_, and current_.version_ are atomics.
// Between the two of them they create an interlocked state change that allows consumers
// to detect that the credentials have changed during the process of copying the values.
//
// Specifically current_.version_ must be incremented before current_.updating_ is set to false.
// This ensures that consumers will either see a version mismatch or see the updating state change.
//

current_.updating_ = true;
current_.creds_.SetAWSAccessKeyId(akid);
current_.creds_.SetAWSSecretKey(sk);
current_.creds_.SetSessionToken(token);
current_.version_++;
current_.updating_ = false;
}
std::uint64_t next_version = version_ + 1;

bool MutableStaticCredentialsProvider::try_optimistic_read(Aws::Auth::AWSCredentials& destination) {
//
// This is an attempt to do an optimistic read. We assume that the contents of the
// credentials may change in while copying to the result.
// Since the credentials are created with the expected next version, and the entire update
// is protected by a lock we can't get into a scenario where one of the consumers has
// a mismatched version and credentials.
//
// 1. To handle this we first check if an update is in progress, and if it is we bounce out
// and retry.
//
// 2. If the credential isn't being updated we go ahead and load the current version of
// the credential.
std::shared_ptr<VersionedCredentials> new_credentials = std::make_shared<VersionedCredentials>(next_version, akid, sk, token);

//
// 3. At this point we go ahead and trigger a copy of the credential to the result. This
// is what we will return if our remaining checks indicate everything is still ok.
// This update the credentials atomically using the shared_ptr specific atomic operations,
// and doesn't require a specific lock on the shared_ptr during the update. The lock
// taken previously is to prevent two credential updates at the same time.
//
// 4. We check again to see if the credential has entered updating, if it has it's possible
// the version of the credential we copied was split between the two updates. So we
// discard our copy, and try again.
// The global version change allows the threads to detect the updated version. Once detected
// the threads will pull the updated credential to their own thread local copy.
//
// 5. Finally we make check to see that the version hasn't changed since we started. This
// ensures that the credential didn't enter and exit updates between our update checks.
std::atomic_store(&creds_, new_credentials);
version_ = next_version;
}

Aws::Auth::AWSCredentials MutableStaticCredentialsProvider::GetAWSCredentials() {
//
// If everything is ok we are safe to return the credential, while it may be a nanosecnds
// out date it should be fine.
// Check to see if the credentials have been updated. If they have load the credentials
// and update the thread local.
//

if (current_.updating_) {
//
// The credentials are currently being updated. It's not safe to read so
// spin while we wait for the update to complete.
//
update_step(debug_stats.update_before_load_, debug_stats.retried_);
return false;
}
std::uint64_t starting_version = current_.version_;

// If the credentials are changing rapidly it's possible that the thread will read an
// old version of the credentials. Should that occur the next read will update to the
// most current version.
//
// Trigger a trivial copy to the result
// This check still works in the very unlikely event that the next_version value
// wraps around.
//
destination = current_.creds_;

if (current_.updating_) {
//
// The credentials object is being updated and the update may have started while we were
// copying it. For safety we discard what we have and try again.
//
update_step(debug_stats.update_after_load_, debug_stats.retried_);
return false;
if (current_creds.version_ != version_) {
std::shared_ptr<VersionedCredentials> updated = std::atomic_load(&creds_);
current_creds = *updated;
}

std::uint64_t ending_version = current_.version_;

if (starting_version != ending_version) {
//
// The version changed in between the start of the copy, and the end
// of the copy. We can no longer trust that the resulting copy is
// correct. So we discard what we have and try again.
//
update_step(debug_stats.version_mismatch_, debug_stats.retried_);
return false;
}
update_step(debug_stats.success_);
return true;
}

Aws::Auth::AWSCredentials MutableStaticCredentialsProvider::GetAWSCredentials() {

Aws::Auth::AWSCredentials result;

if (!try_optimistic_read(result)) {
//
// The optimistic read failed, so just give up and use the lock to acquire the credentials.
//
std::lock_guard<std::mutex> lock(update_mutex_);
update_step(debug_stats.used_lock_, debug_stats.success_);
result = current_.creds_;
}

return result;
}

DebugStats MutableStaticCredentialsProvider::get_debug_stats() {
#ifdef DEBUG
return debug_stats;
#else
return DebugStats();
#endif
return current_creds.creds_;
}
32 changes: 8 additions & 24 deletions aws/auth/mutable_static_creds_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,17 @@
#include <atomic>
#include <cstdint>
#include <mutex>
#include <array>
#include <memory>

namespace aws {
namespace auth {

struct DebugStats {
std::uint64_t update_before_load_;
std::uint64_t update_after_load_;
std::uint64_t version_mismatch_;
std::uint64_t success_;
std::uint64_t retried_;
std::uint64_t attempts_;
std::uint64_t used_lock_;
struct VersionedCredentials {
std::uint64_t version_;
Aws::Auth::AWSCredentials creds_;

DebugStats() : update_before_load_(0), update_after_load_(0), version_mismatch_(0),
success_(0), retried_(0), attempts_(0), used_lock_(0) {}
VersionedCredentials() : version_(0), creds_("", "", "") {}
VersionedCredentials(std::uint64_t version, const std::string& akid, const std::string& sk, const std::string& token);
};

// Like basic static creds, but with an atomic set operation
Expand All @@ -46,21 +41,10 @@ class MutableStaticCredentialsProvider

Aws::Auth::AWSCredentials GetAWSCredentials() override;

DebugStats get_debug_stats();


private:
struct VersionedCredentials {
Aws::Auth::AWSCredentials creds_;
std::atomic<std::uint64_t> version_;
std::atomic<bool> updating_;
VersionedCredentials() : version_(0) {}
};
VersionedCredentials current_;

std::mutex update_mutex_;

bool try_optimistic_read(Aws::Auth::AWSCredentials& destination);
std::shared_ptr<VersionedCredentials> creds_;
std::atomic<std::uint64_t> version_;

};

Expand Down
33 changes: 5 additions & 28 deletions aws/auth/test/mutable_static_creds_provider_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ struct ResultsCounter {
BOOST_AUTO_TEST_SUITE(MutableStaticCredsProviderTest)

BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) {
const std::uint32_t kReaderThreadCount = 20;
const std::uint32_t kReaderThreadCount = 128;
std::atomic<bool> test_running(true);
std::array<ResultsCounter, kReaderThreadCount> results;
std::array<aws::auth::DebugStats, kReaderThreadCount> debug_stats;
std::vector<std::thread> reader_threads;

aws::auth::MutableStaticCredentialsProvider provider("initial-0", "initial-0", "initial-0");
Expand All @@ -79,16 +78,18 @@ BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) {
++counter;
std::stringstream ss;
std::time_t now_time = system_clock::to_time_t(now);
ss << std::put_time(std::gmtime(&now_time), "%FT%T") << "-" << std::setfill('0') << std::setw(10) << counter;
ss << now_time << "-" << std::setfill('0') << std::setw(10) << counter;
std::string value = ss.str();
provider.set_credentials(value, value, value);
using namespace std::chrono_literals;
std::this_thread::sleep_for(20ms);
} while (now < end);
LOG(info) << "Producer thread completed";
});

LOG(info) << "Starting " << kReaderThreadCount << " consumer threads";
for(std::uint32_t i = 0; i < kReaderThreadCount; ++i) {
reader_threads.emplace_back([i, &provider, &results, &test_running, &debug_stats] {
reader_threads.emplace_back([i, &provider, &results, &test_running] {
using namespace std::chrono;
auto start = high_resolution_clock::now();
while(test_running) {
Expand All @@ -104,7 +105,6 @@ BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) {
milliseconds millis_taken = duration_cast<milliseconds>(taken);
double seconds = millis_taken.count() / 1000.0;
results[i].seconds_ = seconds;
debug_stats[i] = provider.get_debug_stats();
});
}

Expand Down Expand Up @@ -151,29 +151,6 @@ BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) {
<< std::setw(results_width) << total_average
<< std::setw(results_width) << overall_calls_per;

#ifdef DEBUG
std::uint32_t debug_width = 20;
LOG(info) << "Debug Stats";
LOG(info) << "\t"
<< std::setw(debug_width) << "Update Before Load"
<< std::setw(debug_width) << "Update After Load"
<< std::setw(debug_width) << "Version Mismatch"
<< std::setw(debug_width) << "Used Lock"
<< std::setw(debug_width) << "Retried"
<< std::setw(debug_width) << "Success"
<< std::setw(debug_width) << "Total";
std::for_each(debug_stats.begin(), debug_stats.end(), [debug_width](aws::auth::DebugStats& d) {
LOG(info) << "\t"
<< std::setw(debug_width) << d.update_before_load_
<< std::setw(debug_width) << d.update_after_load_
<< std::setw(debug_width) << d.version_mismatch_
<< std::setw(debug_width) << d.used_lock_
<< std::setw(debug_width) << d.retried_
<< std::setw(debug_width) << d.success_
<< std::setw(debug_width) << d.attempts_;

});
#endif
BOOST_CHECK_EQUAL(failures, 0);
LOG(info) << "Test Completed";
}
Expand Down

0 comments on commit 7b50aca

Please sign in to comment.