From dfca8612298353f1c32a8bb23093a43ecf18f31b Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Thu, 2 Aug 2018 16:18:39 +1000 Subject: [PATCH] use json.Decoder.UseNumber() when unmarshalling vars --- Gopkg.lock | 4 ++-- example/todo/todo_test.go | 19 +++++++++++-------- graphql/float.go | 5 +++++ graphql/id.go | 3 +++ graphql/int.go | 5 +++-- handler/graphql.go | 11 +++++++++-- handler/websocket.go | 11 +++++++++-- 7 files changed, 42 insertions(+), 16 deletions(-) diff --git a/Gopkg.lock b/Gopkg.lock index b61ca468f89..70a1e151a3f 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -176,7 +176,7 @@ [[projects]] branch = "master" - digest = "1:07b9c7ab78ec48004cf88b4d93c10cad32c7e21b49f99db418fb0ebee0f98b90" + digest = "1:7acc3d2f02aed0095986646435472af4c2e2db42ad730aa78cae780aba5b59f9" name = "github.com/vektah/gqlparser" packages = [ ".", @@ -188,7 +188,7 @@ "validator/rules", ] pruneopts = "UT" - revision = "8dd97c3c1c0357d7602ca9435a8a6cb5e57b4171" + revision = "6298a7d57d3de59879b323d6a8cb66280826906f" [[projects]] branch = "master" diff --git a/example/todo/todo_test.go b/example/todo/todo_test.go index ca244274ad4..368536d7164 100644 --- a/example/todo/todo_test.go +++ b/example/todo/todo_test.go @@ -14,20 +14,23 @@ func TestTodo(t *testing.T) { srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(New()))) c := client.New(srv.URL) - t.Run("create a new todo", func(t *testing.T) { - var resp struct { - CreateTodo struct{ ID int } - } - c.MustPost(`mutation { createTodo(todo:{text:"Fery important"}) { id } }`, &resp) + var resp struct { + CreateTodo struct{ ID int } + } + c.MustPost(`mutation { createTodo(todo:{text:"Fery important"}) { id } }`, &resp) - require.Equal(t, 4, resp.CreateTodo.ID) - }) + require.Equal(t, 4, resp.CreateTodo.ID) t.Run("update the todo text", func(t *testing.T) { var resp struct { UpdateTodo struct{ Text string } } - c.MustPost(`mutation { updateTodo(id: 4, changes:{text:"Very important"}) { text } }`, &resp) + c.MustPost( + `mutation($id: Int!, $text: String!) { updateTodo(id: $id, changes:{text:$text}) { text } }`, + &resp, + client.Var("id", 4), + client.Var("text", "Very important"), + ) require.Equal(t, "Very important", resp.UpdateTodo.Text) }) diff --git a/graphql/float.go b/graphql/float.go index c08b490a4f1..d204335c44b 100644 --- a/graphql/float.go +++ b/graphql/float.go @@ -1,6 +1,7 @@ package graphql import ( + "encoding/json" "fmt" "io" "strconv" @@ -18,8 +19,12 @@ func UnmarshalFloat(v interface{}) (float64, error) { return strconv.ParseFloat(v, 64) case int: return float64(v), nil + case int64: + return float64(v), nil case float64: return v, nil + case json.Number: + return strconv.ParseFloat(string(v), 64) default: return 0, fmt.Errorf("%T is not an float", v) } diff --git a/graphql/id.go b/graphql/id.go index 7958670cdc7..a5a7960f346 100644 --- a/graphql/id.go +++ b/graphql/id.go @@ -1,6 +1,7 @@ package graphql import ( + "encoding/json" "fmt" "io" "strconv" @@ -15,6 +16,8 @@ func UnmarshalID(v interface{}) (string, error) { switch v := v.(type) { case string: return v, nil + case json.Number: + return string(v), nil case int: return strconv.Itoa(v), nil case float64: diff --git a/graphql/int.go b/graphql/int.go index 6b2da63a1b8..ff87574cab7 100644 --- a/graphql/int.go +++ b/graphql/int.go @@ -1,6 +1,7 @@ package graphql import ( + "encoding/json" "fmt" "io" "strconv" @@ -20,8 +21,8 @@ func UnmarshalInt(v interface{}) (int, error) { return v, nil case int64: return int(v), nil - case float64: - return int(v), nil + case json.Number: + return strconv.Atoi(string(v)) default: return 0, fmt.Errorf("%T is not an int", v) } diff --git a/handler/graphql.go b/handler/graphql.go index fb943b68aaa..beee5c0865d 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "strings" @@ -140,13 +141,13 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc reqParams.OperationName = r.URL.Query().Get("operationName") if variables := r.URL.Query().Get("variables"); variables != "" { - if err := json.Unmarshal([]byte(variables), &reqParams.Variables); err != nil { + if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil { sendErrorf(w, http.StatusBadRequest, "variables could not be decoded") return } } case http.MethodPost: - if err := json.NewDecoder(r.Body).Decode(&reqParams); err != nil { + if err := jsonDecode(r.Body, &reqParams); err != nil { sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) return } @@ -201,6 +202,12 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc }) } +func jsonDecode(r io.Reader, val interface{}) error { + dec := json.NewDecoder(r) + dec.UseNumber() + return dec.Decode(val) +} + func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) { w.WriteHeader(code) b, err := json.Marshal(&graphql.Response{Errors: errors}) diff --git a/handler/websocket.go b/handler/websocket.go index 8af56ae7f6e..32555557175 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "context" "encoding/json" "fmt" @@ -132,7 +133,7 @@ func (c *wsConnection) run() { func (c *wsConnection) subscribe(message *operationMessage) bool { var reqParams params - if err := json.Unmarshal(message.Payload, &reqParams); err != nil { + if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil { c.sendConnectionError("invalid json") return false } @@ -228,11 +229,17 @@ func (c *wsConnection) sendConnectionError(format string, args ...interface{}) { } func (c *wsConnection) readOp() *operationMessage { + _, r, err := c.conn.NextReader() + if err != nil { + c.sendConnectionError("invalid json") + return nil + } message := operationMessage{} - if err := c.conn.ReadJSON(&message); err != nil { + if err := jsonDecode(r, &message); err != nil { c.sendConnectionError("invalid json") return nil } + return &message }