diff --git a/src/Renci.SshNet/ShellStream.cs b/src/Renci.SshNet/ShellStream.cs index 190392051..a737cc674 100644 --- a/src/Renci.SshNet/ShellStream.cs +++ b/src/Renci.SshNet/ShellStream.cs @@ -282,19 +282,14 @@ public void Expect(TimeSpan timeout, params ExpectAction[] expectActions) if (match.Success) { - var returnText = matchText.Substring(0, match.Index + match.Length); - var returnLength = _encoding.GetByteCount(returnText); +#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER + var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length)); +#else + var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length)); +#endif // Remove processed items from the queue - for (var i = 0; i < returnLength && _incoming.Count > 0; i++) - { - if (_expect.Count == _incoming.Count) - { - _ = _expect.Dequeue(); - } - - _ = _incoming.Dequeue(); - } + var returnText = SyncQueuesAndReturn(returnLength); expectAction.Action(returnText); expectedFound = true; @@ -385,19 +380,14 @@ public string Expect(Regex regex, TimeSpan timeout) if (match.Success) { - returnText = matchText.Substring(0, match.Index + match.Length); - var returnLength = _encoding.GetByteCount(returnText); +#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER + var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length)); +#else + var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length)); +#endif // Remove processed items from the queue - for (var i = 0; i < returnLength && _incoming.Count > 0; i++) - { - if (_expect.Count == _incoming.Count) - { - _ = _expect.Dequeue(); - } - - _ = _incoming.Dequeue(); - } + returnText = SyncQueuesAndReturn(returnLength); break; } @@ -501,19 +491,14 @@ public IAsyncResult BeginExpect(TimeSpan timeout, AsyncCallback callback, object if (match.Success) { - returnText = matchText.Substring(0, match.Index + match.Length); - var returnLength = _encoding.GetByteCount(returnText); +#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER + var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length)); +#else + var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length)); +#endif // Remove processed items from the queue - for (var i = 0; i < returnLength && _incoming.Count > 0; i++) - { - if (_expect.Count == _incoming.Count) - { - _ = _expect.Dequeue(); - } - - _ = _incoming.Dequeue(); - } + returnText = SyncQueuesAndReturn(returnLength); expectAction.Action(returnText); callback?.Invoke(asyncResult); @@ -614,15 +599,7 @@ public string ReadLine(TimeSpan timeout) var bytesProcessed = _encoding.GetByteCount(text + CrLf); // remove processed bytes from the queue - for (var i = 0; i < bytesProcessed; i++) - { - if (_expect.Count == _incoming.Count) - { - _ = _expect.Dequeue(); - } - - _ = _incoming.Dequeue(); - } + SyncQueuesAndDequeue(bytesProcessed); break; } @@ -687,7 +664,7 @@ public override int Read(byte[] buffer, int offset, int count) { for (; i < count && _incoming.Count > 0; i++) { - if (_expect.Count == _incoming.Count) + if (_incoming.Count == _expect.Count) { _ = _expect.Dequeue(); } @@ -869,5 +846,37 @@ private void OnDataReceived(byte[] data) { DataReceived?.Invoke(this, new ShellDataEventArgs(data)); } + + private string SyncQueuesAndReturn(int bytesToDequeue) + { + string incomingText; + + lock (_incoming) + { + var incomingLength = _incoming.Count - _expect.Count + bytesToDequeue; + incomingText = _encoding.GetString(_incoming.ToArray(), 0, incomingLength); + + SyncQueuesAndDequeue(bytesToDequeue); + } + + return incomingText; + } + + private void SyncQueuesAndDequeue(int bytesToDequeue) + { + lock (_incoming) + { + while (_incoming.Count > _expect.Count) + { + _ = _incoming.Dequeue(); + } + + for (var count = 0; count < bytesToDequeue && _incoming.Count > 0; count++) + { + _ = _incoming.Dequeue(); + _ = _expect.Dequeue(); + } + } + } } } diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs index b3daf51fd..40fc8c883 100644 --- a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs @@ -17,6 +17,8 @@ namespace Renci.SshNet.Tests.Classes [TestClass] public class ShellStreamTest_ReadExpect { + private const int BufferSize = 1024; + private const int ExpectSize = BufferSize * 2; private ShellStream _shellStream; private ChannelSessionStub _channelSessionStub; @@ -42,8 +44,8 @@ public void Initialize() width: 800, height: 600, terminalModeValues: null, - bufferSize: 1024, - expectSize: 2048); + bufferSize: BufferSize, + expectSize: ExpectSize); } [TestMethod] @@ -244,6 +246,31 @@ public void Expect_String_LargeExpect() Assert.AreEqual($"{new string('c', 100)}", _shellStream.Read()); } + [TestMethod] + public void Expect_String_DequeueChecks() + { + const string expected = "ccccc"; + + // Prime buffer + _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string(' ', BufferSize))); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string(' ', ExpectSize))); + + // Test data + _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('a', 100))); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('b', 100))); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes(expected)); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('d', 100))); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('e', 100))); + + // Expected result + var expectedResult = $"{new string(' ', BufferSize)}{new string(' ', ExpectSize)}{new string('a', 100)}{new string('b', 100)}{expected}"; + var expectedRead = $"{new string('d', 100)}{new string('e', 100)}"; + + Assert.AreEqual(expectedResult, _shellStream.Expect(expected)); + + Assert.AreEqual(expectedRead, _shellStream.Read()); + } + [TestMethod] public void Expect_Timeout() {