Skip to content

Commit

Permalink
This is an automated cherry-pick of pingcap#52720
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <ti-community-prow-bot@tidb.io>
  • Loading branch information
YangKeao authored and ti-chi-bot committed May 7, 2024
1 parent 1e19051 commit 81c397e
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 0 deletions.
147 changes: 147 additions & 0 deletions pkg/server/conn_stmt_params.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package server

import (
"github.com/pingcap/tidb/pkg/errno"
"github.com/pingcap/tidb/pkg/param"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/mysql"
util2 "github.com/pingcap/tidb/pkg/server/internal/util"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/dbterror"
)

var errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType)

// parseBinaryParams decodes the binary params according to the protocol
func parseBinaryParams(params []param.BinaryParam, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte, enc *util2.InputDecoder) (err error) {
pos := 0
if enc == nil {
enc = util2.NewInputDecoder(charset.CharsetUTF8)
}

for i := 0; i < len(params); i++ {
// if params had received via ComStmtSendLongData, use them directly.
// ref https://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
// see clientConn#handleStmtSendLongData
if boundParams[i] != nil {
params[i] = param.BinaryParam{
Tp: mysql.TypeBlob,
Val: boundParams[i],
}

// The legacy logic is kept: if the `paramTypes` somehow didn't contain the type information, it will be treated as
// BLOB type. We didn't return `mysql.ErrMalformPacket` to keep compatibility with older versions, though it's
// meaningless if every clients work properly.
if (i<<1)+1 < len(paramTypes) {
// Only TEXT or BLOB type will be sent through `SEND_LONG_DATA`.
tp := paramTypes[i<<1]

switch tp {
case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBit:
params[i].Tp = tp
params[i].Val = enc.DecodeInput(boundParams[i])
case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
params[i].Tp = tp
params[i].Val = boundParams[i]
}
}
continue
}

// check nullBitMap to determine the NULL arguments.
// ref https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
// notice: some client(e.g. mariadb) will set nullBitMap even if data had be sent via ComStmtSendLongData,
// so this check need place after boundParam's check.
if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
var nilDatum types.Datum
nilDatum.SetNull()
params[i] = param.BinaryParam{
Tp: mysql.TypeNull,
}
continue
}

if (i<<1)+1 >= len(paramTypes) {
return mysql.ErrMalformPacket
}

tp := paramTypes[i<<1]
isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0
isNull := false

decodeWithDecoder := false

var length uint64
switch tp {
case mysql.TypeNull:
length = 0
isNull = true
case mysql.TypeTiny:
length = 1
case mysql.TypeShort, mysql.TypeYear:
length = 2
case mysql.TypeInt24, mysql.TypeLong, mysql.TypeFloat:
length = 4
case mysql.TypeLonglong, mysql.TypeDouble:
length = 8
case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDuration:
if len(paramValues) < (pos + 1) {
err = mysql.ErrMalformPacket
return
}
length = uint64(paramValues[pos])
pos++
case mysql.TypeNewDecimal, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
if len(paramValues) < (pos + 1) {
err = mysql.ErrMalformPacket
return
}
var n int
length, isNull, n = util2.ParseLengthEncodedInt(paramValues[pos:])
pos += n
case mysql.TypeUnspecified, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString,
mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeBit:
if len(paramValues) < (pos + 1) {
err = mysql.ErrMalformPacket
return
}
var n int
length, isNull, n = util2.ParseLengthEncodedInt(paramValues[pos:])
pos += n
decodeWithDecoder = true
default:
err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp)
return
}

if len(paramValues) < (pos + int(length)) {
err = mysql.ErrMalformPacket
return
}
params[i] = param.BinaryParam{
Tp: tp,
IsUnsigned: isUnsigned,
IsNull: isNull,
Val: paramValues[pos : pos+int(length)],
}
if decodeWithDecoder {
params[i].Val = enc.DecodeInput(params[i].Val)
}
pos += int(length)
}
return
}
1 change: 1 addition & 0 deletions pkg/server/internal/testserverclient/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ go_library(
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_pingcap_log//:log",
"@com_github_stretchr_testify//require",
"@org_golang_x_text//encoding/simplifiedchinese",
"@org_uber_go_zap//:zap",
],
)
61 changes: 61 additions & 0 deletions pkg/server/internal/testserverclient/server_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/pingcap/tidb/pkg/util/versioninfo"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"golang.org/x/text/encoding/simplifiedchinese"
)

