Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(storage/transfermanager): checksum full object downloads #10569

Merged
merged 4 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 217 additions & 64 deletions storage/transfermanager/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"context"
"errors"
"fmt"
"hash"
"hash/crc32"
"io"
"io/fs"
"math"
Expand All @@ -31,6 +33,12 @@ import (
"google.golang.org/api/iterator"
)

// maxChecksumZeroArraySize is the maximum amount of memory to allocate for
// updating the checksum. A larger size will occupy more memory but will require
// fewer updates when computing the crc32c of a full object.
// TODO: test the performance of smaller values for this.
const maxChecksumZeroArraySize = 4 * 1024 * 1024

// Downloader manages a set of parallelized downloads.
type Downloader struct {
client *storage.Client
Expand Down Expand Up @@ -288,7 +296,7 @@ func (d *Downloader) addNewInputs(inputs []DownloadObjectInput) {
}

func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutput) {
copiedResult := *result // make a copy so that callbacks do not affect the result
copiedResult := *result // make a copy so that callbacks do not affect the result

if input.directory {
f := input.Destination.(*os.File)
Expand All @@ -305,7 +313,6 @@ func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutpu
input.directoryObjectOutputs <- copiedResult
}
}
// TODO: check checksum if full object

if d.config.asynchronous || input.directory {
input.Callback(result)
Expand Down Expand Up @@ -337,27 +344,10 @@ func (d *Downloader) downloadWorker() {
break // no more work; exit
}

out := input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize)

if input.shard == 0 {
if out.Err != nil {
// Don't queue more shards if the first failed.
d.addResult(input, out)
} else {
numShards := numShards(out.Attrs, input.Range, d.config.partSize)

if numShards <= 1 {
// Download completed with a single shard.
d.addResult(input, out)
} else {
// Queue more shards.
outs := d.queueShards(input, out.Attrs.Generation, numShards)
// Start a goroutine that gathers shards sent to the output
// channel and adds the result once it has received all shards.
go d.gatherShards(input, outs, numShards)
}
}
d.startDownload(input)
} else {
out := input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize)
// If this isn't the first shard, send to the output channel specific to the object.
// This should never block since the channel is buffered to exactly the number of shards.
input.shardOutputs <- out
Expand All @@ -366,6 +356,47 @@ func (d *Downloader) downloadWorker() {
d.workers.Done()
}

// startDownload downloads the first shard and schedules subsequent shards
// if necessary.
func (d *Downloader) startDownload(input *DownloadObjectInput) {
var out *DownloadOutput

// Full object read. Request the full object and only read partSize bytes
// (or the full object, if smaller than partSize), so that we can avoid a
// metadata call to grab the CRC32C for JSON downloads.
if fullObjectRead(input.Range) {
input.checkCRC = true
out = input.downloadFirstShard(d.client, d.config.perOperationTimeout, d.config.partSize)
} else {
out = input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize)
}

if out.Err != nil {
// Don't queue more shards if the first failed.
d.addResult(input, out)
return
}

numShards := numShards(out.Attrs, input.Range, d.config.partSize)
input.checkCRC = input.checkCRC && !out.Attrs.Decompressed // do not checksum if the object was decompressed

if numShards > 1 {
outs := d.queueShards(input, out.Attrs.Generation, numShards)
// Start a goroutine that gathers shards sent to the output
// channel and adds the result once it has received all shards.
go d.gatherShards(input, out, outs, numShards, out.crc32c)

} else {
// Download completed with a single shard.
if input.checkCRC {
if err := checksumObject(out.crc32c, out.Attrs.CRC32C); err != nil {
out.Err = err
}
}
d.addResult(input, out)
}
}

