Skip to content

Commit

Permalink
Add option for truncated stream detection (#75671)
Browse files Browse the repository at this point in the history
* Add option for truncated stream detection

fix #47563

* Use RemoteExecutor

move the test to concrete classes as abstracted classes are not supported by RemoteExecutor

* review feedback

* use same error text message

* cache appcontext getswitch

* fix failing test

* slice byte array for assertion

* renaming

* add missing RemoteExecutor.IsSupported

* fast check first
  • Loading branch information
mfkl authored Nov 2, 2022
1 parent d3a3eda commit f43d5e3
Show file tree
Hide file tree
Showing 12 changed files with 417 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.IO.Compression.Tests;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Xunit.Sdk;

namespace System.IO.Compression
{
Expand Down Expand Up @@ -468,86 +465,6 @@ async Task<long> GetLengthAsync(CompressionLevel compressionLevel)
Assert.True(fastestLength >= optimalLength);
Assert.True(optimalLength >= smallestLength);
}

[ActiveIssue("https://github.com/dotnet/runtime/issues/47563")]
[Theory]
[InlineData(TestScenario.ReadAsync)]
[InlineData(TestScenario.Read)]
[InlineData(TestScenario.Copy)]
[InlineData(TestScenario.CopyAsync)]
[InlineData(TestScenario.ReadByte)]
[InlineData(TestScenario.ReadByteAsync)]
public async Task StreamTruncation_IsDetected(TestScenario scenario)
{
var buffer = new byte[16];
byte[] source = Enumerable.Range(0, 64).Select(i => (byte)i).ToArray();
byte[] compressedData;
using (var compressed = new MemoryStream())
using (Stream compressor = CreateStream(compressed, CompressionMode.Compress))
{
foreach (byte b in source)
{
compressor.WriteByte(b);
}

compressor.Dispose();
compressedData = compressed.ToArray();
}

for (var i = 1; i <= compressedData.Length; i += 1)
{
bool expectException = i < compressedData.Length;
using (var compressedStream = new MemoryStream(compressedData.Take(i).ToArray()))
{
using (Stream decompressor = CreateStream(compressedStream, CompressionMode.Decompress))
{
var decompressedStream = new MemoryStream();

try
{
switch (scenario)
{
case TestScenario.Copy:
decompressor.CopyTo(decompressedStream);
break;

case TestScenario.CopyAsync:
await decompressor.CopyToAsync(decompressedStream);
break;

case TestScenario.Read:
while (ZipFileTestBase.ReadAllBytes(decompressor, buffer, 0, buffer.Length) != 0) { };
break;

case TestScenario.ReadAsync:
while (await ZipFileTestBase.ReadAllBytesAsync(decompressor, buffer, 0, buffer.Length) != 0) { };
break;

case TestScenario.ReadByte:
while (decompressor.ReadByte() != -1) { }
break;

case TestScenario.ReadByteAsync:
while (await decompressor.ReadByteAsync() != -1) { }
break;
}
}
catch (InvalidDataException e)
{
if (expectException)
continue;

throw new XunitException($"An unexpected error occurred while decompressing data:{e}");
}

if (expectException)
{
throw new XunitException($"Truncated stream was decompressed successfully but exception was expected: length={i}/{compressedData.Length}");
}
}
}
}
}
}

public enum TestScenario
Expand Down
41 changes: 41 additions & 0 deletions src/libraries/Common/tests/System/IO/Compression/ZipTestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ public static void StreamsEqual(Stream ast, Stream bst)
StreamsEqual(ast, bst, -1);
}

public static async Task StreamsEqualAsync(Stream ast, Stream bst)
{
await StreamsEqualAsync(ast, bst, -1);
}

public static void StreamsEqual(Stream ast, Stream bst, int blocksToRead)
{
if (ast.CanSeek)
Expand Down Expand Up @@ -146,6 +151,42 @@ public static void StreamsEqual(Stream ast, Stream bst, int blocksToRead)
} while (ac == bufSize);
}

public static async Task StreamsEqualAsync(Stream ast, Stream bst, int blocksToRead)
{
if (ast.CanSeek)
ast.Seek(0, SeekOrigin.Begin);
if (bst.CanSeek)
bst.Seek(0, SeekOrigin.Begin);

const int bufSize = 4096;
byte[] ad = new byte[bufSize];
byte[] bd = new byte[bufSize];

int ac = 0;
int bc = 0;

int blocksRead = 0;

//assume read doesn't do weird things
do
{
if (blocksToRead != -1 && blocksRead >= blocksToRead)
break;

ac = await ast.ReadAtLeastAsync(ad, 4096, throwOnEndOfStream: false);
bc = await bst.ReadAtLeastAsync(bd, 4096, throwOnEndOfStream: false);

if (ac != bc)
{
bd = NormalizeLineEndings(bd);
}

AssertExtensions.SequenceEqual(ad.AsSpan(0, ac), bd.AsSpan(0, bc));

blocksRead++;
} while (ac == bufSize);
}

