diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs
index f2fe7f2ae..3712ea768 100644
--- a/src/Renci.SshNet/Session.cs
+++ b/src/Renci.SshNet/Session.cs
@@ -154,6 +154,17 @@ public class Session : ISession
///
private bool _isDisconnecting;
+ ///
+ /// Indicates whether it is the init kex.
+ ///
+ private bool _isInitialKex;
+
+ ///
+ /// Indicates whether server supports strict key exchange.
+ /// 1.10.
+ ///
+ private bool _isStrictKex;
+
private IKeyExchange _keyExchange;
private HashAlgorithm _serverMac;
@@ -281,35 +292,11 @@ public bool IsConnected
///
public byte[] SessionId { get; private set; }
- private Message _clientInitMessage;
-
///
/// Gets the client init message.
///
/// The client init message.
- public Message ClientInitMessage
- {
- get
- {
- _clientInitMessage ??= new KeyExchangeInitMessage
- {
- KeyExchangeAlgorithms = ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(),
- ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(),
- EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(),
- EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(),
- MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
- MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
- CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
- CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
- LanguagesClientToServer = new[] { string.Empty },
- LanguagesServerToClient = new[] { string.Empty },
- FirstKexPacketFollows = false,
- Reserved = 0
- };
-
- return _clientInitMessage;
- }
- }
+ public Message ClientInitMessage { get; private set; }
///
/// Gets the server version string.
@@ -617,6 +604,8 @@ public void Connect()
// Send our key exchange init.
// We need to do this before starting the message listener to avoid the case where we receive the server
// key exchange init and we continue the key exchange before having sent our own init.
+ _isInitialKex = true;
+ ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: true);
SendMessage(ClientInitMessage);
// Mark the message listener threads as started
@@ -741,6 +730,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
// Send our key exchange init.
// We need to do this before starting the message listener to avoid the case where we receive the server
// key exchange init and we continue the key exchange before having sent our own init.
+ _isInitialKex = true;
+ ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: true);
SendMessage(ClientInitMessage);
// Mark the message listener threads as started
@@ -1107,13 +1098,20 @@ internal void SendMessage(Message message)
SendPacket(data, 0, data.Length);
}
- // increment the packet sequence number only after we're sure the packet has
- // been sent; even though it's only used for the MAC, it needs to be incremented
- // for each package sent.
- //
- // the server will use it to verify the data integrity, and as such the order in
- // which messages are sent must follow the outbound packet sequence number
- _outboundPacketSequence++;
+ if (_isStrictKex && message is NewKeysMessage)
+ {
+ _outboundPacketSequence = 0;
+ }
+ else
+ {
+ // increment the packet sequence number only after we're sure the packet has
+ // been sent; even though it's only used for the MAC, it needs to be incremented
+ // for each package sent.
+ //
+ // the server will use it to verify the data integrity, and as such the order in
+ // which messages are sent must follow the outbound packet sequence number
+ _outboundPacketSequence++;
+ }
}
}
@@ -1344,6 +1342,13 @@ private Message ReceiveMessage(Socket socket)
_inboundPacketSequence++;
+ // The below code mirrors from https://github.com/openssh/openssh-portable/commit/1edb00c58f8a6875fad6a497aa2bacf37f9e6cd5
+ // It ensures the integrity of key exchange process.
+ if (_inboundPacketSequence == uint.MaxValue && _isInitialKex)
+ {
+ throw new SshConnectionException("Inbound packet sequence number is about to wrap during initial key exchange.", DisconnectReason.KeyExchangeFailed);
+ }
+
return LoadMessage(data, messagePayloadOffset, messagePayloadLength);
}
@@ -1455,8 +1460,20 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
_keyExchangeCompletedWaitHandle.Reset();
+ if (_isInitialKex && message.KeyExchangeAlgorithms.Contains("kex-strict-s-v00@openssh.com"))
+ {
+ _isStrictKex = true;
+
+ DiagnosticAbstraction.Log(string.Format("[{0}] Enabling strict key exchange extension.", ToHex(SessionId)));
+
+ if (_inboundPacketSequence != 1)
+ {
+ throw new SshConnectionException("KEXINIT was not the first packet during strict key exchange.", DisconnectReason.KeyExchangeFailed);
+ }
+ }
+
// Disable messages that are not key exchange related
- _sshMessageFactory.DisableNonKeyExchangeMessages();
+ _sshMessageFactory.DisableNonKeyExchangeMessages(_isStrictKex);
_keyExchange = _serviceFactory.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms,
message.KeyExchangeAlgorithms);
@@ -1533,6 +1550,17 @@ internal void OnNewKeysReceived(NewKeysMessage message)
// Enable activated messages that are not key exchange related
_sshMessageFactory.EnableActivatedMessages();
+ if (_isInitialKex)
+ {
+ _isInitialKex = false;
+ ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: false);
+ }
+
+ if (_isStrictKex)
+ {
+ _inboundPacketSequence = 0;
+ }
+
NewKeysReceived?.Invoke(this, new MessageEventArgs(message));
// Signal that key exchange completed
@@ -2067,7 +2095,28 @@ private void Reset()
private static SshConnectionException CreateConnectionAbortedByServerException()
{
return new SshConnectionException("An established connection was aborted by the server.",
- DisconnectReason.ConnectionLost);
+ DisconnectReason.ConnectionLost);
+ }
+
+ private KeyExchangeInitMessage BuildClientInitMessage(bool includeStrictKexPseudoAlgorithm)
+ {
+ return new KeyExchangeInitMessage
+ {
+ KeyExchangeAlgorithms = includeStrictKexPseudoAlgorithm ?
+ ConnectionInfo.KeyExchangeAlgorithms.Keys.Concat(["kex-strict-c-v00@openssh.com"]).ToArray() :
+ ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(),
+ ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(),
+ EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(),
+ EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(),
+ MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
+ MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
+ CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
+ CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
+ LanguagesClientToServer = new[] { string.Empty },
+ LanguagesServerToClient = new[] { string.Empty },
+ FirstKexPacketFollows = false,
+ Reserved = 0,
+ };
}
private bool _disposed;
diff --git a/src/Renci.SshNet/SshMessageFactory.cs b/src/Renci.SshNet/SshMessageFactory.cs
index efa861256..2887559a6 100644
--- a/src/Renci.SshNet/SshMessageFactory.cs
+++ b/src/Renci.SshNet/SshMessageFactory.cs
@@ -115,16 +115,41 @@ public Message Create(byte messageNumber)
return enabledMessageMetadata.Create();
}
- public void DisableNonKeyExchangeMessages()
+ ///
+ /// Disables non-KeyExchange messages.
+ ///
+ ///
+ /// to indicate the strict key exchange mode; otherwise .
+ /// In strict key exchange mode, only below messages are allowed:
+ ///
+ /// - SSH_MSG_KEXINIT -> 20
+ /// - SSH_MSG_NEWKEYS -> 21
+ /// - SSH_MSG_DISCONNECT -> 1
+ ///
+ /// Note:
+ /// The relevant KEX Reply MSG will be allowed from a sub class of KeyExchange class.
+ /// For example, it calls Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY"); if the curve25519-sha256 KEX algorithm is selected per negotiation.
+ ///
+ public void DisableNonKeyExchangeMessages(bool strict)
{
for (var i = 0; i < AllMessages.Length; i++)
{
var messageMetadata = AllMessages[i];
var messageNumber = messageMetadata.Number;
- if (messageNumber is (> 2 and < 20) or > 30)
+ if (strict)
+ {
+ if (messageNumber is not 20 and not 21 and not 1)
+ {
+ _enabledMessagesByNumber[messageNumber] = null;
+ }
+ }
+ else
{
- _enabledMessagesByNumber[messageNumber] = null;
+ if (messageNumber is (> 2 and < 20) or > 30)
+ {
+ _enabledMessagesByNumber[messageNumber] = null;
+ }
}
}
}
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs
index 1950f2759..326bae645 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs
@@ -87,7 +87,7 @@ public void IsConnectedShouldReturnFalse()
}
[TestMethod]
- public void SendMessageShouldThrowShhConnectionException()
+ public void SendMessageShouldThrowSshConnectionException()
{
try
{
@@ -189,7 +189,7 @@ public void ISession_MessageListenerCompletedShouldBeSignaled()
}
[TestMethod]
- public void ISession_SendMessageShouldThrowShhConnectionException()
+ public void ISession_SendMessageShouldThrowSshConnectionException()
{
var session = (ISession)_session;
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs
index cdc95c12e..de7a89d6b 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs
@@ -1,4 +1,5 @@
using System;
+using System.Linq;
using System.Threading;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;
@@ -30,6 +31,31 @@ public void ClientVersionIsRenciSshNet()
Assert.AreEqual("SSH-2.0-Renci.SshNet.SshClient.0.0.1", Session.ClientVersion);
}
+ [TestMethod]
+ public void IncludeStrictKexPseudoAlgorithmInInitKex()
+ {
+ Assert.IsTrue(ServerBytesReceivedRegister.Count > 0);
+
+ var kexInitMessage = new KeyExchangeInitMessage();
+ kexInitMessage.Load(ServerBytesReceivedRegister[0], 4 + 1 + 1, ServerBytesReceivedRegister[0].Length - 4 - 1 - 1);
+ Assert.IsTrue(kexInitMessage.KeyExchangeAlgorithms.Contains("kex-strict-c-v00@openssh.com"));
+ }
+
+ [TestMethod]
+ public void ShouldNotIncludeStrictKexPseudoAlgorithmInSubsequentKex()
+ {
+ ServerBytesReceivedRegister.Clear();
+ Session.SendMessage(Session.ClientInitMessage);
+
+ Thread.Sleep(100);
+
+ Assert.IsTrue(ServerBytesReceivedRegister.Count > 0);
+
+ var kexInitMessage = new KeyExchangeInitMessage();
+ kexInitMessage.Load(ServerBytesReceivedRegister[0], 4 + 1 + 1, ServerBytesReceivedRegister[0].Length - 4 - 1 - 1);
+ Assert.IsFalse(kexInitMessage.KeyExchangeAlgorithms.Contains("kex-strict-c-v00@openssh.com"));
+ }
+
[TestMethod]
public void ConnectionInfoShouldReturnConnectionInfoPassedThroughConstructor()
{
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs
index 42c1f54a6..64df26b5a 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs
@@ -46,8 +46,7 @@ public abstract class SessionTest_ConnectedBase
protected Session Session { get; private set; }
protected Socket ClientSocket { get; private set; }
protected Socket ServerSocket { get; private set; }
- internal SshIdentification ServerIdentification { get; set; }
- protected bool CallSessionConnectWhenArrange { get; set; }
+ protected SshIdentification ServerIdentification { get; private set; }
///
/// Should the "server" wait for the client kexinit before sending its own.
@@ -163,8 +162,6 @@ protected virtual void SetupData()
ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);
- CallSessionConnectWhenArrange = true;
-
void SendKeyExchangeInit()
{
var keyExchangeInitMessage = new KeyExchangeInitMessage
@@ -204,7 +201,7 @@ private void SetupMocks()
_ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange())
.Returns(_protocolVersionExchangeMock.Object);
_ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout))
- .Returns(() => ServerIdentification);
+ .Returns(ServerIdentification);
_ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
_ = _keyExchangeMock.Setup(p => p.Name)
.Returns(_keyExchangeAlgorithm);
@@ -252,10 +249,7 @@ protected void Arrange()
SetupData();
SetupMocks();
- if (CallSessionConnectWhenArrange)
- {
- Session.Connect();
- }
+ Session.Connect();
}
protected virtual void ClientAuthentication_Callback()
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectingBase.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectingBase.cs
new file mode 100644
index 000000000..f34634d7b
--- /dev/null
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectingBase.cs
@@ -0,0 +1,294 @@
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.Net;
+using System.Net.Sockets;
+using System.Security.Cryptography;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Moq;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Compression;
+using Renci.SshNet.Connection;
+using Renci.SshNet.Messages;
+using Renci.SshNet.Messages.Transport;
+using Renci.SshNet.Security;
+using Renci.SshNet.Security.Cryptography;
+using Renci.SshNet.Tests.Common;
+
+namespace Renci.SshNet.Tests.Classes
+{
+ [TestClass]
+ public abstract class SessionTest_ConnectingBase
+ {
+ internal Mock ServiceFactoryMock { get; private set; }
+ internal Mock SocketFactoryMock { get; private set; }
+ internal Mock ConnectorMock { get; private set; }
+
+ private Mock _protocolVersionExchangeMock;
+ private Mock _keyExchangeMock;
+ private Mock _clientAuthenticationMock;
+ private IPEndPoint _serverEndPoint;
+ private string[] _keyExchangeAlgorithms;
+ private bool _authenticationStarted;
+ private SocketFactory _socketFactory;
+
+ protected Random Random { get; private set; }
+ protected byte[] SessionId { get; private set; }
+ protected ConnectionInfo ConnectionInfo { get; private set; }
+ protected IList DisconnectedRegister { get; private set; }
+ protected IList> DisconnectReceivedRegister { get; private set; }
+ protected IList ErrorOccurredRegister { get; private set; }
+ protected AsyncSocketListener ServerListener { get; private set; }
+ protected IList ServerBytesReceivedRegister { get; private set; }
+ protected Session Session { get; private set; }
+ protected Socket ClientSocket { get; private set; }
+ protected Socket ServerSocket { get; private set; }
+ protected SshIdentification ServerIdentification { get; set; }
+ protected virtual bool ServerSupportsStrictKex { get; }
+
+ protected virtual bool ServerResetsSequenceAfterSendingNewKeys
+ {
+ get
+ {
+ return ServerSupportsStrictKex;
+ }
+ }
+
+ protected uint ServerOutboundPacketSequence { get; set; }
+
+ [TestInitialize]
+ public void Setup()
+ {
+ CreateMocks();
+ SetupData();
+ SetupMocks();
+ }
+
+ protected virtual void ActionBeforeKexInit()
+ {
+ }
+
+ protected virtual void ActionAfterKexInit()
+ {
+ }
+
+ [TestCleanup]
+ public void TearDown()
+ {
+ if (ServerListener != null)
+ {
+ ServerListener.Dispose();
+ ServerListener = null;
+ }
+
+ if (ServerSocket != null)
+ {
+ ServerSocket.Dispose();
+ ServerSocket = null;
+ }
+
+ if (Session != null)
+ {
+ Session.Dispose();
+ Session = null;
+ }
+
+ if (ClientSocket != null && ClientSocket.Connected)
+ {
+ ClientSocket.Shutdown(SocketShutdown.Both);
+ ClientSocket.Dispose();
+ }
+ }
+
+ protected virtual void SetupData()
+ {
+ Random = new Random();
+
+ _serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
+ ConnectionInfo = new ConnectionInfo(
+ _serverEndPoint.Address.ToString(),
+ _serverEndPoint.Port,
+ "user",
+ new PasswordAuthenticationMethod("user", "password"))
+ { Timeout = TimeSpan.FromSeconds(20) };
+ _keyExchangeAlgorithms = ServerSupportsStrictKex ?
+ [Random.Next().ToString(CultureInfo.InvariantCulture), "kex-strict-s-v00@openssh.com"] :
+ [Random.Next().ToString(CultureInfo.InvariantCulture)];
+ SessionId = new byte[10];
+ Random.NextBytes(SessionId);
+ DisconnectedRegister = new List();
+ DisconnectReceivedRegister = new List>();
+ ErrorOccurredRegister = new List();
+ ServerBytesReceivedRegister = new List();
+ ServerIdentification = new SshIdentification("2.0", "OurServerStub");
+ _authenticationStarted = false;
+ _socketFactory = new SocketFactory();
+
+ Session = new Session(ConnectionInfo, ServiceFactoryMock.Object, SocketFactoryMock.Object);
+ Session.Disconnected += (sender, args) => DisconnectedRegister.Add(args);
+ Session.DisconnectReceived += (sender, args) => DisconnectReceivedRegister.Add(args);
+ Session.ErrorOccured += (sender, args) => ErrorOccurredRegister.Add(args);
+
+ ServerListener = new AsyncSocketListener(_serverEndPoint)
+ {
+ ShutdownRemoteCommunicationSocket = false
+ };
+ ServerListener.Connected += socket =>
+ {
+ ServerSocket = socket;
+ ActionBeforeKexInit();
+ var keyExchangeInitMessage = new KeyExchangeInitMessage
+ {
+ CompressionAlgorithmsClientToServer = new string[0],
+ CompressionAlgorithmsServerToClient = new string[0],
+ EncryptionAlgorithmsClientToServer = new string[0],
+ EncryptionAlgorithmsServerToClient = new string[0],
+ KeyExchangeAlgorithms = _keyExchangeAlgorithms,
+ LanguagesClientToServer = new string[0],
+ LanguagesServerToClient = new string[0],
+ MacAlgorithmsClientToServer = new string[0],
+ MacAlgorithmsServerToClient = new string[0],
+ ServerHostKeyAlgorithms = new string[0]
+ };
+ var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null);
+ _ = ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None);
+ ServerOutboundPacketSequence++;
+ };
+ ServerListener.BytesReceived += (received, socket) =>
+ {
+ ServerBytesReceivedRegister.Add(received);
+
+ if (received.Length > 5 && received[5] == 20)
+ {
+ ActionAfterKexInit();
+ var newKeysMessage = new NewKeysMessage();
+ var newKeys = newKeysMessage.GetPacket(8, null);
+ _ = ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None);
+
+ if (ServerResetsSequenceAfterSendingNewKeys)
+ {
+ ServerOutboundPacketSequence = 0;
+ }
+ else
+ {
+ ServerOutboundPacketSequence++;
+ }
+
+ if (!_authenticationStarted)
+ {
+ var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication)
+ .Build(ServerOutboundPacketSequence);
+ var hash = Abstractions.CryptoAbstraction.CreateSHA256().ComputeHash(serviceAcceptMessage);
+
+ var packet = new byte[serviceAcceptMessage.Length - 4 + hash.Length];
+
+ Array.Copy(serviceAcceptMessage, 4, packet, 0, serviceAcceptMessage.Length - 4);
+ Array.Copy(hash, 0, packet, serviceAcceptMessage.Length - 4, hash.Length);
+
+ _ = ServerSocket.Send(packet, 0, packet.Length, SocketFlags.None);
+
+ ServerOutboundPacketSequence++;
+
+ _authenticationStarted = true;
+ }
+ }
+ };
+ ServerListener.Start();
+
+ ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);
+ }
+
+ private void CreateMocks()
+ {
+ ServiceFactoryMock = new Mock(MockBehavior.Strict);
+ SocketFactoryMock = new Mock(MockBehavior.Strict);
+ ConnectorMock = new Mock(MockBehavior.Strict);
+ _protocolVersionExchangeMock = new Mock(MockBehavior.Strict);
+ _keyExchangeMock = new Mock(MockBehavior.Strict);
+ _clientAuthenticationMock = new Mock(MockBehavior.Strict);
+ }
+
+ private void SetupMocks()
+ {
+ _ = ServiceFactoryMock.Setup(p => p.CreateConnector(ConnectionInfo, SocketFactoryMock.Object))
+ .Returns(ConnectorMock.Object);
+ _ = ConnectorMock.Setup(p => p.Connect(ConnectionInfo))
+ .Returns(ClientSocket);
+ _ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange())
+ .Returns(_protocolVersionExchangeMock.Object);
+ _ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout))
+ .Returns(() => ServerIdentification);
+ _ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, _keyExchangeAlgorithms)).Returns(_keyExchangeMock.Object);
+
+ _ = _keyExchangeMock.Setup(p => p.Name)
+ .Returns(_keyExchangeAlgorithms[0]);
+ _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny(), false));
+ _ = _keyExchangeMock.Setup(p => p.ExchangeHash)
+ .Returns(SessionId);
+ _ = _keyExchangeMock.Setup(p => p.CreateServerCipher(out It.Ref.IsAny))
+ .Returns((ref bool serverAead) =>
+ {
+ serverAead = false;
+ return (Cipher) null;
+ });
+ _ = _keyExchangeMock.Setup(p => p.CreateClientCipher(out It.Ref.IsAny))
+ .Returns((ref bool clientAead) =>
+ {
+ clientAead = false;
+ return (Cipher) null;
+ });
+ _ = _keyExchangeMock.Setup(p => p.CreateServerHash(out It.Ref.IsAny))
+ .Returns((ref bool serverEtm) =>
+ {
+ serverEtm = false;
+ return SHA256.Create();
+ });
+ _ = _keyExchangeMock.Setup(p => p.CreateClientHash(out It.Ref.IsAny))
+ .Returns((ref bool clientEtm) =>
+ {
+ clientEtm = false;
+ return (HashAlgorithm) null;
+ });
+ _ = _keyExchangeMock.Setup(p => p.CreateCompressor())
+ .Returns((Compressor) null);
+ _ = _keyExchangeMock.Setup(p => p.CreateDecompressor())
+ .Returns((Compressor) null);
+ _ = _keyExchangeMock.Setup(p => p.Dispose());
+ _ = ServiceFactoryMock.Setup(p => p.CreateClientAuthentication())
+ .Returns(_clientAuthenticationMock.Object);
+ _ = _clientAuthenticationMock.Setup(p => p.Authenticate(ConnectionInfo, Session));
+ }
+
+ private class ServiceAcceptMessageBuilder
+ {
+ private readonly ServiceName _serviceName;
+
+ private ServiceAcceptMessageBuilder(ServiceName serviceName)
+ {
+ _serviceName = serviceName;
+ }
+
+ public static ServiceAcceptMessageBuilder Create(ServiceName serviceName)
+ {
+ return new ServiceAcceptMessageBuilder(serviceName);
+ }
+
+ public byte[] Build(uint sequence)
+ {
+ var serviceName = _serviceName.ToArray();
+ var target = new ServiceAcceptMessage();
+
+ var sshDataStream = new SshDataStream(4 + 4 + 1 + 1 + 4 + serviceName.Length);
+ sshDataStream.Write(sequence);
+ sshDataStream.Write((uint) (sshDataStream.Capacity - 8)); //sequence and packet length
+ sshDataStream.WriteByte(0); // padding length
+ sshDataStream.WriteByte(target.MessageNumber);
+ sshDataStream.WriteBinary(serviceName);
+ return sshDataStream.ToArray();
+ }
+ }
+ }
+}
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerIdentificationReceived.cs
similarity index 91%
rename from test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs
rename to test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerIdentificationReceived.cs
index 7b5ff1d86..fdb7cc79f 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerIdentificationReceived.cs
@@ -5,14 +5,12 @@
namespace Renci.SshNet.Tests.Classes
{
[TestClass]
- public class SessionTest_Connected_ServerIdentificationReceived : SessionTest_ConnectedBase
+ public class SessionTest_Connecting_ServerIdentificationReceived : SessionTest_ConnectingBase
{
protected override void SetupData()
{
base.SetupData();
- CallSessionConnectWhenArrange = false;
-
Session.ServerIdentificationReceived += (s, e) =>
{
if ((e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.5", System.StringComparison.Ordinal) || e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6", System.StringComparison.Ordinal))
@@ -24,10 +22,6 @@ protected override void SetupData()
};
}
- protected override void Act()
- {
- }
-
[TestMethod]
[DataRow("OpenSSH_6.5")]
[DataRow("OpenSSH_6.5p1")]
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex.cs
new file mode 100644
index 000000000..339f9df6c
--- /dev/null
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex.cs
@@ -0,0 +1,35 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+ [TestClass]
+ public class SessionTest_Connecting_ServerNotResetSequenceNumberAfterNewKeys_StrictKex : SessionTest_ConnectingBase
+ {
+ protected override bool ServerSupportsStrictKex
+ {
+ get
+ {
+ return true;
+ }
+ }
+
+ protected override bool ServerResetsSequenceAfterSendingNewKeys
+ {
+ get
+ {
+ return false;
+ }
+ }
+
+
+ [TestMethod]
+ public void ShouldThrowSshConnectionException()
+ {
+ var reason = Assert.ThrowsException(Session.Connect).DisconnectReason;
+ Assert.AreEqual(DisconnectReason.MacError, reason);
+ }
+ }
+}
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex.cs
new file mode 100644
index 000000000..a8f0680ba
--- /dev/null
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex.cs
@@ -0,0 +1,31 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Renci.SshNet.Tests.Classes
+{
+ [TestClass]
+ public class SessionTest_Connecting_ServerResetsSequenceNumberAfterNewKeys_StrictKex : SessionTest_ConnectingBase
+ {
+ protected override bool ServerSupportsStrictKex
+ {
+ get
+ {
+ return true;
+ }
+ }
+
+ protected override bool ServerResetsSequenceAfterSendingNewKeys
+ {
+ get
+ {
+ return true;
+ }
+ }
+
+
+ [TestMethod]
+ public void ShouldNotThrowException()
+ {
+ Session.Connect();
+ }
+ }
+}
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs
new file mode 100644
index 000000000..6fbdcceaa
--- /dev/null
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs
@@ -0,0 +1,48 @@
+using System.Globalization;
+using System.Net.Sockets;
+using System.Text;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+ [TestClass]
+ public class SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex : SessionTest_ConnectingBase
+ {
+ protected override bool ServerSupportsStrictKex
+ {
+ get
+ {
+ return true;
+ }
+ }
+
+ protected override void ActionAfterKexInit()
+ {
+ using var stream = new SshDataStream(0);
+ stream.WriteByte(1);
+ stream.Write("This is a debug message", Encoding.UTF8);
+ stream.Write(CultureInfo.CurrentCulture.Name, Encoding.UTF8);
+
+ var debugMessage = new DebugMessage();
+ debugMessage.Load(stream.ToArray());
+ var debug = debugMessage.GetPacket(8, null);
+
+ // MitM sends debug message to client
+ _ = ServerSocket.Send(debug, 4, debug.Length - 4, SocketFlags.None);
+
+ // MitM drops server message
+ ServerOutboundPacketSequence++;
+ }
+
+ [TestMethod]
+ public void ShouldThrowSshException()
+ {
+ var message = Assert.ThrowsException(Session.Connect).Message;
+ Assert.AreEqual("Message type 4 is not valid in the current context.", message);
+ }
+ }
+}
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs
new file mode 100644
index 000000000..989d43e56
--- /dev/null
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs
@@ -0,0 +1,39 @@
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+ [TestClass]
+ public class SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex : SessionTest_ConnectingBase
+ {
+ protected override bool ServerSupportsStrictKex
+ {
+ get
+ {
+ return true;
+ }
+ }
+
+ protected override void ActionAfterKexInit()
+ {
+ var disconnectMessage = new DisconnectMessage(DisconnectReason.TooManyConnections, "too many connections");
+ var disconnect = disconnectMessage.GetPacket(8, null);
+
+ // Server sends disconnect message to client
+ _ = ServerSocket.Send(disconnect, 4, disconnect.Length - 4, SocketFlags.None);
+
+ ServerOutboundPacketSequence++;
+ }
+
+ [TestMethod]
+ public void DisconnectIsAllowedDuringStrictKex()
+ {
+ var exception = Assert.ThrowsException(Session.Connect);
+ Assert.AreEqual(DisconnectReason.TooManyConnections, exception.DisconnectReason);
+ }
+ }
+}
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs
new file mode 100644
index 000000000..f20d81d8a
--- /dev/null
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs
@@ -0,0 +1,38 @@
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+ [TestClass]
+ public class SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex : SessionTest_ConnectingBase
+ {
+ protected override bool ServerSupportsStrictKex
+ {
+ get
+ {
+ return false;
+ }
+ }
+
+ protected override void ActionAfterKexInit()
+ {
+ var ignoreMessage = new IgnoreMessage();
+ var ignore = ignoreMessage.GetPacket(8, null);
+
+ // MitM sends ignore message to client
+ _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None);
+
+ // MitM drops server message
+ ServerOutboundPacketSequence++;
+ }
+
+ [TestMethod]
+ public void DoesNotThrowException()
+ {
+ Session.Connect();
+ }
+ }
+}
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs
new file mode 100644
index 000000000..0179c0eb0
--- /dev/null
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs
@@ -0,0 +1,40 @@
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+ [TestClass]
+ public class SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex : SessionTest_ConnectingBase
+ {
+ protected override bool ServerSupportsStrictKex
+ {
+ get
+ {
+ return true;
+ }
+ }
+
+ protected override void ActionAfterKexInit()
+ {
+ var ignoreMessage = new IgnoreMessage();
+ var ignore = ignoreMessage.GetPacket(8, null);
+
+ // MitM sends ignore message to client
+ _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None);
+
+ // MitM drops server message
+ ServerOutboundPacketSequence++;
+ }
+
+ [TestMethod]
+ public void ShouldThrowSshException()
+ {
+ var message = Assert.ThrowsException(Session.Connect).Message;
+ Assert.AreEqual("Message type 2 is not valid in the current context.", message);
+ }
+ }
+}
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs
new file mode 100644
index 000000000..c85d925b7
--- /dev/null
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs
@@ -0,0 +1,38 @@
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+ [TestClass]
+ public class SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex : SessionTest_ConnectingBase
+ {
+ protected override bool ServerSupportsStrictKex
+ {
+ get
+ {
+ return false;
+ }
+ }
+
+ protected override void ActionBeforeKexInit()
+ {
+ var ignoreMessage = new IgnoreMessage();
+ var ignore = ignoreMessage.GetPacket(8, null);
+
+ // MitM sends ignore message to client
+ _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None);
+
+ // MitM drops server message
+ ServerOutboundPacketSequence++;
+ }
+
+ [TestMethod]
+ public void DoesNotThrowException()
+ {
+ Session.Connect();
+ }
+ }
+}
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs
new file mode 100644
index 000000000..53dde0b3c
--- /dev/null
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs
@@ -0,0 +1,41 @@
+using System.Net.Sockets;
+
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Renci.SshNet.Common;
+using Renci.SshNet.Messages.Transport;
+
+namespace Renci.SshNet.Tests.Classes
+{
+ [TestClass]
+ public class SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex : SessionTest_ConnectingBase
+ {
+ protected override bool ServerSupportsStrictKex
+ {
+ get
+ {
+ return true;
+ }
+ }
+
+ protected override void ActionBeforeKexInit()
+ {
+ var ignoreMessage = new IgnoreMessage();
+ var ignore = ignoreMessage.GetPacket(8, null);
+
+ // MitM sends ignore message to client
+ _ = ServerSocket.Send(ignore, 4, ignore.Length - 4, SocketFlags.None);
+
+ // MitM drops server message
+ ServerOutboundPacketSequence++;
+ }
+
+ [TestMethod]
+ public void ShouldThrowSshConnectionException()
+ {
+ var exception = Assert.ThrowsException(Session.Connect);
+ Assert.AreEqual(DisconnectReason.KeyExchangeFailed, exception.DisconnectReason);
+ Assert.AreEqual("KEXINIT was not the first packet during strict key exchange.", exception.Message);
+ }
+ }
+}
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs
index 4bd134348..c493b6df1 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_NotConnected.cs
@@ -57,7 +57,7 @@ public void IsConnectedShouldReturnFalse()
}
[TestMethod]
- public void SendMessageShouldThrowShhConnectionException()
+ public void SendMessageShouldThrowSshConnectionException()
{
try
{
@@ -159,7 +159,7 @@ public void ISession_MessageListenerCompletedShouldBeSignaled()
}
[TestMethod]
- public void ISession_SendMessageShouldThrowShhConnectionException()
+ public void ISession_SendMessageShouldThrowSshConnectionException()
{
var session = (ISession) _session;
diff --git a/test/Renci.SshNet.Tests/Common/AsyncSocketListener.cs b/test/Renci.SshNet.Tests/Common/AsyncSocketListener.cs
index 0f7dac81f..59317f44f 100644
--- a/test/Renci.SshNet.Tests/Common/AsyncSocketListener.cs
+++ b/test/Renci.SshNet.Tests/Common/AsyncSocketListener.cs
@@ -385,7 +385,7 @@ private class SocketStateObject
{
public Socket Socket { get; private set; }
- public readonly byte[] Buffer = new byte[1024];
+ public readonly byte[] Buffer = new byte[2048];
public SocketStateObject(Socket handler)
{