diff --git a/codegen/directive_build.go b/codegen/directive_build.go index 18ece308635..8124afb3207 100644 --- a/codegen/directive_build.go +++ b/codegen/directive_build.go @@ -19,12 +19,12 @@ func (cfg *Config) buildDirectives(types NamedTypes) (map[string]*Directive, err var args []FieldArgument for _, arg := range dir.Arguments { newArg := FieldArgument{ - GQLName: arg.Name, - Type: types.getType(arg.Type), - GoVarName: sanitizeArgName(arg.Name), + GQLName: arg.Name, + TypeReference: types.getType(arg.Type), + GoVarName: sanitizeArgName(arg.Name), } - if !newArg.Type.IsInput && !newArg.Type.IsScalar { + if !newArg.TypeReference.IsInput && !newArg.TypeReference.IsScalar { return nil, errors.Errorf("%s cannot be used as argument of directive %s(%s) only input and scalar types are allowed", arg.Type, dir.Name, arg.Name) } @@ -69,10 +69,10 @@ func (cfg *Config) getDirectives(list ast.DirectiveList) ([]*Directive, error) { value = argValue } args = append(args, FieldArgument{ - GQLName: a.GQLName, - Value: value, - GoVarName: a.GoVarName, - Type: a.Type, + GQLName: a.GQLName, + Value: value, + GoVarName: a.GoVarName, + TypeReference: a.TypeReference, }) } dirs[i] = &Directive{ diff --git a/codegen/enum.go b/codegen/enum.go index 7804971c02a..0fc497eea1f 100644 --- a/codegen/enum.go +++ b/codegen/enum.go @@ -1,7 +1,7 @@ package codegen type Enum struct { - *NamedType + *TypeDefinition Description string Values []EnumValue } diff --git a/codegen/enum_build.go b/codegen/enum_build.go index 457d923f22c..9c55e4f3433 100644 --- a/codegen/enum_build.go +++ b/codegen/enum_build.go @@ -23,9 +23,9 @@ func (cfg *Config) buildEnums(types NamedTypes) []Enum { } enum := Enum{ - NamedType: namedType, - Values: values, - Description: typ.Description, + TypeDefinition: namedType, + Values: values, + Description: typ.Description, } enum.GoType = templates.ToCamel(enum.GQLType) enums = append(enums, enum) diff --git a/codegen/input_build.go b/codegen/input_build.go index 5eac32e5dd3..c0ac31ee020 100644 --- a/codegen/input_build.go +++ b/codegen/input_build.go @@ -44,7 +44,7 @@ func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program) (Obj } func (cfg *Config) buildInput(types NamedTypes, typ *ast.Definition) (*Object, error) { - obj := &Object{NamedType: types[typ.Name]} + obj := &Object{TypeDefinition: types[typ.Name]} typeEntry, entryExists := cfg.Models[typ.Name] for _, field := range typ.Fields { @@ -53,10 +53,10 @@ func (cfg *Config) buildInput(types NamedTypes, typ *ast.Definition) (*Object, e return nil, err } newField := Field{ - GQLName: field.Name, - Type: types.getType(field.Type), - Object: obj, - Directives: dirs, + GQLName: field.Name, + TypeReference: types.getType(field.Type), + Object: obj, + Directives: dirs, } if entryExists { @@ -73,7 +73,7 @@ func (cfg *Config) buildInput(types NamedTypes, typ *ast.Definition) (*Object, e } } - if !newField.Type.IsInput && !newField.Type.IsScalar { + if !newField.TypeReference.IsInput && !newField.TypeReference.IsScalar { return nil, errors.Errorf("%s cannot be used as a field of %s. only input and scalar types are allowed", newField.GQLType, obj.GQLType) } @@ -91,7 +91,7 @@ func (cfg *Config) buildInput(types NamedTypes, typ *ast.Definition) (*Object, e // if user has implemented an UnmarshalGQL method on the input type manually, use it // otherwise we will generate one. -func buildInputMarshaler(typ *ast.Definition, def types.Object) *Ref { +func buildInputMarshaler(typ *ast.Definition, def types.Object) *TypeImplementation { switch def := def.(type) { case *types.TypeName: namedType := def.Type().(*types.Named) @@ -103,5 +103,5 @@ func buildInputMarshaler(typ *ast.Definition, def types.Object) *Ref { } } - return &Ref{GoType: typ.Name} + return &TypeImplementation{GoType: typ.Name} } diff --git a/codegen/interface.go b/codegen/interface.go index 2de0c88a9b3..e18e849d20a 100644 --- a/codegen/interface.go +++ b/codegen/interface.go @@ -1,7 +1,7 @@ package codegen type Interface struct { - *NamedType + *TypeDefinition Implementors []InterfaceImplementor } @@ -9,5 +9,5 @@ type Interface struct { type InterfaceImplementor struct { ValueReceiver bool - *NamedType + *TypeDefinition } diff --git a/codegen/interface_build.go b/codegen/interface_build.go index 92052ba6b99..47152d577c6 100644 --- a/codegen/interface_build.go +++ b/codegen/interface_build.go @@ -24,21 +24,21 @@ func (cfg *Config) buildInterfaces(types NamedTypes, prog *loader.Program) []*In } func (cfg *Config) buildInterface(types NamedTypes, typ *ast.Definition, prog *loader.Program) *Interface { - i := &Interface{NamedType: types[typ.Name]} + i := &Interface{TypeDefinition: types[typ.Name]} for _, implementor := range cfg.schema.GetPossibleTypes(typ) { t := types[implementor.Name] i.Implementors = append(i.Implementors, InterfaceImplementor{ - NamedType: t, - ValueReceiver: cfg.isValueReceiver(types[typ.Name], t, prog), + TypeDefinition: t, + ValueReceiver: cfg.isValueReceiver(types[typ.Name], t, prog), }) } return i } -func (cfg *Config) isValueReceiver(intf *NamedType, implementor *NamedType, prog *loader.Program) bool { +func (cfg *Config) isValueReceiver(intf *TypeDefinition, implementor *TypeDefinition, prog *loader.Program) bool { interfaceType, err := findGoInterface(prog, intf.Package, intf.GoType) if interfaceType == nil || err != nil { return true diff --git a/codegen/model.go b/codegen/model.go index bcdc8703a6c..8c3acba1986 100644 --- a/codegen/model.go +++ b/codegen/model.go @@ -1,14 +1,14 @@ package codegen type Model struct { - *NamedType + *TypeDefinition Description string Fields []ModelField - Implements []*NamedType + Implements []*TypeDefinition } type ModelField struct { - *Type + *TypeReference GQLName string GoFieldName string GoFKName string diff --git a/codegen/models_build.go b/codegen/models_build.go index 56d2ff1fae1..f718e590a77 100644 --- a/codegen/models_build.go +++ b/codegen/models_build.go @@ -54,17 +54,17 @@ func (cfg *Config) buildModels(types NamedTypes, prog *loader.Program) ([]Model, func (cfg *Config) obj2Model(obj *Object) Model { model := Model{ - NamedType: obj.NamedType, - Implements: obj.Implements, - Fields: []ModelField{}, + TypeDefinition: obj.TypeDefinition, + Implements: obj.Implements, + Fields: []ModelField{}, } model.GoType = ucFirst(obj.GQLType) - model.Marshaler = &Ref{GoType: obj.GoType} + model.Marshaler = &TypeImplementation{GoType: obj.GoType} for i := range obj.Fields { field := &obj.Fields[i] - mf := ModelField{Type: field.Type, GQLName: field.GQLName} + mf := ModelField{TypeReference: field.TypeReference, GQLName: field.GQLName} if field.GoFieldName != "" { mf.GoFieldName = field.GoFieldName @@ -80,12 +80,12 @@ func (cfg *Config) obj2Model(obj *Object) Model { func int2Model(obj *Interface) Model { model := Model{ - NamedType: obj.NamedType, - Fields: []ModelField{}, + TypeDefinition: obj.TypeDefinition, + Fields: []ModelField{}, } model.GoType = ucFirst(obj.GQLType) - model.Marshaler = &Ref{GoType: obj.GoType} + model.Marshaler = &TypeImplementation{GoType: obj.GoType} return model } diff --git a/codegen/object.go b/codegen/object.go index 484a49f918e..eb477680d57 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -20,12 +20,12 @@ const ( ) type Object struct { - *NamedType + *TypeDefinition Fields []Field Satisfies []string - Implements []*NamedType - ResolverInterface *Ref + Implements []*TypeDefinition + ResolverInterface *TypeImplementation Root bool DisableConcurrency bool Stream bool @@ -33,7 +33,7 @@ type Object struct { } type Field struct { - *Type + *TypeReference Description string // Description of a field GQLName string // The name of the field in graphql GoFieldType GoFieldType // The field type in go, if any @@ -49,7 +49,7 @@ type Field struct { } type FieldArgument struct { - *Type + *TypeReference GQLName string // The name of the argument in graphql GoVarName string // The name of the var in go @@ -248,7 +248,7 @@ func (f *Field) CallArgs() string { // should be in the template, but its recursive and has a bunch of args func (f *Field) WriteJson() string { - return f.doWriteJson("res", f.Type.Modifiers, f.ASTType, false, 1) + return f.doWriteJson("res", f.TypeReference.Modifiers, f.ASTType, false, 1) } func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Type, isPtr bool, depth int) string { diff --git a/codegen/object_build.go b/codegen/object_build.go index c1ef9bad5e4..e3d067a3f1e 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -82,10 +82,10 @@ func sanitizeArgName(name string) string { } func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition) (*Object, error) { - obj := &Object{NamedType: types[typ.Name]} + obj := &Object{TypeDefinition: types[typ.Name]} typeEntry, entryExists := cfg.Models[typ.Name] - obj.ResolverInterface = &Ref{GoType: obj.GQLType + "Resolver"} + obj.ResolverInterface = &TypeImplementation{GoType: obj.GQLType + "Resolver"} if typ == cfg.schema.Query { obj.Root = true @@ -110,7 +110,7 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition) (*Object, for _, field := range typ.Fields { if typ == cfg.schema.Query && field.Name == "__type" { obj.Fields = append(obj.Fields, Field{ - Type: &Type{types["__Schema"], []string{modPtr}, ast.NamedType("__Schema", nil), nil}, + TypeReference: &TypeReference{types["__Schema"], []string{modPtr}, ast.NamedType("__Schema", nil), nil}, GQLName: "__schema", GoFieldType: GoFieldMethod, GoReceiverName: "ec", @@ -122,13 +122,13 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition) (*Object, } if typ == cfg.schema.Query && field.Name == "__schema" { obj.Fields = append(obj.Fields, Field{ - Type: &Type{types["__Type"], []string{modPtr}, ast.NamedType("__Schema", nil), nil}, + TypeReference: &TypeReference{types["__Type"], []string{modPtr}, ast.NamedType("__Schema", nil), nil}, GQLName: "__type", GoFieldType: GoFieldMethod, GoReceiverName: "ec", GoFieldName: "introspectType", Args: []FieldArgument{ - {GQLName: "name", Type: &Type{types["String"], []string{}, ast.NamedType("String", nil), nil}, Object: &Object{}}, + {GQLName: "name", TypeReference: &TypeReference{types["String"], []string{}, ast.NamedType("String", nil), nil}, Object: &Object{}}, }, Object: obj, }) @@ -151,14 +151,14 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition) (*Object, return nil, err } newArg := FieldArgument{ - GQLName: arg.Name, - Type: types.getType(arg.Type), - Object: obj, - GoVarName: sanitizeArgName(arg.Name), - Directives: dirs, + GQLName: arg.Name, + TypeReference: types.getType(arg.Type), + Object: obj, + GoVarName: sanitizeArgName(arg.Name), + Directives: dirs, } - if !newArg.Type.IsInput && !newArg.Type.IsScalar { + if !newArg.TypeReference.IsInput && !newArg.TypeReference.IsScalar { return nil, errors.Errorf("%s cannot be used as argument of %s.%s. only input and scalar types are allowed", arg.Type, obj.GQLType, field.Name) } @@ -174,7 +174,7 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition) (*Object, obj.Fields = append(obj.Fields, Field{ GQLName: field.Name, - Type: types.getType(field.Type), + TypeReference: types.getType(field.Type), Args: args, Object: obj, GoFieldName: goName, diff --git a/codegen/type_build.go b/codegen/type_build.go index 586b0db2386..096fc447d09 100644 --- a/codegen/type_build.go +++ b/codegen/type_build.go @@ -8,9 +8,9 @@ import ( "golang.org/x/tools/go/loader" ) -// namedTypeFromSchema objects for every graphql type, including scalars. There should only be one instance of Type for each thing +// namedTypeFromSchema objects for every graphql type, including scalars. There should only be one instance of TypeReference for each thing func (cfg *Config) buildNamedTypes() NamedTypes { - types := map[string]*NamedType{} + types := map[string]*TypeDefinition{} for _, schemaType := range cfg.schema.Types { t := namedTypeFromSchema(schemaType) @@ -37,7 +37,7 @@ func (cfg *Config) bindTypes(namedTypes NamedTypes, destDir string, prog *loader switch def := def.(type) { case *types.Func: sig := def.Type().(*types.Signature) - cpy := t.Ref + cpy := t.TypeImplementation t.Marshaler = &cpy t.Package, t.GoType = pkgAndType(sig.Params().At(0).Type().String()) @@ -47,20 +47,20 @@ func (cfg *Config) bindTypes(namedTypes NamedTypes, destDir string, prog *loader // namedTypeFromSchema objects for every graphql type, including primitives. // don't recurse into object fields or interfaces yet, lets make sure we have collected everything first. -func namedTypeFromSchema(schemaType *ast.Definition) *NamedType { +func namedTypeFromSchema(schemaType *ast.Definition) *TypeDefinition { switch schemaType.Kind { case ast.Scalar, ast.Enum: - return &NamedType{GQLType: schemaType.Name, IsScalar: true} + return &TypeDefinition{GQLType: schemaType.Name, IsScalar: true} case ast.Interface, ast.Union: - return &NamedType{GQLType: schemaType.Name, IsInterface: true} + return &TypeDefinition{GQLType: schemaType.Name, IsInterface: true} case ast.InputObject: - return &NamedType{GQLType: schemaType.Name, IsInput: true} + return &TypeDefinition{GQLType: schemaType.Name, IsInput: true} default: - return &NamedType{GQLType: schemaType.Name} + return &TypeDefinition{GQLType: schemaType.Name} } } -// take a string in the form github.com/package/blah.Type and split it into package and type +// take a string in the form github.com/package/blah.TypeReference and split it into package and type func pkgAndType(name string) (string, string) { parts := strings.Split(name, ".") if len(parts) == 1 { @@ -70,7 +70,7 @@ func pkgAndType(name string) (string, string) { return normalizeVendor(strings.Join(parts[:len(parts)-1], ".")), parts[len(parts)-1] } -func (n NamedTypes) getType(t *ast.Type) *Type { +func (n NamedTypes) getType(t *ast.Type) *TypeReference { orig := t var modifiers []string for { @@ -84,10 +84,10 @@ func (n NamedTypes) getType(t *ast.Type) *Type { if n[t.NamedType] == nil { panic("missing type " + t.NamedType) } - res := &Type{ - NamedType: n[t.NamedType], - Modifiers: modifiers, - ASTType: orig, + res := &TypeReference{ + TypeDefinition: n[t.NamedType], + Modifiers: modifiers, + ASTType: orig, } if res.IsInterface { diff --git a/codegen/type_definition.go b/codegen/type_definition.go new file mode 100644 index 00000000000..810cddcf33e --- /dev/null +++ b/codegen/type_definition.go @@ -0,0 +1,43 @@ +package codegen + +import "github.com/99designs/gqlgen/codegen/templates" + +type NamedTypes map[string]*TypeDefinition + +type TypeDefinition struct { + TypeImplementation + IsScalar bool + IsInterface bool + IsInput bool + GQLType string // Name of the graphql type + Marshaler *TypeImplementation // If this type has an external marshaler this will be set +} + +type TypeImplementation struct { + GoType string // Name of the go type + Package string // the package the go type lives in + IsUserDefined bool // does the type exist in the typemap +} + +const ( + modList = "[]" + modPtr = "*" +) + +func (t TypeImplementation) FullName() string { + return t.PkgDot() + t.GoType +} + +func (t TypeImplementation) PkgDot() string { + name := templates.CurrentImports.Lookup(t.Package) + if name == "" { + return "" + + } + + return name + "." +} + +func (t TypeDefinition) IsMarshaled() bool { + return t.Marshaler != nil +} diff --git a/codegen/type.go b/codegen/type_reference.go similarity index 74% rename from codegen/type.go rename to codegen/type_reference.go index 5309c1fbe1f..67b1d033b87 100644 --- a/codegen/type.go +++ b/codegen/type_reference.go @@ -4,63 +4,26 @@ import ( "strconv" "strings" - "github.com/99designs/gqlgen/codegen/templates" - "github.com/vektah/gqlparser/ast" ) -type NamedTypes map[string]*NamedType - -type NamedType struct { - Ref - IsScalar bool - IsInterface bool - IsInput bool - GQLType string // Name of the graphql type - Marshaler *Ref // If this type has an external marshaler this will be set -} - -type Ref struct { - GoType string // Name of the go type - Package string // the package the go type lives in - IsUserDefined bool // does the type exist in the typemap -} - -type Type struct { - *NamedType +// TypeReference represents the type of a field or arg, referencing an underlying TypeDefinition (type, input, scalar) +type TypeReference struct { + *TypeDefinition Modifiers []string ASTType *ast.Type - AliasedType *Ref + AliasedType *TypeImplementation } -const ( - modList = "[]" - modPtr = "*" -) - -func (t Ref) FullName() string { - return t.PkgDot() + t.GoType -} - -func (t Ref) PkgDot() string { - name := templates.CurrentImports.Lookup(t.Package) - if name == "" { - return "" - - } - - return name + "." -} - -func (t Type) Signature() string { +func (t TypeReference) Signature() string { if t.AliasedType != nil { return strings.Join(t.Modifiers, "") + t.AliasedType.FullName() } return strings.Join(t.Modifiers, "") + t.FullName() } -func (t Type) FullSignature() string { +func (t TypeReference) FullSignature() string { pkg := "" if t.Package != "" { pkg = t.Package + "." @@ -69,31 +32,27 @@ func (t Type) FullSignature() string { return strings.Join(t.Modifiers, "") + pkg + t.GoType } -func (t Type) IsPtr() bool { +func (t TypeReference) IsPtr() bool { return len(t.Modifiers) > 0 && t.Modifiers[0] == modPtr } -func (t *Type) StripPtr() { +func (t *TypeReference) StripPtr() { if !t.IsPtr() { return } t.Modifiers = t.Modifiers[0 : len(t.Modifiers)-1] } -func (t Type) IsSlice() bool { +func (t TypeReference) IsSlice() bool { return len(t.Modifiers) > 0 && t.Modifiers[0] == modList || len(t.Modifiers) > 1 && t.Modifiers[0] == modPtr && t.Modifiers[1] == modList } -func (t NamedType) IsMarshaled() bool { - return t.Marshaler != nil -} - -func (t Type) Unmarshal(result, raw string) string { +func (t TypeReference) Unmarshal(result, raw string) string { return t.unmarshal(result, raw, t.Modifiers, 1) } -func (t Type) unmarshal(result, raw string, remainingMods []string, depth int) string { +func (t TypeReference) unmarshal(result, raw string, remainingMods []string, depth int) string { switch { case len(remainingMods) > 0 && remainingMods[0] == modPtr: ptr := "ptr" + strconv.Itoa(depth) @@ -131,7 +90,7 @@ func (t Type) unmarshal(result, raw string, remainingMods []string, depth int) s "rawSlice": rawIf, "index": index, "result": result, - "type": strings.Join(remainingMods, "") + t.NamedType.FullName(), + "type": strings.Join(remainingMods, "") + t.TypeDefinition.FullName(), "next": t.unmarshal(result+"["+index+"]", rawIf+"["+index+"]", remainingMods[1:], depth+1), }) } @@ -161,11 +120,11 @@ func (t Type) unmarshal(result, raw string, remainingMods []string, depth int) s }) } -func (t Type) Middleware(result, raw string) string { +func (t TypeReference) Middleware(result, raw string) string { return t.middleware(result, raw, t.Modifiers, 1) } -func (t Type) middleware(result, raw string, remainingMods []string, depth int) string { +func (t TypeReference) middleware(result, raw string, remainingMods []string, depth int) string { if len(remainingMods) == 1 && remainingMods[0] == modPtr { return tpl(`{{- if .t.Marshaler }} if {{.raw}} != nil { @@ -202,7 +161,7 @@ func (t Type) middleware(result, raw string, remainingMods []string, depth int) "raw": raw, "index": index, "result": result, - "type": strings.Join(remainingMods, "") + t.NamedType.FullName(), + "type": strings.Join(remainingMods, "") + t.TypeDefinition.FullName(), "next": t.middleware(result+"["+index+"]", raw+"["+index+"]", remainingMods[1:], depth+1), }) } @@ -222,7 +181,7 @@ func (t Type) middleware(result, raw string, remainingMods []string, depth int) }) } -func (t Type) Marshal(val string) string { +func (t TypeReference) Marshal(val string) string { if t.AliasedType != nil { val = t.GoType + "(" + val + ")" } diff --git a/codegen/util.go b/codegen/util.go index cc6246fdb74..41cce5f8bff 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -320,7 +320,7 @@ nextArg: for _, oldArg := range field.Args { if strings.EqualFold(oldArg.GQLName, param.Name()) { if !field.ForceResolver { - oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) + oldArg.TypeReference.Modifiers = modifiersFromGoType(param.Type()) } newArgs = append(newArgs, oldArg) continue nextArg @@ -334,20 +334,20 @@ nextArg: } func validateTypeBinding(field *Field, goType types.Type) error { - gqlType := normalizeVendor(field.Type.FullSignature()) + gqlType := normalizeVendor(field.TypeReference.FullSignature()) goTypeStr := normalizeVendor(goType.String()) if equalTypes(goTypeStr, gqlType) { - field.Type.Modifiers = modifiersFromGoType(goType) + field.TypeReference.Modifiers = modifiersFromGoType(goType) return nil } // deal with type aliases underlyingStr := normalizeVendor(goType.Underlying().String()) if equalTypes(underlyingStr, gqlType) { - field.Type.Modifiers = modifiersFromGoType(goType) + field.TypeReference.Modifiers = modifiersFromGoType(goType) pkg, typ := pkgAndType(goType.String()) - field.AliasedType = &Ref{GoType: typ, Package: pkg} + field.AliasedType = &TypeImplementation{GoType: typ, Package: pkg} return nil }