diff --git a/pkg/server/conn_stmt_params.go b/pkg/server/conn_stmt_params.go index 5b625c683faac..29fbfd9b058c4 100644 --- a/pkg/server/conn_stmt_params.go +++ b/pkg/server/conn_stmt_params.go @@ -40,7 +40,24 @@ func parseBinaryParams(params []param.BinaryParam, boundParams [][]byte, nullBit if boundParams[i] != nil { params[i] = param.BinaryParam{ Tp: mysql.TypeBlob, - Val: enc.DecodeInput(boundParams[i]), + Val: boundParams[i], + } + + // The legacy logic is kept: if the `paramTypes` somehow didn't contain the type information, it will be treated as + // BLOB type. We didn't return `mysql.ErrMalformPacket` to keep compatibility with older versions, though it's + // meaningless if every clients work properly. + if (i<<1)+1 < len(paramTypes) { + // Only TEXT or BLOB type will be sent through `SEND_LONG_DATA`. + tp := paramTypes[i<<1] + + switch tp { + case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBit: + params[i].Tp = tp + params[i].Val = enc.DecodeInput(boundParams[i]) + case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + params[i].Tp = tp + params[i].Val = boundParams[i] + } } continue } diff --git a/pkg/server/internal/testserverclient/server_client.go b/pkg/server/internal/testserverclient/server_client.go index f3e205eb37912..62b37880a46cf 100644 --- a/pkg/server/internal/testserverclient/server_client.go +++ b/pkg/server/internal/testserverclient/server_client.go @@ -48,6 +48,7 @@ import ( dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/require" "go.uber.org/zap" + "golang.org/x/text/encoding/simplifiedchinese" ) //revive:disable:exported @@ -2727,4 +2728,64 @@ func (cli *TestServerClient) RunTestConnectionCount(t *testing.T) { }) } +func (cli *TestServerClient) RunTestTypeOfSendLongData(t *testing.T) { + cli.RunTests(t, func(config *mysql.Config) { + config.MaxAllowedPacket = 1024 + }, func(dbt *testkit.DBTestKit) { + ctx := context.Background() + + conn, err := dbt.GetDB().Conn(ctx) + require.NoError(t, err) + _, err = conn.ExecContext(ctx, "CREATE TABLE t (j JSON);") + require.NoError(t, err) + + str := `"` + strings.Repeat("a", 1024) + `"` + stmt, err := conn.PrepareContext(ctx, "INSERT INTO t VALUES (cast(? as JSON));") + require.NoError(t, err) + _, err = stmt.ExecContext(ctx, str) + require.NoError(t, err) + result, err := conn.QueryContext(ctx, "SELECT j FROM t;") + require.NoError(t, err) + + for result.Next() { + var j string + require.NoError(t, result.Scan(&j)) + require.Equal(t, str, j) + } + }) +} + +func (cli *TestServerClient) RunTestCharsetOfSendLongData(t *testing.T) { + str := strings.Repeat("你好", 1024) + enc := simplifiedchinese.GBK.NewEncoder() + gbkStr, err := enc.String(str) + require.NoError(t, err) + + cli.RunTests(t, func(config *mysql.Config) { + config.MaxAllowedPacket = 1024 + config.Params["charset"] = "gbk" + }, func(dbt *testkit.DBTestKit) { + ctx := context.Background() + + conn, err := dbt.GetDB().Conn(ctx) + require.NoError(t, err) + _, err = conn.ExecContext(ctx, "CREATE TABLE t (t TEXT);") + require.NoError(t, err) + + stmt, err := conn.PrepareContext(ctx, "INSERT INTO t VALUES (?);") + require.NoError(t, err) + _, err = stmt.ExecContext(ctx, gbkStr) + require.NoError(t, err) + + result, err := conn.QueryContext(ctx, "SELECT * FROM t;") + require.NoError(t, err) + + for result.Next() { + var txt string + require.NoError(t, result.Scan(&txt)) + require.Equal(t, gbkStr, txt) + } + }) +} + //revive:enable:exported diff --git a/pkg/server/tests/commontest/BUILD.bazel b/pkg/server/tests/commontest/BUILD.bazel index d2b4d8ed82490..8e1baedf9abbf 100644 --- a/pkg/server/tests/commontest/BUILD.bazel +++ b/pkg/server/tests/commontest/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "tidb_test.go", ], flaky = True, - shard_count = 49, + shard_count = 50, deps = [ "//pkg/config", "//pkg/ddl/util", diff --git a/pkg/server/tests/commontest/tidb_test.go b/pkg/server/tests/commontest/tidb_test.go index 5e8b0f431b3c3..ace03c656a872 100644 --- a/pkg/server/tests/commontest/tidb_test.go +++ b/pkg/server/tests/commontest/tidb_test.go @@ -3082,3 +3082,13 @@ func TestConnectionCount(t *testing.T) { ts := servertestkit.CreateTidbTestSuite(t) ts.RunTestConnectionCount(t) } + +func TestTypeOfSendLongData(t *testing.T) { + ts := servertestkit.CreateTidbTestSuite(t) + ts.RunTestTypeOfSendLongData(t) +} + +func TestCharsetOfSendLongData(t *testing.T) { + ts := servertestkit.CreateTidbTestSuite(t) + ts.RunTestCharsetOfSendLongData(t) +}