Skip to content

Commit

Permalink
table: standalone implement for table.MutateContext and `table.Allo…
Browse files Browse the repository at this point in the history
…catorContext` (#51262)

close #51259
  • Loading branch information
lcwangchao authored Feb 23, 2024
1 parent e06dc99 commit fc36864
Show file tree
Hide file tree
Showing 51 changed files with 453 additions and 264 deletions.
2 changes: 2 additions & 0 deletions br/pkg/lightning/backend/kv/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ go_library(
"//pkg/sessionctx",
"//pkg/sessionctx/variable",
"//pkg/table",
"//pkg/table/context",
"//pkg/table/contextimpl",
"//pkg/table/tables",
"//pkg/tablecodec",
"//pkg/types",
Expand Down
6 changes: 3 additions & 3 deletions br/pkg/lightning/backend/kv/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (e *BaseKVEncoder) GetOrCreateRecord() []types.Datum {

// Record2KV converts a row into a KV pair.
func (e *BaseKVEncoder) Record2KV(record, originalRow []types.Datum, rowID int64) (*Pairs, error) {
_, err := e.Table.AddRecord(e.SessionCtx, record)
_, err := e.Table.AddRecord(e.SessionCtx.GetTableCtx(), record)
if err != nil {
e.logger.Error("kv encode failed",
zap.Array("originalRow", RowArrayMarshaller(originalRow)),
Expand Down Expand Up @@ -218,14 +218,14 @@ func (e *BaseKVEncoder) ProcessColDatum(col *table.Column, rowID int64, inputDat
meta := e.Table.Meta()
shardFmt := autoid.NewShardIDFormat(&col.FieldType, meta.AutoRandomBits, meta.AutoRandomRangeBits)
// this allocator is the same as the allocator in table importer, i.e. PanickingAllocators. below too.
alloc := e.Table.Allocators(e.SessionCtx.GetSessionVars()).Get(autoid.AutoRandomType)
alloc := e.Table.Allocators(e.SessionCtx.GetTableCtx()).Get(autoid.AutoRandomType)
if err := alloc.Rebase(context.Background(), value.GetInt64()&shardFmt.IncrementalMask(), false); err != nil {
return value, errors.Trace(err)
}
}
if IsAutoIncCol(col.ToInfo()) {
// same as RowIDAllocType, since SepAutoInc is always false when initializing allocators of Table.
alloc := e.Table.Allocators(e.SessionCtx.GetSessionVars()).Get(autoid.AutoIncrementType)
alloc := e.Table.Allocators(e.SessionCtx.GetTableCtx()).Get(autoid.AutoIncrementType)
if err := alloc.Rebase(context.Background(), GetAutoRecordID(value, &col.FieldType), false); err != nil {
return value, errors.Trace(err)
}
Expand Down
4 changes: 2 additions & 2 deletions br/pkg/lightning/backend/kv/kv2sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestIterRawIndexKeysClusteredPK(t *testing.T) {
require.NoError(t, err)

sctx := kv.NewSession(sessionOpts, log.L())
handle, err := tbl.AddRecord(sctx, []types.Datum{types.NewIntDatum(1), types.NewIntDatum(2)})
handle, err := tbl.AddRecord(sctx.GetTableCtx(), []types.Datum{types.NewIntDatum(1), types.NewIntDatum(2)})
require.NoError(t, err)
paris := sctx.TakeKvPairs()
require.Len(t, paris.Pairs, 2)
Expand Down Expand Up @@ -92,7 +92,7 @@ func TestIterRawIndexKeysIntPK(t *testing.T) {
require.NoError(t, err)

sctx := kv.NewSession(sessionOpts, log.L())
handle, err := tbl.AddRecord(sctx, []types.Datum{types.NewIntDatum(1), types.NewIntDatum(2)})
handle, err := tbl.AddRecord(sctx.GetTableCtx(), []types.Datum{types.NewIntDatum(1), types.NewIntDatum(2)})
require.NoError(t, err)
paris := sctx.TakeKvPairs()
require.Len(t, paris.Pairs, 2)
Expand Down
13 changes: 11 additions & 2 deletions br/pkg/lightning/backend/kv/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import (
planctx "github.com/pingcap/tidb/pkg/planner/context"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
tbctx "github.com/pingcap/tidb/pkg/table/context"
tbctximpl "github.com/pingcap/tidb/pkg/table/contextimpl"
"github.com/pingcap/tidb/pkg/util/topsql/stmtstats"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -268,8 +270,9 @@ func (*transaction) SetAssertion(_ []byte, _ ...kv.FlagsOp) error {
type Session struct {
sessionctx.Context
planctx.EmptyPlanContextExtended
txn transaction
Vars *variable.SessionVars
txn transaction
Vars *variable.SessionVars
tblctx *tbctximpl.TableContextImpl
// currently, we only set `CommonAddRecordCtx`
values map[fmt.Stringer]any
}
Expand Down Expand Up @@ -330,6 +333,7 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session {
}
vars.TxnCtx = nil
s.Vars = vars
s.tblctx = tbctximpl.NewTableContextImpl(s)
s.txn.kvPairs = &Pairs{}

return s
Expand Down Expand Up @@ -362,6 +366,11 @@ func (se *Session) GetPlanCtx() planctx.PlanContext {
return se
}

// GetTableCtx returns the table.MutateContext
func (se *Session) GetTableCtx() tbctx.MutateContext {
return se.tblctx
}

// SetValue saves a value associated with this context for key.
func (se *Session) SetValue(key fmt.Stringer, value any) {
se.values[key] = value
Expand Down
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/kv/sql2kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func (kvcodec *tableKVEncoder) Encode(row []types.Datum,
return nil, kvcodec.LogKVConvertFailed(row, j, ExtraHandleColumnInfo, err)
}
record = append(record, value)
alloc := kvcodec.Table.Allocators(kvcodec.SessionCtx.GetSessionVars()).Get(autoid.RowIDAllocType)
alloc := kvcodec.Table.Allocators(kvcodec.SessionCtx.GetTableCtx()).Get(autoid.RowIDAllocType)
if err := alloc.Rebase(context.Background(), rowValue, false); err != nil {
return nil, errors.Trace(err)
}
Expand Down
12 changes: 6 additions & 6 deletions br/pkg/lightning/backend/kv/sql2kv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ func TestEncodeDoubleAutoIncrement(t *testing.T) {
require.NoError(t, err)

require.Equal(t, pairsExpect, pairs)
require.Equal(t, tbl.Allocators(lkv.GetEncoderSe(encoder).GetSessionVars()).Get(autoid.AutoIncrementType).Base(), int64(70))
require.Equal(t, tbl.Allocators(lkv.GetEncoderSe(encoder).GetTableCtx()).Get(autoid.AutoIncrementType).Base(), int64(70))
}

func TestEncodeMissingAutoValue(t *testing.T) {
Expand Down Expand Up @@ -445,13 +445,13 @@ func TestEncodeMissingAutoValue(t *testing.T) {
}, rowID, []int{0}, 1234)
require.NoError(t, err)
require.Equalf(t, pairsExpect, pairs, "test table info: %+v", testTblInfo)
require.Equalf(t, rowID, tbl.Allocators(lkv.GetEncoderSe(encoder).GetSessionVars()).Get(testTblInfo.AllocType).Base(), "test table info: %+v", testTblInfo)
require.Equalf(t, rowID, tbl.Allocators(lkv.GetEncoderSe(encoder).GetTableCtx()).Get(testTblInfo.AllocType).Base(), "test table info: %+v", testTblInfo)

// test insert a row without specifying the auto_xxxx column
pairs, err = encoder.Encode([]types.Datum{}, rowID, []int{0}, 1234)
require.NoError(t, err)
require.Equalf(t, pairsExpect, pairs, "test table info: %+v", testTblInfo)
require.Equalf(t, rowID, tbl.Allocators(lkv.GetEncoderSe(encoder).GetSessionVars()).Get(testTblInfo.AllocType).Base(), "test table info: %+v", testTblInfo)
require.Equalf(t, rowID, tbl.Allocators(lkv.GetEncoderSe(encoder).GetTableCtx()).Get(testTblInfo.AllocType).Base(), "test table info: %+v", testTblInfo)
}
}

Expand Down Expand Up @@ -524,7 +524,7 @@ func TestDefaultAutoRandoms(t *testing.T) {
RowID: common.EncodeIntRowID(70),
},
}))
require.Equal(t, tbl.Allocators(lkv.GetSession4test(encoder).GetSessionVars()).Get(autoid.AutoRandomType).Base(), int64(70))
require.Equal(t, tbl.Allocators(lkv.GetSession4test(encoder).GetTableCtx()).Get(autoid.AutoRandomType).Base(), int64(70))

pairs, err = encoder.Encode([]types.Datum{types.NewStringDatum("")}, 71, []int{-1, 0}, 1234)
require.NoError(t, err)
Expand All @@ -535,7 +535,7 @@ func TestDefaultAutoRandoms(t *testing.T) {
RowID: common.EncodeIntRowID(71),
},
}))
require.Equal(t, tbl.Allocators(lkv.GetSession4test(encoder).GetSessionVars()).Get(autoid.AutoRandomType).Base(), int64(71))
require.Equal(t, tbl.Allocators(lkv.GetSession4test(encoder).GetTableCtx()).Get(autoid.AutoRandomType).Base(), int64(71))
}

func TestShardRowId(t *testing.T) {
Expand Down Expand Up @@ -566,7 +566,7 @@ func TestShardRowId(t *testing.T) {
keyMap[rowID>>60] = struct{}{}
}
require.Len(t, keyMap, 8)
require.Equal(t, tbl.Allocators(lkv.GetSession4test(encoder).GetSessionVars()).Get(autoid.RowIDAllocType).Base(), int64(32))
require.Equal(t, tbl.Allocators(lkv.GetSession4test(encoder).GetTableCtx()).Get(autoid.RowIDAllocType).Base(), int64(32))
}

func TestClassifyAndAppend(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions br/pkg/lightning/errormanager/errormanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ func (em *ErrorManager) ReplaceConflictKeys(
// for nonclustered PK, need to append handle to decodedData for AddRecord
decodedData = append(decodedData, types.NewIntDatum(overwrittenHandle.IntValue()))
}
_, err = encoder.Table.AddRecord(encoder.SessionCtx, decodedData)
_, err = encoder.Table.AddRecord(encoder.SessionCtx.GetTableCtx(), decodedData)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -675,7 +675,7 @@ func (em *ErrorManager) ReplaceConflictKeys(
// for nonclustered PK, need to append handle to decodedData for AddRecord
decodedData = append(decodedData, types.NewIntDatum(handle.IntValue()))
}
_, err = encoder.Table.AddRecord(encoder.SessionCtx, decodedData)
_, err = encoder.Table.AddRecord(encoder.SessionCtx.GetTableCtx(), decodedData)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -704,7 +704,7 @@ func (em *ErrorManager) ReplaceConflictKeys(
// for nonclustered PK, need to append handle to decodedData for AddRecord
decodedData = append(decodedData, types.NewIntDatum(handle.IntValue()))
}
_, err = encoder.Table.AddRecord(encoder.SessionCtx, decodedData)
_, err = encoder.Table.AddRecord(encoder.SessionCtx.GetTableCtx(), decodedData)
if err != nil {
return errors.Trace(err)
}
Expand Down
22 changes: 12 additions & 10 deletions br/pkg/lightning/errormanager/errormanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,16 @@ func TestReplaceConflictOneKey(t *testing.T) {
types.NewIntDatum(4),
types.NewStringDatum("5.csv"),
}
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data1)
tctx := encoder.SessionCtx.GetTableCtx()
_, err = encoder.Table.AddRecord(tctx, data1)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data2)
_, err = encoder.Table.AddRecord(tctx, data2)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data3)
_, err = encoder.Table.AddRecord(tctx, data3)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data4)
_, err = encoder.Table.AddRecord(tctx, data4)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data5)
_, err = encoder.Table.AddRecord(tctx, data5)
require.NoError(t, err)
kvPairs := encoder.SessionCtx.TakeKvPairs()