public static async Task IsZipSameAsDirAsync(string archiveFile, string directory, ZipArchiveMode mode)
{
await IsZipSameAsDirAsync(archiveFile, directory, mode, requireExplicit: false, checkTimes: false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@
<data name="BrotliStream_Decompress_InvalidStream" xml:space="preserve">
<value>BrotliStream.BaseStream returned more bytes than requested in Read.</value>
</data>
<data name="BrotliStream_Decompress_TruncatedData" xml:space="preserve">
<value>Found truncated data while decoding.</value>
</data>
<data name="IOCompressionBrotli_PlatformNotSupported" xml:space="preserve">
<value>System.IO.Compression.Brotli is not supported on this platform.</value>
</data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public sealed partial class BrotliStream : Stream
private BrotliDecoder _decoder;
private int _bufferOffset;
private int _bufferCount;
private bool _nonEmptyInput;

/// <summary>Reads a number of decompressed bytes into the specified byte array.</summary>
/// <param name="buffer">The array used to store decompressed bytes.</param>
Expand Down Expand Up @@ -65,9 +66,12 @@ public override int Read(Span<byte> buffer)
int bytesRead = _stream.Read(_buffer, _bufferCount, _buffer.Length - _bufferCount);
if (bytesRead <= 0)
{
if (s_useStrictValidation && _nonEmptyInput && !buffer.IsEmpty)
ThrowTruncatedInvalidData();
break;
}

_nonEmptyInput = true;
_bufferCount += bytesRead;

if (_bufferCount > _buffer.Length)
Expand Down Expand Up @@ -150,10 +154,13 @@ async ValueTask<int> Core(Memory<byte> buffer, CancellationToken cancellationTok
int bytesRead = await _stream.ReadAsync(_buffer.AsMemory(_bufferCount), cancellationToken).ConfigureAwait(false);
if (bytesRead <= 0)
{
if (s_useStrictValidation && _nonEmptyInput && !buffer.IsEmpty)
ThrowTruncatedInvalidData();
break;
}

_bufferCount += bytesRead;
_nonEmptyInput = true;

if (_bufferCount > _buffer.Length)
{
Expand Down Expand Up @@ -227,9 +234,15 @@ private bool TryDecompress(Span<byte> destination, out int bytesWritten)
return false;
}

private static readonly bool s_useStrictValidation =
AppContext.TryGetSwitch("System.IO.Compression.UseStrictValidation", out bool strictValidation) ? strictValidation : false;

private static void ThrowInvalidStream() =>
// The stream is either malicious or poorly implemented and returned a number of
// bytes larger than the buffer supplied to it.
throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream);

private static void ThrowTruncatedInvalidData() =>
throw new InvalidDataException(SR.BrotliStream_Decompress_TruncatedData);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ public void CreateFromDirectory_IncludeBaseDirectory()
}
}

[Fact]
public async Task CreateFromDirectory_IncludeBaseDirectoryAsync()
{
string folderName = zfolder("normal");
string withBaseDir = GetTestFilePath();
ZipFile.CreateFromDirectory(folderName, withBaseDir, CompressionLevel.Optimal, true);

IEnumerable<string> expected = Directory.EnumerateFiles(zfolder("normal"), "*", SearchOption.AllDirectories);
using (ZipArchive actual_withbasedir = ZipFile.Open(withBaseDir, ZipArchiveMode.Read))
{
foreach (ZipArchiveEntry actualEntry in actual_withbasedir.Entries)
{
string expectedFile = expected.Single(i => Path.GetFileName(i).Equals(actualEntry.Name));
Assert.StartsWith("normal", actualEntry.FullName);
Assert.Equal(new FileInfo(expectedFile).Length, actualEntry.Length);
using (Stream expectedStream = File.OpenRead(expectedFile))
using (Stream actualStream = actualEntry.Open())
{
await StreamsEqualAsync(expectedStream, actualStream);
}
}
}
}

