Skip to content

Commit

Permalink
Fix underlying connection not being closed on protocol error (#64)
Browse files Browse the repository at this point in the history
Before this commit, the underlying connection of `Conn` was not being
closed when a protocol error was encountered. This behavior contradicted
with `Conn.DisconnectNotify()` because it reported that the underlying
connection was being closed. Additionally, the underlying connection was
now orphaned because `Conn` was no longer processing any of the
subsequent requests.

With this commit, the underlying connection is now being closed when a
protocol error is encountered, matching what `Conn.DisconnectNotify()`
has already been reporting.
  • Loading branch information
samherrmann authored Feb 7, 2023
1 parent 78a3d79 commit 028a50b
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 98 deletions.
58 changes: 26 additions & 32 deletions jsonrpc2.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,10 @@ type Conn struct {

h Handler

mu sync.Mutex
shutdown bool
closing bool
seq uint64
pending map[ID]*call
mu sync.Mutex
closed bool
seq uint64
pending map[ID]*call

sending sync.Mutex

Expand Down Expand Up @@ -417,13 +416,29 @@ func NewConn(ctx context.Context, stream ObjectStream, h Handler, opts ...ConnOp
// Close closes the JSON-RPC connection. The connection may not be
// used after it has been closed.
func (c *Conn) Close() error {
return c.close(nil)
}

func (c *Conn) close(cause error) error {
c.sending.Lock()
c.mu.Lock()
if c.shutdown || c.closing {
c.mu.Unlock()
defer c.sending.Unlock()
defer c.mu.Unlock()

if c.closed {
return ErrClosed
}
c.closing = true
c.mu.Unlock()

for _, call := range c.pending {
close(call.done)
}

if cause != nil && cause != io.EOF && cause != io.ErrUnexpectedEOF {
c.logger.Printf("jsonrpc2: protocol error: %v\n", cause)
}

close(c.disconnect)
c.closed = true
return c.stream.Close()
}

Expand All @@ -436,7 +451,7 @@ func (c *Conn) send(_ context.Context, m *anyMessage, wait bool) (cc *call, err
var id ID

c.mu.Lock()
if c.shutdown || c.closing {
if c.closed {
c.mu.Unlock()
return nil, ErrClosed
}
Expand Down Expand Up @@ -675,28 +690,7 @@ func (c *Conn) readMessages(ctx context.Context) {
}
}
}

c.sending.Lock()
c.mu.Lock()
c.shutdown = true
closing := c.closing
if err == io.EOF {
if closing {
err = ErrClosed
} else {
err = io.ErrUnexpectedEOF
}
}
for _, call := range c.pending {
call.done <- err
close(call.done)
}
c.mu.Unlock()
c.sending.Unlock()
if err != io.ErrUnexpectedEOF && !closing {
c.logger.Printf("jsonrpc2: protocol error: %v\n", err)
}
close(c.disconnect)
c.close(err)
}

// call represents a JSON-RPC call over its entire lifecycle.
Expand Down
148 changes: 82 additions & 66 deletions jsonrpc2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,80 +314,82 @@ type noopHandler struct{}

func (noopHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {}

type readWriteCloser struct {
read, write func(p []byte) (n int, err error)
}

func (x readWriteCloser) Read(p []byte) (n int, err error) {
return x.read(p)
}

func (x readWriteCloser) Write(p []byte) (n int, err error) {
return x.write(p)
}
func TestConn_DisconnectNotify(t *testing.T) {

func (readWriteCloser) Close() error { return nil }
t.Run("EOF", func(t *testing.T) {
connA, connB := net.Pipe()
c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil)
// By closing connA, connB receives io.EOF
if err := connA.Close(); err != nil {
t.Error(err)
}
assertDisconnect(t, c, connB)
})

func eof(p []byte) (n int, err error) {
return 0, io.EOF
}
t.Run("Close", func(t *testing.T) {
_, connB := net.Pipe()
c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil)
if err := c.Close(); err != nil {
t.Error(err)
}
assertDisconnect(t, c, connB)
})

func TestConn_DisconnectNotify_EOF(t *testing.T) {
c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), nil)
select {
case <-c.DisconnectNotify():
case <-time.After(200 * time.Millisecond):
t.Fatal("no disconnect notification")
}
}
t.Run("Close async", func(t *testing.T) {
done := make(chan struct{})
_, connB := net.Pipe()
c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil)
go func() {
if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed {
t.Error(err)
}
close(done)
}()
assertDisconnect(t, c, connB)
<-done
})

