diff --git a/example/starwars/generated.go b/example/starwars/generated.go index aed19d90e60..34e45563a7b 100644 --- a/example/starwars/generated.go +++ b/example/starwars/generated.go @@ -425,21 +425,17 @@ func (ec *executionContext) _mutation(sel []query.Selection, it *interface{}) js ec.Error(err) continue } - ec.wg.Add(1) - go func(i int, field collectedField) { - defer ec.wg.Done() - res, err := ec.resolvers.Mutation_createReview(ec.ctx, arg0, arg1) - if err != nil { - ec.Error(err) - return - } + res, err := ec.resolvers.Mutation_createReview(ec.ctx, arg0, arg1) + if err != nil { + ec.Error(err) + continue + } - if res == nil { - out.Values[i] = jsonw.Null - } else { - out.Values[i] = ec._review(field.Selections, res) - } - }(i, field) + if res == nil { + out.Values[i] = jsonw.Null + } else { + out.Values[i] = ec._review(field.Selections, res) + } default: panic("unknown field " + strconv.Quote(field.Name)) } diff --git a/example/starwars/starwars_test.go b/example/starwars/starwars_test.go index 7af1c28e7bb..65bcea0d08e 100644 --- a/example/starwars/starwars_test.go +++ b/example/starwars/starwars_test.go @@ -133,6 +133,29 @@ func TestStarwars(t *testing.T) { require.Equal(t, "Leia Organa", resp.Droid.FriendsConnection.Edges[2].Node.Name) }) + t.Run("mutations must be run in sequence", func(t *testing.T) { + var resp struct { + A struct{ Time string } + B struct{ Time string } + C struct{ Time string } + } + + c.MustPost(`mutation f{ + a:createReview(episode: NEWHOPE, review:{stars:1, commentary:"Blah blah"}) { + time + } + b:createReview(episode: NEWHOPE, review:{stars:1, commentary:"Blah blah"}) { + time + } + c:createReview(episode: NEWHOPE, review:{stars:1, commentary:"Blah blah"}) { + time + } + }`, &resp) + + require.NotEqual(t, resp.A.Time, resp.B.Time) + require.NotEqual(t, resp.C.Time, resp.B.Time) + }) + t.Run("introspection", func(t *testing.T) { // Make sure we can run the graphiql introspection query without errors c.MustPost(introspection.Query, nil) diff --git a/example/todo/generated.go b/example/todo/generated.go index 6e75ca3f48b..5605a449305 100644 --- a/example/todo/generated.go +++ b/example/todo/generated.go @@ -106,17 +106,13 @@ func (ec *executionContext) _myMutation(sel []query.Selection, it *interface{}) } arg0 = tmp2 } - ec.wg.Add(1) - go func(i int, field collectedField) { - defer ec.wg.Done() - res, err := ec.resolvers.MyMutation_createTodo(ec.ctx, arg0) - if err != nil { - ec.Error(err) - return - } + res, err := ec.resolvers.MyMutation_createTodo(ec.ctx, arg0) + if err != nil { + ec.Error(err) + continue + } - out.Values[i] = ec._todo(field.Selections, &res) - }(i, field) + out.Values[i] = ec._todo(field.Selections, &res) case "updateTodo": var arg0 int if tmp, ok := field.Args["id"]; ok { @@ -131,21 +127,17 @@ func (ec *executionContext) _myMutation(sel []query.Selection, it *interface{}) if tmp, ok := field.Args["changes"]; ok { arg1 = tmp.(map[string]interface{}) } - ec.wg.Add(1) - go func(i int, field collectedField) { - defer ec.wg.Done() - res, err := ec.resolvers.MyMutation_updateTodo(ec.ctx, arg0, arg1) - if err != nil { - ec.Error(err) - return - } + res, err := ec.resolvers.MyMutation_updateTodo(ec.ctx, arg0, arg1) + if err != nil { + ec.Error(err) + continue + } - if res == nil { - out.Values[i] = jsonw.Null - } else { - out.Values[i] = ec._todo(field.Selections, res) - } - }(i, field) + if res == nil { + out.Values[i] = jsonw.Null + } else { + out.Values[i] = ec._todo(field.Selections, res) + } default: panic("unknown field " + strconv.Quote(field.Name)) } diff --git a/extractor.go b/extractor.go index 98f04be9cee..8cea02e0072 100644 --- a/extractor.go +++ b/extractor.go @@ -18,8 +18,8 @@ import ( type extractor struct { Errors []string PackageName string - Objects []object - Interfaces []object + Objects []*object + Interfaces []*object goTypeMap map[string]string Imports map[string]string // local -> full path schema *schema.Schema @@ -32,7 +32,7 @@ func (e *extractor) extract() { for _, typ := range e.schema.Types { switch typ := typ.(type) { case *schema.Object: - obj := object{ + obj := &object{ Name: typ.Name, Type: e.getType(typ.Name), } @@ -54,18 +54,19 @@ func (e *extractor) extract() { GraphQLName: field.Name, Type: e.buildType(field.Type), Args: args, + Object: obj, }) } e.Objects = append(e.Objects, obj) case *schema.Union: - obj := object{ + obj := &object{ Name: typ.Name, Type: e.buildType(typ), } e.Interfaces = append(e.Interfaces, obj) case *schema.Interface: - obj := object{ + obj := &object{ Name: typ.Name, Type: e.buildType(typ), } @@ -82,6 +83,7 @@ func (e *extractor) extract() { } if name == "mutation" { e.MutationRoot = obj.Name + e.GetObject(obj.Name).DisableConcurrency = true } } @@ -316,7 +318,7 @@ func (e *extractor) modifiersFromGoType(t types.Type) []string { } } -func (e *extractor) findBindTargets(t types.Type, object object) bool { +func (e *extractor) findBindTargets(t types.Type, object *object) bool { switch t := t.(type) { case *types.Named: for i := 0; i < t.NumMethods(); i++ { diff --git a/jsonw/jsonw.go b/jsonw/jsonw.go index e9e2ecdf1d6..234a1ba2b64 100644 --- a/jsonw/jsonw.go +++ b/jsonw/jsonw.go @@ -110,6 +110,6 @@ func Bool(b bool) Writer { func Time(t time.Time) Writer { return writerFunc(func(w io.Writer) { - io.WriteString(w, t.Format(time.RFC3339)) + io.WriteString(w, strconv.Quote(t.Format(time.RFC3339))) }) } diff --git a/main.go b/main.go index bfb016a9ea7..ea29d71bb53 100644 --- a/main.go +++ b/main.go @@ -78,6 +78,7 @@ func main() { GraphQLName: "__schema", NoErr: true, MethodName: "ec.introspectSchema", + Object: q, }) q.Fields = append(q.Fields, Field{ Type: e.getType("__Type").Ptr(), @@ -85,6 +86,7 @@ func main() { NoErr: true, MethodName: "ec.introspectType", Args: []FieldArgument{{Name: "name", Type: kind{Scalar: true, Name: "string"}}}, + Object: q, }) if len(e.Errors) != 0 { diff --git a/templates/file.gotpl b/templates/file.gotpl index 60dc1dcc781..cf3aa3b9fe0 100644 --- a/templates/file.gotpl +++ b/templates/file.gotpl @@ -12,7 +12,7 @@ import ( type Resolvers interface { {{- range $object := .Objects -}} {{ range $field := $object.Fields -}} - {{ $field.ResolverDeclaration $object }} + {{ $field.ResolverDeclaration }} {{ end }} {{- end }} } diff --git a/templates/object.gotpl b/templates/object.gotpl index b7eb5459f22..8a5ce90c96d 100644 --- a/templates/object.gotpl +++ b/templates/object.gotpl @@ -16,7 +16,7 @@ func (ec *executionContext) _{{$object.Type.GraphQLName|lcFirst}}(sel []query.Se case "{{$field.GraphQLName}}": {{- template "args" $field.Args }} - {{- if $field.IsResolver }} + {{- if $field.IsConcurrent }} ec.wg.Add(1) go func(i int, field collectedField) { defer ec.wg.Done() @@ -26,25 +26,25 @@ func (ec *executionContext) _{{$object.Type.GraphQLName|lcFirst}}(sel []query.Se res := {{$field.VarName}} {{- else if $field.MethodName }} {{- if $field.NoErr }} - res := {{$field.MethodName}}({{ $field.CallArgs $object }}) + res := {{$field.MethodName}}({{ $field.CallArgs }}) {{- else }} - res, err := {{$field.MethodName}}({{ $field.CallArgs $object }}) + res, err := {{$field.MethodName}}({{ $field.CallArgs }}) if err != nil { ec.Error(err) - continue + {{ if $field.IsConcurrent }}return{{ else }}continue{{end}} } {{- end }} {{- else }} - res, err := ec.resolvers.{{ $object.Name }}_{{ $field.GraphQLName }}({{ $field.CallArgs $object }}) + res, err := ec.resolvers.{{ $object.Name }}_{{ $field.GraphQLName }}({{ $field.CallArgs }}) if err != nil { ec.Error(err) - return + {{ if $field.IsConcurrent }}return{{ else }}continue{{end}} } {{- end }} {{ $field.WriteJson "out.Values[i]" }} - {{- if $field.IsResolver }} + {{- if $field.IsConcurrent }} }(i, field) {{- end }} {{- end }} diff --git a/types.go b/types.go index 732203d1933..f53df9e0325 100644 --- a/types.go +++ b/types.go @@ -70,11 +70,12 @@ func (t kind) FullName() string { } type object struct { - Name string - Fields []Field - Type kind - satisfies []string - Root bool + Name string + Fields []Field + Type kind + satisfies []string + Root bool + DisableConcurrency bool } type Field struct { @@ -84,20 +85,25 @@ type Field struct { Type kind Args []FieldArgument NoErr bool + Object *object } func (f *Field) IsResolver() bool { return f.MethodName == "" && f.VarName == "" } -func (f *Field) ResolverDeclaration(o object) string { +func (f *Field) IsConcurrent() bool { + return f.IsResolver() && !f.Object.DisableConcurrency +} + +func (f *Field) ResolverDeclaration() string { if !f.IsResolver() { return "" } - res := fmt.Sprintf("%s_%s(ctx context.Context", o.Name, f.GraphQLName) + res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.Name, f.GraphQLName) - if !o.Root { - res += fmt.Sprintf(", it *%s", o.Type.Local()) + if !f.Object.Root { + res += fmt.Sprintf(", it *%s", f.Object.Type.Local()) } for _, arg := range f.Args { res += fmt.Sprintf(", %s %s", arg.Name, arg.Type.Local()) @@ -107,13 +113,13 @@ func (f *Field) ResolverDeclaration(o object) string { return res } -func (f *Field) CallArgs(object object) string { +func (f *Field) CallArgs() string { var args []string if f.MethodName == "" { args = append(args, "ec.ctx") - if !object.Root { + if !f.Object.Root { args = append(args, "it") } } @@ -202,7 +208,7 @@ func (o *object) Implementors() string { func (e *extractor) GetObject(name string) *object { for i, o := range e.Objects { if strings.EqualFold(o.Name, name) { - return &e.Objects[i] + return e.Objects[i] } } return nil