Skip to content

Commit

Permalink
Merge pull request arsmn#11 from vektah/subscriptions
Browse files Browse the repository at this point in the history
Add support for subscriptions
  • Loading branch information
vektah authored Feb 17, 2018
2 parents 26fac02 + d1b6fed commit 03734c6
Show file tree
Hide file tree
Showing 45 changed files with 3,273 additions and 413 deletions.
10 changes: 8 additions & 2 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (p *Client) MustPost(query string, response interface{}, options ...Option)
}
}

func (p *Client) Post(query string, response interface{}, options ...Option) error {
func (p *Client) mkRequest(query string, options ...Option) Request {
r := Request{
Query: query,
}
Expand All @@ -71,6 +71,11 @@ func (p *Client) Post(query string, response interface{}, options ...Option) err
option(&r)
}

return r
}

func (p *Client) Post(query string, response interface{}, options ...Option) error {
r := p.mkRequest(query, options...)
requestBody, err := json.Marshal(r)
if err != nil {
return fmt.Errorf("encode: %s", err.Error())
Expand Down Expand Up @@ -120,6 +125,7 @@ func unpack(data interface{}, into interface{}) error {
Result: into,
TagName: "json",
ErrorUnused: true,
ZeroFields: true,
})
if err != nil {
return fmt.Errorf("mapstructure: %s", err.Error())
Expand Down
104 changes: 104 additions & 0 deletions client/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package client

import (
"encoding/json"
"fmt"
"strings"

"github.com/gorilla/websocket"
"github.com/vektah/gqlgen/neelance/errors"
)

const (
connectionInitMsg = "connection_init" // Client -> Server
connectionTerminateMsg = "connection_terminate" // Client -> Server
startMsg = "start" // Client -> Server
stopMsg = "stop" // Client -> Server
connectionAckMsg = "connection_ack" // Server -> Client
connectionErrorMsg = "connection_error" // Server -> Client
connectionKeepAliveMsg = "ka" // Server -> Client
dataMsg = "data" // Server -> Client
errorMsg = "error" // Server -> Client
completeMsg = "complete" // Server -> Client
)

type operationMessage struct {
Payload json.RawMessage `json:"payload,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type"`
}

type Subscription struct {
Close func() error
Next func(response interface{}) error
}

func errorSubscription(err error) *Subscription {
return &Subscription{
Close: func() error { return nil },
Next: func(response interface{}) error {
return err
},
}
}

func (p *Client) Websocket(query string, options ...Option) *Subscription {
r := p.mkRequest(query, options...)
requestBody, err := json.Marshal(r)
if err != nil {
return errorSubscription(fmt.Errorf("encode: %s", err.Error()))
}

url := strings.Replace(p.url, "http://", "ws://", -1)
url = strings.Replace(url, "https://", "wss://", -1)

c, _, err := websocket.DefaultDialer.Dial(url, nil)
if err != nil {
return errorSubscription(fmt.Errorf("dial: %s", err.Error()))
}

if err = c.WriteJSON(operationMessage{Type: connectionInitMsg}); err != nil {
return errorSubscription(fmt.Errorf("init: %s", err.Error()))
}

var ack operationMessage
if err := c.ReadJSON(&ack); err != nil {
return errorSubscription(fmt.Errorf("ack: %s", err.Error()))
}
if ack.Type != connectionAckMsg {
return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
}

if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil {
return errorSubscription(fmt.Errorf("start: %s", err.Error()))
}

return &Subscription{
Close: c.Close,
Next: func(response interface{}) error {
var op operationMessage
c.ReadJSON(&op)
if op.Type != dataMsg {
return fmt.Errorf("expected data message, got %#v", op)
}

respDataRaw := map[string]interface{}{}
err = json.Unmarshal(op.Payload, &respDataRaw)
if err != nil {
return fmt.Errorf("decode: %s", err.Error())
}

if respDataRaw["errors"] != nil {
var errs []*errors.QueryError
if err := unpack(respDataRaw["errors"], errs); err != nil {
return err
}
if len(errs) > 0 {
return fmt.Errorf("errors: %s", errs)
}
}

return unpack(respDataRaw["data"], response)
},
}
}
21 changes: 13 additions & 8 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ import (
)

type Build struct {
PackageName string
Objects Objects
Inputs Objects
Interfaces []*Interface
Imports Imports
QueryRoot *Object
MutationRoot *Object
SchemaRaw string
PackageName string
Objects Objects
Inputs Objects
Interfaces []*Interface
Imports Imports
QueryRoot *Object
MutationRoot *Object
SubscriptionRoot *Object
SchemaRaw string
}

// Bind a schema together with some code to generate a Build
Expand Down Expand Up @@ -48,6 +49,10 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*
b.MutationRoot = b.Objects.ByName(mr.TypeName())
}

if sr, ok := schema.EntryPoints["subscription"]; ok {
b.SubscriptionRoot = b.Objects.ByName(sr.TypeName())
}

// Poke a few magic methods into query
q := b.Objects.ByName(b.QueryRoot.GQLType)
q.Fields = append(q.Fields, Field{
Expand Down
8 changes: 7 additions & 1 deletion codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Object struct {
Satisfies []string
Root bool
DisableConcurrency bool
Stream bool
}

type Field struct {
Expand Down Expand Up @@ -75,7 +76,12 @@ func (f *Field) ResolverDeclaration() string {
res += fmt.Sprintf(", %s %s", arg.GQLName, arg.Signature())
}

res += fmt.Sprintf(") (%s, error)", f.Signature())
result := f.Signature()
if f.Object.Stream {
result = "<-chan " + result
}

res += fmt.Sprintf(") (%s, error)", result)
return res
}

Expand Down
3 changes: 3 additions & 0 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Obje
if name == "mutation" {
objects.ByName(obj.Name).DisableConcurrency = true
}
if name == "subscription" {
objects.ByName(obj.Name).Stream = true
}
}

sort.Slice(objects, func(i, j int) bool {
Expand Down
21 changes: 21 additions & 0 deletions example/chat/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# See https://help.github.com/ignore-files/ for more about ignoring files.

# dependencies
/node_modules

# testing
/coverage

# production
/build

# misc
.DS_Store
.env.local
.env.development.local
.env.test.local
.env.production.local

npm-debug.log*
yarn-debug.log*
yarn-error.log*
52 changes: 52 additions & 0 deletions example/chat/chat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package chat

import (
"net/http/httptest"
"sync"
"testing"

"github.com/stretchr/testify/require"
"github.com/vektah/gqlgen/client"
"github.com/vektah/gqlgen/handler"
)

func TestChat(t *testing.T) {
srv := httptest.NewServer(handler.GraphQL(MakeExecutableSchema(New())))
c := client.New(srv.URL)
var wg sync.WaitGroup
wg.Add(1)

t.Run("subscribe to chat events", func(t *testing.T) {
t.Parallel()

sub := c.Websocket(`subscription { messageAdded(roomName:"#gophers") { text createdBy } }`)
defer sub.Close()

wg.Done()
var resp struct {
MessageAdded struct {
Text string
CreatedBy string
}
}
require.NoError(t, sub.Next(&resp))
require.Equal(t, "Hello!", resp.MessageAdded.Text)
require.Equal(t, "vektah", resp.MessageAdded.CreatedBy)

require.NoError(t, sub.Next(&resp))
require.Equal(t, "Whats up?", resp.MessageAdded.Text)
require.Equal(t, "vektah", resp.MessageAdded.CreatedBy)
})

t.Run("post two messages", func(t *testing.T) {
t.Parallel()

wg.Wait()
var resp interface{}
c.MustPost(`mutation {
a:post(text:"Hello!", roomName:"#gophers", username:"vektah") { id }
b:post(text:"Whats up?", roomName:"#gophers", username:"vektah") { id }
}`, &resp)
})

}
Loading

0 comments on commit 03734c6

Please sign in to comment.