Skip to content

Commit

Permalink
add ServerOptionsSelectionCallback to SslStream (#38760)
Browse files Browse the repository at this point in the history
* add ServerOptionsSelectionCallback to SslStream

* remove System.IO

* add missing file

* more tests

* feedback from review

* feedback from review

* skip SniSetVersion on win7

* fix IsNotWindows7

Co-authored-by: Tomas Weinfurt <furt@DESKTOP-SUKDQFN.corp.microsoft.com>
  • Loading branch information
wfurt and Tomas Weinfurt authored Jul 14, 2020
1 parent 0fecc40 commit 328d0cf
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/libraries/System.Net.Security/ref/System.Net.Security.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,15 @@ public enum ProtectionLevel
Sign = 1,
EncryptAndSign = 2,
}
public readonly struct SslClientHelloInfo
{
public string ServerName { get; }
public System.Security.Authentication.SslProtocols SslProtocols { get; }
internal SslClientHelloInfo(string serverName, System.Security.Authentication.SslProtocols sslProtocol) { throw null; }
}
public delegate bool RemoteCertificateValidationCallback(object sender, System.Security.Cryptography.X509Certificates.X509Certificate? certificate, System.Security.Cryptography.X509Certificates.X509Chain? chain, System.Net.Security.SslPolicyErrors sslPolicyErrors);
public delegate System.Security.Cryptography.X509Certificates.X509Certificate ServerCertificateSelectionCallback(object sender, string? hostName);
public delegate System.Threading.Tasks.ValueTask<SslServerAuthenticationOptions> ServerOptionsSelectionCallback(SslStream stream, SslClientHelloInfo clientHelloInfo, object? state, System.Threading.CancellationToken cancellationToken);
public readonly partial struct SslApplicationProtocol : System.IEquatable<System.Net.Security.SslApplicationProtocol>
{
private readonly object _dummy;
Expand Down Expand Up @@ -236,6 +243,7 @@ public void Write(byte[] buffer) { }
public override void Write(byte[] buffer, int offset, int count) { }
public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; }
public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory<byte> buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionsCallback, object? state, System.Threading.CancellationToken cancellationToken = default) { throw null; }
}
[System.CLSCompliantAttribute(false)]
public enum TlsCipherSuite : ushort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
<Compile Include="System\Net\Security\SslApplicationProtocol.cs" />
<Compile Include="System\Net\Security\SslAuthenticationOptions.cs" />
<Compile Include="System\Net\Security\SslClientAuthenticationOptions.cs" />
<Compile Include="System\Net\Security\SslClientHelloInfo.cs" />
<Compile Include="System\Net\Security\SslServerAuthenticationOptions.cs" />
<Compile Include="System\Net\Security\SecureChannel.cs" />
<Compile Include="System\Net\Security\SslSessionsCache.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,36 @@ internal SslAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthen
}
}

internal SslAuthenticationOptions(ServerOptionsSelectionCallback optionCallback, object? state)
{
CheckCertName = false;
TargetHost = string.Empty;
IsServer = true;
UserState = state;
ServerOptionDelegate = optionCallback;
}

internal void UpdateOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions)
{
AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation;
ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols;
EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslServerAuthenticationOptions.EnabledSslProtocols);
EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy;
RemoteCertRequired = sslServerAuthenticationOptions.ClientCertificateRequired;
CipherSuitesPolicy = sslServerAuthenticationOptions.CipherSuitesPolicy;
CertificateRevocationCheckMode = sslServerAuthenticationOptions.CertificateRevocationCheckMode;
if (sslServerAuthenticationOptions.ServerCertificateContext != null)
{
CertificateContext = sslServerAuthenticationOptions.ServerCertificateContext;
}
else if (sslServerAuthenticationOptions.ServerCertificate is X509Certificate2 certificateWithKey &&
certificateWithKey.HasPrivateKey)
{
// given cert is X509Certificate2 with key. We can use it directly.
CertificateContext = SslStreamCertificateContext.Create(certificateWithKey);
}
}

