From a705649f975220bcc018cf0232eab45252d852c2 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Fri, 28 Oct 2022 18:06:12 -0500 Subject: [PATCH 1/2] rpc: add failing TestClientBatchRequest_len --- rpc/client_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/rpc/client_test.go b/rpc/client_test.go index 51df76f7fe44..1a1b69dd9791 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -19,6 +19,7 @@ package rpc import ( "context" "encoding/json" + "errors" "fmt" "math/rand" "net" @@ -148,6 +149,55 @@ func TestClientBatchRequest(t *testing.T) { } } +func TestClientBatchRequest_len(t *testing.T) { + b, err := json.Marshal([]jsonrpcMessage{ + {Version: "2.0", ID: json.RawMessage("1"), Method: "foo", Result: json.RawMessage(`"0x1"`)}, + {Version: "2.0", ID: json.RawMessage("2"), Method: "bar", Result: json.RawMessage(`"0x2"`)}, + }) + if err != nil { + t.Fatal("failed to encode jsonrpc message:", err) + } + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _, err := rw.Write(b) + if err != nil { + t.Error("failed to write reponse:", err) + } + })) + t.Cleanup(s.Close) + + client, err := Dial(s.URL) + if err != nil { + t.Fatal("failed to dial test server:", err) + } + defer client.Close() + + t.Run("too-few", func(t *testing.T) { + batch := []BatchElem{ + {Method: "foo"}, + {Method: "bar"}, + {Method: "baz"}, + } + ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) + defer cancelFn() + err := client.BatchCallContext(ctx, batch) + if !errors.Is(err, ErrBadResult) { + t.Errorf("expected %q but got: %v", ErrBadResult, err) + } + }) + + t.Run("too-many", func(t *testing.T) { + batch := []BatchElem{ + {Method: "foo"}, + } + ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) + defer cancelFn() + err := client.BatchCallContext(ctx, batch) + if !errors.Is(err, ErrBadResult) { + t.Errorf("expected %q but got: %v", ErrBadResult, err) + } + }) +} + func TestClientNotify(t *testing.T) { server := newTestServer() defer server.Stop() From 4e3d5e0d59752e6b2c4fae2d557aa798f633ba0b Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Fri, 28 Oct 2022 18:06:48 -0500 Subject: [PATCH 2/2] rpc: check batch response length to prevent deadlock --- rpc/client.go | 1 + rpc/client_test.go | 10 ++++------ rpc/http.go | 3 +++ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/rpc/client.go b/rpc/client.go index 8288f976ebeb..d89aa69277c7 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -32,6 +32,7 @@ import ( ) var ( + ErrBadResult = errors.New("bad result in JSON-RPC response") ErrClientQuit = errors.New("client is closed") ErrNoResult = errors.New("no result in JSON-RPC response") ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") diff --git a/rpc/client_test.go b/rpc/client_test.go index 1a1b69dd9791..0a88ce40b2a8 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -160,7 +160,7 @@ func TestClientBatchRequest_len(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { _, err := rw.Write(b) if err != nil { - t.Error("failed to write reponse:", err) + t.Error("failed to write response:", err) } })) t.Cleanup(s.Close) @@ -170,7 +170,7 @@ func TestClientBatchRequest_len(t *testing.T) { t.Fatal("failed to dial test server:", err) } defer client.Close() - + t.Run("too-few", func(t *testing.T) { batch := []BatchElem{ {Method: "foo"}, @@ -179,8 +179,7 @@ func TestClientBatchRequest_len(t *testing.T) { } ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) defer cancelFn() - err := client.BatchCallContext(ctx, batch) - if !errors.Is(err, ErrBadResult) { + if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) { t.Errorf("expected %q but got: %v", ErrBadResult, err) } }) @@ -191,8 +190,7 @@ func TestClientBatchRequest_len(t *testing.T) { } ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) defer cancelFn() - err := client.BatchCallContext(ctx, batch) - if !errors.Is(err, ErrBadResult) { + if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) { t.Errorf("expected %q but got: %v", ErrBadResult, err) } }) diff --git a/rpc/http.go b/rpc/http.go index 8595959afb66..e806ce98b09d 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -192,6 +192,9 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil { return err } + if len(respmsgs) != len(msgs) { + return fmt.Errorf("batch has %d requests but response has %d: %w", len(msgs), len(respmsgs), ErrBadResult) + } for i := 0; i < len(respmsgs); i++ { op.resp <- &respmsgs[i] }