// queueShards queues all subsequent shards of an object after the first.
// The results should be forwarded to the returned channel.
func (d *Downloader) queueShards(in *DownloadObjectInput, gen int64, shards int) <-chan *DownloadOutput {
Expand Down Expand Up @@ -397,12 +428,12 @@ var errCancelAllShards = errors.New("cancelled because another shard failed")
// It will add the result to the Downloader once it has received all shards.
// gatherShards cancels remaining shards if any shard errored.
// It does not do any checking to verify that shards are for the same object.
func (d *Downloader) gatherShards(in *DownloadObjectInput, outs <-chan *DownloadOutput, shards int) {
func (d *Downloader) gatherShards(in *DownloadObjectInput, out *DownloadOutput, outs <-chan *DownloadOutput, shards int, firstPieceCRC uint32) {
errs := []error{}
var shardOut *DownloadOutput
orderedChecksums := make([]crc32cPiece, shards-1)

for i := 1; i < shards; i++ {
// Add monitoring here? This could hang if any individual piece does.
shardOut = <-outs
shardOut := <-outs

// We can ignore errors that resulted from a previous error.
// Note that we may still get some cancel errors if they
Expand All @@ -412,20 +443,30 @@ func (d *Downloader) gatherShards(in *DownloadObjectInput, outs <-chan *Download
errs = append(errs, shardOut.Err)
in.cancelCtx(errCancelAllShards)
}

orderedChecksums[shardOut.shard-1] = crc32cPiece{sum: shardOut.crc32c, length: shardOut.shardLength}
}

// All pieces gathered.
if len(errs) == 0 && in.checkCRC && out.Attrs != nil {
fullCrc := joinCRC32C(firstPieceCRC, orderedChecksums)
if err := checksumObject(fullCrc, out.Attrs.CRC32C); err != nil {
errs = append(errs, err)
}
}

// All pieces gathered; return output. Any shard output will do.
shardOut.Range = in.Range
// Prepare output.
out.Range = in.Range
if len(errs) != 0 {
shardOut.Err = fmt.Errorf("download shard errors:\n%w", errors.Join(errs...))
out.Err = fmt.Errorf("download shard errors:\n%w", errors.Join(errs...))
}
if shardOut.Attrs != nil {
shardOut.Attrs.StartOffset = 0
if out.Attrs != nil {
out.Attrs.StartOffset = 0
if in.Range != nil {
shardOut.Attrs.StartOffset = in.Range.Offset
out.Attrs.StartOffset = in.Range.Offset
}
}
d.addResult(in, shardOut)
d.addResult(in, out)
}

// gatherObjectOutputs receives from the given channel exactly numObjects times.
Expand Down Expand Up @@ -563,45 +604,18 @@ type DownloadObjectInput struct {
shardOutputs chan<- *DownloadOutput
directory bool // input was queued by calling DownloadDirectory
directoryObjectOutputs chan<- DownloadOutput
checkCRC bool
}

// downloadShard will read a specific object piece into in.Destination.
// If timeout is less than 0, no timeout is set.
func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout time.Duration, partSize int64) (out *DownloadOutput) {
out = &DownloadOutput{Bucket: in.Bucket, Object: in.Object, Range: in.Range}

// Set timeout.
ctx := in.ctx
if timeout > 0 {
c, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ctx = c
}

// The first shard will be sent as download many, since we do not know yet
// if it will be sharded.
method := downloadMany
if in.shard != 0 {
method = downloadSharded
}
ctx = setUsageMetricHeader(ctx, method)

// Set options on the object.
o := client.Bucket(in.Bucket).Object(in.Object)

if in.Conditions != nil {
o = o.If(*in.Conditions)
}
if in.Generation != nil {
o = o.Generation(*in.Generation)
}
if len(in.EncryptionKey) > 0 {
o = o.Key(in.EncryptionKey)
}

objRange := shardRange(in.Range, partSize, in.shard)
ctx := in.setOptionsOnContext(timeout)
o := in.setOptionsOnObject(client)

// Read.
r, err := o.NewRangeReader(ctx, objRange.Offset, objRange.Length)
if err != nil {
out.Err = err
Expand All @@ -618,9 +632,63 @@ func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout tim
}
}

w := io.NewOffsetWriter(in.Destination, offset)
_, err = io.Copy(w, r)
var w io.Writer
w = io.NewOffsetWriter(in.Destination, offset)

var crcHash hash.Hash32
if in.checkCRC {
crcHash = crc32.New(crc32.MakeTable(crc32.Castagnoli))
w = io.MultiWriter(w, crcHash)
tritone marked this conversation as resolved.
Show resolved Hide resolved
}

n, err := io.Copy(w, r)
if err != nil {
out.Err = err
r.Close()
return
}

if err = r.Close(); err != nil {
out.Err = err
return
}

out.Attrs = &r.Attrs
out.shard = in.shard
out.shardLength = n
if in.checkCRC {
out.crc32c = crcHash.Sum32()
}
return
}

// downloadFirstShard will read the first object piece into in.Destination.
// If timeout is less than 0, no timeout is set.
func (in *DownloadObjectInput) downloadFirstShard(client *storage.Client, timeout time.Duration, partSize int64) (out *DownloadOutput) {
out = &DownloadOutput{Bucket: in.Bucket, Object: in.Object, Range: in.Range}

ctx := in.setOptionsOnContext(timeout)
o := in.setOptionsOnObject(client)

r, err := o.NewReader(ctx)
if err != nil {
out.Err = err
return
}

var w io.Writer
w = io.NewOffsetWriter(in.Destination, 0)

var crcHash hash.Hash32
if in.checkCRC {
crcHash = crc32.New(crc32.MakeTable(crc32.Castagnoli))
w = io.MultiWriter(w, crcHash)
}

// Copy only the first partSize bytes before closing the reader.
// If we encounter an EOF, the file was smaller than partSize.
n, err := io.CopyN(w, r, partSize)
if err != nil && err != io.EOF {
out.Err = err
r.Close()
return
Expand All @@ -632,9 +700,45 @@ func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout tim
}

out.Attrs = &r.Attrs
out.shard = in.shard
out.shardLength = n
if in.checkCRC {
out.crc32c = crcHash.Sum32()
}
return
}

func (in *DownloadObjectInput) setOptionsOnContext(timeout time.Duration) context.Context {
ctx := in.ctx
if timeout > 0 {
c, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ctx = c
}

// The first shard will be sent as download many, since we do not know yet
// if it will be sharded.
method := downloadMany
if in.shard != 0 {
method = downloadSharded
}
return setUsageMetricHeader(ctx, method)
}

func (in *DownloadObjectInput) setOptionsOnObject(client *storage.Client) *storage.ObjectHandle {
o := client.Bucket(in.Bucket).Object(in.Object)
if in.Conditions != nil {
o = o.If(*in.Conditions)
}
if in.Generation != nil {
o = o.Generation(*in.Generation)
}
if len(in.EncryptionKey) > 0 {
o = o.Key(in.EncryptionKey)
}
return o
}

// DownloadDirectoryInput is the input for a directory to download.
type DownloadDirectoryInput struct {
// Bucket is the bucket in GCS to download from. Required.
Expand Down Expand Up @@ -686,6 +790,10 @@ type DownloadOutput struct {
Range *DownloadRange // requested range, if it was specified
Err error // error occurring during download
Attrs *storage.ReaderObjectAttrs // attributes of downloaded object, if successful

shard int
shardLength int64
crc32c uint32
}

// TODO: use built-in after go < 1.21 is dropped.
Expand Down Expand Up @@ -784,3 +892,48 @@ func setUsageMetricHeader(ctx context.Context, method string) context.Context {
header := fmt.Sprintf("%s/%s", usageMetricKey, method)
return callctx.SetHeaders(ctx, xGoogHeaderKey, header)
}

type crc32cPiece struct {
sum uint32 // crc32c checksum of the piece
length int64 // number of bytes in this piece
}

// joinCRC32C pieces together the initial checksum with the orderedChecksums
// provided to calculate the checksum of the whole.
func joinCRC32C(initialChecksum uint32, orderedChecksums []crc32cPiece) uint32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems surprising that this isn't available in hash/crc32 or elsewhere in a standard library, but I assume you looked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there was a feature request for it a while back but they considered it not common enough a use-case to include in the standard library: golang/go#12297

base := initialChecksum

zeroes := make([]byte, maxChecksumZeroArraySize)
for _, part := range orderedChecksums {
// Precondition Base (flip every bit)
base ^= 0xFFFFFFFF

// Zero pad base crc32c. To conserve memory, do so with only maxChecksumZeroArraySize
// at a time. Reuse the zeroes array where possible.
var padded int64 = 0
for padded < part.length {
desiredZeroes := min(part.length-padded, maxChecksumZeroArraySize)
base = crc32.Update(base, crc32.MakeTable(crc32.Castagnoli), zeroes[:desiredZeroes])
padded += desiredZeroes
}

// Postcondition Base (same as precondition, this switches the bits back)
base ^= 0xFFFFFFFF

// Bitwise OR between Base and Part to produce a new Base
base ^= part.sum
}
return base
}

func fullObjectRead(r *DownloadRange) bool {
return r == nil || (r.Offset == 0 && r.Length < 0)
}

func checksumObject(got, want uint32) error {
// Only checksum the object if we have a valid CRC32C.
if want != 0 && want != got {
return fmt.Errorf("bad CRC on read: got %d, want %d", got, want)
}
return nil
}
Loading
Loading