Skip to content

Commit

Permalink
Fixed dequeuing of incoming queue (#1319)
Browse files Browse the repository at this point in the history
* Fixed dequeuing of incoming queue.

* Adjusted return of Expect to make sure it returns the full incoming queue.
  • Loading branch information
jscarle authored Feb 13, 2024
1 parent bcaf354 commit d07827b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 45 deletions.
95 changes: 52 additions & 43 deletions src/Renci.SshNet/ShellStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -282,19 +282,14 @@ public void Expect(TimeSpan timeout, params ExpectAction[] expectActions)

if (match.Success)
{
var returnText = matchText.Substring(0, match.Index + match.Length);
var returnLength = _encoding.GetByteCount(returnText);
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
#else
var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
#endif

// Remove processed items from the queue
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}
var returnText = SyncQueuesAndReturn(returnLength);

expectAction.Action(returnText);
expectedFound = true;
Expand Down Expand Up @@ -385,19 +380,14 @@ public string Expect(Regex regex, TimeSpan timeout)

if (match.Success)
{
returnText = matchText.Substring(0, match.Index + match.Length);
var returnLength = _encoding.GetByteCount(returnText);
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
#else
var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
#endif

// Remove processed items from the queue
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}
returnText = SyncQueuesAndReturn(returnLength);

break;
}
Expand Down Expand Up @@ -501,19 +491,14 @@ public IAsyncResult BeginExpect(TimeSpan timeout, AsyncCallback callback, object

if (match.Success)
{
returnText = matchText.Substring(0, match.Index + match.Length);
var returnLength = _encoding.GetByteCount(returnText);
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
#else
var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
#endif

// Remove processed items from the queue
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}
returnText = SyncQueuesAndReturn(returnLength);

expectAction.Action(returnText);
callback?.Invoke(asyncResult);
Expand Down Expand Up @@ -614,15 +599,7 @@ public string ReadLine(TimeSpan timeout)
var bytesProcessed = _encoding.GetByteCount(text + CrLf);

// remove processed bytes from the queue
for (var i = 0; i < bytesProcessed; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}
SyncQueuesAndDequeue(bytesProcessed);

break;
}
Expand Down Expand Up @@ -687,7 +664,7 @@ public override int Read(byte[] buffer, int offset, int count)
{
for (; i < count && _incoming.Count > 0; i++)
{
if (_expect.Count == _incoming.Count)
if (_incoming.Count == _expect.Count)
{
_ = _expect.Dequeue();
}
Expand Down Expand Up @@ -869,5 +846,37 @@ private void OnDataReceived(byte[] data)
{
DataReceived?.Invoke(this, new ShellDataEventArgs(data));
}

private string SyncQueuesAndReturn(int bytesToDequeue)
{
string incomingText;

lock (_incoming)
{
var incomingLength = _incoming.Count - _expect.Count + bytesToDequeue;
incomingText = _encoding.GetString(_incoming.ToArray(), 0, incomingLength);

SyncQueuesAndDequeue(bytesToDequeue);
}

return incomingText;
}

private void SyncQueuesAndDequeue(int bytesToDequeue)
{
lock (_incoming)
{
while (_incoming.Count > _expect.Count)
{
_ = _incoming.Dequeue();
}

for (var count = 0; count < bytesToDequeue && _incoming.Count > 0; count++)
{
_ = _incoming.Dequeue();
_ = _expect.Dequeue();
}
}
}
}
}
31 changes: 29 additions & 2 deletions test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ namespace Renci.SshNet.Tests.Classes
[TestClass]
public class ShellStreamTest_ReadExpect
{
private const int BufferSize = 1024;
private const int ExpectSize = BufferSize * 2;
private ShellStream _shellStream;
private ChannelSessionStub _channelSessionStub;

Expand All @@ -42,8 +44,8 @@ public void Initialize()
width: 800,
height: 600,
terminalModeValues: null,
bufferSize: 1024,
expectSize: 2048);
bufferSize: BufferSize,
expectSize: ExpectSize);
}

[TestMethod]
Expand Down Expand Up @@ -244,6 +246,31 @@ public void Expect_String_LargeExpect()
Assert.AreEqual($"{new string('c', 100)}", _shellStream.Read());
}

[TestMethod]
public void Expect_String_DequeueChecks()
{
const string expected = "ccccc";

// Prime buffer
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string(' ', BufferSize)));
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string(' ', ExpectSize)));

// Test data
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('a', 100)));
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('b', 100)));
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(expected));
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('d', 100)));
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('e', 100)));

// Expected result
var expectedResult = $"{new string(' ', BufferSize)}{new string(' ', ExpectSize)}{new string('a', 100)}{new string('b', 100)}{expected}";
var expectedRead = $"{new string('d', 100)}{new string('e', 100)}";

Assert.AreEqual(expectedResult, _shellStream.Expect(expected));

Assert.AreEqual(expectedRead, _shellStream.Read());
}

[TestMethod]
public void Expect_Timeout()
{
Expand Down

0 comments on commit d07827b

Please sign in to comment.