diff --git a/AUTHORS b/AUTHORS index 29e08b0ca..dec27daca 100644 --- a/AUTHORS +++ b/AUTHORS @@ -77,6 +77,7 @@ Maciej Zimnoch Michael Woolnough Nathanial Murphy Nicola Peduzzi +Oliver Bone Olivier Mengué oscarzhao Paul Bonser diff --git a/packets.go b/packets.go index 4e27004aa..0994d41a3 100644 --- a/packets.go +++ b/packets.go @@ -44,6 +44,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // check packet sync [8 bit] if data[3] != mc.sequence { + mc.Close() if data[3] > mc.sequence { return nil, ErrPktSyncMul } diff --git a/packets_test.go b/packets_test.go index f429087e9..56c455188 100644 --- a/packets_test.go +++ b/packets_test.go @@ -133,30 +133,34 @@ func TestReadPacketSingleByte(t *testing.T) { } func TestReadPacketWrongSequenceID(t *testing.T) { - conn := new(mockConn) - mc := &mysqlConn{ - buf: newBuffer(conn), - } - - // too low sequence id - conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} - conn.maxReads = 1 - mc.sequence = 1 - _, err := mc.readPacket() - if err != ErrPktSync { - t.Errorf("expected ErrPktSync, got %v", err) - } - - // reset - conn.reads = 0 - mc.sequence = 0 - mc.buf = newBuffer(conn) - - // too high sequence id - conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} - _, err = mc.readPacket() - if err != ErrPktSyncMul { - t.Errorf("expected ErrPktSyncMul, got %v", err) + for _, testCase := range []struct { + ClientSequenceID byte + ServerSequenceID byte + ExpectedErr error + }{ + { + ClientSequenceID: 1, + ServerSequenceID: 0, + ExpectedErr: ErrPktSync, + }, + { + ClientSequenceID: 0, + ServerSequenceID: 0x42, + ExpectedErr: ErrPktSyncMul, + }, + } { + conn, mc := newRWMockConn(testCase.ClientSequenceID) + + conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff} + _, err := mc.readPacket() + if err != testCase.ExpectedErr { + t.Errorf("expected %v, got %v", testCase.ExpectedErr, err) + } + + // connection should not be returned to the pool in this state + if mc.IsValid() { + t.Errorf("expected IsValid() to be false") + } } }