diff --git a/codegen/args.go b/codegen/args.go index e8c1bfeeb9c..c684dbf0d74 100644 --- a/codegen/args.go +++ b/codegen/args.go @@ -55,7 +55,7 @@ func (b *builder) buildArg(obj *Object, arg *ast.ArgumentDefinition) (*FieldArgu argDirs, err := b.getDirectives(arg.Directives) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: %w", arg.Name, err) } newArg := FieldArgument{ ArgumentDefinition: arg, diff --git a/codegen/concurrent.go b/codegen/concurrent.go new file mode 100644 index 00000000000..947aebbd8ea --- /dev/null +++ b/codegen/concurrent.go @@ -0,0 +1,50 @@ +package codegen + +import "github.com/vektah/gqlparser/v2/ast" + +const concurrentDirectiveName = "concurrent" + +func makeConcurrentObjectAndField(obj *Object, f *Field) { + var hasConcurrentDirective bool + for _, dir := range obj.Directives { + if dir.Name == concurrentDirectiveName { + hasConcurrentDirective = true + break + } + } + + if !hasConcurrentDirective { + obj.Directives = append(obj.Directives, &Directive{ + DirectiveDefinition: &ast.DirectiveDefinition{ + Name: concurrentDirectiveName, + }, + Name: concurrentDirectiveName, + Builtin: true, + }) + obj.DisableConcurrency = false + } + + if obj.Definition != nil && obj.Definition.Directives.ForName(concurrentDirectiveName) == nil { + obj.Definition.Directives = append(obj.Definition.Directives, &ast.Directive{ + Name: concurrentDirectiveName, + Definition: &ast.DirectiveDefinition{ + Name: concurrentDirectiveName, + }, + }) + } + + if f.TypeReference != nil && f.TypeReference.Definition != nil { + for _, dir := range f.TypeReference.Definition.Directives { + if dir.Name == concurrentDirectiveName { + hasConcurrentDirective = true + break + } + } + + if !hasConcurrentDirective { + f.TypeReference.Definition.Directives = append(f.TypeReference.Definition.Directives, &ast.Directive{ + Name: concurrentDirectiveName, + }) + } + } +} diff --git a/codegen/config/config.go b/codegen/config/config.go index 59c2beb0338..7a900cebcce 100644 --- a/codegen/config/config.go +++ b/codegen/config/config.go @@ -299,6 +299,10 @@ func (c *Config) injectTypesFromSchema() error { SkipRuntime: true, } + c.Directives["concurrent"] = DirectiveConfig{ + SkipRuntime: true, + } + for _, schemaType := range c.Schema.Types { if c.IsRoot(schemaType) { continue diff --git a/codegen/data.go b/codegen/data.go index 7110de2f925..d53a49ad716 100644 --- a/codegen/data.go +++ b/codegen/data.go @@ -1,6 +1,7 @@ package codegen import ( + "container/list" "errors" "fmt" "os" @@ -153,6 +154,8 @@ func BuildData(cfg *config.Config, plugins ...any) (*Data, error) { return nil, err } + handleConcurrent(s.Objects) + s.ReferencedTypes = b.buildTypes() sort.Slice(s.Objects, func(i, j int) bool { @@ -234,3 +237,42 @@ func (b *builder) injectIntrospectionRoots(s *Data) error { return nil } + +func handleConcurrent(objects Objects) { + concurrentObjects := make([]*Object, 0) + for _, obj := range objects { + for _, dir := range obj.Directives { + if dir.Name == concurrentDirectiveName { + concurrentObjects = append(concurrentObjects, obj) + break + } + } + } + + queue := list.New() + for _, obj := range concurrentObjects { + queue.PushBack(obj) + } + + concurrentObjectsCache := make(map[string]struct{}, 0) + + for queue.Len() > 0 { + v := queue.Front() + concurrentObject := v.Value.(*Object) + for _, obj := range objects { + if _, ok := concurrentObjectsCache[obj.Name]; ok { + continue + } + + for _, f := range obj.Fields { + if f.TypeReference.Definition == concurrentObject.Definition { + makeConcurrentObjectAndField(obj, f) + + queue.PushBack(obj) + concurrentObjectsCache[obj.Name] = struct{}{} + } + } + } + queue.Remove(v) + } +} diff --git a/codegen/field.go b/codegen/field.go index 883f57a8c83..ae5c5ee4803 100644 --- a/codegen/field.go +++ b/codegen/field.go @@ -40,7 +40,7 @@ type Field struct { func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) { dirs, err := b.getDirectives(field.Directives) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: %w", field.Name, err) } f := Field{ @@ -95,7 +95,7 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) { if f.TypeReference != nil { dirs, err := b.getDirectives(f.TypeReference.Definition.Directives) if err != nil { - errret = err + errret = fmt.Errorf("%s: %w", f.Name, err) } for _, dir := range obj.Directives { if dir.IsLocation(ast.LocationInputObject) { @@ -137,6 +137,7 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) { return nil case b.Config.Models[obj.Name].Fields[f.Name].Resolver: f.IsResolver = true + makeConcurrentObjectAndField(obj, f) return nil case obj.Type == config.MapType: f.GoFieldType = GoFieldMap diff --git a/codegen/object.go b/codegen/object.go index 1b780bd0c1d..dc525514ae8 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -44,9 +44,10 @@ func (b *builder) buildObject(typ *ast.Definition) (*Object, error) { } caser := cases.Title(language.English, cases.NoLower) obj := &Object{ - Definition: typ, - Root: b.Config.IsRoot(typ), - DisableConcurrency: typ == b.Schema.Mutation || typ.Directives.ForName("concurrent") == nil, + Definition: typ, + Root: b.Config.IsRoot(typ), + DisableConcurrency: typ == b.Schema.Mutation || + typ.Directives.ForName(concurrentDirectiveName) == nil, Stream: typ == b.Schema.Subscription, Directives: dirs, PointersInUnmarshalInput: b.Config.ReturnPointersInUnmarshalInput,