Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressing issue #419 #510

Merged
merged 14 commits into from
Jan 18, 2016
163 changes: 98 additions & 65 deletions service/s3/s3manager/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package s3manager
import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
Expand Down Expand Up @@ -154,39 +155,48 @@ func (d *downloader) init() {
func (d *downloader) download() (n int64, err error) {
d.init()

// Spin up workers
ch := make(chan dlchunk, d.ctx.Concurrency)
for i := 0; i < d.ctx.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}
// Spin off first worker to check additional header information
d.getChunk()

if total := d.getTotalBytes(); total >= 0 {
// Spin up workers
ch := make(chan dlchunk, d.ctx.Concurrency)

for i := 0; i < d.ctx.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}

// Assign work
for d.geterr() == nil {
if d.pos != 0 {
// This is not the first chunk, let's wait until we know the total
// size of the payload so we can see if we have read the entire
// object.
total := d.getTotalBytes()

if total < 0 {
// Total has not yet been set, so sleep and loop around while
// waiting for our first worker to resolve this value.
time.Sleep(10 * time.Millisecond)
continue
} else if d.pos >= total {
// Assign work
for d.getErr() == nil {
if d.pos >= total {
break // We're finished queueing chunks
}

// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
d.pos += d.ctx.PartSize
}

// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
d.pos += d.ctx.PartSize
}
// Wait for completion
close(ch)
d.wg.Wait()
} else {
// Checking if we read anything new
for d.err == nil {
d.getChunk()
}

// Wait for completion
close(ch)
d.wg.Wait()
// We expect a 416 error letting us know we are done downloading the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you expand this comment stating why 416 might be received?

// total bytes. Since we do not know the content's length, this will
// keep grabbing chunks of data until the range of bytes specified in
// the request is out of range of the content. Once, this happens, a
// 416 should occur.
e, ok := d.err.(awserr.RequestFailure)
if ok && e.StatusCode() == http.StatusRequestedRangeNotSatisfiable {
d.err = nil
}
}

// Return error
return d.written, d.err
Expand All @@ -199,39 +209,51 @@ func (d *downloader) download() (n int64, err error) {
// of bytes to be read so that the worker manager knows when it is finished.
func (d *downloader) downloadPart(ch chan dlchunk) {
defer d.wg.Done()

for {
chunk, ok := <-ch
if !ok {
break
}
d.downloadChunk(chunk)
}
}

// getChunk grabs a chunk of data from the body.
// Not thread safe. Should only used when grabbing data on a single thread.
func (d *downloader) getChunk() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might make sense to move this logic into the download function. Being its own method accessing data like this could introduce data races if this method is called somewhere else unexpectedly.

chunk := dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
d.pos += d.ctx.PartSize
d.downloadChunk(chunk)
}

if d.geterr() == nil {
// Get the next byte range of data
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
rng := fmt.Sprintf("bytes=%d-%d",
chunk.start, chunk.start+chunk.size-1)
in.Range = &rng
// downloadChunk downloads the chunk froom s3
func (d *downloader) downloadChunk(chunk dlchunk) {
if d.getErr() != nil {
return
}
// Get the next byte range of data
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
rng := fmt.Sprintf("bytes=%d-%d",
chunk.start, chunk.start+chunk.size-1)
in.Range = &rng

req, resp := d.ctx.S3.GetObjectRequest(in)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
err := req.Send()
req, resp := d.ctx.S3.GetObjectRequest(in)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
err := req.Send()

if err != nil {
d.seterr(err)
} else {
d.setTotalBytes(resp) // Set total if not yet set.
if err != nil {
d.setErr(err)
} else {
d.setTotalBytes(resp) // Set total if not yet set.

n, err := io.Copy(&chunk, resp.Body)
resp.Body.Close()
n, err := io.Copy(&chunk, resp.Body)
resp.Body.Close()

if err != nil {
d.seterr(err)
}
d.incrwritten(n)
}
if err != nil {
d.setErr(err)
}
d.incrWritten(n)
}
}

Expand Down Expand Up @@ -259,37 +281,48 @@ func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
if resp.ContentRange == nil {
// ContentRange is nil when the full file contents is provied, and
// is not chunked. Use ContentLength instead.
d.totalBytes = *resp.ContentLength
return
}
if resp.ContentLength != nil {
d.totalBytes = *resp.ContentLength
return
}
} else {
parts := strings.Split(*resp.ContentRange, "/")

total := int64(-1)
var err error
// Checking for whether or not a numbered total exists
// If one does not exist, we will assume the total to be -1, undefined,
// and sequentially download each chunk until hitting a 416 error
totalStr := parts[len(parts)-1]
if totalStr != "*" {
total, err = strconv.ParseInt(totalStr, 10, 64)
if err != nil {
d.err = err
return
}
}

parts := strings.Split(*resp.ContentRange, "/")
total, err := strconv.ParseInt(parts[len(parts)-1], 10, 64)
if err != nil {
d.err = err
return
d.totalBytes = total
}

d.totalBytes = total
}

func (d *downloader) incrwritten(n int64) {
func (d *downloader) incrWritten(n int64) {
d.m.Lock()
defer d.m.Unlock()

d.written += n
}

// geterr is a thread-safe getter for the error object
func (d *downloader) geterr() error {
// getErr is a thread-safe getter for the error object
func (d *downloader) getErr() error {
d.m.Lock()
defer d.m.Unlock()

return d.err
}

// seterr is a thread-safe setter for the error object
func (d *downloader) seterr(e error) {
// setErr is a thread-safe setter for the error object
func (d *downloader) setErr(e error) {
d.m.Lock()
defer d.m.Unlock()

Expand Down
119 changes: 118 additions & 1 deletion service/s3/s3manager/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func dlLoggingSvc(data []byte) (*s3.S3, *[]string, *[]string) {
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d",
start, fin, len(data)))
start, fin-1, len(data)))
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
})

Expand Down Expand Up @@ -80,6 +80,77 @@ func dlLoggingSvcNoChunk(data []byte) (*s3.S3, *[]string) {
return svc, &names
}

func dlLoggingSvcNoContentRangeLength(data []byte, states []int) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
var index int = 0

svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()

names = append(names, r.Operation.Name)

r.HTTPResponse = &http.Response{
StatusCode: states[index],
Body: ioutil.NopCloser(bytes.NewReader(data[:])),
Header: http.Header{},
}
index++
})

