diff --git a/plugin/federation/federation.go b/plugin/federation/federation.go index 37994f5b35c..c361cf8fdb1 100644 --- a/plugin/federation/federation.go +++ b/plugin/federation/federation.go @@ -91,7 +91,11 @@ func (f *federation) InjectSources(cfg *config.Config) { s := "type Entity {\n" for _, e := range f.Entities { - s += fmt.Sprintf("\t%s(%s: %s): %s!\n", e.ResolverName, e.Field.Name, e.Field.Type.String(), e.Def.Name) + resolverArgs := "" + for _, field := range e.KeyFields { + resolverArgs += fmt.Sprintf("%s: %s,", field.Field.Name, field.Field.Type.String()) + } + s += fmt.Sprintf("\t%s(%s): %s!\n", e.ResolverName, resolverArgs, e.Def.Name) } s += "}" cfg.AdditionalSources = append(cfg.AdditionalSources, &ast.Source{Name: "entity.graphql", Input: s, BuiltIn: true}) @@ -192,13 +196,18 @@ directive @extends on OBJECT // Entity represents a federated type // that was declared in the GQL schema. type Entity struct { - Field *ast.FieldDefinition - FieldTypeGo string // The Go representation of that field type - ResolverName string // The resolver name, such as FindUserByID + Name string // The same name as the type declaration + KeyFields []*KeyField // The fields declared in @key. + ResolverName string // The resolver name, such as FindUserByID Def *ast.Definition Requires []*Requires } +type KeyField struct { + Field *ast.FieldDefinition + TypeReference *config.TypeReference // The Go representation of that field type +} + // Requires represents an @requires clause type Requires struct { Name string // the name of the field @@ -223,15 +232,19 @@ func (f *federation) GenerateCode(data *codegen.Data) error { data.Objects.ByName("Entity").Root = true for _, e := range f.Entities { obj := data.Objects.ByName(e.Def.Name) - for _, f := range obj.Fields { - if f.Name == e.Field.Name { - e.FieldTypeGo = f.TypeReference.GO.String() + for _, field := range obj.Fields { + // Storing key fields in a slice rather than a map + // to preserve insertion order at the tradeoff of higher + // lookup complexity. + keyField := f.getKeyField(e.KeyFields, field.Name) + if keyField != nil { + keyField.TypeReference = field.TypeReference } for _, r := range e.Requires { for _, rf := range r.Fields { - if rf.Name == f.Name { - rf.TypeReference = f.TypeReference - rf.NameGo = f.GoFieldName + if rf.Name == field.Name { + rf.TypeReference = field.TypeReference + rf.NameGo = field.GoFieldName } } } @@ -248,6 +261,15 @@ func (f *federation) GenerateCode(data *codegen.Data) error { }) } +func (f *federation) getKeyField(keyFields []*KeyField, fieldName string) *KeyField { + for _, field := range keyFields { + if field.Field.Name == fieldName { + return field + } + } + return nil +} + func (f *federation) setEntities(cfg *config.Config) { // crazy hack to get our injected code in so everything compiles, so we can generate the entity map // so we can reload the full schema. @@ -259,11 +281,14 @@ func (f *federation) setEntities(cfg *config.Config) { if schemaType.Kind == ast.Object { dir := schemaType.Directives.ForName("key") // TODO: interfaces if dir != nil { - fieldName := dir.Arguments[0].Value.Raw // TODO: multiple arguments, and multiple keys - if strings.Contains(fieldName, " ") { - panic("only single fields are currently supported in @key declaration") + if len(dir.Arguments) > 1 { + panic("Multiple arguments are not currently supported in @key declaration.") } - field := schemaType.Fields.ForName(fieldName) + fieldName := dir.Arguments[0].Value.Raw // TODO: multiple arguments + if strings.Contains(fieldName, "{") { + panic("Nested fields are not currently supported in @key declaration.") + } + requires := []*Requires{} for _, f := range schemaType.Fields { dir := f.Directives.ForName("requires") @@ -282,10 +307,26 @@ func (f *federation) setEntities(cfg *config.Config) { Fields: requireFields, }) } + + fieldNames := strings.Split(fieldName, " ") + keyFields := make([]*KeyField, len(fieldNames)) + resolverName := fmt.Sprintf("find%sBy", schemaType.Name) + for i, f := range fieldNames { + field := schemaType.Fields.ForName(f) + + keyFields[i] = &KeyField{Field: field} + if i > 0 { + resolverName += "And" + } + resolverName += templates.ToGo(f) + + } + f.Entities = append(f.Entities, &Entity{ - Field: field, + Name: schemaType.Name, + KeyFields: keyFields, Def: schemaType, - ResolverName: fmt.Sprintf("find%sBy%s", schemaType.Name, templates.ToGo(fieldName)), + ResolverName: resolverName, Requires: requires, }) } diff --git a/plugin/federation/federation.gotpl b/plugin/federation/federation.gotpl index 4bb83db3094..3672b0b84c2 100644 --- a/plugin/federation/federation.gotpl +++ b/plugin/federation/federation.gotpl @@ -1,5 +1,6 @@ {{ reserveImport "context" }} {{ reserveImport "errors" }} +{{ reserveImport "fmt" }} {{ reserveImport "github.com/99designs/gqlgen/plugin/federation" }} @@ -23,23 +24,28 @@ func (ec *executionContext) __resolve_entities(ctx context.Context, representati switch typeName { {{ range .Entities }} case "{{.Def.Name}}": - id, ok := rep["{{.Field.Name}}"].({{.FieldTypeGo}}) - if !ok { - return nil, errors.New("opsies") - } - resp, err := ec.resolvers.Entity().{{.ResolverName | go}}(ctx, id) + {{ range $i, $keyField := .KeyFields -}} + id{{$i}}, err := ec.{{.TypeReference.UnmarshalFunc}}(ctx, rep["{{$keyField.Field.Name}}"]) + if err != nil { + return nil, errors.New(fmt.Sprintf("Field %s undefined in schema.", "{{$keyField.Field.Name}}")) + } + {{end}} + + entity, err := ec.resolvers.Entity().{{.ResolverName | go}}(ctx, + {{ range $i, $_ := .KeyFields -}} id{{$i}}, {{end}}) if err != nil { return nil, err } + {{ range .Requires }} {{ range .Fields}} - resp.{{.NameGo}}, err = ec.{{.TypeReference.UnmarshalFunc}}(ctx, rep["{{.Name}}"]) + entity.{{.NameGo}}, err = ec.{{.TypeReference.UnmarshalFunc}}(ctx, rep["{{.Name}}"]) if err != nil { return nil, err } {{ end }} {{ end }} - list = append(list, resp) + list = append(list, entity) {{ end }} default: return nil, errors.New("unknown type: "+typeName) diff --git a/plugin/federation/federation_test.go b/plugin/federation/federation_test.go index c6b0b37bab4..0fa553fb126 100644 --- a/plugin/federation/federation_test.go +++ b/plugin/federation/federation_test.go @@ -26,6 +26,7 @@ func TestMutateSchema(t *testing.T) { Name: "schema.graphql", Input: `type Query { hello: String! + world: String! }`, }) if gqlErr != nil { diff --git a/plugin/federation/test_data/schema.graphql b/plugin/federation/test_data/schema.graphql index d7b0ce537f3..55e04d341e5 100644 --- a/plugin/federation/test_data/schema.graphql +++ b/plugin/federation/test_data/schema.graphql @@ -2,6 +2,12 @@ type Hello @key(fields: "name") { name: String! } +type World @key(fields: "foo bar") { + foo: String! + bar: Int! +} + type Query { hello: Hello! + world: World! }