Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid allocating collection for intermediate certificates #68188

Merged
merged 4 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ namespace System.Net
{
internal static partial class UnmanagedCertificateContext
{
internal static unsafe X509Certificate2Collection GetRemoteCertificatesFromStoreContext(IntPtr certContext)
internal static unsafe void GetRemoteCertificatesFromStoreContext(IntPtr certContext, X509Certificate2Collection result)
{
X509Certificate2Collection result = new X509Certificate2Collection();

if (certContext == IntPtr.Zero)
{
return result;
return;
}

Interop.Crypt32.CERT_CONTEXT context = *(Interop.Crypt32.CERT_CONTEXT*)certContext;
Expand Down Expand Up @@ -46,8 +44,6 @@ internal static unsafe X509Certificate2Collection GetRemoteCertificatesFromStore
last = next;
}
}

return result;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ namespace System.Net
{
internal static partial class UnmanagedCertificateContext
{
internal static X509Certificate2Collection GetRemoteCertificatesFromStoreContext(SafeFreeCertContext certContext)
internal static void GetRemoteCertificatesFromStoreContext(SafeFreeCertContext certContext, X509Certificate2Collection collection)
{
if (certContext.IsInvalid)
{
return new X509Certificate2Collection();
return;
}

return GetRemoteCertificatesFromStoreContext(certContext.DangerousGetHandle());
GetRemoteCertificatesFromStoreContext(certContext.DangerousGetHandle(), collection);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ private static void OnRequestSendingRequest(WinHttpRequestState state)
throw WinHttpException.CreateExceptionUsingError(lastError, "WINHTTP_CALLBACK_STATUS_SENDING_REQUEST/WinHttpQueryOption");
}

// Get any additional certificates sent from the remote server during the TLS/SSL handshake.
X509Certificate2Collection remoteCertificateStore =
UnmanagedCertificateContext.GetRemoteCertificatesFromStoreContext(certHandle);
// Get any additional certificates sent from the remote server during the TLS/SSL handshake.
X509Certificate2Collection remoteCertificateStore = new X509Certificate2Collection();
UnmanagedCertificateContext.GetRemoteCertificatesFromStoreContext(certHandle, remoteCertificateStore);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
UnmanagedCertificateContext.GetRemoteCertificatesFromStoreContext(certHandle, remoteCertificateStore);
UnmanagedCertificateContext.GetRemoteCertificatesFromStoreContext(certHandle, remoteCertificateStore);


// Create a managed wrapper around the certificate handle. Since this results in duplicating
// the handle, we will close the original handle after creating the wrapper.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,38 +40,18 @@ internal static SslPolicyErrors VerifyCertificateProperties(
//
// Extracts a remote certificate upon request.
//
internal static X509Certificate2? GetRemoteCertificate(SafeDeleteContext securityContext)
{
return GetRemoteCertificate(securityContext, null);
}

internal static X509Certificate2? GetRemoteCertificate(
SafeDeleteContext? securityContext,
out X509Certificate2Collection? remoteCertificateStore)
{
if (securityContext == null)
{
remoteCertificateStore = null;
return null;
}

remoteCertificateStore = new X509Certificate2Collection();
return GetRemoteCertificate(securityContext, remoteCertificateStore);
}

private static X509Certificate2? GetRemoteCertificate(
SafeDeleteContext securityContext,
X509Certificate2Collection? remoteCertificateStore)
bool retrieveChainCertificates,
ref X509Chain? chain)
{
if (securityContext == null)
return null;

SafeSslHandle sslContext = ((SafeDeleteSslContext)securityContext).SslContext;
if (sslContext == null)
return null;

X509Certificate2? cert = null;
if (remoteCertificateStore == null)
if (!retrieveChainCertificates)
{
// Constructing a new X509Certificate2 adds a global reference to the pointer, so we dispose this handle
using (SafeX509Handle handle = Interop.AndroidCrypto.SSLStreamGetPeerCertificate(sslContext))
Expand All @@ -84,6 +64,7 @@ internal static SslPolicyErrors VerifyCertificateProperties(
}
else
{
chain ??= new X509Chain();
IntPtr[]? ptrs = Interop.AndroidCrypto.SSLStreamGetPeerCertificates(sslContext);
if (ptrs != null && ptrs.Length > 0)
{
Expand All @@ -95,7 +76,7 @@ internal static SslPolicyErrors VerifyCertificateProperties(
// Constructing a new X509Certificate2 adds a global reference to the pointer, so we dispose this handle
using (var handle = new SafeX509Handle(ptr))
{
remoteCertificateStore.Add(new X509Certificate2(handle.DangerousGetHandle()));
chain.ChainPolicy.ExtraStore.Add(new X509Certificate2(handle.DangerousGetHandle()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,10 @@ internal static SslPolicyErrors VerifyCertificateProperties(
return errors;
}

//
// Extracts a remote certificate upon request.
//
internal static X509Certificate2? GetRemoteCertificate(SafeDeleteContext securityContext)
{
return GetRemoteCertificate(securityContext, null);
}

internal static X509Certificate2? GetRemoteCertificate(
SafeDeleteContext? securityContext,
out X509Certificate2Collection? remoteCertificateStore)
{
if (securityContext == null)
{
remoteCertificateStore = null;
return null;
}

remoteCertificateStore = new X509Certificate2Collection();
return GetRemoteCertificate(securityContext, remoteCertificateStore);
}

private static X509Certificate2? GetRemoteCertificate(
SafeDeleteContext securityContext,
X509Certificate2Collection? remoteCertificateStore)
bool retrieveChainCertificates,
ref X509Chain? chain)
{
if (securityContext == null)
{
Expand All @@ -91,12 +70,14 @@ internal static SslPolicyErrors VerifyCertificateProperties(
{
long chainSize = Interop.AppleCrypto.X509ChainGetChainSize(chainHandle);

if (remoteCertificateStore != null)
if (retrieveChainCertificates)
{
chain ??= new X509Chain();

for (int i = 0; i < chainSize; i++)
{
IntPtr certHandle = Interop.AppleCrypto.X509ChainGetCertificateAtIndex(chainHandle, i);
remoteCertificateStore.Add(new X509Certificate2(certHandle));
chain.ChainPolicy.ExtraStore.Add(new X509Certificate2(certHandle));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,7 @@ internal static SslPolicyErrors VerifyCertificateProperties(
//
// Extracts a remote certificate upon request.
//
internal static X509Certificate2? GetRemoteCertificate(SafeDeleteContext securityContext)
{
return GetRemoteCertificate(securityContext, null);
}

internal static X509Certificate2? GetRemoteCertificate(
SafeDeleteContext? securityContext,
out X509Certificate2Collection? remoteCertificateStore)
{
if (securityContext == null)
{
remoteCertificateStore = null;
return null;
}

remoteCertificateStore = new X509Certificate2Collection();
return GetRemoteCertificate(securityContext, remoteCertificateStore);
}

private static X509Certificate2? GetRemoteCertificate(SafeDeleteContext? securityContext, X509Certificate2Collection? remoteCertificateStore)
private static X509Certificate2? GetRemoteCertificate(SafeDeleteContext? securityContext, bool retrieveChainCertificates, ref X509Chain? chain)
{
bool gotReference = false;

Expand All @@ -64,8 +45,10 @@ internal static SslPolicyErrors VerifyCertificateProperties(
result = new X509Certificate2(remoteContext.DangerousGetHandle());
}

if (remoteCertificateStore != null)
if (retrieveChainCertificates)
{
chain ??= new X509Chain();

using (SafeSharedX509StackHandle chainStack =
Interop.OpenSsl.GetPeerCertificateChain(((SafeDeleteSslContext)securityContext).SslContext))
{
Expand All @@ -81,7 +64,7 @@ internal static SslPolicyErrors VerifyCertificateProperties(
{
// X509Certificate2(IntPtr) calls X509_dup, so the reference is appropriately tracked.
X509Certificate2 chainCert = new X509Certificate2(certPtr);
remoteCertificateStore.Add(chainCert);
chain.ChainPolicy.ExtraStore.Add(chainCert);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,9 @@ internal static SslPolicyErrors VerifyCertificateProperties(
// Extracts a remote certificate upon request.
//

internal static X509Certificate2? GetRemoteCertificate(SafeDeleteContext? securityContext) =>
GetRemoteCertificate(securityContext, retrieveCollection: false, out _);

internal static X509Certificate2? GetRemoteCertificate(SafeDeleteContext? securityContext, out X509Certificate2Collection? remoteCertificateCollection) =>
GetRemoteCertificate(securityContext, retrieveCollection: true, out remoteCertificateCollection);

private static X509Certificate2? GetRemoteCertificate(
SafeDeleteContext? securityContext, bool retrieveCollection, out X509Certificate2Collection? remoteCertificateCollection)
SafeDeleteContext? securityContext, bool retrieveChainCertificates, ref X509Chain? chain)
{
remoteCertificateCollection = null;

if (securityContext == null)
{
return null;
Expand All @@ -54,7 +46,7 @@ internal static SslPolicyErrors VerifyCertificateProperties(
//
// We can use retrieveCollection to distinguish between in-handshake and after-handshake calls, because
// the collection is retrieved for cert validation purposes after the handshake completes.
if (retrieveCollection) // handshake completed
if (retrieveChainCertificates) // handshake completed
{
SSPIWrapper.QueryContextAttributes_SECPKG_ATTR_REMOTE_CERT_CONTEXT(GlobalSSPI.SSPISecureChannel, securityContext, out remoteContext);
}
Expand All @@ -72,9 +64,11 @@ internal static SslPolicyErrors VerifyCertificateProperties(
{
if (remoteContext != null && !remoteContext.IsInvalid)
{
if (retrieveCollection)
if (retrieveChainCertificates)
{
remoteCertificateCollection = UnmanagedCertificateContext.GetRemoteCertificatesFromStoreContext(remoteContext);
chain ??= new X509Chain();

UnmanagedCertificateContext.GetRemoteCertificatesFromStoreContext(remoteContext, chain.ChainPolicy.ExtraStore);
}

remoteContext.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Net.Security;
using System.Security;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Runtime.InteropServices;

rzikm marked this conversation as resolved.
Show resolved Hide resolved
namespace System.Net
{
Expand All @@ -14,6 +16,13 @@ internal static partial class CertificateValidationPal

private static volatile X509Store? s_myCertStoreEx;
private static volatile X509Store? s_myMachineCertStoreEx;
private static X509Chain? s_chain;

internal static X509Certificate2? GetRemoteCertificate(SafeDeleteContext securityContext) =>
GetRemoteCertificate(securityContext, retrieveChainCertificates: false, ref s_chain);

internal static X509Certificate2? GetRemoteCertificate(SafeDeleteContext securityContext, ref X509Chain? chain) =>
GetRemoteCertificate(securityContext, retrieveChainCertificates: true, ref chain);

rzikm marked this conversation as resolved.
Show resolved Hide resolved
static partial void CheckSupportsStore(StoreLocation storeLocation, ref bool hasSupport);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -993,12 +993,10 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot
// We don't catch exceptions in this method, so it's safe for "accepted" be initialized with true.
bool success = false;
X509Chain? chain = null;
X509Certificate2Collection? remoteCertificateStore = null;

try
{
X509Certificate2? certificate = CertificateValidationPal.GetRemoteCertificate(_securityContext, out remoteCertificateStore);

X509Certificate2? certificate = CertificateValidationPal.GetRemoteCertificate(_securityContext!, ref chain);
if (_remoteCertificate != null && certificate != null &&
certificate.RawDataMemory.Span.SequenceEqual(_remoteCertificate.RawDataMemory.Span))
{
Expand All @@ -1016,18 +1014,17 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot
}
else
{
chain = new X509Chain();
if (chain == null)
{
chain = new X509Chain();
}
rzikm marked this conversation as resolved.
Show resolved Hide resolved

chain.ChainPolicy.RevocationMode = _sslAuthenticationOptions.CertificateRevocationCheckMode;
chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot;

// Authenticate the remote party: (e.g. when operating in server mode, authenticate the client).
chain.ChainPolicy.ApplicationPolicy.Add(_sslAuthenticationOptions.IsServer ? s_clientAuthOid : s_serverAuthOid);

if (remoteCertificateStore != null)
{
chain.ChainPolicy.ExtraStore.AddRange(remoteCertificateStore);
}

if (trust != null)
{
chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
Expand Down Expand Up @@ -1103,15 +1100,6 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot

chain.Dispose();
}

if (remoteCertificateStore != null)
{
int certCount = remoteCertificateStore.Count;
for (int i = 0; i < certCount; i++)
{
remoteCertificateStore[i].Dispose();
}
}
}

return success;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,26 @@ internal bool Build(X509Certificate2 certificate, bool throwOnException)
if (certificate == null || certificate.Pal == null)
throw new ArgumentException(SR.Cryptography_InvalidContextHandle, nameof(certificate));

if (_chainPolicy != null && _chainPolicy.CustomTrustStore != null)
if (_chainPolicy != null)
{
if (_chainPolicy.TrustMode == X509ChainTrustMode.System && _chainPolicy.CustomTrustStore.Count > 0)
throw new CryptographicException(SR.Cryptography_CustomTrustCertsInSystemMode);

foreach (X509Certificate2 customCertificate in _chainPolicy.CustomTrustStore)
if (_chainPolicy._customTrustStore != null)
{
if (customCertificate == null || customCertificate.Handle == IntPtr.Zero)
if (_chainPolicy.TrustMode == X509ChainTrustMode.System && _chainPolicy.CustomTrustStore.Count > 0)
throw new CryptographicException(SR.Cryptography_CustomTrustCertsInSystemMode);

foreach (X509Certificate2 customCertificate in _chainPolicy.CustomTrustStore)
{
throw new CryptographicException(SR.Cryptography_InvalidTrustCertificate);
if (customCertificate == null || customCertificate.Handle == IntPtr.Zero)
{
throw new CryptographicException(SR.Cryptography_InvalidTrustCertificate);
}
}
}

if (_chainPolicy.TrustMode == X509ChainTrustMode.CustomRootTrust && _chainPolicy._customTrustStore == null)
{
_chainPolicy._customTrustStore = new X509Certificate2Collection();
}
}

Reset();
Expand Down