diff --git a/src/Renci.SshNet/ShellStream.cs b/src/Renci.SshNet/ShellStream.cs index e8945bf9c..4e53fb16f 100644 --- a/src/Renci.SshNet/ShellStream.cs +++ b/src/Renci.SshNet/ShellStream.cs @@ -26,9 +26,12 @@ public class ShellStream : Stream private readonly object _sync = new object(); - private byte[] _buffer; - private int _head; // The index from which the data starts in _buffer. - private int _tail; // The index at which to add new data into _buffer. + private readonly byte[] _writeBuffer; + private int _writeLength; // The length of the data in _writeBuffer. + + private byte[] _readBuffer; + private int _readHead; // The index from which the data starts in _readBuffer. + private int _readTail; // The index at which to add new data into _readBuffer. private bool _disposed; /// @@ -54,7 +57,7 @@ public bool DataAvailable lock (_sync) { AssertValid(); - return _tail != _head; + return _readTail != _readHead; } } } @@ -64,11 +67,11 @@ public bool DataAvailable private void AssertValid() { Debug.Assert(Monitor.IsEntered(_sync), $"Should be in lock on {nameof(_sync)}"); - Debug.Assert(_head >= 0, $"{nameof(_head)} should be non-negative but is {_head}"); - Debug.Assert(_tail >= 0, $"{nameof(_tail)} should be non-negative but is {_tail}"); - Debug.Assert(_head < _buffer.Length || _buffer.Length == 0, $"{nameof(_head)} should be < {nameof(_buffer)}.Length but is {_head}"); - Debug.Assert(_tail <= _buffer.Length, $"{nameof(_tail)} should be <= {nameof(_buffer)}.Length but is {_tail}"); - Debug.Assert(_head <= _tail, $"Should have {nameof(_head)} <= {nameof(_tail)} but have {_head} <= {_tail}"); + Debug.Assert(_readHead >= 0, $"{nameof(_readHead)} should be non-negative but is {_readHead}"); + Debug.Assert(_readTail >= 0, $"{nameof(_readTail)} should be non-negative but is {_readTail}"); + Debug.Assert(_readHead < _readBuffer.Length || _readBuffer.Length == 0, $"{nameof(_readHead)} should be < {nameof(_readBuffer)}.Length but is {_readHead}"); + Debug.Assert(_readTail <= _readBuffer.Length, $"{nameof(_readTail)} should be <= {nameof(_readBuffer)}.Length but is {_readTail}"); + Debug.Assert(_readHead <= _readTail, $"Should have {nameof(_readHead)} <= {nameof(_readTail)} but have {_readHead} <= {_readTail}"); } #pragma warning restore MA0076 // Do not use implicit culture-sensitive ToString in interpolated strings @@ -108,7 +111,8 @@ internal ShellStream(ISession session, string terminalName, uint columns, uint r _session.Disconnected += Session_Disconnected; _session.ErrorOccured += Session_ErrorOccured; - _buffer = new byte[bufferSize]; + _readBuffer = new byte[bufferSize]; + _writeBuffer = new byte[bufferSize]; try { @@ -178,6 +182,15 @@ public override bool CanWrite /// public override void Flush() { + ThrowIfDisposed(); + + Debug.Assert(_writeLength >= 0 && _writeLength <= _writeBuffer.Length); + + if (_writeLength > 0) + { + _channel.SendData(_writeBuffer, 0, _writeLength); + _writeLength = 0; + } } /// @@ -191,7 +204,7 @@ public override long Length lock (_sync) { AssertValid(); - return _tail - _head; + return _readTail - _readHead; } } } @@ -326,22 +339,22 @@ public void Expect(TimeSpan timeout, int lookback, params ExpectAction[] expectA AssertValid(); var searchHead = lookback == -1 - ? _head - : Math.Max(_tail - lookback, _head); + ? _readHead + : Math.Max(_readTail - lookback, _readHead); - Debug.Assert(_head <= searchHead && searchHead <= _tail); + Debug.Assert(_readHead <= searchHead && searchHead <= _readTail); #if NETFRAMEWORK || NETSTANDARD2_0 - var indexOfMatch = _buffer.IndexOf(expectBytes, searchHead, _tail - searchHead); + var indexOfMatch = _readBuffer.IndexOf(expectBytes, searchHead, _readTail - searchHead); #else - var indexOfMatch = _buffer.AsSpan(searchHead, _tail - searchHead).IndexOf(expectBytes); + var indexOfMatch = _readBuffer.AsSpan(searchHead, _readTail - searchHead).IndexOf(expectBytes); #endif if (indexOfMatch >= 0) { - var returnText = _encoding.GetString(_buffer, _head, searchHead - _head + indexOfMatch + expectBytes.Length); + var returnText = _encoding.GetString(_readBuffer, _readHead, searchHead - _readHead + indexOfMatch + expectBytes.Length); - _head = searchHead + indexOfMatch + expectBytes.Length; + _readHead = searchHead + indexOfMatch + expectBytes.Length; AssertValid(); @@ -415,7 +428,7 @@ public void Expect(TimeSpan timeout, int lookback, params ExpectAction[] expectA { AssertValid(); - var bufferText = _encoding.GetString(_buffer, _head, _tail - _head); + var bufferText = _encoding.GetString(_readBuffer, _readHead, _readTail - _readHead); var searchStart = lookback == -1 ? 0 @@ -438,7 +451,7 @@ public void Expect(TimeSpan timeout, int lookback, params ExpectAction[] expectA { var returnText = bufferText.Substring(0, match.Index + match.Length); #endif - _head += _encoding.GetByteCount(returnText); + _readHead += _encoding.GetByteCount(returnText); AssertValid(); @@ -604,29 +617,29 @@ public IAsyncResult BeginExpect(TimeSpan timeout, int lookback, AsyncCallback? c AssertValid(); #if NETFRAMEWORK || NETSTANDARD2_0 - var indexOfCr = _buffer.IndexOf(_carriageReturnBytes, _head, _tail - _head); + var indexOfCr = _readBuffer.IndexOf(_carriageReturnBytes, _readHead, _readTail - _readHead); #else - var indexOfCr = _buffer.AsSpan(_head, _tail - _head).IndexOf(_carriageReturnBytes); + var indexOfCr = _readBuffer.AsSpan(_readHead, _readTail - _readHead).IndexOf(_carriageReturnBytes); #endif if (indexOfCr >= 0) { // We have found \r. We only need to search for \n up to and just after the \r // (in order to consume \r\n if we can). #if NETFRAMEWORK || NETSTANDARD2_0 - var indexOfLf = indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length <= _tail - _head - ? _buffer.IndexOf(_lineFeedBytes, _head, indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length) - : _buffer.IndexOf(_lineFeedBytes, _head, indexOfCr); + var indexOfLf = indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length <= _readTail - _readHead + ? _readBuffer.IndexOf(_lineFeedBytes, _readHead, indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length) + : _readBuffer.IndexOf(_lineFeedBytes, _readHead, indexOfCr); #else - var indexOfLf = indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length <= _tail - _head - ? _buffer.AsSpan(_head, indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length).IndexOf(_lineFeedBytes) - : _buffer.AsSpan(_head, indexOfCr).IndexOf(_lineFeedBytes); + var indexOfLf = indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length <= _readTail - _readHead + ? _readBuffer.AsSpan(_readHead, indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length).IndexOf(_lineFeedBytes) + : _readBuffer.AsSpan(_readHead, indexOfCr).IndexOf(_lineFeedBytes); #endif if (indexOfLf >= 0 && indexOfLf < indexOfCr) { // If there is \n before the \r, then return up to the \n - var returnText = _encoding.GetString(_buffer, _head, indexOfLf); + var returnText = _encoding.GetString(_readBuffer, _readHead, indexOfLf); - _head += indexOfLf + _lineFeedBytes.Length; + _readHead += indexOfLf + _lineFeedBytes.Length; AssertValid(); @@ -635,9 +648,9 @@ public IAsyncResult BeginExpect(TimeSpan timeout, int lookback, AsyncCallback? c else if (indexOfLf == indexOfCr + _carriageReturnBytes.Length) { // If we have \r\n, then consume both - var returnText = _encoding.GetString(_buffer, _head, indexOfCr); + var returnText = _encoding.GetString(_readBuffer, _readHead, indexOfCr); - _head += indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length; + _readHead += indexOfCr + _carriageReturnBytes.Length + _lineFeedBytes.Length; AssertValid(); @@ -646,9 +659,9 @@ public IAsyncResult BeginExpect(TimeSpan timeout, int lookback, AsyncCallback? c else { // Return up to the \r - var returnText = _encoding.GetString(_buffer, _head, indexOfCr); + var returnText = _encoding.GetString(_readBuffer, _readHead, indexOfCr); - _head += indexOfCr + _carriageReturnBytes.Length; + _readHead += indexOfCr + _carriageReturnBytes.Length; AssertValid(); @@ -659,15 +672,15 @@ public IAsyncResult BeginExpect(TimeSpan timeout, int lookback, AsyncCallback? c { // There is no \r. What about \n? #if NETFRAMEWORK || NETSTANDARD2_0 - var indexOfLf = _buffer.IndexOf(_lineFeedBytes, _head, _tail - _head); + var indexOfLf = _readBuffer.IndexOf(_lineFeedBytes, _readHead, _readTail - _readHead); #else - var indexOfLf = _buffer.AsSpan(_head, _tail - _head).IndexOf(_lineFeedBytes); + var indexOfLf = _readBuffer.AsSpan(_readHead, _readTail - _readHead).IndexOf(_lineFeedBytes); #endif if (indexOfLf >= 0) { - var returnText = _encoding.GetString(_buffer, _head, indexOfLf); + var returnText = _encoding.GetString(_readBuffer, _readHead, indexOfLf); - _head += indexOfLf + _lineFeedBytes.Length; + _readHead += indexOfLf + _lineFeedBytes.Length; AssertValid(); @@ -677,11 +690,11 @@ public IAsyncResult BeginExpect(TimeSpan timeout, int lookback, AsyncCallback? c if (_disposed) { - var lastLine = _head == _tail + var lastLine = _readHead == _readTail ? null - : _encoding.GetString(_buffer, _head, _tail - _head); + : _encoding.GetString(_readBuffer, _readHead, _readTail - _readHead); - _head = _tail = 0; + _readHead = _readTail = 0; return lastLine; } @@ -719,6 +732,18 @@ private static void ValidateLookback(int lookback) } } + private void ThrowIfDisposed() + { +#if NET7_0_OR_GREATER + ObjectDisposedException.ThrowIf(_disposed, this); +#else + if (_disposed) + { + throw new ObjectDisposedException(GetType().FullName); + } +#endif // NET7_0_OR_GREATER + } + /// /// Reads all of the text currently available in the shell. /// @@ -731,9 +756,9 @@ public string Read() { AssertValid(); - var text = _encoding.GetString(_buffer, _head, _tail - _head); + var text = _encoding.GetString(_readBuffer, _readHead, _readTail - _readHead); - _head = _tail = 0; + _readHead = _readTail = 0; return text; } @@ -744,18 +769,18 @@ public override int Read(byte[] buffer, int offset, int count) { lock (_sync) { - while (_head == _tail && !_disposed) + while (_readHead == _readTail && !_disposed) { _ = Monitor.Wait(_sync); } AssertValid(); - var bytesRead = Math.Min(count, _tail - _head); + var bytesRead = Math.Min(count, _readTail - _readHead); - Buffer.BlockCopy(_buffer, _head, buffer, offset, bytesRead); + Buffer.BlockCopy(_readBuffer, _readHead, buffer, offset, bytesRead); - _head += bytesRead; + _readHead += bytesRead; AssertValid(); @@ -768,13 +793,8 @@ public override int Read(byte[] buffer, int offset, int count) /// /// The text to be written to the shell. /// - /// /// If is , nothing is written. - /// - /// - /// Data is not buffered before being written to the shell. If you have text to send in many pieces, - /// consider wrapping this stream in a . - /// + /// Otherwise, is called after writing the data to the buffer. /// /// The stream is closed. public void Write(string? text) @@ -787,31 +807,31 @@ public void Write(string? text) var data = _encoding.GetBytes(text); Write(data, 0, data.Length); + Flush(); } - /// - /// Writes a sequence of bytes to the shell. - /// - /// An array of bytes. This method sends bytes from buffer to the shell. - /// The zero-based byte offset in at which to begin sending bytes to the shell. - /// The number of bytes to be sent to the shell. - /// - /// Data is not buffered before being written to the shell. If you have data to send in many pieces, - /// consider wrapping this stream in a . - /// - /// The stream is closed. + /// public override void Write(byte[] buffer, int offset, int count) { -#if NET7_0_OR_GREATER - ObjectDisposedException.ThrowIf(_disposed, this); -#else - if (_disposed) + ThrowIfDisposed(); + + while (count > 0) { - throw new ObjectDisposedException(GetType().FullName); - } -#endif // NET7_0_OR_GREATER + if (_writeLength == _writeBuffer.Length) + { + Flush(); + } + + var bytesToCopy = Math.Min(count, _writeBuffer.Length - _writeLength); + + Buffer.BlockCopy(buffer, offset, _writeBuffer, _writeLength, bytesToCopy); - _channel.SendData(buffer, offset, count); + offset += bytesToCopy; + count -= bytesToCopy; + _writeLength += bytesToCopy; + + Debug.Assert(_writeLength >= 0 && _writeLength <= _writeBuffer.Length); + } } /// @@ -820,6 +840,7 @@ public override void Write(byte[] buffer, int offset, int count) /// The line to be written to the shell. /// /// If is , only the line terminator is written. + /// is called once the data is written. /// /// The stream is closed. public void WriteLine(string line) @@ -883,38 +904,38 @@ private void Channel_DataReceived(object? sender, ChannelDataEventArgs e) // Ensure sufficient buffer space and copy the new data in. - if (_buffer.Length - _tail >= e.Data.Length) + if (_readBuffer.Length - _readTail >= e.Data.Length) { // If there is enough space after _tail for the new data, // then copy the data there. - Buffer.BlockCopy(e.Data, 0, _buffer, _tail, e.Data.Length); - _tail += e.Data.Length; + Buffer.BlockCopy(e.Data, 0, _readBuffer, _readTail, e.Data.Length); + _readTail += e.Data.Length; } else { // We can't fit the new data after _tail. - var newLength = _tail - _head + e.Data.Length; + var newLength = _readTail - _readHead + e.Data.Length; - if (newLength <= _buffer.Length) + if (newLength <= _readBuffer.Length) { // If there is sufficient space at the start of the buffer, // then move the current data to the start of the buffer. - Buffer.BlockCopy(_buffer, _head, _buffer, 0, _tail - _head); + Buffer.BlockCopy(_readBuffer, _readHead, _readBuffer, 0, _readTail - _readHead); } else { // Otherwise, we're gonna need a bigger buffer. - var newBuffer = new byte[_buffer.Length * 2]; - Buffer.BlockCopy(_buffer, _head, newBuffer, 0, _tail - _head); - _buffer = newBuffer; + var newBuffer = new byte[_readBuffer.Length * 2]; + Buffer.BlockCopy(_readBuffer, _readHead, newBuffer, 0, _readTail - _readHead); + _readBuffer = newBuffer; } // Copy the new data into the freed-up space. - Buffer.BlockCopy(e.Data, 0, _buffer, _tail - _head, e.Data.Length); + Buffer.BlockCopy(e.Data, 0, _readBuffer, _readTail - _readHead, e.Data.Length); - _head = 0; - _tail = newLength; + _readHead = 0; + _readTail = newLength; } AssertValid(); diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest.cs index f0cd18dac..bc7f10d65 100644 --- a/test/Renci.SshNet.Tests/Classes/ShellStreamTest.cs +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest.cs @@ -86,7 +86,7 @@ public void Write_Text_ShouldWriteNothingWhenTextIsNull() shellStream.Write(text); - _channelSessionMock.Verify(p => p.SendData(It.IsAny()), Times.Never); + _channelSessionMock.VerifyAll(); } [TestMethod] @@ -95,33 +95,15 @@ public void WriteLine_Line_ShouldOnlyWriteLineTerminatorWhenLineIsNull() var shellStream = CreateShellStream(); const string line = null; var lineTerminator = _encoding.GetBytes("\r"); - - _channelSessionMock.Setup(p => p.SendData(lineTerminator, 0, lineTerminator.Length)); + + _ = _channelSessionMock.Setup(p => p.SendData( + It.Is(data => data.Take(lineTerminator.Length).IsEqualTo(lineTerminator)), + 0, + lineTerminator.Length)); shellStream.WriteLine(line); - _channelSessionMock.Verify(p => p.SendData(lineTerminator, 0, lineTerminator.Length), Times.Once); - } - - [TestMethod] - public void Write_Bytes_SendsToChannel() - { - var shellStream = CreateShellStream(); - - var bytes1 = _encoding.GetBytes("Hello World!"); - var bytes2 = _encoding.GetBytes("Some more bytes!"); - - _channelSessionMock.Setup(p => p.SendData(bytes1, 0, bytes1.Length)); - _channelSessionMock.Setup(p => p.SendData(bytes2, 0, bytes2.Length)); - - shellStream.Write(bytes1, 0, bytes1.Length); - - _channelSessionMock.Verify(p => p.SendData(bytes1, 0, bytes1.Length), Times.Once); - - shellStream.Write(bytes2, 0, bytes2.Length); - - _channelSessionMock.Verify(p => p.SendData(bytes1, 0, bytes1.Length), Times.Once); - _channelSessionMock.Verify(p => p.SendData(bytes2, 0, bytes2.Length), Times.Once); + _channelSessionMock.VerifyAll(); } [TestMethod] diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteLessBytesThanBufferSize.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteLessBytesThanBufferSize.cs new file mode 100644 index 000000000..0668177d2 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteLessBytesThanBufferSize.cs @@ -0,0 +1,135 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Renci.SshNet.Abstractions; +using Renci.SshNet.Channels; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class ShellStreamTest_Write_WriteBufferEmptyAndWriteLessBytesThanBufferSize + { + private Mock _sessionMock; + private Mock _connectionInfoMock; + private Mock _channelSessionMock; + private string _terminalName; + private uint _widthColumns; + private uint _heightRows; + private uint _widthPixels; + private uint _heightPixels; + private Dictionary _terminalModes; + private ShellStream _shellStream; + private int _bufferSize; + + private byte[] _data; + private int _offset; + private int _count; + private MockSequence _mockSequence; + + [TestInitialize] + public void Initialize() + { + Arrange(); + Act(); + } + + private void SetupData() + { + var random = new Random(); + + _terminalName = random.Next().ToString(); + _widthColumns = (uint) random.Next(); + _heightRows = (uint) random.Next(); + _widthPixels = (uint) random.Next(); + _heightPixels = (uint) random.Next(); + _terminalModes = new Dictionary(); + _bufferSize = random.Next(100, 1000); + + _data = CryptoAbstraction.GenerateRandom(_bufferSize - 10); + _offset = random.Next(1, 5); + _count = _data.Length - _offset - random.Next(1, 10); + } + + private void CreateMocks() + { + _sessionMock = new Mock(MockBehavior.Strict); + _connectionInfoMock = new Mock(MockBehavior.Strict); + _channelSessionMock = new Mock(MockBehavior.Strict); + } + + private void SetupMocks() + { + _mockSequence = new MockSequence(); + + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.ConnectionInfo) + .Returns(_connectionInfoMock.Object); + _connectionInfoMock.InSequence(_mockSequence) + .Setup(p => p.Encoding) + .Returns(new UTF8Encoding()); + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.CreateChannelSession()) + .Returns(_channelSessionMock.Object); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.Open()); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendPseudoTerminalRequest(_terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes)) + .Returns(true); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendShellRequest()) + .Returns(true); + } + + private void Arrange() + { + SetupData(); + CreateMocks(); + SetupMocks(); + + _shellStream = new ShellStream(_sessionMock.Object, + _terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes, + _bufferSize); + } + + private void Act() + { + _shellStream.Write(_data, _offset, _count); + } + + [TestMethod] + public void NoDataShouldBeSentToServer() + { + _channelSessionMock.Verify(p => p.SendData(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); + } + + [TestMethod] + public void FlushShouldSendWrittenBytesToServer() + { + byte[] bytesSent = null; + + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((data, offset, count) => bytesSent = data.Take(offset, count)); + + _shellStream.Flush(); + + Assert.IsNotNull(bytesSent); + Assert.IsTrue(_data.Take(_offset, _count).IsEqualTo(bytesSent)); + + _channelSessionMock.Verify(p => p.SendData(It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteMoreBytesThanBufferSize.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteMoreBytesThanBufferSize.cs new file mode 100644 index 000000000..64f95687e --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteMoreBytesThanBufferSize.cs @@ -0,0 +1,144 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Renci.SshNet.Abstractions; +using Renci.SshNet.Channels; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class ShellStreamTest_Write_WriteBufferEmptyAndWriteMoreBytesThanBufferSize + { + private Mock _sessionMock; + private Mock _connectionInfoMock; + private Mock _channelSessionMock; + private MockSequence _mockSequence; + private string _terminalName; + private uint _widthColumns; + private uint _heightRows; + private uint _widthPixels; + private uint _heightPixels; + private Dictionary _terminalModes; + private ShellStream _shellStream; + private int _bufferSize; + + private byte[] _data; + private int _offset; + private int _count; + + private byte[] _expectedBytesSent1; + private byte[] _expectedBytesSent2; + + [TestInitialize] + public void Initialize() + { + Arrange(); + Act(); + } + + private void SetupData() + { + var random = new Random(); + + _terminalName = random.Next().ToString(); + _widthColumns = (uint)random.Next(); + _heightRows = (uint)random.Next(); + _widthPixels = (uint)random.Next(); + _heightPixels = (uint)random.Next(); + _terminalModes = new Dictionary(); + _bufferSize = random.Next(100, 1000); + + _data = CryptoAbstraction.GenerateRandom((_bufferSize * 2) + 10); + _offset = 0; + _count = _data.Length; + + _expectedBytesSent1 = _data.Take(0, _bufferSize); + _expectedBytesSent2 = _data.Take(_bufferSize, _bufferSize); + } + + private void CreateMocks() + { + _sessionMock = new Mock(MockBehavior.Strict); + _connectionInfoMock = new Mock(MockBehavior.Strict); + _channelSessionMock = new Mock(MockBehavior.Strict); + } + + private void SetupMocks() + { + _mockSequence = new MockSequence(); + + _ = _sessionMock.InSequence(_mockSequence) + .Setup(p => p.ConnectionInfo) + .Returns(_connectionInfoMock.Object); + _ = _connectionInfoMock.InSequence(_mockSequence) + .Setup(p => p.Encoding) + .Returns(new UTF8Encoding()); + _ = _sessionMock.InSequence(_mockSequence) + .Setup(p => p.CreateChannelSession()) + .Returns(_channelSessionMock.Object); + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.Open()); + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendPseudoTerminalRequest(_terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes)) + .Returns(true); + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendShellRequest()) + .Returns(true); + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData(_expectedBytesSent1, 0, _bufferSize)); + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData(_expectedBytesSent2, 0, _bufferSize)); + } + + private void Arrange() + { + SetupData(); + CreateMocks(); + SetupMocks(); + + _shellStream = new ShellStream(_sessionMock.Object, + _terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes, + _bufferSize); + } + + private void Act() + { + _shellStream.Write(_data, _offset, _count); + } + + [TestMethod] + public void BufferShouldHaveBeenFlushedTwice() + { + _channelSessionMock.VerifyAll(); + } + + [TestMethod] + public void FlushShouldSendRemaningBytesToServer() + { + var expectedBytesSent = _data.Take(_bufferSize * 2, _data.Length - (_bufferSize * 2)); + + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData( + It.Is(data => data.Take(expectedBytesSent.Length).IsEqualTo(expectedBytesSent)), + 0, + expectedBytesSent.Length)); + + _shellStream.Flush(); + + _channelSessionMock.VerifyAll(); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteNumberOfBytesEqualToBufferSize.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteNumberOfBytesEqualToBufferSize.cs new file mode 100644 index 000000000..d3d380036 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteNumberOfBytesEqualToBufferSize.cs @@ -0,0 +1,132 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Renci.SshNet.Abstractions; +using Renci.SshNet.Channels; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class ShellStreamTest_Write_WriteBufferEmptyAndWriteNumberOfBytesEqualToBufferSize + { + private Mock _sessionMock; + private Mock _connectionInfoMock; + private Mock _channelSessionMock; + private string _terminalName; + private uint _widthColumns; + private uint _heightRows; + private uint _widthPixels; + private uint _heightPixels; + private Dictionary _terminalModes; + private ShellStream _shellStream; + private int _bufferSize; + + private byte[] _data; + private int _offset; + private int _count; + private MockSequence _mockSequence; + + [TestInitialize] + public void Initialize() + { + Arrange(); + Act(); + } + + private void SetupData() + { + var random = new Random(); + + _terminalName = random.Next().ToString(); + _widthColumns = (uint)random.Next(); + _heightRows = (uint)random.Next(); + _widthPixels = (uint)random.Next(); + _heightPixels = (uint)random.Next(); + _terminalModes = new Dictionary(); + _bufferSize = random.Next(100, 1000); + + _data = CryptoAbstraction.GenerateRandom(_bufferSize); + _offset = 0; + _count = _data.Length; + } + + private void CreateMocks() + { + _sessionMock = new Mock(MockBehavior.Strict); + _connectionInfoMock = new Mock(MockBehavior.Strict); + _channelSessionMock = new Mock(MockBehavior.Strict); + } + + private void SetupMocks() + { + _mockSequence = new MockSequence(); + + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.ConnectionInfo) + .Returns(_connectionInfoMock.Object); + _connectionInfoMock.InSequence(_mockSequence) + .Setup(p => p.Encoding) + .Returns(new UTF8Encoding()); + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.CreateChannelSession()) + .Returns(_channelSessionMock.Object); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.Open()); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendPseudoTerminalRequest(_terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes)) + .Returns(true); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendShellRequest()) + .Returns(true); + } + + private void Arrange() + { + SetupData(); + CreateMocks(); + SetupMocks(); + + _shellStream = new ShellStream(_sessionMock.Object, + _terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes, + _bufferSize); + } + + private void Act() + { + _shellStream.Write(_data, _offset, _count); + } + + [TestMethod] + public void NoDataShouldBeSentToServer() + { + _channelSessionMock.VerifyAll(); + } + + [TestMethod] + public void FlushShouldSendWrittenBytesToServer() + { + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData( + It.Is(data => data.Take(_data.Length).IsEqualTo(_data)), + 0, + _data.Length)); + + _shellStream.Flush(); + + _channelSessionMock.VerifyAll(); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteZeroBytes.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteZeroBytes.cs new file mode 100644 index 000000000..94d55ca87 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferEmptyAndWriteZeroBytes.cs @@ -0,0 +1,128 @@ +using System; +using System.Collections.Generic; +using System.Text; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Moq; + +using Renci.SshNet.Channels; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class ShellStreamTest_Write_WriteBufferEmptyAndWriteZeroBytes + { + private Mock _sessionMock; + private Mock _connectionInfoMock; + private Mock _channelSessionMock; + private string _terminalName; + private uint _widthColumns; + private uint _heightRows; + private uint _widthPixels; + private uint _heightPixels; + private Dictionary _terminalModes; + private ShellStream _shellStream; + private int _bufferSize; + + private byte[] _data; + private int _offset; + private int _count; + private MockSequence _mockSequence; + + [TestInitialize] + public void Initialize() + { + Arrange(); + Act(); + } + + private void SetupData() + { + var random = new Random(); + + _terminalName = random.Next().ToString(); + _widthColumns = (uint)random.Next(); + _heightRows = (uint)random.Next(); + _widthPixels = (uint)random.Next(); + _heightPixels = (uint)random.Next(); + _terminalModes = new Dictionary(); + _bufferSize = random.Next(100, 1000); + + _data = new byte[0]; + _offset = 0; + _count = _data.Length; + } + + private void CreateMocks() + { + _sessionMock = new Mock(MockBehavior.Strict); + _connectionInfoMock = new Mock(MockBehavior.Strict); + _channelSessionMock = new Mock(MockBehavior.Strict); + } + + private void SetupMocks() + { + _mockSequence = new MockSequence(); + + _ = _sessionMock.InSequence(_mockSequence) + .Setup(p => p.ConnectionInfo) + .Returns(_connectionInfoMock.Object); + _ = _connectionInfoMock.InSequence(_mockSequence) + .Setup(p => p.Encoding) + .Returns(new UTF8Encoding()); + _ = _sessionMock.InSequence(_mockSequence) + .Setup(p => p.CreateChannelSession()) + .Returns(_channelSessionMock.Object); + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.Open()); + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendPseudoTerminalRequest(_terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes)) + .Returns(true); + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendShellRequest()) + .Returns(true); + } + + private void Arrange() + { + SetupData(); + CreateMocks(); + SetupMocks(); + + _shellStream = new ShellStream(_sessionMock.Object, + _terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes, + _bufferSize); + } + + private void Act() + { + _shellStream.Write(_data, _offset, _count); + } + + [TestMethod] + public void NoDataShouldBeSentToServer() + { + _channelSessionMock.Verify(p => p.SendData(It.IsAny()), Times.Never); + } + + [TestMethod] + public void FlushShouldSendNoBytesToServer() + { + _shellStream.Flush(); + + _channelSessionMock.Verify(p => p.SendData(It.IsAny()), Times.Never); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferFullAndWriteLessBytesThanBufferSize.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferFullAndWriteLessBytesThanBufferSize.cs new file mode 100644 index 000000000..40b095355 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferFullAndWriteLessBytesThanBufferSize.cs @@ -0,0 +1,138 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Renci.SshNet.Abstractions; +using Renci.SshNet.Channels; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class ShellStreamTest_Write_WriteBufferFullAndWriteLessBytesThanBufferSize + { + private Mock _sessionMock; + private Mock _connectionInfoMock; + private Mock _channelSessionMock; + private string _terminalName; + private uint _widthColumns; + private uint _heightRows; + private uint _widthPixels; + private uint _heightPixels; + private Dictionary _terminalModes; + private ShellStream _shellStream; + private int _bufferSize; + + private byte[] _data; + private int _offset; + private int _count; + private MockSequence _mockSequence; + private byte[] _bufferData; + + [TestInitialize] + public void Initialize() + { + Arrange(); + Act(); + } + + private void SetupData() + { + var random = new Random(); + + _terminalName = random.Next().ToString(); + _widthColumns = (uint)random.Next(); + _heightRows = (uint)random.Next(); + _widthPixels = (uint)random.Next(); + _heightPixels = (uint)random.Next(); + _terminalModes = new Dictionary(); + _bufferSize = random.Next(100, 1000); + + _bufferData = CryptoAbstraction.GenerateRandom(_bufferSize); + _data = CryptoAbstraction.GenerateRandom(_bufferSize - 10); + _offset = 0; + _count = _data.Length; + } + + private void CreateMocks() + { + _sessionMock = new Mock(MockBehavior.Strict); + _connectionInfoMock = new Mock(MockBehavior.Strict); + _channelSessionMock = new Mock(MockBehavior.Strict); + } + + private void SetupMocks() + { + _mockSequence = new MockSequence(); + + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.ConnectionInfo) + .Returns(_connectionInfoMock.Object); + _connectionInfoMock.InSequence(_mockSequence) + .Setup(p => p.Encoding) + .Returns(new UTF8Encoding()); + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.CreateChannelSession()) + .Returns(_channelSessionMock.Object); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.Open()); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendPseudoTerminalRequest(_terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes)) + .Returns(true); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendShellRequest()) + .Returns(true); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData(_bufferData, 0, _bufferData.Length)); + } + + private void Arrange() + { + SetupData(); + CreateMocks(); + SetupMocks(); + + _shellStream = new ShellStream(_sessionMock.Object, + _terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes, + _bufferSize); + + _shellStream.Write(_bufferData, 0, _bufferData.Length); + } + + private void Act() + { + _shellStream.Write(_data, _offset, _count); + } + + [TestMethod] + public void BufferShouldBeSentToServer() + { + _channelSessionMock.VerifyAll(); + } + + [TestMethod] + public void FlushShouldSendRemainingBytesInBufferToServer() + { + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData( + It.Is(data => data.Take(_data.Length).IsEqualTo(_data)), + 0, + _data.Length)); + + _shellStream.Flush(); + + _channelSessionMock.VerifyAll(); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferFullAndWriteZeroBytes.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferFullAndWriteZeroBytes.cs new file mode 100644 index 000000000..b58c0e970 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferFullAndWriteZeroBytes.cs @@ -0,0 +1,136 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Renci.SshNet.Abstractions; +using Renci.SshNet.Channels; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class ShellStreamTest_Write_WriteBufferFullAndWriteZeroBytes + { + private Mock _sessionMock; + private Mock _connectionInfoMock; + private Mock _channelSessionMock; + private string _terminalName; + private uint _widthColumns; + private uint _heightRows; + private uint _widthPixels; + private uint _heightPixels; + private Dictionary _terminalModes; + private ShellStream _shellStream; + private int _bufferSize; + + private byte[] _data; + private int _offset; + private int _count; + private MockSequence _mockSequence; + private byte[] _bufferData; + + [TestInitialize] + public void Initialize() + { + Arrange(); + Act(); + } + + private void SetupData() + { + var random = new Random(); + + _terminalName = random.Next().ToString(); + _widthColumns = (uint) random.Next(); + _heightRows = (uint) random.Next(); + _widthPixels = (uint) random.Next(); + _heightPixels = (uint) random.Next(); + _terminalModes = new Dictionary(); + _bufferSize = random.Next(100, 1000); + + _bufferData = CryptoAbstraction.GenerateRandom(_bufferSize); + _data = new byte[0]; + _offset = 0; + _count = _data.Length; + } + + private void CreateMocks() + { + _sessionMock = new Mock(MockBehavior.Strict); + _connectionInfoMock = new Mock(MockBehavior.Strict); + _channelSessionMock = new Mock(MockBehavior.Strict); + } + + private void SetupMocks() + { + _mockSequence = new MockSequence(); + + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.ConnectionInfo) + .Returns(_connectionInfoMock.Object); + _connectionInfoMock.InSequence(_mockSequence) + .Setup(p => p.Encoding) + .Returns(new UTF8Encoding()); + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.CreateChannelSession()) + .Returns(_channelSessionMock.Object); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.Open()); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendPseudoTerminalRequest(_terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes)) + .Returns(true); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendShellRequest()) + .Returns(true); + } + + private void Arrange() + { + SetupData(); + CreateMocks(); + SetupMocks(); + + _shellStream = new ShellStream(_sessionMock.Object, + _terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes, + _bufferSize); + + _shellStream.Write(_bufferData, 0, _bufferData.Length); + } + + private void Act() + { + _shellStream.Write(_data, _offset, _count); + } + + [TestMethod] + public void NoDataShouldBeSentToServer() + { + _channelSessionMock.VerifyAll(); + } + + [TestMethod] + public void FlushShouldSendBufferToServer() + { + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData( + It.Is(data => data.Take(_bufferData.Length).IsEqualTo(_bufferData)), + 0, + _bufferData.Length)); + + _shellStream.Flush(); + + _channelSessionMock.VerifyAll(); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferNotEmptyAndWriteLessBytesThanBufferCanContain.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferNotEmptyAndWriteLessBytesThanBufferCanContain.cs new file mode 100644 index 000000000..99ed903e9 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferNotEmptyAndWriteLessBytesThanBufferCanContain.cs @@ -0,0 +1,141 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Renci.SshNet.Abstractions; +using Renci.SshNet.Channels; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class ShellStreamTest_Write_WriteBufferNotEmptyAndWriteLessBytesThanBufferCanContain + { + private Mock _sessionMock; + private Mock _connectionInfoMock; + private Mock _channelSessionMock; + private string _terminalName; + private uint _widthColumns; + private uint _heightRows; + private uint _widthPixels; + private uint _heightPixels; + private Dictionary _terminalModes; + private ShellStream _shellStream; + private int _bufferSize; + + private byte[] _data; + private int _offset; + private int _count; + private MockSequence _mockSequence; + private byte[] _bufferData; + + [TestInitialize] + public void Initialize() + { + Arrange(); + Act(); + } + + private void SetupData() + { + var random = new Random(); + + _terminalName = random.Next().ToString(); + _widthColumns = (uint) random.Next(); + _heightRows = (uint) random.Next(); + _widthPixels = (uint) random.Next(); + _heightPixels = (uint) random.Next(); + _terminalModes = new Dictionary(); + _bufferSize = random.Next(100, 1000); + + _bufferData = CryptoAbstraction.GenerateRandom(_bufferSize - 60); + _data = CryptoAbstraction.GenerateRandom(_bufferSize + 100); + _offset = 0; + _count = _bufferSize - _bufferData.Length - random.Next(1, 10); + } + + private void CreateMocks() + { + _sessionMock = new Mock(MockBehavior.Strict); + _connectionInfoMock = new Mock(MockBehavior.Strict); + _channelSessionMock = new Mock(MockBehavior.Strict); + } + + private void SetupMocks() + { + _mockSequence = new MockSequence(); + + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.ConnectionInfo) + .Returns(_connectionInfoMock.Object); + _connectionInfoMock.InSequence(_mockSequence) + .Setup(p => p.Encoding) + .Returns(new UTF8Encoding()); + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.CreateChannelSession()) + .Returns(_channelSessionMock.Object); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.Open()); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendPseudoTerminalRequest(_terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes)) + .Returns(true); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendShellRequest()) + .Returns(true); + } + + private void Arrange() + { + SetupData(); + CreateMocks(); + SetupMocks(); + + _shellStream = new ShellStream(_sessionMock.Object, + _terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes, + _bufferSize); + + _shellStream.Write(_bufferData, 0, _bufferData.Length); + } + + private void Act() + { + _shellStream.Write(_data, _offset, _count); + } + + [TestMethod] + public void NoDataShouldBeSentToServer() + { + _channelSessionMock.VerifyAll(); + } + + [TestMethod] + public void FlushShouldSendWrittenBytesToServer() + { + byte[] bytesSent = null; + + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((data, offset, count) => bytesSent = data.Take(offset, count)); + + _shellStream.Flush(); + + Assert.IsNotNull(bytesSent); + Assert.AreEqual(_bufferData.Length + _count, bytesSent.Length); + Assert.IsTrue(_bufferData.IsEqualTo(bytesSent.Take(_bufferData.Length))); + Assert.IsTrue(_data.Take(0, _count).IsEqualTo(bytesSent.Take(_bufferData.Length, _count))); + + _channelSessionMock.VerifyAll(); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferNotEmptyAndWriteMoreBytesThanBufferCanContain.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferNotEmptyAndWriteMoreBytesThanBufferCanContain.cs new file mode 100644 index 000000000..de57d0512 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferNotEmptyAndWriteMoreBytesThanBufferCanContain.cs @@ -0,0 +1,149 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Renci.SshNet.Abstractions; +using Renci.SshNet.Channels; +using Renci.SshNet.Common; +using Renci.SshNet.Tests.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class ShellStreamTest_Write_WriteBufferNotEmptyAndWriteMoreBytesThanBufferCanContain + { + private Mock _sessionMock; + private Mock _connectionInfoMock; + private Mock _channelSessionMock; + private string _terminalName; + private uint _widthColumns; + private uint _heightRows; + private uint _widthPixels; + private uint _heightPixels; + private Dictionary _terminalModes; + private ShellStream _shellStream; + private int _bufferSize; + + private byte[] _data; + private int _offset; + private int _count; + private MockSequence _mockSequence; + private byte[] _bufferData; + private byte[] _expectedBytesSent; + + [TestInitialize] + public void Initialize() + { + Arrange(); + Act(); + } + + private void SetupData() + { + var random = new Random(); + + _terminalName = random.Next().ToString(); + _widthColumns = (uint) random.Next(); + _heightRows = (uint) random.Next(); + _widthPixels = (uint) random.Next(); + _heightPixels = (uint) random.Next(); + _terminalModes = new Dictionary(); + _bufferSize = random.Next(100, 1000); + + _bufferData = CryptoAbstraction.GenerateRandom(_bufferSize - 60); + _data = CryptoAbstraction.GenerateRandom(_bufferSize - _bufferData.Length + random.Next(1, 10)); + _offset = 0; + _count = _data.Length; + + _expectedBytesSent = new ArrayBuilder().Add(_bufferData) + .Add(_data, 0, _bufferSize - _bufferData.Length) + .Build(); + } + + private void CreateMocks() + { + _sessionMock = new Mock(MockBehavior.Strict); + _connectionInfoMock = new Mock(MockBehavior.Strict); + _channelSessionMock = new Mock(MockBehavior.Strict); + } + + private void SetupMocks() + { + _mockSequence = new MockSequence(); + + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.ConnectionInfo) + .Returns(_connectionInfoMock.Object); + _connectionInfoMock.InSequence(_mockSequence) + .Setup(p => p.Encoding) + .Returns(new UTF8Encoding()); + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.CreateChannelSession()) + .Returns(_channelSessionMock.Object); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.Open()); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendPseudoTerminalRequest(_terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes)) + .Returns(true); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendShellRequest()) + .Returns(true); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData(_expectedBytesSent, 0, _expectedBytesSent.Length)); + } + + private void Arrange() + { + SetupData(); + CreateMocks(); + SetupMocks(); + + _shellStream = new ShellStream(_sessionMock.Object, + _terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes, + _bufferSize); + + _shellStream.Write(_bufferData, 0, _bufferData.Length); + } + + private void Act() + { + _shellStream.Write(_data, _offset, _count); + } + + [TestMethod] + public void BufferShouldBeSentToServer() + { + _channelSessionMock.VerifyAll(); + } + + [TestMethod] + public void FlushShouldSendRemainingBytesInBufferToServer() + { + var expectedBytesSent = _data.Take(_bufferSize - _bufferData.Length, _data.Length + _bufferData.Length - _bufferSize); + byte[] actualBytesSent = null; + + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((data, offset, count) => actualBytesSent = data.Take(offset, count)); + + _shellStream.Flush(); + + Assert.IsNotNull(actualBytesSent); + Assert.AreEqual(expectedBytesSent.Length, actualBytesSent.Length); + Assert.IsTrue(expectedBytesSent.IsEqualTo(actualBytesSent)); + + _channelSessionMock.VerifyAll(); + } + } +} diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferNotEmptyAndWriteZeroBytes.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferNotEmptyAndWriteZeroBytes.cs new file mode 100644 index 000000000..c1c7d39ec --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_Write_WriteBufferNotEmptyAndWriteZeroBytes.cs @@ -0,0 +1,136 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Renci.SshNet.Abstractions; +using Renci.SshNet.Channels; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class ShellStreamTest_Write_WriteBufferNotEmptyAndWriteZeroBytes + { + private Mock _sessionMock; + private Mock _connectionInfoMock; + private Mock _channelSessionMock; + private string _terminalName; + private uint _widthColumns; + private uint _heightRows; + private uint _widthPixels; + private uint _heightPixels; + private Dictionary _terminalModes; + private ShellStream _shellStream; + private int _bufferSize; + + private byte[] _data; + private int _offset; + private int _count; + private MockSequence _mockSequence; + private byte[] _bufferData; + + [TestInitialize] + public void Initialize() + { + Arrange(); + Act(); + } + + private void SetupData() + { + var random = new Random(); + + _terminalName = random.Next().ToString(); + _widthColumns = (uint)random.Next(); + _heightRows = (uint)random.Next(); + _widthPixels = (uint)random.Next(); + _heightPixels = (uint)random.Next(); + _terminalModes = new Dictionary(); + _bufferSize = random.Next(100, 1000); + + _bufferData = CryptoAbstraction.GenerateRandom(_bufferSize - 60); + _data = new byte[0]; + _offset = 0; + _count = _data.Length; + } + + private void CreateMocks() + { + _sessionMock = new Mock(MockBehavior.Strict); + _connectionInfoMock = new Mock(MockBehavior.Strict); + _channelSessionMock = new Mock(MockBehavior.Strict); + } + + private void SetupMocks() + { + _mockSequence = new MockSequence(); + + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.ConnectionInfo) + .Returns(_connectionInfoMock.Object); + _connectionInfoMock.InSequence(_mockSequence) + .Setup(p => p.Encoding) + .Returns(new UTF8Encoding()); + _sessionMock.InSequence(_mockSequence) + .Setup(p => p.CreateChannelSession()) + .Returns(_channelSessionMock.Object); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.Open()); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendPseudoTerminalRequest(_terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes)) + .Returns(true); + _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendShellRequest()) + .Returns(true); + } + + private void Arrange() + { + SetupData(); + CreateMocks(); + SetupMocks(); + + _shellStream = new ShellStream(_sessionMock.Object, + _terminalName, + _widthColumns, + _heightRows, + _widthPixels, + _heightPixels, + _terminalModes, + _bufferSize); + + _shellStream.Write(_bufferData, 0, _bufferData.Length); + } + + private void Act() + { + _shellStream.Write(_data, _offset, _count); + } + + [TestMethod] + public void NoDataShouldBeSentToServer() + { + _channelSessionMock.VerifyAll(); + } + + [TestMethod] + public void FlushShouldSendWrittenBytesToServer() + { + _ = _channelSessionMock.InSequence(_mockSequence) + .Setup(p => p.SendData( + It.Is(data => data.Take(_bufferData.Length).IsEqualTo(_bufferData)), + 0, + _bufferData.Length)); + + _shellStream.Flush(); + + _channelSessionMock.VerifyAll(); + } + } +}