Expand Down Expand Up @@ -433,15 +434,16 @@ func TestReplaceConflictOneUniqueKey(t *testing.T) {
types.NewIntDatum(4),
types.NewStringDatum("5.csv"),
}
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data1)
tctx := encoder.SessionCtx.GetTableCtx()
_, err = encoder.Table.AddRecord(tctx, data1)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data2)
_, err = encoder.Table.AddRecord(tctx, data2)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data3)
_, err = encoder.Table.AddRecord(tctx, data3)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data4)
_, err = encoder.Table.AddRecord(tctx, data4)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data5)
_, err = encoder.Table.AddRecord(tctx, data5)
require.NoError(t, err)
kvPairs := encoder.SessionCtx.TakeKvPairs()

Expand Down
55 changes: 30 additions & 25 deletions br/pkg/lightning/errormanager/resolveconflict_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,20 @@ func TestReplaceConflictMultipleKeysNonclusteredPk(t *testing.T) {
types.NewStringDatum("5.csv"),
types.NewIntDatum(7),
}
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data1)
tctx := encoder.SessionCtx.GetTableCtx()
_, err = encoder.Table.AddRecord(tctx, data1)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data2)
_, err = encoder.Table.AddRecord(tctx, data2)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data3)
_, err = encoder.Table.AddRecord(tctx, data3)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data4)
_, err = encoder.Table.AddRecord(tctx, data4)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data5)
_, err = encoder.Table.AddRecord(tctx, data5)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data6)
_, err = encoder.Table.AddRecord(tctx, data6)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data7)
_, err = encoder.Table.AddRecord(tctx, data7)
require.NoError(t, err)
kvPairs := encoder.SessionCtx.TakeKvPairs()

