From c9e1d1549b99ef3bb21b499a15893a410decf23c Mon Sep 17 00:00:00 2001 From: Radek Zikmund <32671551+rzikm@users.noreply.github.com> Date: Fri, 16 Aug 2024 14:43:00 +0200 Subject: [PATCH] Replace TlsStream type by using SslStream directly (#106451) * Remove TlsStream from System.Net.Mail * Remove TlsStream from System.Net.Requests * Delete TlsStream.cs * Update src/libraries/System.Net.Requests/src/System/Net/FtpControlStream.cs Co-authored-by: Miha Zupan * Update src/libraries/System.Net.Requests/src/System/Net/FtpDataStream.cs --------- Co-authored-by: Miha Zupan --- .../Common/src/System/Net/TlsStream.cs | 106 ------------------ .../src/System.Net.Mail.csproj | 2 - .../src/System/Net/Mail/SmtpConnection.cs | 78 ++++++++----- .../Unit/System.Net.Mail.Unit.Tests.csproj | 6 +- .../src/System.Net.Requests.csproj | 2 - .../src/System/Net/FtpControlStream.cs | 95 ++++++++++------ .../src/System/Net/FtpDataStream.cs | 57 +++++----- .../src/System/Net/NetworkStreamWrapper.cs | 69 ++++++------ 8 files changed, 179 insertions(+), 236 deletions(-) delete mode 100644 src/libraries/Common/src/System/Net/TlsStream.cs diff --git a/src/libraries/Common/src/System/Net/TlsStream.cs b/src/libraries/Common/src/System/Net/TlsStream.cs deleted file mode 100644 index 286d02688b38b..0000000000000 --- a/src/libraries/Common/src/System/Net/TlsStream.cs +++ /dev/null @@ -1,106 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Net.Security; -using System.Net.Sockets; -using System.Security.Authentication; -using System.Security.Cryptography.X509Certificates; -using System.Threading; -using System.Threading.Tasks; - -#pragma warning disable SYSLIB0014 // ServicePointManager is obsolete -// This type is used by FtpWebRequest (already obsolete) and SmtpClient (discouraged). - -namespace System.Net -{ - internal sealed class TlsStream : NetworkStream - { - private readonly SslStream _sslStream; - private readonly string _host; - private readonly X509CertificateCollection? _clientCertificates; - - public TlsStream(NetworkStream stream, Socket socket, string host, X509CertificateCollection? clientCertificates) : base(socket) - { - _sslStream = new SslStream(stream, false, ServicePointManager.ServerCertificateValidationCallback); - _host = host; - _clientCertificates = clientCertificates; - } - - public void AuthenticateAsClient() - { - _sslStream.AuthenticateAsClient( - _host, - _clientCertificates, - (SslProtocols)ServicePointManager.SecurityProtocol, // enums use same values - ServicePointManager.CheckCertificateRevocationList); - } - - public IAsyncResult BeginAuthenticateAsClient(AsyncCallback? asyncCallback, object? state) - { - return _sslStream.BeginAuthenticateAsClient( - _host, - _clientCertificates, - (SslProtocols)ServicePointManager.SecurityProtocol, // enums use same values - ServicePointManager.CheckCertificateRevocationList, - asyncCallback, - state); - } - - public void EndAuthenticateAsClient(IAsyncResult asyncResult) - { - _sslStream.EndAuthenticateAsClient(asyncResult); - } - - public override void Write(byte[] buffer, int offset, int size) - { - _sslStream.Write(buffer, offset, size); - } - - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback? callback, object? state) - { - return _sslStream.BeginWrite(buffer, offset, size, callback, state); - } - - public override void EndWrite(IAsyncResult result) - { - _sslStream.EndWrite(result); - } - - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - return _sslStream.WriteAsync(buffer, offset, count, cancellationToken); - } - - public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default(CancellationToken)) - { - return _sslStream.WriteAsync(buffer, cancellationToken); - } - - public override int Read(byte[] buffer, int offset, int size) - { - return _sslStream.Read(buffer, offset, size); - } - - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - return _sslStream.ReadAsync(buffer, offset, count, cancellationToken); - } - - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) - { - return _sslStream.BeginRead(buffer, offset, count, callback, state); - } - - public override int EndRead(IAsyncResult asyncResult) - { - return _sslStream.EndRead(asyncResult); - } - - public override void Close() - { - base.Close(); - - _sslStream?.Close(); - } - } -} diff --git a/src/libraries/System.Net.Mail/src/System.Net.Mail.csproj b/src/libraries/System.Net.Mail/src/System.Net.Mail.csproj index db37efc40b049..c07188adcd018 100644 --- a/src/libraries/System.Net.Mail/src/System.Net.Mail.csproj +++ b/src/libraries/System.Net.Mail/src/System.Net.Mail.csproj @@ -110,8 +110,6 @@ Link="Common\System\Net\DebugSafeHandleZeroOrMinusOneIsInvalid.cs" /> - - - + diff --git a/src/libraries/System.Net.Requests/src/System.Net.Requests.csproj b/src/libraries/System.Net.Requests/src/System.Net.Requests.csproj index 46bda299d9a64..9d879b250d948 100644 --- a/src/libraries/System.Net.Requests/src/System.Net.Requests.csproj +++ b/src/libraries/System.Net.Requests/src/System.Net.Requests.csproj @@ -81,8 +81,6 @@ Link="Common\System\Net\ContextAwareResult.cs" /> - - { - try - { - tlsStream.EndAuthenticateAsClient(ar); - NetworkStream = tlsStream; - this.ContinueCommandPipeline(); - } - catch (Exception e) + sslStream.BeginAuthenticateAsClient( + request.RequestUri.Host, + request.ClientCertificates, + (SslProtocols)ServicePointManager.SecurityProtocol, // enums use same values + ServicePointManager.CheckCertificateRevocationList, + ar => { - this.CloseSocket(); - this.InvokeRequestCallback(e); - } - }, null); + try + { + sslStream.EndAuthenticateAsClient(ar); + Stream = sslStream; + this.ContinueCommandPipeline(); + } + catch (Exception e) + { + this.CloseSocket(); + this.InvokeRequestCallback(e); + } + }, + null); return PipelineInstruction.Pause; } else { - tlsStream.AuthenticateAsClient(); - NetworkStream = tlsStream; + sslStream.AuthenticateAsClient( + request.RequestUri.Host, + request.ClientCertificates, + (SslProtocols)ServicePointManager.SecurityProtocol, // enums use same values + ServicePointManager.CheckCertificateRevocationList); + Stream = sslStream; } } +#pragma warning restore SYSLIB0014 // ServicePointManager is obsolete } // OR parse out the file size or file time, usually a result of sending SIZE/MDTM commands else if (status == FtpStatusCode.FileStatus) diff --git a/src/libraries/System.Net.Requests/src/System/Net/FtpDataStream.cs b/src/libraries/System.Net.Requests/src/System/Net/FtpDataStream.cs index 9cd8b06a61042..129052a971307 100644 --- a/src/libraries/System.Net.Requests/src/System/Net/FtpDataStream.cs +++ b/src/libraries/System.Net.Requests/src/System/Net/FtpDataStream.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; +using System.Net.Security; using System.Net.Sockets; using System.Runtime.ExceptionServices; @@ -15,7 +16,8 @@ namespace System.Net internal sealed class FtpDataStream : Stream, ICloseEx { private readonly FtpWebRequest _request; - private readonly NetworkStream _networkStream; + private readonly Stream _stream; + private readonly NetworkStream _originalStream; private bool _writeable; private bool _readable; private bool _isFullyRead; @@ -23,7 +25,7 @@ internal sealed class FtpDataStream : Stream, ICloseEx private const int DefaultCloseTimeout = -1; - internal FtpDataStream(NetworkStream networkStream, FtpWebRequest request, TriState writeOnly) + internal FtpDataStream(Stream stream, NetworkStream originalStream, FtpWebRequest request, TriState writeOnly) { if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this); @@ -37,7 +39,8 @@ internal FtpDataStream(NetworkStream networkStream, FtpWebRequest request, TriSt { _writeable = false; } - _networkStream = networkStream; + _stream = stream; + _originalStream = originalStream; _request = request; } @@ -75,9 +78,9 @@ void ICloseEx.CloseEx(CloseExState closeState) try { if ((closeState & CloseExState.Abort) == 0) - _networkStream.Close(DefaultCloseTimeout); + _originalStream.Close(DefaultCloseTimeout); else - _networkStream.Close(0); + _originalStream.Close(0); } finally { @@ -125,7 +128,7 @@ public override bool CanSeek { get { - return _networkStream.CanSeek; + return _stream.CanSeek; } } @@ -141,7 +144,7 @@ public override long Length { get { - return _networkStream.Length; + return _stream.Length; } } @@ -149,12 +152,12 @@ public override long Position { get { - return _networkStream.Position; + return _stream.Position; } set { - _networkStream.Position = value; + _stream.Position = value; } } @@ -163,7 +166,7 @@ public override long Seek(long offset, SeekOrigin origin) CheckError(); try { - return _networkStream.Seek(offset, origin); + return _stream.Seek(offset, origin); } catch { @@ -178,7 +181,7 @@ public override int Read(byte[] buffer, int offset, int size) int readBytes; try { - readBytes = _networkStream.Read(buffer, offset, size); + readBytes = _stream.Read(buffer, offset, size); } catch { @@ -199,7 +202,7 @@ public override int Read(Span buffer) int readBytes; try { - readBytes = _networkStream.Read(buffer); + readBytes = _stream.Read(buffer); } catch { @@ -219,7 +222,7 @@ public override void Write(byte[] buffer, int offset, int size) CheckError(); try { - _networkStream.Write(buffer, offset, size); + _stream.Write(buffer, offset, size); } catch { @@ -233,7 +236,7 @@ public override void Write(ReadOnlySpan buffer) CheckError(); try { - _networkStream.Write(buffer); + _stream.Write(buffer); } catch { @@ -249,7 +252,7 @@ private void AsyncReadCallback(IAsyncResult ar) { try { - int readBytes = _networkStream.EndRead(ar); + int readBytes = _stream.EndRead(ar); if (readBytes == 0) { _isFullyRead = true; @@ -273,7 +276,7 @@ public override IAsyncResult BeginRead(byte[] buffer, int offset, int size, Asyn LazyAsyncResult userResult = new LazyAsyncResult(this, state, callback); try { - _networkStream.BeginRead(buffer, offset, size, new AsyncCallback(AsyncReadCallback), userResult); + _stream.BeginRead(buffer, offset, size, new AsyncCallback(AsyncReadCallback), userResult); } catch { @@ -307,7 +310,7 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, Asy CheckError(); try { - return _networkStream.BeginWrite(buffer, offset, size, callback, state); + return _stream.BeginWrite(buffer, offset, size, callback, state); } catch { @@ -320,7 +323,7 @@ public override void EndWrite(IAsyncResult asyncResult) { try { - _networkStream.EndWrite(asyncResult); + _stream.EndWrite(asyncResult); } finally { @@ -330,19 +333,19 @@ public override void EndWrite(IAsyncResult asyncResult) public override void Flush() { - _networkStream.Flush(); + _stream.Flush(); } public override void SetLength(long value) { - _networkStream.SetLength(value); + _stream.SetLength(value); } public override bool CanTimeout { get { - return _networkStream.CanTimeout; + return _stream.CanTimeout; } } @@ -350,11 +353,11 @@ public override int ReadTimeout { get { - return _networkStream.ReadTimeout; + return _stream.ReadTimeout; } set { - _networkStream.ReadTimeout = value; + _stream.ReadTimeout = value; } } @@ -362,18 +365,18 @@ public override int WriteTimeout { get { - return _networkStream.WriteTimeout; + return _stream.WriteTimeout; } set { - _networkStream.WriteTimeout = value; + _stream.WriteTimeout = value; } } internal void SetSocketTimeoutOption(int timeout) { - _networkStream.ReadTimeout = timeout; - _networkStream.WriteTimeout = timeout; + _stream.ReadTimeout = timeout; + _stream.WriteTimeout = timeout; } } } diff --git a/src/libraries/System.Net.Requests/src/System/Net/NetworkStreamWrapper.cs b/src/libraries/System.Net.Requests/src/System/Net/NetworkStreamWrapper.cs index 134c773d74a73..3d8b4e2604fef 100644 --- a/src/libraries/System.Net.Requests/src/System/Net/NetworkStreamWrapper.cs +++ b/src/libraries/System.Net.Requests/src/System/Net/NetworkStreamWrapper.cs @@ -1,7 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.IO; +using System.Net.Security; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; @@ -11,6 +13,7 @@ namespace System.Net internal class NetworkStreamWrapper : Stream { private NetworkStream _networkStream; + private SslStream? _sslStream; internal NetworkStreamWrapper(NetworkStream stream) { @@ -21,7 +24,7 @@ protected bool UsingSecureStream { get { - return (_networkStream is TlsStream); + return _sslStream != null; } } @@ -41,15 +44,17 @@ internal Socket Socket } } - internal NetworkStream NetworkStream + internal Stream Stream { get { - return _networkStream; + return (Stream?)_sslStream ?? _networkStream; } set { - _networkStream = value; + // The setter is only used to upgrade to secure connection by wrapping the _networkStream + Debug.Assert(value is SslStream, "Expected SslStream"); + _sslStream = (SslStream)value; } } @@ -57,7 +62,7 @@ public override bool CanRead { get { - return _networkStream.CanRead; + return Stream.CanRead; } } @@ -65,7 +70,7 @@ public override bool CanSeek { get { - return _networkStream.CanSeek; + return Stream.CanSeek; } } @@ -73,7 +78,7 @@ public override bool CanWrite { get { - return _networkStream.CanWrite; + return Stream.CanWrite; } } @@ -81,7 +86,7 @@ public override bool CanTimeout { get { - return _networkStream.CanTimeout; + return Stream.CanTimeout; } } @@ -89,11 +94,11 @@ public override int ReadTimeout { get { - return _networkStream.ReadTimeout; + return Stream.ReadTimeout; } set { - _networkStream.ReadTimeout = value; + Stream.ReadTimeout = value; } } @@ -101,11 +106,11 @@ public override int WriteTimeout { get { - return _networkStream.WriteTimeout; + return Stream.WriteTimeout; } set { - _networkStream.WriteTimeout = value; + Stream.WriteTimeout = value; } } @@ -113,7 +118,7 @@ public override long Length { get { - return _networkStream.Length; + return Stream.Length; } } @@ -121,27 +126,27 @@ public override long Position { get { - return _networkStream.Position; + return Stream.Position; } set { - _networkStream.Position = value; + Stream.Position = value; } } public override long Seek(long offset, SeekOrigin origin) { - return _networkStream.Seek(offset, origin); + return Stream.Seek(offset, origin); } public override int Read(byte[] buffer, int offset, int size) { - return _networkStream.Read(buffer, offset, size); + return Stream.Read(buffer, offset, size); } public override void Write(byte[] buffer, int offset, int size) { - _networkStream.Write(buffer, offset, size); + Stream.Write(buffer, offset, size); } protected override void Dispose(bool disposing) @@ -162,7 +167,7 @@ protected override void Dispose(bool disposing) internal void CloseSocket() { - _networkStream.Close(); + Stream.Close(); } public void Close(int timeout) @@ -172,63 +177,63 @@ public void Close(int timeout) public override IAsyncResult BeginRead(byte[] buffer, int offset, int size, AsyncCallback? callback, object? state) { - return _networkStream.BeginRead(buffer, offset, size, callback, state); + return Stream.BeginRead(buffer, offset, size, callback, state); } public override int EndRead(IAsyncResult asyncResult) { - return _networkStream.EndRead(asyncResult); + return Stream.EndRead(asyncResult); } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return _networkStream.ReadAsync(buffer, offset, count, cancellationToken); + return Stream.ReadAsync(buffer, offset, count, cancellationToken); } public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - return _networkStream.ReadAsync(buffer, cancellationToken); + return Stream.ReadAsync(buffer, cancellationToken); } public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback? callback, object? state) { - return _networkStream.BeginWrite(buffer, offset, size, callback, state); + return Stream.BeginWrite(buffer, offset, size, callback, state); } public override void EndWrite(IAsyncResult asyncResult) { - _networkStream.EndWrite(asyncResult); + Stream.EndWrite(asyncResult); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return _networkStream.WriteAsync(buffer, offset, count, cancellationToken); + return Stream.WriteAsync(buffer, offset, count, cancellationToken); } public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { - return _networkStream.WriteAsync(buffer, cancellationToken); + return Stream.WriteAsync(buffer, cancellationToken); } public override void Flush() { - _networkStream.Flush(); + Stream.Flush(); } public override Task FlushAsync(CancellationToken cancellationToken) { - return _networkStream.FlushAsync(cancellationToken); + return Stream.FlushAsync(cancellationToken); } public override void SetLength(long value) { - _networkStream.SetLength(value); + Stream.SetLength(value); } internal void SetSocketTimeoutOption(int timeout) { - _networkStream.ReadTimeout = timeout; - _networkStream.WriteTimeout = timeout; + Stream.ReadTimeout = timeout; + Stream.WriteTimeout = timeout; } } } // System.Net