func TestConn_DisconnectNotify_Close(t *testing.T) {
c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), nil)
if err := c.Close(); err != nil {
t.Error(err)
}
select {
case <-c.DisconnectNotify():
case <-time.After(200 * time.Millisecond):
t.Fatal("no disconnect notification")
}
t.Run("protocol error", func(t *testing.T) {
connA, connB := net.Pipe()
c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewPlainObjectStream(connB), nil)
connA.Write([]byte("invalid json"))
assertDisconnect(t, c, connB)
})
}

func TestConn_DisconnectNotify_Close_async(t *testing.T) {
done := make(chan struct{})
c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), nil)
go func() {
if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed {
func TestConn_Close(t *testing.T) {
t.Run("waiting for response", func(t *testing.T) {
connA, connB := net.Pipe()
nodeA := jsonrpc2.NewConn(
context.Background(),
jsonrpc2.NewPlainObjectStream(connA), noopHandler{},
)
defer nodeA.Close()
nodeB := jsonrpc2.NewConn(
context.Background(),
jsonrpc2.NewPlainObjectStream(connB),
noopHandler{},
)
defer nodeB.Close()

ready := make(chan struct{})
done := make(chan struct{})
go func() {
close(ready)
err := nodeB.Call(context.Background(), "m", nil, nil)
if err != jsonrpc2.ErrClosed {
t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed)
}
close(done)
}()
// Wait for the request to be sent before we close the connection.
<-ready
if err := nodeB.Close(); err != nil && err != jsonrpc2.ErrClosed {
t.Error(err)
}
close(done)
}()
select {
case <-c.DisconnectNotify():
case <-time.After(200 * time.Millisecond):
t.Fatal("no disconnect notification")
}
<-done
}

func TestConn_Close_waitingForResponse(t *testing.T) {
c := jsonrpc2.NewConn(context.Background(), jsonrpc2.NewBufferedStream(&readWriteCloser{eof, eof}, jsonrpc2.VarintObjectCodec{}), noopHandler{})
done := make(chan struct{})
go func() {
if err := c.Call(context.Background(), "m", nil, nil); err != jsonrpc2.ErrClosed {
t.Errorf("got error %v, want %v", err, jsonrpc2.ErrClosed)
}
close(done)
}()
if err := c.Close(); err != nil && err != jsonrpc2.ErrClosed {
t.Error(err)
}
select {
case <-c.DisconnectNotify():
case <-time.After(200 * time.Millisecond):
t.Fatal("no disconnect notification")
}
<-done
assertDisconnect(t, nodeB, connB)
<-done
})
}

func serve(ctx context.Context, lis net.Listener, h jsonrpc2.Handler, streamMaker streamMaker, opts ...jsonrpc2.ConnOpt) error {
Expand All @@ -399,3 +401,17 @@ func serve(ctx context.Context, lis net.Listener, h jsonrpc2.Handler, streamMake
jsonrpc2.NewConn(ctx, streamMaker(conn), h, opts...)
}
}

func assertDisconnect(t *testing.T, c *jsonrpc2.Conn, conn io.Writer) {
select {
case <-c.DisconnectNotify():
case <-time.After(200 * time.Millisecond):
t.Fatal("no disconnect notification")
}
// Assert that conn is closed by trying to write to it.
_, got := conn.Write(nil)
want := io.ErrClosedPipe
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}

0 comments on commit 028a50b

Please sign in to comment.