//revive:disable:exported
Expand Down Expand Up @@ -2532,4 +2533,64 @@ func (cli *TestServerClient) RunTestStmtCountLimit(t *testing.T) {
})
}

func (cli *TestServerClient) RunTestTypeAndCharsetOfSendLongData(t *testing.T) {
cli.RunTests(t, func(config *mysql.Config) {
config.MaxAllowedPacket = 1024
}, func(dbt *testkit.DBTestKit) {
ctx := context.Background()

conn, err := dbt.GetDB().Conn(ctx)
require.NoError(t, err)
_, err = conn.ExecContext(ctx, "CREATE TABLE t (j JSON);")
require.NoError(t, err)

str := `"` + strings.Repeat("a", 1024) + `"`
stmt, err := conn.PrepareContext(ctx, "INSERT INTO t VALUES (cast(? as JSON));")
require.NoError(t, err)
_, err = stmt.ExecContext(ctx, str)
require.NoError(t, err)
result, err := conn.QueryContext(ctx, "SELECT j FROM t;")
require.NoError(t, err)

for result.Next() {
var j string
require.NoError(t, result.Scan(&j))
require.Equal(t, str, j)
}
})

str := strings.Repeat("你好", 1024)
enc := simplifiedchinese.GBK.NewEncoder()
gbkStr, err := enc.String(str)
require.NoError(t, err)

cli.RunTests(t, func(config *mysql.Config) {
config.MaxAllowedPacket = 1024
config.Params["charset"] = "gbk"
}, func(dbt *testkit.DBTestKit) {
ctx := context.Background()

conn, err := dbt.GetDB().Conn(ctx)
require.NoError(t, err)
_, err = conn.ExecContext(ctx, "drop table t")
require.NoError(t, err)
_, err = conn.ExecContext(ctx, "CREATE TABLE t (t TEXT);")
require.NoError(t, err)

stmt, err := conn.PrepareContext(ctx, "INSERT INTO t VALUES (?);")
require.NoError(t, err)
_, err = stmt.ExecContext(ctx, gbkStr)
require.NoError(t, err)

result, err := conn.QueryContext(ctx, "SELECT * FROM t;")
require.NoError(t, err)

for result.Next() {
var txt string
require.NoError(t, result.Scan(&txt))
require.Equal(t, gbkStr, txt)
}
})
}

//revive:enable:exported
50 changes: 50 additions & 0 deletions pkg/server/tests/commontest/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")

go_test(
name = "commontest_test",
timeout = "short",
srcs = [
"main_test.go",
"tidb_test.go",
],
flaky = True,
shard_count = 50,
deps = [
"//pkg/config",
"//pkg/ddl/util",
"//pkg/extension",
"//pkg/metrics",
"//pkg/parser",
"//pkg/parser/ast",
"//pkg/parser/auth",
"//pkg/parser/mysql",
"//pkg/server",
"//pkg/server/internal/column",
"//pkg/server/internal/resultset",
"//pkg/server/internal/testserverclient",
"//pkg/server/internal/testutil",
"//pkg/server/internal/util",
"//pkg/server/tests/servertestkit",
"//pkg/session",
"//pkg/sessionctx/variable",
"//pkg/store/mockstore/unistore",
"//pkg/testkit",
"//pkg/testkit/testsetup",
"//pkg/util",
"//pkg/util/plancodec",
"//pkg/util/resourcegrouptag",
"//pkg/util/topsql",
"//pkg/util/topsql/collector",
"//pkg/util/topsql/collector/mock",
"//pkg/util/topsql/state",
"//pkg/util/topsql/stmtstats",
"@com_github_go_sql_driver_mysql//:mysql",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_stretchr_testify//require",
"@com_github_tikv_client_go_v2//tikv",
"@com_github_tikv_client_go_v2//tikvrpc",
"@io_opencensus_go//stats/view",
"@org_uber_go_goleak//:goleak",
],
)
103 changes: 103 additions & 0 deletions pkg/server/tests/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3134,3 +3134,106 @@ func TestProxyProtocolWithIpNoFallbackable(t *testing.T) {
require.NotNil(t, err)
db.Close()
}
<<<<<<< HEAD:pkg/server/tests/tidb_test.go
=======

