Skip to content

Commit

Permalink
perf: multi-thread downloader, Content-Disposition (#4921)
Browse files Browse the repository at this point in the history
general: enhance multi-thread downloader with cancelable context, immediately stop all stream processes when canceled;
feat(crypt): improve stream closing;
general: fix the bug of downloading files becomes previewing stream on modern browsers;

Co-authored-by: Sean He <866155+seanhe26@users.noreply.github.com>
Co-authored-by: Andy Hsu <i@nn.ci>
  • Loading branch information
3 people authored Aug 4, 2023
1 parent 861948b commit 15b7169
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 34 deletions.
18 changes: 0 additions & 18 deletions drivers/base/util.go
Original file line number Diff line number Diff line change
@@ -1,19 +1 @@
package base

import "io"

type Closers struct {
closers []io.Closer
}

func (c *Closers) Close() (err error) {
for _, closer := range c.closers {
if closer != nil {
_ = closer.Close()
}
}
return nil
}
func (c *Closers) Add(closer io.Closer) {
c.closers = append(c.closers, closer)
}
8 changes: 5 additions & 3 deletions drivers/crypt/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
return nil, fmt.Errorf("the remote storage driver need to be enhanced to support encrytion")
}
remoteFileSize := remoteFile.GetSize()
var remoteCloser io.Closer
remoteClosers := utils.NewClosers()
rangeReaderFunc := func(ctx context.Context, underlyingOffset, underlyingLength int64) (io.ReadCloser, error) {
length := underlyingLength
if underlyingLength >= 0 && underlyingOffset+underlyingLength >= remoteFileSize {
Expand All @@ -245,6 +245,7 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
if remoteLink.RangeReadCloser.RangeReader != nil {
//remoteRangeReader, err :=
remoteReader, err := remoteLink.RangeReadCloser.RangeReader(http_range.Range{Start: underlyingOffset, Length: length})
remoteClosers.Add(remoteLink.RangeReadCloser.Closers)
if err != nil {
return nil, err
}
Expand All @@ -255,8 +256,8 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
if err != nil {
return nil, err
}
//remoteClosers.Add(remoteLink.ReadSeekCloser)
//keep reuse same ReadSeekCloser and close at last.
remoteCloser = remoteLink.ReadSeekCloser
return io.NopCloser(remoteLink.ReadSeekCloser), nil
}
if len(remoteLink.URL) > 0 {
Expand All @@ -265,6 +266,7 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
Header: remoteLink.Header,
}
response, err := RequestRangedHttp(args.HttpReq, rangedRemoteLink, underlyingOffset, length)
//remoteClosers.Add(response.Body)
if err != nil {
return nil, fmt.Errorf("remote storage http request failure,status: %d err:%s", response.StatusCode, err)
}
Expand Down Expand Up @@ -301,7 +303,7 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (
return readSeeker, nil
}

resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closer: remoteCloser}
resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: remoteClosers}
resultLink := &model.Link{
Header: remoteLink.Header,
RangeReadCloser: *resultRangeReadCloser,
Expand Down
5 changes: 2 additions & 3 deletions drivers/mega/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/rclone/rclone/lib/readers"
"io"
Expand Down Expand Up @@ -75,7 +74,7 @@ func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*
//}

size := file.GetSize()
var finalClosers base.Closers
var finalClosers utils.Closers
resultRangeReader := func(httpRange http_range.Range) (io.ReadCloser, error) {
length := httpRange.Length
if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size {
Expand All @@ -98,7 +97,7 @@ func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*

return readers.NewLimitedReadCloser(oo, length), nil
}
resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closer: &finalClosers}
resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: &finalClosers}
resultLink := &model.Link{
RangeReadCloser: *resultRangeReadCloser,
}
Expand Down
3 changes: 2 additions & 1 deletion internal/model/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package model

