Skip to content

Commit

Permalink
plugin/resolvergen: respect named return values
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Dec 28, 2022
1 parent c2b8eab commit 75550f2
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 46 deletions.
20 changes: 18 additions & 2 deletions codegen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package codegen
import (
"errors"
"fmt"
goast "go/ast"
"go/types"
"log"
"reflect"
Expand Down Expand Up @@ -503,6 +504,12 @@ func (f *Field) ResolverType() string {
}

func (f *Field) ShortResolverDeclaration() string {
return f.ShortResolverSignature(nil)
}

// ShortResolverSignature is identical to ShortResolverDeclaration,
// but respects previous naming (return) conventions, if any.
func (f *Field) ShortResolverSignature(ft *goast.FuncType) string {
if f.Object.Kind == ast.InputObject {
return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error",
templates.CurrentImports.LookupType(f.Object.Reference()),
Expand All @@ -523,8 +530,17 @@ func (f *Field) ShortResolverDeclaration() string {
if f.Object.Stream {
result = "<-chan " + result
}

res += fmt.Sprintf(") (%s, error)", result)
// Named return.
var namedV, namedE string
if ft != nil {
if ft.Results != nil && len(ft.Results.List) > 0 && ft.Results.List[0].Names != nil && len(ft.Results.List[0].Names) > 0 {
namedV = ft.Results.List[0].Names[0].Name
}
if ft.Results != nil && len(ft.Results.List) > 1 && ft.Results.List[1].Names != nil && len(ft.Results.List[1].Names) > 0 {
namedE = ft.Results.List[1].Names[0].Name
}
}
res += fmt.Sprintf(") (%s %s, %s error)", namedV, result, namedE)
return res
}

Expand Down
47 changes: 14 additions & 33 deletions internal/rewrite/rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (r *Rewriter) getFile(filename string) string {
return r.files[filename]
}

func (r *Rewriter) GetMethodComment(structname string, methodname string) string {
func (r *Rewriter) GetPrevDecl(structname string, methodname string) *ast.FuncDecl {
for _, f := range r.pkg.Syntax {
for _, d := range f.Decls {
d, isFunc := d.(*ast.FuncDecl)
Expand All @@ -89,48 +89,29 @@ func (r *Rewriter) GetMethodComment(structname string, methodname string) string
if !ok {
continue
}

if ident.Name != structname {
continue
}
return d.Doc.Text()
r.copied[d] = true
return d
}
}
return nil
}

func (r *Rewriter) GetMethodComment(structname string, methodname string) string {
d := r.GetPrevDecl(structname, methodname)
if d != nil {
return d.Doc.Text()
}
return ""
}
func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
for _, f := range r.pkg.Syntax {
for _, d := range f.Decls {
d, isFunc := d.(*ast.FuncDecl)
if !isFunc {
continue
}
if d.Name.Name != methodname {
continue
}
if d.Recv == nil || len(d.Recv.List) == 0 {
continue
}
recv := d.Recv.List[0].Type
if star, isStar := recv.(*ast.StarExpr); isStar {
recv = star.X
}
ident, ok := recv.(*ast.Ident)
if !ok {
continue
}

if ident.Name != structname {
continue
}

r.copied[d] = true

return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
}
func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
d := r.GetPrevDecl(structname, methodname)
if d != nil {
return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
}

return ""
}

Expand Down
14 changes: 8 additions & 6 deletions plugin/resolvergen/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
_ "embed"
"errors"
"fmt"
"go/ast"
"io/fs"
"os"
"path/filepath"
Expand Down Expand Up @@ -67,7 +68,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error {
continue
}

resolver := Resolver{o, f, "// foo", `panic("not implemented")`}
resolver := Resolver{o, f, nil, "// foo", `panic("not implemented")`}
file.Resolvers = append(file.Resolvers, &resolver)
}
}
Expand Down Expand Up @@ -119,16 +120,16 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error {
}

structName := templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type)
implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName))
comment := strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment(structName, f.GoFieldName), `\`))
if implementation == "" {
implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", f.GoFieldName, f.Name)
}
if comment == "" {
comment = fmt.Sprintf("%v is the resolver for the %v field.", f.GoFieldName, f.Name)
}
implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName))
if implementation == "" {
implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", f.GoFieldName, f.Name)
}

resolver := Resolver{o, f, comment, implementation}
resolver := Resolver{o, f, rewriter.GetPrevDecl(structName, f.GoFieldName), comment, implementation}
fn := gqlToResolverName(data.Config.Resolver.Dir(), f.Position.Src.Name, data.Config.Resolver.FilenameTemplate)
if files[fn] == nil {
files[fn] = &File{}
Expand Down Expand Up @@ -215,6 +216,7 @@ func (f *File) Imports() string {
type Resolver struct {
Object *codegen.Object
Field *codegen.Field
PrevDecl *ast.FuncDecl
Comment string
Implementation string
}
Expand Down
2 changes: 1 addition & 1 deletion plugin/resolvergen/resolver.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

{{ range $resolver := .Resolvers -}}
// {{ $resolver.Comment }}
func (r *{{lcFirst $resolver.Object.Name}}{{ucFirst $.ResolverType}}) {{$resolver.Field.GoFieldName}}{{ $resolver.Field.ShortResolverDeclaration }} {
func (r *{{lcFirst $resolver.Object.Name}}{{ucFirst $.ResolverType}}) {{$resolver.Field.GoFieldName}}{{ with $resolver.PrevDecl }}{{ $resolver.Field.ShortResolverSignature .Type }}{{ else }}{{ $resolver.Field.ShortResolverDeclaration }}{{ end }}{
{{ $resolver.Implementation }}
}

Expand Down
3 changes: 2 additions & 1 deletion plugin/resolvergen/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ func TestLayoutFollowSchema(t *testing.T) {
require.NoError(t, err)
source := string(b)

require.Contains(t, source, "// CustomerResolverType.Resolver implementation")
require.Contains(t, source, "(_ *customresolver.Resolver, err error)")
require.Contains(t, source, "// Named return values are supported.")
require.Contains(t, source, "// CustomerResolverType.Name implementation")
require.Contains(t, source, "// AUserHelperFunction implementation")
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 75550f2

Please sign in to comment.