diff --git a/codegen/generated!.gotpl b/codegen/generated!.gotpl index 9d8915c8d77..0117b952ffb 100644 --- a/codegen/generated!.gotpl +++ b/codegen/generated!.gotpl @@ -166,7 +166,7 @@ func (e *executableSchema) Subscription(ctx context.Context, op *ast.OperationDe {{- if .SubscriptionRoot }} ec := executionContext{graphql.GetRequestContext(ctx), e} - {{ if .MutationDirectives -}} + {{ if .SubscriptionDirectives -}} next := ec._{{.SubscriptionRoot.Name}}Middleware(ctx, op) {{- else -}} next := ec._{{.SubscriptionRoot.Name}}(ctx, op.SelectionSet) @@ -247,7 +247,7 @@ func (ec *executionContext) _{{.QueryRoot.Name}}Middleware(ctx context.Context, {{end}} {{ if and .SubscriptionDirectives .SubscriptionRoot }} -func (ec *executionContext) _{{.SubscriptionRoot.Name}}Middleware(ctx context.Context, obj *ast.OperationDefinition) graphql.Marshaler { +func (ec *executionContext) _{{.SubscriptionRoot.Name}}Middleware(ctx context.Context, obj *ast.OperationDefinition) func() graphql.Marshaler { next := func(ctx context.Context) (interface{}, error){ return ec._{{.SubscriptionRoot.Name}}(ctx, obj.SelectionSet),nil @@ -261,7 +261,9 @@ func (ec *executionContext) _{{.SubscriptionRoot.Name}}Middleware(ctx context.Co args, err := ec.{{ $directive.ArgsFunc }}(ctx,rawArgs) if err != nil { ec.Error(ctx, err) - return graphql.Null + return func() graphql.Marshaler { + return graphql.Null + } } {{- end }} n := next @@ -274,13 +276,17 @@ func (ec *executionContext) _{{.SubscriptionRoot.Name}}Middleware(ctx context.Co tmp, err := next(ctx) if err != nil { ec.Error(ctx, err) - return graphql.Null + return func() graphql.Marshaler { + return graphql.Null + } } - if data, ok := tmp.(graphql.Marshaler); ok { + if data, ok := tmp.(func() graphql.Marshaler); ok { return data } ec.Errorf(ctx, `unexpected type %T from directive, should be graphql.Marshaler`, tmp) - return graphql.Null + return func() graphql.Marshaler { + return graphql.Null + } } {{end}} diff --git a/example/chat/chat_test.go b/example/chat/chat_test.go index a4245f4862d..23c6673ef66 100644 --- a/example/chat/chat_test.go +++ b/example/chat/chat_test.go @@ -15,7 +15,7 @@ func TestChatSubscriptions(t *testing.T) { srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(New()))) c := client.New(srv.URL) - sub := c.Websocket(`subscription { messageAdded(roomName:"#gophers") { text createdBy } }`) + sub := c.Websocket(`subscription @user(username:"vektah") { messageAdded(roomName:"#gophers") { text createdBy } }`) defer sub.Close() go func() { @@ -23,7 +23,8 @@ func TestChatSubscriptions(t *testing.T) { time.Sleep(10 * time.Millisecond) err := c.Post(`mutation { a:post(text:"Hello!", roomName:"#gophers", username:"vektah") { id } - b:post(text:"Whats up?", roomName:"#gophers", username:"vektah") { id } + b:post(text:"Hello Vektah!", roomName:"#gophers", username:"andrey") { id } + c:post(text:"Whats up?", roomName:"#gophers", username:"vektah") { id } }`, &resp) assert.NoError(t, err) }() diff --git a/example/chat/generated.go b/example/chat/generated.go index f13879d03cc..15ffd28ce48 100644 --- a/example/chat/generated.go +++ b/example/chat/generated.go @@ -41,6 +41,7 @@ type ResolverRoot interface { } type DirectiveRoot struct { + User func(ctx context.Context, obj interface{}, next graphql.Resolver, username string) (res interface{}, err error) } type ComplexityRoot struct { @@ -213,7 +214,7 @@ func (e *executableSchema) Mutation(ctx context.Context, op *ast.OperationDefini func (e *executableSchema) Subscription(ctx context.Context, op *ast.OperationDefinition) func() *graphql.Response { ec := executionContext{graphql.GetRequestContext(ctx), e} - next := ec._Subscription(ctx, op.SelectionSet) + next := ec._SubscriptionMiddleware(ctx, op) if ec.Errors != nil { return graphql.OneShot(&graphql.Response{Data: []byte("null"), Errors: ec.Errors}) } @@ -248,6 +249,44 @@ type executionContext struct { *executableSchema } +func (ec *executionContext) _SubscriptionMiddleware(ctx context.Context, obj *ast.OperationDefinition) func() graphql.Marshaler { + + next := func(ctx context.Context) (interface{}, error) { + return ec._Subscription(ctx, obj.SelectionSet), nil + } + for _, d := range obj.Directives { + switch d.Name { + case "user": + rawArgs := d.ArgumentMap(ec.Variables) + args, err := ec.dir_user_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return func() graphql.Marshaler { + return graphql.Null + } + } + n := next + next = func(ctx context.Context) (interface{}, error) { + return ec.directives.User(ctx, obj, n, args["username"].(string)) + } + } + } + tmp, err := next(ctx) + if err != nil { + ec.Error(ctx, err) + return func() graphql.Marshaler { + return graphql.Null + } + } + if data, ok := tmp.(func() graphql.Marshaler); ok { + return data + } + ec.Errorf(ctx, `unexpected type %T from directive, should be graphql.Marshaler`, tmp) + return func() graphql.Marshaler { + return graphql.Null + } +} + func (ec *executionContext) FieldMiddleware(ctx context.Context, obj interface{}, next graphql.Resolver) (ret interface{}) { defer func() { if r := recover(); r != nil { @@ -255,6 +294,24 @@ func (ec *executionContext) FieldMiddleware(ctx context.Context, obj interface{} ret = nil } }() + rctx := graphql.GetResolverContext(ctx) + for _, d := range rctx.Field.Definition.Directives { + switch d.Name { + case "user": + if ec.directives.User != nil { + rawArgs := d.ArgumentMap(ec.Variables) + args, err := ec.dir_user_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return nil + } + n := next + next = func(ctx context.Context) (interface{}, error) { + return ec.directives.User(ctx, obj, n, args["username"].(string)) + } + } + } + } res, err := ec.ResolverMiddleware(ctx, next) if err != nil { ec.Error(ctx, err) @@ -303,6 +360,8 @@ type Subscription { } scalar Time + +directive @user(username: String!) on SUBSCRIPTION `}, ) @@ -310,6 +369,20 @@ scalar Time // region ***************************** args.gotpl ***************************** +func (ec *executionContext) dir_user_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 string + if tmp, ok := rawArgs["username"]; ok { + arg0, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["username"] = arg0 + return args, nil +} + func (ec *executionContext) field_Mutation_post_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} diff --git a/example/chat/resolvers.go b/example/chat/resolvers.go index 5c0de7b0602..023b88f664e 100644 --- a/example/chat/resolvers.go +++ b/example/chat/resolvers.go @@ -3,10 +3,12 @@ package chat import ( - context "context" + "context" "math/rand" "sync" "time" + + "github.com/99designs/gqlgen/graphql" ) type resolver struct { @@ -31,13 +33,28 @@ func New() Config { Resolvers: &resolver{ Rooms: map[string]*Chatroom{}, }, + Directives: DirectiveRoot{ + User: func(ctx context.Context, obj interface{}, next graphql.Resolver, username string) (res interface{}, err error) { + return next(context.WithValue(ctx, "username", username)) + }, + }, + } +} + +func getUsername(ctx context.Context) string { + if username, ok := ctx.Value("username").(string); ok { + return username } + return "" } type Chatroom struct { Name string Messages []Message - Observers map[string]chan *Message + Observers map[string]struct { + Username string + Message chan *Message + } } type mutationResolver struct{ *resolver } @@ -46,7 +63,13 @@ func (r *mutationResolver) Post(ctx context.Context, text string, username strin r.mu.Lock() room := r.Rooms[roomName] if room == nil { - room = &Chatroom{Name: roomName, Observers: map[string]chan *Message{}} + room = &Chatroom{ + Name: roomName, + Observers: map[string]struct { + Username string + Message chan *Message + }{}, + } r.Rooms[roomName] = room } r.mu.Unlock() @@ -61,7 +84,9 @@ func (r *mutationResolver) Post(ctx context.Context, text string, username strin room.Messages = append(room.Messages, message) r.mu.Lock() for _, observer := range room.Observers { - observer <- &message + if observer.Username == "" || observer.Username == message.CreatedBy { + observer.Message <- &message + } } r.mu.Unlock() return &message, nil @@ -73,7 +98,13 @@ func (r *queryResolver) Room(ctx context.Context, name string) (*Chatroom, error r.mu.Lock() room := r.Rooms[name] if room == nil { - room = &Chatroom{Name: name, Observers: map[string]chan *Message{}} + room = &Chatroom{ + Name: name, + Observers: map[string]struct { + Username string + Message chan *Message + }{}, + } r.Rooms[name] = room } r.mu.Unlock() @@ -87,7 +118,13 @@ func (r *subscriptionResolver) MessageAdded(ctx context.Context, roomName string r.mu.Lock() room := r.Rooms[roomName] if room == nil { - room = &Chatroom{Name: roomName, Observers: map[string]chan *Message{}} + room = &Chatroom{ + Name: roomName, + Observers: map[string]struct { + Username string + Message chan *Message + }{}, + } r.Rooms[roomName] = room } r.mu.Unlock() @@ -103,7 +140,10 @@ func (r *subscriptionResolver) MessageAdded(ctx context.Context, roomName string }() r.mu.Lock() - room.Observers[id] = events + room.Observers[id] = struct { + Username string + Message chan *Message + }{Username: getUsername(ctx), Message: events} r.mu.Unlock() return events, nil diff --git a/example/chat/schema.graphql b/example/chat/schema.graphql index 85a46768edf..18bfcae121b 100644 --- a/example/chat/schema.graphql +++ b/example/chat/schema.graphql @@ -23,3 +23,5 @@ type Subscription { } scalar Time + +directive @user(username: String!) on SUBSCRIPTION