Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ServerOptionsSelectionCallback to SslStream #38760

Merged
merged 10 commits into from
Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/libraries/System.Net.Security/ref/System.Net.Security.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,16 @@ public enum ProtectionLevel
Sign = 1,
EncryptAndSign = 2,
}
public readonly struct SslClientHelloInfo
{
public string ServerName { get; }
public System.Security.Authentication.SslProtocols SslProtocols { get; }
internal SslClientHelloInfo(string serverName, System.Security.Authentication.SslProtocols sslProtocol) { throw null; }
}
public delegate bool RemoteCertificateValidationCallback(object sender, System.Security.Cryptography.X509Certificates.X509Certificate? certificate, System.Security.Cryptography.X509Certificates.X509Chain? chain, System.Net.Security.SslPolicyErrors sslPolicyErrors);
public delegate System.Security.Cryptography.X509Certificates.X509Certificate ServerCertificateSelectionCallback(object sender, string? hostName);
public delegate System.Threading.Tasks.ValueTask<SslServerAuthenticationOptions> ServerOptionsSelectionCallback(SslStream stream, SslClientHelloInfo clientHelloInfo, object? state, System.Threading.CancellationToken cancellationToken);

wfurt marked this conversation as resolved.
Show resolved Hide resolved
public readonly partial struct SslApplicationProtocol : System.IEquatable<System.Net.Security.SslApplicationProtocol>
{
private readonly object _dummy;
Expand Down Expand Up @@ -236,6 +244,7 @@ public void Write(byte[] buffer) { }
public override void Write(byte[] buffer, int offset, int count) { }
public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; }
public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory<byte> buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.Task AuthenticateAsServerAsync( ServerOptionsSelectionCallback optionCallback, object? state, System.Threading.CancellationToken cancellationToken = default) { throw null; }
wfurt marked this conversation as resolved.
Show resolved Hide resolved
wfurt marked this conversation as resolved.
Show resolved Hide resolved
}
[System.CLSCompliantAttribute(false)]
public enum TlsCipherSuite : ushort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
<Compile Include="System\Net\Security\SslApplicationProtocol.cs" />
<Compile Include="System\Net\Security\SslAuthenticationOptions.cs" />
<Compile Include="System\Net\Security\SslClientAuthenticationOptions.cs" />
<Compile Include="System\Net\Security\SslClientHelloInfo.cs" />
<Compile Include="System\Net\Security\SslServerAuthenticationOptions.cs" />
<Compile Include="System\Net\Security\SecureChannel.cs" />
<Compile Include="System\Net\Security\SslSessionsCache.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,40 @@ internal SslAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthen
}
}

internal SslAuthenticationOptions(ServerOptionsSelectionCallback optionCallback, object? state)
{
CheckCertName = false;
TargetHost = string.Empty;
IsServer = true;
UserState = state;
ServerOptionDelegate = optionCallback;
}

internal void UpdateOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions)
{
AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation;
ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols;
EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslServerAuthenticationOptions.EnabledSslProtocols);
EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy;
RemoteCertRequired = sslServerAuthenticationOptions.ClientCertificateRequired;
CipherSuitesPolicy = sslServerAuthenticationOptions.CipherSuitesPolicy;
CertificateRevocationCheckMode = sslServerAuthenticationOptions.CertificateRevocationCheckMode;
if (sslServerAuthenticationOptions.ServerCertificateContext != null)
{
CertificateContext = sslServerAuthenticationOptions.ServerCertificateContext;
}
else if (sslServerAuthenticationOptions.ServerCertificate != null)
{
X509Certificate2? certificateWithKey = sslServerAuthenticationOptions.ServerCertificate as X509Certificate2;

if (certificateWithKey != null && certificateWithKey.HasPrivateKey)
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
// given cert is X509Certificate2 with key. We can use it directly.
CertificateContext = SslStreamCertificateContext.Create(certificateWithKey);
}
}
}

