Skip to content

Commit

Permalink
avoid alpn allocation on windows (#74619)
Browse files Browse the repository at this point in the history
* avoid alpn allocation on windows

* updates

* remove dead code

* scoped

* Apply suggestions from code review

Co-authored-by: Stephen Toub <stoub@microsoft.com>

* Fix build breaks

* SetAlpn

Co-authored-by: Stephen Toub <stoub@microsoft.com>
Co-authored-by: Jan Kotas <jkotas@microsoft.com>
  • Loading branch information
3 people authored Nov 22, 2022
1 parent 167b0a1 commit a8e2b6e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ internal struct Sec_Application_Protocols
public ApplicationProtocolNegotiationExt ProtocolExtensionType;
public short ProtocolListSize;

public static unsafe byte[] ToByteArray(List<SslApplicationProtocol> applicationProtocols)
public static int GetProtocolLength(List<SslApplicationProtocol> applicationProtocols)
{
long protocolListSize = 0;
int protocolListSize = 0;
for (int i = 0; i < applicationProtocols.Count; i++)
{
int protocolLength = applicationProtocols[i].Protocol.Length;
Expand All @@ -36,6 +36,13 @@ public static unsafe byte[] ToByteArray(List<SslApplicationProtocol> application
}
}

return protocolListSize;
}

public static unsafe byte[] ToByteArray(List<SslApplicationProtocol> applicationProtocols)
{
int protocolListSize = GetProtocolLength(applicationProtocols);

Sec_Application_Protocols protocols = default;

int protocolListConstSize = sizeof(Sec_Application_Protocols) - sizeof(uint) /* offsetof(Sec_Application_Protocols, ProtocolExtensionType) */;
Expand All @@ -60,5 +67,24 @@ public static unsafe byte[] ToByteArray(List<SslApplicationProtocol> application

return buffer;
}

public static unsafe void SetProtocols(Span<byte> buffer, List<SslApplicationProtocol> applicationProtocols, int protocolLength)
{
Span<Sec_Application_Protocols> alpn = MemoryMarshal.Cast<byte, Sec_Application_Protocols>(buffer);
alpn[0].ProtocolListsSize = (uint)(sizeof(Sec_Application_Protocols) - sizeof(uint) + protocolLength);
alpn[0].ProtocolExtensionType = ApplicationProtocolNegotiationExt.ALPN;
alpn[0].ProtocolListSize = (short)protocolLength;

Span<byte> data = buffer.Slice(sizeof(Sec_Application_Protocols));
for (int i = 0; i < applicationProtocols.Count; i++)
{
ReadOnlySpan<byte> protocol = applicationProtocols[i].Protocol.Span;

data[0] = (byte)protocol.Length;
data = data.Slice(1);
protocol.CopyTo(data);
data = data.Slice(protocol.Length);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Authentication.ExtendedProtection;
using System.Security.Cryptography.X509Certificates;
using System.Security.Principal;
using System.Text;
using Microsoft.Win32.SafeHandles;

namespace System.Net.Security
{
internal static class SslStreamPal
{
private static readonly byte[] s_http1 = Interop.Sec_Application_Protocols.ToByteArray(new List<SslApplicationProtocol> { SslApplicationProtocol.Http11 });
private static readonly byte[] s_http2 = Interop.Sec_Application_Protocols.ToByteArray(new List<SslApplicationProtocol> { SslApplicationProtocol.Http2 });
private static readonly byte[] s_http12 = Interop.Sec_Application_Protocols.ToByteArray(new List<SslApplicationProtocol> { SslApplicationProtocol.Http11, SslApplicationProtocol.Http2 });
private static readonly byte[] s_http21 = Interop.Sec_Application_Protocols.ToByteArray(new List<SslApplicationProtocol> { SslApplicationProtocol.Http2, SslApplicationProtocol.Http11 });

private static readonly bool UseNewCryptoApi =
// On newer Windows version we use new API to get TLS1.3.
// API is supported since Windows 10 1809 (17763) but there is no reason to use at the moment.
Expand Down Expand Up @@ -47,12 +50,36 @@ public static void VerifyPackageInfo()
SSPIWrapper.GetVerifyPackageInfo(GlobalSSPI.SSPISecureChannel, SecurityPackage, true);
}

public static byte[] ConvertAlpnProtocolListToByteArray(List<SslApplicationProtocol> protocols)
private static unsafe void SetAlpn(ref InputSecurityBuffers inputBuffers, List<SslApplicationProtocol> alpn, Span<byte> localBuffer)
{
return Interop.Sec_Application_Protocols.ToByteArray(protocols);
if (alpn.Count == 1 && alpn[0] == SslApplicationProtocol.Http11)
{
inputBuffers.SetNextBuffer(new InputSecurityBuffer(s_http1, SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS));
}
else if (alpn.Count == 1 && alpn[0] == SslApplicationProtocol.Http2)
{
inputBuffers.SetNextBuffer(new InputSecurityBuffer(s_http2, SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS));
}
else if (alpn.Count == 2 && alpn[0] == SslApplicationProtocol.Http11 && alpn[1] == SslApplicationProtocol.Http2)
{
inputBuffers.SetNextBuffer(new InputSecurityBuffer(s_http12, SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS));
}
else if (alpn.Count == 2 && alpn[0] == SslApplicationProtocol.Http2 && alpn[1] == SslApplicationProtocol.Http11)
{
inputBuffers.SetNextBuffer(new InputSecurityBuffer(s_http21, SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS));
}
else
{
int protocolLength = Interop.Sec_Application_Protocols.GetProtocolLength(alpn);
int bufferLength = sizeof(Interop.Sec_Application_Protocols) + protocolLength;

Span<byte> alpnBuffer = bufferLength <= localBuffer.Length ? localBuffer : new byte[bufferLength];
Interop.Sec_Application_Protocols.SetProtocols(alpnBuffer, alpn, protocolLength);
inputBuffers.SetNextBuffer(new InputSecurityBuffer(alpnBuffer, SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS));
}
}

public static SecurityStatusPal AcceptSecurityContext(
public static unsafe SecurityStatusPal AcceptSecurityContext(
ref SafeFreeCredentials? credentialsHandle,
ref SafeDeleteSslContext? context,
ReadOnlySpan<byte> inputBuffer,
Expand All @@ -61,14 +88,13 @@ public static SecurityStatusPal AcceptSecurityContext(
{
Interop.SspiCli.ContextFlags unusedAttributes = default;

InputSecurityBuffers inputBuffers = default;
scoped InputSecurityBuffers inputBuffers = default;
inputBuffers.SetNextBuffer(new InputSecurityBuffer(inputBuffer, SecurityBufferType.SECBUFFER_TOKEN));
inputBuffers.SetNextBuffer(new InputSecurityBuffer(default, SecurityBufferType.SECBUFFER_EMPTY));

if (context == null && sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0)
{
byte[] alpnBytes = ConvertAlpnProtocolListToByteArray(sslAuthenticationOptions.ApplicationProtocols);
inputBuffers.SetNextBuffer(new InputSecurityBuffer(new ReadOnlySpan<byte>(alpnBytes), SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS));
Span<byte> localBuffer = stackalloc byte[64];
SetAlpn(ref inputBuffers, sslAuthenticationOptions.ApplicationProtocols, localBuffer);
}

var resultBuffer = new SecurityBuffer(outputBuffer, SecurityBufferType.SECBUFFER_TOKEN);
Expand All @@ -87,7 +113,7 @@ public static SecurityStatusPal AcceptSecurityContext(
return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode);
}

public static SecurityStatusPal InitializeSecurityContext(
public static unsafe SecurityStatusPal InitializeSecurityContext(
ref SafeFreeCredentials? credentialsHandle,
ref SafeDeleteSslContext? context,
string? targetName,
Expand All @@ -98,13 +124,13 @@ public static SecurityStatusPal InitializeSecurityContext(
{
Interop.SspiCli.ContextFlags unusedAttributes = default;

InputSecurityBuffers inputBuffers = default;
scoped InputSecurityBuffers inputBuffers = default;
inputBuffers.SetNextBuffer(new InputSecurityBuffer(inputBuffer, SecurityBufferType.SECBUFFER_TOKEN));
inputBuffers.SetNextBuffer(new InputSecurityBuffer(default, SecurityBufferType.SECBUFFER_EMPTY));
if (context == null && sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0)
{
byte[] alpnBytes = ConvertAlpnProtocolListToByteArray(sslAuthenticationOptions.ApplicationProtocols);
inputBuffers.SetNextBuffer(new InputSecurityBuffer(new ReadOnlySpan<byte>(alpnBytes), SecurityBufferType.SECBUFFER_APPLICATION_PROTOCOLS));
Span<byte> localBuffer = stackalloc byte[64];
SetAlpn(ref inputBuffers, sslAuthenticationOptions.ApplicationProtocols, localBuffer);
}

var resultBuffer = new SecurityBuffer(outputBuffer, SecurityBufferType.SECBUFFER_TOKEN);
Expand Down

0 comments on commit a8e2b6e

Please sign in to comment.