func TestConnectionWillNotLeak(t *testing.T) {
cfg := util2.NewTestConfig()
cfg.Port = 0
cfg.Status.ReportStatus = false
// Setup proxy protocol config
cfg.ProxyProtocol.Networks = "*"
cfg.ProxyProtocol.Fallbackable = false

ts := servertestkit.CreateTidbTestSuite(t)

cli := testserverclient.NewTestServerClient()
cli.Port = testutil.GetPortFromTCPAddr(ts.Server.ListenAddr())
dsn := cli.GetDSN(func(config *mysql.Config) {
config.User = "root"
config.DBName = "test"
})
db, err := sql.Open("mysql", dsn)
require.Nil(t, err)
db.SetMaxOpenConns(100)
db.SetMaxIdleConns(0)

// create 100 connections
conns := make([]*sql.Conn, 0, 100)
for len(conns) < 100 {
conn, err := db.Conn(context.Background())
require.NoError(t, err)
conns = append(conns, conn)
}
require.Eventually(t, func() bool {
runtime.GC()
return server2.ConnectionInMemCounterForTest.Load() == int64(100)
}, time.Minute, time.Millisecond*100)

// run a simple query on each connection and close it
// this cannot ensure the connection will not leak for any kinds of requests
var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
conn := conn
go func() {
rows, err := conn.QueryContext(context.Background(), "SELECT 2023")
require.NoError(t, err)
var result int
require.True(t, rows.Next())
require.NoError(t, rows.Scan(&result))
require.Equal(t, result, 2023)
require.NoError(t, rows.Close())
// `db.Close` will not close already grabbed connection, so it's still needed to close the connection here.
require.NoError(t, conn.Close())
wg.Done()
}()
}
wg.Wait()

require.NoError(t, db.Close())
require.Eventually(t, func() bool {
runtime.GC()
count := server2.ConnectionInMemCounterForTest.Load()
return count == 0
}, time.Minute, time.Millisecond*100)
}

func TestPrepareCount(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil)
require.NoError(t, err)
prepareCnt := atomic.LoadInt64(&variable.PreparedStmtCount)
ctx := context.Background()
_, err = Execute(ctx, qctx, "use test;")
require.NoError(t, err)
_, err = Execute(ctx, qctx, "drop table if exists t1")
require.NoError(t, err)
_, err = Execute(ctx, qctx, "create table t1 (id int)")
require.NoError(t, err)
stmt, _, _, err := qctx.Prepare("insert into t1 values (?)")
require.NoError(t, err)
require.Equal(t, prepareCnt+1, atomic.LoadInt64(&variable.PreparedStmtCount))
require.NoError(t, err)
err = qctx.GetStatement(stmt.ID()).Close()
require.NoError(t, err)
require.Equal(t, prepareCnt, atomic.LoadInt64(&variable.PreparedStmtCount))
require.NoError(t, qctx.Close())
}

func TestSQLModeIsLoadedBeforeQuery(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)
ts.RunTestSQLModeIsLoadedBeforeQuery(t)
}

func TestConnectionCount(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)
ts.RunTestConnectionCount(t)
}

func TestTypeAndCharsetOfSendLongData(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)
ts.RunTestTypeAndCharsetOfSendLongData(t)
}
>>>>>>> 24990b5ddd6 (server: handle the case the type of param is set for the param sent by `SEND_LONG_DATA` (#52720)):pkg/server/tests/commontest/tidb_test.go

0 comments on commit 81c397e

Please sign in to comment.