Skip to content

Commit

Permalink
*: fix a bug that the pessimistic lock doesn't work on a partition (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaiamao authored Mar 9, 2020
1 parent 11ec003 commit bd70523
Show file tree
Hide file tree
Showing 14 changed files with 147 additions and 23 deletions.
5 changes: 3 additions & 2 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,9 @@ func (b *executorBuilder) buildSelectLock(v *plannercore.PhysicalLock) Executor
return src
}
e := &SelectLockExec{
baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), src),
Lock: v.Lock,
baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), src),
Lock: v.Lock,
partitionedTable: v.PartitionedTable,
}
return e
}
Expand Down
38 changes: 33 additions & 5 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,11 @@ type SelectLockExec struct {

Lock ast.SelectLockType
keys []kv.Key

partitionedTable []table.PartitionedTable

// tblID2Table is cached to reduce cost.
tblID2Table map[int64]table.PartitionedTable
}

// Open implements the Executor Open interface.
Expand All @@ -755,6 +760,18 @@ func (e *SelectLockExec) Open(ctx context.Context) error {
// This operation is only for schema validator check.
txnCtx.UpdateDeltaForTable(id, 0, 0, map[int64]int64{})
}

if len(e.Schema().TblID2Handle) > 0 && len(e.partitionedTable) > 0 {
e.tblID2Table = make(map[int64]table.PartitionedTable, len(e.partitionedTable))
for id := range e.Schema().TblID2Handle {
for _, p := range e.partitionedTable {
if id == p.Meta().ID {
e.tblID2Table[id] = p
}
}
}
}

return nil
}