Expand Down Expand Up @@ -310,15 +311,16 @@ func TestReplaceConflictOneKeyNonclusteredPk(t *testing.T) {
types.NewStringDatum("5.csv"),
types.NewIntDatum(5),
}
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data1)
tctx := encoder.SessionCtx.GetTableCtx()
_, err = encoder.Table.AddRecord(tctx, data1)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data2)
_, err = encoder.Table.AddRecord(tctx, data2)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data3)
_, err = encoder.Table.AddRecord(tctx, data3)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data4)
_, err = encoder.Table.AddRecord(tctx, data4)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data5)
_, err = encoder.Table.AddRecord(tctx, data5)
require.NoError(t, err)
kvPairs := encoder.SessionCtx.TakeKvPairs()

Expand Down Expand Up @@ -462,15 +464,16 @@ func TestReplaceConflictOneUniqueKeyNonclusteredPk(t *testing.T) {
types.NewStringDatum("5.csv"),
types.NewIntDatum(5),
}
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data1)
tctx := encoder.SessionCtx.GetTableCtx()
_, err = encoder.Table.AddRecord(tctx, data1)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data2)
_, err = encoder.Table.AddRecord(tctx, data2)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data3)
_, err = encoder.Table.AddRecord(tctx, data3)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data4)
_, err = encoder.Table.AddRecord(tctx, data4)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data5)
_, err = encoder.Table.AddRecord(tctx, data5)
require.NoError(t, err)
kvPairs := encoder.SessionCtx.TakeKvPairs()