return svc, &names
}

func dlLoggingSvcContentRangeTotalAny(data []byte, states []int) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
ranges := []string{}
var index int = 0

svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()

names = append(names, r.Operation.Name)
ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)

rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
start, _ := strconv.ParseInt(rng[1], 10, 64)
fin, _ := strconv.ParseInt(rng[2], 10, 64)
fin++

if fin >= int64(len(data)) {
fin = int64(len(data))
}

// Setting start and finish to 0 because this state of 1 is suppose to
// be an error state of 416
if index == len(states)-1 {
start = 0
fin = 0
}

bodyBytes := data[start:fin]

r.HTTPResponse = &http.Response{
StatusCode: states[index],
Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/*",
start, fin-1))
index++
})

return svc, &names
}

func TestDownloadOrder(t *testing.T) {
s, names, ranges := dlLoggingSvc(buf12MB)

Expand Down Expand Up @@ -190,3 +261,49 @@ func TestDownloadNonChunk(t *testing.T) {
}
assert.Equal(t, 0, count)
}

func TestDownloadNoContentRangeLength(t *testing.T) {
s, names := dlLoggingSvcNoContentRangeLength(buf2MB, []int{200, 416})

d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})

assert.Nil(t, err)
assert.Equal(t, int64(len(buf2MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)

count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
}

func TestDownloadContentRangeTotalAny(t *testing.T) {
s, names := dlLoggingSvcContentRangeTotalAny(buf2MB, []int{200, 416})

d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})

assert.Nil(t, err)
assert.Equal(t, int64(len(buf2MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)

count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
}