Skip to content

Commit

Permalink
Replace TlsStream type by using SslStream directly (#106451)
Browse files Browse the repository at this point in the history
* Remove TlsStream from System.Net.Mail

* Remove TlsStream from System.Net.Requests

* Delete TlsStream.cs

* Update src/libraries/System.Net.Requests/src/System/Net/FtpControlStream.cs

Co-authored-by: Miha Zupan <mihazupan.zupan1@gmail.com>

* Update src/libraries/System.Net.Requests/src/System/Net/FtpDataStream.cs

---------

Co-authored-by: Miha Zupan <mihazupan.zupan1@gmail.com>
  • Loading branch information
rzikm and MihaZupan authored Aug 16, 2024
1 parent 9230f2b commit c9e1d15
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 236 deletions.
106 changes: 0 additions & 106 deletions src/libraries/Common/src/System/Net/TlsStream.cs

This file was deleted.

2 changes: 0 additions & 2 deletions src/libraries/System.Net.Mail/src/System.Net.Mail.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@
Link="Common\System\Net\DebugSafeHandleZeroOrMinusOneIsInvalid.cs" />
<Compile Include="$(CommonPath)System\Net\DebugSafeHandle.cs"
Link="Common\System\Net\DebugSafeHandle.cs" />
<Compile Include="$(CommonPath)System\Net\TlsStream.cs"
Link="Common\System\Net\TlsStream.cs" />
<Compile Include="$(CommonPath)System\Net\InternalException.cs"
Link="Common\System\Net\InternalException.cs" />
<Compile Include="$(CommonPath)System\Net\ExceptionCheck.cs"
Expand Down
78 changes: 48 additions & 30 deletions src/libraries/System.Net.Mail/src/System/Net/Mail/SmtpConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ internal sealed partial class SmtpConnection
private readonly EventHandler? _onCloseHandler;
internal SmtpTransport? _parent;
private readonly SmtpClient? _client;
private NetworkStream? _networkStream;
private Stream? _stream;
internal TcpClient? _tcpClient;
private SmtpReplyReaderFactory? _responseReader;

Expand Down Expand Up @@ -82,7 +82,7 @@ internal X509CertificateCollection? ClientCertificates
internal void InitializeConnection(string host, int port)
{
_tcpClient!.Connect(host, port);
_networkStream = _tcpClient.GetStream();
_stream = _tcpClient.GetStream();
}

internal IAsyncResult BeginInitializeConnection(string host, int port, AsyncCallback? callback, object? state)
Expand All @@ -93,7 +93,7 @@ internal IAsyncResult BeginInitializeConnection(string host, int port, AsyncCall
internal void EndInitializeConnection(IAsyncResult result)
{
_tcpClient!.EndConnect(result);
_networkStream = _tcpClient.GetStream();
_stream = _tcpClient.GetStream();
}

internal IAsyncResult BeginGetConnection(ContextAwareResult outerResult, AsyncCallback? callback, object? state, string host, int port)
Expand All @@ -105,18 +105,18 @@ internal IAsyncResult BeginGetConnection(ContextAwareResult outerResult, AsyncCa

internal IAsyncResult BeginFlush(AsyncCallback? callback, object? state)
{
return _networkStream!.BeginWrite(_bufferBuilder.GetBuffer(), 0, _bufferBuilder.Length, callback, state);
return _stream!.BeginWrite(_bufferBuilder.GetBuffer(), 0, _bufferBuilder.Length, callback, state);
}

internal void EndFlush(IAsyncResult result)
{
_networkStream!.EndWrite(result);
_stream!.EndWrite(result);
_bufferBuilder.Reset();
}

internal void Flush()
{
_networkStream!.Write(_bufferBuilder.GetBuffer(), 0, _bufferBuilder.Length);
_stream!.Write(_bufferBuilder.GetBuffer(), 0, _bufferBuilder.Length);
_bufferBuilder.Reset();
}

Expand Down Expand Up @@ -150,7 +150,7 @@ private void ShutdownConnection(bool isAbort)
finally
{
//free cbt buffer
_networkStream?.Close();
_stream?.Close();
_tcpClient.Dispose();
}
}
Expand Down Expand Up @@ -190,7 +190,7 @@ internal void GetConnection(string host, int port)
}

InitializeConnection(host, port);
_responseReader = new SmtpReplyReaderFactory(_networkStream!);
_responseReader = new SmtpReplyReaderFactory(_stream!);

LineInfo info = _responseReader.GetNextReplyReader().ReadLine();

Expand Down Expand Up @@ -225,17 +225,25 @@ internal void GetConnection(string host, int port)
if (!_serverSupportsStartTls)
{
// Either TLS is already established or server does not support TLS
if (!(_networkStream is TlsStream))
if (!(_stream is SslStream))
{
throw new SmtpException(SR.MailServerDoesNotSupportStartTls);
}
}

StartTlsCommand.Send(this);
TlsStream tlsStream = new TlsStream(_networkStream!, _tcpClient!.Client, host, _clientCertificates);
tlsStream.AuthenticateAsClient();
_networkStream = tlsStream;
_responseReader = new SmtpReplyReaderFactory(_networkStream);
#pragma warning disable SYSLIB0014 // ServicePointManager is obsolete
SslStream sslStream = new SslStream(_stream!, false, ServicePointManager.ServerCertificateValidationCallback);

sslStream.AuthenticateAsClient(
host,
_clientCertificates,
(SslProtocols)ServicePointManager.SecurityProtocol, // enums use same values
ServicePointManager.CheckCertificateRevocationList);
#pragma warning restore SYSLIB0014 // ServicePointManager is obsolete

_stream = sslStream;
_responseReader = new SmtpReplyReaderFactory(_stream);

// According to RFC 3207: The client SHOULD send an EHLO command
// as the first command after a successful TLS negotiation.
Expand Down Expand Up @@ -362,7 +370,7 @@ internal static void EndGetConnection(IAsyncResult result)

internal Stream GetClosableStream()
{
ClosableStream cs = new ClosableStream(_networkStream!, _onCloseHandler);
ClosableStream cs = new ClosableStream(_stream!, _onCloseHandler);
_isStreamOpen = true;
return cs;
}
Expand Down Expand Up @@ -460,7 +468,7 @@ private static void InitializeConnectionCallback(IAsyncResult result)

private void Handshake()
{
_connection._responseReader = new SmtpReplyReaderFactory(_connection._networkStream!);
_connection._responseReader = new SmtpReplyReaderFactory(_connection._stream!);

SmtpReplyReader reader = _connection.Reader!.GetNextReplyReader();
IAsyncResult result = reader.BeginReadLine(s_handshakeCallback, this);
Expand Down Expand Up @@ -533,10 +541,10 @@ private bool SendEHello()
{
_connection._extensions = EHelloCommand.EndSend(result);
_connection.ParseExtensions(_connection._extensions);
// If we already have a TlsStream, this is the second EHLO cmd
// If we already have a SslStream, this is the second EHLO cmd
// that we sent after TLS handshake compelted. So skip TLS and
// continue with Authenticate.
if (_connection._networkStream is TlsStream)
if (_connection._stream is SslStream)
{
Authenticate();
return true;
Expand All @@ -547,7 +555,7 @@ private bool SendEHello()
if (!_connection._serverSupportsStartTls)
{
// Either TLS is already established or server does not support TLS
if (!(_connection._networkStream is TlsStream))
if (!(_connection._stream is SslStream))
{
throw new SmtpException(SR.MailServerDoesNotSupportStartTls);
}
Expand Down Expand Up @@ -579,7 +587,7 @@ private static void SendEHelloCallback(IAsyncResult result)
// If we already have a SSlStream, this is the second EHLO cmd
// that we sent after TLS handshake compelted. So skip TLS and
// continue with Authenticate.
if (thisPtr._connection._networkStream is TlsStream)
if (thisPtr._connection._stream is SslStream)
{
thisPtr.Authenticate();
return;
Expand All @@ -606,7 +614,7 @@ private static void SendEHelloCallback(IAsyncResult result)
if (!thisPtr._connection._serverSupportsStartTls)
{
// Either TLS is already established or server does not support TLS
if (!(thisPtr._connection._networkStream is TlsStream))
if (!(thisPtr._connection._stream is SslStream))
{
throw new SmtpException(SR.MailServerDoesNotSupportStartTls);
}
Expand Down Expand Up @@ -663,7 +671,7 @@ private bool SendStartTls()
if (result.CompletedSynchronously)
{
StartTlsCommand.EndSend(result);
TlsStreamAuthenticate();
SslStreamAuthenticate();
return true;
}
return false;
Expand All @@ -677,7 +685,7 @@ private static void SendStartTlsCallback(IAsyncResult result)
try
{
StartTlsCommand.EndSend(result);
thisPtr.TlsStreamAuthenticate();
thisPtr.SslStreamAuthenticate();
}
catch (Exception e)
{
Expand All @@ -686,29 +694,39 @@ private static void SendStartTlsCallback(IAsyncResult result)
}
}

private bool TlsStreamAuthenticate()
private bool SslStreamAuthenticate()
{
_connection._networkStream = new TlsStream(_connection._networkStream!, _connection._tcpClient!.Client, _host, _connection._clientCertificates);
IAsyncResult result = ((TlsStream)_connection._networkStream).BeginAuthenticateAsClient(TlsStreamAuthenticateCallback, this);
#pragma warning disable SYSLIB0014 // ServicePointManager is obsolete
_connection._stream = new SslStream(_connection._stream!, false, ServicePointManager.ServerCertificateValidationCallback);

IAsyncResult result = ((SslStream)_connection._stream).BeginAuthenticateAsClient(
_host,
_connection._clientCertificates,
(SslProtocols)ServicePointManager.SecurityProtocol, // enums use same values
ServicePointManager.CheckCertificateRevocationList,
SslStreamAuthenticateCallback,
this);
#pragma warning restore SYSLIB0014 // ServicePointManager is obsolete

if (result.CompletedSynchronously)
{
((TlsStream)_connection._networkStream).EndAuthenticateAsClient(result);
_connection._responseReader = new SmtpReplyReaderFactory(_connection._networkStream);
((SslStream)_connection._stream).EndAuthenticateAsClient(result);
_connection._responseReader = new SmtpReplyReaderFactory(_connection._stream);
SendEHello();
return true;
}
return false;
}

private static void TlsStreamAuthenticateCallback(IAsyncResult result)
private static void SslStreamAuthenticateCallback(IAsyncResult result)
{
if (!result.CompletedSynchronously)
{
ConnectAndHandshakeAsyncResult thisPtr = (ConnectAndHandshakeAsyncResult)result.AsyncState!;
try
{
(thisPtr._connection._networkStream as TlsStream)!.EndAuthenticateAsClient(result);
thisPtr._connection._responseReader = new SmtpReplyReaderFactory(thisPtr._connection._networkStream);
(thisPtr._connection._stream as SslStream)!.EndAuthenticateAsClient(result);
thisPtr._connection._responseReader = new SmtpReplyReaderFactory(thisPtr._connection._stream);
thisPtr.SendEHello();
}
catch (Exception e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@
Link="ProductionCode\BufferBuilder.cs" />
<Compile Include="$(CommonPath)DisableRuntimeMarshalling.cs"
Link="Common\DisableRuntimeMarshalling.cs" />
<Compile Include="$(CommonPath)System\Net\TlsStream.cs"
Link="Common\System\Net\TlsStream.cs" />
<Compile Include="$(CommonPath)System\Net\InternalException.cs"
Link="Common\System\Net\InternalException.cs" />
<Compile Include="$(CommonPath)System\Net\LazyAsyncResult.cs"
Expand All @@ -140,8 +138,8 @@
Link="Common\System\HexConverter.cs" />
<Compile Include="$(CommonPath)System\Obsoletions.cs"
Link="Common\System\Obsoletions.cs" />
<Compile Include="$(CommonPath)System\Text\ValueStringBuilder.cs"
Link="Common\System\Text\ValueStringBuilder.cs" />
<Compile Include="$(CommonPath)System\Text\ValueStringBuilder.cs"
Link="Common\System\Text\ValueStringBuilder.cs" />
</ItemGroup>
<!-- Unix specific files -->
<ItemGroup Condition="'$(TargetPlatformIdentifier)' == 'unix'">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@
Link="Common\System\Net\ContextAwareResult.cs" />
<Compile Include="$(CommonPath)System\Net\ExceptionCheck.cs"
Link="Common\System\Net\ExceptionCheck.cs" />
<Compile Include="$(CommonPath)System\Net\TlsStream.cs"
Link="Common\System\Net\TlsStream.cs" />
<Compile Include="$(CommonPath)System\Net\SecurityProtocol.cs"
Link="Common\System\Net\SecurityProtocol.cs" />
<Compile Include="$(CommonPath)System\NotImplemented.cs"
Expand Down
Loading

0 comments on commit c9e1d15

Please sign in to comment.