[Fact]
public void CreateFromDirectoryUnicode()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@
<data name="SplitSpanned" xml:space="preserve">
<value>Split or spanned archives are not supported.</value>
</data>
<data name="TruncatedData" xml:space="preserve">
<value>Found truncated data while decoding.</value>
</data>
<data name="UnexpectedEndOfStream" xml:space="preserve">
<value>Zip file corrupt: unexpected end of stream reached.</value>
</data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ internal int ReadCore(Span<byte> buffer)
int n = _stream.Read(_buffer, 0, _buffer.Length);
if (n <= 0)
{
// - Inflater didn't return any data although a non-empty output buffer was passed by the caller.
// - More input is needed but there is no more input available.
// - Inflation is not finished yet.
// - Provided input wasn't completely empty
// In such case, we are dealing with a truncated input stream.
if (s_useStrictValidation && !buffer.IsEmpty && !_inflater.Finished() && _inflater.NonEmptyInput())
{
ThrowTruncatedInvalidData();
}
break;
}
else if (n > _buffer.Length)
Expand Down Expand Up @@ -347,6 +356,9 @@ private static void ThrowGenericInvalidData() =>
// bytes < 0 || > than the buffer supplied to it.
throw new InvalidDataException(SR.GenericInvalidData);

private static void ThrowTruncatedInvalidData() =>
throw new InvalidDataException(SR.TruncatedData);

public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState) =>
TaskToApm.Begin(ReadAsync(buffer, offset, count, CancellationToken.None), asyncCallback, asyncState);

Expand Down Expand Up @@ -416,6 +428,15 @@ async ValueTask<int> Core(Memory<byte> buffer, CancellationToken cancellationTok
int n = await _stream.ReadAsync(new Memory<byte>(_buffer, 0, _buffer.Length), cancellationToken).ConfigureAwait(false);
if (n <= 0)
{
// - Inflater didn't return any data although a non-empty output buffer was passed by the caller.
// - More input is needed but there is no more input available.
// - Inflation is not finished yet.
// - Provided input wasn't completely empty
// In such case, we are dealing with a truncated input stream.
if (s_useStrictValidation && !_inflater.Finished() && _inflater.NonEmptyInput() && !buffer.IsEmpty)
{
ThrowTruncatedInvalidData();
}
break;
}
else if (n > _buffer.Length)
Expand Down Expand Up @@ -893,6 +914,10 @@ public async Task CopyFromSourceToDestinationAsync()

// Now, use the source stream's CopyToAsync to push directly to our inflater via this helper stream
await _deflateStream._stream.CopyToAsync(this, _arrayPoolBuffer.Length, _cancellationToken).ConfigureAwait(false);
if (s_useStrictValidation && !_deflateStream._inflater.Finished())
{
ThrowTruncatedInvalidData();
}
}
finally
{
Expand Down Expand Up @@ -925,6 +950,10 @@ public void CopyFromSourceToDestination()

// Now, use the source stream's CopyToAsync to push directly to our inflater via this helper stream
_deflateStream._stream.CopyTo(this, _arrayPoolBuffer.Length);
if (s_useStrictValidation && !_deflateStream._inflater.Finished())
{
ThrowTruncatedInvalidData();
}
}
finally
{
Expand Down Expand Up @@ -1049,5 +1078,8 @@ private void AsyncOperationCompleting() =>

private static void ThrowInvalidBeginCall() =>
throw new InvalidOperationException(SR.InvalidBeginCall);

private static readonly bool s_useStrictValidation =
AppContext.TryGetSwitch("System.IO.Compression.UseStrictValidation", out bool strictValidation) ? strictValidation : false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ internal sealed class Inflater : IDisposable
private const int MinWindowBits = -15; // WindowBits must be between -8..-15 to ignore the header, 8..15 for
private const int MaxWindowBits = 47; // zlib headers, 24..31 for GZip headers, or 40..47 for either Zlib or GZip

private bool _nonEmptyInput; // Whether there is any non empty input
private bool _finished; // Whether the end of the stream has been reached
private bool _isDisposed; // Prevents multiple disposals
private readonly int _windowBits; // The WindowBits parameter passed to Inflater construction
Expand All @@ -34,6 +35,7 @@ internal Inflater(int windowBits, long uncompressedSize = -1)
{
Debug.Assert(windowBits >= MinWindowBits && windowBits <= MaxWindowBits);
_finished = false;
_nonEmptyInput = false;
_isDisposed = false;
_windowBits = windowBits;
InflateInit(windowBits);
Expand Down Expand Up @@ -176,6 +178,8 @@ private unsafe bool ResetStreamForLeftoverInput()

public bool NeedsInput() => _zlibStream.AvailIn == 0;

public bool NonEmptyInput() => _nonEmptyInput;

public void SetInput(byte[] inputBuffer, int startIndex, int count)
{
Debug.Assert(NeedsInput(), "We have something left in previous input!");
Expand All @@ -200,6 +204,7 @@ public unsafe void SetInput(ReadOnlyMemory<byte> inputBuffer)
_zlibStream.NextIn = (IntPtr)_inputBufferHandle.Pointer;
_zlibStream.AvailIn = (uint)inputBuffer.Length;
_finished = false;
_nonEmptyInput = true;
}
}

Expand Down
Loading

0 comments on commit f43d5e3

Please sign in to comment.