Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Add ALPN support for SslStream. #24389

Merged
merged 12 commits into from
Oct 19, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ internal static partial class Interop
{
internal static partial class OpenSsl
{
private static Ssl.SslCtxSetVerifyCallback s_verifyClientCertificate = VerifyClientCertificate;
private static readonly Ssl.SslCtxSetVerifyCallback s_verifyClientCertificate = VerifyClientCertificate;
private static readonly Ssl.SslCtxSetAlpnCallback s_alpnServerCallback = AlpnServerSelectCallback;

#region internal methods

Expand All @@ -47,7 +48,7 @@ internal static SafeChannelBindingHandle QueryChannelBinding(SafeSslHandle conte
return bindingHandle;
}

internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX509Handle certHandle, SafeEvpPKeyHandle certKeyHandle, EncryptionPolicy policy, bool isServer, bool remoteCertRequired)
internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX509Handle certHandle, SafeEvpPKeyHandle certKeyHandle, EncryptionPolicy policy, SslAuthenticationOptions sslAuthenticationOptions)
{
SafeSslHandle context = null;

Expand Down Expand Up @@ -88,17 +89,32 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50
SetSslCertificate(innerContext, certHandle, certKeyHandle);
}

if (remoteCertRequired)
if (sslAuthenticationOptions.IsServer && sslAuthenticationOptions.RemoteCertRequired)
{
Debug.Assert(isServer, "isServer flag should be true");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean if RemoteCertRequired is true and IsServer is false? The assert here took care of identifying that as an invalid state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This state will not be hit with the current implementation. Previously, these internal APIs were passing the RemoteCertRequired=false for client and RemoteCertRequired=true for server. This is in direct contradiction with external meaning of RemoteCertRequired, where for client it is always true, and server is got from RemoteCertRequired parameter on the authenticate methods. Since we don't flip these values internally anymore, it doesn't make sense to have this assert.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I get it. You've changed the client to be true for this instead of false. Really the client is N/A, and that's what the assert was guarding (that it made no sense for a client to change that value during the handshake).

This seems fine, now that I understand the change. There's a mild perf gain (ditching the &&) for having the client use false, but it's not something worth adding confusion over.

Ssl.SslCtxSetVerify(innerContext,
s_verifyClientCertificate);
Ssl.SslCtxSetVerify(innerContext, s_verifyClientCertificate);

//update the client CA list
UpdateCAListFromRootStore(innerContext);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method call does not match the behavior on Windows.
Why does it exist in this manner? See also: https://github.com/dotnet/corefx/issues/23938

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ayende This code is not related to the work done in this PR. Please file separate issue if there's a bug.

}

context = SafeSslHandle.Create(innerContext, isServer);
if (sslAuthenticationOptions.ApplicationProtocols != null)
{
if (sslAuthenticationOptions.IsServer)
{
byte[] protos = Interop.Ssl.ConvertAlpnProtocolListToByteArray(sslAuthenticationOptions.ApplicationProtocols);
sslAuthenticationOptions.AlpnProtocolsHandle = GCHandle.Alloc(protos);
Interop.Ssl.SslCtxSetAlpnSelectCb(innerContext, s_alpnServerCallback, GCHandle.ToIntPtr(sslAuthenticationOptions.AlpnProtocolsHandle));
}
else
{
if (Interop.Ssl.SslCtxSetAlpnProtos(innerContext, sslAuthenticationOptions.ApplicationProtocols) != 0)
{
throw CreateSslException(SR.net_alpn_config_failed);
}
}
}

