diff --git a/drivers/base/util.go b/drivers/base/util.go index 09040fcc43b..22f11114481 100644 --- a/drivers/base/util.go +++ b/drivers/base/util.go @@ -1,19 +1 @@ package base - -import "io" - -type Closers struct { - closers []io.Closer -} - -func (c *Closers) Close() (err error) { - for _, closer := range c.closers { - if closer != nil { - _ = closer.Close() - } - } - return nil -} -func (c *Closers) Add(closer io.Closer) { - c.closers = append(c.closers, closer) -} diff --git a/drivers/crypt/driver.go b/drivers/crypt/driver.go index 8fe84a25417..a2d5a17c3e9 100644 --- a/drivers/crypt/driver.go +++ b/drivers/crypt/driver.go @@ -236,7 +236,7 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( return nil, fmt.Errorf("the remote storage driver need to be enhanced to support encrytion") } remoteFileSize := remoteFile.GetSize() - var remoteCloser io.Closer + remoteClosers := utils.NewClosers() rangeReaderFunc := func(ctx context.Context, underlyingOffset, underlyingLength int64) (io.ReadCloser, error) { length := underlyingLength if underlyingLength >= 0 && underlyingOffset+underlyingLength >= remoteFileSize { @@ -245,6 +245,7 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( if remoteLink.RangeReadCloser.RangeReader != nil { //remoteRangeReader, err := remoteReader, err := remoteLink.RangeReadCloser.RangeReader(http_range.Range{Start: underlyingOffset, Length: length}) + remoteClosers.Add(remoteLink.RangeReadCloser.Closers) if err != nil { return nil, err } @@ -255,8 +256,8 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( if err != nil { return nil, err } + //remoteClosers.Add(remoteLink.ReadSeekCloser) //keep reuse same ReadSeekCloser and close at last. - remoteCloser = remoteLink.ReadSeekCloser return io.NopCloser(remoteLink.ReadSeekCloser), nil } if len(remoteLink.URL) > 0 { @@ -265,6 +266,7 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( Header: remoteLink.Header, } response, err := RequestRangedHttp(args.HttpReq, rangedRemoteLink, underlyingOffset, length) + //remoteClosers.Add(response.Body) if err != nil { return nil, fmt.Errorf("remote storage http request failure,status: %d err:%s", response.StatusCode, err) } @@ -301,7 +303,7 @@ func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( return readSeeker, nil } - resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closer: remoteCloser} + resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: remoteClosers} resultLink := &model.Link{ Header: remoteLink.Header, RangeReadCloser: *resultRangeReadCloser, diff --git a/drivers/mega/driver.go b/drivers/mega/driver.go index 1af0a7e564f..b329d4873b4 100644 --- a/drivers/mega/driver.go +++ b/drivers/mega/driver.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/pkg/http_range" "github.com/rclone/rclone/lib/readers" "io" @@ -75,7 +74,7 @@ func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (* //} size := file.GetSize() - var finalClosers base.Closers + var finalClosers utils.Closers resultRangeReader := func(httpRange http_range.Range) (io.ReadCloser, error) { length := httpRange.Length if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size { @@ -98,7 +97,7 @@ func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (* return readers.NewLimitedReadCloser(oo, length), nil } - resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closer: &finalClosers} + resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: &finalClosers} resultLink := &model.Link{ RangeReadCloser: *resultRangeReadCloser, } diff --git a/internal/model/args.go b/internal/model/args.go index 014b371a8ac..41ac826ed13 100644 --- a/internal/model/args.go +++ b/internal/model/args.go @@ -2,6 +2,7 @@ package model import ( "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" "io" "net/http" "time" @@ -45,7 +46,7 @@ type FsOtherArgs struct { } type RangeReadCloser struct { RangeReader RangeReaderFunc - Closer io.Closer + Closers *utils.Closers } type WriterFunc func(w io.Writer) error diff --git a/internal/net/request.go b/internal/net/request.go index 522c771616d..bb5b68bb1ee 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -86,8 +86,9 @@ func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readClo // downloader is the implementation structure used internally by Downloader. type downloader struct { - ctx context.Context - cfg Downloader + ctx context.Context + cancel context.CancelFunc + cfg Downloader params *HttpRequestParams //http request params chunkChannel chan chunk //chunk chanel @@ -107,6 +108,7 @@ type downloader struct { // download performs the implementation of the object download across ranged GETs. func (d *downloader) download() (*io.ReadCloser, error) { + d.ctx, d.cancel = context.WithCancel(d.ctx) pos := d.params.Range.Start maxPos := d.params.Range.Start + d.params.Range.Length @@ -138,7 +140,7 @@ func (d *downloader) download() (*io.ReadCloser, error) { d.chunkChannel = make(chan chunk, d.cfg.Concurrency) for i := 0; i < d.cfg.Concurrency; i++ { - buf := NewBuf(d.cfg.PartSize, i) + buf := NewBuf(d.ctx, d.cfg.PartSize, i) d.bufs = append(d.bufs, buf) go d.downloadPart() } @@ -163,6 +165,7 @@ func (d *downloader) sendChunkTask() *chunk { // when the final reader Close, we interrupt func (d *downloader) interrupt() error { + d.cancel() if d.written != d.params.Range.Length { log.Debugf("Downloader interrupt before finish") if d.getErr() == nil { @@ -520,15 +523,16 @@ func (buf *Buffer) waitTillNewWrite(pos int) error { type Buf struct { buffer *Buffer // Buffer we read from size int //expected size + ctx context.Context } // NewBuf is a buffer that can have 1 read & 1 write at the same time. // when read is faster write, immediately feed data to read after written -func NewBuf(maxSize int, id int) *Buf { +func NewBuf(ctx context.Context, maxSize int, id int) *Buf { d := make([]byte, maxSize) buffer := &Buffer{data: d, id: id, notify: make(chan int)} buffer.reset() - return &Buf{buffer: buffer, size: maxSize} + return &Buf{ctx: ctx, buffer: buffer, size: maxSize} } func (br *Buf) Reset(size int) { @@ -540,6 +544,9 @@ func (br *Buf) GetId() int { } func (br *Buf) Read(p []byte) (n int, err error) { + if err := br.ctx.Err(); err != nil { + return 0, err + } if len(p) == 0 { return 0, nil } @@ -580,6 +587,9 @@ func (br *Buf) waitTillNewWrite(pos int) error { } func (br *Buf) Write(p []byte) (n int, err error) { + if err := br.ctx.Err(); err != nil { + return 0, err + } return br.buffer.Write(p) } func (br *Buf) Close() { diff --git a/pkg/utils/io.go b/pkg/utils/io.go index 3e1c81a4ab6..7af7136aa40 100644 --- a/pkg/utils/io.go +++ b/pkg/utils/io.go @@ -164,3 +164,24 @@ func Retry(attempts int, sleep time.Duration, f func() error) (err error) { } return fmt.Errorf("after %d attempts, last error: %s", attempts, err) } + +type Closers struct { + closers []*io.Closer +} + +func (c *Closers) Close() (err error) { + for _, closer := range c.closers { + if closer != nil { + _ = (*closer).Close() + } + } + return nil +} +func (c *Closers) Add(closer io.Closer) { + if closer != nil { + c.closers = append(c.closers, &closer) + } +} +func NewClosers() *Closers { + return &Closers{[]*io.Closer{}} +} diff --git a/server/common/proxy.go b/server/common/proxy.go index ec2229fa20c..f6148860717 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -7,6 +7,7 @@ import ( "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/net" "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/pkg/errors" "io" "net/http" @@ -33,22 +34,24 @@ var httpClient *http.Client func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error { if link.ReadSeekCloser != nil { - filename := file.GetName() - w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, filename, url.PathEscape(filename))) + attachFileName(w, file) http.ServeContent(w, r, file.GetName(), file.ModTime(), link.ReadSeekCloser) defer link.ReadSeekCloser.Close() return nil } else if link.RangeReadCloser.RangeReader != nil { + attachFileName(w, file) net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), link.RangeReadCloser.RangeReader) defer func() { - if link.RangeReadCloser.Closer != nil { - link.RangeReadCloser.Closer.Close() + if link.RangeReadCloser.Closers != nil { + link.RangeReadCloser.Closers.Close() } }() return nil } else if link.Concurrency != 0 || link.PartSize != 0 { + attachFileName(w, file) size := file.GetSize() //var finalClosers model.Closers + finalClosers := utils.NewClosers() header := net.ProcessHeader(&r.Header, &link.Header) rangeReader := func(httpRange http_range.Range) (io.ReadCloser, error) { down := net.NewDownloader(func(d *net.Downloader) { @@ -62,9 +65,11 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. HeaderRef: header, } rc, err := down.Download(context.Background(), req) + finalClosers.Add(*rc) return *rc, err } net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), rangeReader) + defer finalClosers.Close() return nil } else { //transparent proxy @@ -89,3 +94,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. return nil } } +func attachFileName(w http.ResponseWriter, file model.Obj) { + fileName := file.GetName() + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, fileName, url.PathEscape(fileName))) +}