From d518754423e9ab6740a37ce7354f78048976e863 Mon Sep 17 00:00:00 2001 From: Carl Dunham Date: Fri, 22 Oct 2021 14:05:54 -0700 Subject: [PATCH 1/2] Support for multiple @key directives in federation --- .../accounts/graph/generated/federation.go | 38 +++-- .../products/graph/entity.resolvers.go | 9 + .../products/graph/generated/federation.go | 53 ++++-- .../products/graph/generated/generated.go | 97 ++++++++++- .../federation/products/graph/schema.graphqls | 2 +- .../reviews/graph/generated/federation.go | 44 +++-- plugin/federation/federation.go | 161 +++++++++++------- plugin/federation/federation.gotpl | 34 ++-- plugin/federation/federation_test.go | 73 ++++---- plugin/federation/test_data/schema.graphql | 2 +- 10 files changed, 350 insertions(+), 163 deletions(-) diff --git a/example/federation/accounts/graph/generated/federation.go b/example/federation/accounts/graph/generated/federation.go index 1d7ee7ca52a..b987aa8909a 100644 --- a/example/federation/accounts/graph/generated/federation.go +++ b/example/federation/accounts/graph/generated/federation.go @@ -9,6 +9,7 @@ import ( "strings" "sync" + "github.com/99designs/gqlgen/example/federation/accounts/graph/model" "github.com/99designs/gqlgen/plugin/federation/fedruntime" ) @@ -47,32 +48,39 @@ func (ec *executionContext) __resolve_entities(ctx context.Context, representati return errors.New("__typename must be an existing string") } switch typeName { - case "EmailHost": - id0, err := ec.unmarshalNString2string(ctx, rep["id"]) - if err != nil { - return errors.New(fmt.Sprintf("Field %s undefined in schema.", "id")) - } + entity, err := func() (*model.EmailHost, error) { + id0, err := ec.unmarshalNString2string(ctx, rep["id"]) + if err == nil { + return ec.resolvers.Entity().FindEmailHostByID(ctx, id0) + } + return nil, nil + }() - entity, err := ec.resolvers.Entity().FindEmailHostByID(ctx, - id0) if err != nil { - return err + return fmt.Errorf(`resolving Entity "EmailHost": %w`, err) + } + if entity == nil { + return errors.New(`unable to resolve Entity "EmailHost"`) } list[i] = entity return nil case "User": - id0, err := ec.unmarshalNID2string(ctx, rep["id"]) - if err != nil { - return errors.New(fmt.Sprintf("Field %s undefined in schema.", "id")) - } + entity, err := func() (*model.User, error) { + id0, err := ec.unmarshalNID2string(ctx, rep["id"]) + if err == nil { + return ec.resolvers.Entity().FindUserByID(ctx, id0) + } + return nil, nil + }() - entity, err := ec.resolvers.Entity().FindUserByID(ctx, - id0) if err != nil { - return err + return fmt.Errorf(`resolving Entity "User": %w`, err) + } + if entity == nil { + return errors.New(`unable to resolve Entity "User"`) } list[i] = entity diff --git a/example/federation/products/graph/entity.resolvers.go b/example/federation/products/graph/entity.resolvers.go index 1af00473615..29e682ef570 100644 --- a/example/federation/products/graph/entity.resolvers.go +++ b/example/federation/products/graph/entity.resolvers.go @@ -26,6 +26,15 @@ func (r *entityResolver) FindProductByManufacturerIDAndID(ctx context.Context, m return nil, nil } +func (r *entityResolver) FindProductByUpc(ctx context.Context, upc string) (*model.Product, error) { + for _, hat := range hats { + if hat.Upc == upc { + return hat, nil + } + } + return nil, nil +} + // Entity returns generated.EntityResolver implementation. func (r *Resolver) Entity() generated.EntityResolver { return &entityResolver{r} } diff --git a/example/federation/products/graph/generated/federation.go b/example/federation/products/graph/generated/federation.go index 410d467a82d..c0dd5b58b58 100644 --- a/example/federation/products/graph/generated/federation.go +++ b/example/federation/products/graph/generated/federation.go @@ -9,6 +9,7 @@ import ( "strings" "sync" + "github.com/99designs/gqlgen/example/federation/products/graph/model" "github.com/99designs/gqlgen/plugin/federation/fedruntime" ) @@ -47,36 +48,52 @@ func (ec *executionContext) __resolve_entities(ctx context.Context, representati return errors.New("__typename must be an existing string") } switch typeName { - case "Manufacturer": - id0, err := ec.unmarshalNString2string(ctx, rep["id"]) - if err != nil { - return errors.New(fmt.Sprintf("Field %s undefined in schema.", "id")) - } + entity, err := func() (*model.Manufacturer, error) { + id0, err := ec.unmarshalNString2string(ctx, rep["id"]) + if err == nil { + return ec.resolvers.Entity().FindManufacturerByID(ctx, id0) + } + return nil, nil + }() - entity, err := ec.resolvers.Entity().FindManufacturerByID(ctx, - id0) if err != nil { - return err + return fmt.Errorf(`resolving Entity "Manufacturer": %w`, err) + } + if entity == nil { + return errors.New(`unable to resolve Entity "Manufacturer"`) } list[i] = entity return nil case "Product": - id0, err := ec.unmarshalNString2string(ctx, rep["manufacturer"].(map[string]interface{})["id"]) - if err != nil { - return errors.New(fmt.Sprintf("Field %s undefined in schema.", "manufacturerID")) - } - id1, err := ec.unmarshalNString2string(ctx, rep["id"]) - if err != nil { - return errors.New(fmt.Sprintf("Field %s undefined in schema.", "id")) + entity, err := func() (*model.Product, error) { + id0, err := ec.unmarshalNString2string(ctx, rep["manufacturer"].(map[string]interface{})["id"]) + if err == nil { + id1, err := ec.unmarshalNString2string(ctx, rep["id"]) + if err == nil { + return ec.resolvers.Entity().FindProductByManufacturerIDAndID(ctx, id0, id1) + } + } + return nil, nil + }() + + if entity == nil { + entity, err = func() (*model.Product, error) { + id0, err := ec.unmarshalNString2string(ctx, rep["upc"]) + if err == nil { + return ec.resolvers.Entity().FindProductByUpc(ctx, id0) + } + return nil, nil + }() } - entity, err := ec.resolvers.Entity().FindProductByManufacturerIDAndID(ctx, - id0, id1) if err != nil { - return err + return fmt.Errorf(`resolving Entity "Product": %w`, err) + } + if entity == nil { + return errors.New(`unable to resolve Entity "Product"`) } list[i] = entity diff --git a/example/federation/products/graph/generated/generated.go b/example/federation/products/graph/generated/generated.go index db9b6e3a572..1a5957f5408 100644 --- a/example/federation/products/graph/generated/generated.go +++ b/example/federation/products/graph/generated/generated.go @@ -48,6 +48,7 @@ type ComplexityRoot struct { Entity struct { FindManufacturerByID func(childComplexity int, id string) int FindProductByManufacturerIDAndID func(childComplexity int, manufacturerID string, id string) int + FindProductByUpc func(childComplexity int, upc string) int } Manufacturer struct { @@ -77,6 +78,7 @@ type ComplexityRoot struct { type EntityResolver interface { FindManufacturerByID(ctx context.Context, id string) (*model.Manufacturer, error) FindProductByManufacturerIDAndID(ctx context.Context, manufacturerID string, id string) (*model.Product, error) + FindProductByUpc(ctx context.Context, upc string) (*model.Product, error) } type QueryResolver interface { TopProducts(ctx context.Context, first *int) ([]*model.Product, error) @@ -121,6 +123,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Entity.FindProductByManufacturerIDAndID(childComplexity, args["manufacturerID"].(string), args["id"].(string)), true + case "Entity.findProductByUpc": + if e.complexity.Entity.FindProductByUpc == nil { + break + } + + args, err := ec.field_Entity_findProductByUpc_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Entity.FindProductByUpc(childComplexity, args["upc"].(string)), true + case "Manufacturer.id": if e.complexity.Manufacturer.ID == nil { break @@ -267,7 +281,7 @@ type Manufacturer @key(fields: "id") { name: String! } -type Product @key(fields: "manufacturer { id } id") { +type Product @key(fields: "manufacturer { id } id") @key(fields: "upc") { id: String! manufacturer: Manufacturer! upc: String! @@ -293,6 +307,7 @@ union _Entity = Manufacturer | Product type Entity { findManufacturerByID(id: String!,): Manufacturer! findProductByManufacturerIDAndID(manufacturerID: String!,id: String!,): Product! + findProductByUpc(upc: String!,): Product! } @@ -351,6 +366,21 @@ func (ec *executionContext) field_Entity_findProductByManufacturerIDAndID_args(c return args, nil } +func (ec *executionContext) field_Entity_findProductByUpc_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 string + if tmp, ok := rawArgs["upc"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("upc")) + arg0, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["upc"] = arg0 + return args, nil +} + func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -518,6 +548,48 @@ func (ec *executionContext) _Entity_findProductByManufacturerIDAndID(ctx context return ec.marshalNProduct2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋexampleᚋfederationᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) } +func (ec *executionContext) _Entity_findProductByUpc(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Entity", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Entity_findProductByUpc_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Entity().FindProductByUpc(rctx, args["upc"].(string)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*model.Product) + fc.Result = res + return ec.marshalNProduct2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋexampleᚋfederationᚋproductsᚋgraphᚋmodelᚐProduct(ctx, field.Selections, res) +} + func (ec *executionContext) _Manufacturer_id(ctx context.Context, field graphql.CollectedField, obj *model.Manufacturer) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -2197,6 +2269,29 @@ func (ec *executionContext) _Entity(ctx context.Context, sel ast.SelectionSet) g return ec.OperationContext.RootResolverMiddleware(ctx, innerFunc) } + out.Concurrently(i, func() graphql.Marshaler { + return rrm(innerCtx) + }) + case "findProductByUpc": + field := field + + innerFunc := func(ctx context.Context) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Entity_findProductByUpc(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + } + + rrm := func(ctx context.Context) graphql.Marshaler { + return ec.OperationContext.RootResolverMiddleware(ctx, innerFunc) + } + out.Concurrently(i, func() graphql.Marshaler { return rrm(innerCtx) }) diff --git a/example/federation/products/graph/schema.graphqls b/example/federation/products/graph/schema.graphqls index 7e55ac996eb..8c1f0e7cf0c 100644 --- a/example/federation/products/graph/schema.graphqls +++ b/example/federation/products/graph/schema.graphqls @@ -7,7 +7,7 @@ type Manufacturer @key(fields: "id") { name: String! } -type Product @key(fields: "manufacturer { id } id") { +type Product @key(fields: "manufacturer { id } id") @key(fields: "upc") { id: String! manufacturer: Manufacturer! upc: String! diff --git a/example/federation/reviews/graph/generated/federation.go b/example/federation/reviews/graph/generated/federation.go index 55391a46eae..fe1532e488d 100644 --- a/example/federation/reviews/graph/generated/federation.go +++ b/example/federation/reviews/graph/generated/federation.go @@ -9,6 +9,7 @@ import ( "strings" "sync" + "github.com/99designs/gqlgen/example/federation/reviews/graph/model" "github.com/99designs/gqlgen/plugin/federation/fedruntime" ) @@ -49,34 +50,41 @@ func (ec *executionContext) __resolve_entities(ctx context.Context, representati switch typeName { case "Product": - id0, err := ec.unmarshalNString2string(ctx, rep["manufacturer"].(map[string]interface{})["id"]) - if err != nil { - return errors.New(fmt.Sprintf("Field %s undefined in schema.", "manufacturerID")) - } - id1, err := ec.unmarshalNString2string(ctx, rep["id"]) - if err != nil { - return errors.New(fmt.Sprintf("Field %s undefined in schema.", "id")) - } + entity, err := func() (*model.Product, error) { + id0, err := ec.unmarshalNString2string(ctx, rep["manufacturer"].(map[string]interface{})["id"]) + if err == nil { + id1, err := ec.unmarshalNString2string(ctx, rep["id"]) + if err == nil { + return ec.resolvers.Entity().FindProductByManufacturerIDAndID(ctx, id0, id1) + } + } + return nil, nil + }() - entity, err := ec.resolvers.Entity().FindProductByManufacturerIDAndID(ctx, - id0, id1) if err != nil { - return err + return fmt.Errorf(`resolving Entity "Product": %w`, err) + } + if entity == nil { + return errors.New(`unable to resolve Entity "Product"`) } list[i] = entity return nil case "User": - id0, err := ec.unmarshalNID2string(ctx, rep["id"]) - if err != nil { - return errors.New(fmt.Sprintf("Field %s undefined in schema.", "id")) - } + entity, err := func() (*model.User, error) { + id0, err := ec.unmarshalNID2string(ctx, rep["id"]) + if err == nil { + return ec.resolvers.Entity().FindUserByID(ctx, id0) + } + return nil, nil + }() - entity, err := ec.resolvers.Entity().FindUserByID(ctx, - id0) if err != nil { - return err + return fmt.Errorf(`resolving Entity "User": %w`, err) + } + if entity == nil { + return errors.New(`unable to resolve Entity "User"`) } entity.Host.ID, err = ec.unmarshalNString2string(ctx, rep["hostID"]) diff --git a/plugin/federation/federation.go b/plugin/federation/federation.go index dd1b8ca38ed..e12d8d05002 100644 --- a/plugin/federation/federation.go +++ b/plugin/federation/federation.go @@ -2,6 +2,7 @@ package federation import ( "fmt" + "go/types" "sort" "github.com/vektah/gqlparser/v2/ast" @@ -15,6 +16,8 @@ import ( type federation struct { Entities []*Entity + + imports []string } // New returns a federation plugin that injects @@ -95,12 +98,14 @@ func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source { } entities += e.Name - if e.ResolverName != "" { - resolverArgs := "" - for _, keyField := range e.KeyFields { - resolverArgs += fmt.Sprintf("%s: %s,", keyField.Field.ToGoPrivate(), keyField.Definition.Type.String()) + for _, r := range e.Resolvers { + if r.ResolverName != "" { + resolverArgs := "" + for _, keyField := range r.KeyFields { + resolverArgs += fmt.Sprintf("%s: %s,", keyField.Field.ToGoPrivate(), keyField.Definition.Type.String()) + } + resolvers += fmt.Sprintf("\t%s(%s): %s!\n", r.ResolverName, resolverArgs, e.Def.Name) } - resolvers += fmt.Sprintf("\t%s(%s): %s!\n", e.ResolverName, resolverArgs, e.Def.Name) } } @@ -143,11 +148,17 @@ extend type Query { // Entity represents a federated type // that was declared in the GQL schema. type Entity struct { - Name string // The same name as the type declaration - KeyFields []*KeyField // The fields declared in @key. + Name string // The same name as the type declaration + Def *ast.Definition + Resolvers []*EntityResolver + Requires []*Requires + // Type *config.TypeReference // The Go representation of that field type + Type string // The Go representation of that field type +} + +type EntityResolver struct { ResolverName string // The resolver name, such as FindUserByID - Def *ast.Definition - Requires []*Requires + KeyFields []*KeyField // The fields declared in @key. } type KeyField struct { @@ -179,16 +190,25 @@ func (f *federation) GenerateCode(data *codegen.Data) error { } for _, e := range f.Entities { obj := data.Objects.ByName(e.Def.Name) + e.Type = types.TypeString(obj.Type, func(i *types.Package) string { + f.imports = append(f.imports, i.Path()) + return data.Config.Packages.NameForPackage(i.Path()) + }) + if e.Type == "" { + panic("unknown type " + obj.Type.String()) + } - // fill in types for key fields - // - for _, keyField := range e.KeyFields { - if len(keyField.Field) == 0 { - fmt.Println("skipping key field " + keyField.Definition.Name + " in " + e.Def.Name) - continue + for _, r := range e.Resolvers { + // fill in types for key fields + // + for _, keyField := range r.KeyFields { + if len(keyField.Field) == 0 { + fmt.Println("skipping field " + keyField.Definition.Name + " in " + r.ResolverName + " in " + e.Def.Name) + continue + } + cgField := keyField.Field.TypeReference(obj, data.Objects) + keyField.Type = cgField.TypeReference } - cgField := keyField.Field.TypeReference(obj, data.Objects) - keyField.Type = cgField.TypeReference } // fill in types for requires fields @@ -213,6 +233,13 @@ func (f *federation) GenerateCode(data *codegen.Data) error { }) } +func (f *federation) Imports() string { + for _, path := range f.imports { + _, _ = templates.CurrentImports.Reserve(path) + } + return "" +} + func (f *federation) setEntities(schema *ast.Schema) { for _, schemaType := range schema.Types { if schemaType.Kind == ast.Interface { @@ -227,35 +254,18 @@ func (f *federation) setEntities(schema *ast.Schema) { } if schemaType.Kind == ast.Object { keys := schemaType.Directives.ForNames("key") - if len(keys) > 1 { - // TODO: support multiple keys -- multiple resolvers per Entity - panic("only one @key directive currently supported") + if len(keys) == 0 { + continue } + resolvers := []*EntityResolver{} - if len(keys) > 0 { - dir := keys[0] + for _, dir := range keys { if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" { panic("Exactly one `fields` argument needed for @key declaration.") } arg := dir.Arguments[0] keyFieldSet := fieldset.New(arg.Value.Raw, nil) - // TODO: why is this nested inside the @key handling? -- because it's per-Entity, and we make one per @key - requires := []*Requires{} - for _, f := range schemaType.Fields { - dir := f.Directives.ForName("requires") - if dir == nil { - continue - } - requiresFieldSet := fieldset.New(dir.Arguments[0].Value.Raw, nil) - for _, field := range requiresFieldSet { - requires = append(requires, &Requires{ - Name: field.ToGoPrivate(), - Field: field, - }) - } - } - keyFields := make([]*KeyField, len(keyFieldSet)) resolverName := fmt.Sprintf("find%sBy", schemaType.Name) for i, field := range keyFieldSet { @@ -271,38 +281,59 @@ func (f *federation) setEntities(schema *ast.Schema) { } resolverName += field.ToGo() } - - e := &Entity{ - Name: schemaType.Name, - KeyFields: keyFields, - Def: schemaType, + resolvers = append(resolvers, &EntityResolver{ ResolverName: resolverName, - Requires: requires, + KeyFields: keyFields, + }) + } + + requires := []*Requires{} + for _, f := range schemaType.Fields { + dir := f.Directives.ForName("requires") + if dir == nil { + continue + } + if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" { + panic("Exactly one `fields` argument needed for @requires declaration.") } - // If our schema has a field with a type defined in - // another service, then we need to define an "empty - // extend" of that type in this service, so this service - // knows what the type is like. But the graphql-server - // will never ask us to actually resolve this "empty - // extend", so we don't require a resolver function for - // it. (Well, it will never ask in practice; it's - // unclear whether the spec guarantees this. See - // https://github.com/apollographql/apollo-server/issues/3852 - // ). Example: - // type MyType { - // myvar: TypeDefinedInOtherService - // } - // // Federation needs this type, but - // // it doesn't need a resolver for it! - // extend TypeDefinedInOtherService @key(fields: "id") { - // id: ID @external - // } - if e.allFieldsAreExternal() { - e.ResolverName = "" + requiresFieldSet := fieldset.New(dir.Arguments[0].Value.Raw, nil) + for _, field := range requiresFieldSet { + requires = append(requires, &Requires{ + Name: field.ToGoPrivate(), + Field: field, + }) } + } - f.Entities = append(f.Entities, e) + e := &Entity{ + Name: schemaType.Name, + Def: schemaType, + Resolvers: resolvers, + Requires: requires, + } + // If our schema has a field with a type defined in + // another service, then we need to define an "empty + // extend" of that type in this service, so this service + // knows what the type is like. But the graphql-server + // will never ask us to actually resolve this "empty + // extend", so we don't require a resolver function for + // it. (Well, it will never ask in practice; it's + // unclear whether the spec guarantees this. See + // https://github.com/apollographql/apollo-server/issues/3852 + // ). Example: + // type MyType { + // myvar: TypeDefinedInOtherService + // } + // // Federation needs this type, but + // // it doesn't need a resolver for it! + // extend TypeDefinedInOtherService @key(fields: "id") { + // id: ID @external + // } + if e.allFieldsAreExternal() { + e.Resolvers = nil } + + f.Entities = append(f.Entities, e) } } diff --git a/plugin/federation/federation.gotpl b/plugin/federation/federation.gotpl index 7dbb00535d7..e66d77e7429 100644 --- a/plugin/federation/federation.gotpl +++ b/plugin/federation/federation.gotpl @@ -6,6 +6,8 @@ {{ reserveImport "github.com/99designs/gqlgen/plugin/federation/fedruntime" }} +{{ .Imports }} + func (ec *executionContext) __resolve__service(ctx context.Context) (fedruntime.Service, error) { if ec.DisableIntrospection { return fedruntime.Service{}, errors.New("federated introspection disabled") @@ -42,22 +44,30 @@ func (ec *executionContext) __resolve_entities(ctx context.Context, representati return errors.New("__typename must be an existing string") } switch typeName { - {{ range .Entities }} - {{ if .ResolverName }} + {{ range $_, $entity := .Entities }} + {{- if .Resolvers -}} case "{{.Def.Name}}": - {{ range $i, $keyField := .KeyFields -}} - id{{$i}}, err := ec.{{.Type.UnmarshalFunc}}(ctx, rep["{{.Field.Join `"].(map[string]interface{})["`}}"]) - if err != nil { - return errors.New(fmt.Sprintf("Field %s undefined in schema.", "{{.Definition.Name}}")) - } + {{range $i, $_ := .Resolvers -}} + {{- if ne $i 0 -}}if entity == nil { {{- end -}} + entity, err {{- if eq $i 0 -}}:{{- end -}}= func() (*{{$entity.Type}}, error) { + {{- range $j, $keyField := .KeyFields -}} + id{{$j}}, err := ec.{{.Type.UnmarshalFunc}}(ctx, rep["{{.Field.Join `"].(map[string]interface{})["`}}"]) + if err == nil { + {{- end}} + return ec.resolvers.Entity().{{.ResolverName | go}}(ctx, {{- range $j, $_ := .KeyFields -}} id{{$j}}, {{end}}) + {{- range .KeyFields -}} + } + {{- end}} + return nil, nil + }() + {{ if ne $i 0 -}} } {{- end}} {{end}} - - entity, err := ec.resolvers.Entity().{{.ResolverName | go}}(ctx, - {{ range $i, $_ := .KeyFields -}} id{{$i}}, {{end}}) if err != nil { - return err + return fmt.Errorf(`resolving Entity "{{.Def.Name}}": %w`, err) + } + if entity == nil { + return errors.New(`unable to resolve Entity "{{.Def.Name}}"`) } - {{ range .Requires }} entity.{{.Field.JoinGo `.`}}, err = ec.{{.Type.UnmarshalFunc}}(ctx, rep["{{.Name}}"]) if err != nil { diff --git a/plugin/federation/federation_test.go b/plugin/federation/federation_test.go index 20578cf7a7a..d4844d3c89e 100644 --- a/plugin/federation/federation_test.go +++ b/plugin/federation/federation_test.go @@ -12,61 +12,70 @@ func TestWithEntities(t *testing.T) { require.Equal(t, []string{"ExternalExtension", "Hello", "MoreNesting", "NestedKey", "VeryNestedKey", "World"}, cfg.Schema.Types["_Entity"].Types) - require.Len(t, cfg.Schema.Types["Entity"].Fields, 5) + require.Len(t, cfg.Schema.Types["Entity"].Fields, 6) require.Equal(t, "findExternalExtensionByUpc", cfg.Schema.Types["Entity"].Fields[0].Name) require.Equal(t, "findHelloByName", cfg.Schema.Types["Entity"].Fields[1].Name) + // missing on purpose: all @external fields: + // require.Equal(t, "findMoreNestingByID", cfg.Schema.Types["Entity"].Fields[2].Name) require.Equal(t, "findNestedKeyByIDAndHelloName", cfg.Schema.Types["Entity"].Fields[2].Name) require.Equal(t, "findVeryNestedKeyByIDAndHelloNameAndWorldFooAndWorldBarAndMoreWorldFoo", cfg.Schema.Types["Entity"].Fields[3].Name) - require.Equal(t, "findWorldByFooAndBar", cfg.Schema.Types["Entity"].Fields[4].Name) + require.Equal(t, "findWorldByFoo", cfg.Schema.Types["Entity"].Fields[4].Name) + require.Equal(t, "findWorldByBar", cfg.Schema.Types["Entity"].Fields[5].Name) require.NoError(t, f.MutateConfig(cfg)) + require.Len(t, f.Entities, 6) + require.Equal(t, "ExternalExtension", f.Entities[0].Name) - require.Len(t, f.Entities[0].KeyFields, 1) - require.Equal(t, "upc", f.Entities[0].KeyFields[0].Definition.Name) - require.Equal(t, "String", f.Entities[0].KeyFields[0].Definition.Type.Name()) + require.Len(t, f.Entities[0].Resolvers, 1) + require.Len(t, f.Entities[0].Resolvers[0].KeyFields, 1) + require.Equal(t, "upc", f.Entities[0].Resolvers[0].KeyFields[0].Definition.Name) + require.Equal(t, "String", f.Entities[0].Resolvers[0].KeyFields[0].Definition.Type.Name()) require.Equal(t, "Hello", f.Entities[1].Name) - require.Len(t, f.Entities[1].KeyFields, 1) - require.Equal(t, "name", f.Entities[1].KeyFields[0].Definition.Name) - require.Equal(t, "String", f.Entities[1].KeyFields[0].Definition.Type.Name()) + require.Len(t, f.Entities[1].Resolvers, 1) + require.Len(t, f.Entities[1].Resolvers[0].KeyFields, 1) + require.Equal(t, "name", f.Entities[1].Resolvers[0].KeyFields[0].Definition.Name) + require.Equal(t, "String", f.Entities[1].Resolvers[0].KeyFields[0].Definition.Type.Name()) require.Equal(t, "MoreNesting", f.Entities[2].Name) - require.Len(t, f.Entities[2].KeyFields, 1) - require.Equal(t, "id", f.Entities[2].KeyFields[0].Definition.Name) - require.Equal(t, "String", f.Entities[2].KeyFields[0].Definition.Type.Name()) + require.Len(t, f.Entities[2].Resolvers, 0) require.Equal(t, "NestedKey", f.Entities[3].Name) - require.Len(t, f.Entities[3].KeyFields, 2) - require.Equal(t, "id", f.Entities[3].KeyFields[0].Definition.Name) - require.Equal(t, "String", f.Entities[3].KeyFields[0].Definition.Type.Name()) - require.Equal(t, "helloName", f.Entities[3].KeyFields[1].Definition.Name) - require.Equal(t, "String", f.Entities[3].KeyFields[1].Definition.Type.Name()) + require.Len(t, f.Entities[3].Resolvers, 1) + require.Len(t, f.Entities[3].Resolvers[0].KeyFields, 2) + require.Equal(t, "id", f.Entities[3].Resolvers[0].KeyFields[0].Definition.Name) + require.Equal(t, "String", f.Entities[3].Resolvers[0].KeyFields[0].Definition.Type.Name()) + require.Equal(t, "helloName", f.Entities[3].Resolvers[0].KeyFields[1].Definition.Name) + require.Equal(t, "String", f.Entities[3].Resolvers[0].KeyFields[1].Definition.Type.Name()) require.Equal(t, "VeryNestedKey", f.Entities[4].Name) - require.Len(t, f.Entities[4].KeyFields, 5) - require.Equal(t, "id", f.Entities[4].KeyFields[0].Definition.Name) - require.Equal(t, "String", f.Entities[4].KeyFields[0].Definition.Type.Name()) - require.Equal(t, "helloName", f.Entities[4].KeyFields[1].Definition.Name) - require.Equal(t, "String", f.Entities[4].KeyFields[1].Definition.Type.Name()) - require.Equal(t, "worldFoo", f.Entities[4].KeyFields[2].Definition.Name) - require.Equal(t, "String", f.Entities[4].KeyFields[2].Definition.Type.Name()) - require.Equal(t, "worldBar", f.Entities[4].KeyFields[3].Definition.Name) - require.Equal(t, "Int", f.Entities[4].KeyFields[3].Definition.Type.Name()) - require.Equal(t, "moreWorldFoo", f.Entities[4].KeyFields[4].Definition.Name) - require.Equal(t, "String", f.Entities[4].KeyFields[4].Definition.Type.Name()) + require.Len(t, f.Entities[4].Resolvers, 1) + require.Len(t, f.Entities[4].Resolvers[0].KeyFields, 5) + require.Equal(t, "id", f.Entities[4].Resolvers[0].KeyFields[0].Definition.Name) + require.Equal(t, "String", f.Entities[4].Resolvers[0].KeyFields[0].Definition.Type.Name()) + require.Equal(t, "helloName", f.Entities[4].Resolvers[0].KeyFields[1].Definition.Name) + require.Equal(t, "String", f.Entities[4].Resolvers[0].KeyFields[1].Definition.Type.Name()) + require.Equal(t, "worldFoo", f.Entities[4].Resolvers[0].KeyFields[2].Definition.Name) + require.Equal(t, "String", f.Entities[4].Resolvers[0].KeyFields[2].Definition.Type.Name()) + require.Equal(t, "worldBar", f.Entities[4].Resolvers[0].KeyFields[3].Definition.Name) + require.Equal(t, "Int", f.Entities[4].Resolvers[0].KeyFields[3].Definition.Type.Name()) + require.Equal(t, "moreWorldFoo", f.Entities[4].Resolvers[0].KeyFields[4].Definition.Name) + require.Equal(t, "String", f.Entities[4].Resolvers[0].KeyFields[4].Definition.Type.Name()) require.Len(t, f.Entities[4].Requires, 2) require.Equal(t, f.Entities[4].Requires[0].Name, "id") require.Equal(t, f.Entities[4].Requires[1].Name, "helloSecondary") require.Equal(t, "World", f.Entities[5].Name) - require.Len(t, f.Entities[5].KeyFields, 2) - require.Equal(t, "foo", f.Entities[5].KeyFields[0].Definition.Name) - require.Equal(t, "String", f.Entities[5].KeyFields[0].Definition.Type.Name()) - require.Equal(t, "bar", f.Entities[5].KeyFields[1].Definition.Name) - require.Equal(t, "Int", f.Entities[5].KeyFields[1].Definition.Type.Name()) + require.Len(t, f.Entities[5].Resolvers, 2) + require.Len(t, f.Entities[5].Resolvers[0].KeyFields, 1) + require.Equal(t, "foo", f.Entities[5].Resolvers[0].KeyFields[0].Definition.Name) + require.Equal(t, "String", f.Entities[5].Resolvers[0].KeyFields[0].Definition.Type.Name()) + require.Len(t, f.Entities[5].Resolvers[1].KeyFields, 1) + require.Equal(t, "bar", f.Entities[5].Resolvers[1].KeyFields[0].Definition.Name) + require.Equal(t, "Int", f.Entities[5].Resolvers[1].KeyFields[0].Definition.Type.Name()) } func TestNoEntities(t *testing.T) { diff --git a/plugin/federation/test_data/schema.graphql b/plugin/federation/test_data/schema.graphql index cf5a72004e5..5a379154baa 100644 --- a/plugin/federation/test_data/schema.graphql +++ b/plugin/federation/test_data/schema.graphql @@ -3,7 +3,7 @@ type Hello @key(fields: "name") { secondary: String! } -type World @key(fields: " foo bar ") { +type World @key(fields: " foo ") @key(fields: "bar") { foo: String! bar: Int! } From 85beda0579ade29b8d54baeb601f731d9a854825 Mon Sep 17 00:00:00 2001 From: Carl Dunham Date: Tue, 2 Nov 2021 18:40:59 -0700 Subject: [PATCH 2/2] add more unit test coverage to plugin/federation --- plugin/federation/.gitignore | 1 + plugin/federation/federation_test.go | 22 +++++++++ plugin/federation/test_data/gqlgen.yml | 3 ++ .../federation/test_data/interfaces.graphqls | 9 ++++ plugin/federation/test_data/interfaces.yml | 6 +++ .../federation/test_data/model/federation.go | 48 +++++++++++++++++++ 6 files changed, 89 insertions(+) create mode 100644 plugin/federation/.gitignore create mode 100644 plugin/federation/test_data/interfaces.graphqls create mode 100644 plugin/federation/test_data/interfaces.yml create mode 100644 plugin/federation/test_data/model/federation.go diff --git a/plugin/federation/.gitignore b/plugin/federation/.gitignore new file mode 100644 index 00000000000..52fdfdda111 --- /dev/null +++ b/plugin/federation/.gitignore @@ -0,0 +1 @@ +graph diff --git a/plugin/federation/federation_test.go b/plugin/federation/federation_test.go index d4844d3c89e..69a1be62b0a 100644 --- a/plugin/federation/federation_test.go +++ b/plugin/federation/federation_test.go @@ -3,6 +3,7 @@ package federation import ( "testing" + "github.com/99designs/gqlgen/codegen" "github.com/99designs/gqlgen/codegen/config" "github.com/stretchr/testify/require" ) @@ -78,6 +79,21 @@ func TestWithEntities(t *testing.T) { require.Equal(t, "Int", f.Entities[5].Resolvers[1].KeyFields[0].Definition.Type.Name()) } +func TestCodeGeneration(t *testing.T) { + f, cfg := load(t, "test_data/gqlgen.yml") + + require.Len(t, cfg.Schema.Types["_Entity"].Types, 6) + require.Len(t, f.Entities, 6) + + require.NoError(t, f.MutateConfig(cfg)) + + data, err := codegen.BuildData(cfg) + if err != nil { + panic(err) + } + require.NoError(t, f.GenerateCode(data)) +} + func TestNoEntities(t *testing.T) { f, cfg := load(t, "test_data/nokey.yml") @@ -85,6 +101,12 @@ func TestNoEntities(t *testing.T) { require.NoError(t, err) } +func TestInterfaces(t *testing.T) { + require.Panics(t, func() { + load(t, "test_data/interfaces.yml") + }) +} + func load(t *testing.T, name string) (*federation, *config.Config) { t.Helper() diff --git a/plugin/federation/test_data/gqlgen.yml b/plugin/federation/test_data/gqlgen.yml index 51e7fa628b8..34e515a7aa7 100644 --- a/plugin/federation/test_data/gqlgen.yml +++ b/plugin/federation/test_data/gqlgen.yml @@ -4,3 +4,6 @@ exec: filename: graph/generated/exec.go federation: filename: graph/generated/federation.go + +autobind: + - "github.com/99designs/gqlgen/plugin/federation/test_data/model" diff --git a/plugin/federation/test_data/interfaces.graphqls b/plugin/federation/test_data/interfaces.graphqls new file mode 100644 index 00000000000..87a5c75ce85 --- /dev/null +++ b/plugin/federation/test_data/interfaces.graphqls @@ -0,0 +1,9 @@ +interface Hello @key(fields: "name") { + name: String! + secondary: String! +} + +extend interface World { + foo: String! @external + bar: Int! +} diff --git a/plugin/federation/test_data/interfaces.yml b/plugin/federation/test_data/interfaces.yml new file mode 100644 index 00000000000..674e361e056 --- /dev/null +++ b/plugin/federation/test_data/interfaces.yml @@ -0,0 +1,6 @@ +schema: + - "test_data/interfaces.graphqls" +exec: + filename: graph/generated/exec.go +federation: + filename: graph/generated/federation.go diff --git a/plugin/federation/test_data/model/federation.go b/plugin/federation/test_data/model/federation.go new file mode 100644 index 00000000000..9aba4388fcd --- /dev/null +++ b/plugin/federation/test_data/model/federation.go @@ -0,0 +1,48 @@ +package model + +type _FieldSet string //nolint:deadcode,unused + +type Hello struct { + Name string + Secondary string +} + +func (Hello) IsEntity() {} + +type World struct { + Foo string + Bar int +} + +func (World) IsEntity() {} + +type ExternalExtension struct { + UPC string + Reviews []*World +} + +func (ExternalExtension) IsEntity() {} + +type NestedKey struct { + ID string + Hello *Hello +} + +func (NestedKey) IsEntity() {} + +type MoreNesting struct { + ID string + World *World +} + +func (MoreNesting) IsEntity() {} + +type VeryNestedKey struct { + ID string + Hello *Hello + World *World + Nested *NestedKey + More *MoreNesting +} + +func (VeryNestedKey) IsEntity() {}