Skip to content

Commit

Permalink
Dont execute mutations concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Feb 11, 2018
1 parent 3900a41 commit 75a3a05
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 65 deletions.
24 changes: 10 additions & 14 deletions example/starwars/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,21 +425,17 @@ func (ec *executionContext) _mutation(sel []query.Selection, it *interface{}) js
ec.Error(err)
continue
}
ec.wg.Add(1)
go func(i int, field collectedField) {
defer ec.wg.Done()
res, err := ec.resolvers.Mutation_createReview(ec.ctx, arg0, arg1)
if err != nil {
ec.Error(err)
return
}
res, err := ec.resolvers.Mutation_createReview(ec.ctx, arg0, arg1)
if err != nil {
ec.Error(err)
continue
}

if res == nil {
out.Values[i] = jsonw.Null
} else {
out.Values[i] = ec._review(field.Selections, res)
}
}(i, field)
if res == nil {
out.Values[i] = jsonw.Null
} else {
out.Values[i] = ec._review(field.Selections, res)
}
default:
panic("unknown field " + strconv.Quote(field.Name))
}
Expand Down
23 changes: 23 additions & 0 deletions example/starwars/starwars_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,29 @@ func TestStarwars(t *testing.T) {
require.Equal(t, "Leia Organa", resp.Droid.FriendsConnection.Edges[2].Node.Name)
})

t.Run("mutations must be run in sequence", func(t *testing.T) {
var resp struct {
A struct{ Time string }
B struct{ Time string }
C struct{ Time string }
}

c.MustPost(`mutation f{
a:createReview(episode: NEWHOPE, review:{stars:1, commentary:"Blah blah"}) {
time
}
b:createReview(episode: NEWHOPE, review:{stars:1, commentary:"Blah blah"}) {
time
}
c:createReview(episode: NEWHOPE, review:{stars:1, commentary:"Blah blah"}) {
time
}
}`, &resp)

require.NotEqual(t, resp.A.Time, resp.B.Time)
require.NotEqual(t, resp.C.Time, resp.B.Time)
})