private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols protocols)
{
if (protocols.HasFlag(SslProtocols.Tls12) || protocols.HasFlag(SslProtocols.Tls13))
Expand All @@ -98,7 +132,7 @@ private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols proto
internal bool AllowRenegotiation { get; set; }
internal string TargetHost { get; set; }
internal X509CertificateCollection? ClientCertificates { get; set; }
internal List<SslApplicationProtocol>? ApplicationProtocols { get; }
internal List<SslApplicationProtocol>? ApplicationProtocols { get; set; }
internal bool IsServer { get; set; }
internal SslStreamCertificateContext? CertificateContext { get; set; }
internal SslProtocols EnabledSslProtocols { get; set; }
Expand All @@ -110,5 +144,7 @@ private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols proto
internal LocalCertSelectionCallback? CertSelectionDelegate { get; set; }
internal ServerCertSelectionCallback? ServerCertSelectionDelegate { get; set; }
internal CipherSuitesPolicy? CipherSuitesPolicy { get; set; }
internal object? UserState { get; set; }
internal ServerOptionsSelectionCallback? ServerOptionDelegate { get; set; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these ever set outside of the ctor? Looks like they can be made read-only.

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Security.Authentication;

namespace System.Net.Security
{
public readonly struct SslClientHelloInfo
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
public readonly string ServerName { get; }
public readonly SslProtocols SslProtocols { get; }

internal SslClientHelloInfo(string serverName, SslProtocols sslProtocols)
{
ServerName = serverName;
SslProtocols = sslProtocols;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private enum Framing

private TlsFrameHelper.TlsFrameInfo _lastFrame;

private readonly object _handshakeLock = new object();
private object _handshakeLock => _sslAuthenticationOptions!;
private volatile TaskCompletionSource<bool>? _handshakeWaiter;

private const int FrameOverhead = 32;
Expand Down Expand Up @@ -408,7 +408,8 @@ private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(TIOAdapter a
else if (_lastFrame.Header.Type == TlsContentType.Handshake)
{
if ((_handshakeBuffer.ActiveReadOnlySpan[TlsFrameHelper.HeaderSize] == (byte)TlsHandshakeType.ClientHello &&
_sslAuthenticationOptions!.ServerCertSelectionDelegate != null) ||
(_sslAuthenticationOptions!.ServerCertSelectionDelegate != null ||
_sslAuthenticationOptions!.ServerOptionDelegate != null)) ||
NetEventSource.IsEnabled)
{
TlsFrameHelper.ProcessingOptions options = NetEventSource.IsEnabled ?
Expand All @@ -425,6 +426,23 @@ private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(TIOAdapter a
{
// SNI if it exist. Even if we could not parse the hello, we can fall-back to default certificate.
_sslAuthenticationOptions!.TargetHost = _lastFrame.TargetName;

if (_sslAuthenticationOptions.ServerOptionDelegate != null)
{
ValueTask<SslServerAuthenticationOptions> t =
_sslAuthenticationOptions.ServerOptionDelegate(this, new SslClientHelloInfo(_lastFrame.TargetName, _lastFrame.SupportedVersions),
_sslAuthenticationOptions.UserState, adapter.CancellationToken);

if (t.IsCompletedSuccessfully)
{
_sslAuthenticationOptions.UpdateOptions(t.Result);
}
else
{
SslServerAuthenticationOptions userOptions = await t.ConfigureAwait(false);
_sslAuthenticationOptions.UpdateOptions(userOptions);
}
}
wfurt marked this conversation as resolved.
Show resolved Hide resolved
}

if (NetEventSource.IsEnabled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public enum EncryptionPolicy

public delegate X509Certificate ServerCertificateSelectionCallback(object sender, string? hostName);

public delegate ValueTask<SslServerAuthenticationOptions> ServerOptionsSelectionCallback(SslStream stream, SslClientHelloInfo clientHelloInfo, object? state, CancellationToken cancellationToken);

// Internal versions of the above delegates.
internal delegate bool RemoteCertValidationCallback(string? host, X509Certificate2? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors);
internal delegate X509Certificate LocalCertSelectionCallback(string targetHost, X509CertificateCollection localCertificates, X509Certificate2? remoteCertificate, string[] acceptableIssuers);
Expand Down Expand Up @@ -453,6 +455,12 @@ private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAut
return ProcessAuthentication(true, true, cancellationToken)!;
}

public Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionCallback, object? state, CancellationToken cancellationToken = default)
{
ValidateCreateContext(new SslAuthenticationOptions(optionCallback, state));
return ProcessAuthentication(true, false, cancellationToken)!;
wfurt marked this conversation as resolved.
Show resolved Hide resolved
}

public virtual Task ShutdownAsync()
{
ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,159 @@ public async Task ServerAsyncAuthenticate_AllClientVsIndividualServerSupportedPr
await ServerAsyncSslHelper(SslProtocolSupport.SupportedSslProtocols, serverProtocol);
}

[Fact]
public async Task ServerAsyncAuthenticate_SimpleSniOptions_Success()
{
var state = new Object();
wfurt marked this conversation as resolved.
Show resolved Hide resolved
var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate };
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, forIssuer: false) };

clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really care about this but for future reference if you aren't using any of the parameters you can also do:

Suggested change
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
clientOptions.RemoteCertificateValidationCallback = delegate { return true; }


(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
using (client)
using (server)
{
Task t1 = client.AuthenticateAsClientAsync(clientOptions, default);
wfurt marked this conversation as resolved.
Show resolved Hide resolved
Task t2 = server.AuthenticateAsServerAsync(
(stream, clientHelloInfo, userState, cancellationToken) =>
{
Assert.Equal(server, stream);
Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName);
Assert.True(object.ReferenceEquals(state, userState));
return new ValueTask<SslServerAuthenticationOptions>(serverOptions);
},
state, default);
wfurt marked this conversation as resolved.
Show resolved Hide resolved

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
}
}

[Theory]
[InlineData(SslProtocols.Tls11)]
[InlineData(SslProtocols.Tls12)]
public async Task ServerAsyncAuthenticate_SniSetVersion_Success(SslProtocols version)
{
var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate, EnabledSslProtocols = version };
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
wfurt marked this conversation as resolved.
Show resolved Hide resolved
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
using (client)
using (server)
{
Task t1 = client.AuthenticateAsClientAsync(clientOptions, default);
wfurt marked this conversation as resolved.
Show resolved Hide resolved
Task t2 = server.AuthenticateAsServerAsync(
(stream, clientHelloInfo, userState, cancellationToken) =>
{
Assert.Equal(server, stream);
Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName);
return new ValueTask<SslServerAuthenticationOptions>(serverOptions);
},
null, default);
wfurt marked this conversation as resolved.
Show resolved Hide resolved

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
// Verify that the SNI callback can impact version.
Assert.Equal(version, client.SslProtocol);
}
}

private async Task<SslServerAuthenticationOptions> FailedTask()
{
await Task.Delay(10);
wfurt marked this conversation as resolved.
Show resolved Hide resolved
throw new InvalidOperationException("foo");
}

private async Task<SslServerAuthenticationOptions> OptionsTask(SslServerAuthenticationOptions value)
{
await Task.Delay(10);
wfurt marked this conversation as resolved.
Show resolved Hide resolved
return value;
}

[Fact]
public async Task ServerAsyncAuthenticate_AsyncOptions_Success()
{
var state = new Object();
wfurt marked this conversation as resolved.
Show resolved Hide resolved
var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate };
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, forIssuer: false) };

