diff --git a/codegen/config/binder.go b/codegen/config/binder.go index cea904ada33..72956de4d1e 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -14,7 +14,7 @@ import ( // Binder connects graphql types to golang types using static analysis type Binder struct { - pkgs []*packages.Package + pkgs map[string]*packages.Package schema *ast.Schema cfg *Config References []*TypeReference @@ -26,7 +26,9 @@ func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) { return nil, err } + mp := map[string]*packages.Package{} for _, p := range pkgs { + populatePkg(mp, p) for _, e := range p.Errors { if e.Kind == packages.ListError { return nil, p.Errors[0] @@ -35,12 +37,23 @@ func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) { } return &Binder{ - pkgs: pkgs, + pkgs: mp, schema: s, cfg: c, }, nil } +func populatePkg(mp map[string]*packages.Package, p *packages.Package) { + imp := code.NormalizeVendor(p.PkgPath) + if _, ok := mp[imp]; ok { + return + } + mp[imp] = p + for _, p := range p.Imports { + populatePkg(mp, p) + } +} + func (b *Binder) TypePosition(typ types.Type) token.Position { named, isNamed := typ.(*types.Named) if !isNamed { @@ -75,10 +88,9 @@ func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) { } func (b *Binder) getPkg(find string) *packages.Package { - for _, p := range b.pkgs { - if code.NormalizeVendor(find) == code.NormalizeVendor(p.PkgPath) { - return p - } + imp := code.NormalizeVendor(find) + if p, ok := b.pkgs[imp]; ok { + return p } return nil } diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index d7a074445a8..3d5f4b6c6fe 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -221,6 +221,8 @@ type ComplexityRoot struct { User func(childComplexity int, id int) int Valid func(childComplexity int) int ValidType func(childComplexity int) int + WrappedScalar func(childComplexity int) int + WrappedStruct func(childComplexity int) int } Rectangle struct { @@ -255,6 +257,10 @@ type ComplexityRoot struct { ValidInputKeywords func(childComplexity int, input *ValidInput) int } + WrappedStruct struct { + Name func(childComplexity int) int + } + XXIt struct { ID func(childComplexity int) int } @@ -335,6 +341,8 @@ type QueryResolver interface { Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion(ctx context.Context) (TestUnion, error) ValidType(ctx context.Context) (*ValidType, error) + WrappedStruct(ctx context.Context) (*WrappedStruct, error) + WrappedScalar(ctx context.Context) (WrappedScalar, error) } type SubscriptionResolver interface { Updated(ctx context.Context) (<-chan string, error) @@ -1013,6 +1021,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.ValidType(childComplexity), true + case "Query.wrappedScalar": + if e.complexity.Query.WrappedScalar == nil { + break + } + + return e.complexity.Query.WrappedScalar(childComplexity), true + + case "Query.wrappedStruct": + if e.complexity.Query.WrappedStruct == nil { + break + } + + return e.complexity.Query.WrappedStruct(childComplexity), true + case "Rectangle.area": if e.complexity.Rectangle.Area == nil { break @@ -1142,6 +1164,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ValidType.ValidInputKeywords(childComplexity, args["input"].(*ValidInput)), true + case "WrappedStruct.name": + if e.complexity.WrappedStruct.Name == nil { + break + } + + return e.complexity.WrappedStruct.Name(childComplexity), true + case "XXIt.id": if e.complexity.XXIt.ID == nil { break @@ -1659,6 +1688,16 @@ type AIt { id: ID! } type XXIt { id: ID! } type AbIt { id: ID! } type XxIt { id: ID! } +`}, + &ast.Source{Name: "wrapped_type.graphql", Input: `# regression test for https://github.com/99designs/gqlgen/issues/721 + +extend type Query { + wrappedStruct: WrappedStruct! + wrappedScalar: WrappedScalar! +} + +type WrappedStruct { name: String! } +scalar WrappedScalar `}, ) @@ -4561,6 +4600,60 @@ func (ec *executionContext) _Query_validType(ctx context.Context, field graphql. return ec.marshalOValidType2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐValidType(ctx, field.Selections, res) } +func (ec *executionContext) _Query_wrappedStruct(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().WrappedStruct(rctx) + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*WrappedStruct) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNWrappedStruct2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedStruct(ctx, field.Selections, res) +} + +func (ec *executionContext) _Query_wrappedScalar(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().WrappedScalar(rctx) + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(WrappedScalar) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNWrappedScalar2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx, field.Selections, res) +} + func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { ec.Tracer.EndFieldExecution(ctx) }() @@ -5073,6 +5166,33 @@ func (ec *executionContext) _ValidType_validArgs(ctx context.Context, field grap return ec.marshalNBoolean2bool(ctx, field.Selections, res) } +func (ec *executionContext) _WrappedStruct_name(ctx context.Context, field graphql.CollectedField, obj *WrappedStruct) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "WrappedStruct", + Field: field, + Args: nil, + IsMethod: false, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Name, nil + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNString2string(ctx, field.Selections, res) +} + func (ec *executionContext) _XXIt_id(ctx context.Context, field graphql.CollectedField, obj *XXIt) graphql.Marshaler { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { ec.Tracer.EndFieldExecution(ctx) }() @@ -7730,6 +7850,34 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr res = ec._Query_validType(ctx, field) return res }) + case "wrappedStruct": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_wrappedStruct(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + }) + case "wrappedScalar": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_wrappedScalar(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + }) case "__type": out.Values[i] = ec._Query___type(ctx, field) case "__schema": @@ -7921,6 +8069,33 @@ func (ec *executionContext) _ValidType(ctx context.Context, sel ast.SelectionSet return out } +var wrappedStructImplementors = []string{"WrappedStruct"} + +func (ec *executionContext) _WrappedStruct(ctx context.Context, sel ast.SelectionSet, obj *WrappedStruct) graphql.Marshaler { + fields := graphql.CollectFields(ec.RequestContext, sel, wrappedStructImplementors) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("WrappedStruct") + case "name": + out.Values[i] = ec._WrappedStruct_name(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var xXItImplementors = []string{"XXIt"} func (ec *executionContext) _XXIt(ctx context.Context, sel ast.SelectionSet, obj *XXIt) graphql.Marshaler { @@ -8768,6 +8943,35 @@ func (ec *executionContext) marshalNUser2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋ return ec._User(ctx, sel, v) } +func (ec *executionContext) unmarshalNWrappedScalar2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx context.Context, v interface{}) (WrappedScalar, error) { + tmp, err := graphql.UnmarshalString(v) + return WrappedScalar(tmp), err +} + +func (ec *executionContext) marshalNWrappedScalar2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx context.Context, sel ast.SelectionSet, v WrappedScalar) graphql.Marshaler { + res := graphql.MarshalString(string(v)) + if res == graphql.Null { + if !ec.HasError(graphql.GetResolverContext(ctx)) { + ec.Errorf(ctx, "must not be null") + } + } + return res +} + +func (ec *executionContext) marshalNWrappedStruct2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedStruct(ctx context.Context, sel ast.SelectionSet, v WrappedStruct) graphql.Marshaler { + return ec._WrappedStruct(ctx, sel, &v) +} + +func (ec *executionContext) marshalNWrappedStruct2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedStruct(ctx context.Context, sel ast.SelectionSet, v *WrappedStruct) graphql.Marshaler { + if v == nil { + if !ec.HasError(graphql.GetResolverContext(ctx)) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + return ec._WrappedStruct(ctx, sel, v) +} + func (ec *executionContext) marshalN__Directive2githubᚗcomᚋ99designsᚋgqlgenᚋgraphqlᚋintrospectionᚐDirective(ctx context.Context, sel ast.SelectionSet, v introspection.Directive) graphql.Marshaler { return ec.___Directive(ctx, sel, &v) } diff --git a/codegen/testserver/gqlgen.yml b/codegen/testserver/gqlgen.yml index 411e8277976..4a8292fb5bd 100644 --- a/codegen/testserver/gqlgen.yml +++ b/codegen/testserver/gqlgen.yml @@ -76,3 +76,7 @@ models: model: "github.com/99designs/gqlgen/codegen/testserver.FallbackToStringEncoding" Bytes: model: "github.com/99designs/gqlgen/codegen/testserver.Bytes" + WrappedStruct: + model: "github.com/99designs/gqlgen/codegen/testserver.WrappedStruct" + WrappedScalar: + model: "github.com/99designs/gqlgen/codegen/testserver.WrappedScalar" diff --git a/codegen/testserver/otherpkg/model.go b/codegen/testserver/otherpkg/model.go new file mode 100644 index 00000000000..b3f856b6e57 --- /dev/null +++ b/codegen/testserver/otherpkg/model.go @@ -0,0 +1,7 @@ +package otherpkg + +type Scalar string + +type Struct struct { + Name string +} diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index 012b2103850..e82ec74efae 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -203,6 +203,12 @@ func (r *queryResolver) OptionalUnion(ctx context.Context) (TestUnion, error) { func (r *queryResolver) ValidType(ctx context.Context) (*ValidType, error) { panic("not implemented") } +func (r *queryResolver) WrappedStruct(ctx context.Context) (*WrappedStruct, error) { + panic("not implemented") +} +func (r *queryResolver) WrappedScalar(ctx context.Context) (WrappedScalar, error) { + panic("not implemented") +} type subscriptionResolver struct{ *Resolver } diff --git a/codegen/testserver/stub.go b/codegen/testserver/stub.go index ee69e0c7b37..2537853f499 100644 --- a/codegen/testserver/stub.go +++ b/codegen/testserver/stub.go @@ -71,6 +71,8 @@ type Stub struct { Fallback func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion func(ctx context.Context) (TestUnion, error) ValidType func(ctx context.Context) (*ValidType, error) + WrappedStruct func(ctx context.Context) (*WrappedStruct, error) + WrappedScalar func(ctx context.Context) (WrappedScalar, error) } SubscriptionResolver struct { Updated func(ctx context.Context) (<-chan string, error) @@ -273,6 +275,12 @@ func (r *stubQuery) OptionalUnion(ctx context.Context) (TestUnion, error) { func (r *stubQuery) ValidType(ctx context.Context) (*ValidType, error) { return r.QueryResolver.ValidType(ctx) } +func (r *stubQuery) WrappedStruct(ctx context.Context) (*WrappedStruct, error) { + return r.QueryResolver.WrappedStruct(ctx) +} +func (r *stubQuery) WrappedScalar(ctx context.Context) (WrappedScalar, error) { + return r.QueryResolver.WrappedScalar(ctx) +} type stubSubscription struct{ *Stub } diff --git a/codegen/testserver/wrapped_type.go b/codegen/testserver/wrapped_type.go new file mode 100644 index 00000000000..d9d318f403b --- /dev/null +++ b/codegen/testserver/wrapped_type.go @@ -0,0 +1,6 @@ +package testserver + +import "github.com/99designs/gqlgen/codegen/testserver/otherpkg" + +type WrappedScalar otherpkg.Scalar +type WrappedStruct otherpkg.Struct diff --git a/codegen/testserver/wrapped_type.graphql b/codegen/testserver/wrapped_type.graphql new file mode 100644 index 00000000000..f50e334e4e2 --- /dev/null +++ b/codegen/testserver/wrapped_type.graphql @@ -0,0 +1,9 @@ +# regression test for https://github.com/99designs/gqlgen/issues/721 + +extend type Query { + wrappedStruct: WrappedStruct! + wrappedScalar: WrappedScalar! +} + +type WrappedStruct { name: String! } +scalar WrappedScalar diff --git a/codegen/testserver/wrapped_type_test.go b/codegen/testserver/wrapped_type_test.go new file mode 100644 index 00000000000..0842fa3ba03 --- /dev/null +++ b/codegen/testserver/wrapped_type_test.go @@ -0,0 +1,54 @@ +package testserver + +import ( + "context" + "net/http/httptest" + "testing" + + "github.com/99designs/gqlgen/client" + "github.com/99designs/gqlgen/codegen/testserver/otherpkg" + "github.com/99designs/gqlgen/handler" + "github.com/stretchr/testify/require" +) + +func TestWrappedTypes(t *testing.T) { + resolvers := &Stub{} + + srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) + c := client.New(srv.URL) + + resolvers.QueryResolver.WrappedScalar = func(ctx context.Context) (scalar WrappedScalar, e error) { + return WrappedScalar("hello"), nil + } + + resolvers.QueryResolver.WrappedStruct = func(ctx context.Context) (wrappedStruct *WrappedStruct, e error) { + wrapped := WrappedStruct(otherpkg.Struct{ + Name: "hello", + }) + return &wrapped, nil + } + + t.Run("wrapped struct", func(t *testing.T) { + var resp struct { + WrappedStruct struct { + Name string + } + } + + err := c.Post(`query { wrappedStruct { name } }`, &resp) + require.NoError(t, err) + + require.Equal(t, "hello", resp.WrappedStruct.Name) + }) + + t.Run("wrapped scalar", func(t *testing.T) { + var resp struct { + WrappedScalar string + } + + err := c.Post(`query { wrappedScalar }`, &resp) + require.NoError(t, err) + + require.Equal(t, "hello", resp.WrappedScalar) + }) +}