Skip to content

Commit

Permalink
plan replayer: fix cannot load bindings when the statement contains i…
Browse files Browse the repository at this point in the history
…n (...) (pingcap#50762)

close pingcap#43192
  • Loading branch information
King-Dylan authored Jan 31, 2024
1 parent 333fc9a commit c76fe3f
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 20 deletions.
4 changes: 3 additions & 1 deletion pkg/executor/plan_replayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
Expand Down Expand Up @@ -345,13 +346,14 @@ func loadBindings(ctx sessionctx.Context, f *zip.File, isSession bool) error {
originSQL := cols[0]
bindingSQL := cols[1]
enabled := cols[3]
newNormalizedSQL := parser.NormalizeForBinding(originSQL, true)
if strings.Compare(enabled, "enabled") == 0 {
sql := fmt.Sprintf("CREATE %s BINDING FOR %s USING %s", func() string {
if isSession {
return "SESSION"
}
return "GLOBAL"
}(), originSQL, bindingSQL)
}(), newNormalizedSQL, bindingSQL)
c := context.Background()
_, err = ctx.(sqlexec.SQLExecutor).Execute(c, sql)
if err != nil {
Expand Down
43 changes: 30 additions & 13 deletions pkg/parser/digester.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ func Normalize(sql string) (result string) {
// which removes general property of a statement but keeps specific property.
//
// for example: NormalizeForBinding('select 1 from b where a = 1') => 'select ? from b where a = ?'
func NormalizeForBinding(sql string) (result string) {
func NormalizeForBinding(sql string, forPlanReplayerReload bool) (result string) {
d := digesterPool.Get().(*sqlDigester)
result = d.doNormalizeForBinding(sql, false)
result = d.doNormalizeForBinding(sql, false, forPlanReplayerReload)
digesterPool.Put(d)
return
}
Expand Down Expand Up @@ -161,7 +161,7 @@ func (d *sqlDigester) doDigestNormalized(normalized string) (digest *Digest) {
}

func (d *sqlDigester) doDigest(sql string) (digest *Digest) {
d.normalize(sql, false, false)
d.normalize(sql, false, false, false)
d.hasher.Write(d.buffer.Bytes())
d.buffer.Reset()
digest = NewDigest(d.hasher.Sum(nil))
Expand All @@ -170,21 +170,21 @@ func (d *sqlDigester) doDigest(sql string) (digest *Digest) {
}

func (d *sqlDigester) doNormalize(sql string, keepHint bool) (result string) {
d.normalize(sql, keepHint, false)
d.normalize(sql, keepHint, false, false)
result = d.buffer.String()
d.buffer.Reset()
return
}

func (d *sqlDigester) doNormalizeForBinding(sql string, keepHint bool) (result string) {
d.normalize(sql, keepHint, true)
func (d *sqlDigester) doNormalizeForBinding(sql string, keepHint bool, forPlanReplayerReload bool) (result string) {
d.normalize(sql, keepHint, true, forPlanReplayerReload)
result = d.buffer.String()
d.buffer.Reset()
return
}

func (d *sqlDigester) doNormalizeDigest(sql string) (normalized string, digest *Digest) {
d.normalize(sql, false, false)
d.normalize(sql, false, false, false)
normalized = d.buffer.String()
d.hasher.Write(d.buffer.Bytes())
d.buffer.Reset()
Expand All @@ -194,7 +194,7 @@ func (d *sqlDigester) doNormalizeDigest(sql string) (normalized string, digest *
}

func (d *sqlDigester) doNormalizeDigestForBinding(sql string) (normalized string, digest *Digest) {
d.normalize(sql, false, true)
d.normalize(sql, false, true, false)
normalized = d.buffer.String()
d.hasher.Write(d.buffer.Bytes())
d.buffer.Reset()
Expand All @@ -212,7 +212,7 @@ const (
genericSymbolList = -2
)

func (d *sqlDigester) normalize(sql string, keepHint bool, forBinding bool) {
func (d *sqlDigester) normalize(sql string, keepHint bool, forBinding bool, forPlanReplayerReload bool) {
d.lexer.reset(sql)
d.lexer.setKeepHint(keepHint)
for {
Expand All @@ -230,10 +230,11 @@ func (d *sqlDigester) normalize(sql string, keepHint bool, forBinding bool) {
}

d.reduceLit(&currTok)

// Apply binding matching specific rules
if forBinding {
// IN (?) => IN ( ... ) #44298
if forPlanReplayerReload {
// Apply for plan replayer to match specific rules, changing IN (...) to IN (?). This can avoid plan replayer load failures caused by parse errors.
d.replaceSingleLiteralWithInList(&currTok)
} else if forBinding {
// Apply binding matching specific rules, IN (?) => IN ( ... ) #44298
d.reduceInListWithSingleLiteral(&currTok)
}

Expand Down Expand Up @@ -377,6 +378,22 @@ func (d *sqlDigester) isGenericLists(last4 []token) bool {
return true
}

// IN (...) => IN (?) Issue: #43192
func (d *sqlDigester) replaceSingleLiteralWithInList(currTok *token) {
last5 := d.tokens.back(5)
if len(last5) == 5 &&
d.isInKeyword(last5[0]) &&
d.isLeftParen(last5[1]) &&
last5[2].lit == "." &&
last5[3].lit == "." &&
last5[4].lit == "." &&
d.isRightParen(*currTok) {
d.tokens.popBack(3)
d.tokens.pushBack(token{genericSymbol, "?"})
return
}
}

// IN (?) => IN (...) Issue: #44298
func (d *sqlDigester) reduceInListWithSingleLiteral(currTok *token) {
last3 := d.tokens.back(3)
Expand Down
2 changes: 1 addition & 1 deletion pkg/parser/digester_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestNormalize(t *testing.T) {
{"select * from t where a in(1, 2, 3)", "select * from `t` where `a` in ( ... )"},
}
for _, test := range tests_for_binding_specific_rules {
normalized := parser.NormalizeForBinding(test.input)
normalized := parser.NormalizeForBinding(test.input, false)
digest := parser.DigestNormalized(normalized)
require.Equal(t, test.expect, normalized)

Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ func (b *PlanBuilder) buildSetBindingStatusPlan(v *ast.SetBindingStmt) (Plan, er
if v.OriginNode != nil {
p = &SQLBindPlan{
SQLBindOp: OpSetBindingStatus,
NormdOrigSQL: parser.NormalizeForBinding(utilparser.RestoreWithDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB, v.OriginNode.Text())),
NormdOrigSQL: parser.NormalizeForBinding(utilparser.RestoreWithDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB, v.OriginNode.Text()), false),
Db: utilparser.GetDefaultDB(v.OriginNode, b.ctx.GetSessionVars().CurrentDB),
}
} else if v.SQLDigest != "" {
Expand Down
4 changes: 2 additions & 2 deletions pkg/planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,8 @@ func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, def
tn.DBInfo = dbInfo
}

originSQL := parser.NormalizeForBinding(utilparser.RestoreWithDefaultDB(originNode, defaultDB, originNode.Text()))
hintedSQL := parser.NormalizeForBinding(utilparser.RestoreWithDefaultDB(hintedNode, defaultDB, hintedNode.Text()))
originSQL := parser.NormalizeForBinding(utilparser.RestoreWithDefaultDB(originNode, defaultDB, originNode.Text()), false)
hintedSQL := parser.NormalizeForBinding(utilparser.RestoreWithDefaultDB(hintedNode, defaultDB, hintedNode.Text()), false)
if originSQL != hintedSQL {
p.err = errors.Errorf("hinted sql and origin sql don't match when hinted sql erase the hint info, after erase hint info, originSQL:%s, hintedSQL:%s", originSQL, hintedSQL)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/handler/optimizor/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ go_test(
"statistics_handler_test.go",
],
flaky = True,
shard_count = 4,
shard_count = 5,
deps = [
":optimizor",
"//pkg/config",
Expand Down
95 changes: 95 additions & 0 deletions pkg/server/handler/optimizor/plan_replayer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,101 @@ func prepareData4PlanReplayer(t *testing.T, client *testserverclient.TestServerC
return filename, filename3
}

func TestIssue43192(t *testing.T) {
store := testkit.CreateMockStore(t)
dom, err := session.GetDomain(store)
require.NoError(t, err)
// 1. setup and prepare plan replayer files by manual command and capture
server, client := prepareServerAndClientForTest(t, store, dom)
defer server.Close()

filename := prepareData4Issue43192(t, client, dom)
defer os.RemoveAll(replayer.GetPlanReplayerDirName())

// 2. check the contents of the plan replayer zip files.
var filesInReplayer []string
collectFileNameAndAssertFileSize := func(f *zip.File) {
// collect file name
filesInReplayer = append(filesInReplayer, f.Name)
}

// 2-1. check the plan replayer file from manual command
resp0, err := client.FetchStatus(filepath.Join("/plan_replayer/dump/", filename))
require.NoError(t, err)
defer func() {
require.NoError(t, resp0.Body.Close())
}()
body, err := io.ReadAll(resp0.Body)
require.NoError(t, err)
forEachFileInZipBytes(t, body, collectFileNameAndAssertFileSize)
slices.Sort(filesInReplayer)
require.Equal(t, expectedFilesInReplayer, filesInReplayer)

// 3. check plan replayer load
// 3-1. write the plan replayer file from manual command to a file
path := "/tmp/plan_replayer.zip"
fp, err := os.Create(path)
require.NoError(t, err)
require.NotNil(t, fp)
defer func() {
require.NoError(t, fp.Close())
require.NoError(t, os.Remove(path))
}()

_, err = io.Copy(fp, bytes.NewReader(body))
require.NoError(t, err)
require.NoError(t, fp.Sync())

// 3-2. connect to tidb and use PLAN REPLAYER LOAD to load this file
db, err := sql.Open("mysql", client.GetDSN(func(config *mysql.Config) {
config.AllowAllFiles = true
}))
require.NoError(t, err, "Error connecting")
defer func() {
err := db.Close()
require.NoError(t, err)
}()
tk := testkit.NewDBTestKit(t, db)
tk.MustExec("use planReplayer")
tk.MustExec("drop table planReplayer.t")
tk.MustExec(`plan replayer load "/tmp/plan_replayer.zip"`)

// 3-3. check whether binding takes effect
tk.MustExec(`select a, b from t where a in (1, 2, 3)`)
rows := tk.MustQuery("select @@last_plan_from_binding")
require.True(t, rows.Next(), "unexpected data")
var count int64
err = rows.Scan(&count)
require.NoError(t, err)
require.Equal(t, int64(1), count)
}

func prepareData4Issue43192(t *testing.T, client *testserverclient.TestServerClient, dom *domain.Domain) string {
h := dom.StatsHandle()
db, err := sql.Open("mysql", client.GetDSN())
require.NoError(t, err, "Error connecting")
defer func() {
err := db.Close()
require.NoError(t, err)
}()
tk := testkit.NewDBTestKit(t, db)

tk.MustExec("create database planReplayer")
tk.MustExec("use planReplayer")
tk.MustExec("create table t(a int, b int, INDEX ia (a), INDEX ib (b));")
err = h.HandleDDLEvent(<-h.DDLEventCh())
require.NoError(t, err)
tk.MustExec("create global binding for select a, b from t where a in (1, 2, 3) using select a, b from t use index (ib) where a in (1, 2, 3)")
rows := tk.MustQuery("plan replayer dump explain select a, b from t where a in (1, 2, 3)")
require.True(t, rows.Next(), "unexpected data")
var filename string
require.NoError(t, rows.Scan(&filename))
require.NoError(t, rows.Close())
rows = tk.MustQuery("select @@tidb_last_plan_replayer_token")
require.True(t, rows.Next(), "unexpected data")
return filename
}

func forEachFileInZipBytes(t *testing.T, b []byte, fn func(file *zip.File)) {
br := bytes.NewReader(b)
z, err := zip.NewReader(br, int64(len(b)))
Expand Down
2 changes: 1 addition & 1 deletion pkg/session/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -2868,7 +2868,7 @@ func upgradeToVer175(s sessiontypes.Session, ver int64) {
}
for i := 0; i < req.NumRows(); i++ {
originalNormalizedSQL, bindSQL := req.GetRow(i).GetString(0), req.GetRow(i).GetString(1)
newNormalizedSQL := parser.NormalizeForBinding(bindSQL)
newNormalizedSQL := parser.NormalizeForBinding(bindSQL, false)
// update `in (?)` to `in (...)`
if originalNormalizedSQL == newNormalizedSQL {
continue // no need to update
Expand Down

0 comments on commit c76fe3f

Please sign in to comment.