diff --git a/storage/transfermanager/downloader.go b/storage/transfermanager/downloader.go index 3f6107627dda..45d765e0b8fb 100644 --- a/storage/transfermanager/downloader.go +++ b/storage/transfermanager/downloader.go @@ -18,6 +18,8 @@ import ( "context" "errors" "fmt" + "hash" + "hash/crc32" "io" "io/fs" "math" @@ -31,6 +33,12 @@ import ( "google.golang.org/api/iterator" ) +// maxChecksumZeroArraySize is the maximum amount of memory to allocate for +// updating the checksum. A larger size will occupy more memory but will require +// fewer updates when computing the crc32c of a full object. +// TODO: test the performance of smaller values for this. +const maxChecksumZeroArraySize = 4 * 1024 * 1024 + // Downloader manages a set of parallelized downloads. type Downloader struct { client *storage.Client @@ -288,7 +296,7 @@ func (d *Downloader) addNewInputs(inputs []DownloadObjectInput) { } func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutput) { - copiedResult := *result // make a copy so that callbacks do not affect the result + copiedResult := *result // make a copy so that callbacks do not affect the result if input.directory { f := input.Destination.(*os.File) @@ -305,7 +313,6 @@ func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutpu input.directoryObjectOutputs <- copiedResult } } - // TODO: check checksum if full object if d.config.asynchronous || input.directory { input.Callback(result) @@ -337,27 +344,10 @@ func (d *Downloader) downloadWorker() { break // no more work; exit } - out := input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize) - if input.shard == 0 { - if out.Err != nil { - // Don't queue more shards if the first failed. - d.addResult(input, out) - } else { - numShards := numShards(out.Attrs, input.Range, d.config.partSize) - - if numShards <= 1 { - // Download completed with a single shard. - d.addResult(input, out) - } else { - // Queue more shards. - outs := d.queueShards(input, out.Attrs.Generation, numShards) - // Start a goroutine that gathers shards sent to the output - // channel and adds the result once it has received all shards. - go d.gatherShards(input, outs, numShards) - } - } + d.startDownload(input) } else { + out := input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize) // If this isn't the first shard, send to the output channel specific to the object. // This should never block since the channel is buffered to exactly the number of shards. input.shardOutputs <- out @@ -366,6 +356,47 @@ func (d *Downloader) downloadWorker() { d.workers.Done() } +// startDownload downloads the first shard and schedules subsequent shards +// if necessary. +func (d *Downloader) startDownload(input *DownloadObjectInput) { + var out *DownloadOutput + + // Full object read. Request the full object and only read partSize bytes + // (or the full object, if smaller than partSize), so that we can avoid a + // metadata call to grab the CRC32C for JSON downloads. + if fullObjectRead(input.Range) { + input.checkCRC = true + out = input.downloadFirstShard(d.client, d.config.perOperationTimeout, d.config.partSize) + } else { + out = input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize) + } + + if out.Err != nil { + // Don't queue more shards if the first failed. + d.addResult(input, out) + return + } + + numShards := numShards(out.Attrs, input.Range, d.config.partSize) + input.checkCRC = input.checkCRC && !out.Attrs.Decompressed // do not checksum if the object was decompressed + + if numShards > 1 { + outs := d.queueShards(input, out.Attrs.Generation, numShards) + // Start a goroutine that gathers shards sent to the output + // channel and adds the result once it has received all shards. + go d.gatherShards(input, out, outs, numShards, out.crc32c) + + } else { + // Download completed with a single shard. + if input.checkCRC { + if err := checksumObject(out.crc32c, out.Attrs.CRC32C); err != nil { + out.Err = err + } + } + d.addResult(input, out) + } +} + // queueShards queues all subsequent shards of an object after the first. // The results should be forwarded to the returned channel. func (d *Downloader) queueShards(in *DownloadObjectInput, gen int64, shards int) <-chan *DownloadOutput { @@ -397,12 +428,12 @@ var errCancelAllShards = errors.New("cancelled because another shard failed") // It will add the result to the Downloader once it has received all shards. // gatherShards cancels remaining shards if any shard errored. // It does not do any checking to verify that shards are for the same object. -func (d *Downloader) gatherShards(in *DownloadObjectInput, outs <-chan *DownloadOutput, shards int) { +func (d *Downloader) gatherShards(in *DownloadObjectInput, out *DownloadOutput, outs <-chan *DownloadOutput, shards int, firstPieceCRC uint32) { errs := []error{} - var shardOut *DownloadOutput + orderedChecksums := make([]crc32cPiece, shards-1) + for i := 1; i < shards; i++ { - // Add monitoring here? This could hang if any individual piece does. - shardOut = <-outs + shardOut := <-outs // We can ignore errors that resulted from a previous error. // Note that we may still get some cancel errors if they @@ -412,20 +443,30 @@ func (d *Downloader) gatherShards(in *DownloadObjectInput, outs <-chan *Download errs = append(errs, shardOut.Err) in.cancelCtx(errCancelAllShards) } + + orderedChecksums[shardOut.shard-1] = crc32cPiece{sum: shardOut.crc32c, length: shardOut.shardLength} + } + + // All pieces gathered. + if len(errs) == 0 && in.checkCRC && out.Attrs != nil { + fullCrc := joinCRC32C(firstPieceCRC, orderedChecksums) + if err := checksumObject(fullCrc, out.Attrs.CRC32C); err != nil { + errs = append(errs, err) + } } - // All pieces gathered; return output. Any shard output will do. - shardOut.Range = in.Range + // Prepare output. + out.Range = in.Range if len(errs) != 0 { - shardOut.Err = fmt.Errorf("download shard errors:\n%w", errors.Join(errs...)) + out.Err = fmt.Errorf("download shard errors:\n%w", errors.Join(errs...)) } - if shardOut.Attrs != nil { - shardOut.Attrs.StartOffset = 0 + if out.Attrs != nil { + out.Attrs.StartOffset = 0 if in.Range != nil { - shardOut.Attrs.StartOffset = in.Range.Offset + out.Attrs.StartOffset = in.Range.Offset } } - d.addResult(in, shardOut) + d.addResult(in, out) } // gatherObjectOutputs receives from the given channel exactly numObjects times. @@ -563,6 +604,7 @@ type DownloadObjectInput struct { shardOutputs chan<- *DownloadOutput directory bool // input was queued by calling DownloadDirectory directoryObjectOutputs chan<- DownloadOutput + checkCRC bool } // downloadShard will read a specific object piece into in.Destination. @@ -570,38 +612,10 @@ type DownloadObjectInput struct { func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout time.Duration, partSize int64) (out *DownloadOutput) { out = &DownloadOutput{Bucket: in.Bucket, Object: in.Object, Range: in.Range} - // Set timeout. - ctx := in.ctx - if timeout > 0 { - c, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - ctx = c - } - - // The first shard will be sent as download many, since we do not know yet - // if it will be sharded. - method := downloadMany - if in.shard != 0 { - method = downloadSharded - } - ctx = setUsageMetricHeader(ctx, method) - - // Set options on the object. - o := client.Bucket(in.Bucket).Object(in.Object) - - if in.Conditions != nil { - o = o.If(*in.Conditions) - } - if in.Generation != nil { - o = o.Generation(*in.Generation) - } - if len(in.EncryptionKey) > 0 { - o = o.Key(in.EncryptionKey) - } - objRange := shardRange(in.Range, partSize, in.shard) + ctx := in.setOptionsOnContext(timeout) + o := in.setOptionsOnObject(client) - // Read. r, err := o.NewRangeReader(ctx, objRange.Offset, objRange.Length) if err != nil { out.Err = err @@ -618,9 +632,63 @@ func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout tim } } - w := io.NewOffsetWriter(in.Destination, offset) - _, err = io.Copy(w, r) + var w io.Writer + w = io.NewOffsetWriter(in.Destination, offset) + + var crcHash hash.Hash32 + if in.checkCRC { + crcHash = crc32.New(crc32.MakeTable(crc32.Castagnoli)) + w = io.MultiWriter(w, crcHash) + } + + n, err := io.Copy(w, r) + if err != nil { + out.Err = err + r.Close() + return + } + + if err = r.Close(); err != nil { + out.Err = err + return + } + + out.Attrs = &r.Attrs + out.shard = in.shard + out.shardLength = n + if in.checkCRC { + out.crc32c = crcHash.Sum32() + } + return +} + +// downloadFirstShard will read the first object piece into in.Destination. +// If timeout is less than 0, no timeout is set. +func (in *DownloadObjectInput) downloadFirstShard(client *storage.Client, timeout time.Duration, partSize int64) (out *DownloadOutput) { + out = &DownloadOutput{Bucket: in.Bucket, Object: in.Object, Range: in.Range} + + ctx := in.setOptionsOnContext(timeout) + o := in.setOptionsOnObject(client) + + r, err := o.NewReader(ctx) if err != nil { + out.Err = err + return + } + + var w io.Writer + w = io.NewOffsetWriter(in.Destination, 0) + + var crcHash hash.Hash32 + if in.checkCRC { + crcHash = crc32.New(crc32.MakeTable(crc32.Castagnoli)) + w = io.MultiWriter(w, crcHash) + } + + // Copy only the first partSize bytes before closing the reader. + // If we encounter an EOF, the file was smaller than partSize. + n, err := io.CopyN(w, r, partSize) + if err != nil && err != io.EOF { out.Err = err r.Close() return @@ -632,9 +700,45 @@ func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout tim } out.Attrs = &r.Attrs + out.shard = in.shard + out.shardLength = n + if in.checkCRC { + out.crc32c = crcHash.Sum32() + } return } +func (in *DownloadObjectInput) setOptionsOnContext(timeout time.Duration) context.Context { + ctx := in.ctx + if timeout > 0 { + c, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + ctx = c + } + + // The first shard will be sent as download many, since we do not know yet + // if it will be sharded. + method := downloadMany + if in.shard != 0 { + method = downloadSharded + } + return setUsageMetricHeader(ctx, method) +} + +func (in *DownloadObjectInput) setOptionsOnObject(client *storage.Client) *storage.ObjectHandle { + o := client.Bucket(in.Bucket).Object(in.Object) + if in.Conditions != nil { + o = o.If(*in.Conditions) + } + if in.Generation != nil { + o = o.Generation(*in.Generation) + } + if len(in.EncryptionKey) > 0 { + o = o.Key(in.EncryptionKey) + } + return o +} + // DownloadDirectoryInput is the input for a directory to download. type DownloadDirectoryInput struct { // Bucket is the bucket in GCS to download from. Required. @@ -686,6 +790,10 @@ type DownloadOutput struct { Range *DownloadRange // requested range, if it was specified Err error // error occurring during download Attrs *storage.ReaderObjectAttrs // attributes of downloaded object, if successful + + shard int + shardLength int64 + crc32c uint32 } // TODO: use built-in after go < 1.21 is dropped. @@ -784,3 +892,48 @@ func setUsageMetricHeader(ctx context.Context, method string) context.Context { header := fmt.Sprintf("%s/%s", usageMetricKey, method) return callctx.SetHeaders(ctx, xGoogHeaderKey, header) } + +type crc32cPiece struct { + sum uint32 // crc32c checksum of the piece + length int64 // number of bytes in this piece +} + +// joinCRC32C pieces together the initial checksum with the orderedChecksums +// provided to calculate the checksum of the whole. +func joinCRC32C(initialChecksum uint32, orderedChecksums []crc32cPiece) uint32 { + base := initialChecksum + + zeroes := make([]byte, maxChecksumZeroArraySize) + for _, part := range orderedChecksums { + // Precondition Base (flip every bit) + base ^= 0xFFFFFFFF + + // Zero pad base crc32c. To conserve memory, do so with only maxChecksumZeroArraySize + // at a time. Reuse the zeroes array where possible. + var padded int64 = 0 + for padded < part.length { + desiredZeroes := min(part.length-padded, maxChecksumZeroArraySize) + base = crc32.Update(base, crc32.MakeTable(crc32.Castagnoli), zeroes[:desiredZeroes]) + padded += desiredZeroes + } + + // Postcondition Base (same as precondition, this switches the bits back) + base ^= 0xFFFFFFFF + + // Bitwise OR between Base and Part to produce a new Base + base ^= part.sum + } + return base +} + +func fullObjectRead(r *DownloadRange) bool { + return r == nil || (r.Offset == 0 && r.Length < 0) +} + +func checksumObject(got, want uint32) error { + // Only checksum the object if we have a valid CRC32C. + if want != 0 && want != got { + return fmt.Errorf("bad CRC on read: got %d, want %d", got, want) + } + return nil +} diff --git a/storage/transfermanager/downloader_test.go b/storage/transfermanager/downloader_test.go index 2c469e5ef99b..0f6a9b33a2e5 100644 --- a/storage/transfermanager/downloader_test.go +++ b/storage/transfermanager/downloader_test.go @@ -15,6 +15,7 @@ package transfermanager import ( + "bytes" "context" "errors" "strings" @@ -411,11 +412,12 @@ func TestGatherShards(t *testing.T) { Offset: 20, Length: 120, } + firstOut := &DownloadOutput{Object: object, Range: &DownloadRange{Offset: 20, Length: 30}, shard: 0} outChan := make(chan *DownloadOutput, shards) outs := []*DownloadOutput{ - {Object: object, Range: &DownloadRange{Offset: 50, Length: 30}}, - {Object: object, Range: &DownloadRange{Offset: 80, Length: 30}}, - {Object: object, Range: &DownloadRange{Offset: 110, Length: 30}}, + {Object: object, Range: &DownloadRange{Offset: 50, Length: 30}, shard: 1}, + {Object: object, Range: &DownloadRange{Offset: 80, Length: 30}, shard: 2}, + {Object: object, Range: &DownloadRange{Offset: 110, Length: 30}, shard: 3}, } in := &DownloadObjectInput{ @@ -441,7 +443,7 @@ func TestGatherShards(t *testing.T) { d.downloadsInProgress.Add(1) go func() { - d.gatherShards(in, outChan, shards) + d.gatherShards(in, firstOut, outChan, shards, 0) wg.Done() }() @@ -475,7 +477,7 @@ func TestGatherShards(t *testing.T) { d.downloadsInProgress.Add(1) go func() { - d.gatherShards(in, outChan, shards) + d.gatherShards(in, firstOut, outChan, shards, 0) wg.Done() }() @@ -507,3 +509,43 @@ func TestGatherShards(t *testing.T) { t.Errorf("error in DownloadOutput should not contain error %q; got: %v", errCancelAllShards, err) } } + +func TestCalculateCRC32C(t *testing.T) { + t.Parallel() + for _, test := range []struct { + desc string + pieces []string + }{ + { + desc: "equal sized pieces", + pieces: []string{"he", "ll", "o ", "wo", "rl", "d!"}, + }, + { + desc: "uneven pieces", + pieces: []string{"hello", " ", "world!"}, + }, + { + desc: "large pieces", + pieces: []string{string(bytes.Repeat([]byte("a"), 1024*1024*32)), + string(bytes.Repeat([]byte("b"), 1024*1024*32)), + string(bytes.Repeat([]byte("c"), 1024*1024*32)), + }, + }, + } { + t.Run(test.desc, func(t *testing.T) { + initialChecksum := crc32c([]byte(test.pieces[0])) + + remainingChecksums := make([]crc32cPiece, len(test.pieces)-1) + for i, piece := range test.pieces[1:] { + remainingChecksums[i] = crc32cPiece{sum: crc32c([]byte(piece)), length: int64(len(piece))} + } + + got := joinCRC32C(initialChecksum, remainingChecksums) + want := crc32c([]byte(strings.Join(test.pieces, ""))) + + if got != want { + t.Errorf("crc32c not calculated correctly - want %v, got %v", want, got) + } + }) + } +} diff --git a/storage/transfermanager/integration_test.go b/storage/transfermanager/integration_test.go index 59f48d91b908..1d06b5fc67bb 100644 --- a/storage/transfermanager/integration_test.go +++ b/storage/transfermanager/integration_test.go @@ -46,6 +46,7 @@ const ( testPrefix = "go-integration-test-tm" grpcTestPrefix = "golang-grpc-test-tm" bucketExpiryAge = 24 * time.Hour + minObjectSize = 1024 maxObjectSize = 1024 * 1024 ) @@ -721,7 +722,7 @@ func TestIntegration_DownloadShard(t *testing.T) { o := c.Bucket(tb.bucket).Object(objectName) r, err := o.NewReader(ctx) if err != nil { - t.Fatalf("o.Attrs: %v", err) + t.Fatalf("o.NewReader: %v", err) } incorrectGen := r.Attrs.Generation - 1 @@ -1030,7 +1031,7 @@ func (tb *downloadTestBucket) Create(prefix string) error { // Write objects. for _, obj := range tb.objects { - size := randomInt64(1000, maxObjectSize) + size := randomInt64(minObjectSize, maxObjectSize) crc, err := generateFileInGCS(ctx, b.Object(obj), size) if err != nil { return fmt.Errorf("generateFileInGCS: %v", err)