import (
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/alist-org/alist/v3/pkg/utils"
"io"
"net/http"
"time"
Expand Down Expand Up @@ -45,7 +46,7 @@ type FsOtherArgs struct {
}
type RangeReadCloser struct {
RangeReader RangeReaderFunc
Closer io.Closer
Closers *utils.Closers
}

type WriterFunc func(w io.Writer) error
Expand Down
20 changes: 15 additions & 5 deletions internal/net/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readClo

// downloader is the implementation structure used internally by Downloader.
type downloader struct {
ctx context.Context
cfg Downloader
ctx context.Context
cancel context.CancelFunc
cfg Downloader

params *HttpRequestParams //http request params
chunkChannel chan chunk //chunk chanel
Expand All @@ -107,6 +108,7 @@ type downloader struct {

// download performs the implementation of the object download across ranged GETs.
func (d *downloader) download() (*io.ReadCloser, error) {
d.ctx, d.cancel = context.WithCancel(d.ctx)

pos := d.params.Range.Start
maxPos := d.params.Range.Start + d.params.Range.Length
Expand Down Expand Up @@ -138,7 +140,7 @@ func (d *downloader) download() (*io.ReadCloser, error) {
d.chunkChannel = make(chan chunk, d.cfg.Concurrency)

for i := 0; i < d.cfg.Concurrency; i++ {
buf := NewBuf(d.cfg.PartSize, i)
buf := NewBuf(d.ctx, d.cfg.PartSize, i)
d.bufs = append(d.bufs, buf)
go d.downloadPart()
}
Expand All @@ -163,6 +165,7 @@ func (d *downloader) sendChunkTask() *chunk {

// when the final reader Close, we interrupt
func (d *downloader) interrupt() error {
d.cancel()
if d.written != d.params.Range.Length {
log.Debugf("Downloader interrupt before finish")
if d.getErr() == nil {
Expand Down Expand Up @@ -520,15 +523,16 @@ func (buf *Buffer) waitTillNewWrite(pos int) error {
type Buf struct {
buffer *Buffer // Buffer we read from
size int //expected size
ctx context.Context
}

// NewBuf is a buffer that can have 1 read & 1 write at the same time.
// when read is faster write, immediately feed data to read after written
func NewBuf(maxSize int, id int) *Buf {
func NewBuf(ctx context.Context, maxSize int, id int) *Buf {
d := make([]byte, maxSize)
buffer := &Buffer{data: d, id: id, notify: make(chan int)}
buffer.reset()
return &Buf{buffer: buffer, size: maxSize}
return &Buf{ctx: ctx, buffer: buffer, size: maxSize}

}
func (br *Buf) Reset(size int) {
Expand All @@ -540,6 +544,9 @@ func (br *Buf) GetId() int {
}

func (br *Buf) Read(p []byte) (n int, err error) {
if err := br.ctx.Err(); err != nil {
return 0, err
}
if len(p) == 0 {
return 0, nil
}
Expand Down Expand Up @@ -580,6 +587,9 @@ func (br *Buf) waitTillNewWrite(pos int) error {
}

func (br *Buf) Write(p []byte) (n int, err error) {
if err := br.ctx.Err(); err != nil {
return 0, err
}
return br.buffer.Write(p)
}
func (br *Buf) Close() {
Expand Down
21 changes: 21 additions & 0 deletions pkg/utils/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,24 @@ func Retry(attempts int, sleep time.Duration, f func() error) (err error) {
}
return fmt.Errorf("after %d attempts, last error: %s", attempts, err)
}

type Closers struct {
closers []*io.Closer
}

func (c *Closers) Close() (err error) {
for _, closer := range c.closers {
if closer != nil {
_ = (*closer).Close()
}
}
return nil
}
func (c *Closers) Add(closer io.Closer) {
if closer != nil {
c.closers = append(c.closers, &closer)
}
}
func NewClosers() *Closers {
return &Closers{[]*io.Closer{}}
}
17 changes: 13 additions & 4 deletions server/common/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/net"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/pkg/errors"
"io"
"net/http"
Expand All @@ -33,22 +34,24 @@ var httpClient *http.Client

func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error {
if link.ReadSeekCloser != nil {
filename := file.GetName()
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, filename, url.PathEscape(filename)))
attachFileName(w, file)
http.ServeContent(w, r, file.GetName(), file.ModTime(), link.ReadSeekCloser)
defer link.ReadSeekCloser.Close()
return nil
} else if link.RangeReadCloser.RangeReader != nil {
attachFileName(w, file)
net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), link.RangeReadCloser.RangeReader)
defer func() {
if link.RangeReadCloser.Closer != nil {
link.RangeReadCloser.Closer.Close()
if link.RangeReadCloser.Closers != nil {
link.RangeReadCloser.Closers.Close()
}
}()
return nil
} else if link.Concurrency != 0 || link.PartSize != 0 {
attachFileName(w, file)
size := file.GetSize()
//var finalClosers model.Closers
finalClosers := utils.NewClosers()
header := net.ProcessHeader(&r.Header, &link.Header)
rangeReader := func(httpRange http_range.Range) (io.ReadCloser, error) {
down := net.NewDownloader(func(d *net.Downloader) {
Expand All @@ -62,9 +65,11 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.
HeaderRef: header,
}
rc, err := down.Download(context.Background(), req)
finalClosers.Add(*rc)
return *rc, err
}
net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), rangeReader)
defer finalClosers.Close()
return nil
} else {
//transparent proxy
Expand All @@ -89,3 +94,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.
return nil
}
}
func attachFileName(w http.ResponseWriter, file model.Obj) {
fileName := file.GetName()
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, fileName, url.PathEscape(fileName)))
}

0 comments on commit 15b7169

Please sign in to comment.