From b925960e72275f9cf944e8965ee18838a3b7fc6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Magiera?= Date: Sun, 9 Dec 2018 16:40:06 +0100 Subject: [PATCH] Allow skipping entries in multipartIterator --- multifilereader_test.go | 31 +++++++++++++++++++++++++++ multipartfile.go | 47 +++++++++++++++++++++++++++++------------ 2 files changed, 65 insertions(+), 13 deletions(-) diff --git a/multifilereader_test.go b/multifilereader_test.go index a7d6421..31ef5c0 100644 --- a/multifilereader_test.go +++ b/multifilereader_test.go @@ -72,6 +72,37 @@ func TestMultiFileReaderToMultiFile(t *testing.T) { } } +func TestMultiFileReaderToMultiFileSkip(t *testing.T) { + mfr := getTestMultiFileReader(t) + mpReader := multipart.NewReader(mfr, mfr.Boundary()) + mf, err := NewFileFromPartReader(mpReader, multipartFormdataType) + if err != nil { + t.Fatal(err) + } + + md, ok := mf.(Directory) + if !ok { + t.Fatal("Expected a directory") + } + it := md.Entries() + + if !it.Next() || it.Name() != "beep.txt" { + t.Fatal("iterator didn't work as expected") + } + + if !it.Next() || it.Name() != "boop" || DirFrom(it) == nil { + t.Fatal("iterator didn't work as expected") + } + + if !it.Next() || it.Name() != "file.txt" || DirFrom(it) != nil || it.Err() != nil { + t.Fatal("iterator didn't work as expected") + } + + if it.Next() || it.Err() != nil { + t.Fatal("iterator didn't work as expected") + } +} + func TestOutput(t *testing.T) { mfr := getTestMultiFileReader(t) mpReader := &peekReader{r: multipart.NewReader(mfr, mfr.Boundary())} diff --git a/multipartfile.go b/multipartfile.go index 14b4cba..7bb535f 100644 --- a/multipartfile.go +++ b/multipartfile.go @@ -8,6 +8,7 @@ import ( "mime/multipart" "net/url" "path" + "strings" ) const ( @@ -22,6 +23,7 @@ const ( ) var ErrPartOutsideParent = errors.New("file outside parent dir") +var ErrPartInChildTree = errors.New("file in child tree") // MultipartFile implements Node, and is created from a `multipart.Part`. // @@ -56,7 +58,19 @@ func newFileFromPart(parent string, part *multipart.Part, reader PartReader) (st } dir, base := path.Split(f.fileName()) - if path.Clean(dir) != path.Clean(parent) { + dir = path.Clean(dir) + parent = path.Clean(parent) + if dir == "." { + dir = "" + } + if parent == "." { + parent = "" + } + + if dir != parent { + if strings.HasPrefix(dir, parent) { + return "", nil, ErrPartInChildTree + } return "", nil, ErrPartOutsideParent } @@ -118,21 +132,28 @@ func (it *multipartIterator) Next() bool { if it.f.Reader == nil { return false } - part, err := it.f.Reader.NextPart() - if err != nil { - if err == io.EOF { + var part *multipart.Part + for { + var err error + part, err = it.f.Reader.NextPart() + if err != nil { + if err == io.EOF { + return false + } + it.err = err return false } - it.err = err - return false - } - name, cf, err := newFileFromPart(it.f.fileName(), part, it.f.Reader) - if err != ErrPartOutsideParent { - it.curFile = cf - it.curName = name - it.err = err - return err == nil + name, cf, err := newFileFromPart(it.f.fileName(), part, it.f.Reader) + if err == ErrPartOutsideParent { + break + } + if err != ErrPartInChildTree { + it.curFile = cf + it.curName = name + it.err = err + return err == nil + } } // we read too much, try to fix this