Expand Down Expand Up @@ -662,15 +665,16 @@ func TestReplaceConflictOneUniqueKeyNonclusteredVarcharPk(t *testing.T) {
types.NewStringDatum("5.csv"),
types.NewIntDatum(5),
}
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data1)
tctx := encoder.SessionCtx.GetTableCtx()
_, err = encoder.Table.AddRecord(tctx, data1)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data2)
_, err = encoder.Table.AddRecord(tctx, data2)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data3)
_, err = encoder.Table.AddRecord(tctx, data3)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data4)
_, err = encoder.Table.AddRecord(tctx, data4)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data5)
_, err = encoder.Table.AddRecord(tctx, data5)
require.NoError(t, err)
kvPairs := encoder.SessionCtx.TakeKvPairs()

Expand Down Expand Up @@ -852,11 +856,12 @@ func TestResolveConflictKeysError(t *testing.T) {
types.NewStringDatum("3.csv"),
types.NewIntDatum(3),
}
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data1)
tctx := encoder.SessionCtx.GetTableCtx()
_, err = encoder.Table.AddRecord(tctx, data1)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data2)
_, err = encoder.Table.AddRecord(tctx, data2)
require.NoError(t, err)
_, err = encoder.Table.AddRecord(encoder.SessionCtx, data3)
_, err = encoder.Table.AddRecord(tctx, data3)
require.NoError(t, err)
kvPairs := encoder.SessionCtx.TakeKvPairs()

