Skip to content

Commit

Permalink
spec-compliant @defer support on fragments.
Browse files Browse the repository at this point in the history
  • Loading branch information
fiatjaf committed Apr 22, 2023
1 parent 112e332 commit a2360d1
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 56 deletions.
47 changes: 33 additions & 14 deletions codegen/generated!.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
}

func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
ec := executionContext{nil, e, 0, nil}
ec := executionContext{nil, e, nil, 0, nil}
_ = ec
{{ if not .Config.OmitComplexity -}}
switch typeName + "." + field {
Expand Down Expand Up @@ -138,7 +138,7 @@

func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
rc := graphql.GetOperationContext(ctx)
ec := executionContext{rc, e, 0, make(chan graphql.DeferredResult)}
ec := executionContext{rc, e, nil, 0, make(chan graphql.DeferredResult)}
inputUnmarshalMap := graphql.BuildUnmarshalerMap(
{{- range $input := .Inputs -}}
{{ if not $input.HasUnmarshal }}
Expand All @@ -151,8 +151,8 @@
switch rc.Operation.Operation {
{{- if .QueryRoot }} case ast.Query:
return func(ctx context.Context) *graphql.Response {
var response graphql.Response
var data graphql.Marshaler
var path ast.Path
if first {
first = false
ctx = graphql.WithUnmarshalerMap(ctx, inputUnmarshalMap)
Expand All @@ -165,22 +165,40 @@
{{- end }}
} else {
if ec.pendingDeferred > 0 {
res := <-ec.deferredResults
path = res.Path
data = res.Result
ec.pendingDeferred--
result := <-ec.deferredResults
ec.pendingDeferred--
data = result.Result
response.Path = result.Path
response.Label = result.Label
} else {
return nil
}
}
var buf bytes.Buffer
data.MarshalGQL(&buf)
response.Data = buf.Bytes()
response.HasNext = ec.pendingDeferred+len(ec.deferredGroups) > 0

return &graphql.Response{
Data: buf.Bytes(),
Path: path,
HasNext: ec.pendingDeferred > 0,
}
// dispatch deferred calls
dg := ec.deferredGroups
ec.deferredGroups = nil
for _, deferred := range dg {
go func (deferred graphql.DeferredGroup) {
ec.pendingDeferred++
deferred.FieldSet.Dispatch()
ds := graphql.DeferredResult{
Path: deferred.Path,
Label: deferred.Label,
Result: deferred.FieldSet,
}
if deferred.FieldSet.Invalids > 0 {
ds.Result = graphql.Null
}
ec.deferredResults <- ds
}(deferred)
}

return &response
}
{{ end }}

Expand Down Expand Up @@ -237,8 +255,9 @@
type executionContext struct {
*graphql.OperationContext
*executableSchema
pendingDeferred int
deferredResults chan graphql.DeferredResult
deferredGroups []graphql.DeferredGroup
pendingDeferred int
deferredResults chan graphql.DeferredResult
}

func (ec *executionContext) introspectSchema() (*introspection.Schema, error) {
Expand Down
80 changes: 46 additions & 34 deletions codegen/object.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.Selec
Object: {{$object.Name|quote}},
})
{{end}}

out := graphql.NewFieldSet(fields)
var invalids uint32
deferred := make(map[string]*graphql.FieldSet) // deferred-labels=>fieldsets
for i, field := range fields {
{{- if $object.Root }}
innerCtx := graphql.WithRootFieldContext(ctx, &graphql.RootFieldContext{
Expand All @@ -47,28 +48,7 @@ func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.Selec
{{- if $field.IsConcurrent }}
field := field

{{if not $object.Root}}
deferred := false
for _, dir := range field.Directives {
if dir.Name == "defer" {
out.Values[i] = graphql.Null
deferred = true
ec.pendingDeferred++
go func(ctx context.Context, field graphql.CollectedField) {
deferredOut := graphql.NewFieldSet([]graphql.CollectedField{field})
deferredOut.Values[0] = ec._{{$object.Name}}_{{$field.Name}}(ctx, field, obj)
ec.deferredResults <- graphql.DeferredResult{
Result: deferredOut,
Path: graphql.GetPath(ctx),
}
}(ctx, field)
break
}
}
if deferred { break }
{{end}}

innerFunc := func(ctx context.Context) (res graphql.Marshaler) {
innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
Expand All @@ -78,9 +58,9 @@ func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.Selec
{{- if $field.TypeReference.GQL.NonNull }}
if res == graphql.Null {
{{- if $object.IsConcurrent }}
atomic.AddUint32(&invalids, 1)
atomic.AddUint32(&fs.Invalids, 1)
{{- else }}
invalids++
fs.Invalids++
{{- end }}
}
{{- end }}
Expand All @@ -89,32 +69,54 @@ func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.Selec

{{if $object.Root}}
rrm := func(ctx context.Context) graphql.Marshaler {
return ec.OperationContext.RootResolverMiddleware(ctx, innerFunc)
return ec.OperationContext.RootResolverMiddleware(ctx,
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
}
{{end}}

{{if not $object.Root}}
if field.Deferrable != nil {
dfs, ok := deferred[field.Deferrable.Label]
di := 0
if ok {
dfs.AddField(field)
di = len(dfs.Values) - 1
} else {
dfs = graphql.NewFieldSet([]graphql.CollectedField{field})
deferred[field.Deferrable.Label] = dfs
}
dfs.Concurrently(di, func() graphql.Marshaler {
return innerFunc(ctx, dfs)
})

// don't run the out.Concurrently() call below
out.Values[i] = graphql.Null
continue
}
{{end}}

out.Concurrently(i, func() graphql.Marshaler {
{{- if $object.Root -}}
return rrm(innerCtx)
{{- else -}}
return innerFunc(ctx)
{{end}}
return innerFunc(ctx, out)
{{- end -}}
})
{{- else }}
{{if $object.Root}}
{{- if $object.Root -}}
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) {
return ec._{{$object.Name}}_{{$field.Name}}(ctx, field)
})
{{else}}
{{- else -}}
out.Values[i] = ec._{{$object.Name}}_{{$field.Name}}(ctx, field, obj)
{{end}}
{{- end -}}

{{- if $field.TypeReference.GQL.NonNull }}
if out.Values[i] == graphql.Null {
{{- if $object.IsConcurrent }}
atomic.AddUint32(&invalids, 1)
atomic.AddUint32(&out.Invalids, 1)
{{- else }}
invalids++
out.Invalids++
{{- end }}
}
{{- end }}
Expand All @@ -125,7 +127,17 @@ func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.Selec
}
}
out.Dispatch()
if invalids > 0 { return graphql.Null }
if out.Invalids > 0 { return graphql.Null }

// assign deferred groups to main executionContext
for label, dfs := range deferred {
ec.deferredGroups = append(ec.deferredGroups, graphql.DeferredGroup{
Label: label,
Path: graphql.GetPath(ctx),
FieldSet: dfs,
})
}

return out
}
{{- end }}
Expand Down
11 changes: 11 additions & 0 deletions graphql/deferred.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@ package graphql

import "github.com/vektah/gqlparser/v2/ast"

type Deferrable struct {
Label string
}

type DeferredGroup struct {
Path ast.Path
Label string
FieldSet *FieldSet
}

type DeferredResult struct {
Path ast.Path
Label string
Result Marshaler
}
46 changes: 44 additions & 2 deletions graphql/executable_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,17 @@ func collectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies
if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) {
continue
}

shouldDefer, label := deferrable(sel.Directives, reqCtx.Variables)

for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
f := getOrCreateAndAppendField(&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition, func() CollectedField { return childField })
f := getOrCreateAndAppendField(
&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition,
func() CollectedField { return childField })
f.Selections = append(f.Selections, childField.Selections...)
if shouldDefer {
f.Deferrable = &Deferrable{Label: label}
}
}

