From 605b3c0724f3a7f3007587d5cf2da9f7ba049448 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Wed, 28 Dec 2022 21:02:12 +0200 Subject: [PATCH] plugin/resolvergen: respect named return values --- codegen/field.go | 20 +++++++- internal/rewrite/rewriter.go | 47 ++++++------------- plugin/resolvergen/resolver.go | 14 +++--- plugin/resolvergen/resolver.gotpl | 2 +- plugin/resolvergen/resolver_test.go | 3 +- .../followschema/out/schema.resolvers.go | 6 +-- 6 files changed, 46 insertions(+), 46 deletions(-) diff --git a/codegen/field.go b/codegen/field.go index a33cc18e5a5..b9e71de731b 100644 --- a/codegen/field.go +++ b/codegen/field.go @@ -3,6 +3,7 @@ package codegen import ( "errors" "fmt" + goast "go/ast" "go/types" "log" "reflect" @@ -503,6 +504,12 @@ func (f *Field) ResolverType() string { } func (f *Field) ShortResolverDeclaration() string { + return f.ShortResolverSignature(nil) +} + +// ShortResolverSignature is identical to ShortResolverDeclaration, +// but respects previous naming (return) conventions, if any. +func (f *Field) ShortResolverSignature(ft *goast.FuncType) string { if f.Object.Kind == ast.InputObject { return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error", templates.CurrentImports.LookupType(f.Object.Reference()), @@ -523,8 +530,17 @@ func (f *Field) ShortResolverDeclaration() string { if f.Object.Stream { result = "<-chan " + result } - - res += fmt.Sprintf(") (%s, error)", result) + // Named return. + var namedV, namedE string + if ft != nil { + if ft.Results != nil && len(ft.Results.List) > 0 && len(ft.Results.List[0].Names) > 0 { + namedV = ft.Results.List[0].Names[0].Name + } + if ft.Results != nil && len(ft.Results.List) > 1 && len(ft.Results.List[1].Names) > 0 { + namedE = ft.Results.List[1].Names[0].Name + } + } + res += fmt.Sprintf(") (%s %s, %s error)", namedV, result, namedE) return res } diff --git a/internal/rewrite/rewriter.go b/internal/rewrite/rewriter.go index a8a6485cff7..07a5f04227d 100644 --- a/internal/rewrite/rewriter.go +++ b/internal/rewrite/rewriter.go @@ -68,7 +68,7 @@ func (r *Rewriter) getFile(filename string) string { return r.files[filename] } -func (r *Rewriter) GetMethodComment(structname string, methodname string) string { +func (r *Rewriter) GetPrevDecl(structname string, methodname string) *ast.FuncDecl { for _, f := range r.pkg.Syntax { for _, d := range f.Decls { d, isFunc := d.(*ast.FuncDecl) @@ -89,48 +89,29 @@ func (r *Rewriter) GetMethodComment(structname string, methodname string) string if !ok { continue } - if ident.Name != structname { continue } - return d.Doc.Text() + r.copied[d] = true + return d } } + return nil +} +func (r *Rewriter) GetMethodComment(structname string, methodname string) string { + d := r.GetPrevDecl(structname, methodname) + if d != nil { + return d.Doc.Text() + } return "" } -func (r *Rewriter) GetMethodBody(structname string, methodname string) string { - for _, f := range r.pkg.Syntax { - for _, d := range f.Decls { - d, isFunc := d.(*ast.FuncDecl) - if !isFunc { - continue - } - if d.Name.Name != methodname { - continue - } - if d.Recv == nil || len(d.Recv.List) == 0 { - continue - } - recv := d.Recv.List[0].Type - if star, isStar := recv.(*ast.StarExpr); isStar { - recv = star.X - } - ident, ok := recv.(*ast.Ident) - if !ok { - continue - } - if ident.Name != structname { - continue - } - - r.copied[d] = true - - return r.getSource(d.Body.Pos()+1, d.Body.End()-1) - } +func (r *Rewriter) GetMethodBody(structname string, methodname string) string { + d := r.GetPrevDecl(structname, methodname) + if d != nil { + return r.getSource(d.Body.Pos()+1, d.Body.End()-1) } - return "" } diff --git a/plugin/resolvergen/resolver.go b/plugin/resolvergen/resolver.go index 189a79fdcbd..ab1534155ce 100644 --- a/plugin/resolvergen/resolver.go +++ b/plugin/resolvergen/resolver.go @@ -4,6 +4,7 @@ import ( _ "embed" "errors" "fmt" + "go/ast" "io/fs" "os" "path/filepath" @@ -67,7 +68,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error { continue } - resolver := Resolver{o, f, "// foo", `panic("not implemented")`} + resolver := Resolver{o, f, nil, "// foo", `panic("not implemented")`} file.Resolvers = append(file.Resolvers, &resolver) } } @@ -119,16 +120,16 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error { } structName := templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type) - implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName)) comment := strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment(structName, f.GoFieldName), `\`)) - if implementation == "" { - implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", f.GoFieldName, f.Name) - } if comment == "" { comment = fmt.Sprintf("%v is the resolver for the %v field.", f.GoFieldName, f.Name) } + implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName)) + if implementation == "" { + implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", f.GoFieldName, f.Name) + } - resolver := Resolver{o, f, comment, implementation} + resolver := Resolver{o, f, rewriter.GetPrevDecl(structName, f.GoFieldName), comment, implementation} fn := gqlToResolverName(data.Config.Resolver.Dir(), f.Position.Src.Name, data.Config.Resolver.FilenameTemplate) if files[fn] == nil { files[fn] = &File{} @@ -215,6 +216,7 @@ func (f *File) Imports() string { type Resolver struct { Object *codegen.Object Field *codegen.Field + PrevDecl *ast.FuncDecl Comment string Implementation string } diff --git a/plugin/resolvergen/resolver.gotpl b/plugin/resolvergen/resolver.gotpl index c5d716ff7b6..920617608e6 100644 --- a/plugin/resolvergen/resolver.gotpl +++ b/plugin/resolvergen/resolver.gotpl @@ -20,7 +20,7 @@ {{ range $resolver := .Resolvers -}} // {{ $resolver.Comment }} - func (r *{{lcFirst $resolver.Object.Name}}{{ucFirst $.ResolverType}}) {{$resolver.Field.GoFieldName}}{{ $resolver.Field.ShortResolverDeclaration }} { + func (r *{{lcFirst $resolver.Object.Name}}{{ucFirst $.ResolverType}}) {{$resolver.Field.GoFieldName}}{{ with $resolver.PrevDecl }}{{ $resolver.Field.ShortResolverSignature .Type }}{{ else }}{{ $resolver.Field.ShortResolverDeclaration }}{{ end }}{ {{ $resolver.Implementation }} } diff --git a/plugin/resolvergen/resolver_test.go b/plugin/resolvergen/resolver_test.go index 930f3cb72b1..ea66e50567f 100644 --- a/plugin/resolvergen/resolver_test.go +++ b/plugin/resolvergen/resolver_test.go @@ -37,7 +37,8 @@ func TestLayoutFollowSchema(t *testing.T) { require.NoError(t, err) source := string(b) - require.Contains(t, source, "// CustomerResolverType.Resolver implementation") + require.Contains(t, source, "(_ *customresolver.Resolver, err error)") + require.Contains(t, source, "// Named return values are supported.") require.Contains(t, source, "// CustomerResolverType.Name implementation") require.Contains(t, source, "// AUserHelperFunction implementation") } diff --git a/plugin/resolvergen/testdata/followschema/out/schema.resolvers.go b/plugin/resolvergen/testdata/followschema/out/schema.resolvers.go index 80eb3a54060..ed8d1c77475 100644 --- a/plugin/resolvergen/testdata/followschema/out/schema.resolvers.go +++ b/plugin/resolvergen/testdata/followschema/out/schema.resolvers.go @@ -11,9 +11,9 @@ import ( ) // Resolver is the resolver for the resolver field. -func (r *queryCustomResolverType) Resolver(ctx context.Context) (*customresolver.Resolver, error) { - // CustomerResolverType.Resolver implementation - return nil, nil +func (r *queryCustomResolverType) Resolver(ctx context.Context) (_ *customresolver.Resolver, err error) { + // Named return values are supported. + return } // Name is the resolver for the name field.