diff --git a/example/todo/generated.go b/example/todo/generated.go index eea9f41ce4a..1815ba619dd 100644 --- a/example/todo/generated.go +++ b/example/todo/generated.go @@ -40,6 +40,7 @@ type MyMutationResolver interface { } type MyQueryResolver interface { Todo(ctx context.Context, id int) (*Todo, error) + AuthenticatedTodo(ctx context.Context, id int) (*Todo, error) LastTodo(ctx context.Context) (*Todo, error) Todos(ctx context.Context) ([]Todo, error) } @@ -211,6 +212,8 @@ func (ec *executionContext) _MyQuery(ctx context.Context, sel ast.SelectionSet) out.Values[i] = graphql.MarshalString("MyQuery") case "todo": out.Values[i] = ec._MyQuery_todo(ctx, field) + case "authenticatedTodo": + out.Values[i] = ec._MyQuery_authenticatedTodo(ctx, field) case "lastTodo": out.Values[i] = ec._MyQuery_lastTodo(ctx, field) case "todos": @@ -267,6 +270,46 @@ func (ec *executionContext) _MyQuery_todo(ctx context.Context, field graphql.Col }) } +func (ec *executionContext) _MyQuery_authenticatedTodo(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + args := map[string]interface{}{} + var arg0 int + if tmp, ok := field.Args["id"]; ok { + var err error + arg0, err = graphql.UnmarshalInt(tmp) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + } + args["id"] = arg0 + ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "MyQuery", + Args: args, + Field: field, + }) + return graphql.Defer(func() (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + userErr := ec.Recover(ctx, r) + ec.Error(ctx, userErr) + ret = graphql.Null + } + }() + + resTmp := ec.FieldMiddleware(ctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.MyQuery().AuthenticatedTodo(ctx, args["id"].(int)) + }) + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*Todo) + if res == nil { + return graphql.Null + } + return ec._Todo(ctx, field.Selections, res) + }) +} + func (ec *executionContext) _MyQuery_lastTodo(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{ Object: "MyQuery", @@ -1355,13 +1398,12 @@ func UnmarshalTodoInput(v interface{}) (TodoInput, error) { func (ec *executionContext) FieldMiddleware(ctx context.Context, next graphql.Resolver) interface{} { rctx := graphql.GetResolverContext(ctx) - if len(rctx.Field.Directives) != 0 { - for _, d := range rctx.Field.Directives { - switch d.Name { - case "isAuthenticated": - next = func(ctx context.Context) (interface{}, error) { - return ec.directives.IsAuthenticated(ctx, next) - } + for _, d := range rctx.Field.Definition.Directives { + switch d.Name { + case "isAuthenticated": + n := next + next = func(ctx context.Context) (interface{}, error) { + return ec.directives.IsAuthenticated(ctx, n) } } } @@ -1389,6 +1431,7 @@ var parsedSchema = gqlparser.MustLoadSchema( type MyQuery { todo(id: Int!): Todo + authenticatedTodo(id: Int!): Todo @isAuthenticated lastTodo: Todo todos: [Todo!]! } diff --git a/example/todo/schema.graphql b/example/todo/schema.graphql index c98e09b328a..bdc631b016c 100644 --- a/example/todo/schema.graphql +++ b/example/todo/schema.graphql @@ -5,6 +5,7 @@ schema { type MyQuery { todo(id: Int!): Todo + authenticatedTodo(id: Int!): Todo @isAuthenticated lastTodo: Todo todos: [Todo!]! } diff --git a/example/todo/todo.go b/example/todo/todo.go index 8ebdbfec01e..d9aa4ea7f7f 100644 --- a/example/todo/todo.go +++ b/example/todo/todo.go @@ -8,18 +8,30 @@ import ( "time" "github.com/mitchellh/mapstructure" + graphql "github.com/vektah/gqlgen/graphql" ) func New() Config { - r := &resolvers{ - todos: []Todo{ - {ID: 1, Text: "A todo not to forget", Done: false}, - {ID: 2, Text: "This is the most important", Done: false}, - {ID: 3, Text: "Please do this or else", Done: false}, + c := Config{ + Resolvers: &resolvers{ + todos: []Todo{ + {ID: 1, Text: "A todo not to forget", Done: false}, + {ID: 2, Text: "This is the most important", Done: false}, + {ID: 3, Text: "Please do this or else", Done: false}, + }, + lastID: 3, }, - lastID: 3, } - return Config{Resolvers: r} + c.Directives.IsAuthenticated = func(ctx context.Context, next graphql.Resolver) (interface{}, error) { + rctx := graphql.GetResolverContext(ctx) + idVal := rctx.Field.Arguments.ForName("id").Value + id, _ := idVal.Value(make(map[string]interface{})) + if id.(int64) == 1 { + return nil, nil + } + return next(ctx) + } + return c } type resolvers struct { @@ -63,6 +75,10 @@ func (r *QueryResolver) Todos(ctx context.Context) ([]Todo, error) { return r.todos, nil } +func (r *QueryResolver) AuthenticatedTodo(ctx context.Context, id int) (*Todo, error) { + return r.Todo(ctx, id) +} + type MutationResolver resolvers func (r *MutationResolver) CreateTodo(ctx context.Context, todo TodoInput) (Todo, error) { diff --git a/example/todo/todo_test.go b/example/todo/todo_test.go index e5c4fdddfe3..ca244274ad4 100644 --- a/example/todo/todo_test.go +++ b/example/todo/todo_test.go @@ -127,6 +127,19 @@ func TestTodo(t *testing.T) { require.Equal(t, "Completed todo", resp.CreateTodo.Text) }) + + t.Run("isAuthenticated directive middleware", func(t *testing.T) { + var resp map[string]interface{} + c.MustPost(`{ authenticatedTodo(id: 1) { __typename } }`, &resp) + val, ok := resp["authenticatedTodo"] + require.True(t, ok) + require.Nil(t, val) + + c.MustPost(`{ authenticatedTodo(id: 2) { __typename } }`, &resp) + val, ok = resp["authenticatedTodo"] + require.True(t, ok) + require.NotNil(t, val) + }) } func TestSkipAndIncludeDirectives(t *testing.T) {