Skip to content

Commit

Permalink
Merge pull request #510 from xibz/master
Browse files Browse the repository at this point in the history
Addressing issue #419
  • Loading branch information
xibz committed Jan 18, 2016
2 parents 3a0be65 + dd0b428 commit c5550b8
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 66 deletions.
163 changes: 98 additions & 65 deletions service/s3/s3manager/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package s3manager
import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
Expand Down Expand Up @@ -154,39 +155,48 @@ func (d *downloader) init() {
func (d *downloader) download() (n int64, err error) {
d.init()

// Spin up workers
ch := make(chan dlchunk, d.ctx.Concurrency)
for i := 0; i < d.ctx.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}
// Spin off first worker to check additional header information
d.getChunk()

if total := d.getTotalBytes(); total >= 0 {
// Spin up workers
ch := make(chan dlchunk, d.ctx.Concurrency)

for i := 0; i < d.ctx.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}

// Assign work
for d.geterr() == nil {
if d.pos != 0 {
// This is not the first chunk, let's wait until we know the total
// size of the payload so we can see if we have read the entire
// object.
total := d.getTotalBytes()

if total < 0 {
// Total has not yet been set, so sleep and loop around while
// waiting for our first worker to resolve this value.
time.Sleep(10 * time.Millisecond)
continue
} else if d.pos >= total {
// Assign work
for d.getErr() == nil {
if d.pos >= total {
break // We're finished queueing chunks
}

// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
d.pos += d.ctx.PartSize
}

// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
d.pos += d.ctx.PartSize
}
// Wait for completion
close(ch)
d.wg.Wait()
} else {
// Checking if we read anything new
for d.err == nil {
d.getChunk()
}

// Wait for completion
close(ch)
d.wg.Wait()
// We expect a 416 error letting us know we are done downloading the
// total bytes. Since we do not know the content's length, this will
// keep grabbing chunks of data until the range of bytes specified in
// the request is out of range of the content. Once, this happens, a
// 416 should occur.
e, ok := d.err.(awserr.RequestFailure)
if ok && e.StatusCode() == http.StatusRequestedRangeNotSatisfiable {
d.err = nil
}
}

// Return error
return d.written, d.err
Expand All @@ -199,39 +209,51 @@ func (d *downloader) download() (n int64, err error) {
// of bytes to be read so that the worker manager knows when it is finished.
func (d *downloader) downloadPart(ch chan dlchunk) {
defer d.wg.Done()

for {
chunk, ok := <-ch
if !ok {
break
}
d.downloadChunk(chunk)
}
}

// getChunk grabs a chunk of data from the body.
// Not thread safe. Should only used when grabbing data on a single thread.
func (d *downloader) getChunk() {
chunk := dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
d.pos += d.ctx.PartSize
d.downloadChunk(chunk)
}

if d.geterr() == nil {
// Get the next byte range of data
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
rng := fmt.Sprintf("bytes=%d-%d",
chunk.start, chunk.start+chunk.size-1)
in.Range = &rng
// downloadChunk downloads the chunk froom s3
func (d *downloader) downloadChunk(chunk dlchunk) {
if d.getErr() != nil {
return
}
// Get the next byte range of data
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
rng := fmt.Sprintf("bytes=%d-%d",
chunk.start, chunk.start+chunk.size-1)
in.Range = &rng

req, resp := d.ctx.S3.GetObjectRequest(in)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
err := req.Send()
req, resp := d.ctx.S3.GetObjectRequest(in)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
err := req.Send()

if err != nil {
d.seterr(err)
} else {
d.setTotalBytes(resp) // Set total if not yet set.
if err != nil {
d.setErr(err)
} else {
d.setTotalBytes(resp) // Set total if not yet set.

n, err := io.Copy(&chunk, resp.Body)
resp.Body.Close()
n, err := io.Copy(&chunk, resp.Body)
resp.Body.Close()

if err != nil {
d.seterr(err)
}
d.incrwritten(n)
}
if err != nil {
d.setErr(err)
}
d.incrWritten(n)
}
}

Expand Down Expand Up @@ -259,37 +281,48 @@ func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
if resp.ContentRange == nil {
// ContentRange is nil when the full file contents is provied, and
// is not chunked. Use ContentLength instead.
d.totalBytes = *resp.ContentLength
return
}
if resp.ContentLength != nil {
d.totalBytes = *resp.ContentLength
return
}
} else {
parts := strings.Split(*resp.ContentRange, "/")

total := int64(-1)
var err error
// Checking for whether or not a numbered total exists
// If one does not exist, we will assume the total to be -1, undefined,
// and sequentially download each chunk until hitting a 416 error
totalStr := parts[len(parts)-1]
if totalStr != "*" {
total, err = strconv.ParseInt(totalStr, 10, 64)
if err != nil {
d.err = err
return
}
}

parts := strings.Split(*resp.ContentRange, "/")
total, err := strconv.ParseInt(parts[len(parts)-1], 10, 64)
if err != nil {
d.err = err
return
d.totalBytes = total
}

d.totalBytes = total
}

func (d *downloader) incrwritten(n int64) {
func (d *downloader) incrWritten(n int64) {
d.m.Lock()
defer d.m.Unlock()

d.written += n
}

// geterr is a thread-safe getter for the error object
func (d *downloader) geterr() error {
// getErr is a thread-safe getter for the error object
func (d *downloader) getErr() error {
d.m.Lock()
defer d.m.Unlock()

return d.err
}

// seterr is a thread-safe setter for the error object
func (d *downloader) seterr(e error) {
// setErr is a thread-safe setter for the error object
func (d *downloader) setErr(e error) {
d.m.Lock()
defer d.m.Unlock()

Expand Down
119 changes: 118 additions & 1 deletion service/s3/s3manager/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func dlLoggingSvc(data []byte) (*s3.S3, *[]string, *[]string) {
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d",
start, fin, len(data)))
start, fin-1, len(data)))
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
})

Expand Down Expand Up @@ -80,6 +80,77 @@ func dlLoggingSvcNoChunk(data []byte) (*s3.S3, *[]string) {
return svc, &names
}

func dlLoggingSvcNoContentRangeLength(data []byte, states []int) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
var index int = 0

svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()

names = append(names, r.Operation.Name)

r.HTTPResponse = &http.Response{
StatusCode: states[index],
Body: ioutil.NopCloser(bytes.NewReader(data[:])),
Header: http.Header{},
}
index++
})