context = SafeSslHandle.Create(innerContext, sslAuthenticationOptions.IsServer);
Debug.Assert(context != null, "Expected non-null return value from SafeSslHandle.Create");
if (context.IsInvalid)
{
Expand Down Expand Up @@ -314,6 +330,18 @@ private static int VerifyClientCertificate(int preverify_ok, IntPtr x509_ctx_ptr
return OpenSslSuccess;
}

private static unsafe int AlpnServerSelectCallback(IntPtr ssl, out IntPtr outp, out byte outlen, IntPtr inp, uint inlen, IntPtr arg)
{
GCHandle protocols = GCHandle.FromIntPtr(arg);
byte[] server = (byte[])protocols.Target;

fixed (byte* sp = server)
{
return Interop.Ssl.SslSelectNextProto(out outp, out outlen, (IntPtr)sp, (uint)server.Length, inp, inlen) == Interop.Ssl.OPENSSL_NPN_NEGOTIATED ?
Interop.Ssl.SSL_TLSEXT_ERR_OK : Interop.Ssl.SSL_TLSEXT_ERR_NOACK;
}
}

private static void UpdateCAListFromRootStore(SafeSslContextHandle context)
{
using (SafeX509NameStackHandle nameStack = Crypto.NewX509NameStack())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ internal static partial class Interop
{
internal static partial class Ssl
{
internal const int SSL_TLSEXT_ERR_OK = 0;
internal const int OPENSSL_NPN_NEGOTIATED = 1;
internal const int SSL_TLSEXT_ERR_NOACK = 3;

internal delegate int SslCtxSetVerifyCallback(int preverify_ok, IntPtr x509_ctx);

[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_EnsureLibSslInitialized")]
Expand Down Expand Up @@ -44,6 +48,26 @@ internal static partial class Ssl
[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetVersion")]
private static extern IntPtr SslGetVersion(SafeSslHandle ssl);

[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSelectNextProto")]
internal static extern int SslSelectNextProto(out IntPtr outp, out byte outlen, IntPtr server, uint serverlen, IntPtr client, uint clientlen);

[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGet0AlpnSelected")]
internal static extern void SslGetAlpnSelected(SafeSslHandle ssl, out IntPtr protocol, out int len);

internal static byte[] SslGetAlpnSelected(SafeSslHandle ssl)
{
IntPtr protocol;
int len;
SslGetAlpnSelected(ssl, out protocol, out len);

if (len == 0)
return null;

byte[] result = new byte[len];
Marshal.Copy(protocol, result, 0, len);
return result;
}

internal static string GetProtocolVersion(SafeSslHandle ssl)
{
return Marshal.PtrToStringAnsi(SslGetVersion(ssl));
Expand Down Expand Up @@ -156,12 +180,12 @@ internal enum SslErrorCode
SSL_ERROR_WANT_WRITE = 3,
SSL_ERROR_SYSCALL = 5,
SSL_ERROR_ZERO_RETURN = 6,

// NOTE: this SslErrorCode value doesn't exist in OpenSSL, but
// we use it to distinguish when a renegotiation is pending.
// Choosing an arbitrarily large value that shouldn't conflict
// with any actual OpenSSL error codes
SSL_ERROR_RENEGOTIATE = 29304
SSL_ERROR_RENEGOTIATE = 29304
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Net.Security;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.Win32.SafeHandles;

internal static partial class Interop
Expand All @@ -12,6 +15,7 @@ internal static partial class Ssl
{
internal delegate int AppVerifyCallback(IntPtr storeCtx, IntPtr arg);
internal delegate int ClientCertCallback(IntPtr ssl, out IntPtr x509, out IntPtr pkey);
internal delegate int SslCtxSetAlpnCallback(IntPtr ssl, out IntPtr outp, out byte outlen, IntPtr inp, uint inlen, IntPtr arg);

[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxCreate")]
internal static extern SafeSslContextHandle SslCtxCreate(IntPtr method);
Expand All @@ -24,6 +28,46 @@ internal static partial class Ssl

[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetClientCertCallback")]
internal static extern void SslCtxSetClientCertCallback(IntPtr ctx, ClientCertCallback callback);

[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetAlpnProtos")]
internal static extern int SslCtxSetAlpnProtos(SafeSslContextHandle ctx, IntPtr protos, int len);

[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslCtxSetAlpnSelectCb")]
internal static unsafe extern void SslCtxSetAlpnSelectCb(SafeSslContextHandle ctx, SslCtxSetAlpnCallback callback, IntPtr arg);

internal static unsafe int SslCtxSetAlpnProtos(SafeSslContextHandle ctx, List<SslApplicationProtocol> protocols)
{
byte[] buffer = ConvertAlpnProtocolListToByteArray(protocols);
fixed (byte* b = buffer)
{
return SslCtxSetAlpnProtos(ctx, (IntPtr)b, buffer.Length);
}
}

internal static byte[] ConvertAlpnProtocolListToByteArray(List<SslApplicationProtocol> applicationProtocols)
{
int protocolSize = 0;
foreach (SslApplicationProtocol protocol in applicationProtocols)
{
if (protocol.Protocol.Length == 0 || protocol.Protocol.Length > byte.MaxValue)
{
throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols));
}

protocolSize += protocol.Protocol.Length + 1;
}

byte[] buffer = new byte[protocolSize];
var offset = 0;
foreach (SslApplicationProtocol protocol in applicationProtocols)
{
buffer[offset++] = (byte)(protocol.Protocol.Length);
protocol.Protocol.Span.CopyTo(new Span<byte>(buffer).Slice(offset));
offset += protocol.Protocol.Length;
}

return buffer;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ internal enum SECURITY_STATUS
SmartcardLogonRequired = unchecked((int)0x8009033E),
UnsupportedPreauth = unchecked((int)0x80090343),
BadBinding = unchecked((int)0x80090346),
DowngradeDetected = unchecked((int)0x80090350)
DowngradeDetected = unchecked((int)0x80090350),
ApplicationProtocolMismatch = unchecked((int)0x80090367),
}

#if TRACE_VERBOSE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.InteropServices;

internal static partial class Interop
{
internal enum ApplicationProtocolNegotiationStatus
{
None = 0,
Success,
SelectedClientOnly
}

internal enum ApplicationProtocolNegotiationExt
{
None = 0,
NPN,
ALPN
}

[StructLayout(LayoutKind.Sequential)]
internal class SecPkgContext_ApplicationProtocol
{
private const int MaxProtocolIdSize = 0xFF;

public ApplicationProtocolNegotiationStatus ProtoNegoStatus;
public ApplicationProtocolNegotiationExt ProtoNegoExt;
public byte ProtocolIdSize;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = MaxProtocolIdSize)]
public byte[] ProtocolId;
public byte[] Protocol
{
get
{
return new Span<byte>(ProtocolId, 0, ProtocolIdSize).ToArray();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.Security;
using System.Runtime.InteropServices;

internal static partial class Interop
{
[StructLayout(LayoutKind.Sequential, Pack = 1)]
internal struct Sec_Application_Protocols
{
private static readonly int ProtocolListOffset = Marshal.SizeOf<Sec_Application_Protocols>();
private static readonly int ProtocolListConstSize = ProtocolListOffset - (int)Marshal.OffsetOf<Sec_Application_Protocols>(nameof(ProtocolExtenstionType));
public uint ProtocolListsSize;
public ApplicationProtocolNegotiationExt ProtocolExtenstionType;
public short ProtocolListSize;

public static unsafe byte[] ToByteArray(List<SslApplicationProtocol> applicationProtocols)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stephentoub, should we also have a method that takes in a span and puts the bytes into it, i.e. TryCopyTo(Span buffer)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is internal. Is the equivalent exposed publicly somewhere?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's exposed, but: a) maybe we should expose it, b) maybe the internal code that uses it would be written with less copies/allocations if we had such internal API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For b), at that point it's an implementation detail and we should do whatever's most efficient, assuming it matters.

For a), what would it be used for?

{
long protocolListSize = 0;
for (int i = 0; i < applicationProtocols.Count; i++)
{
if (applicationProtocols[i].Protocol.Length == 0 || applicationProtocols[i].Protocol.Length > byte.MaxValue)
{
throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols));
}

protocolListSize += applicationProtocols[i].Protocol.Length + 1;

if (protocolListSize > short.MaxValue)
{
throw new ArgumentException(SR.net_ssl_app_protocols_invalid, nameof(applicationProtocols));
}
}

Sec_Application_Protocols protocols = new Sec_Application_Protocols();
protocols.ProtocolListsSize = (uint)(ProtocolListConstSize + protocolListSize);
protocols.ProtocolExtenstionType = ApplicationProtocolNegotiationExt.ALPN;
protocols.ProtocolListSize = (short)protocolListSize;

Span<byte> pBuffer = new byte[protocolListSize];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: pBuffer kind of implies that the variable is a pointer. it might be better to call it protocolsBuffer.

int index = 0;
for (int i = 0; i < applicationProtocols.Count; i++)
{
pBuffer[index++] = (byte)applicationProtocols[i].Protocol.Length;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug.Assert(applicationProtocols[i].Protocol.Length < 255)? (or <=, if appropriate)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is done, in the code just before in the same function, so a debug assert is not required here.

applicationProtocols[i].Protocol.Span.CopyTo(pBuffer.Slice(index));
index += applicationProtocols[i].Protocol.Length;
}

byte[] buffer = new byte[ProtocolListOffset + protocolListSize];
fixed (byte* bufferPtr = buffer)
{
Marshal.StructureToPtr(protocols, new IntPtr(bufferPtr), false);
byte* pList = bufferPtr + ProtocolListOffset;
pBuffer.CopyTo(new Span<byte>(pList, index));
}

return buffer;
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of allocation happening in these conversion routines: enumerators, byte arrays, etc. Not to block this PR, but it'd be good subsequently to see what kind of impact that has and whether there are ways to reduce it.

}
1 change: 1 addition & 0 deletions src/Common/src/Interop/Windows/sspicli/Interop.SSPI.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ internal enum ContextAttribute
SECPKG_ATTR_UNIQUE_BINDINGS = 25,
SECPKG_ATTR_ENDPOINT_BINDINGS = 26,
SECPKG_ATTR_CLIENT_SPECIFIED_TARGET = 27,
SECPKG_ATTR_APPLICATION_PROTOCOL = 35,

// minschannel.h
SECPKG_ATTR_REMOTE_CERT_CONTEXT = 0x53, // returns PCCERT_CONTEXT
Expand Down
15 changes: 15 additions & 0 deletions src/Common/src/Interop/Windows/sspicli/SSPIWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,10 @@ public static object QueryContextAttributes(SSPIInterface secModule, SafeDeleteC
nativeBlockSize = Marshal.SizeOf<SecPkgContext_ConnectionInfo>();
break;

case Interop.SspiCli.ContextAttribute.SECPKG_ATTR_APPLICATION_PROTOCOL:
nativeBlockSize = Marshal.SizeOf<Interop.SecPkgContext_ApplicationProtocol>();
break;

default:
throw new ArgumentException(SR.Format(SR.net_invalid_enum, nameof(contextAttribute)), nameof(contextAttribute));
}
Expand Down Expand Up @@ -540,6 +544,17 @@ public static object QueryContextAttributes(SSPIInterface secModule, SafeDeleteC
case Interop.SspiCli.ContextAttribute.SECPKG_ATTR_CONNECTION_INFO:
attribute = new SecPkgContext_ConnectionInfo(nativeBuffer);
break;

case Interop.SspiCli.ContextAttribute.SECPKG_ATTR_APPLICATION_PROTOCOL:
unsafe
{
fixed (void *ptr = nativeBuffer)
{
attribute = Marshal.PtrToStructure<Interop.SecPkgContext_ApplicationProtocol>(new IntPtr(ptr));
}
}
break;

default:
// Will return null.
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Microsoft.Win32.SafeHandles;

using System.Diagnostics;
using System.Net.Security;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Authentication.ExtendedProtection;
Expand All @@ -25,7 +26,7 @@ public SafeSslHandle SslContext
}
}

public SafeDeleteSslContext(SafeFreeSslCredentials credential, bool isServer, bool remoteCertRequired)
public SafeDeleteSslContext(SafeFreeSslCredentials credential, SslAuthenticationOptions sslAuthenticationOptions)
: base(credential)
{
Debug.Assert((null != credential) && !credential.IsInvalid, "Invalid credential used in SafeDeleteSslContext");
Expand All @@ -37,8 +38,7 @@ public SafeDeleteSslContext(SafeFreeSslCredentials credential, bool isServer, bo
credential.CertHandle,
credential.CertKeyHandle,
credential.Policy,
isServer,
remoteCertRequired);
sslAuthenticationOptions);
}
catch(Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace System.Net
{
internal static class SecurityStatusAdapterPal
{
private const int StatusDictionarySize = 40;
private const int StatusDictionarySize = 41;

#if DEBUG
static SecurityStatusAdapterPal()
Expand All @@ -22,6 +22,7 @@ static SecurityStatusAdapterPal()
private static readonly BidirectionalDictionary<Interop.SECURITY_STATUS, SecurityStatusPalErrorCode> s_statusDictionary = new BidirectionalDictionary<Interop.SECURITY_STATUS, SecurityStatusPalErrorCode>(StatusDictionarySize)
{
{ Interop.SECURITY_STATUS.AlgorithmMismatch, SecurityStatusPalErrorCode.AlgorithmMismatch },
{ Interop.SECURITY_STATUS.ApplicationProtocolMismatch, SecurityStatusPalErrorCode.ApplicationProtocolMismatch },
{ Interop.SECURITY_STATUS.BadBinding, SecurityStatusPalErrorCode.BadBinding },
{ Interop.SECURITY_STATUS.BufferNotEnough, SecurityStatusPalErrorCode.BufferNotEnough },
{ Interop.SECURITY_STATUS.CannotInstall, SecurityStatusPalErrorCode.CannotInstall },
Expand Down
Loading