diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index f2fe7f2ae..3712ea768 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -154,6 +154,17 @@ public class Session : ISession /// private bool _isDisconnecting; + /// + /// Indicates whether it is the init kex. + /// + private bool _isInitialKex; + + /// + /// Indicates whether server supports strict key exchange. + /// 1.10. + /// + private bool _isStrictKex; + private IKeyExchange _keyExchange; private HashAlgorithm _serverMac; @@ -281,35 +292,11 @@ public bool IsConnected /// public byte[] SessionId { get; private set; } - private Message _clientInitMessage; - /// /// Gets the client init message. /// /// The client init message. - public Message ClientInitMessage - { - get - { - _clientInitMessage ??= new KeyExchangeInitMessage - { - KeyExchangeAlgorithms = ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(), - ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(), - EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(), - EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(), - MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(), - MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(), - CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(), - CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(), - LanguagesClientToServer = new[] { string.Empty }, - LanguagesServerToClient = new[] { string.Empty }, - FirstKexPacketFollows = false, - Reserved = 0 - }; - - return _clientInitMessage; - } - } + public Message ClientInitMessage { get; private set; } /// /// Gets the server version string. @@ -617,6 +604,8 @@ public void Connect() // Send our key exchange init. // We need to do this before starting the message listener to avoid the case where we receive the server // key exchange init and we continue the key exchange before having sent our own init. + _isInitialKex = true; + ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: true); SendMessage(ClientInitMessage); // Mark the message listener threads as started @@ -741,6 +730,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken) // Send our key exchange init. // We need to do this before starting the message listener to avoid the case where we receive the server // key exchange init and we continue the key exchange before having sent our own init. + _isInitialKex = true; + ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: true); SendMessage(ClientInitMessage); // Mark the message listener threads as started @@ -1107,13 +1098,20 @@ internal void SendMessage(Message message) SendPacket(data, 0, data.Length); } - // increment the packet sequence number only after we're sure the packet has - // been sent; even though it's only used for the MAC, it needs to be incremented - // for each package sent. - // - // the server will use it to verify the data integrity, and as such the order in - // which messages are sent must follow the outbound packet sequence number - _outboundPacketSequence++; + if (_isStrictKex && message is NewKeysMessage) + { + _outboundPacketSequence = 0; + } + else + { + // increment the packet sequence number only after we're sure the packet has + // been sent; even though it's only used for the MAC, it needs to be incremented + // for each package sent. + // + // the server will use it to verify the data integrity, and as such the order in + // which messages are sent must follow the outbound packet sequence number + _outboundPacketSequence++; + } } } @@ -1344,6 +1342,13 @@ private Message ReceiveMessage(Socket socket) _inboundPacketSequence++; + // The below code mirrors from https://github.com/openssh/openssh-portable/commit/1edb00c58f8a6875fad6a497aa2bacf37f9e6cd5 + // It ensures the integrity of key exchange process. + if (_inboundPacketSequence == uint.MaxValue && _isInitialKex) + { + throw new SshConnectionException("Inbound packet sequence number is about to wrap during initial key exchange.", DisconnectReason.KeyExchangeFailed); + } + return LoadMessage(data, messagePayloadOffset, messagePayloadLength); } @@ -1455,8 +1460,20 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message) _keyExchangeCompletedWaitHandle.Reset(); + if (_isInitialKex && message.KeyExchangeAlgorithms.Contains("kex-strict-s-v00@openssh.com")) + { + _isStrictKex = true; + + DiagnosticAbstraction.Log(string.Format("[{0}] Enabling strict key exchange extension.", ToHex(SessionId))); + + if (_inboundPacketSequence != 1) + { + throw new SshConnectionException("KEXINIT was not the first packet during strict key exchange.", DisconnectReason.KeyExchangeFailed); + } + } + // Disable messages that are not key exchange related - _sshMessageFactory.DisableNonKeyExchangeMessages(); + _sshMessageFactory.DisableNonKeyExchangeMessages(_isStrictKex); _keyExchange = _serviceFactory.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, message.KeyExchangeAlgorithms); @@ -1533,6 +1550,17 @@ internal void OnNewKeysReceived(NewKeysMessage message) // Enable activated messages that are not key exchange related _sshMessageFactory.EnableActivatedMessages(); + if (_isInitialKex) + { + _isInitialKex = false; + ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: false); + } + + if (_isStrictKex) + { + _inboundPacketSequence = 0; + } + NewKeysReceived?.Invoke(this, new MessageEventArgs(message)); // Signal that key exchange completed @@ -2067,7 +2095,28 @@ private void Reset() private static SshConnectionException CreateConnectionAbortedByServerException() { return new SshConnectionException("An established connection was aborted by the server.", - DisconnectReason.ConnectionLost); + DisconnectReason.ConnectionLost); + } + + private KeyExchangeInitMessage BuildClientInitMessage(bool includeStrictKexPseudoAlgorithm) + { + return new KeyExchangeInitMessage + { + KeyExchangeAlgorithms = includeStrictKexPseudoAlgorithm ? + ConnectionInfo.KeyExchangeAlgorithms.Keys.Concat(["kex-strict-c-v00@openssh.com"]).ToArray() : + ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(), + ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(), + EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(), + EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(), + MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(), + MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(), + CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(), + CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(), + LanguagesClientToServer = new[] { string.Empty }, + LanguagesServerToClient = new[] { string.Empty }, + FirstKexPacketFollows = false, + Reserved = 0, + }; } private bool _disposed; diff --git a/src/Renci.SshNet/SshMessageFactory.cs b/src/Renci.SshNet/SshMessageFactory.cs index efa861256..2887559a6 100644 --- a/src/Renci.SshNet/SshMessageFactory.cs +++ b/src/Renci.SshNet/SshMessageFactory.cs @@ -115,16 +115,41 @@ public Message Create(byte messageNumber) return enabledMessageMetadata.Create(); } - public void DisableNonKeyExchangeMessages() + /// + /// Disables non-KeyExchange messages. + /// + /// + /// to indicate the strict key exchange mode; otherwise . + /// In strict key exchange mode, only below messages are allowed: + /// + /// SSH_MSG_KEXINIT -> 20 + /// SSH_MSG_NEWKEYS -> 21 + /// SSH_MSG_DISCONNECT -> 1 + /// + /// Note: + /// The relevant KEX Reply MSG will be allowed from a sub class of KeyExchange class. + /// For example, it calls Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY"); if the curve25519-sha256 KEX algorithm is selected per negotiation. + /// + public void DisableNonKeyExchangeMessages(bool strict) { for (var i = 0; i < AllMessages.Length; i++) { var messageMetadata = AllMessages[i]; var messageNumber = messageMetadata.Number; - if (messageNumber is (> 2 and < 20) or > 30) + if (strict) + { + if (messageNumber is not 20 and not 21 and not 1) + { + _enabledMessagesByNumber[messageNumber] = null; + } + } + else { - _enabledMessagesByNumber[messageNumber] = null; + if (messageNumber is (> 2 and < 20) or > 30) + { + _enabledMessagesByNumber[messageNumber] = null; + } } } } diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs index 1950f2759..326bae645 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs @@ -87,7 +87,7 @@ public void IsConnectedShouldReturnFalse() } [TestMethod] - public void SendMessageShouldThrowShhConnectionException() + public void SendMessageShouldThrowSshConnectionException() { try { @@ -189,7 +189,7 @@ public void ISession_MessageListenerCompletedShouldBeSignaled() } [TestMethod] - public void ISession_SendMessageShouldThrowShhConnectionException() + public void ISession_SendMessageShouldThrowSshConnectionException() { var session = (ISession)_session; diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs index cdc95c12e..de7a89d6b 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using System.Threading; using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; @@ -30,6 +31,31 @@ public void ClientVersionIsRenciSshNet() Assert.AreEqual("SSH-2.0-Renci.SshNet.SshClient.0.0.1", Session.ClientVersion); } + [TestMethod] + public void IncludeStrictKexPseudoAlgorithmInInitKex() + { + Assert.IsTrue(ServerBytesReceivedRegister.Count > 0); + + var kexInitMessage = new KeyExchangeInitMessage(); + kexInitMessage.Load(ServerBytesReceivedRegister[0], 4 + 1 + 1, ServerBytesReceivedRegister[0].Length - 4 - 1 - 1); + Assert.IsTrue(kexInitMessage.KeyExchangeAlgorithms.Contains("kex-strict-c-v00@openssh.com")); + } + + [TestMethod] + public void ShouldNotIncludeStrictKexPseudoAlgorithmInSubsequentKex() + { + ServerBytesReceivedRegister.Clear(); + Session.SendMessage(Session.ClientInitMessage); + + Thread.Sleep(100); + + Assert.IsTrue(ServerBytesReceivedRegister.Count > 0); + + var kexInitMessage = new KeyExchangeInitMessage(); + kexInitMessage.Load(ServerBytesReceivedRegister[0], 4 + 1 + 1, ServerBytesReceivedRegister[0].Length - 4 - 1 - 1); + Assert.IsFalse(kexInitMessage.KeyExchangeAlgorithms.Contains("kex-strict-c-v00@openssh.com")); + } + [TestMethod] public void ConnectionInfoShouldReturnConnectionInfoPassedThroughConstructor() { diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs index 42c1f54a6..64df26b5a 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs @@ -46,8 +46,7 @@ public abstract class SessionTest_ConnectedBase protected Session Session { get; private set; } protected Socket ClientSocket { get; private set; } protected Socket ServerSocket { get; private set; } - internal SshIdentification ServerIdentification { get; set; } - protected bool CallSessionConnectWhenArrange { get; set; } + protected SshIdentification ServerIdentification { get; private set; } /// /// Should the "server" wait for the client kexinit before sending its own. @@ -163,8 +162,6 @@ protected virtual void SetupData() ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo); - CallSessionConnectWhenArrange = true; - void SendKeyExchangeInit() { var keyExchangeInitMessage = new KeyExchangeInitMessage @@ -204,7 +201,7 @@ private void SetupMocks() _ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange()) .Returns(_protocolVersionExchangeMock.Object); _ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout)) - .Returns(() => ServerIdentification); + .Returns(ServerIdentification); _ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object); _ = _keyExchangeMock.Setup(p => p.Name) .Returns(_keyExchangeAlgorithm); @@ -252,10 +249,7 @@ protected void Arrange() SetupData(); SetupMocks(); - if (CallSessionConnectWhenArrange) - { - Session.Connect(); - } + Session.Connect(); } protected virtual void ClientAuthentication_Callback() diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectingBase.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectingBase.cs new file mode 100644 index 000000000..f34634d7b --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectingBase.cs @@ -0,0 +1,294 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Net; +using System.Net.Sockets; +using System.Security.Cryptography; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Moq; + +using Renci.SshNet.Common; +using Renci.SshNet.Compression; +using Renci.SshNet.Connection; +using Renci.SshNet.Messages; +using Renci.SshNet.Messages.Transport; +using Renci.SshNet.Security; +using Renci.SshNet.Security.Cryptography; +using Renci.SshNet.Tests.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public abstract class SessionTest_ConnectingBase + { + internal Mock ServiceFactoryMock { get; private set; } + internal Mock SocketFactoryMock { get; private set; } + internal Mock ConnectorMock { get; private set; } + + private Mock _protocolVersionExchangeMock; + private Mock _keyExchangeMock; + private Mock _clientAuthenticationMock; + private IPEndPoint _serverEndPoint; + private string[] _keyExchangeAlgorithms; + private bool _authenticationStarted; + private SocketFactory _socketFactory; + + protected Random Random { get; private set; } + protected byte[] SessionId { get; private set; } + protected ConnectionInfo ConnectionInfo { get; private set; } + protected IList DisconnectedRegister { get; private set; } + protected IList> DisconnectReceivedRegister { get; private set; } + protected IList ErrorOccurredRegister { get; private set; } + protected AsyncSocketListener ServerListener { get; private set; } + protected IList ServerBytesReceivedRegister { get; private set; } + protected Session Session { get; private set; } + protected Socket ClientSocket { get; private set; } + protected Socket ServerSocket { get; private set; } + protected SshIdentification ServerIdentification { get; set; } + protected virtual bool ServerSupportsStrictKex { get; } + + protected virtual bool ServerResetsSequenceAfterSendingNewKeys + { + get + { + return ServerSupportsStrictKex; + } + } + + protected uint ServerOutboundPacketSequence { get; set; } + + [TestInitialize] + public void Setup() + { + CreateMocks(); + SetupData(); + SetupMocks(); + } + + protected virtual void ActionBeforeKexInit() + { + } + + protected virtual void ActionAfterKexInit() + { + } + + [TestCleanup] + public void TearDown() + { + if (ServerListener != null) + { + ServerListener.Dispose(); + ServerListener = null; + } + + if (ServerSocket != null) + { + ServerSocket.Dispose(); + ServerSocket = null; + } + + if (Session != null) + { + Session.Dispose(); + Session = null; + } + + if (ClientSocket != null && ClientSocket.Connected) + { + ClientSocket.Shutdown(SocketShutdown.Both); + ClientSocket.Dispose(); + } + } + + protected virtual void SetupData() + { + Random = new Random(); + + _serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122); + ConnectionInfo = new ConnectionInfo( + _serverEndPoint.Address.ToString(), + _serverEndPoint.Port, + "user", + new PasswordAuthenticationMethod("user", "password")) + { Timeout = TimeSpan.FromSeconds(20) }; + _keyExchangeAlgorithms = ServerSupportsStrictKex ? + [Random.Next().ToString(CultureInfo.InvariantCulture), "kex-strict-s-v00@openssh.com"] : + [Random.Next().ToString(CultureInfo.InvariantCulture)]; + SessionId = new byte[10]; + Random.NextBytes(SessionId); + DisconnectedRegister = new List(); + DisconnectReceivedRegister = new List>(); + ErrorOccurredRegister = new List(); + ServerBytesReceivedRegister = new List(); + ServerIdentification = new SshIdentification("2.0", "OurServerStub"); + _authenticationStarted = false; + _socketFactory = new SocketFactory(); + + Session = new Session(ConnectionInfo, ServiceFactoryMock.Object, SocketFactoryMock.Object); + Session.Disconnected += (sender, args) => DisconnectedRegister.Add(args); + Session.DisconnectReceived += (sender, args) => DisconnectReceivedRegister.Add(args); + Session.ErrorOccured += (sender, args) => ErrorOccurredRegister.Add(args); + + ServerListener = new AsyncSocketListener(_serverEndPoint) + { + ShutdownRemoteCommunicationSocket = false + }; + ServerListener.Connected += socket => + { + ServerSocket = socket; + ActionBeforeKexInit(); + var keyExchangeInitMessage = new KeyExchangeInitMessage + { + CompressionAlgorithmsClientToServer = new string[0], + CompressionAlgorithmsServerToClient = new string[0], + EncryptionAlgorithmsClientToServer = new string[0], + EncryptionAlgorithmsServerToClient = new string[0], + KeyExchangeAlgorithms = _keyExchangeAlgorithms, + LanguagesClientToServer = new string[0], + LanguagesServerToClient = new string[0], + MacAlgorithmsClientToServer = new string[0], + MacAlgorithmsServerToClient = new string[0], + ServerHostKeyAlgorithms = new string[0] + }; + var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null); + _ = ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None); + ServerOutboundPacketSequence++; + }; + ServerListener.BytesReceived += (received, socket) => + { + ServerBytesReceivedRegister.Add(received); + + if (received.Length > 5 && received[5] == 20) + { + ActionAfterKexInit(); + var newKeysMessage = new NewKeysMessage(); + var newKeys = newKeysMessage.GetPacket(8, null); + _ = ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None); + + if (ServerResetsSequenceAfterSendingNewKeys) + { + ServerOutboundPacketSequence = 0; + } + else + { + ServerOutboundPacketSequence++; + } + + if (!_authenticationStarted) + { + var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication) + .Build(ServerOutboundPacketSequence); + var hash = Abstractions.CryptoAbstraction.CreateSHA256().ComputeHash(serviceAcceptMessage); + + var packet = new byte[serviceAcceptMessage.Length - 4 + hash.Length]; + + Array.Copy(serviceAcceptMessage, 4, packet, 0, serviceAcceptMessage.Length - 4); + Array.Copy(hash, 0, packet, serviceAcceptMessage.Length - 4, hash.Length); + + _ = ServerSocket.Send(packet, 0, packet.Length, SocketFlags.None); + + ServerOutboundPacketSequence++; + + _authenticationStarted = true; + } + } + }; + ServerListener.Start(); + + ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo); + } + + private void CreateMocks() + { + ServiceFactoryMock = new Mock(MockBehavior.Strict); + SocketFactoryMock = new Mock(MockBehavior.Strict); + ConnectorMock = new Mock(MockBehavior.Strict); + _protocolVersionExchangeMock = new Mock(MockBehavior.Strict); + _keyExchangeMock = new Mock(MockBehavior.Strict); + _clientAuthenticationMock = new Mock(MockBehavior.Strict); + } + + private void SetupMocks() + { + _ = ServiceFactoryMock.Setup(p => p.CreateConnector(ConnectionInfo, SocketFactoryMock.Object)) + .Returns(ConnectorMock.Object); + _ = ConnectorMock.Setup(p => p.Connect(ConnectionInfo)) + .Returns(ClientSocket); + _ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange()) + .Returns(_protocolVersionExchangeMock.Object); + _ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout)) + .Returns(() => ServerIdentification); + _ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, _keyExchangeAlgorithms)).Returns(_keyExchangeMock.Object); + + _ = _keyExchangeMock.Setup(p => p.Name) + .Returns(_keyExchangeAlgorithms[0]); + _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny(), false)); + _ = _keyExchangeMock.Setup(p => p.ExchangeHash) + .Returns(SessionId); + _ = _keyExchangeMock.Setup(p => p.CreateServerCipher(out It.Ref.IsAny)) + .Returns((ref bool serverAead) => + { + serverAead = false; + return (Cipher) null; + }); + _ = _keyExchangeMock.Setup(p => p.CreateClientCipher(out It.Ref.IsAny)) + .Returns((ref bool clientAead) => + { + clientAead = false; + return (Cipher) null; + }); + _ = _keyExchangeMock.Setup(p => p.CreateServerHash(out It.Ref.IsAny)) + .Returns((ref bool serverEtm) => + { + serverEtm = false; + return SHA256.Create(); + }); + _ = _keyExchangeMock.Setup(p => p.CreateClientHash(out It.Ref.IsAny)) + .Returns((ref bool clientEtm) => + { + clientEtm = false; + return (HashAlgorithm) null; + }); + _ = _keyExchangeMock.Setup(p => p.CreateCompressor()) + .Returns((Compressor) null); + _ = _keyExchangeMock.Setup(p => p.CreateDecompressor()) + .Returns((Compressor) null); + _ = _keyExchangeMock.Setup(p => p.Dispose()); + _ = ServiceFactoryMock.Setup(p => p.CreateClientAuthentication()) + .Returns(_clientAuthenticationMock.Object); + _ = _clientAuthenticationMock.Setup(p => p.Authenticate(ConnectionInfo, Session)); + } + + private class ServiceAcceptMessageBuilder + { + private readonly ServiceName _serviceName; + + private ServiceAcceptMessageBuilder(ServiceName serviceName) + { + _serviceName = serviceName; + } + + public static ServiceAcceptMessageBuilder Create(ServiceName serviceName) + { + return new ServiceAcceptMessageBuilder(serviceName); + } + + public byte[] Build(uint sequence) + { + var serviceName = _serviceName.ToArray(); + var target = new ServiceAcceptMessage(); + + var sshDataStream = new SshDataStream(4 + 4 + 1 + 1 + 4 + serviceName.Length); + sshDataStream.Write(sequence); + sshDataStream.Write((uint) (sshDataStream.Capacity - 8)); //sequence and packet length + sshDataStream.WriteByte(0); // padding length + sshDataStream.WriteByte(target.MessageNumber); + sshDataStream.WriteBinary(serviceName); + return sshDataStream.ToArray(); + } + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerIdentificationReceived.cs similarity index 91% rename from test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs rename to test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerIdentificationReceived.cs index 7b5ff1d86..fdb7cc79f 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerIdentificationReceived.cs @@ -5,14 +5,12 @@ namespace Renci.SshNet.Tests.Classes { [TestClass] - public class SessionTest_Connected_ServerIdentificationReceived : SessionTest_ConnectedBase + public class SessionTest_Connecting_ServerIdentificationReceived : SessionTest_ConnectingBase { protected override void SetupData() { base.SetupData(); - CallSessionConnectWhenArrange = false; - Session.ServerIdentificationReceived += (s, e) => { if ((e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.5", System.StringComparison.Ordinal) || e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6", System.StringComparison.Ordinal)) @@ -24,10 +22,6 @@ protected override void SetupData() }; } - protected override void Act() - { - } - [TestMethod] [DataRow("OpenSSH_6.5")] [DataRow("OpenSSH_6.5p1")] diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex.cs new file mode 100644 index 000000000..339f9df6c --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex.cs @@ -0,0 +1,35 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Renci.SshNet.Common; +using Renci.SshNet.Messages.Transport; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex : SessionTest_ConnectingBase + { + protected override bool ServerSupportsStrictKex + { + get + { + return true; + } + } + + protected override bool ServerResetsSequenceAfterSendingNewKeys + { + get + { + return false; + } + } + + + [TestMethod] + public void ShouldThrowSshConnectionException() + { + var reason = Assert.ThrowsException(Session.Connect).DisconnectReason; + Assert.AreEqual(DisconnectReason.MacError, reason); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex.cs new file mode 100644 index 000000000..a8f0680ba --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex.cs @@ -0,0 +1,31 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex : SessionTest_ConnectingBase + { + protected override bool ServerSupportsStrictKex + { + get + { + return true; + } + } + + protected override bool ServerResetsSequenceAfterSendingNewKeys + { + get + { + return true; + } + } + + + [TestMethod] + public void ShouldNotThrowException() + { + Session.Connect(); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs new file mode 100644 index 000000000..6fbdcceaa --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs @@ -0,0 +1,48 @@ +using System.Globalization; +using System.Net.Sockets; +using System.Text; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Renci.SshNet.Common; +using Renci.SshNet.Messages.Transport; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex : SessionTest_ConnectingBase + { + protected override bool ServerSupportsStrictKex + { + get + { + return true; + } + } + + protected override void ActionAfterKexInit() + { + using var stream = new SshDataStream(0); + stream.WriteByte(1); + stream.Write("This is a debug message", Encoding.UTF8); + stream.Write(CultureInfo.CurrentCulture.Name, Encoding.UTF8); + + var debugMessage = new DebugMessage(); + debugMessage.Load(stream.ToArray()); + var debug = debugMessage.GetPacket(8, null); + + // MitM sends debug message to client + _ = ServerSocket.Send(debug, 4, debug.Length - 4, SocketFlags.None); + + // MitM drops server message + ServerOutboundPacketSequence++; + } + + [TestMethod] + public void ShouldThrowSshException() + { + var message = Assert.ThrowsException(Session.Connect).Message; + Assert.AreEqual("Message type 4 is not valid in the current context.", message); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs new file mode 100644 index 000000000..989d43e56 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs @@ -0,0 +1,39 @@ +using System.Net.Sockets; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Renci.SshNet.Common; +using Renci.SshNet.Messages.Transport; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex : SessionTest_ConnectingBase + { + protected override bool ServerSupportsStrictKex + { + get + { + return true; + } + } + + protected override void ActionAfterKexInit() + { + var disconnectMessage = new DisconnectMessage(DisconnectReason.TooManyConnections, "too many connections"); + var disconnect = disconnectMessage.GetPacket(8, null); + + // Server sends disconnect message to client + _ = ServerSocket.Send(disconnect, 4, disconnect.Length - 4, SocketFlags.None); + + ServerOutboundPacketSequence++; + } + + [TestMethod] + public void DisconnectIsAllowedDuringStrictKex() + { + var exception = Assert.ThrowsException(Session.Connect); + Assert.AreEqual(DisconnectReason.TooManyConnections, exception.DisconnectReason); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs new file mode 100644 index 000000000..f20d81d8a --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs @@ -0,0 +1,38 @@ +using System.Net.Sockets; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Renci.SshNet.Messages.Transport; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex : SessionTest_ConnectingBase + { + protected override bool ServerSupportsStrictKex + { + get + { + return false; + } + } + + protected override void ActionAfterKexInit() + { + var ignoreMessage = new IgnoreMessage(); + var ignore = ignoreMessage.GetPacket(8, null); + + // MitM sends ignore message to client + _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None); + + // MitM drops server message + ServerOutboundPacketSequence++; + } + + [TestMethod] + public void DoesNotThrowException() + { + Session.Connect(); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs new file mode 100644 index 000000000..0179c0eb0 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs @@ -0,0 +1,40 @@ +using System.Net.Sockets; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Renci.SshNet.Common; +using Renci.SshNet.Messages.Transport; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex : SessionTest_ConnectingBase + { + protected override bool ServerSupportsStrictKex + { + get + { + return true; + } + } + + protected override void ActionAfterKexInit() + { + var ignoreMessage = new IgnoreMessage(); + var ignore = ignoreMessage.GetPacket(8, null); + + // MitM sends ignore message to client + _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None); + + // MitM drops server message + ServerOutboundPacketSequence++; + } + + [TestMethod] + public void ShouldThrowSshException() + { + var message = Assert.ThrowsException(Session.Connect).Message; + Assert.AreEqual("Message type 2 is not valid in the current context.", message); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs new file mode 100644 index 000000000..c85d925b7 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs @@ -0,0 +1,38 @@ +using System.Net.Sockets; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Renci.SshNet.Messages.Transport; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex : SessionTest_ConnectingBase + { + protected override bool ServerSupportsStrictKex + { + get + { + return false; + } + } + + protected override void ActionBeforeKexInit() + { + var ignoreMessage = new IgnoreMessage(); + var ignore = ignoreMessage.GetPacket(8, null); + + // MitM sends ignore message to client + _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None); + + // MitM drops server message + ServerOutboundPacketSequence++; + } + + [TestMethod] + public void DoesNotThrowException() + { + Session.Connect(); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs new file mode 100644 index 000000000..53dde0b3c --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs @@ -0,0 +1,41 @@ +using System.Net.Sockets; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Renci.SshNet.Common; +using Renci.SshNet.Messages.Transport; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex : SessionTest_ConnectingBase + { + protected override bool ServerSupportsStrictKex + { + get + { + return true; + } + } + + protected override void ActionBeforeKexInit() + { + var ignoreMessage = new IgnoreMessage(); + var ignore = ignoreMessage.GetPacket(8, null); + + // MitM sends ignore message to client + _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None); + + // MitM drops server message + ServerOutboundPacketSequence++; + } + + [TestMethod] + public void ShouldThrowSshConnectionException() + { + var exception = Assert.ThrowsException(Session.Connect); + Assert.AreEqual(DisconnectReason.KeyExchangeFailed, exception.DisconnectReason); + Assert.AreEqual("KEXINIT was not the first packet during strict key exchange.", exception.Message); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs index 4bd134348..c493b6df1 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs @@ -57,7 +57,7 @@ public void IsConnectedShouldReturnFalse() } [TestMethod] - public void SendMessageShouldThrowShhConnectionException() + public void SendMessageShouldThrowSshConnectionException() { try { @@ -159,7 +159,7 @@ public void ISession_MessageListenerCompletedShouldBeSignaled() } [TestMethod] - public void ISession_SendMessageShouldThrowShhConnectionException() + public void ISession_SendMessageShouldThrowSshConnectionException() { var session = (ISession) _session; diff --git a/test/Renci.SshNet.Tests/Common/AsyncSocketListener.cs b/test/Renci.SshNet.Tests/Common/AsyncSocketListener.cs index 0f7dac81f..59317f44f 100644 --- a/test/Renci.SshNet.Tests/Common/AsyncSocketListener.cs +++ b/test/Renci.SshNet.Tests/Common/AsyncSocketListener.cs @@ -385,7 +385,7 @@ private class SocketStateObject { public Socket Socket { get; private set; } - public readonly byte[] Buffer = new byte[1024]; + public readonly byte[] Buffer = new byte[2048]; public SocketStateObject(Socket handler) {