diff --git a/docs/content/reference/scalars.md b/docs/content/reference/scalars.md index 7b098b06896..36f035b9f32 100644 --- a/docs/content/reference/scalars.md +++ b/docs/content/reference/scalars.md @@ -35,7 +35,7 @@ scalar Upload Maps a `Upload` GraphQL scalar to a `graphql.Upload` struct, defined as follows: ``` type Upload struct { - File multipart.File + File io.ReadCloser Filename string Size int64 } diff --git a/example/fileupload/fileupload_test.go b/example/fileupload/fileupload_test.go index e91029375f9..c0d735aaa10 100644 --- a/example/fileupload/fileupload_test.go +++ b/example/fileupload/fileupload_test.go @@ -51,14 +51,13 @@ func TestFileUpload(t *testing.T) { resp, err := client.Do(req) require.Nil(t, err) - defer func() { - _ = resp.Body.Close() - }() require.Equal(t, http.StatusOK, resp.StatusCode) responseBody, err := ioutil.ReadAll(resp.Body) require.Nil(t, err) responseString := string(responseBody) require.Equal(t, `{"data":{"singleUpload":{"id":1,"name":"a.txt","content":"test"}}}`, responseString) + err = resp.Body.Close() + require.Nil(t, err) }) t.Run("valid single file upload with payload", func(t *testing.T) { @@ -94,13 +93,12 @@ func TestFileUpload(t *testing.T) { resp, err := client.Do(req) require.Nil(t, err) - defer func() { - _ = resp.Body.Close() - }() require.Equal(t, http.StatusOK, resp.StatusCode) responseBody, err := ioutil.ReadAll(resp.Body) require.Nil(t, err) require.Equal(t, `{"data":{"singleUploadWithPayload":{"id":1,"name":"a.txt","content":"test"}}}`, string(responseBody)) + err = resp.Body.Close() + require.Nil(t, err) }) t.Run("valid file list upload", func(t *testing.T) { @@ -145,13 +143,12 @@ func TestFileUpload(t *testing.T) { resp, err := client.Do(req) require.Nil(t, err) - defer func() { - _ = resp.Body.Close() - }() require.Equal(t, http.StatusOK, resp.StatusCode) responseBody, err := ioutil.ReadAll(resp.Body) require.Nil(t, err) require.Equal(t, `{"data":{"multipleUpload":[{"id":1,"name":"a.txt","content":"test1"},{"id":2,"name":"b.txt","content":"test2"}]}}`, string(responseBody)) + err = resp.Body.Close() + require.Nil(t, err) }) t.Run("valid file list upload with payload", func(t *testing.T) { @@ -200,13 +197,12 @@ func TestFileUpload(t *testing.T) { resp, err := client.Do(req) require.Nil(t, err) - defer func() { - _ = resp.Body.Close() - }() require.Equal(t, http.StatusOK, resp.StatusCode) responseBody, err := ioutil.ReadAll(resp.Body) require.Nil(t, err) require.Equal(t, `{"data":{"multipleUploadWithPayload":[{"id":1,"name":"a.txt","content":"test1"},{"id":2,"name":"b.txt","content":"test2"}]}}`, string(responseBody)) + err = resp.Body.Close() + require.Nil(t, err) }) t.Run("valid file list upload with payload and file reuse", func(t *testing.T) { @@ -220,7 +216,6 @@ func TestFileUpload(t *testing.T) { require.NotNil(t, req[i].File) require.NotNil(t, req[i].File.File) ids = append(ids, req[i].ID) - req[i].File.File.Seek(0, 0) content, err := ioutil.ReadAll(req[i].File.File) require.Nil(t, err) contents = append(contents, string(content)) @@ -235,8 +230,6 @@ func TestFileUpload(t *testing.T) { return resp, nil }, } - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolver}), handler.UploadMaxMemory(2))) - defer srv.Close() operations := `{ "query": "mutation($req: [UploadFile!]!) { multipleUploadWithPayload(req: $req) { id, name, content } }", "variables": { "req": [ { "id": 1, "file": null }, { "id": 2, "file": null } ] } }` mapData := `{ "0": ["variables.req.0.file", "variables.req.1.file"] }` @@ -247,17 +240,29 @@ func TestFileUpload(t *testing.T) { content: "test1", }, } - req := createUploadRequest(t, srv.URL, operations, mapData, files) - resp, err := client.Do(req) - require.Nil(t, err) - defer func() { - _ = resp.Body.Close() - }() - require.Equal(t, http.StatusOK, resp.StatusCode) - responseBody, err := ioutil.ReadAll(resp.Body) - require.Nil(t, err) - require.Equal(t, `{"data":{"multipleUploadWithPayload":[{"id":1,"name":"a.txt","content":"test1"},{"id":2,"name":"a.txt","content":"test1"}]}}`, string(responseBody)) + test := func(uploadMaxMemory int64) { + memory := handler.UploadMaxMemory(uploadMaxMemory) + srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolver}), memory)) + defer srv.Close() + req := createUploadRequest(t, srv.URL, operations, mapData, files) + resp, err := client.Do(req) + require.Nil(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + responseBody, err := ioutil.ReadAll(resp.Body) + require.Nil(t, err) + require.Equal(t, `{"data":{"multipleUploadWithPayload":[{"id":1,"name":"a.txt","content":"test1"},{"id":2,"name":"a.txt","content":"test1"}]}}`, string(responseBody)) + err = resp.Body.Close() + require.Nil(t, err) + } + + t.Run("payload smaller than UploadMaxMemory, stored in memory", func(t *testing.T){ + test(5000) + }) + + t.Run("payload bigger than UploadMaxMemory, persisted to disk", func(t *testing.T){ + test(2) + }) }) } diff --git a/graphql/upload.go b/graphql/upload.go index e7919f1b415..22d61031495 100644 --- a/graphql/upload.go +++ b/graphql/upload.go @@ -3,11 +3,10 @@ package graphql import ( "fmt" "io" - "mime/multipart" ) type Upload struct { - File multipart.File + File io.Reader Filename string Size int64 } diff --git a/handler/graphql.go b/handler/graphql.go index e34329a8f05..df99efc2cfa 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -1,12 +1,15 @@ package handler import ( + "bytes" "context" "encoding/json" "errors" "fmt" "io" + "io/ioutil" "net/http" + "os" "strconv" "strings" "time" @@ -369,7 +372,17 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case http.MethodPost: contentType := strings.SplitN(r.Header.Get("Content-Type"), ";", 2)[0] if contentType == "multipart/form-data" { - if err := processMultipart(w, r, &reqParams, gh.cfg.uploadMaxSize, gh.cfg.uploadMaxMemory); err != nil { + var closers []io.Closer + var tmpFiles []string + defer func() { + for i := len(closers) - 1; 0 <= i; i-- { + _ = closers[i].Close() + } + for _, tmpFile := range tmpFiles { + _ = os.Remove(tmpFile) + } + }() + if err := processMultipart(w, r, &reqParams, &closers, &tmpFiles, gh.cfg.uploadMaxSize, gh.cfg.uploadMaxMemory); err != nil { sendErrorf(w, http.StatusBadRequest, "multipart body could not be decoded: "+err.Error()) return } @@ -534,7 +547,7 @@ func sendErrorf(w http.ResponseWriter, code int, format string, args ...interfac sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)}) } -func processMultipart(w http.ResponseWriter, r *http.Request, request *params, uploadMaxSize, uploadMaxMemory int64) error { +func processMultipart(w http.ResponseWriter, r *http.Request, request *params, closers *[]io.Closer, tmpFiles *[]string, uploadMaxSize, uploadMaxMemory int64) error { var err error if r.ContentLength > uploadMaxSize { return errors.New("failed to parse multipart form, request body too large") @@ -546,6 +559,7 @@ func processMultipart(w http.ResponseWriter, r *http.Request, request *params, u } return errors.New("failed to parse multipart form") } + *closers = append(*closers, r.Body) if err = jsonDecode(strings.NewReader(r.Form.Get("operations")), &request); err != nil { return errors.New("operations form field could not be decoded") @@ -558,38 +572,86 @@ func processMultipart(w http.ResponseWriter, r *http.Request, request *params, u var upload graphql.Upload for key, paths := range uploadsMap { - err = func() error { - file, header, err := r.FormFile(key) - if err != nil { - return fmt.Errorf("failed to get key %s from form", key) - } - if len(paths) == 0 { - return fmt.Errorf("invalid empty operations paths list for key %s", key) - } + if len(paths) == 0 { + return fmt.Errorf("invalid empty operations paths list for key %s", key) + } + file, header, err := r.FormFile(key) + if err != nil { + return fmt.Errorf("failed to get key %s from form", key) + } + *closers = append(*closers, file) + + if len(paths) == 1 { upload = graphql.Upload{ File: file, Size: header.Size, Filename: header.Filename, } - for _, path := range paths { - if !strings.HasPrefix(path, "variables.") { - return fmt.Errorf("invalid operations paths for key %s", key) + err = addUploadToOperations(request, upload, key, paths[0]) + if err != nil { + return err + } + } else { + if r.ContentLength < uploadMaxMemory { + fileContent, err := ioutil.ReadAll(file) + if err != nil { + return fmt.Errorf("failed to read file for key %s", key) + } + for _, path := range paths { + upload = graphql.Upload{ + File: ioutil.NopCloser(bytes.NewReader(fileContent)), + Size: header.Size, + Filename: header.Filename, + } + err = addUploadToOperations(request, upload, key, path) + if err != nil { + return err + } + } + } else { + tmpFile, err := ioutil.TempFile(os.TempDir(), "gqlgen-") + if err != nil { + return fmt.Errorf("failed to create temp file for key %s", key) } - err = addUploadToOperations(request, upload, path) + tmpName := tmpFile.Name() + *tmpFiles = append(*tmpFiles, tmpName) + _, err = io.Copy(tmpFile, file) if err != nil { - return err + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("failed to copy to temp file and close temp file for key %s", key) + } + return fmt.Errorf("failed to copy to temp file for key %s", key) + } + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("failed to close temp file for key %s", key) + } + for _, path := range paths { + pathTmpFile, err := os.Open(tmpName) + if err != nil { + return fmt.Errorf("failed to open temp file for key %s", key) + } + *closers = append(*closers, pathTmpFile) + upload = graphql.Upload{ + File: pathTmpFile, + Size: header.Size, + Filename: header.Filename, + } + err = addUploadToOperations(request, upload, key, path) + if err != nil { + return err + } } } - return nil - }() - if err != nil { - return err } } return nil } -func addUploadToOperations(request *params, upload graphql.Upload, path string) error { +func addUploadToOperations(request *params, upload graphql.Upload, key, path string) error { + if !strings.HasPrefix(path, "variables.") { + return fmt.Errorf("invalid operations paths for key %s", key) + } + var ptr interface{} = request.Variables parts := strings.Split(path, ".") @@ -597,7 +659,7 @@ func addUploadToOperations(request *params, upload graphql.Upload, path string) for i, p := range parts[1:] { last := i == len(parts)-2 if ptr == nil { - return fmt.Errorf("variables is missing, path: %s", path) + return fmt.Errorf("path is missing \"variables.\" prefix, key: %s, path: %s", key, path) } if index, parseNbrErr := strconv.Atoi(p); parseNbrErr == nil { if last { diff --git a/handler/graphql_test.go b/handler/graphql_test.go index fd353cff1b8..82fd1606e21 100644 --- a/handler/graphql_test.go +++ b/handler/graphql_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "io/ioutil" "mime/multipart" "net/http" @@ -304,30 +305,41 @@ func TestFileUpload(t *testing.T) { }) t.Run("valid file list upload with payload and file reuse", func(t *testing.T) { - mock := &executableSchemaMock{ - MutationFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - require.Equal(t, len(op.VariableDefinitions), 1) - require.Equal(t, op.VariableDefinitions[0].Variable, "req") - return &graphql.Response{Data: []byte(`{"multipleUploadWithPayload":[{"id":1},{"id":2}]}`)} - }, - } - handler := GraphQL(mock) + test := func (uploadMaxMemory int64) { + mock := &executableSchemaMock{ + MutationFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { + require.Equal(t, len(op.VariableDefinitions), 1) + require.Equal(t, op.VariableDefinitions[0].Variable, "req") + return &graphql.Response{Data: []byte(`{"multipleUploadWithPayload":[{"id":1},{"id":2}]}`)} + }, + } + maxMemory := UploadMaxMemory(5000) + handler := GraphQL(mock, maxMemory) + + operations := `{ "query": "mutation($req: [UploadFile!]!) { multipleUploadWithPayload(req: $req) { id } }", "variables": { "req": [ { "id": 1, "file": null }, { "id": 2, "file": null } ] } }` + mapData := `{ "0": ["variables.req.0.file", "variables.req.1.file"] }` + files := []file{ + { + mapKey: "0", + name: "a.txt", + content: "test1", + }, + } + req := createUploadRequest(t, operations, mapData, files) - operations := `{ "query": "mutation($req: [UploadFile!]!) { multipleUploadWithPayload(req: $req) { id } }", "variables": { "req": [ { "id": 1, "file": null }, { "id": 2, "file": null } ] } }` - mapData := `{ "0": ["variables.req.0.file", "variables.req.1.file"] }` - files := []file{ - { - mapKey: "0", - name: "a.txt", - content: "test1", - }, + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + require.Equal(t, `{"data":{"multipleUploadWithPayload":[{"id":1},{"id":2}]}}`, resp.Body.String()) } - req := createUploadRequest(t, operations, mapData, files) - resp := httptest.NewRecorder() - handler.ServeHTTP(resp, req) - require.Equal(t, http.StatusOK, resp.Code) - require.Equal(t, `{"data":{"multipleUploadWithPayload":[{"id":1},{"id":2}]}}`, resp.Body.String()) + t.Run("payload smaller than UploadMaxMemory, stored in memory", func(t *testing.T){ + test(5000) + }) + + t.Run("payload bigger than UploadMaxMemory, persisted to disk", func(t *testing.T){ + test(2) + }) }) } @@ -342,6 +354,17 @@ func TestProcessMultipart(t *testing.T) { }, } + cleanUp := func(t *testing.T, closers []io.Closer, tmpFiles []string) { + for i := len(closers) - 1; 0 <= i; i-- { + err := closers[i].Close() + require.Nil(t, err) + } + for _, tmpFiles := range tmpFiles { + err := os.Remove(tmpFiles) + require.Nil(t, err) + } + } + t.Run("fail to parse multipart", func(t *testing.T) { req := &http.Request{ Method: "POST", @@ -349,10 +372,13 @@ func TestProcessMultipart(t *testing.T) { Body: ioutil.NopCloser(new(bytes.Buffer)), } var reqParams params + var closers []io.Closer + var tmpFiles []string w := httptest.NewRecorder() - err := processMultipart(w, req, &reqParams, DefaultUploadMaxSize, DefaultUploadMaxMemory) + err := processMultipart(w, req, &reqParams, &closers, &tmpFiles, DefaultUploadMaxSize, DefaultUploadMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "failed to parse multipart form") + cleanUp(t, closers, tmpFiles) }) t.Run("fail parse operation", func(t *testing.T) { @@ -360,10 +386,13 @@ func TestProcessMultipart(t *testing.T) { req := createUploadRequest(t, operations, validMap, validFiles) var reqParams params + var closers []io.Closer + var tmpFiles []string w := httptest.NewRecorder() - err := processMultipart(w, req, &reqParams, DefaultUploadMaxSize, DefaultUploadMaxMemory) + err := processMultipart(w, req, &reqParams, &closers, &tmpFiles, DefaultUploadMaxSize, DefaultUploadMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "operations form field could not be decoded") + cleanUp(t, closers, tmpFiles) }) t.Run("fail parse map", func(t *testing.T) { @@ -371,10 +400,13 @@ func TestProcessMultipart(t *testing.T) { req := createUploadRequest(t, validOperations, mapData, validFiles) var reqParams params + var closers []io.Closer + var tmpFiles []string w := httptest.NewRecorder() - err := processMultipart(w, req, &reqParams, DefaultUploadMaxSize, DefaultUploadMaxMemory) + err := processMultipart(w, req, &reqParams, &closers, &tmpFiles, DefaultUploadMaxSize, DefaultUploadMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "map form field could not be decoded") + cleanUp(t, closers, tmpFiles) }) t.Run("fail missing file", func(t *testing.T) { @@ -382,10 +414,13 @@ func TestProcessMultipart(t *testing.T) { req := createUploadRequest(t, validOperations, validMap, files) var reqParams params + var closers []io.Closer + var tmpFiles []string w := httptest.NewRecorder() - err := processMultipart(w, req, &reqParams, DefaultUploadMaxSize, DefaultUploadMaxMemory) + err := processMultipart(w, req, &reqParams, &closers, &tmpFiles, DefaultUploadMaxSize, DefaultUploadMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "failed to get key 0 from form") + cleanUp(t, closers, tmpFiles) }) t.Run("fail map entry with invalid operations paths prefix", func(t *testing.T) { @@ -393,29 +428,37 @@ func TestProcessMultipart(t *testing.T) { req := createUploadRequest(t, validOperations, mapData, validFiles) var reqParams params + var closers []io.Closer + var tmpFiles []string w := httptest.NewRecorder() - err := processMultipart(w, req, &reqParams, DefaultUploadMaxSize, DefaultUploadMaxMemory) + err := processMultipart(w, req, &reqParams, &closers, &tmpFiles, DefaultUploadMaxSize, DefaultUploadMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "invalid operations paths for key 0") + cleanUp(t, closers, tmpFiles) }) t.Run("fail parse request big body", func(t *testing.T) { req := createUploadRequest(t, validOperations, validMap, validFiles) var reqParams params + var closers []io.Closer + var tmpFiles []string w := httptest.NewRecorder() var smallMaxSize int64 = 2 - err := processMultipart(w, req, &reqParams, smallMaxSize, DefaultUploadMaxMemory) + err := processMultipart(w, req, &reqParams, &closers, &tmpFiles, smallMaxSize, DefaultUploadMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "failed to parse multipart form, request body too large") + cleanUp(t, closers, tmpFiles) }) t.Run("valid request", func(t *testing.T) { req := createUploadRequest(t, validOperations, validMap, validFiles) var reqParams params + var closers []io.Closer + var tmpFiles []string w := httptest.NewRecorder() - err := processMultipart(w, req, &reqParams, DefaultUploadMaxSize, DefaultUploadMaxMemory) + err := processMultipart(w, req, &reqParams, &closers, &tmpFiles, DefaultUploadMaxSize, DefaultUploadMaxMemory) require.Nil(t, err) require.Equal(t, "mutation ($file: Upload!) { singleUpload(file: $file) { id } }", reqParams.Query) require.Equal(t, "", reqParams.OperationName) @@ -428,9 +471,10 @@ func TestProcessMultipart(t *testing.T) { content, err := ioutil.ReadAll(reqParamsFile.File) require.Nil(t, err) require.Equal(t, "test1", string(content)) + cleanUp(t, closers, tmpFiles) }) - t.Run("valid request with two values", func(t *testing.T) { + t.Run("valid file list upload with payload and file reuse", func(t *testing.T) { operations := `{ "query": "mutation($req: [UploadFile!]!) { multipleUploadWithPayload(req: $req) { id } }", "variables": { "req": [ { "id": 1, "file": null }, { "id": 2, "file": null } ] } }` mapData := `{ "0": ["variables.req.0.file", "variables.req.1.file"] }` files := []file{ @@ -442,33 +486,46 @@ func TestProcessMultipart(t *testing.T) { } req := createUploadRequest(t, operations, mapData, files) - var reqParams params - w := httptest.NewRecorder() - err := processMultipart(w, req, &reqParams, DefaultUploadMaxSize, 2) - require.Nil(t, err) - require.Equal(t, "mutation($req: [UploadFile!]!) { multipleUploadWithPayload(req: $req) { id } }", reqParams.Query) - require.Equal(t, "", reqParams.OperationName) - require.Equal(t, 1, len(reqParams.Variables)) - require.NotNil(t, reqParams.Variables["req"]) - reqParamsFile, ok := reqParams.Variables["req"].([]interface{}) - require.True(t, ok) - require.Equal(t, 2, len(reqParamsFile)) - for i, item := range reqParamsFile { - itemMap := item.(map[string]interface{}) - require.Equal(t, fmt.Sprint(itemMap["id"]), fmt.Sprint(i+1)) - file := itemMap["file"].(graphql.Upload) - require.Equal(t, "a.txt", file.Filename) - require.Equal(t, int64(len("test1")), file.Size) - _, err = file.File.Seek(0, 0) - require.Nil(t, err) - content, err := ioutil.ReadAll(file.File) + test := func(uploadMaxMemory int64) { + var reqParams params + var closers []io.Closer + var tmpFiles []string + w := httptest.NewRecorder() + err := processMultipart(w, req, &reqParams, &closers, &tmpFiles, DefaultUploadMaxSize, uploadMaxMemory) require.Nil(t, err) - require.Equal(t, "test1", string(content)) + require.Equal(t, "mutation($req: [UploadFile!]!) { multipleUploadWithPayload(req: $req) { id } }", reqParams.Query) + require.Equal(t, "", reqParams.OperationName) + require.Equal(t, 1, len(reqParams.Variables)) + require.NotNil(t, reqParams.Variables["req"]) + reqParamsFile, ok := reqParams.Variables["req"].([]interface{}) + require.True(t, ok) + require.Equal(t, 2, len(reqParamsFile)) + for i, item := range reqParamsFile { + itemMap := item.(map[string]interface{}) + require.Equal(t, fmt.Sprint(itemMap["id"]), fmt.Sprint(i+1)) + file := itemMap["file"].(graphql.Upload) + require.Equal(t, "a.txt", file.Filename) + require.Equal(t, int64(len("test1")), file.Size) + require.Nil(t, err) + content, err := ioutil.ReadAll(file.File) + require.Nil(t, err) + require.Equal(t, "test1", string(content)) + } + cleanUp(t, closers, tmpFiles) } + + t.Run("payload smaller than UploadMaxMemory, stored in memory", func(t *testing.T){ + test(5000) + }) + + t.Run("payload bigger than UploadMaxMemory, persisted to disk", func(t *testing.T){ + test(2) + }) }) } func TestAddUploadToOperations(t *testing.T) { + key := "0" t.Run("fail missing all variables", func(t *testing.T) { file, _ := os.Open("path/to/file") @@ -480,9 +537,9 @@ func TestAddUploadToOperations(t *testing.T) { Size: int64(5), } path := "variables.req.0.file" - err := addUploadToOperations(request, upload, path) + err := addUploadToOperations(request, upload, key, path) require.NotNil(t, err) - require.Equal(t, "variables is missing, path: variables.req.0.file", err.Error()) + require.Equal(t, "path is missing \"variables.\" prefix, key: 0, path: variables.req.0.file", err.Error()) }) t.Run("valid variable", func(t *testing.T) { @@ -506,7 +563,7 @@ func TestAddUploadToOperations(t *testing.T) { } path := "variables.file" - err := addUploadToOperations(request, upload, path) + err := addUploadToOperations(request, upload, key, path) require.Nil(t, err) require.Equal(t, request, expected) @@ -541,7 +598,7 @@ func TestAddUploadToOperations(t *testing.T) { } path := "variables.req.0.file" - err := addUploadToOperations(request, upload, path) + err := addUploadToOperations(request, upload, key, path) require.Nil(t, err) require.Equal(t, request, expected)