From 6b44c45f8353aac1545d488e60a5bc2bc257726d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Mon, 1 Aug 2022 12:36:25 +0200 Subject: [PATCH] server: support compressed MySQL protocol - https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_compression.html - https://dev.mysql.com/worklog/task/?id=12039 - https://github.com/pingcap/tidb/issues/22605 --- go.mod | 2 +- go.sum | 4 +- parser/mysql/const.go | 15 ++- server/conn.go | 21 ++++- server/packetio.go | 212 ++++++++++++++++++++++++++++++++++++++---- server/server.go | 8 +- 6 files changed, 230 insertions(+), 32 deletions(-) diff --git a/go.mod b/go.mod index 6608f2d635eb0..0d10a5bdf071f 100644 --- a/go.mod +++ b/go.mod @@ -59,7 +59,7 @@ require ( github.com/jingyugao/rowserrcheck v1.1.1 github.com/joho/sqltocsv v0.0.0-20210428211105-a6d6801d59df github.com/kisielk/errcheck v1.6.3 - github.com/klauspost/compress v1.15.13 + github.com/klauspost/compress v1.16.5 github.com/kyoh86/exportloopref v0.1.11 github.com/lestrrat-go/jwx/v2 v2.0.6 github.com/mgechev/revive v1.3.1 diff --git a/go.sum b/go.sum index 8c8e2a4c50dee..0df34de297cf2 100644 --- a/go.sum +++ b/go.sum @@ -604,8 +604,8 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.10.5/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.12.2/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.15.13 h1:NFn1Wr8cfnenSJSA46lLq4wHCcBzKTSjnBIexDMMOV0= -github.com/klauspost/compress v1.15.13/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM= +github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= +github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid v1.3.1 h1:5JNjFYYQrZeKRJ0734q51WCEEn2huer72Dc7K+R/b6s= github.com/klauspost/cpuid v1.3.1/go.mod h1:bYW4mA6ZgKPob1/Dlai2LviZJO7KGI3uoWLd42rAQw4= diff --git a/parser/mysql/const.go b/parser/mysql/const.go index 23c0b0a2d1547..6071367801070 100644 --- a/parser/mysql/const.go +++ b/parser/mysql/const.go @@ -140,7 +140,7 @@ const ( ClientLongFlag // CLIENT_LONG_FLAG ClientConnectWithDB // CLIENT_CONNECT_WITH_DB ClientNoSchema // CLIENT_NO_SCHEMA - ClientCompress // CLIENT_COMPRESS, Not supported: https://github.com/pingcap/tidb/issues/22605 + ClientCompress // CLIENT_COMPRESS ClientODBC // CLIENT_ODBC ClientLocalFiles // CLIENT_LOCAL_FILES ClientIgnoreSpace // CLIENT_IGNORE_SPACE @@ -160,8 +160,8 @@ const ( ClientHandleExpiredPasswords // CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS, Not supported: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_expired_passwords.html ClientSessionTrack // CLIENT_SESSION_TRACK, Not supported: https://github.com/pingcap/tidb/issues/35309 ClientDeprecateEOF // CLIENT_DEPRECATE_EOF - // 1 << 25 == CLIENT_OPTIONAL_RESULTSET_METADATA - // 1 << 26 == CLIENT_ZSTD_COMPRESSION_ALGORITHM + ClientOptionalResultsetMetadata // CLIENT_OPTIONAL_RESULTSET_METADATA, Not supported: https://dev.mysql.com/doc/c-api/8.0/en/c-api-optional-metadata.html + ClientZstdCompressionAlgorithm // CLIENT_ZSTD_COMPRESSION_ALGORITHM // 1 << 27 == CLIENT_QUERY_ATTRIBUTES // 1 << 28 == MULTI_FACTOR_AUTHENTICATION // 1 << 29 == CLIENT_CAPABILITY_EXTENSION @@ -629,3 +629,12 @@ const ( CursorTypeForUpdate CursorTypeScrollable ) + +const ( + // CompressionNone is no compression in use + CompressionNone = iota + // CompressionZlib is zlib/deflate + CompressionZlib + // CompressionZstd is Facebook's Zstandard + CompressionZstd +) diff --git a/server/conn.go b/server/conn.go index 639260d39fbe3..a55601578526d 100644 --- a/server/conn.go +++ b/server/conn.go @@ -54,6 +54,7 @@ import ( "time" "unsafe" + "github.com/klauspost/compress/zstd" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/config" @@ -279,6 +280,14 @@ func (cc *clientConn) handshake(ctx context.Context) error { logutil.Logger(ctx).Debug("flush response to client failed", zap.Error(err)) return err } + + // With mysql --compression-algorithms=zlib,zstd both flags are set, the result is Zlib + if cc.capability&mysql.ClientCompress > 0 { + cc.pkt.SetCompressionAlgorithm(mysql.CompressionZlib) + } else if cc.capability&mysql.ClientZstdCompressionAlgorithm > 0 { + cc.pkt.SetCompressionAlgorithm(mysql.CompressionZstd) + } + return err } @@ -414,6 +423,7 @@ type handshakeResponse41 struct { Auth []byte AuthPlugin string Attrs map[string]string + ZstdLevel zstd.EncoderLevel } // parseHandshakeResponseHeader parses the common header of SSLRequest and HandshakeResponse41. @@ -502,8 +512,8 @@ func parseHandshakeResponseBody(ctx context.Context, packet *handshakeResponse41 // Defend some ill-formated packet, connection attribute is not important and can be ignored. return nil } - if num, null, off := parseLengthEncodedInt(data[offset:]); !null { - offset += off + if num, null, intOff := parseLengthEncodedInt(data[offset:]); !null { + offset += intOff // Length of variable length encoded integer itself in bytes row := data[offset : offset+int(num)] attrs, err := parseAttrs(row) if err != nil { @@ -511,9 +521,14 @@ func parseHandshakeResponseBody(ctx context.Context, packet *handshakeResponse41 return nil } packet.Attrs = attrs + offset += int(num) // Length of attributes } } + if packet.Capability&mysql.ClientZstdCompressionAlgorithm > 0 { + packet.ZstdLevel = zstd.EncoderLevelFromZstd(int(data[offset])) + } + return nil } @@ -625,6 +640,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con cc.dbname = resp.DBName cc.collation = resp.Collation cc.attrs = resp.Attrs + cc.pkt.zstdLevel = resp.ZstdLevel err = cc.handleAuthPlugin(ctx, &resp) if err != nil { @@ -1163,6 +1179,7 @@ func (cc *clientConn) Run(ctx context.Context) { } cc.addMetrics(data[0], startTime, err) cc.pkt.sequence = 0 + cc.pkt.compressedSequence = 0 } } diff --git a/server/packetio.go b/server/packetio.go index 736df5e241b3c..835a2357aa6af 100644 --- a/server/packetio.go +++ b/server/packetio.go @@ -37,9 +37,12 @@ package server import ( "bufio" + "bytes" + "compress/zlib" "io" "time" + "github.com/klauspost/compress/zstd" "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -59,16 +62,26 @@ type packetIO struct { // maxAllowedPacket is the maximum size of one packet in readPacket. maxAllowedPacket uint64 // accumulatedLength count the length of totally received 'payload' in readPacket. - accumulatedLength uint64 + accumulatedLength uint64 + compressionAlgorithm int + compressedSequence uint8 + zstdLevel zstd.EncoderLevel + compressedWriter *compressedWriter } func newPacketIO(bufReadConn *bufferedReadConn) *packetIO { - p := &packetIO{sequence: 0} + p := &packetIO{sequence: 0, compressionAlgorithm: mysql.CompressionNone, compressedSequence: 0, zstdLevel: 3} p.setBufferedReadConn(bufReadConn) p.setMaxAllowedPacket(variable.DefMaxAllowedPacket) return p } +func (p *packetIO) SetCompressionAlgorithm(ca int) { + p.compressionAlgorithm = ca + p.compressedWriter = newCompressedWriter(p.bufReadConn, ca) + p.bufWriter.Flush() +} + func (p *packetIO) setBufferedReadConn(bufReadConn *bufferedReadConn) { p.bufReadConn = bufReadConn p.bufWriter = bufio.NewWriterSize(bufReadConn, defaultWriterSize) @@ -80,12 +93,48 @@ func (p *packetIO) setReadTimeout(timeout time.Duration) { func (p *packetIO) readOnePacket() ([]byte, error) { var header [4]byte + r := io.NopCloser(p.bufReadConn) if p.readTimeout > 0 { if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil { return nil, err } } - if _, err := io.ReadFull(p.bufReadConn, header[:]); err != nil { + if p.compressionAlgorithm != mysql.CompressionNone { + var compressedHeader [7]byte + if _, err := io.ReadFull(p.bufReadConn, compressedHeader[:]); err != nil { + return nil, errors.Trace(err) + } + compressedSequence := compressedHeader[3] + if compressedSequence != p.compressedSequence { + return nil, errInvalidSequence.GenWithStack( + "invalid compressed sequence %d != %d", compressedSequence, p.compressedSequence) + } + p.compressedSequence++ + p.compressedWriter.compressedSequence = p.compressedSequence + uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16) + + if uncompressedLength > 0 { + var zr io.ReadCloser + switch p.compressionAlgorithm { + case mysql.CompressionZlib: + var err error + zr, err = zlib.NewReader(p.bufReadConn) + if err != nil { + return nil, errors.Trace(err) + } + case mysql.CompressionZstd: + zstdReader, err := zstd.NewReader(p.bufReadConn, zstd.WithDecoderConcurrency(1)) + if err != nil { + return nil, errors.Trace(err) + } + zr = zstdReader.IOReadCloser() + default: + return nil, errors.New("Unknown compression algorithm") + } + r = zr + } + } + if _, err := io.ReadFull(r, header[:]); err != nil { return nil, errors.Trace(err) } @@ -110,7 +159,11 @@ func (p *packetIO) readOnePacket() ([]byte, error) { return nil, err } } - if _, err := io.ReadFull(p.bufReadConn, data); err != nil { + if _, err := io.ReadFull(r, data); err != nil { + return nil, errors.Trace(err) + } + err := r.Close() + if err != nil { return nil, errors.Trace(err) } return data, nil @@ -160,42 +213,161 @@ func (p *packetIO) writePacket(data []byte) error { length := len(data) - 4 server_metrics.WritePacketBytes.Add(float64(len(data))) - for length >= mysql.MaxPayloadLen { + maxPayloadLen := mysql.MaxPayloadLen + if p.compressionAlgorithm != mysql.CompressionNone { + maxPayloadLen -= 4 + } + + for length >= maxPayloadLen { data[3] = p.sequence data[0] = 0xff data[1] = 0xff data[2] = 0xff - if n, err := p.bufWriter.Write(data[:4+mysql.MaxPayloadLen]); err != nil { - return errors.Trace(mysql.ErrBadConn) - } else if n != (4 + mysql.MaxPayloadLen) { - return errors.Trace(mysql.ErrBadConn) + if p.compressionAlgorithm != mysql.CompressionNone { + if n, err := p.compressedWriter.Write(data[:4+maxPayloadLen]); err != nil { + return errors.Trace(mysql.ErrBadConn) + } else if n != (4 + maxPayloadLen) { + return errors.Trace(mysql.ErrBadConn) + } } else { - p.sequence++ - length -= mysql.MaxPayloadLen - data = data[mysql.MaxPayloadLen:] + if n, err := p.bufWriter.Write(data[:4+maxPayloadLen]); err != nil { + return errors.Trace(mysql.ErrBadConn) + } else if n != (4 + maxPayloadLen) { + return errors.Trace(mysql.ErrBadConn) + } } + p.sequence++ + length -= maxPayloadLen + data = data[maxPayloadLen:] } data[3] = p.sequence data[0] = byte(length) data[1] = byte(length >> 8) data[2] = byte(length >> 16) - if n, err := p.bufWriter.Write(data); err != nil { - terror.Log(errors.Trace(err)) - return errors.Trace(mysql.ErrBadConn) - } else if n != len(data) { - return errors.Trace(mysql.ErrBadConn) + if p.compressionAlgorithm != mysql.CompressionNone { + if n, err := p.compressedWriter.Write(data); err != nil { + terror.Log(errors.Trace(err)) + return errors.Trace(mysql.ErrBadConn) + } else if n != len(data) { + return errors.Trace(mysql.ErrBadConn) + } else { + p.sequence++ + return nil + } } else { - p.sequence++ - return nil + if n, err := p.bufWriter.Write(data); err != nil { + terror.Log(errors.Trace(err)) + return errors.Trace(mysql.ErrBadConn) + } else if n != len(data) { + return errors.Trace(mysql.ErrBadConn) + } else { + p.sequence++ + return nil + } } } func (p *packetIO) flush() error { - err := p.bufWriter.Flush() + var err error + if p.compressionAlgorithm != mysql.CompressionNone { + err = p.compressedWriter.Flush() + } else { + err = p.bufWriter.Flush() + } if err != nil { return errors.Trace(err) } return err } + +func newCompressedWriter(w io.Writer, ca int) *compressedWriter { + return &compressedWriter{ + w, + new(bytes.Buffer), + ca, + 0, + 3, + } +} + +type compressedWriter struct { + w io.Writer + buf *bytes.Buffer + compressionAlgorithm int + compressedSequence uint8 + zstdLevel zstd.EncoderLevel +} + +func (cw *compressedWriter) Write(data []byte) (int, error) { + return cw.buf.Write(data) +} + +func (cw *compressedWriter) Flush() error { + var payload, compressedPacket bytes.Buffer + var w io.WriteCloser + var err error + + minCompressLength := 50 + data := cw.buf.Bytes() + cw.buf.Reset() + + switch cw.compressionAlgorithm { + case mysql.CompressionZlib: + w, err = zlib.NewWriterLevel(&payload, zlib.HuffmanOnly) + case mysql.CompressionZstd: + w, err = zstd.NewWriter(&payload, zstd.WithEncoderLevel(cw.zstdLevel)) + default: + return errors.New("Unknown compression algorithm") + } + if err != nil { + return errors.Trace(err) + } + + uncompressedLength := 0 + compressedHeader := make([]byte, 7) + + if len(data) > minCompressLength { + uncompressedLength = len(data) + _, err := w.Write(data) + if err != nil { + return errors.Trace(err) + } + err = w.Close() + if err != nil { + return errors.Trace(err) + } + } + + var compressedLength int + if len(data) > minCompressLength { + compressedLength = len(payload.Bytes()) + } else { + compressedLength = len(data) + } + compressedHeader[0] = byte(compressedLength) + compressedHeader[1] = byte(compressedLength >> 8) + compressedHeader[2] = byte(compressedLength >> 16) + compressedHeader[3] = cw.compressedSequence + compressedHeader[4] = byte(uncompressedLength) + compressedHeader[5] = byte(uncompressedLength >> 8) + compressedHeader[6] = byte(uncompressedLength >> 16) + _, err = compressedPacket.Write(compressedHeader) + if err != nil { + return errors.Trace(err) + } + cw.compressedSequence++ + + if len(data) > minCompressLength { + _, err = compressedPacket.Write(payload.Bytes()) + } else { + _, err = compressedPacket.Write(data) + } + if err != nil { + return errors.Trace(err) + } + w.Close() + cw.w.Write(compressedPacket.Bytes()) + return nil +} diff --git a/server/server.go b/server/server.go index c0c82f7612a9a..d13b957b99cc9 100644 --- a/server/server.go +++ b/server/server.go @@ -36,9 +36,8 @@ import ( "io" "math/rand" "net" - "net/http" //nolint:goimports - // For pprof - _ "net/http/pprof" // #nosec G108 + "net/http" //nolint:goimports + _ "net/http/pprof" // #nosec G108 for pprof "os" "os/user" "reflect" @@ -121,7 +120,8 @@ const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | mysql.ClientConnectWithDB | mysql.ClientProtocol41 | mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows | mysql.ClientMultiStatements | mysql.ClientMultiResults | mysql.ClientLocalFiles | - mysql.ClientConnectAtts | mysql.ClientPluginAuth | mysql.ClientInteractive | mysql.ClientDeprecateEOF + mysql.ClientConnectAtts | mysql.ClientPluginAuth | mysql.ClientInteractive | + mysql.ClientDeprecateEOF | mysql.ClientCompress | mysql.ClientZstdCompressionAlgorithm // Server is the MySQL protocol server type Server struct {