Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClientCertificateAzureServiceTokenProvider dispose of certificate obj… #17266

Merged
1 commit merged into from
Dec 3, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
Expand Down Expand Up @@ -128,76 +129,95 @@ public override async Task<AppAuthenticationResult> GetAuthResultAsync(string re
}

List<X509Certificate2> certs = null;
switch (_certificateIdentifierType)
{
case CertificateIdentifierType.KeyVaultCertificateSecretIdentifier:
// Get certificate for the given Key Vault secret identifier
try
{
var keyVaultCert = await _keyVaultClient.GetCertificateAsync(_certificateIdentifier, cancellationToken).ConfigureAwait(false);
certs = new List<X509Certificate2>() { keyVaultCert };
Dictionary<string, string> exceptionDictionary = new Dictionary<string, string>();

// If authority is still not specified, create it using azureAdInstance and tenantId. Tenant ID comes from Key Vault access token.
if (string.IsNullOrWhiteSpace(authority))
try
{
switch (_certificateIdentifierType)
{
case CertificateIdentifierType.KeyVaultCertificateSecretIdentifier:
// Get certificate for the given Key Vault secret identifier
try
{
_tenantId = _keyVaultClient.PrincipalUsed.TenantId;
authority = $"{_azureAdInstance}{_tenantId}";
var keyVaultCert = await _keyVaultClient
.GetCertificateAsync(_certificateIdentifier, cancellationToken).ConfigureAwait(false);
certs = new List<X509Certificate2>() { keyVaultCert };

// If authority is still not specified, create it using azureAdInstance and tenantId. Tenant ID comes from Key Vault access token.
if (string.IsNullOrWhiteSpace(authority))
{
_tenantId = _keyVaultClient.PrincipalUsed.TenantId;
authority = $"{_azureAdInstance}{_tenantId}";
}
}
}
catch (Exception exp)
{
throw new AzureServiceTokenProviderException(ConnectionString, resource, authority,
$"{AzureServiceTokenProviderException.KeyVaultCertificateRetrievalError} {exp.Message}");
}
break;
case CertificateIdentifierType.SubjectName:
case CertificateIdentifierType.Thumbprint:
// Get certificates for the given thumbprint or subject name.
bool isThumbprint = _certificateIdentifierType == CertificateIdentifierType.Thumbprint;
certs = CertificateHelper.GetCertificates(_certificateIdentifier, isThumbprint,
_storeLocation);

if (certs == null || certs.Count == 0)
{
throw new AzureServiceTokenProviderException(ConnectionString, resource, authority,
AzureServiceTokenProviderException.LocalCertificateNotFound);
}
break;
}

// If multiple certs were found, use in order of most recently created.
// This helps if old cert is rolled over, but not removed.
certs = certs.OrderByDescending(p => p.NotBefore).ToList();
catch (Exception exp)
{
throw new AzureServiceTokenProviderException(ConnectionString, resource, authority,
$"{AzureServiceTokenProviderException.KeyVaultCertificateRetrievalError} {exp.Message}");
}
break;
case CertificateIdentifierType.SubjectName:
case CertificateIdentifierType.Thumbprint:
// Get certificates for the given thumbprint or subject name.
bool isThumbprint = _certificateIdentifierType == CertificateIdentifierType.Thumbprint;
certs = CertificateHelper.GetCertificates(_certificateIdentifier, isThumbprint,
_storeLocation);

if (certs == null || certs.Count == 0)
{
throw new AzureServiceTokenProviderException(ConnectionString, resource, authority,
AzureServiceTokenProviderException.LocalCertificateNotFound);
}
break;
}

// To hold reason why token could not be acquired per cert tried.
Dictionary<string, string> exceptionDictionary = new Dictionary<string, string>();
Debug.Assert(certs != null, "Probably wrong certificateIdentifierType was used to instantiate this class!");

foreach (X509Certificate2 cert in certs)
{
if (!string.IsNullOrEmpty(cert.Thumbprint))
// If multiple certs were found, use in order of most recently created.
// This helps if old cert is rolled over, but not removed.
// To hold reason why token could not be acquired per cert tried.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment was for the exception dictionary that was moved up

foreach (X509Certificate2 cert in certs.OrderByDescending(p => p.NotBefore))
{
try
if (!string.IsNullOrEmpty(cert.Thumbprint))
{
ClientAssertionCertificate certCred = new ClientAssertionCertificate(_clientId, cert);
try
{
ClientAssertionCertificate certCred = new ClientAssertionCertificate(_clientId, cert);

var authResult =
await _authenticationContext.AcquireTokenAsync(authority, resource, certCred).ConfigureAwait(false);
var authResult =
await _authenticationContext.AcquireTokenAsync(authority, resource, certCred)
.ConfigureAwait(false);

var accessToken = authResult?.AccessToken;
var accessToken = authResult?.AccessToken;

if (accessToken != null)
{
PrincipalUsed.CertificateThumbprint = cert.Thumbprint;
PrincipalUsed.IsAuthenticated = true;
PrincipalUsed.TenantId = AccessToken.Parse(accessToken).TenantId;
if (accessToken != null)
{
PrincipalUsed.CertificateThumbprint = cert.Thumbprint;
PrincipalUsed.IsAuthenticated = true;
PrincipalUsed.TenantId = AccessToken.Parse(accessToken).TenantId;

return authResult;
return authResult;
}
}
catch (Exception exp)
{
// If token cannot be acquired using a cert, try the next one
exceptionDictionary[cert.Thumbprint] = exp.Message;
}
}
catch (Exception exp)
}
}
finally
{
if (certs != null)
{
foreach (var cert in certs)
{
// If token cannot be acquired using a cert, try the next one
exceptionDictionary[cert.Thumbprint] = exp.Message;
#if net452
cert.Reset();
#else
cert.Dispose();
#endif
}
}
}
Expand Down