diff --git a/sdk/storage/Azure.Storage.Blobs.Batch/tests/Azure.Storage.Blobs.Batch.Tests.csproj b/sdk/storage/Azure.Storage.Blobs.Batch/tests/Azure.Storage.Blobs.Batch.Tests.csproj index 2b77907e9aaac..286ab317256bf 100644 --- a/sdk/storage/Azure.Storage.Blobs.Batch/tests/Azure.Storage.Blobs.Batch.Tests.csproj +++ b/sdk/storage/Azure.Storage.Blobs.Batch/tests/Azure.Storage.Blobs.Batch.Tests.csproj @@ -23,6 +23,7 @@ + PreserveNewest @@ -42,4 +43,4 @@ - \ No newline at end of file + diff --git a/sdk/storage/Azure.Storage.Blobs.ChangeFeed/tests/Azure.Storage.Blobs.ChangeFeed.Tests.csproj b/sdk/storage/Azure.Storage.Blobs.ChangeFeed/tests/Azure.Storage.Blobs.ChangeFeed.Tests.csproj index 9682ab15ecd60..8cf13cd60744f 100644 --- a/sdk/storage/Azure.Storage.Blobs.ChangeFeed/tests/Azure.Storage.Blobs.ChangeFeed.Tests.csproj +++ b/sdk/storage/Azure.Storage.Blobs.ChangeFeed/tests/Azure.Storage.Blobs.ChangeFeed.Tests.csproj @@ -17,6 +17,7 @@ + @@ -28,4 +29,4 @@ PreserveNewest - \ No newline at end of file + diff --git a/sdk/storage/Azure.Storage.Blobs/assets.json b/sdk/storage/Azure.Storage.Blobs/assets.json index f659bfa082944..6a8a60c8c101a 100644 --- a/sdk/storage/Azure.Storage.Blobs/assets.json +++ b/sdk/storage/Azure.Storage.Blobs/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "net", "TagPrefix": "net/storage/Azure.Storage.Blobs", - "Tag": "net/storage/Azure.Storage.Blobs_dcc7be748a" + "Tag": "net/storage/Azure.Storage.Blobs_8b3f7ac2a4" } diff --git a/sdk/storage/Azure.Storage.Blobs/src/AppendBlobClient.cs b/sdk/storage/Azure.Storage.Blobs/src/AppendBlobClient.cs index 5a396e60a0598..9a110cf8eb13a 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/AppendBlobClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/AppendBlobClient.cs @@ -1248,7 +1248,6 @@ internal async Task> AppendBlockInternal( string structuredBodyType = null; if (validationOptions != null && validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 && - validationOptions.PrecalculatedChecksum.IsEmpty && ClientSideEncryption == null) // don't allow feature combination { // report progress in terms of caller bytes, not encoded bytes @@ -1256,10 +1255,14 @@ internal async Task> AppendBlockInternal( contentLength = (content?.Length - content?.Position) ?? 0; structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; content = content.WithNoDispose().WithProgress(progressHandler); - content = new StructuredMessageEncodingStream( - content, - Constants.StructuredMessage.DefaultSegmentContentLength, - StructuredMessage.Flags.StorageCrc64); + content = validationOptions.PrecalculatedChecksum.IsEmpty + ? new StructuredMessageEncodingStream( + content, + Constants.StructuredMessage.DefaultSegmentContentLength, + StructuredMessage.Flags.StorageCrc64) + : new StructuredMessagePrecalculatedCrcWrapperStream( + content, + validationOptions.PrecalculatedChecksum.Span); contentLength = (content?.Length - content?.Position) ?? 0; } else diff --git a/sdk/storage/Azure.Storage.Blobs/src/Azure.Storage.Blobs.csproj b/sdk/storage/Azure.Storage.Blobs/src/Azure.Storage.Blobs.csproj index 527ebfabde810..851474c2d0dab 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/Azure.Storage.Blobs.csproj +++ b/sdk/storage/Azure.Storage.Blobs/src/Azure.Storage.Blobs.csproj @@ -95,6 +95,7 @@ + diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs index bb591b8d3c0e9..d271615f6f4b7 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs @@ -1574,6 +1574,7 @@ ValueTask> Factory(long offset, bool force .EnsureCompleted(), async startOffset => await StructuredMessageFactory(startOffset, async: true, cancellationToken) .ConfigureAwait(false), + default, //decodedData => response.Value.Details.ContentCrc = decodedData.TotalCrc.ToArray(), ClientConfiguration.Pipeline.ResponseClassifier, Constants.MaxReliabilityRetries); } @@ -1723,14 +1724,7 @@ private async ValueTask> StartDownloadAsyn rangeGetContentMD5 = true; break; case StorageChecksumAlgorithm.StorageCrc64: - if (!forceStructuredMessage && pageRange?.Length <= Constants.StructuredMessage.MaxDownloadCrcWithHeader) - { - rangeGetContentCRC64 = true; - } - else - { - structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; - } + structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; break; default: break; diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlockBlobClient.cs b/sdk/storage/Azure.Storage.Blobs/src/BlockBlobClient.cs index 0a90c66c851b8..fe9e6af997b6c 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlockBlobClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlockBlobClient.cs @@ -1337,7 +1337,6 @@ internal virtual async Task> StageBlockInternal( string structuredBodyType = null; if (validationOptions != null && validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 && - validationOptions.PrecalculatedChecksum.IsEmpty && ClientSideEncryption == null) // don't allow feature combination { // report progress in terms of caller bytes, not encoded bytes @@ -1345,10 +1344,14 @@ internal virtual async Task> StageBlockInternal( contentLength = (content?.Length - content?.Position) ?? 0; structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; content = content.WithNoDispose().WithProgress(progressHandler); - content = new StructuredMessageEncodingStream( - content, - Constants.StructuredMessage.DefaultSegmentContentLength, - StructuredMessage.Flags.StorageCrc64); + content = validationOptions.PrecalculatedChecksum.IsEmpty + ? new StructuredMessageEncodingStream( + content, + Constants.StructuredMessage.DefaultSegmentContentLength, + StructuredMessage.Flags.StorageCrc64) + : new StructuredMessagePrecalculatedCrcWrapperStream( + content, + validationOptions.PrecalculatedChecksum.Span); contentLength = (content?.Length - content?.Position) ?? 0; } else diff --git a/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadDetails.cs b/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadDetails.cs index bc119822cdc12..6104abfd9ac5f 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadDetails.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadDetails.cs @@ -34,6 +34,15 @@ public class BlobDownloadDetails public byte[] ContentHash { get; internal set; } #pragma warning restore CA1819 // Properties should not return arrays + // TODO enable in following PR + ///// + ///// When requested using , this value contains the CRC for the download blob range. + ///// This value may only become populated once the network stream is fully consumed. If this instance is accessed through + ///// , the network stream has already been consumed. Otherwise, consume the content stream before + ///// checking this value. + ///// + //public byte[] ContentCrc { get; internal set; } + /// /// Returns the date and time the container was last modified. Any operation that modifies the blob, including an update of the blob's metadata or properties, changes the last-modified time of the blob. /// diff --git a/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadInfo.cs b/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadInfo.cs index e034573b54b3a..1a525c718d1b4 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadInfo.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/Models/BlobDownloadInfo.cs @@ -4,6 +4,8 @@ using System; using System.ComponentModel; using System.IO; +using System.Threading.Tasks; +using Azure.Core; using Azure.Storage.Shared; namespace Azure.Storage.Blobs.Models @@ -49,6 +51,15 @@ public class BlobDownloadInfo : IDisposable, IDownloadedContent /// public BlobDownloadDetails Details { get; internal set; } + // TODO enable in following PR + ///// + ///// Indicates some contents of are mixed into the response stream. + ///// They will not be set until has been fully consumed. These details + ///// will be extracted from the content stream by the library before the calling code can + ///// encounter them. + ///// + //public bool ExpectTrailingDetails { get; internal set; } + /// /// Constructor. /// diff --git a/sdk/storage/Azure.Storage.Blobs/src/PageBlobClient.cs b/sdk/storage/Azure.Storage.Blobs/src/PageBlobClient.cs index 34befc6d4efe7..7038897531fbb 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/PageBlobClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/PageBlobClient.cs @@ -1370,7 +1370,6 @@ internal async Task> UploadPagesInternal( HttpRange range; if (validationOptions != null && validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 && - validationOptions.PrecalculatedChecksum.IsEmpty && ClientSideEncryption == null) // don't allow feature combination { // report progress in terms of caller bytes, not encoded bytes @@ -1379,10 +1378,14 @@ internal async Task> UploadPagesInternal( range = new HttpRange(offset, (content?.Length - content?.Position) ?? null); structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; content = content?.WithNoDispose().WithProgress(progressHandler); - content = new StructuredMessageEncodingStream( - content, - Constants.StructuredMessage.DefaultSegmentContentLength, - StructuredMessage.Flags.StorageCrc64); + content = validationOptions.PrecalculatedChecksum.IsEmpty + ? new StructuredMessageEncodingStream( + content, + Constants.StructuredMessage.DefaultSegmentContentLength, + StructuredMessage.Flags.StorageCrc64) + : new StructuredMessagePrecalculatedCrcWrapperStream( + content, + validationOptions.PrecalculatedChecksum.Span); contentLength = (content?.Length - content?.Position) ?? 0; } else diff --git a/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs b/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs index 2c52d0c256e34..276bdadb673fa 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs @@ -48,7 +48,8 @@ internal class PartitionedDownloader /// private readonly StorageChecksumAlgorithm _validationAlgorithm; private readonly int _checksumSize; - private bool UseMasterCrc => _validationAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64; + // TODO disabling master crc temporarily. segment CRCs still handled. + private bool UseMasterCrc => false; // _validationAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64; private StorageCrc64HashAlgorithm _masterCrcCalculator = null; /// @@ -212,8 +213,20 @@ public async Task DownloadToInternal( // If the first segment was the entire blob, we'll copy that to // the output stream and finish now - long initialLength = initialResponse.Value.Details.ContentLength; - long totalLength = ParseRangeTotalLength(initialResponse.Value.Details.ContentRange); + long initialLength; + long totalLength; + // Get blob content length downloaded from content range when available to handle transit encoding + if (string.IsNullOrWhiteSpace(initialResponse.Value.Details.ContentRange)) + { + initialLength = initialResponse.Value.Details.ContentLength; + totalLength = 0; + } + else + { + ContentRange recievedRange = ContentRange.Parse(initialResponse.Value.Details.ContentRange); + initialLength = recievedRange.End.Value - recievedRange.Start.Value + 1; + totalLength = recievedRange.Size.Value; + } if (initialLength == totalLength) { await HandleOneShotDownload(initialResponse, destination, async, cancellationToken) @@ -395,20 +408,6 @@ private async Task FinalizeDownloadInternal( } } - private static long ParseRangeTotalLength(string range) - { - if (range == null) - { - return 0; - } - int lengthSeparator = range.IndexOf("/", StringComparison.InvariantCultureIgnoreCase); - if (lengthSeparator == -1) - { - throw BlobErrors.ParsingFullHttpRangeFailed(range); - } - return long.Parse(range.Substring(lengthSeparator + 1), CultureInfo.InvariantCulture); - } - private async Task CopyToInternal( Response response, Stream destination, @@ -417,7 +416,10 @@ private async Task CopyToInternal( CancellationToken cancellationToken) { CancellationHelper.ThrowIfCancellationRequested(cancellationToken); - using IHasher hasher = ContentHasher.GetHasherFromAlgorithmId(_validationAlgorithm); + // if structured message, this crc is validated in the decoding process. don't decode it here. + using IHasher hasher = response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader) + ? null + : ContentHasher.GetHasherFromAlgorithmId(_validationAlgorithm); using Stream rawSource = response.Value.Content; using Stream source = hasher != null ? ChecksumCalculatingStream.GetReadStream(rawSource, hasher.AppendHash) @@ -432,13 +434,13 @@ await source.CopyToInternal( if (hasher != null) { hasher.GetFinalHash(checksumBuffer.Span); - (ReadOnlyMemory checksum, StorageChecksumAlgorithm _) - = ContentHasher.GetResponseChecksumOrDefault(response.GetRawResponse()); - if (!checksumBuffer.Span.SequenceEqual(checksum.Span)) - { - throw Errors.HashMismatchOnStreamedDownload(response.Value.Details.ContentRange); + (ReadOnlyMemory checksum, StorageChecksumAlgorithm _) + = ContentHasher.GetResponseChecksumOrDefault(response.GetRawResponse()); + if (!checksumBuffer.Span.SequenceEqual(checksum.Span)) + { + throw Errors.HashMismatchOnStreamedDownload(response.Value.Details.ContentRange); + } } - } } private IEnumerable GetRanges(long initialLength, long totalLength) diff --git a/sdk/storage/Azure.Storage.Blobs/tests/BlobBaseClientTransferValidationTests.cs b/sdk/storage/Azure.Storage.Blobs/tests/BlobBaseClientTransferValidationTests.cs index 5653e40af16e3..76d807835873c 100644 --- a/sdk/storage/Azure.Storage.Blobs/tests/BlobBaseClientTransferValidationTests.cs +++ b/sdk/storage/Azure.Storage.Blobs/tests/BlobBaseClientTransferValidationTests.cs @@ -95,59 +95,6 @@ public override void TestAutoResolve() } #region Added Tests - [TestCaseSource("GetValidationAlgorithms")] - public async Task ExpectedDownloadStreamingStreamTypeReturned(StorageChecksumAlgorithm algorithm) - { - await using var test = await GetDisposingContainerAsync(); - - // Arrange - var data = GetRandomBuffer(Constants.KB); - BlobClient blob = InstrumentClient(test.Container.GetBlobClient(GetNewResourceName())); - using (var stream = new MemoryStream(data)) - { - await blob.UploadAsync(stream); - } - // don't make options instance at all for no hash request - DownloadTransferValidationOptions transferValidation = algorithm == StorageChecksumAlgorithm.None - ? default - : new DownloadTransferValidationOptions { ChecksumAlgorithm = algorithm }; - - // Act - Response response = await blob.DownloadStreamingAsync(new BlobDownloadOptions - { - TransferValidation = transferValidation, - Range = new HttpRange(length: data.Length) - }); - - // Assert - // validated stream is buffered - Assert.AreEqual(typeof(MemoryStream), response.Value.Content.GetType()); - } - - [Test] - public async Task ExpectedDownloadStreamingStreamTypeReturned_None() - { - await using var test = await GetDisposingContainerAsync(); - - // Arrange - var data = GetRandomBuffer(Constants.KB); - BlobClient blob = InstrumentClient(test.Container.GetBlobClient(GetNewResourceName())); - using (var stream = new MemoryStream(data)) - { - await blob.UploadAsync(stream); - } - - // Act - Response response = await blob.DownloadStreamingAsync(new BlobDownloadOptions - { - Range = new HttpRange(length: data.Length) - }); - - // Assert - // unvalidated stream type is private; just check we didn't get back a buffered stream - Assert.AreNotEqual(typeof(MemoryStream), response.Value.Content.GetType()); - } - [Test] public virtual async Task OlderServiceVersionThrowsOnStructuredMessage() { diff --git a/sdk/storage/Azure.Storage.Blobs/tests/PartitionedDownloaderTests.cs b/sdk/storage/Azure.Storage.Blobs/tests/PartitionedDownloaderTests.cs index d8d4756a510c1..af408264c5bfa 100644 --- a/sdk/storage/Azure.Storage.Blobs/tests/PartitionedDownloaderTests.cs +++ b/sdk/storage/Azure.Storage.Blobs/tests/PartitionedDownloaderTests.cs @@ -305,7 +305,7 @@ public Response GetStream(HttpRange range, BlobRequ ContentHash = new byte[] { 1, 2, 3 }, LastModified = DateTimeOffset.Now, Metadata = new Dictionary() { { "meta", "data" } }, - ContentRange = $"bytes {range.Offset}-{range.Offset + contentLength}/{_length}", + ContentRange = $"bytes {range.Offset}-{Math.Max(1, range.Offset + contentLength - 1)}/{_length}", ETag = s_etag, ContentEncoding = "test", CacheControl = "test", diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/LazyLoadingReadOnlyStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/LazyLoadingReadOnlyStream.cs index c3e9c641c3fea..fe2db427bef02 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/LazyLoadingReadOnlyStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/LazyLoadingReadOnlyStream.cs @@ -249,41 +249,9 @@ private async Task DownloadInternal(bool async, CancellationToken cancellat response = await _downloadInternalFunc(range, _validationOptions, async, cancellationToken).ConfigureAwait(false); using Stream networkStream = response.Value.Content; - - // The number of bytes we just downloaded. - long downloadSize = GetResponseRange(response.GetRawResponse()).Length.Value; - - // The number of bytes we copied in the last loop. - int copiedBytes; - - // Bytes we have copied so far. - int totalCopiedBytes = 0; - - // Bytes remaining to copy. It is save to truncate the long because we asked for a max of int _buffer size bytes. - int remainingBytes = (int)downloadSize; - - do - { - if (async) - { - copiedBytes = await networkStream.ReadAsync( - buffer: _buffer, - offset: totalCopiedBytes, - count: remainingBytes, - cancellationToken: cancellationToken).ConfigureAwait(false); - } - else - { - copiedBytes = networkStream.Read( - buffer: _buffer, - offset: totalCopiedBytes, - count: remainingBytes); - } - - totalCopiedBytes += copiedBytes; - remainingBytes -= copiedBytes; - } - while (copiedBytes != 0); + // use stream copy to ensure consumption of any trailing metadata (e.g. structured message) + // allow buffer limits to catch the error of data size mismatch + int totalCopiedBytes = (int) await networkStream.CopyToInternal(new MemoryStream(_buffer), async, cancellationToken).ConfigureAwait((false)); _bufferPosition = 0; _bufferLength = totalCopiedBytes; @@ -291,7 +259,7 @@ private async Task DownloadInternal(bool async, CancellationToken cancellat // if we deferred transactional hash validation on download, validate now // currently we always defer but that may change - if (_validationOptions != default && _validationOptions.ChecksumAlgorithm != StorageChecksumAlgorithm.None && !_validationOptions.AutoValidateChecksum) + if (_validationOptions != default && _validationOptions.ChecksumAlgorithm == StorageChecksumAlgorithm.MD5 && !_validationOptions.AutoValidateChecksum) // TODO better condition { ContentHasher.AssertResponseHashMatch(_buffer, _bufferPosition, _bufferLength, _validationOptions.ChecksumAlgorithm, response.GetRawResponse()); } diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/PooledMemoryStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/PooledMemoryStream.cs index 3e218d18a90af..6070329d10d3d 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/PooledMemoryStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/PooledMemoryStream.cs @@ -251,7 +251,7 @@ public override int Read(byte[] buffer, int offset, int count) Length - Position, bufferCount - (Position - offsetOfBuffer), count - read); - Array.Copy(currentBuffer, Position - offsetOfBuffer, buffer, read, toCopy); + Array.Copy(currentBuffer, Position - offsetOfBuffer, buffer, offset + read, toCopy); read += toCopy; Position += toCopy; } diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StreamExtensions.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StreamExtensions.cs index 31f121d414ea4..c8803ecf421e7 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StreamExtensions.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StreamExtensions.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; +using System.Buffers; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -48,7 +50,7 @@ public static async Task WriteInternal( } } - public static Task CopyToInternal( + public static Task CopyToInternal( this Stream src, Stream dest, bool async, @@ -79,21 +81,33 @@ public static Task CopyToInternal( /// Cancellation token for the operation. /// /// - public static async Task CopyToInternal( + public static async Task CopyToInternal( this Stream src, Stream dest, int bufferSize, bool async, CancellationToken cancellationToken) { + using IDisposable _ = ArrayPool.Shared.RentDisposable(bufferSize, out byte[] buffer); + long totalRead = 0; + int read; if (async) { - await src.CopyToAsync(dest, bufferSize, cancellationToken).ConfigureAwait(false); + while (0 < (read = await src.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false))) + { + totalRead += read; + await dest.WriteAsync(buffer, 0, read, cancellationToken).ConfigureAwait(false); + } } else { - src.CopyTo(dest, bufferSize); + while (0 < (read = src.Read(buffer, 0, buffer.Length))) + { + totalRead += read; + dest.Write(buffer, 0, read); + } } + return totalRead; } } } diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs index 444fe3eb2e0a9..fe2d6697a4621 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs @@ -19,6 +19,7 @@ internal class StructuredMessageDecodingRetriableStream : Stream private long _decodedBytesRead; private readonly List _decodedDatas; + private readonly Action _onComplete; private readonly Func _decodingStreamFactory; private readonly Func> _decodingAsyncStreamFactory; @@ -28,6 +29,7 @@ public StructuredMessageDecodingRetriableStream( StructuredMessageDecodingStream.DecodedData initialDecodedData, Func decodingStreamFactory, Func> decodingAsyncStreamFactory, + Action onComplete, ResponseClassifier responseClassifier, int maxRetries) { @@ -35,6 +37,7 @@ public StructuredMessageDecodingRetriableStream( _decodingAsyncStreamFactory = decodingAsyncStreamFactory; _innerRetriable = RetriableStream.Create(initialDecodingStream, StreamFactory, StreamFactoryAsync, responseClassifier, maxRetries); _decodedDatas = new() { initialDecodedData }; + _onComplete = onComplete; } private Stream StreamFactory(long _) @@ -86,11 +89,22 @@ protected override void Dispose(bool disposing) _innerRetriable.Dispose(); } + private void OnCompleted() + { + StructuredMessageDecodingStream.DecodedData final = new(); + // TODO + _onComplete?.Invoke(final); + } + #region Read public override int Read(byte[] buffer, int offset, int count) { int read = _innerRetriable.Read(buffer, offset, count); _decodedBytesRead += read; + if (read == 0) + { + OnCompleted(); + } return read; } @@ -98,6 +112,10 @@ public override async Task ReadAsync(byte[] buffer, int offset, int count, { int read = await _innerRetriable.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); _decodedBytesRead += read; + if (read == 0) + { + OnCompleted(); + } return read; } @@ -106,6 +124,10 @@ public override int Read(Span buffer) { int read = _innerRetriable.Read(buffer); _decodedBytesRead += read; + if (read == 0) + { + OnCompleted(); + } return read; } @@ -113,6 +135,10 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation { int read = await _innerRetriable.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); _decodedBytesRead += read; + if (read == 0) + { + OnCompleted(); + } return read; } #endif @@ -121,6 +147,10 @@ public override int ReadByte() { int val = _innerRetriable.ReadByte(); _decodedBytesRead += 1; + if (val == -1) + { + OnCompleted(); + } return val; } @@ -128,6 +158,10 @@ public override int EndRead(IAsyncResult asyncResult) { int read = _innerRetriable.EndRead(asyncResult); _decodedBytesRead += read; + if (read == 0) + { + OnCompleted(); + } return read; } #endregion diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessagePrecalculatedCrcWrapperStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessagePrecalculatedCrcWrapperStream.cs new file mode 100644 index 0000000000000..3569ef4339735 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessagePrecalculatedCrcWrapperStream.cs @@ -0,0 +1,451 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core.Pipeline; +using Azure.Storage.Common; + +namespace Azure.Storage.Shared; + +internal class StructuredMessagePrecalculatedCrcWrapperStream : Stream +{ + private readonly Stream _innerStream; + + private readonly int _streamHeaderLength; + private readonly int _streamFooterLength; + private readonly int _segmentHeaderLength; + private readonly int _segmentFooterLength; + + private bool _disposed; + + private readonly byte[] _crc; + + public override bool CanRead => true; + + public override bool CanWrite => false; + + public override bool CanSeek => _innerStream.CanSeek; + + public override bool CanTimeout => _innerStream.CanTimeout; + + public override int ReadTimeout => _innerStream.ReadTimeout; + + public override int WriteTimeout => _innerStream.WriteTimeout; + + public override long Length => + _streamHeaderLength + _streamFooterLength + + _segmentHeaderLength + _segmentFooterLength + + _innerStream.Length; + + #region Position + private enum SMRegion + { + StreamHeader, + StreamFooter, + SegmentHeader, + SegmentFooter, + SegmentContent, + } + + private SMRegion _currentRegion = SMRegion.StreamHeader; + private int _currentRegionPosition = 0; + + private long _maxSeekPosition = 0; + + public override long Position + { + get + { + return _currentRegion switch + { + SMRegion.StreamHeader => _currentRegionPosition, + SMRegion.SegmentHeader => _innerStream.Position + + _streamHeaderLength + + _currentRegionPosition, + SMRegion.SegmentContent => _streamHeaderLength + + _segmentHeaderLength + + _innerStream.Position, + SMRegion.SegmentFooter => _streamHeaderLength + + _segmentHeaderLength + + _innerStream.Length + + _currentRegionPosition, + SMRegion.StreamFooter => _streamHeaderLength + + _segmentHeaderLength + + _innerStream.Length + + _segmentFooterLength + + _currentRegionPosition, + _ => throw new InvalidDataException($"{nameof(StructuredMessageEncodingStream)} invalid state."), + }; + } + set + { + Argument.AssertInRange(value, 0, _maxSeekPosition, nameof(value)); + if (value < _streamHeaderLength) + { + _currentRegion = SMRegion.StreamHeader; + _currentRegionPosition = (int)value; + _innerStream.Position = 0; + return; + } + if (value < _streamHeaderLength + _segmentHeaderLength) + { + _currentRegion = SMRegion.SegmentHeader; + _currentRegionPosition = (int)(value - _streamHeaderLength); + _innerStream.Position = 0; + return; + } + if (value < _streamHeaderLength + _segmentHeaderLength + _innerStream.Length) + { + _currentRegion = SMRegion.SegmentContent; + _currentRegionPosition = (int)(value - _streamHeaderLength - _segmentHeaderLength); + _innerStream.Position = value - _streamHeaderLength - _segmentHeaderLength; + return; + } + if (value < _streamHeaderLength + _segmentHeaderLength + _innerStream.Length + _segmentFooterLength) + { + _currentRegion = SMRegion.SegmentFooter; + _currentRegionPosition = (int)(value - _streamHeaderLength - _segmentHeaderLength - _innerStream.Length); + _innerStream.Position = _innerStream.Length; + return; + } + + _currentRegion = SMRegion.StreamFooter; + _currentRegionPosition = (int)(value - _streamHeaderLength - _segmentHeaderLength - _innerStream.Length - _segmentFooterLength); + _innerStream.Position = _innerStream.Length; + } + } + #endregion + + public StructuredMessagePrecalculatedCrcWrapperStream( + Stream innerStream, + ReadOnlySpan precalculatedCrc) + { + Argument.AssertNotNull(innerStream, nameof(innerStream)); + if (innerStream.GetLengthOrDefault() == default) + { + throw new ArgumentException("Stream must have known length.", nameof(innerStream)); + } + if (innerStream.Position != 0) + { + throw new ArgumentException("Stream must be at starting position.", nameof(innerStream)); + } + + _streamHeaderLength = StructuredMessage.V1_0.StreamHeaderLength; + _streamFooterLength = StructuredMessage.Crc64Length; + _segmentHeaderLength = StructuredMessage.V1_0.SegmentHeaderLength; + _segmentFooterLength = StructuredMessage.Crc64Length; + + _crc = ArrayPool.Shared.Rent(StructuredMessage.Crc64Length); + precalculatedCrc.CopyTo(_crc); + + _innerStream = innerStream; + } + + #region Write + public override void Flush() => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + #endregion + + #region Read + public override int Read(byte[] buffer, int offset, int count) + => ReadInternal(buffer, offset, count, async: false, cancellationToken: default).EnsureCompleted(); + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => await ReadInternal(buffer, offset, count, async: true, cancellationToken).ConfigureAwait(false); + + private async ValueTask ReadInternal(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken) + { + int totalRead = 0; + bool readInner = false; + while (totalRead < count && Position < Length) + { + int subreadOffset = offset + totalRead; + int subreadCount = count - totalRead; + switch (_currentRegion) + { + case SMRegion.StreamHeader: + totalRead += ReadFromStreamHeader(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.StreamFooter: + totalRead += ReadFromStreamFooter(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.SegmentHeader: + totalRead += ReadFromSegmentHeader(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.SegmentFooter: + totalRead += ReadFromSegmentFooter(new Span(buffer, subreadOffset, subreadCount)); + break; + case SMRegion.SegmentContent: + // don't double read from stream. Allow caller to multi-read when desired. + if (readInner) + { + UpdateLatestPosition(); + return totalRead; + } + totalRead += await ReadFromInnerStreamInternal( + buffer, subreadOffset, subreadCount, async, cancellationToken).ConfigureAwait(false); + readInner = true; + break; + default: + break; + } + } + UpdateLatestPosition(); + return totalRead; + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + public override int Read(Span buffer) + { + int totalRead = 0; + bool readInner = false; + while (totalRead < buffer.Length && Position < Length) + { + switch (_currentRegion) + { + case SMRegion.StreamHeader: + totalRead += ReadFromStreamHeader(buffer.Slice(totalRead)); + break; + case SMRegion.StreamFooter: + totalRead += ReadFromStreamFooter(buffer.Slice(totalRead)); + break; + case SMRegion.SegmentHeader: + totalRead += ReadFromSegmentHeader(buffer.Slice(totalRead)); + break; + case SMRegion.SegmentFooter: + totalRead += ReadFromSegmentFooter(buffer.Slice(totalRead)); + break; + case SMRegion.SegmentContent: + // don't double read from stream. Allow caller to multi-read when desired. + if (readInner) + { + UpdateLatestPosition(); + return totalRead; + } + totalRead += ReadFromInnerStream(buffer.Slice(totalRead)); + readInner = true; + break; + default: + break; + } + } + UpdateLatestPosition(); + return totalRead; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + int totalRead = 0; + bool readInner = false; + while (totalRead < buffer.Length && Position < Length) + { + switch (_currentRegion) + { + case SMRegion.StreamHeader: + totalRead += ReadFromStreamHeader(buffer.Slice(totalRead).Span); + break; + case SMRegion.StreamFooter: + totalRead += ReadFromStreamFooter(buffer.Slice(totalRead).Span); + break; + case SMRegion.SegmentHeader: + totalRead += ReadFromSegmentHeader(buffer.Slice(totalRead).Span); + break; + case SMRegion.SegmentFooter: + totalRead += ReadFromSegmentFooter(buffer.Slice(totalRead).Span); + break; + case SMRegion.SegmentContent: + // don't double read from stream. Allow caller to multi-read when desired. + if (readInner) + { + UpdateLatestPosition(); + return totalRead; + } + totalRead += await ReadFromInnerStreamAsync(buffer.Slice(totalRead), cancellationToken).ConfigureAwait(false); + readInner = true; + break; + default: + break; + } + } + UpdateLatestPosition(); + return totalRead; + } +#endif + + #region Read Headers/Footers + private int ReadFromStreamHeader(Span buffer) + { + int read = Math.Min(buffer.Length, _streamHeaderLength - _currentRegionPosition); + using IDisposable _ = StructuredMessage.V1_0.GetStreamHeaderBytes( + ArrayPool.Shared, + out Memory headerBytes, + Length, + StructuredMessage.Flags.StorageCrc64, + totalSegments: 1); + headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + if (_currentRegionPosition == _streamHeaderLength) + { + _currentRegion = SMRegion.SegmentHeader; + _currentRegionPosition = 0; + } + + return read; + } + + private int ReadFromStreamFooter(Span buffer) + { + int read = Math.Min(buffer.Length, _segmentFooterLength - _currentRegionPosition); + if (read <= 0) + { + return 0; + } + + using IDisposable _ = StructuredMessage.V1_0.GetStreamFooterBytes( + ArrayPool.Shared, + out Memory footerBytes, + new ReadOnlySpan(_crc, 0, StructuredMessage.Crc64Length)); + footerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + return read; + } + + private int ReadFromSegmentHeader(Span buffer) + { + int read = Math.Min(buffer.Length, _segmentHeaderLength - _currentRegionPosition); + using IDisposable _ = StructuredMessage.V1_0.GetSegmentHeaderBytes( + ArrayPool.Shared, + out Memory headerBytes, + segmentNum: 1, + _innerStream.Length); + headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + if (_currentRegionPosition == _segmentHeaderLength) + { + _currentRegion = SMRegion.SegmentContent; + _currentRegionPosition = 0; + } + + return read; + } + + private int ReadFromSegmentFooter(Span buffer) + { + int read = Math.Min(buffer.Length, _segmentFooterLength - _currentRegionPosition); + if (read < 0) + { + return 0; + } + + using IDisposable _ = StructuredMessage.V1_0.GetSegmentFooterBytes( + ArrayPool.Shared, + out Memory headerBytes, + new ReadOnlySpan(_crc, 0, StructuredMessage.Crc64Length)); + headerBytes.Slice(_currentRegionPosition, read).Span.CopyTo(buffer); + _currentRegionPosition += read; + + if (_currentRegionPosition == _segmentFooterLength) + { + _currentRegion = _innerStream.Position == _innerStream.Length + ? SMRegion.StreamFooter : SMRegion.SegmentHeader; + _currentRegionPosition = 0; + } + + return read; + } + #endregion + + #region ReadUnderlyingStream + private void CleanupContentSegment() + { + if (_innerStream.Position >= _innerStream.Length) + { + _currentRegion = SMRegion.SegmentFooter; + _currentRegionPosition = 0; + } + } + + private async ValueTask ReadFromInnerStreamInternal( + byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken) + { + int read = async + ? await _innerStream.ReadAsync(buffer, offset, count).ConfigureAwait(false) + : _innerStream.Read(buffer, offset, count); + _currentRegionPosition += read; + CleanupContentSegment(); + return read; + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + private int ReadFromInnerStream(Span buffer) + { + int read = _innerStream.Read(buffer); + _currentRegionPosition += read; + CleanupContentSegment(); + return read; + } + + private async ValueTask ReadFromInnerStreamAsync(Memory buffer, CancellationToken cancellationToken) + { + int read = await _innerStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + _currentRegionPosition += read; + CleanupContentSegment(); + return read; + } +#endif + #endregion + + // don't allow stream to seek too far forward. track how far the stream has been naturally read. + private void UpdateLatestPosition() + { + if (_maxSeekPosition < Position) + { + _maxSeekPosition = Position; + } + } + #endregion + + public override long Seek(long offset, SeekOrigin origin) + { + switch (origin) + { + case SeekOrigin.Begin: + Position = offset; + break; + case SeekOrigin.Current: + Position += offset; + break; + case SeekOrigin.End: + Position = Length + offset; + break; + } + return Position; + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + if (_disposed) + { + return; + } + + if (disposing) + { + ArrayPool.Shared.Return(_crc); + _innerStream.Dispose(); + _disposed = true; + } + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj b/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj index c4f9b715a692a..6a0dc0506be51 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj +++ b/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj @@ -18,6 +18,7 @@ + @@ -52,6 +53,7 @@ + diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs new file mode 100644 index 0000000000000..1414f4ec80076 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/ObserveStructuredMessagePolicy.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using Azure.Core; +using Azure.Core.Pipeline; +using Azure.Storage.Shared; + +namespace Azure.Storage.Test.Shared +{ + internal class ObserveStructuredMessagePolicy : HttpPipelineSynchronousPolicy + { + private readonly HashSet _requestScopes = new(); + + private readonly HashSet _responseScopes = new(); + + public ObserveStructuredMessagePolicy() + { + } + + public override void OnSendingRequest(HttpMessage message) + { + if (_requestScopes.Count > 0) + { + byte[] encodedContent; + byte[] underlyingContent; + StructuredMessageDecodingStream.DecodedData decodedData; + using (MemoryStream ms = new()) + { + message.Request.Content.WriteTo(ms, default); + encodedContent = ms.ToArray(); + using (MemoryStream ms2 = new()) + { + (Stream s, decodedData) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedContent)); + s.CopyTo(ms2); + underlyingContent = ms2.ToArray(); + } + } + } + } + + public override void OnReceivedResponse(HttpMessage message) + { + } + + public IDisposable CheckRequestScope() => CheckMessageScope.CheckRequestScope(this); + + public IDisposable CheckResponseScope() => CheckMessageScope.CheckResponseScope(this); + + private class CheckMessageScope : IDisposable + { + private bool _isRequestScope; + private ObserveStructuredMessagePolicy _policy; + + public static CheckMessageScope CheckRequestScope(ObserveStructuredMessagePolicy policy) + { + CheckMessageScope result = new() + { + _isRequestScope = true, + _policy = policy + }; + result._policy._requestScopes.Add(result); + return result; + } + + public static CheckMessageScope CheckResponseScope(ObserveStructuredMessagePolicy policy) + { + CheckMessageScope result = new() + { + _isRequestScope = false, + _policy = policy + }; + result._policy._responseScopes.Add(result); + return result; + } + + public void Dispose() + { + (_isRequestScope ? _policy._requestScopes : _policy._responseScopes).Remove(this); + } + } + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs new file mode 100644 index 0000000000000..0f9cdb07773f7 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Linq; +using System.Text; +using Azure.Core; +using NUnit.Framework; + +namespace Azure.Storage; + +public static partial class RequestExtensions +{ + public static string AssertHeaderPresent(this Request request, string headerName) + { + if (request.Headers.TryGetValue(headerName, out string value)) + { + return headerName == Constants.StructuredMessage.CrcStructuredMessageHeader ? null : value; + } + StringBuilder sb = new StringBuilder() + .AppendLine($"`{headerName}` expected on request but was not found.") + .AppendLine($"{request.Method} {request.Uri}") + .AppendLine(string.Join("\n", request.Headers.Select(h => $"{h.Name}: {h.Value}s"))) + ; + Assert.Fail(sb.ToString()); + return null; + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/TamperStreamContentsPolicy.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/TamperStreamContentsPolicy.cs index f4198e9dfd532..7e6c78117f53b 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Shared/TamperStreamContentsPolicy.cs +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/TamperStreamContentsPolicy.cs @@ -14,7 +14,7 @@ internal class TamperStreamContentsPolicy : HttpPipelineSynchronousPolicy /// /// Default tampering that changes the first byte of the stream. /// - private static readonly Func _defaultStreamTransform = stream => + private static Func GetTamperByteStreamTransform(long position) => stream => { if (stream is not MemoryStream) { @@ -23,10 +23,10 @@ internal class TamperStreamContentsPolicy : HttpPipelineSynchronousPolicy stream = buffer; } - stream.Position = 0; + stream.Position = position; var firstByte = stream.ReadByte(); - stream.Position = 0; + stream.Position = position; stream.WriteByte((byte)((firstByte + 1) % byte.MaxValue)); stream.Position = 0; @@ -37,9 +37,12 @@ internal class TamperStreamContentsPolicy : HttpPipelineSynchronousPolicy public TamperStreamContentsPolicy(Func streamTransform = default) { - _streamTransform = streamTransform ?? _defaultStreamTransform; + _streamTransform = streamTransform ?? GetTamperByteStreamTransform(0); } + public static TamperStreamContentsPolicy TamperByteAt(long position) + => new(GetTamperByteStreamTransform(position)); + public bool TransformRequestBody { get; set; } public bool TransformResponseBody { get; set; } diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs index 201092978627c..ff06ddc4cbd92 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs @@ -196,18 +196,12 @@ protected string GetNewResourceName() internal static Action GetRequestChecksumHeaderAssertion(StorageChecksumAlgorithm algorithm, Func isChecksumExpected = default, byte[] expectedChecksum = default) { // action to assert a request header is as expected - void AssertChecksum(RequestHeaders headers, string headerName) + void AssertChecksum(Request req, string headerName) { - if (headers.TryGetValue(headerName, out string checksum)) - { - if (expectedChecksum != default) - { - Assert.AreEqual(Convert.ToBase64String(expectedChecksum), checksum); - } - } - else + string checksum = req.AssertHeaderPresent(headerName); + if (expectedChecksum != default) { - Assert.Fail($"{headerName} expected on request but was not found."); + Assert.AreEqual(Convert.ToBase64String(expectedChecksum), checksum); } }; @@ -222,10 +216,10 @@ void AssertChecksum(RequestHeaders headers, string headerName) switch (algorithm.ResolveAuto()) { case StorageChecksumAlgorithm.MD5: - AssertChecksum(request.Headers, "Content-MD5"); + AssertChecksum(request, "Content-MD5"); break; case StorageChecksumAlgorithm.StorageCrc64: - AssertChecksum(request.Headers, "x-ms-content-crc64"); + AssertChecksum(request, Constants.StructuredMessage.CrcStructuredMessageHeader); break; default: throw new Exception($"Bad {nameof(StorageChecksumAlgorithm)} provided to {nameof(GetRequestChecksumHeaderAssertion)}."); @@ -308,7 +302,7 @@ void AssertChecksum(ResponseHeaders headers, string headerName) AssertChecksum(response.Headers, "Content-MD5"); break; case StorageChecksumAlgorithm.StorageCrc64: - AssertChecksum(response.Headers, "x-ms-content-crc64"); + AssertChecksum(response.Headers, Constants.StructuredMessage.CrcStructuredMessageHeader); break; default: throw new Exception($"Bad {nameof(StorageChecksumAlgorithm)} provided to {nameof(GetRequestChecksumHeaderAssertion)}."); @@ -355,7 +349,7 @@ internal static void AssertWriteChecksumMismatch( var exception = ThrowsOrInconclusiveAsync(writeAction); if (expectStructuredMessage) { - Assert.That(exception.ErrorCode, Is.EqualTo("InvalidStructuredMessage")); + Assert.That(exception.ErrorCode, Is.EqualTo("Crc64Mismatch")); } else { @@ -481,7 +475,9 @@ public virtual async Task UploadPartitionUsePrecalculatedHash(StorageChecksumAlg // make pipeline assertion for checking precalculated checksum was present on upload // precalculated partition upload will never use structured message. always check header - var assertion = GetRequestChecksumHeaderAssertion(algorithm, expectedChecksum: precalculatedChecksum); + var assertion = GetRequestChecksumHeaderAssertion( + algorithm, + expectedChecksum: algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 ? default : precalculatedChecksum); var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: assertion); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); @@ -499,7 +495,7 @@ public virtual async Task UploadPartitionUsePrecalculatedHash(StorageChecksumAlg AsyncTestDelegate operation = async () => await UploadPartitionAsync(client, stream, validationOptions); // Assert - AssertWriteChecksumMismatch(operation, algorithm); + AssertWriteChecksumMismatch(operation, algorithm, algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64); } } @@ -517,7 +513,7 @@ public virtual async Task UploadPartitionTamperedStreamThrows(StorageChecksumAlg }; // Tamper with stream contents in the pipeline to simulate silent failure in the transit layer - var streamTamperPolicy = new TamperStreamContentsPolicy(); + var streamTamperPolicy = TamperStreamContentsPolicy.TamperByteAt(100); var clientOptions = ClientBuilder.GetOptions(); clientOptions.AddPolicy(streamTamperPolicy, HttpPipelinePosition.PerCall); @@ -650,7 +646,7 @@ public virtual async Task UploadPartitionDisablesDefaultClientValidationOptions( { Assert.Fail($"Hash found when none expected."); } - if (request.Headers.Contains("x-ms-content-crc64")) + if (request.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -702,7 +698,9 @@ public virtual async Task OpenWriteSuccessfulHashComputation( // make pipeline assertion for checking checksum was present on upload var checksumPipelineAssertion = new AssertMessageContentsPolicy(checkRequest: GetRequestChecksumHeaderAssertion(algorithm)); var clientOptions = ClientBuilder.GetOptions(); + //ObserveStructuredMessagePolicy observe = new(); clientOptions.AddPolicy(checksumPipelineAssertion, HttpPipelinePosition.PerCall); + //clientOptions.AddPolicy(observe, HttpPipelinePosition.BeforeTransport); var client = await GetResourceClientAsync( disposingContainer.Container, @@ -715,6 +713,7 @@ public virtual async Task OpenWriteSuccessfulHashComputation( using var writeStream = await OpenWriteAsync(client, validationOptions, streamBufferSize); // Assert + //using var obsv = observe.CheckRequestScope(); using (checksumPipelineAssertion.CheckRequestScope()) { foreach (var _ in Enumerable.Range(0, streamWrites)) @@ -743,7 +742,7 @@ public virtual async Task OpenWriteMismatchedHashThrows(StorageChecksumAlgorithm // Tamper with stream contents in the pipeline to simulate silent failure in the transit layer var clientOptions = ClientBuilder.GetOptions(); - var tamperPolicy = new TamperStreamContentsPolicy(); + var tamperPolicy = TamperStreamContentsPolicy.TamperByteAt(100); clientOptions.AddPolicy(tamperPolicy, HttpPipelinePosition.PerCall); var client = await GetResourceClientAsync( @@ -873,7 +872,7 @@ public virtual async Task OpenWriteDisablesDefaultClientValidationOptions( { Assert.Fail($"Hash found when none expected."); } - if (request.Headers.Contains("x-ms-content-crc64")) + if (request.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -1236,7 +1235,7 @@ public virtual async Task ParallelUploadDisablesDefaultClientValidationOptions( { Assert.Fail($"Hash found when none expected."); } - if (request.Headers.Contains("x-ms-content-crc64")) + if (request.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -1301,15 +1300,17 @@ public virtual async Task ParallelDownloadSuccessfulHashVerification( }; // Act - var dest = new MemoryStream(); + byte[] dest; + using (MemoryStream ms = new()) using (checksumPipelineAssertion.CheckRequestScope()) { - await ParallelDownloadAsync(client, dest, validationOptions, transferOptions); + await ParallelDownloadAsync(client, ms, validationOptions, transferOptions); + dest = ms.ToArray(); } // Assert // Assertion was in the pipeline and the SDK not throwing means the checksum was validated - Assert.IsTrue(dest.ToArray().SequenceEqual(data)); + Assert.IsTrue(dest.SequenceEqual(data)); } [Test] @@ -1474,7 +1475,7 @@ public virtual async Task ParallelDownloadDisablesDefaultClientValidationOptions { Assert.Fail($"Hash found when none expected."); } - if (response.Headers.Contains("x-ms-content-crc64")) + if (response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -1682,7 +1683,7 @@ public virtual async Task OpenReadDisablesDefaultClientValidationOptions( { Assert.Fail($"Hash found when none expected."); } - if (response.Headers.Contains("x-ms-content-crc64")) + if (response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -1743,7 +1744,7 @@ public virtual async Task DownloadSuccessfulHashVerification(StorageChecksumAlgo Assert.True(response.Headers.Contains("Content-MD5")); break; case StorageChecksumAlgorithm.StorageCrc64: - Assert.True(response.Headers.Contains("x-ms-content-crc64")); + Assert.True(response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)); break; default: Assert.Fail("Test can't validate given algorithm type."); @@ -1834,7 +1835,9 @@ public virtual async Task DownloadHashMismatchThrows( // alter response contents in pipeline, forcing a checksum mismatch on verification step var clientOptions = ClientBuilder.GetOptions(); - clientOptions.AddPolicy(new TamperStreamContentsPolicy() { TransformResponseBody = true }, HttpPipelinePosition.PerCall); + var tamperPolicy = TamperStreamContentsPolicy.TamperByteAt(50); + tamperPolicy.TransformResponseBody = true; + clientOptions.AddPolicy(tamperPolicy, HttpPipelinePosition.PerCall); client = await GetResourceClientAsync( disposingContainer.Container, createResource: false, @@ -1846,7 +1849,7 @@ public virtual async Task DownloadHashMismatchThrows( AsyncTestDelegate operation = async () => await DownloadPartitionAsync(client, dest, validationOptions, new HttpRange(length: data.Length)); // Assert - if (validate) + if (validate || algorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64) { // SDK responsible for finding bad checksum. Throw. ThrowsOrInconclusiveAsync(operation); @@ -1904,7 +1907,7 @@ public virtual async Task DownloadUsesDefaultClientValidationOptions( Assert.True(response.Headers.Contains("Content-MD5")); break; case StorageChecksumAlgorithm.StorageCrc64: - Assert.True(response.Headers.Contains("x-ms-content-crc64")); + Assert.True(response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)); break; default: Assert.Fail("Test can't validate given algorithm type."); @@ -1964,7 +1967,7 @@ public virtual async Task DownloadOverwritesDefaultClientValidationOptions( Assert.True(response.Headers.Contains("Content-MD5")); break; case StorageChecksumAlgorithm.StorageCrc64: - Assert.True(response.Headers.Contains("x-ms-content-crc64")); + Assert.True(response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)); break; default: Assert.Fail("Test can't validate given algorithm type."); @@ -2003,7 +2006,7 @@ public virtual async Task DownloadDisablesDefaultClientValidationOptions( { Assert.Fail($"Hash found when none expected."); } - if (response.Headers.Contains("x-ms-content-crc64")) + if (response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)) { Assert.Fail($"Hash found when none expected."); } @@ -2026,7 +2029,7 @@ public virtual async Task DownloadDisablesDefaultClientValidationOptions( // Assert // no policies this time; just check response headers Assert.False(response.Headers.Contains("Content-MD5")); - Assert.False(response.Headers.Contains("x-ms-content-crc64")); + Assert.False(response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessage)); Assert.IsTrue(dest.ToArray().SequenceEqual(data)); } diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs index 666933e546189..39d2a5566b5ff 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs @@ -39,7 +39,7 @@ public async ValueTask UninterruptedStream() // mock with a simple MemoryStream rather than an actual StructuredMessageDecodingStream using (Stream src = new MemoryStream(data)) - using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream(src, new(), default, default, default, 1)) + using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream(src, new(), default, default, default, default, 1)) using (Stream dst = new MemoryStream(dest)) { await retriableSrc.CopyToInternal(dst, Async, default); @@ -89,6 +89,7 @@ public async Task Interrupt_DataIntact([Values(true, false)] bool multipleInterr initialDecodedData, offset => Factory(offset, multipleInterrupts), offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData)>(Factory(offset, multipleInterrupts)), + null, AllExceptionsRetry().Object, int.MaxValue)) using (Stream dst = new MemoryStream(dest)) @@ -159,6 +160,7 @@ public async Task Interrupt_AppropriateRewind() initialDecodedData, offset => (mock.Object, new()), offset => new(Task.FromResult((mock.Object, new StructuredMessageDecodingStream.DecodedData()))), + null, AllExceptionsRetry().Object, 1); @@ -215,6 +217,7 @@ public async Task Interrupt_ProperDecode([Values(true, false)] bool multipleInte decodedData, offset => Factory(offset, multipleInterrupts), offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData)>(Factory(offset, multipleInterrupts)), + null, AllExceptionsRetry().Object, int.MaxValue); using Stream dst = new MemoryStream(dest); diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageEncodingStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageEncodingStreamTests.cs index d4c667b937bf5..e0f91dee7de3a 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageEncodingStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageEncodingStreamTests.cs @@ -218,6 +218,31 @@ public void NotSupportsFastForwardBeyondLatestRead() Assert.That(() => encodingStream.Position = 123, Throws.TypeOf()); } + [Test] + [Pairwise] + public async Task WrapperStreamCorrectData( + [Values(2048, 2005)] int dataLength, + [Values(8 * Constants.KB, 512, 530, 3)] int readLen) + { + int segmentContentLength = dataLength; + Flags flags = Flags.StorageCrc64; + + byte[] originalData = new byte[dataLength]; + new Random().NextBytes(originalData); + byte[] crc = CrcInline(originalData); + byte[] expectedEncodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentContentLength, flags); + + Stream encodingStream = new StructuredMessagePrecalculatedCrcWrapperStream(new MemoryStream(originalData), crc); + byte[] encodedData; + using (MemoryStream dest = new()) + { + await CopyStream(encodingStream, dest, readLen); + encodedData = dest.ToArray(); + } + + Assert.That(new Span(encodedData).SequenceEqual(expectedEncodedData)); + } + private static void AssertExpectedStreamHeader(ReadOnlySpan actual, int originalDataLength, Flags flags, int expectedSegments) { int expectedFooterLen = flags.HasFlag(Flags.StorageCrc64) ? Crc64Length : 0; diff --git a/sdk/storage/Azure.Storage.DataMovement.Blobs/src/Azure.Storage.DataMovement.Blobs.csproj b/sdk/storage/Azure.Storage.DataMovement.Blobs/src/Azure.Storage.DataMovement.Blobs.csproj index 6098dcd8ba33d..93e7432f186e3 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Blobs/src/Azure.Storage.DataMovement.Blobs.csproj +++ b/sdk/storage/Azure.Storage.DataMovement.Blobs/src/Azure.Storage.DataMovement.Blobs.csproj @@ -37,6 +37,7 @@ + diff --git a/sdk/storage/Azure.Storage.DataMovement.Blobs/tests/Azure.Storage.DataMovement.Blobs.Tests.csproj b/sdk/storage/Azure.Storage.DataMovement.Blobs/tests/Azure.Storage.DataMovement.Blobs.Tests.csproj index f8b62d0b947e2..214903eb5f9c4 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Blobs/tests/Azure.Storage.DataMovement.Blobs.Tests.csproj +++ b/sdk/storage/Azure.Storage.DataMovement.Blobs/tests/Azure.Storage.DataMovement.Blobs.Tests.csproj @@ -22,11 +22,15 @@ + + + + @@ -40,6 +44,7 @@ + diff --git a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/BlobToFileSharesTests/Azure.Storage.DataMovement.Blobs.Files.Shares.Tests.csproj b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/BlobToFileSharesTests/Azure.Storage.DataMovement.Blobs.Files.Shares.Tests.csproj index a6abde432473f..66a9fea0861a2 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/BlobToFileSharesTests/Azure.Storage.DataMovement.Blobs.Files.Shares.Tests.csproj +++ b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/BlobToFileSharesTests/Azure.Storage.DataMovement.Blobs.Files.Shares.Tests.csproj @@ -35,6 +35,7 @@ + diff --git a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Azure.Storage.DataMovement.Files.Shares.Tests.csproj b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Azure.Storage.DataMovement.Files.Shares.Tests.csproj index 8e574bca36a48..d75775beceafd 100644 --- a/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Azure.Storage.DataMovement.Files.Shares.Tests.csproj +++ b/sdk/storage/Azure.Storage.DataMovement.Files.Shares/tests/Azure.Storage.DataMovement.Files.Shares.Tests.csproj @@ -27,6 +27,7 @@ + diff --git a/sdk/storage/Azure.Storage.DataMovement/tests/Azure.Storage.DataMovement.Tests.csproj b/sdk/storage/Azure.Storage.DataMovement/tests/Azure.Storage.DataMovement.Tests.csproj index 8afd7735a0168..21a1ea45f92a0 100644 --- a/sdk/storage/Azure.Storage.DataMovement/tests/Azure.Storage.DataMovement.Tests.csproj +++ b/sdk/storage/Azure.Storage.DataMovement/tests/Azure.Storage.DataMovement.Tests.csproj @@ -34,6 +34,7 @@ + diff --git a/sdk/storage/Azure.Storage.Files.DataLake/assets.json b/sdk/storage/Azure.Storage.Files.DataLake/assets.json index 7329a98a34f40..e1cf4a6fa1c07 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/assets.json +++ b/sdk/storage/Azure.Storage.Files.DataLake/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "net", "TagPrefix": "net/storage/Azure.Storage.Files.DataLake", - "Tag": "net/storage/Azure.Storage.Files.DataLake_9c23b9b180" + "Tag": "net/storage/Azure.Storage.Files.DataLake_c7efac8f52" } diff --git a/sdk/storage/Azure.Storage.Files.DataLake/src/Azure.Storage.Files.DataLake.csproj b/sdk/storage/Azure.Storage.Files.DataLake/src/Azure.Storage.Files.DataLake.csproj index 7467863a61be0..8a2e0bcd97d46 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/src/Azure.Storage.Files.DataLake.csproj +++ b/sdk/storage/Azure.Storage.Files.DataLake/src/Azure.Storage.Files.DataLake.csproj @@ -84,6 +84,7 @@ + diff --git a/sdk/storage/Azure.Storage.Files.DataLake/src/DataLakeFileClient.cs b/sdk/storage/Azure.Storage.Files.DataLake/src/DataLakeFileClient.cs index 6088e970ec9e0..93ca4c3f9a1fd 100644 --- a/sdk/storage/Azure.Storage.Files.DataLake/src/DataLakeFileClient.cs +++ b/sdk/storage/Azure.Storage.Files.DataLake/src/DataLakeFileClient.cs @@ -2339,17 +2339,20 @@ internal virtual async Task AppendInternal( string structuredBodyType = null; if (content != null && validationOptions != null && - validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64 && - validationOptions.PrecalculatedChecksum.IsEmpty) + validationOptions.ChecksumAlgorithm.ResolveAuto() == StorageChecksumAlgorithm.StorageCrc64) { // report progress in terms of caller bytes, not encoded bytes structuredContentLength = contentLength; structuredBodyType = Constants.StructuredMessage.CrcStructuredMessage; content = content.WithNoDispose().WithProgress(progressHandler); - content = new StructuredMessageEncodingStream( - content, - Constants.StructuredMessage.DefaultSegmentContentLength, - StructuredMessage.Flags.StorageCrc64); + content = validationOptions.PrecalculatedChecksum.IsEmpty + ? new StructuredMessageEncodingStream( + content, + Constants.StructuredMessage.DefaultSegmentContentLength, + StructuredMessage.Flags.StorageCrc64) + : new StructuredMessagePrecalculatedCrcWrapperStream( + content, + validationOptions.PrecalculatedChecksum.Span); contentLength = content.Length - content.Position; } else diff --git a/sdk/storage/Azure.Storage.Files.Shares/assets.json b/sdk/storage/Azure.Storage.Files.Shares/assets.json index d4df7130a51d0..7a34db9bef740 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/assets.json +++ b/sdk/storage/Azure.Storage.Files.Shares/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "net", "TagPrefix": "net/storage/Azure.Storage.Files.Shares", - "Tag": "net/storage/Azure.Storage.Files.Shares_5e5b51e54d" + "Tag": "net/storage/Azure.Storage.Files.Shares_4ec8e7e485" } diff --git a/sdk/storage/Azure.Storage.Files.Shares/tests/Azure.Storage.Files.Shares.Tests.csproj b/sdk/storage/Azure.Storage.Files.Shares/tests/Azure.Storage.Files.Shares.Tests.csproj index 398a4b6367489..d09dd8fe8949f 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/tests/Azure.Storage.Files.Shares.Tests.csproj +++ b/sdk/storage/Azure.Storage.Files.Shares/tests/Azure.Storage.Files.Shares.Tests.csproj @@ -17,6 +17,7 @@ + PreserveNewest diff --git a/sdk/storage/Azure.Storage.Queues/tests/Azure.Storage.Queues.Tests.csproj b/sdk/storage/Azure.Storage.Queues/tests/Azure.Storage.Queues.Tests.csproj index e0a6fab3c753b..4d0334255f041 100644 --- a/sdk/storage/Azure.Storage.Queues/tests/Azure.Storage.Queues.Tests.csproj +++ b/sdk/storage/Azure.Storage.Queues/tests/Azure.Storage.Queues.Tests.csproj @@ -21,6 +21,7 @@ +