Expand All @@ -774,12 +791,23 @@ func (e *SelectLockExec) Next(ctx context.Context, req *chunk.Chunk) error {
if len(e.Schema().TblID2Handle) == 0 || e.Lock != ast.SelectLockForUpdate {
return nil
}
if req.NumRows() != 0 {

if req.NumRows() > 0 {
iter := chunk.NewIterator4Chunk(req)
for id, cols := range e.Schema().TblID2Handle {
for _, col := range cols {
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
e.keys = append(e.keys, tablecodec.EncodeRowKeyWithHandle(id, row.GetInt64(col.Index)))
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
for id, cols := range e.Schema().TblID2Handle {
physicalID := id
if pt, ok := e.tblID2Table[id]; ok {
// On a partitioned table, we have to use physical ID to encode the lock key!
p, err := pt.GetPartitionByRow(e.ctx, row.GetDatumRow(e.base().retFieldTypes))
if err != nil {
return err
}
physicalID = p.GetPhysicalID()
}

for _, col := range cols {
e.keys = append(e.keys, tablecodec.EncodeRowKeyWithHandle(physicalID, row.GetInt64(col.Index)))
}
}
}
Expand Down
12 changes: 11 additions & 1 deletion executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,17 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
if ctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 {
sc.AddAffectedRows(1)
}
unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(t.Meta().ID, h)

physicalID := t.Meta().ID
if pt, ok := t.(table.PartitionedTable); ok {
p, err := pt.GetPartitionByRow(ctx, oldData)
if err != nil {
return false, false, 0, err
}
physicalID = p.GetPhysicalID()
}

unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(physicalID, h)
txnCtx := ctx.GetSessionVars().TxnCtx
if txnCtx.IsPessimistic {
txnCtx.AddUnchangedRowKey(unchangedRowKey)
Expand Down
20 changes: 10 additions & 10 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

sig := &builtinConcatSig{bf, maxAllowedPacket}
Expand Down Expand Up @@ -354,7 +354,7 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

sig := &builtinConcatWSSig{bf, maxAllowedPacket}
Expand Down Expand Up @@ -591,7 +591,7 @@ func (c *repeatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}
sig := &builtinRepeatSig{bf, maxAllowedPacket}
return sig, nil
Expand Down Expand Up @@ -758,7 +758,7 @@ func (c *spaceFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}
sig := &builtinSpaceSig{bf, maxAllowedPacket}
return sig, nil
Expand Down Expand Up @@ -1853,7 +1853,7 @@ func (c *lpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) {
Expand Down Expand Up @@ -1981,7 +1981,7 @@ func (c *rpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) {
Expand Down Expand Up @@ -2607,7 +2607,7 @@ func (b *builtinOctStringSig) evalString(row chunk.Row) (string, bool, error) {
if err != nil {
numError, ok := err.(*strconv.NumError)
if !ok || numError.Err != strconv.ErrRange {
return "", true, err
return "", true, errors.Trace(err)
}
overflow = true
}
Expand Down Expand Up @@ -3161,7 +3161,7 @@ func (c *fromBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Exp
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

types.SetBinChsClnFlag(bf.tp)
Expand Down Expand Up @@ -3234,7 +3234,7 @@ func (c *toBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Expre
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

sig := &builtinToBase64Sig{bf, maxAllowedPacket}
Expand Down Expand Up @@ -3335,7 +3335,7 @@ func (c *insertFunctionClass) getFunction(ctx sessionctx.Context, args []Express
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, err
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) {
Expand Down
3 changes: 2 additions & 1 deletion planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,8 @@ func (p *LogicalLimit) exhaustPhysicalPlans(prop *property.PhysicalProperty) []P
func (p *LogicalLock) exhaustPhysicalPlans(prop *property.PhysicalProperty) []PhysicalPlan {
childProp := prop.Clone()
lock := PhysicalLock{
Lock: p.Lock,
Lock: p.Lock,
PartitionedTable: p.partitionedTable,
}.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
return []PhysicalPlan{lock}
}
Expand Down
1 change: 1 addition & 0 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2239,6 +2239,7 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName) (L

if tableInfo.GetPartitionInfo() != nil {
b.optFlag = b.optFlag | flagPartitionProcessor
b.partitionedTable = append(b.partitionedTable, tbl.(table.PartitionedTable))
// check partition by name.
for _, name := range tn.PartitionNames {
_, err = tables.FindPartitionByName(tableInfo, name.L)
Expand Down
3 changes: 2 additions & 1 deletion planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,8 @@ type LogicalLimit struct {
type LogicalLock struct {
baseLogicalPlan

Lock ast.SelectLockType
Lock ast.SelectLockType
partitionedTable []table.PartitionedTable
}

// WindowFrame represents a window function frame.
Expand Down
3 changes: 3 additions & 0 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/pingcap/tidb/planner/property"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/ranger"
)
Expand Down Expand Up @@ -280,6 +281,8 @@ type PhysicalLock struct {
basePhysicalPlan

Lock ast.SelectLockType

PartitionedTable []table.PartitionedTable
}

// PhysicalLimit is the physical operator of Limit.
Expand Down
5 changes: 4 additions & 1 deletion planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ type PlanBuilder struct {
inStraightJoin bool

windowSpecs map[string]*ast.WindowSpec

// SelectLock need this information to locate the lock on partitions.
partitionedTable []table.PartitionedTable
}

// GetVisitInfo gets the visitInfo of the PlanBuilder.
Expand Down Expand Up @@ -575,7 +578,7 @@ func removeIgnoredPaths(paths, ignoredPaths []*accessPath, tblInfo *model.TableI
}

func (b *PlanBuilder) buildSelectLock(src LogicalPlan, lock ast.SelectLockType) *LogicalLock {
selectLock := LogicalLock{Lock: lock}.Init(b.ctx)
selectLock := LogicalLock{Lock: lock, partitionedTable: b.partitionedTable}.Init(b.ctx)
selectLock.SetChildren(src)
return selectLock
}
Expand Down
6 changes: 6 additions & 0 deletions planner/core/rule_column_pruning.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,12 @@ func (p *LogicalLock) PruneColumns(parentUsedCols []*expression.Column) error {
return p.baseLogicalPlan.PruneColumns(parentUsedCols)
}

if len(p.partitionedTable) > 0 {
// If the children include partitioned tables, do not prune columns.
// Because the executor needs the partitioned columns to calculate the lock key.
return p.children[0].PruneColumns(p.Schema().Columns)
}

for _, cols := range p.children[0].Schema().TblID2Handle {
parentUsedCols = append(parentUsedCols, cols...)
}
Expand Down
4 changes: 4 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,10 @@ func createSessionFunc(store kv.Storage) pools.Factory {
if err != nil {
return nil, errors.Trace(err)
}
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxAllowedPacket, types.NewStringDatum("67108864"))
if err != nil {
return nil, errors.Trace(err)
}
se.sessionVars.CommonGlobalLoaded = true
se.sessionVars.InRestrictedSQL = true
return se, nil
Expand Down
66 changes: 66 additions & 0 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2799,3 +2799,69 @@ func (s *testSessionSuite) TestGrantViewRelated(c *C) {
tkUser.MustQuery("select current_user();").Check(testkit.Rows("u_version29@%"))
tkUser.MustExec("create view v_version29_c as select * from v_version29;")
}

func (s *testSessionSuite) TestPessimisticLockOnPartition(c *C) {
// This test checks that 'select ... for update' locks the partition instead of the table.
// Cover a bug that table ID is used to encode the lock key mistakenly.
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec(`create table if not exists forupdate_on_partition (
age int not null primary key,
nickname varchar(20) not null,
gender int not null default 0,
first_name varchar(30) not null default '',
last_name varchar(20) not null default '',
full_name varchar(60) as (concat(first_name, ' ', last_name)),
index idx_nickname (nickname)
) partition by range (age) (
partition child values less than (18),
partition young values less than (30),
partition middle values less than (50),
partition old values less than (123)
);`)
tk.MustExec("insert into forupdate_on_partition (`age`, `nickname`) values (25, 'cosven');")

tk1 := testkit.NewTestKit(c, s.store)
tk1.MustExec("use test")

tk.MustExec("begin pessimistic")
tk.MustQuery("select * from forupdate_on_partition where age=25 for update").Check(testkit.Rows("25 cosven 0 "))
tk1.MustExec("begin pessimistic")

ch := make(chan int32, 5)
go func() {
tk1.MustExec("update forupdate_on_partition set first_name='sw' where age=25")
ch <- 0
tk1.MustExec("commit")
}()

// Leave 50ms for tk1 to run, tk1 should be blocked at the update operation.
time.Sleep(50 * time.Millisecond)
ch <- 1

tk.MustExec("commit")
// tk1 should be blocked until tk commit, check the order.
c.Assert(<-ch, Equals, int32(1))
c.Assert(<-ch, Equals, int32(0))

// Once again...
// This time, test for the update-update conflict.
tk.MustExec("begin pessimistic")
tk.MustExec("update forupdate_on_partition set first_name='sw' where age=25")
tk1.MustExec("begin pessimistic")

go func() {
tk1.MustExec("update forupdate_on_partition set first_name = 'xxx' where age=25")
ch <- 0
tk1.MustExec("commit")
}()

// Leave 50ms for tk1 to run, tk1 should be blocked at the update operation.
time.Sleep(50 * time.Millisecond)
ch <- 1

tk.MustExec("commit")
// tk1 should be blocked until tk commit, check the order.
c.Assert(<-ch, Equals, int32(1))
c.Assert(<-ch, Equals, int32(0))
}
2 changes: 1 addition & 1 deletion table/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ type PhysicalTable interface {
type PartitionedTable interface {
Table
GetPartition(physicalID int64) PhysicalTable
GetPartitionByRow(sessionctx.Context, []types.Datum) (Table, error)
GetPartitionByRow(sessionctx.Context, []types.Datum) (PhysicalTable, error)
}

// TableFromMeta builds a table.Table from *model.TableInfo.
Expand Down
2 changes: 1 addition & 1 deletion table/tables/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func (t *partitionedTable) GetPartition(pid int64) table.PhysicalTable {
}

// GetPartitionByRow returns a Table, which is actually a Partition.
func (t *partitionedTable) GetPartitionByRow(ctx sessionctx.Context, r []types.Datum) (table.Table, error) {
func (t *partitionedTable) GetPartitionByRow(ctx sessionctx.Context, r []types.Datum) (table.PhysicalTable, error) {
pid, err := t.locatePartition(ctx, t.Meta().GetPartitionInfo(), r)
if err != nil {
return nil, errors.Trace(err)
Expand Down

0 comments on commit bd70523

Please sign in to comment.