Skip to content

Commit

Permalink
handle the case the type is set for SEND_LONG_DATA param
Browse files Browse the repository at this point in the history
Signed-off-by: Yang Keao <yangkeao@chunibyo.icu>
  • Loading branch information
YangKeao committed Apr 23, 2024
1 parent ecb8a9e commit af9fe1e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 2 deletions.
19 changes: 18 additions & 1 deletion pkg/server/conn_stmt_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,24 @@ func parseBinaryParams(params []param.BinaryParam, boundParams [][]byte, nullBit
if boundParams[i] != nil {
params[i] = param.BinaryParam{
Tp: mysql.TypeBlob,
Val: enc.DecodeInput(boundParams[i]),
Val: boundParams[i],
}

// The legacy logic is kept: if the `paramTypes` somehow didn't contain the type information, it will be treated as
// BLOB type. We didn't return `mysql.ErrMalformPacket` to keep compatibility with older versions, though it's
// meaningless if every clients work properly.
if (i<<1)+1 < len(paramTypes) {
// Only TEXT or BLOB type will be sent through `SEND_LONG_DATA`.
tp := paramTypes[i<<1]

switch tp {
case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBit:
params[i].Tp = tp
params[i].Val = enc.DecodeInput(boundParams[i])
case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
params[i].Tp = tp
params[i].Val = boundParams[i]
}
}
continue
}
Expand Down
61 changes: 61 additions & 0 deletions pkg/server/internal/testserverclient/server_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import (
dto "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"golang.org/x/text/encoding/simplifiedchinese"
)

//revive:disable:exported
Expand Down Expand Up @@ -2727,4 +2728,64 @@ func (cli *TestServerClient) RunTestConnectionCount(t *testing.T) {
})
}

func (cli *TestServerClient) RunTestTypeOfSendLongData(t *testing.T) {
cli.RunTests(t, func(config *mysql.Config) {
config.MaxAllowedPacket = 1024
}, func(dbt *testkit.DBTestKit) {
ctx := context.Background()

conn, err := dbt.GetDB().Conn(ctx)
require.NoError(t, err)
_, err = conn.ExecContext(ctx, "CREATE TABLE t (j JSON);")
require.NoError(t, err)

str := `"` + strings.Repeat("a", 1024) + `"`
stmt, err := conn.PrepareContext(ctx, "INSERT INTO t VALUES (cast(? as JSON));")
require.NoError(t, err)
_, err = stmt.ExecContext(ctx, str)
require.NoError(t, err)
result, err := conn.QueryContext(ctx, "SELECT j FROM t;")
require.NoError(t, err)

for result.Next() {
var j string
require.NoError(t, result.Scan(&j))
require.Equal(t, str, j)
}
})
}

func (cli *TestServerClient) RunTestCharsetOfSendLongData(t *testing.T) {
str := strings.Repeat("你好", 1024)
enc := simplifiedchinese.GBK.NewEncoder()
gbkStr, err := enc.String(str)
require.NoError(t, err)

cli.RunTests(t, func(config *mysql.Config) {
config.MaxAllowedPacket = 1024
config.Params["charset"] = "gbk"
}, func(dbt *testkit.DBTestKit) {
ctx := context.Background()

conn, err := dbt.GetDB().Conn(ctx)
require.NoError(t, err)
_, err = conn.ExecContext(ctx, "CREATE TABLE t (t TEXT);")
require.NoError(t, err)

stmt, err := conn.PrepareContext(ctx, "INSERT INTO t VALUES (?);")
require.NoError(t, err)
_, err = stmt.ExecContext(ctx, gbkStr)
require.NoError(t, err)

result, err := conn.QueryContext(ctx, "SELECT * FROM t;")
require.NoError(t, err)

for result.Next() {
var txt string
require.NoError(t, result.Scan(&txt))
require.Equal(t, gbkStr, txt)
}
})
}

//revive:enable:exported
2 changes: 1 addition & 1 deletion pkg/server/tests/commontest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ go_test(
"tidb_test.go",
],
flaky = True,
shard_count = 49,
shard_count = 50,
deps = [
"//pkg/config",
"//pkg/ddl/util",
Expand Down
10 changes: 10 additions & 0 deletions pkg/server/tests/commontest/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3082,3 +3082,13 @@ func TestConnectionCount(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)
ts.RunTestConnectionCount(t)
}

func TestTypeOfSendLongData(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)
ts.RunTestTypeOfSendLongData(t)
}

func TestCharsetOfSendLongData(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)
ts.RunTestCharsetOfSendLongData(t)
}

0 comments on commit af9fe1e

Please sign in to comment.