diff --git a/src/libraries/System.Net.Security/ref/System.Net.Security.cs b/src/libraries/System.Net.Security/ref/System.Net.Security.cs index 0369a60af9c7a..3ef7ad8ab16fd 100644 --- a/src/libraries/System.Net.Security/ref/System.Net.Security.cs +++ b/src/libraries/System.Net.Security/ref/System.Net.Security.cs @@ -104,8 +104,15 @@ public enum ProtectionLevel Sign = 1, EncryptAndSign = 2, } + public readonly struct SslClientHelloInfo + { + public string ServerName { get; } + public System.Security.Authentication.SslProtocols SslProtocols { get; } + internal SslClientHelloInfo(string serverName, System.Security.Authentication.SslProtocols sslProtocol) { throw null; } + } public delegate bool RemoteCertificateValidationCallback(object sender, System.Security.Cryptography.X509Certificates.X509Certificate? certificate, System.Security.Cryptography.X509Certificates.X509Chain? chain, System.Net.Security.SslPolicyErrors sslPolicyErrors); public delegate System.Security.Cryptography.X509Certificates.X509Certificate ServerCertificateSelectionCallback(object sender, string? hostName); + public delegate System.Threading.Tasks.ValueTask ServerOptionsSelectionCallback(SslStream stream, SslClientHelloInfo clientHelloInfo, object? state, System.Threading.CancellationToken cancellationToken); public readonly partial struct SslApplicationProtocol : System.IEquatable { private readonly object _dummy; @@ -236,6 +243,7 @@ public void Write(byte[] buffer) { } public override void Write(byte[] buffer, int offset, int count) { } public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; } public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public System.Threading.Tasks.Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionsCallback, object? state, System.Threading.CancellationToken cancellationToken = default) { throw null; } } [System.CLSCompliantAttribute(false)] public enum TlsCipherSuite : ushort diff --git a/src/libraries/System.Net.Security/src/System.Net.Security.csproj b/src/libraries/System.Net.Security/src/System.Net.Security.csproj index e0edf84a0faa3..30737df95bbcc 100644 --- a/src/libraries/System.Net.Security/src/System.Net.Security.csproj +++ b/src/libraries/System.Net.Security/src/System.Net.Security.csproj @@ -23,6 +23,7 @@ + diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs index 311346fab0f59..e229ddd0e72e2 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslAuthenticationOptions.cs @@ -80,6 +80,36 @@ internal SslAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthen } } + internal SslAuthenticationOptions(ServerOptionsSelectionCallback optionCallback, object? state) + { + CheckCertName = false; + TargetHost = string.Empty; + IsServer = true; + UserState = state; + ServerOptionDelegate = optionCallback; + } + + internal void UpdateOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions) + { + AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation; + ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols; + EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslServerAuthenticationOptions.EnabledSslProtocols); + EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy; + RemoteCertRequired = sslServerAuthenticationOptions.ClientCertificateRequired; + CipherSuitesPolicy = sslServerAuthenticationOptions.CipherSuitesPolicy; + CertificateRevocationCheckMode = sslServerAuthenticationOptions.CertificateRevocationCheckMode; + if (sslServerAuthenticationOptions.ServerCertificateContext != null) + { + CertificateContext = sslServerAuthenticationOptions.ServerCertificateContext; + } + else if (sslServerAuthenticationOptions.ServerCertificate is X509Certificate2 certificateWithKey && + certificateWithKey.HasPrivateKey) + { + // given cert is X509Certificate2 with key. We can use it directly. + CertificateContext = SslStreamCertificateContext.Create(certificateWithKey); + } + } + private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols protocols) { if (protocols.HasFlag(SslProtocols.Tls12) || protocols.HasFlag(SslProtocols.Tls13)) @@ -98,7 +128,7 @@ private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols proto internal bool AllowRenegotiation { get; set; } internal string TargetHost { get; set; } internal X509CertificateCollection? ClientCertificates { get; set; } - internal List? ApplicationProtocols { get; } + internal List? ApplicationProtocols { get; set; } internal bool IsServer { get; set; } internal SslStreamCertificateContext? CertificateContext { get; set; } internal SslProtocols EnabledSslProtocols { get; set; } @@ -110,5 +140,7 @@ private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols proto internal LocalCertSelectionCallback? CertSelectionDelegate { get; set; } internal ServerCertSelectionCallback? ServerCertSelectionDelegate { get; set; } internal CipherSuitesPolicy? CipherSuitesPolicy { get; set; } + internal object? UserState { get; } + internal ServerOptionsSelectionCallback? ServerOptionDelegate { get; } } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslClientHelloInfo.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslClientHelloInfo.cs new file mode 100644 index 0000000000000..16a19935c206f --- /dev/null +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslClientHelloInfo.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Security.Authentication; + +namespace System.Net.Security +{ + /// + /// This struct contains information from received TLS Client Hello frame. + /// + public readonly struct SslClientHelloInfo + { + public readonly string ServerName { get; } + public readonly SslProtocols SslProtocols { get; } + + internal SslClientHelloInfo(string serverName, SslProtocols sslProtocols) + { + ServerName = serverName; + SslProtocols = sslProtocols; + } + } +} diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs index 95644169dd15a..afda3eb1de5e7 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs @@ -35,7 +35,7 @@ private enum Framing private TlsFrameHelper.TlsFrameInfo _lastFrame; - private readonly object _handshakeLock = new object(); + private object _handshakeLock => _sslAuthenticationOptions!; private volatile TaskCompletionSource? _handshakeWaiter; private const int FrameOverhead = 32; @@ -403,9 +403,9 @@ private async ValueTask ReceiveBlobAsync(TIOAdapter a } else if (_lastFrame.Header.Type == TlsContentType.Handshake) { - if ((_handshakeBuffer.ActiveReadOnlySpan[TlsFrameHelper.HeaderSize] == (byte)TlsHandshakeType.ClientHello && - _sslAuthenticationOptions!.ServerCertSelectionDelegate != null) || - NetEventSource.Log.IsEnabled()) + if (_handshakeBuffer.ActiveReadOnlySpan[TlsFrameHelper.HeaderSize] == (byte)TlsHandshakeType.ClientHello && + (_sslAuthenticationOptions!.ServerCertSelectionDelegate != null || + _sslAuthenticationOptions!.ServerOptionDelegate != null)) { TlsFrameHelper.ProcessingOptions options = NetEventSource.Log.IsEnabled() ? TlsFrameHelper.ProcessingOptions.All : @@ -421,6 +421,14 @@ private async ValueTask ReceiveBlobAsync(TIOAdapter a { // SNI if it exist. Even if we could not parse the hello, we can fall-back to default certificate. _sslAuthenticationOptions!.TargetHost = _lastFrame.TargetName; + + if (_sslAuthenticationOptions.ServerOptionDelegate != null) + { + SslServerAuthenticationOptions userOptions = + await _sslAuthenticationOptions.ServerOptionDelegate(this, new SslClientHelloInfo(_lastFrame.TargetName, _lastFrame.SupportedVersions), + _sslAuthenticationOptions.UserState, adapter.CancellationToken).ConfigureAwait(false); + _sslAuthenticationOptions.UpdateOptions(userOptions); + } } if (NetEventSource.Log.IsEnabled()) diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs index e4817767dc2ac..523935e3533bb 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -34,6 +34,8 @@ public enum EncryptionPolicy public delegate X509Certificate ServerCertificateSelectionCallback(object sender, string? hostName); + public delegate ValueTask ServerOptionsSelectionCallback(SslStream stream, SslClientHelloInfo clientHelloInfo, object? state, CancellationToken cancellationToken); + // Internal versions of the above delegates. internal delegate bool RemoteCertValidationCallback(string? host, X509Certificate2? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors); internal delegate X509Certificate LocalCertSelectionCallback(string targetHost, X509CertificateCollection localCertificates, X509Certificate2? remoteCertificate, string[] acceptableIssuers); @@ -453,6 +455,12 @@ private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAut return ProcessAuthentication(true, true, cancellationToken)!; } + public Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionsCallback, object? state, CancellationToken cancellationToken = default) + { + ValidateCreateContext(new SslAuthenticationOptions(optionsCallback, state)); + return ProcessAuthentication(isAsync: true, isApm: false, cancellationToken)!; + } + public virtual Task ShutdownAsync() { ThrowIfExceptionalOrNotAuthenticatedOrShutdown(); diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/ServerAsyncAuthenticateTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/ServerAsyncAuthenticateTest.cs index 3425fd4f7a0f4..1518cfc3fa545 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/ServerAsyncAuthenticateTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/ServerAsyncAuthenticateTest.cs @@ -7,6 +7,7 @@ using System.Net.Test.Common; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -22,6 +23,8 @@ public class ServerAsyncAuthenticateTest : IDisposable private readonly ITestOutputHelper _logVerbose; private readonly X509Certificate2 _serverCertificate; + public static bool IsNotWindows7 => !PlatformDetection.IsWindows7; + public ServerAsyncAuthenticateTest(ITestOutputHelper output) { _log = output; @@ -69,6 +72,159 @@ public async Task ServerAsyncAuthenticate_AllClientVsIndividualServerSupportedPr await ServerAsyncSslHelper(SslProtocolSupport.SupportedSslProtocols, serverProtocol); } + [Fact] + public async Task ServerAsyncAuthenticate_SimpleSniOptions_Success() + { + var state = new object(); + var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate }; + var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) }; + clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true; + + (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(); + using (client) + using (server) + { + Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None); + Task t2 = server.AuthenticateAsServerAsync( + (stream, clientHelloInfo, userState, cancellationToken) => + { + Assert.Equal(server, stream); + Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName); + Assert.True(object.ReferenceEquals(state, userState)); + return new ValueTask(serverOptions); + }, + state, CancellationToken.None); + + await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2); + } + } + + [ConditionalTheory(nameof(IsNotWindows7))] + [InlineData(SslProtocols.Tls11)] + [InlineData(SslProtocols.Tls12)] + public async Task ServerAsyncAuthenticate_SniSetVersion_Success(SslProtocols version) + { + var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate, EnabledSslProtocols = version }; + var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, forIssuer: false) }; + clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true; + + (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(); + using (client) + using (server) + { + Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None); + Task t2 = server.AuthenticateAsServerAsync( + (stream, clientHelloInfo, userState, cancellationToken) => + { + Assert.Equal(server, stream); + Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName); + return new ValueTask(serverOptions); + }, + null, CancellationToken.None); + + await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2); + // Verify that the SNI callback can impact version. + Assert.Equal(version, client.SslProtocol); + } + } + + private async Task FailedTask() + { + await Task.Yield(); + throw new InvalidOperationException("foo"); + } + + private async Task OptionsTask(SslServerAuthenticationOptions value) + { + await Task.Yield(); + return value; + } + + [Fact] + public async Task ServerAsyncAuthenticate_AsyncOptions_Success() + { + var state = new object(); + var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate }; + var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) }; + clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true; + + (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(); + using (client) + using (server) + { + Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None); + Task t2 = server.AuthenticateAsServerAsync( + (stream, clientHelloInfo, userState, cancellationToken) => + { + Assert.Equal(server, stream); + Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName); + Assert.True(object.ReferenceEquals(state, userState)); + return new ValueTask(OptionsTask(serverOptions)); + }, + state, CancellationToken.None); + + await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ServerAsyncAuthenticate_FailingOptionCallback_Throws(bool useAsync) + { + var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate }; + var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) }; + clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true; + + (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(); + using (client) + using (server) + { + Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None); + Task t2 = server.AuthenticateAsServerAsync( + (stream, clientHelloInfo, userState, cancellationToken) => + { + if (useAsync) + { + return new ValueTask(FailedTask()); + } + + throw new InvalidOperationException("foo"); + }, + null, CancellationToken.None); + await Assert.ThrowsAsync(() => t2); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ServerAsyncAuthenticate_NoCertificate_Throws(bool useAsync) + { + var serverOptions = new SslServerAuthenticationOptions(); + var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) }; + clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true; + + (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(); + using (client) + using (server) + { + Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None); + Task t2 = server.AuthenticateAsServerAsync( + (stream, clientHelloInfo, userState, cancellationToken) => + { + if (useAsync) + { + return new ValueTask(serverOptions); + } + + return new ValueTask(OptionsTask(serverOptions)); + }, + null, CancellationToken.None); + await Assert.ThrowsAsync(() => t2); + } + } + public static IEnumerable ProtocolMismatchData() { if (PlatformDetection.SupportsSsl3) diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs index 9de9618267131..b89ecd664da32 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs @@ -29,6 +29,12 @@ public static class TestHelper private static readonly X509BasicConstraintsExtension s_eeConstraints = new X509BasicConstraintsExtension(false, false, 0, false); + public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams() + { + (Stream clientStream, Stream serverStream) = GetConnectedStreams(); + return (new SslStream(clientStream), new SslStream(serverStream)); + } + public static (Stream ClientStream, Stream ServerStream) GetConnectedStreams() { if (Capability.SecurityForceSocketStreams())