case *ast.FragmentSpread:
Expand All @@ -70,9 +78,16 @@ func collectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies
continue
}

shouldDefer, label := deferrable(sel.Directives, reqCtx.Variables)

for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
f := getOrCreateAndAppendField(&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition, func() CollectedField { return childField })
f := getOrCreateAndAppendField(&groupedFields,
childField.Name, childField.Alias, childField.ObjectDefinition,
func() CollectedField { return childField })
f.Selections = append(f.Selections, childField.Selections...)
if shouldDefer {
f.Deferrable = &Deferrable{Label: label}
}
}

default:
Expand All @@ -87,6 +102,7 @@ type CollectedField struct {
*ast.Field

Selections ast.SelectionSet
Deferrable *Deferrable
}

func instanceOf(val string, satisfies []string) bool {
Expand Down Expand Up @@ -150,6 +166,32 @@ func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interf
return !skip && include
}

func deferrable(directives ast.DirectiveList, variables map[string]interface{}) (shouldDefer bool, label string) {
d := directives.ForName("defer")
if d == nil {
return false, ""
}

shouldDefer = true

for _, arg := range d.Arguments {
switch arg.Name {
case "if":
if value, err := arg.Value.Value(variables); err == nil {
shouldDefer, _ = value.(bool)
}
case "label":
if value, err := arg.Value.Value(variables); err == nil {
label, _ = value.(string)
}
default:
panic(fmt.Sprintf("defer: argument '%s' not supported", arg.Name))
}
}

return shouldDefer, label
}

func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
arg := d.Arguments.ForName("if")
if arg == nil {
Expand Down
14 changes: 10 additions & 4 deletions graphql/fieldset.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import (
)

type FieldSet struct {
fields []CollectedField
Values []Marshaler
delayed []delayedResult
fields []CollectedField
Values []Marshaler
Invalids uint32
delayed []delayedResult
}

type delayedResult struct {
Expand All @@ -24,6 +25,11 @@ func NewFieldSet(fields []CollectedField) *FieldSet {
}
}

func (m *FieldSet) AddField(field CollectedField) {
m.fields = append(m.fields, field)
m.Values = append(m.Values, nil)
}

func (m *FieldSet) Concurrently(i int, f func() Marshaler) {
m.delayed = append(m.delayed, delayedResult{i: i, f: f})
}
Expand Down Expand Up @@ -58,7 +64,7 @@ func (m *FieldSet) MarshalGQL(writer io.Writer) {
}
writeQuotedString(writer, field.Alias)
writer.Write(colon)
fmt.Println(m.fields[i].Name, "=>", m.Values[i])
fmt.Println(i, m.fields[i].Name, "=>", m.Values[i])
m.Values[i].MarshalGQL(writer)
}
writer.Write(closeBrace)
Expand Down
5 changes: 3 additions & 2 deletions graphql/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ import (
type Response struct {
Errors gqlerror.List `json:"errors,omitempty"`
Data json.RawMessage `json:"data"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
Label string `json:"label,omitempty"`
Path ast.Path `json:"path,omitempty"`
HasNext bool `json:"hasNext"`
Path ast.Path `json:"path"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
}

func ErrorResponse(ctx context.Context, messagef string, args ...interface{}) *Response {
Expand Down

0 comments on commit a2360d1

Please sign in to comment.