(there's more of this elsewhere in the other tests, going to stop adding comments now!)

clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
using (client)
using (server)
{
Task t1 = client.AuthenticateAsClientAsync(clientOptions, default);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Task t1 = client.AuthenticateAsClientAsync(clientOptions, default);
Task t1 = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);

Task t2 = server.AuthenticateAsServerAsync(
(stream, clientHelloInfo, userState, cancellationToken) =>
{
Assert.Equal(server, stream);
Assert.Equal(clientOptions.TargetHost, clientHelloInfo.ServerName);
Assert.True(object.ReferenceEquals(state, userState));
return new ValueTask<SslServerAuthenticationOptions>(OptionsTask(serverOptions));
},
state, default);
wfurt marked this conversation as resolved.
Show resolved Hide resolved

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ServerAsyncAuthenticate_FailingOptionCallback_Throws(bool useAsync)
{
var serverOptions = new SslServerAuthenticationOptions() { ServerCertificate = _serverCertificate };
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
using (client)
using (server)
{
Task t1 = client.AuthenticateAsClientAsync(clientOptions, default);
Task t2 = server.AuthenticateAsServerAsync(
(stream, clientHelloInfo, userState, cancellationToken) =>
{
if (useAsync)
{
return new ValueTask<SslServerAuthenticationOptions>(FailedTask());
}

throw new InvalidOperationException("foo");
},
null, default);
await Assert.ThrowsAsync<InvalidOperationException>(() => t2);
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ServerAsyncAuthenticate_NoCertificate_Throws(bool useAsync)
{
var serverOptions = new SslServerAuthenticationOptions();
var clientOptions = new SslClientAuthenticationOptions() { TargetHost = _serverCertificate.GetNameInfo(X509NameType.SimpleName, false) };
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams();
using (client)
using (server)
{
Task t1 = client.AuthenticateAsClientAsync(clientOptions, default);
Task t2 = server.AuthenticateAsServerAsync(
(stream, clientHelloInfo, userState, cancellationToken) =>
{
if (useAsync)
{
return new ValueTask<SslServerAuthenticationOptions>(serverOptions);
}

return new ValueTask<SslServerAuthenticationOptions>(OptionsTask(serverOptions));
},
null, default);
await Assert.ThrowsAsync<System.NotSupportedException>(() => t2);
}
}

public static IEnumerable<object[]> ProtocolMismatchData()
{
if (PlatformDetection.SupportsSsl3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ public static class TestHelper
private static readonly X509BasicConstraintsExtension s_eeConstraints =
new X509BasicConstraintsExtension(false, false, 0, false);

public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams()
{
(Stream clientStream, Stream serverStream) = GetConnectedStreams();
return (new SslStream(clientStream), new SslStream(serverStream));
}

public static (Stream ClientStream, Stream ServerStream) GetConnectedStreams()
{
if (Capability.SecurityForceSocketStreams())
Expand Down