Expand Down
12 changes: 6 additions & 6 deletions pkg/ddl/column_change_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestColumnAdd(t *testing.T) {
sess := testNewContext(store)
err := sessiontxn.NewTxn(context.Background(), sess)
require.NoError(t, err)
_, err = writeOnlyTable.AddRecord(sess, types.MakeDatums(10, 10))
_, err = writeOnlyTable.AddRecord(sess.GetTableCtx(), types.MakeDatums(10, 10))
require.NoError(t, err)
}
}
Expand Down Expand Up @@ -219,7 +219,7 @@ func checkAddWriteOnly(ctx sessionctx.Context, deleteOnlyTable, writeOnlyTable t
if err != nil {
return errors.Trace(err)
}
_, err = writeOnlyTable.AddRecord(ctx, types.MakeDatums(2, 3))
_, err = writeOnlyTable.AddRecord(ctx.GetTableCtx(), types.MakeDatums(2, 3))
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -257,7 +257,7 @@ func checkAddWriteOnly(ctx sessionctx.Context, deleteOnlyTable, writeOnlyTable t
if err != nil {
return errors.Trace(err)
}
err = writeOnlyTable.UpdateRecord(context.Background(), ctx, h, types.MakeDatums(1, 2, 3), types.MakeDatums(2, 2, 3), touchedSlice(writeOnlyTable))
err = writeOnlyTable.UpdateRecord(context.Background(), ctx.GetTableCtx(), h, types.MakeDatums(1, 2, 3), types.MakeDatums(2, 2, 3), touchedSlice(writeOnlyTable))
if err != nil {
return errors.Trace(err)
}
Expand All @@ -274,7 +274,7 @@ func checkAddWriteOnly(ctx sessionctx.Context, deleteOnlyTable, writeOnlyTable t
return errors.Trace(err)
}
// DeleteOnlyTable: delete from t where c2 = 2
err = deleteOnlyTable.RemoveRecord(ctx, h, types.MakeDatums(2, 2))
err = deleteOnlyTable.RemoveRecord(ctx.GetTableCtx(), h, types.MakeDatums(2, 2))
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -304,7 +304,7 @@ func checkAddPublic(sctx sessionctx.Context, writeOnlyTable, publicTable table.T
if err != nil {
return errors.Trace(err)
}
h, err := publicTable.AddRecord(sctx, types.MakeDatums(4, 4, 4))
h, err := publicTable.AddRecord(sctx.GetTableCtx(), types.MakeDatums(4, 4, 4))
if err != nil {
return errors.Trace(err)
}
Expand All @@ -321,7 +321,7 @@ func checkAddPublic(sctx sessionctx.Context, writeOnlyTable, publicTable table.T
return errors.Errorf("%v", oldRow)
}
newRow := types.MakeDatums(3, 4, oldRow[2].GetValue())
err = writeOnlyTable.UpdateRecord(context.Background(), sctx, h, oldRow, newRow, touchedSlice(writeOnlyTable))
err = writeOnlyTable.UpdateRecord(context.Background(), sctx.GetTableCtx(), h, oldRow, newRow, touchedSlice(writeOnlyTable))
if err != nil {
return errors.Trace(err)
}
Expand Down
Loading

0 comments on commit fc36864

Please sign in to comment.