diff --git a/aws/auth/mutable_static_creds_provider.cc b/aws/auth/mutable_static_creds_provider.cc index 22bf2fc7..4bbefaea 100644 --- a/aws/auth/mutable_static_creds_provider.cc +++ b/aws/auth/mutable_static_creds_provider.cc @@ -12,140 +12,67 @@ // permissions and limitations under the License. #include "mutable_static_creds_provider.h" -#include namespace { -#ifdef DEBUG - thread_local aws::auth::DebugStats debug_stats; - - void update_step(std::uint64_t& step, std::uint64_t& outcome) { - step++; - outcome++; - debug_stats.attempts_++; - } - - void update_step(std::uint64_t& step) { - std::uint64_t ignored = 0; - update_step(step, ignored); - } -#else -#define update_step(...) -#endif + // + // Provides a thread scoped copy of the current credentials to an executing thread. + // This makes the most difference when using a thread pool, as the retrieval of the + // credentials will only require a lock when the credentials version changes. + // + thread_local aws::auth::VersionedCredentials current_creds; } using namespace aws::auth; +VersionedCredentials::VersionedCredentials(std::uint64_t version, const std::string& akid, const std::string& sk, const std::string& token) : + version_(version), creds_(Aws::Auth::AWSCredentials(akid, sk, token)) { +} + MutableStaticCredentialsProvider::MutableStaticCredentialsProvider(const std::string& akid, const std::string& sk, - std::string token) { - Aws::Auth::AWSCredentials creds(akid, sk, token); - current_.creds_ = creds; + std::string token) : + creds_(std::make_shared(1, akid, sk, token)), version_(1) { } void MutableStaticCredentialsProvider::set_credentials(const std::string& akid, const std::string& sk, std::string token) { std::lock_guard lock(update_mutex_); - // - // The ordering of stores here is important. current_.updating_, and current_.version_ are atomics. - // Between the two of them they create an interlocked state change that allows consumers - // to detect that the credentials have changed during the process of copying the values. - // - // Specifically current_.version_ must be incremented before current_.updating_ is set to false. - // This ensures that consumers will either see a version mismatch or see the updating state change. - // - - current_.updating_ = true; - current_.creds_.SetAWSAccessKeyId(akid); - current_.creds_.SetAWSSecretKey(sk); - current_.creds_.SetSessionToken(token); - current_.version_++; - current_.updating_ = false; -} + std::uint64_t next_version = version_ + 1; -bool MutableStaticCredentialsProvider::try_optimistic_read(Aws::Auth::AWSCredentials& destination) { // - // This is an attempt to do an optimistic read. We assume that the contents of the - // credentials may change in while copying to the result. + // Since the credentials are created with the expected next version, and the entire update + // is protected by a lock we can't get into a scenario where one of the consumers has + // a mismatched version and credentials. // - // 1. To handle this we first check if an update is in progress, and if it is we bounce out - // and retry. - // - // 2. If the credential isn't being updated we go ahead and load the current version of - // the credential. + std::shared_ptr new_credentials = std::make_shared(next_version, akid, sk, token); + // - // 3. At this point we go ahead and trigger a copy of the credential to the result. This - // is what we will return if our remaining checks indicate everything is still ok. + // This update the credentials atomically using the shared_ptr specific atomic operations, + // and doesn't require a specific lock on the shared_ptr during the update. The lock + // taken previously is to prevent two credential updates at the same time. // - // 4. We check again to see if the credential has entered updating, if it has it's possible - // the version of the credential we copied was split between the two updates. So we - // discard our copy, and try again. + // The global version change allows the threads to detect the updated version. Once detected + // the threads will pull the updated credential to their own thread local copy. // - // 5. Finally we make check to see that the version hasn't changed since we started. This - // ensures that the credential didn't enter and exit updates between our update checks. + std::atomic_store(&creds_, new_credentials); + version_ = next_version; +} + +Aws::Auth::AWSCredentials MutableStaticCredentialsProvider::GetAWSCredentials() { // - // If everything is ok we are safe to return the credential, while it may be a nanosecnds - // out date it should be fine. + // Check to see if the credentials have been updated. If they have load the credentials + // and update the thread local. // - - if (current_.updating_) { - // - // The credentials are currently being updated. It's not safe to read so - // spin while we wait for the update to complete. - // - update_step(debug_stats.update_before_load_, debug_stats.retried_); - return false; - } - std::uint64_t starting_version = current_.version_; - + // If the credentials are changing rapidly it's possible that the thread will read an + // old version of the credentials. Should that occur the next read will update to the + // most current version. // - // Trigger a trivial copy to the result + // This check still works in the very unlikely event that the next_version value + // wraps around. // - destination = current_.creds_; - - if (current_.updating_) { - // - // The credentials object is being updated and the update may have started while we were - // copying it. For safety we discard what we have and try again. - // - update_step(debug_stats.update_after_load_, debug_stats.retried_); - return false; + if (current_creds.version_ != version_) { + std::shared_ptr updated = std::atomic_load(&creds_); + current_creds = *updated; } - - std::uint64_t ending_version = current_.version_; - - if (starting_version != ending_version) { - // - // The version changed in between the start of the copy, and the end - // of the copy. We can no longer trust that the resulting copy is - // correct. So we discard what we have and try again. - // - update_step(debug_stats.version_mismatch_, debug_stats.retried_); - return false; - } - update_step(debug_stats.success_); - return true; -} - -Aws::Auth::AWSCredentials MutableStaticCredentialsProvider::GetAWSCredentials() { - - Aws::Auth::AWSCredentials result; - - if (!try_optimistic_read(result)) { - // - // The optimistic read failed, so just give up and use the lock to acquire the credentials. - // - std::lock_guard lock(update_mutex_); - update_step(debug_stats.used_lock_, debug_stats.success_); - result = current_.creds_; - } - - return result; -} - -DebugStats MutableStaticCredentialsProvider::get_debug_stats() { -#ifdef DEBUG - return debug_stats; -#else - return DebugStats(); -#endif + return current_creds.creds_; } diff --git a/aws/auth/mutable_static_creds_provider.h b/aws/auth/mutable_static_creds_provider.h index ac7fc183..65800cdb 100644 --- a/aws/auth/mutable_static_creds_provider.h +++ b/aws/auth/mutable_static_creds_provider.h @@ -18,22 +18,17 @@ #include #include #include -#include +#include namespace aws { namespace auth { -struct DebugStats { - std::uint64_t update_before_load_; - std::uint64_t update_after_load_; - std::uint64_t version_mismatch_; - std::uint64_t success_; - std::uint64_t retried_; - std::uint64_t attempts_; - std::uint64_t used_lock_; +struct VersionedCredentials { + std::uint64_t version_; + Aws::Auth::AWSCredentials creds_; - DebugStats() : update_before_load_(0), update_after_load_(0), version_mismatch_(0), - success_(0), retried_(0), attempts_(0), used_lock_(0) {} + VersionedCredentials() : version_(0), creds_("", "", "") {} + VersionedCredentials(std::uint64_t version, const std::string& akid, const std::string& sk, const std::string& token); }; // Like basic static creds, but with an atomic set operation @@ -46,21 +41,10 @@ class MutableStaticCredentialsProvider Aws::Auth::AWSCredentials GetAWSCredentials() override; - DebugStats get_debug_stats(); - - private: - struct VersionedCredentials { - Aws::Auth::AWSCredentials creds_; - std::atomic version_; - std::atomic updating_; - VersionedCredentials() : version_(0) {} - }; - VersionedCredentials current_; - std::mutex update_mutex_; - - bool try_optimistic_read(Aws::Auth::AWSCredentials& destination); + std::shared_ptr creds_; + std::atomic version_; }; diff --git a/aws/auth/test/mutable_static_creds_provider_test.cc b/aws/auth/test/mutable_static_creds_provider_test.cc index 20e064c8..80946da2 100644 --- a/aws/auth/test/mutable_static_creds_provider_test.cc +++ b/aws/auth/test/mutable_static_creds_provider_test.cc @@ -58,10 +58,9 @@ struct ResultsCounter { BOOST_AUTO_TEST_SUITE(MutableStaticCredsProviderTest) BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) { - const std::uint32_t kReaderThreadCount = 20; + const std::uint32_t kReaderThreadCount = 128; std::atomic test_running(true); std::array results; - std::array debug_stats; std::vector reader_threads; aws::auth::MutableStaticCredentialsProvider provider("initial-0", "initial-0", "initial-0"); @@ -79,16 +78,18 @@ BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) { ++counter; std::stringstream ss; std::time_t now_time = system_clock::to_time_t(now); - ss << std::put_time(std::gmtime(&now_time), "%FT%T") << "-" << std::setfill('0') << std::setw(10) << counter; + ss << now_time << "-" << std::setfill('0') << std::setw(10) << counter; std::string value = ss.str(); provider.set_credentials(value, value, value); + using namespace std::chrono_literals; + std::this_thread::sleep_for(20ms); } while (now < end); LOG(info) << "Producer thread completed"; }); LOG(info) << "Starting " << kReaderThreadCount << " consumer threads"; for(std::uint32_t i = 0; i < kReaderThreadCount; ++i) { - reader_threads.emplace_back([i, &provider, &results, &test_running, &debug_stats] { + reader_threads.emplace_back([i, &provider, &results, &test_running] { using namespace std::chrono; auto start = high_resolution_clock::now(); while(test_running) { @@ -104,7 +105,6 @@ BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) { milliseconds millis_taken = duration_cast(taken); double seconds = millis_taken.count() / 1000.0; results[i].seconds_ = seconds; - debug_stats[i] = provider.get_debug_stats(); }); } @@ -151,29 +151,6 @@ BOOST_AUTO_TEST_CASE(SinglePublisherMultipleReaders) { << std::setw(results_width) << total_average << std::setw(results_width) << overall_calls_per; -#ifdef DEBUG - std::uint32_t debug_width = 20; - LOG(info) << "Debug Stats"; - LOG(info) << "\t" - << std::setw(debug_width) << "Update Before Load" - << std::setw(debug_width) << "Update After Load" - << std::setw(debug_width) << "Version Mismatch" - << std::setw(debug_width) << "Used Lock" - << std::setw(debug_width) << "Retried" - << std::setw(debug_width) << "Success" - << std::setw(debug_width) << "Total"; - std::for_each(debug_stats.begin(), debug_stats.end(), [debug_width](aws::auth::DebugStats& d) { - LOG(info) << "\t" - << std::setw(debug_width) << d.update_before_load_ - << std::setw(debug_width) << d.update_after_load_ - << std::setw(debug_width) << d.version_mismatch_ - << std::setw(debug_width) << d.used_lock_ - << std::setw(debug_width) << d.retried_ - << std::setw(debug_width) << d.success_ - << std::setw(debug_width) << d.attempts_; - - }); -#endif BOOST_CHECK_EQUAL(failures, 0); LOG(info) << "Test Completed"; }