t.Run("introspection", func(t *testing.T) {
// Make sure we can run the graphiql introspection query without errors
c.MustPost(introspection.Query, nil)
Expand Down
40 changes: 16 additions & 24 deletions example/todo/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,13 @@ func (ec *executionContext) _myMutation(sel []query.Selection, it *interface{})
}
arg0 = tmp2
}
ec.wg.Add(1)
go func(i int, field collectedField) {
defer ec.wg.Done()
res, err := ec.resolvers.MyMutation_createTodo(ec.ctx, arg0)
if err != nil {
ec.Error(err)
return
}
res, err := ec.resolvers.MyMutation_createTodo(ec.ctx, arg0)
if err != nil {
ec.Error(err)
continue
}

out.Values[i] = ec._todo(field.Selections, &res)
}(i, field)
out.Values[i] = ec._todo(field.Selections, &res)
case "updateTodo":
var arg0 int
if tmp, ok := field.Args["id"]; ok {
Expand All @@ -131,21 +127,17 @@ func (ec *executionContext) _myMutation(sel []query.Selection, it *interface{})
if tmp, ok := field.Args["changes"]; ok {
arg1 = tmp.(map[string]interface{})
}
ec.wg.Add(1)
go func(i int, field collectedField) {
defer ec.wg.Done()
res, err := ec.resolvers.MyMutation_updateTodo(ec.ctx, arg0, arg1)
if err != nil {
ec.Error(err)
return
}
res, err := ec.resolvers.MyMutation_updateTodo(ec.ctx, arg0, arg1)
if err != nil {
ec.Error(err)
continue
}

if res == nil {
out.Values[i] = jsonw.Null
} else {
out.Values[i] = ec._todo(field.Selections, res)
}
}(i, field)
if res == nil {
out.Values[i] = jsonw.Null
} else {
out.Values[i] = ec._todo(field.Selections, res)
}
default:
panic("unknown field " + strconv.Quote(field.Name))
}
Expand Down
14 changes: 8 additions & 6 deletions extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import (
type extractor struct {
Errors []string
PackageName string
Objects []object
Interfaces []object
Objects []*object
Interfaces []*object
goTypeMap map[string]string
Imports map[string]string // local -> full path
schema *schema.Schema
Expand All @@ -32,7 +32,7 @@ func (e *extractor) extract() {
for _, typ := range e.schema.Types {
switch typ := typ.(type) {
case *schema.Object:
obj := object{
obj := &object{
Name: typ.Name,
Type: e.getType(typ.Name),
}
Expand All @@ -54,18 +54,19 @@ func (e *extractor) extract() {
GraphQLName: field.Name,
Type: e.buildType(field.Type),
Args: args,
Object: obj,
})
}
e.Objects = append(e.Objects, obj)
case *schema.Union:
obj := object{
obj := &object{
Name: typ.Name,
Type: e.buildType(typ),
}
e.Interfaces = append(e.Interfaces, obj)

case *schema.Interface:
obj := object{
obj := &object{
Name: typ.Name,
Type: e.buildType(typ),
}
Expand All @@ -82,6 +83,7 @@ func (e *extractor) extract() {
}
if name == "mutation" {
e.MutationRoot = obj.Name
e.GetObject(obj.Name).DisableConcurrency = true
}
}

Expand Down Expand Up @@ -316,7 +318,7 @@ func (e *extractor) modifiersFromGoType(t types.Type) []string {
}
}

func (e *extractor) findBindTargets(t types.Type, object object) bool {
func (e *extractor) findBindTargets(t types.Type, object *object) bool {
switch t := t.(type) {
case *types.Named:
for i := 0; i < t.NumMethods(); i++ {
Expand Down
2 changes: 1 addition & 1 deletion jsonw/jsonw.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,6 @@ func Bool(b bool) Writer {

func Time(t time.Time) Writer {
return writerFunc(func(w io.Writer) {
io.WriteString(w, t.Format(time.RFC3339))
io.WriteString(w, strconv.Quote(t.Format(time.RFC3339)))
})
}
2 changes: 2 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,15 @@ func main() {
GraphQLName: "__schema",
NoErr: true,
MethodName: "ec.introspectSchema",
Object: q,
})
q.Fields = append(q.Fields, Field{
Type: e.getType("__Type").Ptr(),
GraphQLName: "__type",
NoErr: true,
MethodName: "ec.introspectType",
Args: []FieldArgument{{Name: "name", Type: kind{Scalar: true, Name: "string"}}},
Object: q,
})

if len(e.Errors) != 0 {
Expand Down
2 changes: 1 addition & 1 deletion templates/file.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
type Resolvers interface {
{{- range $object := .Objects -}}
{{ range $field := $object.Fields -}}
{{ $field.ResolverDeclaration $object }}
{{ $field.ResolverDeclaration }}
{{ end }}
{{- end }}
}
Expand Down
14 changes: 7 additions & 7 deletions templates/object.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func (ec *executionContext) _{{$object.Type.GraphQLName|lcFirst}}(sel []query.Se
case "{{$field.GraphQLName}}":
{{- template "args" $field.Args }}

{{- if $field.IsResolver }}
{{- if $field.IsConcurrent }}
ec.wg.Add(1)
go func(i int, field collectedField) {
defer ec.wg.Done()
Expand All @@ -26,25 +26,25 @@ func (ec *executionContext) _{{$object.Type.GraphQLName|lcFirst}}(sel []query.Se
res := {{$field.VarName}}
{{- else if $field.MethodName }}
{{- if $field.NoErr }}
res := {{$field.MethodName}}({{ $field.CallArgs $object }})
res := {{$field.MethodName}}({{ $field.CallArgs }})
{{- else }}
res, err := {{$field.MethodName}}({{ $field.CallArgs $object }})
res, err := {{$field.MethodName}}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
continue
{{ if $field.IsConcurrent }}return{{ else }}continue{{end}}
}
{{- end }}
{{- else }}
res, err := ec.resolvers.{{ $object.Name }}_{{ $field.GraphQLName }}({{ $field.CallArgs $object }})
res, err := ec.resolvers.{{ $object.Name }}_{{ $field.GraphQLName }}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
return
{{ if $field.IsConcurrent }}return{{ else }}continue{{end}}
}
{{- end }}

{{ $field.WriteJson "out.Values[i]" }}

{{- if $field.IsResolver }}
{{- if $field.IsConcurrent }}
}(i, field)
{{- end }}
{{- end }}
Expand Down
30 changes: 18 additions & 12 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ func (t kind) FullName() string {
}

type object struct {
Name string
Fields []Field
Type kind
satisfies []string
Root bool
Name string
Fields []Field
Type kind
satisfies []string
Root bool
DisableConcurrency bool
}

type Field struct {
Expand All @@ -84,20 +85,25 @@ type Field struct {
Type kind
Args []FieldArgument
NoErr bool
Object *object
}

func (f *Field) IsResolver() bool {
return f.MethodName == "" && f.VarName == ""
}

func (f *Field) ResolverDeclaration(o object) string {
func (f *Field) IsConcurrent() bool {
return f.IsResolver() && !f.Object.DisableConcurrency
}

func (f *Field) ResolverDeclaration() string {
if !f.IsResolver() {
return ""
}
res := fmt.Sprintf("%s_%s(ctx context.Context", o.Name, f.GraphQLName)
res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.Name, f.GraphQLName)

if !o.Root {
res += fmt.Sprintf(", it *%s", o.Type.Local())
if !f.Object.Root {
res += fmt.Sprintf(", it *%s", f.Object.Type.Local())
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.Name, arg.Type.Local())
Expand All @@ -107,13 +113,13 @@ func (f *Field) ResolverDeclaration(o object) string {
return res
}

func (f *Field) CallArgs(object object) string {
func (f *Field) CallArgs() string {
var args []string

if f.MethodName == "" {
args = append(args, "ec.ctx")

if !object.Root {
if !f.Object.Root {
args = append(args, "it")
}
}
Expand Down Expand Up @@ -202,7 +208,7 @@ func (o *object) Implementors() string {
func (e *extractor) GetObject(name string) *object {
for i, o := range e.Objects {
if strings.EqualFold(o.Name, name) {
return &e.Objects[i]
return e.Objects[i]
}
}
return nil
Expand Down

0 comments on commit 75a3a05

Please sign in to comment.