private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols protocols)
{
if (protocols.HasFlag(SslProtocols.Tls12) || protocols.HasFlag(SslProtocols.Tls13))
Expand All @@ -98,7 +128,7 @@ private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols proto
internal bool AllowRenegotiation { get; set; }
internal string TargetHost { get; set; }
internal X509CertificateCollection? ClientCertificates { get; set; }
internal List<SslApplicationProtocol>? ApplicationProtocols { get; }
internal List<SslApplicationProtocol>? ApplicationProtocols { get; set; }
internal bool IsServer { get; set; }
internal SslStreamCertificateContext? CertificateContext { get; set; }
internal SslProtocols EnabledSslProtocols { get; set; }
Expand All @@ -110,5 +140,7 @@ private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols proto
internal LocalCertSelectionCallback? CertSelectionDelegate { get; set; }
internal ServerCertSelectionCallback? ServerCertSelectionDelegate { get; set; }
internal CipherSuitesPolicy? CipherSuitesPolicy { get; set; }
internal object? UserState { get; }
internal ServerOptionsSelectionCallback? ServerOptionDelegate { get; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Security.Authentication;

namespace System.Net.Security
{
/// <summary>
/// This struct contains information from received TLS Client Hello frame.
/// </summary>
public readonly struct SslClientHelloInfo
{
public readonly string ServerName { get; }
public readonly SslProtocols SslProtocols { get; }

internal SslClientHelloInfo(string serverName, SslProtocols sslProtocols)
{
ServerName = serverName;
SslProtocols = sslProtocols;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private enum Framing

private TlsFrameHelper.TlsFrameInfo _lastFrame;

private readonly object _handshakeLock = new object();
private object _handshakeLock => _sslAuthenticationOptions!;
private volatile TaskCompletionSource<bool>? _handshakeWaiter;

private const int FrameOverhead = 32;
Expand Down Expand Up @@ -403,9 +403,9 @@ private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(TIOAdapter a
}
else if (_lastFrame.Header.Type == TlsContentType.Handshake)
{
if ((_handshakeBuffer.ActiveReadOnlySpan[TlsFrameHelper.HeaderSize] == (byte)TlsHandshakeType.ClientHello &&
_sslAuthenticationOptions!.ServerCertSelectionDelegate != null) ||
NetEventSource.Log.IsEnabled())
if (_handshakeBuffer.ActiveReadOnlySpan[TlsFrameHelper.HeaderSize] == (byte)TlsHandshakeType.ClientHello &&
(_sslAuthenticationOptions!.ServerCertSelectionDelegate != null ||
_sslAuthenticationOptions!.ServerOptionDelegate != null))
{
TlsFrameHelper.ProcessingOptions options = NetEventSource.Log.IsEnabled() ?
TlsFrameHelper.ProcessingOptions.All :
Expand All @@ -421,6 +421,14 @@ private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(TIOAdapter a
{
// SNI if it exist. Even if we could not parse the hello, we can fall-back to default certificate.
_sslAuthenticationOptions!.TargetHost = _lastFrame.TargetName;

if (_sslAuthenticationOptions.ServerOptionDelegate != null)
{
SslServerAuthenticationOptions userOptions =
await _sslAuthenticationOptions.ServerOptionDelegate(this, new SslClientHelloInfo(_lastFrame.TargetName, _lastFrame.SupportedVersions),
_sslAuthenticationOptions.UserState, adapter.CancellationToken).ConfigureAwait(false);
_sslAuthenticationOptions.UpdateOptions(userOptions);
}
}

if (NetEventSource.Log.IsEnabled())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public enum EncryptionPolicy

public delegate X509Certificate ServerCertificateSelectionCallback(object sender, string? hostName);

public delegate ValueTask<SslServerAuthenticationOptions> ServerOptionsSelectionCallback(SslStream stream, SslClientHelloInfo clientHelloInfo, object? state, CancellationToken cancellationToken);

// Internal versions of the above delegates.
internal delegate bool RemoteCertValidationCallback(string? host, X509Certificate2? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors);
internal delegate X509Certificate LocalCertSelectionCallback(string targetHost, X509CertificateCollection localCertificates, X509Certificate2? remoteCertificate, string[] acceptableIssuers);
Expand Down Expand Up @@ -453,6 +455,12 @@ private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAut
return ProcessAuthentication(true, true, cancellationToken)!;
}

public Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionsCallback, object? state, CancellationToken cancellationToken = default)
{
ValidateCreateContext(new SslAuthenticationOptions(optionsCallback, state));
return ProcessAuthentication(isAsync: true, isApm: false, cancellationToken)!;
}

public virtual Task ShutdownAsync()
{
ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Net.Test.Common;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;

using Xunit;
Expand All @@ -22,6 +23,8 @@ public class ServerAsyncAuthenticateTest : IDisposable
private readonly ITestOutputHelper _logVerbose;
private readonly X509Certificate2 _serverCertificate;

public static bool IsNotWindows7 => !PlatformDetection.IsWindows7;

public ServerAsyncAuthenticateTest(ITestOutputHelper output)
{
_log = output;
Expand Down Expand Up @@ -69,6 +72,159 @@ public async Task ServerAsyncAuthenticate_AllClientVsIndividualServerSupportedPr
await ServerAsyncSslHelper(SslProtocolSupport.SupportedSslProtocols, serverProtocol);
}

[Fact]
public async Task ServerAsyncAuthenticate_SimpleSniOptions_Success()
{
var state = new object();
var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate };
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
using (client)
using (server)
{
Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
Task t2 = server.AuthenticateAsServerAsync(
(stream, clientHelloInfo, userState, cancellationToken) =>
{
Assert.Equal(server, stream);
Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName);
Assert.True(object.ReferenceEquals(state, userState));
return new ValueTask<SslServerAuthenticationOptions>(serverOptions);
},
state, CancellationToken.None);

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
}
}

[ConditionalTheory(nameof(IsNotWindows7))]
[InlineData(SslProtocols.Tls11)]
[InlineData(SslProtocols.Tls12)]
public async Task ServerAsyncAuthenticate_SniSetVersion_Success(SslProtocols version)
{
var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate, EnabledSslProtocols = version };
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, forIssuer: false) };
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
using (client)
using (server)
{
Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
Task t2 = server.AuthenticateAsServerAsync(
(stream, clientHelloInfo, userState, cancellationToken) =>
{
Assert.Equal(server, stream);
Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName);
return new ValueTask<SslServerAuthenticationOptions>(serverOptions);
},
null, CancellationToken.None);

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
// Verify that the SNI callback can impact version.
Assert.Equal(version, client.SslProtocol);
}
}

private async Task<SslServerAuthenticationOptions> FailedTask()
{
await Task.Yield();
throw new InvalidOperationException("foo");
}

private async Task<SslServerAuthenticationOptions> OptionsTask(SslServerAuthenticationOptions value)
{
await Task.Yield();
return value;
}

[Fact]
public async Task ServerAsyncAuthenticate_AsyncOptions_Success()
{
var state = new object();
var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate };
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
using (client)
using (server)
{
Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
Task t2 = server.AuthenticateAsServerAsync(
(stream, clientHelloInfo, userState, cancellationToken) =>
{
Assert.Equal(server, stream);
Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName);
Assert.True(object.ReferenceEquals(state, userState));
return new ValueTask<SslServerAuthenticationOptions>(OptionsTask(serverOptions));
},
state, CancellationToken.None);

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ServerAsyncAuthenticate_FailingOptionCallback_Throws(bool useAsync)
{
var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate };
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
using (client)
using (server)
{
Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
Task t2 = server.AuthenticateAsServerAsync(
(stream, clientHelloInfo, userState, cancellationToken) =>
{
if (useAsync)
{
return new ValueTask<SslServerAuthenticationOptions>(FailedTask());
}

throw new InvalidOperationException("foo");
},
null, CancellationToken.None);
await Assert.ThrowsAsync<InvalidOperationException>(() => t2);
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ServerAsyncAuthenticate_NoCertificate_Throws(bool useAsync)
{
var serverOptions = new SslServerAuthenticationOptions();
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
using (client)
using (server)
{
Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
Task t2 = server.AuthenticateAsServerAsync(
(stream, clientHelloInfo, userState, cancellationToken) =>
{
if (useAsync)
{
return new ValueTask<SslServerAuthenticationOptions>(serverOptions);
}

return new ValueTask<SslServerAuthenticationOptions>(OptionsTask(serverOptions));
},
null, CancellationToken.None);
await Assert.ThrowsAsync<System.NotSupportedException>(() => t2);
}
}

public static IEnumerable<object[]> ProtocolMismatchData()
{
if (PlatformDetection.SupportsSsl3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ public static class TestHelper
private static readonly X509BasicConstraintsExtension s_eeConstraints =
new X509BasicConstraintsExtension(false, false, 0, false);

public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams()
{
(Stream clientStream, Stream serverStream) = GetConnectedStreams();
return (new SslStream(clientStream), new SslStream(serverStream));
}

public static (Stream ClientStream, Stream ServerStream) GetConnectedStreams()
{
if (Capability.SecurityForceSocketStreams())
Expand Down

0 comments on commit 328d0cf

Please sign in to comment.