diff --git a/data/src/main/java/com/microsoft/azure/kusto/data/AadAuthenticationHelper.java b/data/src/main/java/com/microsoft/azure/kusto/data/AadAuthenticationHelper.java index c821475a..f35386b5 100644 --- a/data/src/main/java/com/microsoft/azure/kusto/data/AadAuthenticationHelper.java +++ b/data/src/main/java/com/microsoft/azure/kusto/data/AadAuthenticationHelper.java @@ -8,18 +8,21 @@ import javax.naming.ServiceUnavailableException; import java.awt.*; -import java.io.IOException; import java.net.MalformedURLException; import java.net.URI; import java.net.URISyntaxException; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.Date; import java.util.concurrent.*; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; -public class AadAuthenticationHelper { +class AadAuthenticationHelper { private final static String DEFAULT_AAD_TENANT = "common"; private final static String CLIENT_ID = "db662dc1-0cfe-4e1c-a843-19a68e65be58"; + final static long MIN_ACCESS_TOKEN_VALIDITY_IN_MILLISECS = 60000; private ClientCredential clientCredential; private String userUsername; @@ -30,6 +33,9 @@ public class AadAuthenticationHelper { private PrivateKey privateKey; private AuthenticationType authenticationType; private String accessToken; + private AuthenticationResult lastAuthenticationResult; + private Lock lastAuthenticationResultLock = new ReentrantLock(); + private String applicationClientId; private enum AuthenticationType { AAD_USERNAME_PASSWORD, @@ -39,7 +45,8 @@ private enum AuthenticationType { AAD_ACCESS_TOKEN } - public AadAuthenticationHelper(@NotNull ConnectionStringBuilder csb) throws URISyntaxException { + AadAuthenticationHelper(@NotNull ConnectionStringBuilder csb) throws URISyntaxException { + URI clusterUri = new URI(csb.getClusterUrl()); clusterUrl = String.format("%s://%s", clusterUri.getScheme(), clusterUri.getHost()); if (StringUtils.isNotEmpty(csb.getApplicationClientId()) && StringUtils.isNotEmpty(csb.getApplicationKey())) { @@ -52,7 +59,7 @@ public AadAuthenticationHelper(@NotNull ConnectionStringBuilder csb) throws URIS } else if (csb.getX509Certificate() != null && csb.getPrivateKey() != null) { x509Certificate = csb.getX509Certificate(); privateKey = csb.getPrivateKey(); - clientCredential = new ClientCredential(csb.getApplicationClientId(), null); + applicationClientId = csb.getApplicationClientId(); authenticationType = AuthenticationType.AAD_APPLICATION_CERTIFICATE; } else if (StringUtils.isNotBlank(csb.getAccessToken())) { authenticationType = AuthenticationType.AAD_ACCESS_TOKEN; @@ -67,24 +74,25 @@ public AadAuthenticationHelper(@NotNull ConnectionStringBuilder csb) throws URIS } String acquireAccessToken() throws DataServiceException { - try { - switch (authenticationType) { - case AAD_APPLICATION_KEY: - return acquireAadApplicationAccessToken().getAccessToken(); - case AAD_USERNAME_PASSWORD: - return acquireAadUserAccessToken().getAccessToken(); - case AAD_DEVICE_LOGIN: - return acquireAccessTokenUsingDeviceCodeFlow().getAccessToken(); - case AAD_APPLICATION_CERTIFICATE: - return acquireWithClientCertificate().getAccessToken(); - case AAD_ACCESS_TOKEN: - return accessToken; - default: - throw new DataServiceException("Authentication type: " + authenticationType.name() + " is invalid"); + if (authenticationType == AuthenticationType.AAD_ACCESS_TOKEN) { + return accessToken; + } + + if (lastAuthenticationResult == null) { + acquireToken(); + } else if (isTokenExpired()) { + if (lastAuthenticationResult.getRefreshToken() == null) { + acquireToken(); + } else { + lastAuthenticationResultLock.lock(); + if (isTokenExpired()) { + lastAuthenticationResult = acquireAccessTokenByRefreshToken(); + } + lastAuthenticationResultLock.unlock(); } - } catch (Exception e) { - throw new DataServiceException(e.getMessage()); } + + return lastAuthenticationResult.getAccessToken(); } private AuthenticationResult acquireAadUserAccessToken() throws DataServiceException, DataClientException { @@ -180,19 +188,21 @@ private AuthenticationResult waitAndAcquireTokenByDeviceCode(DeviceCode deviceCo } AuthenticationResult acquireWithClientCertificate() - throws IOException, InterruptedException, ExecutionException, ServiceUnavailableException { + throws InterruptedException, ExecutionException, ServiceUnavailableException { AuthenticationContext context; - AuthenticationResult result; + AuthenticationResult result = null; ExecutorService service = null; try { service = Executors.newSingleThreadExecutor(); context = new AuthenticationContext(aadAuthorityUri, false, service); - AsymmetricKeyCredential asymmetricKeyCredential = AsymmetricKeyCredential.create(clientCredential.getClientId(), + AsymmetricKeyCredential asymmetricKeyCredential = AsymmetricKeyCredential.create(applicationClientId, privateKey, x509Certificate); // pass null value for optional callback function and acquire access token result = context.acquireToken(clusterUrl, asymmetricKeyCredential, null).get(); + } catch (MalformedURLException e) { + e.printStackTrace(); } finally { if (service != null) { service.shutdown(); @@ -204,4 +214,64 @@ AuthenticationResult acquireWithClientCertificate() return result; } + private void acquireToken() throws DataServiceException { + lastAuthenticationResultLock.lock(); + if (lastAuthenticationResult == null || isTokenExpired()) { + try { + switch (authenticationType) { + case AAD_APPLICATION_KEY: + lastAuthenticationResult = acquireAadApplicationAccessToken(); + break; + case AAD_USERNAME_PASSWORD: + lastAuthenticationResult = acquireAadUserAccessToken(); + break; + case AAD_DEVICE_LOGIN: + lastAuthenticationResult = acquireAccessTokenUsingDeviceCodeFlow(); + break; + case AAD_APPLICATION_CERTIFICATE: + lastAuthenticationResult = acquireWithClientCertificate(); + break; + default: + throw new DataServiceException("Authentication type: " + authenticationType.name() + " is invalid"); + } + } catch (Exception e) { + throw new DataServiceException(e.getMessage()); + } + } + lastAuthenticationResultLock.unlock(); + } + + private boolean isTokenExpired() { + return lastAuthenticationResult.getExpiresOnDate().before(dateInAMinute()); + } + + AuthenticationResult acquireAccessTokenByRefreshToken() throws DataServiceException { + AuthenticationContext context; + ExecutorService service = null; + + try { + service = Executors.newSingleThreadExecutor(); + context = new AuthenticationContext(aadAuthorityUri, false, service); + switch (authenticationType) { + case AAD_APPLICATION_KEY: + case AAD_APPLICATION_CERTIFICATE: + return context.acquireTokenByRefreshToken(lastAuthenticationResult.getRefreshToken(), clientCredential, null).get(); + case AAD_USERNAME_PASSWORD: + case AAD_DEVICE_LOGIN: + return context.acquireTokenByRefreshToken(lastAuthenticationResult.getRefreshToken(), CLIENT_ID, clusterUrl, null).get(); + default: + throw new DataServiceException("Authentication type: " + authenticationType.name() + " is invalid"); + } + } catch (Exception e) { + throw new DataServiceException(e.getMessage()); + } finally { + if (service != null) { + service.shutdown(); + } + } + } + + Date dateInAMinute() { + return new Date(System.currentTimeMillis() + MIN_ACCESS_TOKEN_VALIDITY_IN_MILLISECS); + } } diff --git a/data/src/test/java/com/microsoft/azure/kusto/data/AadAuthenticationHelperTest.java b/data/src/test/java/com/microsoft/azure/kusto/data/AadAuthenticationHelperTest.java index 9ea407e7..bb1a791c 100644 --- a/data/src/test/java/com/microsoft/azure/kusto/data/AadAuthenticationHelperTest.java +++ b/data/src/test/java/com/microsoft/azure/kusto/data/AadAuthenticationHelperTest.java @@ -1,6 +1,9 @@ package com.microsoft.azure.kusto.data; +import com.microsoft.aad.adal4j.AuthenticationResult; +import com.microsoft.aad.adal4j.UserInfo; +import com.microsoft.azure.kusto.data.exceptions.DataServiceException; import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; import org.bouncycastle.cert.X509CertificateHolder; import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; @@ -25,17 +28,27 @@ import java.security.Security; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.Date; import java.util.concurrent.ExecutionException; +import javax.naming.ServiceUnavailableException; + +import static com.microsoft.azure.kusto.data.AadAuthenticationHelper.MIN_ACCESS_TOKEN_VALIDITY_IN_MILLISECS; +import static org.mockito.Mockito.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; + public class AadAuthenticationHelperTest { + @Test @DisplayName("validate auth with certificate throws exception when missing or invalid parameters") void acquireWithClientCertificateNullKey() throws CertificateException, OperatorCreationException, PKCSException, IOException, URISyntaxException { - String certFilePath = Paths.get("src","test","resources", "cert.cer").toString(); - String privateKeyPath = Paths.get("src","test","resources","key.pem").toString(); + String certFilePath = Paths.get("src", "test", "resources", "cert.cer").toString(); + String privateKeyPath = Paths.get("src", "test", "resources", "key.pem").toString(); X509Certificate x509Certificate = readPem(certFilePath, "basic").getCertificate(); PrivateKey privateKey = readPem(privateKeyPath, "basic").getKey(); @@ -47,7 +60,6 @@ void acquireWithClientCertificateNullKey() throws CertificateException, Operator Assertions.assertThrows(ExecutionException.class, () -> aadAuthenticationHelper.acquireWithClientCertificate()); - } static KeyCert readPem(String path, String password) @@ -83,4 +95,39 @@ static KeyCert readPem(String path, String password) return keycert; } + @Test + @DisplayName("validate cached token. Refresh if needed. Call regularly if no refresh token") + void useCachedTokenAndRefreshWhenNeeded() throws InterruptedException, ExecutionException, ServiceUnavailableException, IOException, DataServiceException, URISyntaxException, CertificateException, OperatorCreationException, PKCSException { + String certFilePath = Paths.get("src", "test", "resources", "cert.cer").toString(); + String privateKeyPath = Paths.get("src", "test", "resources", "key.pem").toString(); + + X509Certificate x509Certificate = readPem(certFilePath, "basic").getCertificate(); + PrivateKey privateKey = readPem(privateKeyPath, "basic").getKey(); + + ConnectionStringBuilder csb = ConnectionStringBuilder + .createWithAadApplicationCertificate("resource.uri", "client-id", x509Certificate, privateKey); + + AadAuthenticationHelper aadAuthenticationHelperSpy = spy(new AadAuthenticationHelper(csb)); + + AuthenticationResult authenticationResult = new AuthenticationResult("testType", "firstToken", "refreshToken", 0, "id", mock(UserInfo.class), false); + AuthenticationResult authenticationResultFromRefresh = new AuthenticationResult("testType", "fromRefresh", null, 90, "id", mock(UserInfo.class), false); + AuthenticationResult authenticationResultNullRefreshTokenResult = new AuthenticationResult("testType", "nullRefreshResult", null, 0, "id", mock(UserInfo.class), false); + + doReturn(authenticationResultFromRefresh).when(aadAuthenticationHelperSpy).acquireAccessTokenByRefreshToken(); + doReturn(authenticationResult).when(aadAuthenticationHelperSpy).acquireWithClientCertificate(); + + assertEquals("firstToken", aadAuthenticationHelperSpy.acquireAccessToken()); + + // Token was passed as expired - expected to be refreshed + assertEquals("fromRefresh", aadAuthenticationHelperSpy.acquireAccessToken()); + + // Token is still valid - expected to return the same + assertEquals("fromRefresh", aadAuthenticationHelperSpy.acquireAccessToken()); + + doReturn(new Date(System.currentTimeMillis() + MIN_ACCESS_TOKEN_VALIDITY_IN_MILLISECS * 2)).when(aadAuthenticationHelperSpy).dateInAMinute(); + doReturn(authenticationResultNullRefreshTokenResult).when(aadAuthenticationHelperSpy).acquireWithClientCertificate(); + + // Null refresh token + token is now expired- expected to authenticate again and reacquire token + assertEquals("nullRefreshResult", aadAuthenticationHelperSpy.acquireAccessToken()); + } } diff --git a/ingest/src/main/java/com/microsoft/azure/kusto/ingest/IngestClientFactory.java b/ingest/src/main/java/com/microsoft/azure/kusto/ingest/IngestClientFactory.java index 787f9edf..2fbc1e08 100644 --- a/ingest/src/main/java/com/microsoft/azure/kusto/ingest/IngestClientFactory.java +++ b/ingest/src/main/java/com/microsoft/azure/kusto/ingest/IngestClientFactory.java @@ -2,11 +2,12 @@ import com.microsoft.azure.kusto.data.ConnectionStringBuilder; +import java.net.MalformedURLException; import java.net.URISyntaxException; public class IngestClientFactory { - public static IngestClient createClient(ConnectionStringBuilder csb) throws URISyntaxException { + public static IngestClient createClient(ConnectionStringBuilder csb) throws URISyntaxException, MalformedURLException { return new IngestClientImpl(csb); } }