From c71060fa0118ac9de2b292fe0c2ef587e0cb35ce Mon Sep 17 00:00:00 2001 From: jaymode Date: Wed, 9 Jan 2019 12:17:43 -0700 Subject: [PATCH] Test: fix race in auth result propagation test This commit fixes a race condition in a test introduced by #36900 that verifies concurrent authentications get a result propagated from the first thread that attempts to authenticate. Previously, a thread may be in a state where it had not attempted to authenticate when the first thread that authenticates finishes the authentication, which would cause the test to fail as there would be an additional authentication attempt. This change adds additional latches to ensure all threads have attempted to authenticate before a result gets returned in the thread that is performing authentication. --- .../CachingUsernamePasswordRealmTests.java | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java index 4ed04864041d6..2fed720e23c09 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java @@ -484,13 +484,27 @@ public void testUnauthenticatedResultPropagatesWithSameCreds() throws Exception final int numberOfProcessors = Runtime.getRuntime().availableProcessors(); final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3); - final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads); List threads = new ArrayList<>(numberOfThreads); final SecureString credsToUse = new SecureString(randomAlphaOfLength(12).toCharArray()); + + // we use a bunch of different latches here, the first `latch` is used to ensure all threads have been started + // before they start to execute. The `authWaitLatch` is there to ensure we have all threads waiting on the + // listener before we auth otherwise we may run into a race condition where we auth and one of the threads is + // not waiting on auth yet. Finally, the completedLatch is used to signal that each thread received a response! + final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads); + final CountDownLatch authWaitLatch = new CountDownLatch(numberOfThreads); + final CountDownLatch completedLatch = new CountDownLatch(numberOfThreads); final CachingUsernamePasswordRealm realm = new CachingUsernamePasswordRealm(config, threadPool) { @Override protected void doAuthenticate(UsernamePasswordToken token, ActionListener listener) { authCounter.incrementAndGet(); + authWaitLatch.countDown(); + try { + authWaitLatch.await(); + } catch (InterruptedException e) { + logger.info("authentication was interrupted", e); + Thread.currentThread().interrupt(); + } // do something slow if (pwdHasher.verify(token.credentials(), passwordHash.toCharArray())) { listener.onFailure(new IllegalStateException("password auth should never succeed")); @@ -513,14 +527,17 @@ protected void doLookupUser(String username, ActionListener listener) { realm.authenticate(token, ActionListener.wrap((result) -> { if (result.isAuthenticated()) { + completedLatch.countDown(); throw new IllegalStateException("invalid password led to an authenticated result: " + result); } assertThat(result.getMessage(), containsString("password verification failed")); + completedLatch.countDown(); }, (e) -> { logger.error("caught exception", e); + completedLatch.countDown(); fail("unexpected exception - " + e); })); - + authWaitLatch.countDown(); } catch (InterruptedException e) { logger.error("thread was interrupted", e); Thread.currentThread().interrupt(); @@ -535,6 +552,7 @@ protected void doLookupUser(String username, ActionListener listener) { for (Thread thread : threads) { thread.join(); } + completedLatch.await(); assertEquals(1, authCounter.get()); }