return svc, &names
}

func dlLoggingSvcContentRangeTotalAny(data []byte, states []int) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
ranges := []string{}
var index int = 0

svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()

names = append(names, r.Operation.Name)
ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)

rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
start, _ := strconv.ParseInt(rng[1], 10, 64)
fin, _ := strconv.ParseInt(rng[2], 10, 64)
fin++

if fin >= int64(len(data)) {
fin = int64(len(data))
}

// Setting start and finish to 0 because this state of 1 is suppose to
// be an error state of 416
if index == len(states)-1 {
start = 0
fin = 0
}

bodyBytes := data[start:fin]

r.HTTPResponse = &http.Response{
StatusCode: states[index],
Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/*",
start, fin-1))
index++
})

return svc, &names
}

func TestDownloadOrder(t *testing.T) {
s, names, ranges := dlLoggingSvc(buf12MB)

Expand Down Expand Up @@ -190,3 +261,49 @@ func TestDownloadNonChunk(t *testing.T) {
}
assert.Equal(t, 0, count)
}

func TestDownloadNoContentRangeLength(t *testing.T) {
s, names := dlLoggingSvcNoContentRangeLength(buf2MB, []int{200, 416})

d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})

assert.Nil(t, err)
assert.Equal(t, int64(len(buf2MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)

count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
}

func TestDownloadContentRangeTotalAny(t *testing.T) {
s, names := dlLoggingSvcContentRangeTotalAny(buf2MB, []int{200, 416})

d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})

assert.Nil(t, err)
assert.Equal(t, int64(len(buf2MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)

count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
}

0 comments on commit c5550b8

Please sign in to comment.