Skip to content

Commit

Permalink
Gracefully return error when no callback defined for response (#494)
Browse files Browse the repository at this point in the history
Fixes #492 

Co-authored-by: Rouven Bauer <rouven.bauer@neo4j.com>
  • Loading branch information
fbiville and robsdedude authored May 24, 2023
1 parent bbb6f6d commit e1b42c9
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 8 deletions.
37 changes: 29 additions & 8 deletions neo4j/internal/bolt/message_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"container/list"
"context"
"errors"
"fmt"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/log"
"net"
Expand Down Expand Up @@ -153,19 +154,39 @@ func (q *messageQueue) receive(ctx context.Context) error {
if q.handlers.Len() == 0 {
return errors.New("no more response callback to apply")
}
callback := q.pop()
handler := q.pop()
switch message := res.(type) {
case *db.Record:
callback.onRecord(message)
onRecord := handler.onRecord
if onRecord == nil {
return errors.New("protocol violation: the server sent an unexpected RECORD response")
}
onRecord(message)
case *success:
callback.onSuccess(message)
onSuccess := handler.onSuccess
if onSuccess == nil {
return errors.New("protocol violation: the server sent an unexpected SUCCESS response")
}
onSuccess(message)
case *db.Neo4jError:
callback.onFailure(ctx, message)
onFailure := handler.onFailure
if onFailure == nil {
return errors.New("protocol violation: the server sent an unexpected FAILURE response")
}
onFailure(ctx, message)
return message
case *ignored:
callback.onIgnored(message)
onIgnored := handler.onIgnored
if onIgnored == nil {
return errors.New("protocol violation: the server sent an unexpected IGNORED response")
}
onIgnored(message)
default:
callback.onUnknown(message)
onUnknown := handler.onUnknown
if onUnknown == nil {
return fmt.Errorf("protocol violation: the server sent an unknown %v response", message)
}
onUnknown(message)
}
return nil
}
Expand Down Expand Up @@ -195,8 +216,8 @@ func (q *messageQueue) receiveMsg(ctx context.Context) any {
return msg
}

func (q *messageQueue) enqueueCallback(callbacks responseHandler) {
q.handlers.PushBack(callbacks)
func (q *messageQueue) enqueueCallback(handler responseHandler) {
q.handlers.PushBack(handler)
}

func (q *messageQueue) setLogId(logId string) {
Expand Down
64 changes: 64 additions & 0 deletions neo4j/internal/bolt/message_queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package bolt
import (
"bytes"
"context"
"fmt"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
. "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil"
"net"
Expand Down Expand Up @@ -227,7 +228,70 @@ func TestMessageQueue(outer *testing.T) {
writer.send(ctx, server)
<-done
})

inner.Run("returns error when nil callback called for", func(inner *testing.T) {
inner.Parallel()

type testCase struct {
description string
writerWork func(*outgoing)
expectedErrorMsg string
}

testCases := []testCase{
{
description: "RECORD",
writerWork: func(o *outgoing) {
writer.appendX(msgRecord, []any{})
writer.send(ctx, server)
},
expectedErrorMsg: "protocol violation: the server sent an unexpected RECORD response",
},
{
description: "SUCCESS",
writerWork: func(o *outgoing) {
writer.appendX(msgSuccess, map[string]any{})
writer.send(ctx, server)
},
expectedErrorMsg: "protocol violation: the server sent an unexpected SUCCESS response",
},
{
description: "FAILURE",
writerWork: func(o *outgoing) {
writer.appendX(msgFailure, map[string]any{})
writer.send(ctx, server)
},
expectedErrorMsg: "protocol violation: the server sent an unexpected FAILURE response",
},
{
description: "IGNORED",
writerWork: func(o *outgoing) {
writer.appendX(msgIgnored)
writer.send(ctx, server)
},
expectedErrorMsg: "protocol violation: the server sent an unexpected IGNORED response",
},
}

for _, test := range testCases {
inner.Run(fmt.Sprintf("%s response", test.description), func(t *testing.T) {
done := make(chan any)
queue.enqueueCallback(responseHandler{})

go func() {
err := queue.receive(ctx)

AssertErrorMessageContains(t, err, test.expectedErrorMsg)
done <- struct{}{}
}()

test.writerWork(writer)
<-done
})
}
})
})

}

func assertEqualResponseHandlers(t *testing.T, handler1, handler2 responseHandler) {
Expand Down

0 comments on commit e1b42c9

Please sign in to comment.