From c0905a5628c23f563d15732b6ca84111ca964519 Mon Sep 17 00:00:00 2001 From: BrennaEpp Date: Sat, 20 Jul 2024 20:25:13 -0700 Subject: [PATCH 1/3] feat(storage/transfermanager): checksum full object downloads --- storage/transfermanager/downloader.go | 293 +++++++++++++++----- storage/transfermanager/downloader_test.go | 52 +++- storage/transfermanager/integration_test.go | 5 +- 3 files changed, 279 insertions(+), 71 deletions(-) diff --git a/storage/transfermanager/downloader.go b/storage/transfermanager/downloader.go index 3f6107627dda..3830f61aaa0e 100644 --- a/storage/transfermanager/downloader.go +++ b/storage/transfermanager/downloader.go @@ -18,6 +18,8 @@ import ( "context" "errors" "fmt" + "hash" + "hash/crc32" "io" "io/fs" "math" @@ -31,6 +33,12 @@ import ( "google.golang.org/api/iterator" ) +// maxChecksumZeroArraySize is the maximum amount of memory to allocate for +// updating the checksum. A larger size will occupy more memory but will require +// more updates when computing the crc32c of a full object. +// TODO: test the performance of smaller values for this. +const maxChecksumZeroArraySize = 4 * 1024 * 1024 + // Downloader manages a set of parallelized downloads. type Downloader struct { client *storage.Client @@ -288,7 +296,7 @@ func (d *Downloader) addNewInputs(inputs []DownloadObjectInput) { } func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutput) { - copiedResult := *result // make a copy so that callbacks do not affect the result + copiedResult := *result // make a copy so that callbacks do not affect the result if input.directory { f := input.Destination.(*os.File) @@ -305,7 +313,6 @@ func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutpu input.directoryObjectOutputs <- copiedResult } } - // TODO: check checksum if full object if d.config.asynchronous || input.directory { input.Callback(result) @@ -337,27 +344,10 @@ func (d *Downloader) downloadWorker() { break // no more work; exit } - out := input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize) - if input.shard == 0 { - if out.Err != nil { - // Don't queue more shards if the first failed. - d.addResult(input, out) - } else { - numShards := numShards(out.Attrs, input.Range, d.config.partSize) - - if numShards <= 1 { - // Download completed with a single shard. - d.addResult(input, out) - } else { - // Queue more shards. - outs := d.queueShards(input, out.Attrs.Generation, numShards) - // Start a goroutine that gathers shards sent to the output - // channel and adds the result once it has received all shards. - go d.gatherShards(input, outs, numShards) - } - } + d.startDownload(input) } else { + out := input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize) // If this isn't the first shard, send to the output channel specific to the object. // This should never block since the channel is buffered to exactly the number of shards. input.shardOutputs <- out @@ -366,6 +356,48 @@ func (d *Downloader) downloadWorker() { d.workers.Done() } +// startDownload downloads the first shard and schedules subsequent shards +// if necessary. +func (d *Downloader) startDownload(input *DownloadObjectInput) { + var out *DownloadOutput + + // Special-case: full object read. + // In this case, we want to request the full object and only read partSize bytes, + // so that we can avoid a metadata call to grab the CRC32C for JSON downloads. + if fullObjectRead(input.Range) { + input.checkCRC = true + out = input.downloadFirstShard(d.client, d.config.perOperationTimeout, d.config.partSize) + } else { + out = input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize) + } + + if out.Err != nil { + // Don't queue more shards if the first failed. + d.addResult(input, out) + return + } + + numShards := numShards(out.Attrs, input.Range, d.config.partSize) + input.checkCRC = input.checkCRC && !out.Attrs.Decompressed // do not checksum if the object was decompressed + + if numShards > 1 { + outs := d.queueShards(input, out.Attrs.Generation, numShards) + // Start a goroutine that gathers shards sent to the output + // channel and adds the result once it has received all shards. + go d.gatherShards(input, out, outs, numShards, out.crc32c) + + } else { + // Download completed with a single shard. + if input.checkCRC { + o := d.client.Bucket(input.Bucket).Object(input.Object).Generation(out.Attrs.Generation) + if err := checksumObject(input.ctx, o, out.crc32c, out.Attrs.CRC32C); err != nil { + out.Err = err + } + } + d.addResult(input, out) + } +} + // queueShards queues all subsequent shards of an object after the first. // The results should be forwarded to the returned channel. func (d *Downloader) queueShards(in *DownloadObjectInput, gen int64, shards int) <-chan *DownloadOutput { @@ -397,12 +429,12 @@ var errCancelAllShards = errors.New("cancelled because another shard failed") // It will add the result to the Downloader once it has received all shards. // gatherShards cancels remaining shards if any shard errored. // It does not do any checking to verify that shards are for the same object. -func (d *Downloader) gatherShards(in *DownloadObjectInput, outs <-chan *DownloadOutput, shards int) { +func (d *Downloader) gatherShards(in *DownloadObjectInput, out *DownloadOutput, outs <-chan *DownloadOutput, shards int, firstPieceCRC uint32) { errs := []error{} - var shardOut *DownloadOutput + orderedChecksums := make([]crc32cPiece, shards-1) + for i := 1; i < shards; i++ { - // Add monitoring here? This could hang if any individual piece does. - shardOut = <-outs + shardOut := <-outs // We can ignore errors that resulted from a previous error. // Note that we may still get some cancel errors if they @@ -412,20 +444,31 @@ func (d *Downloader) gatherShards(in *DownloadObjectInput, outs <-chan *Download errs = append(errs, shardOut.Err) in.cancelCtx(errCancelAllShards) } + + orderedChecksums[shardOut.shard-1] = crc32cPiece{sum: shardOut.crc32c, length: shardOut.shardLength} + } + + // All pieces gathered. + if len(errs) == 0 && in.checkCRC && out.Attrs != nil { + fullCrc := joinCRC32C(firstPieceCRC, orderedChecksums) + o := d.client.Bucket(in.Bucket).Object(in.Object).Generation(out.Attrs.Generation) + if err := checksumObject(in.ctx, o, fullCrc, out.Attrs.CRC32C); err != nil { + errs = append(errs, err) + } } - // All pieces gathered; return output. Any shard output will do. - shardOut.Range = in.Range + // Prepare output. + out.Range = in.Range if len(errs) != 0 { - shardOut.Err = fmt.Errorf("download shard errors:\n%w", errors.Join(errs...)) + out.Err = fmt.Errorf("download shard errors:\n%w", errors.Join(errs...)) } - if shardOut.Attrs != nil { - shardOut.Attrs.StartOffset = 0 + if out.Attrs != nil { + out.Attrs.StartOffset = 0 if in.Range != nil { - shardOut.Attrs.StartOffset = in.Range.Offset + out.Attrs.StartOffset = in.Range.Offset } } - d.addResult(in, shardOut) + d.addResult(in, out) } // gatherObjectOutputs receives from the given channel exactly numObjects times. @@ -563,6 +606,7 @@ type DownloadObjectInput struct { shardOutputs chan<- *DownloadOutput directory bool // input was queued by calling DownloadDirectory directoryObjectOutputs chan<- DownloadOutput + checkCRC bool } // downloadShard will read a specific object piece into in.Destination. @@ -570,38 +614,10 @@ type DownloadObjectInput struct { func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout time.Duration, partSize int64) (out *DownloadOutput) { out = &DownloadOutput{Bucket: in.Bucket, Object: in.Object, Range: in.Range} - // Set timeout. - ctx := in.ctx - if timeout > 0 { - c, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - ctx = c - } - - // The first shard will be sent as download many, since we do not know yet - // if it will be sharded. - method := downloadMany - if in.shard != 0 { - method = downloadSharded - } - ctx = setUsageMetricHeader(ctx, method) - - // Set options on the object. - o := client.Bucket(in.Bucket).Object(in.Object) - - if in.Conditions != nil { - o = o.If(*in.Conditions) - } - if in.Generation != nil { - o = o.Generation(*in.Generation) - } - if len(in.EncryptionKey) > 0 { - o = o.Key(in.EncryptionKey) - } - objRange := shardRange(in.Range, partSize, in.shard) + ctx := in.setOptionsOnContext(timeout) + o := in.setOptionsOnObject(client) - // Read. r, err := o.NewRangeReader(ctx, objRange.Offset, objRange.Length) if err != nil { out.Err = err @@ -618,8 +634,16 @@ func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout tim } } - w := io.NewOffsetWriter(in.Destination, offset) - _, err = io.Copy(w, r) + var w io.Writer + w = io.NewOffsetWriter(in.Destination, offset) + + var crcHash hash.Hash32 + if in.checkCRC { + crcHash = crc32.New(crc32.MakeTable(crc32.Castagnoli)) + w = io.MultiWriter(w, crcHash) + } + + n, err := io.Copy(w, r) if err != nil { out.Err = err r.Close() @@ -632,9 +656,91 @@ func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout tim } out.Attrs = &r.Attrs + out.shard = in.shard + out.shardLength = n + if in.checkCRC { + out.crc32c = crcHash.Sum32() + } return } +// downloadFirstShard will read the first object piece into in.Destination. +// If timeout is less than 0, no timeout is set. +func (in *DownloadObjectInput) downloadFirstShard(client *storage.Client, timeout time.Duration, partSize int64) (out *DownloadOutput) { + out = &DownloadOutput{Bucket: in.Bucket, Object: in.Object, Range: in.Range} + + ctx := in.setOptionsOnContext(timeout) + o := in.setOptionsOnObject(client) + + r, err := o.NewReader(ctx) + if err != nil { + out.Err = err + return + } + + var w io.Writer + w = io.NewOffsetWriter(in.Destination, 0) + + var crcHash hash.Hash32 + if in.checkCRC { + crcHash = crc32.New(crc32.MakeTable(crc32.Castagnoli)) + w = io.MultiWriter(w, crcHash) + } + + // Copy only the first partSize bytes before closing the reader. + // If we encounter an EOF, the file was smaller than partSize. + n, err := io.CopyN(w, r, partSize) + if err != nil && err != io.EOF { + out.Err = err + r.Close() + return + } + + if err = r.Close(); err != nil { + out.Err = err + return + } + + out.Attrs = &r.Attrs + out.shard = in.shard + out.shardLength = n + if in.checkCRC { + out.crc32c = crcHash.Sum32() + } + return +} + +func (in *DownloadObjectInput) setOptionsOnContext(timeout time.Duration) context.Context { + ctx := in.ctx + if timeout > 0 { + c, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + ctx = c + } + + // The first shard will be sent as download many, since we do not know yet + // if it will be sharded. + method := downloadMany + if in.shard != 0 { + method = downloadSharded + } + return setUsageMetricHeader(ctx, method) +} + +func (in *DownloadObjectInput) setOptionsOnObject(client *storage.Client) *storage.ObjectHandle { + o := client.Bucket(in.Bucket).Object(in.Object) + if in.Conditions != nil { + o = o.If(*in.Conditions) + } + if in.Generation != nil { + o = o.Generation(*in.Generation) + } + if len(in.EncryptionKey) > 0 { + o = o.Key(in.EncryptionKey) + } + return o +} + // DownloadDirectoryInput is the input for a directory to download. type DownloadDirectoryInput struct { // Bucket is the bucket in GCS to download from. Required. @@ -686,6 +792,10 @@ type DownloadOutput struct { Range *DownloadRange // requested range, if it was specified Err error // error occurring during download Attrs *storage.ReaderObjectAttrs // attributes of downloaded object, if successful + + shard int + shardLength int64 + crc32c uint32 } // TODO: use built-in after go < 1.21 is dropped. @@ -784,3 +894,58 @@ func setUsageMetricHeader(ctx context.Context, method string) context.Context { header := fmt.Sprintf("%s/%s", usageMetricKey, method) return callctx.SetHeaders(ctx, xGoogHeaderKey, header) } + +type crc32cPiece struct { + sum uint32 // crc32c checksum of the piece + length int64 // number of bytes in this piece +} + +// joinCRC32C pieces together the initial checksum with the orderedChecksums +// provided to calculate the checksum of the whole. +func joinCRC32C(initialChecksum uint32, orderedChecksums []crc32cPiece) uint32 { + base := initialChecksum + + zeroes := make([]byte, maxChecksumZeroArraySize) + for _, part := range orderedChecksums { + // Precondition Base (flip every bit) + base ^= 0xFFFFFFFF + + // Zero pad base crc32c. To conserve memory, do so with only maxChecksumZeroArraySize + // at a time. Reuse the zeroes array where possible. + var padded int64 = 0 + for padded < part.length { + desiredZeroes := min(part.length-padded, maxChecksumZeroArraySize) + base = crc32.Update(base, crc32.MakeTable(crc32.Castagnoli), zeroes[:desiredZeroes]) + padded += desiredZeroes + } + + // Postcondition Base (same as precondition, this switches the bits back) + base ^= 0xFFFFFFFF + + // Bitwise OR between Base and Part to produce a new Base + base ^= part.sum + } + return base +} + +func fullObjectRead(r *DownloadRange) bool { + return r == nil || (r.Offset == 0 && r.Length < 0) +} + +func checksumObject(ctx context.Context, object *storage.ObjectHandle, got, want uint32) error { + // If we don't have a CRC32C from the reader, do a metadata call to get it. + if want == 0 { + attrs, err := object.Attrs(ctx) + if err != nil { + return fmt.Errorf("error getting ObjectAttrs for checksumming: %w", err) + } else { + want = attrs.CRC32C + } + } + + // Only checksum the object if we have a valid CRC32C. + if want != 0 && want != got { + return fmt.Errorf("bad CRC on read: got %d, want %d", got, want) + } + return nil +} diff --git a/storage/transfermanager/downloader_test.go b/storage/transfermanager/downloader_test.go index 2c469e5ef99b..0f6a9b33a2e5 100644 --- a/storage/transfermanager/downloader_test.go +++ b/storage/transfermanager/downloader_test.go @@ -15,6 +15,7 @@ package transfermanager import ( + "bytes" "context" "errors" "strings" @@ -411,11 +412,12 @@ func TestGatherShards(t *testing.T) { Offset: 20, Length: 120, } + firstOut := &DownloadOutput{Object: object, Range: &DownloadRange{Offset: 20, Length: 30}, shard: 0} outChan := make(chan *DownloadOutput, shards) outs := []*DownloadOutput{ - {Object: object, Range: &DownloadRange{Offset: 50, Length: 30}}, - {Object: object, Range: &DownloadRange{Offset: 80, Length: 30}}, - {Object: object, Range: &DownloadRange{Offset: 110, Length: 30}}, + {Object: object, Range: &DownloadRange{Offset: 50, Length: 30}, shard: 1}, + {Object: object, Range: &DownloadRange{Offset: 80, Length: 30}, shard: 2}, + {Object: object, Range: &DownloadRange{Offset: 110, Length: 30}, shard: 3}, } in := &DownloadObjectInput{ @@ -441,7 +443,7 @@ func TestGatherShards(t *testing.T) { d.downloadsInProgress.Add(1) go func() { - d.gatherShards(in, outChan, shards) + d.gatherShards(in, firstOut, outChan, shards, 0) wg.Done() }() @@ -475,7 +477,7 @@ func TestGatherShards(t *testing.T) { d.downloadsInProgress.Add(1) go func() { - d.gatherShards(in, outChan, shards) + d.gatherShards(in, firstOut, outChan, shards, 0) wg.Done() }() @@ -507,3 +509,43 @@ func TestGatherShards(t *testing.T) { t.Errorf("error in DownloadOutput should not contain error %q; got: %v", errCancelAllShards, err) } } + +func TestCalculateCRC32C(t *testing.T) { + t.Parallel() + for _, test := range []struct { + desc string + pieces []string + }{ + { + desc: "equal sized pieces", + pieces: []string{"he", "ll", "o ", "wo", "rl", "d!"}, + }, + { + desc: "uneven pieces", + pieces: []string{"hello", " ", "world!"}, + }, + { + desc: "large pieces", + pieces: []string{string(bytes.Repeat([]byte("a"), 1024*1024*32)), + string(bytes.Repeat([]byte("b"), 1024*1024*32)), + string(bytes.Repeat([]byte("c"), 1024*1024*32)), + }, + }, + } { + t.Run(test.desc, func(t *testing.T) { + initialChecksum := crc32c([]byte(test.pieces[0])) + + remainingChecksums := make([]crc32cPiece, len(test.pieces)-1) + for i, piece := range test.pieces[1:] { + remainingChecksums[i] = crc32cPiece{sum: crc32c([]byte(piece)), length: int64(len(piece))} + } + + got := joinCRC32C(initialChecksum, remainingChecksums) + want := crc32c([]byte(strings.Join(test.pieces, ""))) + + if got != want { + t.Errorf("crc32c not calculated correctly - want %v, got %v", want, got) + } + }) + } +} diff --git a/storage/transfermanager/integration_test.go b/storage/transfermanager/integration_test.go index 59f48d91b908..1d06b5fc67bb 100644 --- a/storage/transfermanager/integration_test.go +++ b/storage/transfermanager/integration_test.go @@ -46,6 +46,7 @@ const ( testPrefix = "go-integration-test-tm" grpcTestPrefix = "golang-grpc-test-tm" bucketExpiryAge = 24 * time.Hour + minObjectSize = 1024 maxObjectSize = 1024 * 1024 ) @@ -721,7 +722,7 @@ func TestIntegration_DownloadShard(t *testing.T) { o := c.Bucket(tb.bucket).Object(objectName) r, err := o.NewReader(ctx) if err != nil { - t.Fatalf("o.Attrs: %v", err) + t.Fatalf("o.NewReader: %v", err) } incorrectGen := r.Attrs.Generation - 1 @@ -1030,7 +1031,7 @@ func (tb *downloadTestBucket) Create(prefix string) error { // Write objects. for _, obj := range tb.objects { - size := randomInt64(1000, maxObjectSize) + size := randomInt64(minObjectSize, maxObjectSize) crc, err := generateFileInGCS(ctx, b.Object(obj), size) if err != nil { return fmt.Errorf("generateFileInGCS: %v", err) From db7e4e684444c8a0c940718663295656e22f9269 Mon Sep 17 00:00:00 2001 From: BrennaEpp Date: Sat, 20 Jul 2024 20:31:58 -0700 Subject: [PATCH 2/3] fix vet --- storage/transfermanager/downloader.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/storage/transfermanager/downloader.go b/storage/transfermanager/downloader.go index 3830f61aaa0e..be0e863bfcc1 100644 --- a/storage/transfermanager/downloader.go +++ b/storage/transfermanager/downloader.go @@ -938,9 +938,8 @@ func checksumObject(ctx context.Context, object *storage.ObjectHandle, got, want attrs, err := object.Attrs(ctx) if err != nil { return fmt.Errorf("error getting ObjectAttrs for checksumming: %w", err) - } else { - want = attrs.CRC32C } + want = attrs.CRC32C } // Only checksum the object if we have a valid CRC32C. From 480792348ef3bef340892c6ff828b3cd6bcf751a Mon Sep 17 00:00:00 2001 From: BrennaEpp Date: Mon, 29 Jul 2024 17:00:39 -0700 Subject: [PATCH 3/3] comment suggestions --- storage/transfermanager/downloader.go | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/storage/transfermanager/downloader.go b/storage/transfermanager/downloader.go index be0e863bfcc1..45d765e0b8fb 100644 --- a/storage/transfermanager/downloader.go +++ b/storage/transfermanager/downloader.go @@ -35,7 +35,7 @@ import ( // maxChecksumZeroArraySize is the maximum amount of memory to allocate for // updating the checksum. A larger size will occupy more memory but will require -// more updates when computing the crc32c of a full object. +// fewer updates when computing the crc32c of a full object. // TODO: test the performance of smaller values for this. const maxChecksumZeroArraySize = 4 * 1024 * 1024 @@ -361,9 +361,9 @@ func (d *Downloader) downloadWorker() { func (d *Downloader) startDownload(input *DownloadObjectInput) { var out *DownloadOutput - // Special-case: full object read. - // In this case, we want to request the full object and only read partSize bytes, - // so that we can avoid a metadata call to grab the CRC32C for JSON downloads. + // Full object read. Request the full object and only read partSize bytes + // (or the full object, if smaller than partSize), so that we can avoid a + // metadata call to grab the CRC32C for JSON downloads. if fullObjectRead(input.Range) { input.checkCRC = true out = input.downloadFirstShard(d.client, d.config.perOperationTimeout, d.config.partSize) @@ -389,8 +389,7 @@ func (d *Downloader) startDownload(input *DownloadObjectInput) { } else { // Download completed with a single shard. if input.checkCRC { - o := d.client.Bucket(input.Bucket).Object(input.Object).Generation(out.Attrs.Generation) - if err := checksumObject(input.ctx, o, out.crc32c, out.Attrs.CRC32C); err != nil { + if err := checksumObject(out.crc32c, out.Attrs.CRC32C); err != nil { out.Err = err } } @@ -451,8 +450,7 @@ func (d *Downloader) gatherShards(in *DownloadObjectInput, out *DownloadOutput, // All pieces gathered. if len(errs) == 0 && in.checkCRC && out.Attrs != nil { fullCrc := joinCRC32C(firstPieceCRC, orderedChecksums) - o := d.client.Bucket(in.Bucket).Object(in.Object).Generation(out.Attrs.Generation) - if err := checksumObject(in.ctx, o, fullCrc, out.Attrs.CRC32C); err != nil { + if err := checksumObject(fullCrc, out.Attrs.CRC32C); err != nil { errs = append(errs, err) } } @@ -932,16 +930,7 @@ func fullObjectRead(r *DownloadRange) bool { return r == nil || (r.Offset == 0 && r.Length < 0) } -func checksumObject(ctx context.Context, object *storage.ObjectHandle, got, want uint32) error { - // If we don't have a CRC32C from the reader, do a metadata call to get it. - if want == 0 { - attrs, err := object.Attrs(ctx) - if err != nil { - return fmt.Errorf("error getting ObjectAttrs for checksumming: %w", err) - } - want = attrs.CRC32C - } - +func checksumObject(got, want uint32) error { // Only checksum the object if we have a valid CRC32C. if want != 0 && want != got { return fmt.Errorf("bad CRC on read: got %d, want %d", got, want)