Skip to content

Commit

Permalink
Fix race in parallel uploads in file/readat
Browse files Browse the repository at this point in the history
  • Loading branch information
vadmeste committed Oct 25, 2016
1 parent 507c792 commit caf3a80
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 37 deletions.
9 changes: 4 additions & 5 deletions api-put-object-common.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,17 @@ func isReadAt(reader io.Reader) (ok bool) {
}

// shouldUploadPart - verify if part should be uploaded.
func shouldUploadPart(objPart objectPart, objectParts map[int]objectPart) bool {
func shouldUploadPart(objPart objectPart, uploadReq uploadPartReq) bool {
// If part not found should upload the part.
uploadedPart, found := objectParts[objPart.PartNumber]
if !found {
if uploadReq.Part == nil {
return true
}
// if size mismatches should upload the part.
if objPart.Size != uploadedPart.Size {
if objPart.Size != uploadReq.Part.Size {
return true
}
// if md5sum mismatches should upload the part.
if objPart.ETag != uploadedPart.ETag {
if objPart.ETag != uploadReq.Part.ETag {
return true
}
return false
Expand Down
35 changes: 22 additions & 13 deletions api-put-object-file.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,22 +178,27 @@ func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileRe

// Create a channel to communicate which part to upload.
// Buffer this to 10000, the maximum number of parts allowed by S3.
uploadPartsCh := make(chan int, 10000)
uploadPartsCh := make(chan uploadPartReq, 10000)

// Just for readability.
lastPartNumber := totalPartsCount

// Send each part through the partUploadCh to be uploaded.
for p := 1; p <= totalPartsCount; p++ {
uploadPartsCh <- p
part, ok := partsInfo[p]
if ok {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: &part}
} else {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: nil}
}
}
close(uploadPartsCh)

// Use three 'workers' to upload parts in parallel.
for w := 1; w <= 3; w++ {
go func() {
// Deal with each part as it comes through the channel.
for partNumber := range uploadPartsCh {
for uploadReq := range uploadPartsCh {
// Add hash algorithms that need to be calculated by computeHash()
// In case of a non-v4 signature or https connection, sha256 is not needed.
hashAlgos := make(map[string]hash.Hash)
Expand All @@ -206,19 +211,21 @@ func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileRe
// If partNumber was not uploaded we calculate the missing
// part offset and size. For all other part numbers we
// calculate offset based on multiples of partSize.
readOffset := int64(partNumber-1) * partSize
readOffset := int64(uploadReq.PartNum-1) * partSize
missingPartSize := partSize

// As a special case if partNumber is lastPartNumber, we
// calculate the offset based on the last part size.
if partNumber == lastPartNumber {
if uploadReq.PartNum == lastPartNumber {
readOffset = (fileSize - lastPartSize)
missingPartSize = lastPartSize
}

// Get a section reader on a particular offset.
sectionReader := io.NewSectionReader(fileReader, readOffset, missingPartSize)
var prtSize int64
var err error

prtSize, err = computeHash(hashAlgos, hashSums, sectionReader)
if err != nil {
uploadedPartsCh <- uploadedPartRes{
Expand All @@ -231,19 +238,20 @@ func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileRe
// Create the part to be uploaded.
verifyObjPart := objectPart{
ETag: hex.EncodeToString(hashSums["md5"]),
PartNumber: partNumber,
PartNumber: uploadReq.PartNum,
Size: partSize,
}

// If this is the last part do not give it the full part size.
if partNumber == lastPartNumber {
if uploadReq.PartNum == lastPartNumber {
verifyObjPart.Size = lastPartSize
}

// Verify if part should be uploaded.
if shouldUploadPart(verifyObjPart, partsInfo) {
if shouldUploadPart(verifyObjPart, uploadReq) {
// Proceed to upload the part.
var objPart objectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, sectionReader, partNumber, hashSums["md5"], hashSums["sha256"], prtSize)
objPart, err = c.uploadPart(bucketName, objectName, uploadID, sectionReader, uploadReq.PartNum, hashSums["md5"], hashSums["sha256"], prtSize)
if err != nil {
uploadedPartsCh <- uploadedPartRes{
Error: err,
Expand All @@ -252,12 +260,13 @@ func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileRe
return
}
// Save successfully uploaded part metadata.
partsInfo[partNumber] = objPart
uploadReq.Part = &objPart
}
// Return through the channel the part size.
uploadedPartsCh <- uploadedPartRes{
Size: verifyObjPart.Size,
PartNum: partNumber,
PartNum: uploadReq.PartNum,
Part: uploadReq.Part,
Error: nil,
}
}
Expand All @@ -271,8 +280,8 @@ func (c Client) putObjectMultipartFromFile(bucketName, objectName string, fileRe
return totalUploadedSize, uploadRes.Error
}
// Retrieve each uploaded part and store it to be completed.
part, ok := partsInfo[uploadRes.PartNum]
if !ok {
part := uploadRes.Part
if part == nil {
return totalUploadedSize, ErrInvalidArgument(fmt.Sprintf("Missing part number %d", uploadRes.PartNum))
}
// Update the total uploaded size.
Expand Down
6 changes: 4 additions & 2 deletions api-put-object-multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,14 @@ func (c Client) putObjectMultipartStream(bucketName, objectName string, reader i
// as we read from the source.
reader = newHook(tmpBuffer, progress)

part, ok := partsInfo[partNumber]

// Verify if part should be uploaded.
if shouldUploadPart(objectPart{
if ok && shouldUploadPart(objectPart{
ETag: hex.EncodeToString(hashSums["md5"]),
PartNumber: partNumber,
Size: prtSize,
}, partsInfo) {
}, uploadPartReq{PartNum: partNumber, Part: &part}) {
// Proceed to upload the part.
var objPart objectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, reader, partNumber, hashSums["md5"], hashSums["sha256"], prtSize)
Expand Down
47 changes: 30 additions & 17 deletions api-put-object-readat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,22 @@ type uploadedPartRes struct {
Error error // Any error encountered while uploading the part.
PartNum int // Number of the part uploaded.
Size int64 // Size of the part uploaded.
Part *objectPart
}

type uploadPartReq struct {
PartNum int // Number of the part uploaded.
Part *objectPart // Size of the part uploaded.
}

// shouldUploadPartReadAt - verify if part should be uploaded.
func shouldUploadPartReadAt(objPart objectPart, objectParts map[int]objectPart) bool {
func shouldUploadPartReadAt(objPart objectPart, uploadReq uploadPartReq) bool {
// If part not found part should be uploaded.
uploadedPart, found := objectParts[objPart.PartNumber]
if !found {
if uploadReq.Part == nil {
return true
}
// if size mismatches part should be uploaded.
if uploadedPart.Size != objPart.Size {
if uploadReq.Part.Size != objPart.Size {
return true
}
return false
Expand Down Expand Up @@ -103,7 +108,7 @@ func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, read
// Declare a channel that sends the next part number to be uploaded.
// Buffered to 10000 because thats the maximum number of parts allowed
// by S3.
uploadPartsCh := make(chan int, 10000)
uploadPartsCh := make(chan uploadPartReq, 10000)

// Declare a channel that sends back the response of a part upload.
// Buffered to 10000 because thats the maximum number of parts allowed
Expand All @@ -112,7 +117,12 @@ func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, read

// Send each part number to the channel to be processed.
for p := 1; p <= totalPartsCount; p++ {
uploadPartsCh <- p
part, ok := partsInfo[p]
if ok {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: &part}
} else {
uploadPartsCh <- uploadPartReq{PartNum: p, Part: nil}
}
}
close(uploadPartsCh)

Expand All @@ -123,19 +133,19 @@ func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, read
readAtBuffer := make([]byte, optimalReadBufferSize)

// Each worker will draw from the part channel and upload in parallel.
for partNumber := range uploadPartsCh {
for uploadReq := range uploadPartsCh {
// Declare a new tmpBuffer.
tmpBuffer := new(bytes.Buffer)

// If partNumber was not uploaded we calculate the missing
// part offset and size. For all other part numbers we
// calculate offset based on multiples of partSize.
readOffset := int64(partNumber-1) * partSize
readOffset := int64(uploadReq.PartNum-1) * partSize
missingPartSize := partSize

// As a special case if partNumber is lastPartNumber, we
// calculate the offset based on the last part size.
if partNumber == lastPartNumber {
if uploadReq.PartNum == lastPartNumber {
readOffset = (size - lastPartSize)
missingPartSize = lastPartSize
}
Expand All @@ -153,6 +163,7 @@ func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, read
}

var prtSize int64
var err error
prtSize, err = hashCopyBuffer(hashAlgos, hashSums, tmpBuffer, sectionReader, readAtBuffer)
if err != nil {
// Send the error back through the channel.
Expand All @@ -166,21 +177,21 @@ func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, read

// Verify object if its uploaded.
verifyObjPart := objectPart{
PartNumber: partNumber,
PartNumber: uploadReq.PartNum,
Size: partSize,
}
// Special case if we see a last part number, save last part
// size as the proper part size.
if partNumber == lastPartNumber {
if uploadReq.PartNum == lastPartNumber {
verifyObjPart.Size = lastPartSize
}

// Only upload the necessary parts. Otherwise return size through channel
// to update any progress bar.
if shouldUploadPartReadAt(verifyObjPart, partsInfo) {
if shouldUploadPartReadAt(verifyObjPart, uploadReq) {
// Proceed to upload the part.
var objPart objectPart
objPart, err = c.uploadPart(bucketName, objectName, uploadID, tmpBuffer, partNumber, hashSums["md5"], hashSums["sha256"], prtSize)
objPart, err = c.uploadPart(bucketName, objectName, uploadID, tmpBuffer, uploadReq.PartNum, hashSums["md5"], hashSums["sha256"], prtSize)
if err != nil {
uploadedPartsCh <- uploadedPartRes{
Size: 0,
Expand All @@ -190,12 +201,13 @@ func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, read
return
}
// Save successfully uploaded part metadata.
partsInfo[partNumber] = objPart
uploadReq.Part = &objPart
}
// Send successful part info through the channel.
uploadedPartsCh <- uploadedPartRes{
Size: verifyObjPart.Size,
PartNum: partNumber,
PartNum: uploadReq.PartNum,
Part: uploadReq.Part,
Error: nil,
}
}
Expand All @@ -210,8 +222,9 @@ func (c Client) putObjectMultipartFromReadAt(bucketName, objectName string, read
return totalUploadedSize, uploadRes.Error
}
// Retrieve each uploaded part and store it to be completed.
part, ok := partsInfo[uploadRes.PartNum]
if !ok {
// part, ok := partsInfo[uploadRes.PartNum]
part := uploadRes.Part
if part == nil {
return 0, ErrInvalidArgument(fmt.Sprintf("Missing part number %d", uploadRes.PartNum))
}
// Update the totalUploadedSize.
Expand Down

0 comments on commit caf3a80

Please sign in to comment.