From e60f20b9f4ee502bec4cebd84de18a25fefeebb4 Mon Sep 17 00:00:00 2001 From: xufei Date: Wed, 22 Sep 2021 17:36:45 +0800 Subject: [PATCH 01/13] distsql: avoid false positive error log about `invalid cop task execution summaries length` (#28188) --- distsql/select_result.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/distsql/select_result.go b/distsql/select_result.go index 5c87a73f99c7c..e0d4e3c8cec98 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -395,9 +395,14 @@ func (r *selectResult) updateCopRuntimeStats(ctx context.Context, copStats *copr } else { // For cop task cases, we still need this protection. if len(r.selectResp.GetExecutionSummaries()) != len(r.copPlanIDs) { - logutil.Logger(ctx).Error("invalid cop task execution summaries length", - zap.Int("expected", len(r.copPlanIDs)), - zap.Int("received", len(r.selectResp.GetExecutionSummaries()))) + // for TiFlash streaming call(BatchCop and MPP), it is by design that only the last response will + // carry the execution summaries, so it is ok if some responses have no execution summaries, should + // not trigger an error log in this case. + if !(r.storeType == kv.TiFlash && len(r.selectResp.GetExecutionSummaries()) == 0) { + logutil.Logger(ctx).Error("invalid cop task execution summaries length", + zap.Int("expected", len(r.copPlanIDs)), + zap.Int("received", len(r.selectResp.GetExecutionSummaries()))) + } return } for i, detail := range r.selectResp.GetExecutionSummaries() { From c01f1a3c5fb783969b288636535a0e54fbbf702d Mon Sep 17 00:00:00 2001 From: crazycs Date: Wed, 22 Sep 2021 17:48:45 +0800 Subject: [PATCH 02/13] config: change tidb_top_sql_agent_address to config top-sql.receiver-address (#28135) --- config/config.go | 7 +++++++ config/config_test.go | 3 +++ executor/executor_test.go | 4 +++- executor/set_test.go | 7 ------- server/tidb_test.go | 30 +++++++++++++++++---------- sessionctx/variable/sysvar.go | 7 ------- sessionctx/variable/tidb_vars.go | 9 +------- util/topsql/main_test.go | 5 ++++- util/topsql/reporter/reporter.go | 3 ++- util/topsql/reporter/reporter_test.go | 5 ++++- util/topsql/topsql_test.go | 5 ++++- 11 files changed, 47 insertions(+), 38 deletions(-) diff --git a/config/config.go b/config/config.go index 4a0fa3dc01620..6740055034f07 100644 --- a/config/config.go +++ b/config/config.go @@ -138,6 +138,7 @@ type Config struct { DelayCleanTableLock uint64 `toml:"delay-clean-table-lock" json:"delay-clean-table-lock"` SplitRegionMaxNum uint64 `toml:"split-region-max-num" json:"split-region-max-num"` StmtSummary StmtSummary `toml:"stmt-summary" json:"stmt-summary"` + TopSQL TopSQL `toml:"top-sql" json:"top-sql"` // RepairMode indicates that the TiDB is in the repair mode for table meta. RepairMode bool `toml:"repair-mode" json:"repair-mode"` RepairTableList []string `toml:"repair-table-list" json:"repair-table-list"` @@ -537,6 +538,12 @@ type StmtSummary struct { HistorySize int `toml:"history-size" json:"history-size"` } +// TopSQL is the config for TopSQL. +type TopSQL struct { + // The TopSQL's data receiver address. + ReceiverAddress string `toml:"receiver-address" json:"receiver-address"` +} + // IsolationRead is the config for isolation read. type IsolationRead struct { // Engines filters tidb-server access paths by engine type. diff --git a/config/config_test.go b/config/config_test.go index 252fb47d10ea4..064bac593b3c4 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -234,6 +234,8 @@ spilled-file-encryption-method = "plaintext" [pessimistic-txn] deadlock-history-capacity = 123 deadlock-history-collect-retryable = true +[top-sql] +receiver-address = "127.0.0.1:10100" `) require.NoError(t, err) @@ -289,6 +291,7 @@ deadlock-history-collect-retryable = true require.Equal(t, uint(123), conf.PessimisticTxn.DeadlockHistoryCapacity) require.True(t, conf.PessimisticTxn.DeadlockHistoryCollectRetryable) require.False(t, conf.Experimental.EnableNewCharset) + require.Equal(t, "127.0.0.1:10100", conf.TopSQL.ReceiverAddress) _, err = f.WriteString(` [log.file] diff --git a/executor/executor_test.go b/executor/executor_test.go index 1b713b46bb6de..b8214911a79a8 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -8620,7 +8620,9 @@ func (s *testResourceTagSuite) TestResourceGroupTag(c *C) { // Enable Top SQL variable.TopSQLVariable.Enable.Store(true) - variable.TopSQLVariable.AgentAddress.Store("mock-agent") + config.UpdateGlobal(func(conf *config.Config) { + conf.TopSQL.ReceiverAddress = "mock-agent" + }) c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/unistoreRPCClientSendHook", `return(true)`), IsNil) defer failpoint.Disable("github.com/pingcap/tidb/store/mockstore/unistore/unistoreRPCClientSendHook") diff --git a/executor/set_test.go b/executor/set_test.go index b93987acc2787..167698d4564a7 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -1423,13 +1423,6 @@ func (s *testSerialSuite) TestSetTopSQLVariables(c *C) { tk.MustQuery("select @@global.tidb_enable_top_sql;").Check(testkit.Rows("0")) c.Assert(variable.TopSQLVariable.Enable.Load(), IsFalse) - tk.MustExec("set @@tidb_top_sql_agent_address='127.0.0.1:4001';") - tk.MustQuery("select @@tidb_top_sql_agent_address;").Check(testkit.Rows("127.0.0.1:4001")) - c.Assert(variable.TopSQLVariable.AgentAddress.Load(), Equals, "127.0.0.1:4001") - tk.MustExec("set @@tidb_top_sql_agent_address='';") - tk.MustQuery("select @@tidb_top_sql_agent_address;").Check(testkit.Rows("")) - c.Assert(variable.TopSQLVariable.AgentAddress.Load(), Equals, "") - tk.MustExec("set @@global.tidb_top_sql_precision_seconds=2;") tk.MustQuery("select @@global.tidb_top_sql_precision_seconds;").Check(testkit.Rows("2")) c.Assert(variable.TopSQLVariable.PrecisionSeconds.Load(), Equals, int64(2)) diff --git a/server/tidb_test.go b/server/tidb_test.go index 1b49b6774a684..a9e34bb24fe99 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1627,7 +1627,9 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLCPUProfile(c *C) { dbt.mustExec("create table t1 (a int auto_increment, b int, unique index idx(a));") dbt.mustExec("create table t2 (a int auto_increment, b int, unique index idx(a));") dbt.mustExec("set @@global.tidb_enable_top_sql='On';") - dbt.mustExec("set @@tidb_top_sql_agent_address='127.0.0.1:4001';") + config.UpdateGlobal(func(conf *config.Config) { + conf.TopSQL.ReceiverAddress = "127.0.0.1:4001" + }) dbt.mustExec("set @@global.tidb_top_sql_precision_seconds=1;") dbt.mustExec("set @@global.tidb_txn_mode = 'pessimistic'") @@ -1856,8 +1858,13 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { dbt.mustExec(fmt.Sprintf("insert into t%v (b) values (%v);", i, j)) } } + setTopSQLReceiverAddress := func(addr string) { + config.UpdateGlobal(func(conf *config.Config) { + conf.TopSQL.ReceiverAddress = addr + }) + } dbt.mustExec("set @@global.tidb_enable_top_sql='On';") - dbt.mustExec("set @@tidb_top_sql_agent_address='';") + setTopSQLReceiverAddress("") dbt.mustExec("set @@global.tidb_top_sql_precision_seconds=1;") dbt.mustExec("set @@global.tidb_top_sql_report_interval_seconds=2;") dbt.mustExec("set @@global.tidb_top_sql_max_statement_count=5;") @@ -1900,21 +1907,22 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { // case 1: dynamically change agent endpoint cancel := runWorkload(0, 10) // Test with null agent address, the agent server can't receive any record. - dbt.mustExec("set @@tidb_top_sql_agent_address='';") + setTopSQLReceiverAddress("") agentServer.WaitCollectCnt(1, time.Second*4) checkFn(0) // Test after set agent address and the evict take effect. dbt.mustExec("set @@global.tidb_top_sql_max_statement_count=5;") - dbt.mustExec(fmt.Sprintf("set @@tidb_top_sql_agent_address='%v';", agentServer.Address())) + setTopSQLReceiverAddress(agentServer.Address()) agentServer.WaitCollectCnt(1, time.Second*4) checkFn(5) // Test with wrong agent address, the agent server can't receive any record. dbt.mustExec("set @@global.tidb_top_sql_max_statement_count=8;") - dbt.mustExec("set @@tidb_top_sql_agent_address='127.0.0.1:65530';") + setTopSQLReceiverAddress("127.0.0.1:65530") + agentServer.WaitCollectCnt(1, time.Second*4) checkFn(0) // Test after set agent address and the evict take effect. - dbt.mustExec(fmt.Sprintf("set @@tidb_top_sql_agent_address='%v';", agentServer.Address())) + setTopSQLReceiverAddress(agentServer.Address()) agentServer.WaitCollectCnt(1, time.Second*4) checkFn(8) cancel() // cancel case 1 @@ -1923,11 +1931,11 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { cancel2 := runWorkload(0, 10) // empty agent address, should not collect records dbt.mustExec("set @@global.tidb_top_sql_max_statement_count=5;") - dbt.mustExec("set @@tidb_top_sql_agent_address='';") + setTopSQLReceiverAddress("") agentServer.WaitCollectCnt(1, time.Second*4) checkFn(0) // set correct address, should collect records - dbt.mustExec(fmt.Sprintf("set @@tidb_top_sql_agent_address='%v';", agentServer.Address())) + setTopSQLReceiverAddress(agentServer.Address()) agentServer.WaitCollectCnt(1, time.Second*4) checkFn(5) // agent server hangs for a while @@ -1943,11 +1951,11 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { // case 3: agent restart cancel4 := runWorkload(0, 10) // empty agent address, should not collect records - dbt.mustExec("set @@tidb_top_sql_agent_address='';") + setTopSQLReceiverAddress("") agentServer.WaitCollectCnt(1, time.Second*4) checkFn(0) // set correct address, should collect records - dbt.mustExec(fmt.Sprintf("set @@tidb_top_sql_agent_address='%v';", agentServer.Address())) + setTopSQLReceiverAddress(agentServer.Address()) agentServer.WaitCollectCnt(1, time.Second*8) checkFn(5) // run another set of SQL queries @@ -1959,7 +1967,7 @@ func (ts *tidbTestTopSQLSuite) TestTopSQLAgent(c *C) { // agent server restart agentServer, err = mockTopSQLReporter.StartMockAgentServer() c.Assert(err, IsNil) - dbt.mustExec(fmt.Sprintf("set @@tidb_top_sql_agent_address='%v';", agentServer.Address())) + setTopSQLReceiverAddress(agentServer.Address()) // check result agentServer.WaitCollectCnt(2, time.Second*8) checkFn(5) diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 761f5dc3fc10a..5b470123e5fc1 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -1785,13 +1785,6 @@ var defaultSysVars = []*SysVar{ TopSQLVariable.Enable.Store(TiDBOptOn(s)) return nil }}, - // TODO(crazycs520): Add validation - {Scope: ScopeSession, Name: TiDBTopSQLAgentAddress, Value: DefTiDBTopSQLAgentAddress, Type: TypeStr, Hidden: true, skipInit: true, AllowEmpty: true, GetSession: func(s *SessionVars) (string, error) { - return TopSQLVariable.AgentAddress.Load(), nil - }, SetSession: func(vars *SessionVars, s string) error { - TopSQLVariable.AgentAddress.Store(s) - return nil - }}, {Scope: ScopeGlobal, Name: TiDBTopSQLPrecisionSeconds, Value: strconv.Itoa(DefTiDBTopSQLPrecisionSeconds), Type: TypeInt, Hidden: true, MinValue: 1, MaxValue: math.MaxInt64, AllowEmpty: true, GetGlobal: func(s *SessionVars) (string, error) { return strconv.FormatInt(TopSQLVariable.PrecisionSeconds.Load(), 10), nil }, SetGlobal: func(vars *SessionVars, s string) error { diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 7e9ac30369e4e..eb955d2b11a6c 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -566,9 +566,6 @@ const ( // TiDBEnableTopSQL indicates whether the top SQL is enabled. TiDBEnableTopSQL = "tidb_enable_top_sql" - // TiDBTopSQLAgentAddress indicates the top SQL agent address. - TiDBTopSQLAgentAddress = "tidb_top_sql_agent_address" - // TiDBTopSQLPrecisionSeconds indicates the top SQL precision seconds. TiDBTopSQLPrecisionSeconds = "tidb_top_sql_precision_seconds" @@ -739,7 +736,6 @@ const ( DefTiDBEnableExchangePartition = false DefCTEMaxRecursionDepth = 1000 DefTiDBTopSQLEnable = false - DefTiDBTopSQLAgentAddress = "" DefTiDBTopSQLPrecisionSeconds = 1 DefTiDBTopSQLMaxStatementCount = 200 DefTiDBTopSQLMaxCollect = 10000 @@ -774,7 +770,6 @@ var ( MemoryUsageAlarmRatio = atomic.NewFloat64(config.GetGlobalConfig().Performance.MemoryUsageAlarmRatio) TopSQLVariable = TopSQL{ Enable: atomic.NewBool(DefTiDBTopSQLEnable), - AgentAddress: atomic.NewString(DefTiDBTopSQLAgentAddress), PrecisionSeconds: atomic.NewInt64(DefTiDBTopSQLPrecisionSeconds), MaxStatementCount: atomic.NewInt64(DefTiDBTopSQLMaxStatementCount), MaxCollect: atomic.NewInt64(DefTiDBTopSQLMaxCollect), @@ -788,8 +783,6 @@ var ( type TopSQL struct { // Enable top-sql or not. Enable *atomic.Bool - // AgentAddress indicate the collect agent address. - AgentAddress *atomic.String // The refresh interval of top-sql. PrecisionSeconds *atomic.Int64 // The maximum number of statements kept in memory. @@ -802,5 +795,5 @@ type TopSQL struct { // TopSQLEnabled uses to check whether enabled the top SQL feature. func TopSQLEnabled() bool { - return TopSQLVariable.Enable.Load() && TopSQLVariable.AgentAddress.Load() != "" + return TopSQLVariable.Enable.Load() && config.GetGlobalConfig().TopSQL.ReceiverAddress != "" } diff --git a/util/topsql/main_test.go b/util/topsql/main_test.go index 8b2fb2640a028..f5e3dc3f7d0cf 100644 --- a/util/topsql/main_test.go +++ b/util/topsql/main_test.go @@ -17,6 +17,7 @@ package topsql import ( "testing" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/testbridge" "github.com/pingcap/tidb/util/topsql/tracecpu" @@ -28,7 +29,9 @@ func TestMain(m *testing.M) { // set up variable.TopSQLVariable.Enable.Store(true) - variable.TopSQLVariable.AgentAddress.Store("mock") + config.UpdateGlobal(func(conf *config.Config) { + conf.TopSQL.ReceiverAddress = "mock" + }) variable.TopSQLVariable.PrecisionSeconds.Store(1) tracecpu.GlobalSQLCPUProfiler.Run() diff --git a/util/topsql/reporter/reporter.go b/util/topsql/reporter/reporter.go index 4449d2d4a7297..08503610734da 100644 --- a/util/topsql/reporter/reporter.go +++ b/util/topsql/reporter/reporter.go @@ -23,6 +23,7 @@ import ( "time" "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" @@ -552,7 +553,7 @@ func (tsr *RemoteTopSQLReporter) doReport(data reportData) { return } - agentAddr := variable.TopSQLVariable.AgentAddress.Load() + agentAddr := config.GetGlobalConfig().TopSQL.ReceiverAddress timeout := reportTimeout failpoint.Inject("resetTimeoutForTest", func(val failpoint.Value) { if val.(bool) { diff --git a/util/topsql/reporter/reporter_test.go b/util/topsql/reporter/reporter_test.go index 6001212a766a2..d4be70450e596 100644 --- a/util/topsql/reporter/reporter_test.go +++ b/util/topsql/reporter/reporter_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/topsql/reporter/mock" "github.com/pingcap/tidb/util/topsql/tracecpu" @@ -66,7 +67,9 @@ func mockPlanBinaryDecoderFunc(plan string) (string, error) { func setupRemoteTopSQLReporter(maxStatementsNum, interval int, addr string) *RemoteTopSQLReporter { variable.TopSQLVariable.MaxStatementCount.Store(int64(maxStatementsNum)) variable.TopSQLVariable.ReportIntervalSeconds.Store(int64(interval)) - variable.TopSQLVariable.AgentAddress.Store(addr) + config.UpdateGlobal(func(conf *config.Config) { + conf.TopSQL.ReceiverAddress = addr + }) rc := NewGRPCReportClient(mockPlanBinaryDecoderFunc) ts := NewRemoteTopSQLReporter(rc) diff --git a/util/topsql/topsql_test.go b/util/topsql/topsql_test.go index bb4562a98258d..3451ca3e70347 100644 --- a/util/topsql/topsql_test.go +++ b/util/topsql/topsql_test.go @@ -22,6 +22,7 @@ import ( "github.com/google/pprof/profile" "github.com/pingcap/parser" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/topsql" "github.com/pingcap/tidb/util/topsql/reporter" @@ -110,7 +111,9 @@ func TestTopSQLReporter(t *testing.T) { require.NoError(t, err) variable.TopSQLVariable.MaxStatementCount.Store(200) variable.TopSQLVariable.ReportIntervalSeconds.Store(1) - variable.TopSQLVariable.AgentAddress.Store(server.Address()) + config.UpdateGlobal(func(conf *config.Config) { + conf.TopSQL.ReceiverAddress = server.Address() + }) client := reporter.NewGRPCReportClient(mockPlanBinaryDecoderFunc) report := reporter.NewRemoteTopSQLReporter(client) From 69d73019ac6107f88a68f53ccf50c00b4d3a61f6 Mon Sep 17 00:00:00 2001 From: tison Date: Wed, 22 Sep 2021 18:16:46 +0800 Subject: [PATCH 03/13] issue-28004: migrate test-infra to testify for types/field_type_test.go (#28258) --- types/field_type_test.go | 237 ++++++++++++++++++++------------------- 1 file changed, 122 insertions(+), 115 deletions(-) diff --git a/types/field_type_test.go b/types/field_type_test.go index f9ab97b847e0c..f4f3f218f34b9 100644 --- a/types/field_type_test.go +++ b/types/field_type_test.go @@ -15,152 +15,151 @@ package types import ( - . "github.com/pingcap/check" + "testing" + "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/util/testleak" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testFieldTypeSuite{}) - -type testFieldTypeSuite struct { -} +func TestFieldType(t *testing.T) { + t.Parallel() -func (s *testFieldTypeSuite) TestFieldType(c *C) { - defer testleak.AfterTest(c)() ft := NewFieldType(mysql.TypeDuration) - c.Assert(ft.Flen, Equals, UnspecifiedLength) - c.Assert(ft.Decimal, Equals, UnspecifiedLength) + require.Equal(t, UnspecifiedLength, ft.Flen) + require.Equal(t, UnspecifiedLength, ft.Decimal) + ft.Decimal = 5 - c.Assert(ft.String(), Equals, "time(5)") + require.Equal(t, "time(5)", ft.String()) ft = NewFieldType(mysql.TypeLong) ft.Flen = 5 ft.Flag = mysql.UnsignedFlag | mysql.ZerofillFlag - c.Assert(ft.String(), Equals, "int(5) UNSIGNED ZEROFILL") - c.Assert(ft.InfoSchemaStr(), Equals, "int(5) unsigned") + require.Equal(t, "int(5) UNSIGNED ZEROFILL", ft.String()) + require.Equal(t, "int(5) unsigned", ft.InfoSchemaStr()) ft = NewFieldType(mysql.TypeFloat) ft.Flen = 12 // Default ft.Decimal = 3 // Not Default - c.Assert(ft.String(), Equals, "float(12,3)") + require.Equal(t, "float(12,3)", ft.String()) ft = NewFieldType(mysql.TypeFloat) ft.Flen = 12 // Default ft.Decimal = -1 // Default - c.Assert(ft.String(), Equals, "float") + require.Equal(t, "float", ft.String()) ft = NewFieldType(mysql.TypeFloat) ft.Flen = 5 // Not Default ft.Decimal = -1 // Default - c.Assert(ft.String(), Equals, "float") + require.Equal(t, "float", ft.String()) ft = NewFieldType(mysql.TypeFloat) ft.Flen = 7 // Not Default ft.Decimal = 3 // Not Default - c.Assert(ft.String(), Equals, "float(7,3)") + require.Equal(t, "float(7,3)", ft.String()) ft = NewFieldType(mysql.TypeDouble) ft.Flen = 22 // Default ft.Decimal = 3 // Not Default - c.Assert(ft.String(), Equals, "double(22,3)") + require.Equal(t, "double(22,3)", ft.String()) ft = NewFieldType(mysql.TypeDouble) ft.Flen = 22 // Default ft.Decimal = -1 // Default - c.Assert(ft.String(), Equals, "double") + require.Equal(t, "double", ft.String()) ft = NewFieldType(mysql.TypeDouble) ft.Flen = 5 // Not Default ft.Decimal = -1 // Default - c.Assert(ft.String(), Equals, "double") + require.Equal(t, "double", ft.String()) ft = NewFieldType(mysql.TypeDouble) ft.Flen = 7 // Not Default ft.Decimal = 3 // Not Default - c.Assert(ft.String(), Equals, "double(7,3)") + require.Equal(t, "double(7,3)", ft.String()) ft = NewFieldType(mysql.TypeBlob) ft.Flen = 10 ft.Charset = "UTF8" ft.Collate = "UTF8_UNICODE_GI" - c.Assert(ft.String(), Equals, "text CHARACTER SET UTF8 COLLATE UTF8_UNICODE_GI") + require.Equal(t, "text CHARACTER SET UTF8 COLLATE UTF8_UNICODE_GI", ft.String()) ft = NewFieldType(mysql.TypeVarchar) ft.Flen = 10 ft.Flag |= mysql.BinaryFlag - c.Assert(ft.String(), Equals, "varchar(10) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin") + require.Equal(t, "varchar(10) BINARY CHARACTER SET utf8mb4 COLLATE utf8mb4_bin", ft.String()) ft = NewFieldType(mysql.TypeString) ft.Charset = charset.CollationBin ft.Flag |= mysql.BinaryFlag - c.Assert(ft.String(), Equals, "binary(1) COLLATE utf8mb4_bin") + require.Equal(t, "binary(1) COLLATE utf8mb4_bin", ft.String()) ft = NewFieldType(mysql.TypeEnum) ft.Elems = []string{"a", "b"} - c.Assert(ft.String(), Equals, "enum('a','b')") + require.Equal(t, "enum('a','b')", ft.String()) ft = NewFieldType(mysql.TypeEnum) ft.Elems = []string{"'a'", "'b'"} - c.Assert(ft.String(), Equals, "enum('''a''','''b''')") + require.Equal(t, "enum('''a''','''b''')", ft.String()) ft = NewFieldType(mysql.TypeEnum) ft.Elems = []string{"a\nb", "a\tb", "a\rb"} - c.Assert(ft.String(), Equals, "enum('a\\nb','a\tb','a\\rb')") + require.Equal(t, "enum('a\\nb','a\tb','a\\rb')", ft.String()) ft = NewFieldType(mysql.TypeEnum) ft.Elems = []string{"a\nb", "a'\t\r\nb", "a\rb"} - c.Assert(ft.String(), Equals, "enum('a\\nb','a'' \\r\\nb','a\\rb')") + require.Equal(t, "enum('a\\nb','a'' \\r\\nb','a\\rb')", ft.String()) ft = NewFieldType(mysql.TypeSet) ft.Elems = []string{"a", "b"} - c.Assert(ft.String(), Equals, "set('a','b')") + require.Equal(t, "set('a','b')", ft.String()) ft = NewFieldType(mysql.TypeSet) ft.Elems = []string{"'a'", "'b'"} - c.Assert(ft.String(), Equals, "set('''a''','''b''')") + require.Equal(t, "set('''a''','''b''')", ft.String()) ft = NewFieldType(mysql.TypeSet) ft.Elems = []string{"a\nb", "a'\t\r\nb", "a\rb"} - c.Assert(ft.String(), Equals, "set('a\\nb','a'' \\r\\nb','a\\rb')") + require.Equal(t, "set('a\\nb','a'' \\r\\nb','a\\rb')", ft.String()) ft = NewFieldType(mysql.TypeSet) ft.Elems = []string{"a'\nb", "a'b\tc"} - c.Assert(ft.String(), Equals, "set('a''\\nb','a''b c')") + require.Equal(t, "set('a''\\nb','a''b c')", ft.String()) ft = NewFieldType(mysql.TypeTimestamp) ft.Flen = 8 ft.Decimal = 2 - c.Assert(ft.String(), Equals, "timestamp(2)") + require.Equal(t, "timestamp(2)", ft.String()) ft = NewFieldType(mysql.TypeTimestamp) ft.Flen = 8 ft.Decimal = 0 - c.Assert(ft.String(), Equals, "timestamp") + require.Equal(t, "timestamp", ft.String()) ft = NewFieldType(mysql.TypeDatetime) ft.Flen = 8 ft.Decimal = 2 - c.Assert(ft.String(), Equals, "datetime(2)") + require.Equal(t, "datetime(2)", ft.String()) ft = NewFieldType(mysql.TypeDatetime) ft.Flen = 8 ft.Decimal = 0 - c.Assert(ft.String(), Equals, "datetime") + require.Equal(t, "datetime", ft.String()) ft = NewFieldType(mysql.TypeDate) ft.Flen = 8 ft.Decimal = 2 - c.Assert(ft.String(), Equals, "date") + require.Equal(t, "date", ft.String()) ft = NewFieldType(mysql.TypeDate) ft.Flen = 8 ft.Decimal = 0 - c.Assert(ft.String(), Equals, "date") + require.Equal(t, "date", ft.String()) ft = NewFieldType(mysql.TypeYear) ft.Flen = 4 ft.Decimal = 0 - c.Assert(ft.String(), Equals, "year(4)") + require.Equal(t, "year(4)", ft.String()) ft = NewFieldType(mysql.TypeYear) ft.Flen = 2 ft.Decimal = 2 - c.Assert(ft.String(), Equals, "year(2)") // Note: Invalid year. + require.Equal(t, "year(2)", ft.String()) // Note: Invalid year. } -func (s *testFieldTypeSuite) TestDefaultTypeForValue(c *C) { - defer testleak.AfterTest(c)() +func TestDefaultTypeForValue(t *testing.T) { + t.Parallel() + tests := []struct { value interface{} tp byte @@ -197,20 +196,22 @@ func (s *testFieldTypeSuite) TestDefaultTypeForValue(c *C) { {Enum{Name: "a", Value: 1}, mysql.TypeEnum, 1, UnspecifiedLength, charset.CharsetBin, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag}, {Set{Name: "a", Value: 1}, mysql.TypeSet, 1, UnspecifiedLength, charset.CharsetBin, charset.CharsetBin, mysql.BinaryFlag | mysql.NotNullFlag}, } + for i, tt := range tests { var ft FieldType DefaultTypeForValue(tt.value, &ft, mysql.DefaultCharset, mysql.DefaultCollationName) - c.Assert(ft.Tp, Equals, tt.tp, Commentf("%v %v %v", i, ft.Tp, tt.tp)) - c.Assert(ft.Flen, Equals, tt.flen, Commentf("%v %v %v", i, ft.Flen, tt.flen)) - c.Assert(ft.Charset, Equals, tt.charset, Commentf("%v %v %v", i, ft.Charset, tt.charset)) - c.Assert(ft.Decimal, Equals, tt.decimal, Commentf("%v %v %v", i, ft.Decimal, tt.decimal)) - c.Assert(ft.Collate, Equals, tt.collation, Commentf("%v %v %v", i, ft.Collate, tt.collation)) - c.Assert(ft.Flag, Equals, tt.flag, Commentf("%v %v %v", i, ft.Flag, tt.flag)) + require.Equalf(t, tt.tp, ft.Tp, "%v %v %v", i, ft.Tp, tt.tp) + require.Equalf(t, tt.flen, ft.Flen, "%v %v %v", i, ft.Flen, tt.flen) + require.Equalf(t, tt.charset, ft.Charset, "%v %v %v", i, ft.Charset, tt.charset) + require.Equalf(t, tt.decimal, ft.Decimal, "%v %v %v", i, ft.Decimal, tt.decimal) + require.Equalf(t, tt.collation, ft.Collate, "%v %v %v", i, ft.Collate, tt.collation) + require.Equalf(t, tt.flag, ft.Flag, "%v %v %v", i, ft.Flag, tt.flag) } } -func (s *testFieldTypeSuite) TestAggFieldType(c *C) { - defer testleak.AfterTest(c)() +func TestAggFieldType(t *testing.T) { + t.Parallel() + fts := []*FieldType{ NewFieldType(mysql.TypeUnspecified), NewFieldType(mysql.TypeTiny), @@ -244,93 +245,98 @@ func (s *testFieldTypeSuite) TestAggFieldType(c *C) { for i := range fts { aggTp := AggFieldType(fts[i : i+1]) - c.Assert(aggTp.Tp, Equals, fts[i].Tp) + require.Equal(t, fts[i].Tp, aggTp.Tp) aggTp = AggFieldType([]*FieldType{fts[i], fts[i]}) switch fts[i].Tp { case mysql.TypeDate: - c.Assert(aggTp.Tp, Equals, mysql.TypeDate) + require.Equal(t, mysql.TypeDate, aggTp.Tp) case mysql.TypeJSON: - c.Assert(aggTp.Tp, Equals, mysql.TypeJSON) + require.Equal(t, mysql.TypeJSON, aggTp.Tp) case mysql.TypeEnum, mysql.TypeSet, mysql.TypeVarString: - c.Assert(aggTp.Tp, Equals, mysql.TypeVarchar) + require.Equal(t, mysql.TypeVarchar, aggTp.Tp) case mysql.TypeUnspecified: - c.Assert(aggTp.Tp, Equals, mysql.TypeNewDecimal) + require.Equal(t, mysql.TypeNewDecimal, aggTp.Tp) default: - c.Assert(aggTp.Tp, Equals, fts[i].Tp) + require.Equal(t, fts[i].Tp, aggTp.Tp) } aggTp = AggFieldType([]*FieldType{fts[i], NewFieldType(mysql.TypeLong)}) switch fts[i].Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong, mysql.TypeYear, mysql.TypeInt24, mysql.TypeNull: - c.Assert(aggTp.Tp, Equals, mysql.TypeLong) + require.Equal(t, mysql.TypeLong, aggTp.Tp) case mysql.TypeLonglong: - c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) + require.Equal(t, mysql.TypeLonglong, aggTp.Tp) case mysql.TypeFloat, mysql.TypeDouble: - c.Assert(aggTp.Tp, Equals, mysql.TypeDouble) + require.Equal(t, mysql.TypeDouble, aggTp.Tp) case mysql.TypeTimestamp, mysql.TypeDate, mysql.TypeDuration, mysql.TypeDatetime, mysql.TypeNewDate, mysql.TypeVarchar, mysql.TypeJSON, mysql.TypeEnum, mysql.TypeSet, mysql.TypeVarString, mysql.TypeGeometry: - c.Assert(aggTp.Tp, Equals, mysql.TypeVarchar) + require.Equal(t, mysql.TypeVarchar, aggTp.Tp) case mysql.TypeBit: - c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) + require.Equal(t, mysql.TypeLonglong, aggTp.Tp) case mysql.TypeString: - c.Assert(aggTp.Tp, Equals, mysql.TypeString) + require.Equal(t, mysql.TypeString, aggTp.Tp) case mysql.TypeUnspecified, mysql.TypeNewDecimal: - c.Assert(aggTp.Tp, Equals, mysql.TypeNewDecimal) + require.Equal(t, mysql.TypeNewDecimal, aggTp.Tp) case mysql.TypeTinyBlob: - c.Assert(aggTp.Tp, Equals, mysql.TypeTinyBlob) + require.Equal(t, mysql.TypeTinyBlob, aggTp.Tp) case mysql.TypeBlob: - c.Assert(aggTp.Tp, Equals, mysql.TypeBlob) + require.Equal(t, mysql.TypeBlob, aggTp.Tp) case mysql.TypeMediumBlob: - c.Assert(aggTp.Tp, Equals, mysql.TypeMediumBlob) + require.Equal(t, mysql.TypeMediumBlob, aggTp.Tp) case mysql.TypeLongBlob: - c.Assert(aggTp.Tp, Equals, mysql.TypeLongBlob) + require.Equal(t, mysql.TypeLongBlob, aggTp.Tp) } aggTp = AggFieldType([]*FieldType{fts[i], NewFieldType(mysql.TypeJSON)}) switch fts[i].Tp { case mysql.TypeJSON, mysql.TypeNull: - c.Assert(aggTp.Tp, Equals, mysql.TypeJSON) + require.Equal(t, mysql.TypeJSON, aggTp.Tp) case mysql.TypeLongBlob, mysql.TypeMediumBlob, mysql.TypeTinyBlob, mysql.TypeBlob: - c.Assert(aggTp.Tp, Equals, mysql.TypeLongBlob) + require.Equal(t, mysql.TypeLongBlob, aggTp.Tp) case mysql.TypeString: - c.Assert(aggTp.Tp, Equals, mysql.TypeString) + require.Equal(t, mysql.TypeString, aggTp.Tp) default: - c.Assert(aggTp.Tp, Equals, mysql.TypeVarchar) + require.Equal(t, mysql.TypeVarchar, aggTp.Tp) } } } -func (s *testFieldTypeSuite) TestAggFieldTypeForTypeFlag(c *C) { + +func TestAggFieldTypeForTypeFlag(t *testing.T) { + t.Parallel() + types := []*FieldType{ NewFieldType(mysql.TypeLonglong), NewFieldType(mysql.TypeLonglong), } aggTp := AggFieldType(types) - c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) - c.Assert(aggTp.Flag, Equals, uint(0)) + require.Equal(t, mysql.TypeLonglong, aggTp.Tp) + require.Equal(t, uint(0), aggTp.Flag) types[0].Flag = mysql.NotNullFlag aggTp = AggFieldType(types) - c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) - c.Assert(aggTp.Flag, Equals, uint(0)) + require.Equal(t, mysql.TypeLonglong, aggTp.Tp) + require.Equal(t, uint(0), aggTp.Flag) types[0].Flag = 0 types[1].Flag = mysql.NotNullFlag aggTp = AggFieldType(types) - c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) - c.Assert(aggTp.Flag, Equals, uint(0)) + require.Equal(t, mysql.TypeLonglong, aggTp.Tp) + require.Equal(t, uint(0), aggTp.Flag) types[0].Flag = mysql.NotNullFlag aggTp = AggFieldType(types) - c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) - c.Assert(aggTp.Flag, Equals, mysql.NotNullFlag) + require.Equal(t, mysql.TypeLonglong, aggTp.Tp) + require.Equal(t, mysql.NotNullFlag, aggTp.Flag) } -func (s testFieldTypeSuite) TestAggFieldTypeForIntegralPromotion(c *C) { +func TestAggFieldTypeForIntegralPromotion(t *testing.T) { + t.Parallel() + fts := []*FieldType{ NewFieldType(mysql.TypeTiny), NewFieldType(mysql.TypeShort), @@ -346,30 +352,31 @@ func (s testFieldTypeSuite) TestAggFieldTypeForIntegralPromotion(c *C) { tps[0].Flag = 0 tps[1].Flag = 0 aggTp := AggFieldType(tps) - c.Assert(aggTp.Tp, Equals, fts[i].Tp) - c.Assert(aggTp.Flag, Equals, uint(0)) + require.Equal(t, fts[i].Tp, aggTp.Tp) + require.Equal(t, uint(0), aggTp.Flag) tps[0].Flag = mysql.UnsignedFlag aggTp = AggFieldType(tps) - c.Assert(aggTp.Tp, Equals, fts[i].Tp) - c.Assert(aggTp.Flag, Equals, uint(0)) + require.Equal(t, fts[i].Tp, aggTp.Tp) + require.Equal(t, uint(0), aggTp.Flag) tps[0].Flag = mysql.UnsignedFlag tps[1].Flag = mysql.UnsignedFlag aggTp = AggFieldType(tps) - c.Assert(aggTp.Tp, Equals, fts[i].Tp) - c.Assert(aggTp.Flag, Equals, mysql.UnsignedFlag) + require.Equal(t, fts[i].Tp, aggTp.Tp) + require.Equal(t, mysql.UnsignedFlag, aggTp.Flag) tps[0].Flag = 0 tps[1].Flag = mysql.UnsignedFlag aggTp = AggFieldType(tps) - c.Assert(aggTp.Tp, Equals, fts[i+1].Tp) - c.Assert(aggTp.Flag, Equals, uint(0)) + require.Equal(t, fts[i+1].Tp, aggTp.Tp) + require.Equal(t, uint(0), aggTp.Flag) } } -func (s *testFieldTypeSuite) TestAggregateEvalType(c *C) { - defer testleak.AfterTest(c)() +func TestAggregateEvalType(t *testing.T) { + t.Parallel() + fts := []*FieldType{ NewFieldType(mysql.TypeUnspecified), NewFieldType(mysql.TypeTiny), @@ -410,18 +417,18 @@ func (s *testFieldTypeSuite) TestAggregateEvalType(c *C) { mysql.TypeJSON, mysql.TypeEnum, mysql.TypeSet, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob, mysql.TypeVarString, mysql.TypeString, mysql.TypeGeometry: - c.Assert(aggregatedEvalType.IsStringKind(), IsTrue) - c.Assert(flag, Equals, uint(0)) + require.True(t, aggregatedEvalType.IsStringKind()) + require.Equal(t, uint(0), flag) case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeBit, mysql.TypeInt24, mysql.TypeYear: - c.Assert(aggregatedEvalType, Equals, ETInt) - c.Assert(flag, Equals, mysql.BinaryFlag) + require.Equal(t, ETInt, aggregatedEvalType) + require.Equal(t, mysql.BinaryFlag, flag) case mysql.TypeFloat, mysql.TypeDouble: - c.Assert(aggregatedEvalType, Equals, ETReal) - c.Assert(flag, Equals, mysql.BinaryFlag) + require.Equal(t, ETReal, aggregatedEvalType) + require.Equal(t, mysql.BinaryFlag, flag) case mysql.TypeNewDecimal: - c.Assert(aggregatedEvalType, Equals, ETDecimal) - c.Assert(flag, Equals, mysql.BinaryFlag) + require.Equal(t, ETDecimal, aggregatedEvalType) + require.Equal(t, mysql.BinaryFlag, flag) } flag = 0 @@ -432,18 +439,18 @@ func (s *testFieldTypeSuite) TestAggregateEvalType(c *C) { mysql.TypeJSON, mysql.TypeEnum, mysql.TypeSet, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob, mysql.TypeVarString, mysql.TypeString, mysql.TypeGeometry: - c.Assert(aggregatedEvalType.IsStringKind(), IsTrue) - c.Assert(flag, Equals, uint(0)) + require.True(t, aggregatedEvalType.IsStringKind()) + require.Equal(t, uint(0), flag) case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeBit, mysql.TypeInt24, mysql.TypeYear: - c.Assert(aggregatedEvalType, Equals, ETInt) - c.Assert(flag, Equals, mysql.BinaryFlag) + require.Equal(t, ETInt, aggregatedEvalType) + require.Equal(t, mysql.BinaryFlag, flag) case mysql.TypeFloat, mysql.TypeDouble: - c.Assert(aggregatedEvalType, Equals, ETReal) - c.Assert(flag, Equals, mysql.BinaryFlag) + require.Equal(t, ETReal, aggregatedEvalType) + require.Equal(t, mysql.BinaryFlag, flag) case mysql.TypeNewDecimal: - c.Assert(aggregatedEvalType, Equals, ETDecimal) - c.Assert(flag, Equals, mysql.BinaryFlag) + require.Equal(t, ETDecimal, aggregatedEvalType) + require.Equal(t, mysql.BinaryFlag, flag) } flag = 0 aggregatedEvalType = AggregateEvalType([]*FieldType{fts[i], NewFieldType(mysql.TypeLong)}, &flag) @@ -453,18 +460,18 @@ func (s *testFieldTypeSuite) TestAggregateEvalType(c *C) { mysql.TypeEnum, mysql.TypeSet, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob, mysql.TypeVarString, mysql.TypeString, mysql.TypeGeometry: - c.Assert(aggregatedEvalType.IsStringKind(), IsTrue) - c.Assert(flag, Equals, uint(0)) + require.True(t, aggregatedEvalType.IsStringKind()) + require.Equal(t, uint(0), flag) case mysql.TypeUnspecified, mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong, mysql.TypeNull, mysql.TypeBit, mysql.TypeLonglong, mysql.TypeYear, mysql.TypeInt24: - c.Assert(aggregatedEvalType, Equals, ETInt) - c.Assert(flag, Equals, mysql.BinaryFlag) + require.Equal(t, ETInt, aggregatedEvalType) + require.Equal(t, mysql.BinaryFlag, flag) case mysql.TypeFloat, mysql.TypeDouble: - c.Assert(aggregatedEvalType, Equals, ETReal) - c.Assert(flag, Equals, mysql.BinaryFlag) + require.Equal(t, ETReal, aggregatedEvalType) + require.Equal(t, mysql.BinaryFlag, flag) case mysql.TypeNewDecimal: - c.Assert(aggregatedEvalType, Equals, ETDecimal) - c.Assert(flag, Equals, mysql.BinaryFlag) + require.Equal(t, ETDecimal, aggregatedEvalType) + require.Equal(t, mysql.BinaryFlag, flag) } } } From 1e8b9337e3ef3cabba819fb562328fdca273c9ea Mon Sep 17 00:00:00 2001 From: Wallace Date: Wed, 22 Sep 2021 19:28:45 +0800 Subject: [PATCH 04/13] lightning: do not output pre-check info when disable check-requirements (#27934) --- br/pkg/lightning/config/const.go | 1 - br/pkg/lightning/lightning.go | 11 +++- br/pkg/lightning/log/log.go | 5 ++ br/pkg/lightning/metric/metric.go | 11 ++-- br/pkg/lightning/restore/check_info.go | 51 +++++++++-------- br/pkg/lightning/restore/check_template.go | 2 +- br/pkg/lightning/restore/restore.go | 64 ++++++++++++++-------- br/pkg/lightning/restore/restore_test.go | 6 +- br/pkg/lightning/restore/table_restore.go | 9 +++ 9 files changed, 99 insertions(+), 61 deletions(-) diff --git a/br/pkg/lightning/config/const.go b/br/pkg/lightning/config/const.go index 4f262eaddbcca..bf807f2fe759a 100644 --- a/br/pkg/lightning/config/const.go +++ b/br/pkg/lightning/config/const.go @@ -21,7 +21,6 @@ import ( const ( // mydumper ReadBlockSize ByteSize = 64 * units.KiB - MinRegionSize ByteSize = 256 * units.MiB MaxRegionSize ByteSize = 256 * units.MiB SplitRegionSize ByteSize = 96 * units.MiB MaxSplitRegionSizeRatio int = 10 diff --git a/br/pkg/lightning/lightning.go b/br/pkg/lightning/lightning.go index 06ddf30ed9d91..8164f27f551d8 100644 --- a/br/pkg/lightning/lightning.go +++ b/br/pkg/lightning/lightning.go @@ -61,6 +61,7 @@ type Lightning struct { server http.Server serverAddr net.Addr serverLock sync.Mutex + status restore.LightningStatus cancelLock sync.Mutex curTask *config.Config @@ -310,7 +311,8 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, g glue. web.BroadcastInitProgress(dbMetas) var procedure *restore.Controller - procedure, err = restore.NewRestoreController(ctx, dbMetas, taskCfg, s, g) + + procedure, err = restore.NewRestoreController(ctx, dbMetas, taskCfg, &l.status, s, g) if err != nil { log.L().Error("restore failed", log.ShortError(err)) return errors.Trace(err) @@ -333,6 +335,13 @@ func (l *Lightning) Stop() { l.shutdown() } +// Status return the sum size of file which has been imported to TiKV and the total size of source file. +func (l *Lightning) Status() (finished int64, total int64) { + finished = l.status.FinishedFileSize.Load() + total = l.status.TotalFileSize.Load() + return +} + func writeJSONError(w http.ResponseWriter, code int, prefix string, err error) { type errorResponse struct { Error string `json:"error"` diff --git a/br/pkg/lightning/log/log.go b/br/pkg/lightning/log/log.go index e91ed3fb47a56..2dc24acac1541 100644 --- a/br/pkg/lightning/log/log.go +++ b/br/pkg/lightning/log/log.go @@ -116,6 +116,11 @@ func InitLogger(cfg *Config, tidbLoglevel string) error { return nil } +// SetAppLogger replaces the default logger in this package to given one +func SetAppLogger(l *zap.Logger) { + appLogger = Logger{l.WithOptions(zap.AddStacktrace(zap.DPanicLevel))} +} + // L returns the current logger for Lightning. func L() Logger { return appLogger diff --git a/br/pkg/lightning/metric/metric.go b/br/pkg/lightning/metric/metric.go index a06c8355266cf..984b14c8460f6 100644 --- a/br/pkg/lightning/metric/metric.go +++ b/br/pkg/lightning/metric/metric.go @@ -23,13 +23,10 @@ import ( const ( // states used for the TableCounter labels - TableStatePending = "pending" - TableStateWritten = "written" - TableStateClosed = "closed" - TableStateImported = "imported" - TableStateAlteredAutoInc = "altered_auto_inc" - TableStateChecksum = "checksum" - TableStateCompleted = "completed" + TableStatePending = "pending" + TableStateWritten = "written" + TableStateImported = "imported" + TableStateCompleted = "completed" // results used for the TableCounter labels TableResultSuccess = "success" diff --git a/br/pkg/lightning/restore/check_info.go b/br/pkg/lightning/restore/check_info.go index e2905ed1369ac..648f421574fa2 100644 --- a/br/pkg/lightning/restore/check_info.go +++ b/br/pkg/lightning/restore/check_info.go @@ -92,8 +92,8 @@ func (rc *Controller) getClusterAvail(ctx context.Context) (uint64, error) { return clusterAvail, nil } -// ClusterResource check cluster has enough resource to import data. this test can by skipped. -func (rc *Controller) ClusterResource(ctx context.Context, localSource int64) error { +// clusterResource check cluster has enough resource to import data. this test can by skipped. +func (rc *Controller) clusterResource(ctx context.Context, localSource int64) error { passed := true message := "Cluster resources are rich for this import task" defer func() { @@ -167,11 +167,6 @@ func (rc *Controller) ClusterIsAvailable(ctx context.Context) error { defer func() { rc.checkTemplate.Collect(Critical, passed, message) }() - // skip requirement check if explicitly turned off - if !rc.cfg.App.CheckRequirements { - message = "Cluster's available check is skipped by user requirement" - return nil - } checkCtx := &backend.CheckCtx{ DBMetas: rc.dbMetas, } @@ -314,8 +309,8 @@ func (rc *Controller) checkRegionDistribution(ctx context.Context) error { return nil } -// CheckClusterRegion checks cluster if there are too many empty regions or region distribution is unbalanced. -func (rc *Controller) CheckClusterRegion(ctx context.Context) error { +// checkClusterRegion checks cluster if there are too many empty regions or region distribution is unbalanced. +func (rc *Controller) checkClusterRegion(ctx context.Context) error { err := rc.taskMgr.CheckTasksExclusively(ctx, func(tasks []taskMeta) ([]taskMeta, error) { restoreStarted := false for _, task := range tasks { @@ -390,7 +385,7 @@ func (rc *Controller) HasLargeCSV(dbMetas []*mydump.MDDatabaseMeta) error { return nil } -func (rc *Controller) EstimateSourceData(ctx context.Context) (int64, error) { +func (rc *Controller) estimateSourceData(ctx context.Context) (int64, error) { sourceSize := int64(0) originSource := int64(0) bigTableCount := 0 @@ -404,21 +399,32 @@ func (rc *Controller) EstimateSourceData(ctx context.Context) (int64, error) { for _, tbl := range db.Tables { tableInfo, ok := info.Tables[tbl.Name] if ok { - if err := rc.SampleDataFromTable(ctx, db.Name, tbl, tableInfo.Core); err != nil { - return sourceSize, errors.Trace(err) - } - sourceSize += int64(float64(tbl.TotalSize) * tbl.IndexRatio) originSource += tbl.TotalSize - if tbl.TotalSize > int64(config.DefaultBatchSize)*2 { - bigTableCount += 1 - if !tbl.IsRowOrdered { - unSortedTableCount += 1 + // Do not sample small table because there may a large number of small table and it will take a long + // time to sample data for all of them. + if tbl.TotalSize < int64(config.SplitRegionSize) { + sourceSize += tbl.TotalSize + tbl.IndexRatio = 1.0 + tbl.IsRowOrdered = false + } else { + if err := rc.sampleDataFromTable(ctx, db.Name, tbl, tableInfo.Core); err != nil { + return sourceSize, errors.Trace(err) + } + sourceSize += int64(float64(tbl.TotalSize) * tbl.IndexRatio) + if tbl.TotalSize > int64(config.DefaultBatchSize)*2 { + bigTableCount += 1 + if !tbl.IsRowOrdered { + unSortedTableCount += 1 + } } } tableCount += 1 } } } + if rc.status != nil { + rc.status.TotalFileSize.Store(originSource) + } // Do not import with too large concurrency because these data may be all unsorted. if bigTableCount > 0 && unSortedTableCount > 0 { @@ -429,8 +435,8 @@ func (rc *Controller) EstimateSourceData(ctx context.Context) (int64, error) { return sourceSize, nil } -// LocalResource checks the local node has enough resources for this import when local backend enabled; -func (rc *Controller) LocalResource(ctx context.Context, sourceSize int64) error { +// localResource checks the local node has enough resources for this import when local backend enabled; +func (rc *Controller) localResource(sourceSize int64) error { if rc.isSourceInLocal() { sourceDir := strings.TrimPrefix(rc.cfg.Mydumper.SourceDir, storage.LocalURIPrefix) same, err := common.SameDisk(sourceDir, rc.cfg.TikvImporter.SortedKVDir) @@ -449,9 +455,6 @@ func (rc *Controller) LocalResource(ctx context.Context, sourceSize int64) error return errors.Trace(err) } localAvailable := storageSize.Available - if err = rc.taskMgr.InitTask(ctx, sourceSize); err != nil { - return errors.Trace(err) - } var message string var passed bool @@ -732,7 +735,7 @@ func (rc *Controller) SchemaIsValid(ctx context.Context, tableInfo *mydump.MDTab return msgs, nil } -func (rc *Controller) SampleDataFromTable(ctx context.Context, dbName string, tableMeta *mydump.MDTableMeta, tableInfo *model.TableInfo) error { +func (rc *Controller) sampleDataFromTable(ctx context.Context, dbName string, tableMeta *mydump.MDTableMeta, tableInfo *model.TableInfo) error { if len(tableMeta.DataFiles) == 0 { return nil } diff --git a/br/pkg/lightning/restore/check_template.go b/br/pkg/lightning/restore/check_template.go index ba63c505a45f8..2b7d2a405cde0 100644 --- a/br/pkg/lightning/restore/check_template.go +++ b/br/pkg/lightning/restore/check_template.go @@ -82,7 +82,7 @@ func (c *SimpleTemplate) Collect(t CheckType, passed bool, msg string) { } func (c *SimpleTemplate) Success() bool { - return c.warnFailedCount+c.criticalFailedCount == 0 + return c.criticalFailedCount == 0 } func (c *SimpleTemplate) FailedCount(t CheckType) int { diff --git a/br/pkg/lightning/restore/restore.go b/br/pkg/lightning/restore/restore.go index c98c43ae1c784..6969e61acfc1d 100644 --- a/br/pkg/lightning/restore/restore.go +++ b/br/pkg/lightning/restore/restore.go @@ -257,22 +257,30 @@ type Controller struct { diskQuotaLock *diskQuotaLock diskQuotaState atomic.Int32 compactState atomic.Int32 + status *LightningStatus +} + +type LightningStatus struct { + FinishedFileSize atomic.Int64 + TotalFileSize atomic.Int64 } func NewRestoreController( ctx context.Context, dbMetas []*mydump.MDDatabaseMeta, cfg *config.Config, + status *LightningStatus, s storage.ExternalStorage, g glue.Glue, ) (*Controller, error) { - return NewRestoreControllerWithPauser(ctx, dbMetas, cfg, s, DeliverPauser, g) + return NewRestoreControllerWithPauser(ctx, dbMetas, cfg, status, s, DeliverPauser, g) } func NewRestoreControllerWithPauser( ctx context.Context, dbMetas []*mydump.MDDatabaseMeta, cfg *config.Config, + status *LightningStatus, s storage.ExternalStorage, pauser *common.Pauser, g glue.Glue, @@ -379,6 +387,7 @@ func NewRestoreControllerWithPauser( store: s, metaMgrBuilder: metaBuilder, diskQuotaLock: newDiskQuotaLock(), + status: status, taskMgr: nil, } @@ -1716,18 +1725,26 @@ func (rc *Controller) isLocalBackend() bool { // 4. Lightning configuration // before restore tables start. func (rc *Controller) preCheckRequirements(ctx context.Context) error { - if err := rc.ClusterIsAvailable(ctx); err != nil { - return errors.Trace(err) - } + if rc.cfg.App.CheckRequirements { + if err := rc.ClusterIsAvailable(ctx); err != nil { + return errors.Trace(err) + } - if err := rc.StoragePermission(ctx); err != nil { - return errors.Trace(err) + if err := rc.StoragePermission(ctx); err != nil { + return errors.Trace(err) + } } + if err := rc.metaMgrBuilder.Init(ctx); err != nil { return err } taskExist := false - + // We still need to sample source data even if this task has existed, because we need to judge whether the + // source is in order as row key to decide how to sort local data. + source, err := rc.estimateSourceData(ctx) + if err != nil { + return errors.Trace(err) + } if rc.isLocalBackend() { pdController, err := pdutil.NewPdController(ctx, rc.cfg.TiDB.PdAddr, rc.tls.TLSConfig(), rc.tls.ToPDSecurityOption()) @@ -1742,29 +1759,28 @@ func (rc *Controller) preCheckRequirements(ctx context.Context) error { return errors.Trace(err) } if !taskExist { - source, err := rc.EstimateSourceData(ctx) - if err != nil { + if err = rc.taskMgr.InitTask(ctx, source); err != nil { return errors.Trace(err) } - err = rc.LocalResource(ctx, source) - if err != nil { - rc.taskMgr.CleanupTask(ctx) - return errors.Trace(err) - } - if err := rc.ClusterResource(ctx, source); err != nil { - rc.taskMgr.CleanupTask(ctx) - return errors.Trace(err) - } - if err := rc.CheckClusterRegion(ctx); err != nil { - return errors.Trace(err) + if rc.cfg.App.CheckRequirements { + err = rc.localResource(source) + if err != nil { + return errors.Trace(err) + } + if err := rc.clusterResource(ctx, source); err != nil { + rc.taskMgr.CleanupTask(ctx) + return errors.Trace(err) + } + if err := rc.checkClusterRegion(ctx); err != nil { + return errors.Trace(err) + } } } } - if rc.tidbGlue.OwnsSQLExecutor() { - // print check info at any time. + + if rc.tidbGlue.OwnsSQLExecutor() && rc.cfg.App.CheckRequirements { fmt.Print(rc.checkTemplate.Output()) - if rc.cfg.App.CheckRequirements && !rc.checkTemplate.Success() { - // if check requirements is true, return error. + if !rc.checkTemplate.Success() { if !taskExist && rc.taskMgr != nil { rc.taskMgr.CleanupTask(ctx) } diff --git a/br/pkg/lightning/restore/restore_test.go b/br/pkg/lightning/restore/restore_test.go index 02a38b431a8c1..15fceff51b275 100644 --- a/br/pkg/lightning/restore/restore_test.go +++ b/br/pkg/lightning/restore/restore_test.go @@ -1713,7 +1713,7 @@ func (s *tableRestoreSuite) TestCheckClusterResource(c *C) { sourceSize += size return nil }) - err = rc.ClusterResource(ctx, sourceSize) + err = rc.clusterResource(ctx, sourceSize) c.Assert(err, IsNil) c.Assert(template.FailedCount(Critical), Equals, ca.expectErrorCount) @@ -1835,7 +1835,7 @@ func (s *tableRestoreSuite) TestCheckClusterRegion(c *C) { cfg := &config.Config{TiDB: config.DBStore{PdAddr: url}} rc := &Controller{cfg: cfg, tls: tls, taskMgr: mockTaskMetaMgr{}, checkTemplate: template} - err := rc.CheckClusterRegion(context.Background()) + err := rc.checkClusterRegion(context.Background()) c.Assert(err, IsNil) c.Assert(template.FailedCount(Critical), Equals, ca.expectErrorCnt) c.Assert(template.Success(), Equals, ca.expectResult) @@ -1887,7 +1887,7 @@ func (s *tableRestoreSuite) TestCheckHasLargeCSV(c *C) { { false, "(.*)large csv: /testPath file exists(.*)", - false, + true, 1, []*mydump.MDDatabaseMeta{ { diff --git a/br/pkg/lightning/restore/table_restore.go b/br/pkg/lightning/restore/table_restore.go index 0ead1c1eb36c6..90834c482cff8 100644 --- a/br/pkg/lightning/restore/table_restore.go +++ b/br/pkg/lightning/restore/table_restore.go @@ -314,11 +314,20 @@ func (tr *TableRestore) restoreEngines(pCtx context.Context, rc *Controller, cp dataWorker := rc.closedEngineLimit.Apply() defer rc.closedEngineLimit.Recycle(dataWorker) err = tr.importEngine(ctx, dataClosedEngine, rc, eid, ecp) + if rc.status != nil { + for _, chunk := range ecp.Chunks { + rc.status.FinishedFileSize.Add(chunk.Chunk.EndOffset - chunk.Key.Offset) + } + } } if err != nil { setError(err) } }(restoreWorker, engineID, engine) + } else { + for _, chunk := range engine.Chunks { + rc.status.FinishedFileSize.Add(chunk.Chunk.EndOffset - chunk.Key.Offset) + } } } From bfbea9c3ef4232d76296a9c8390eb8b7da5bf45d Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Wed, 22 Sep 2021 12:10:46 -0600 Subject: [PATCH 05/13] variable: Add support for mock globalvars (#26990) --- expression/bench_test.go | 2 +- sessionctx/variable/mock_globalaccessor.go | 48 +++++++++++++++++++--- sessionctx/variable/sysvar_test.go | 33 ++++++++++++--- 3 files changed, 71 insertions(+), 12 deletions(-) diff --git a/expression/bench_test.go b/expression/bench_test.go index fef13319d95c3..a58f74f867667 100644 --- a/expression/bench_test.go +++ b/expression/bench_test.go @@ -1806,7 +1806,7 @@ func (s *testVectorizeSuite2) TestVecEvalBool(c *C) { it := chunk.NewIterator4Chunk(input) i := 0 for row := it.Begin(); row != it.End(); row = it.Next() { - ok, null, err := EvalBool(mock.NewContext(), exprs, row) + ok, null, err := EvalBool(ctx, exprs, row) c.Assert(err, IsNil) c.Assert(null, Equals, nulls[i]) c.Assert(ok, Equals, selected[i]) diff --git a/sessionctx/variable/mock_globalaccessor.go b/sessionctx/variable/mock_globalaccessor.go index e48c44945949e..339b9c40bdac1 100644 --- a/sessionctx/variable/mock_globalaccessor.go +++ b/sessionctx/variable/mock_globalaccessor.go @@ -16,30 +16,66 @@ package variable // MockGlobalAccessor implements GlobalVarAccessor interface. it's used in tests type MockGlobalAccessor struct { + SessionVars *SessionVars // can be overwritten if needed for correctness. + vals map[string]string } // NewMockGlobalAccessor implements GlobalVarAccessor interface. func NewMockGlobalAccessor() *MockGlobalAccessor { - return new(MockGlobalAccessor) + tmp := new(MockGlobalAccessor) + tmp.vals = make(map[string]string) + + // There's technically a test bug here where the sessionVars won't match + // the session vars in the test which this MockGlobalAccessor is assigned to. + // But if the test requires accurate sessionVars, it can do the following: + // + // vars := NewSessionVars() + // mock := NewMockGlobalAccessor() + // mock.SessionVars = vars + // vars.GlobalVarsAccessor = mock + + tmp.SessionVars = NewSessionVars() + + // Set all sysvars to the default value + for k, sv := range GetSysVars() { + tmp.vals[k] = sv.Value + } + return tmp } // GetGlobalSysVar implements GlobalVarAccessor.GetGlobalSysVar interface. func (m *MockGlobalAccessor) GetGlobalSysVar(name string) (string, error) { - v, ok := sysVars[name] + v, ok := m.vals[name] if ok { - return v.Value, nil + return v, nil } return "", nil } // SetGlobalSysVar implements GlobalVarAccessor.SetGlobalSysVar interface. -func (m *MockGlobalAccessor) SetGlobalSysVar(name string, value string) error { - panic("not supported") +func (m *MockGlobalAccessor) SetGlobalSysVar(name string, value string) (err error) { + sv := GetSysVar(name) + if sv == nil { + return ErrUnknownSystemVar.GenWithStackByArgs(name) + } + if value, err = sv.Validate(m.SessionVars, value, ScopeGlobal); err != nil { + return err + } + if err = sv.SetGlobalFromHook(m.SessionVars, value, false); err != nil { + return err + } + m.vals[name] = value + return nil } // SetGlobalSysVarOnly implements GlobalVarAccessor.SetGlobalSysVarOnly interface. func (m *MockGlobalAccessor) SetGlobalSysVarOnly(name string, value string) error { - panic("not supported") + sv := GetSysVar(name) + if sv == nil { + return ErrUnknownSystemVar.GenWithStackByArgs(name) + } + m.vals[name] = value + return nil } // GetTiDBTableValue implements GlobalVarAccessor.GetTiDBTableValue interface. diff --git a/sessionctx/variable/sysvar_test.go b/sessionctx/variable/sysvar_test.go index 6a4f925b12044..029761467989a 100644 --- a/sessionctx/variable/sysvar_test.go +++ b/sessionctx/variable/sysvar_test.go @@ -464,11 +464,38 @@ func TestTiDBMultiStatementMode(t *testing.T) { func TestReadOnlyNoop(t *testing.T) { vars := NewSessionVars() + mock := NewMockGlobalAccessor() + mock.SessionVars = vars + vars.GlobalVarsAccessor = mock + noopFuncs := GetSysVar(TiDBEnableNoopFuncs) + + // For session scope for _, name := range []string{TxReadOnly, TransactionReadOnly} { sv := GetSysVar(name) val, err := sv.Validate(vars, "on", ScopeSession) require.Equal(t, "[variable:1235]function READ ONLY has only noop implementation in tidb now, use tidb_enable_noop_functions to enable these functions", err.Error()) require.Equal(t, "OFF", val) + + require.NoError(t, noopFuncs.SetSessionFromHook(vars, "ON")) + _, err = sv.Validate(vars, "on", ScopeSession) + require.NoError(t, err) + require.NoError(t, noopFuncs.SetSessionFromHook(vars, "OFF")) // restore default. + } + + // For global scope + for _, name := range []string{TxReadOnly, TransactionReadOnly, OfflineMode, SuperReadOnly, ReadOnly} { + sv := GetSysVar(name) + val, err := sv.Validate(vars, "on", ScopeGlobal) + if name == OfflineMode { + require.Equal(t, "[variable:1235]function OFFLINE MODE has only noop implementation in tidb now, use tidb_enable_noop_functions to enable these functions", err.Error()) + } else { + require.Equal(t, "[variable:1235]function READ ONLY has only noop implementation in tidb now, use tidb_enable_noop_functions to enable these functions", err.Error()) + } + require.Equal(t, "OFF", val) + require.NoError(t, vars.GlobalVarsAccessor.SetGlobalSysVar(TiDBEnableNoopFuncs, "ON")) + _, err = sv.Validate(vars, "on", ScopeGlobal) + require.NoError(t, err) + require.NoError(t, vars.GlobalVarsAccessor.SetGlobalSysVar(TiDBEnableNoopFuncs, "OFF")) } } @@ -605,6 +632,7 @@ func TestInstanceScopedVars(t *testing.T) { // The default values should also be normalized for consistency. func TestDefaultValuesAreSettable(t *testing.T) { vars := NewSessionVars() + vars.GlobalVarsAccessor = NewMockGlobalAccessor() for _, sv := range GetSysVars() { if sv.HasSessionScope() && !sv.ReadOnly { val, err := sv.Validate(vars, sv.Value, ScopeSession) @@ -613,11 +641,6 @@ func TestDefaultValuesAreSettable(t *testing.T) { } if sv.HasGlobalScope() && !sv.ReadOnly { - if sv.Name == TiDBEnableNoopFuncs { - // TODO: this requires access to the global var accessor, - // which is not available in this test. - continue - } val, err := sv.Validate(vars, sv.Value, ScopeGlobal) require.Equal(t, val, sv.Value) require.NoError(t, err) From b8d85beb0ddaa55b417a9af66b3facce44af61d4 Mon Sep 17 00:00:00 2001 From: Mattias Jonsson Date: Thu, 23 Sep 2021 00:52:46 +0200 Subject: [PATCH 06/13] ddl: add schema placement rules (#27969) --- ddl/db_partition_test.go | 2 +- ddl/ddl.go | 2 +- ddl/ddl_api.go | 173 +++++++++++++++++++------------------- ddl/placement_sql_test.go | 85 +++++++++++++++++++ domain/domain_test.go | 6 +- executor/ddl.go | 31 +++++-- executor/show.go | 11 ++- session/session_test.go | 2 +- 8 files changed, 208 insertions(+), 104 deletions(-) diff --git a/ddl/db_partition_test.go b/ddl/db_partition_test.go index a7722398efc89..afbbe9e60ed7e 100644 --- a/ddl/db_partition_test.go +++ b/ddl/db_partition_test.go @@ -1203,7 +1203,7 @@ func (s *testIntegrationSuite5) TestAlterTableDropPartitionByListColumns(c *C) { );`) tk.MustExec(`insert into t values (1,'a'),(3,'a'),(5,'a'),(null,null)`) tk.MustExec(`alter table t drop partition p1`) - tk.MustQuery("select * from t").Check(testkit.Rows("1 a", "5 a", " ")) + tk.MustQuery("select * from t").Sort().Check(testkit.Rows("1 a", "5 a", " ")) ctx := tk.Se.(sessionctx.Context) is := domain.GetDomain(ctx).InfoSchema() tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) diff --git a/ddl/ddl.go b/ddl/ddl.go index cb2f2dba92586..0511f8b278481 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -93,7 +93,7 @@ var ( // DDL is responsible for updating schema in data store and maintaining in-memory InfoSchema cache. type DDL interface { - CreateSchema(ctx sessionctx.Context, name model.CIStr, charsetInfo *ast.CharsetOpt) error + CreateSchema(ctx sessionctx.Context, schema model.CIStr, charsetInfo *ast.CharsetOpt, directPlacementOpts *model.PlacementSettings, placementPolicyRef *model.PolicyRefInfo) error AlterSchema(ctx sessionctx.Context, stmt *ast.AlterDatabaseStmt) error DropSchema(ctx sessionctx.Context, schema model.CIStr) error CreateTable(ctx sessionctx.Context, stmt *ast.CreateTableStmt) error diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 4b47459203343..1361be646c442 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -72,8 +72,26 @@ const ( longBlobMaxLength = 4294967295 ) -func (d *ddl) CreateSchema(ctx sessionctx.Context, schema model.CIStr, charsetInfo *ast.CharsetOpt) error { +func (d *ddl) CreateSchema(ctx sessionctx.Context, schema model.CIStr, charsetInfo *ast.CharsetOpt, directPlacementOpts *model.PlacementSettings, placementPolicyRef *model.PolicyRefInfo) error { dbInfo := &model.DBInfo{Name: schema} + if directPlacementOpts != nil && placementPolicyRef != nil { + return errors.Trace(ErrPlacementPolicyWithDirectOption.GenWithStackByArgs(placementPolicyRef.Name)) + } + if directPlacementOpts != nil { + // check the direct placement option compatibility. + if err := checkPolicyValidation(directPlacementOpts); err != nil { + return errors.Trace(err) + } + dbInfo.DirectPlacementOpts = directPlacementOpts + } + if placementPolicyRef != nil { + // placement policy reference will override the direct placement options. + policy, ok := ctx.GetInfoSchema().(infoschema.InfoSchema).PolicyByName(placementPolicyRef.Name) + if !ok { + return errors.Trace(infoschema.ErrPlacementPolicyNotExists.GenWithStackByArgs(placementPolicyRef.Name)) + } + dbInfo.PlacementPolicyRef = &model.PolicyRefInfo{ID: policy.ID, Name: placementPolicyRef.Name} + } if charsetInfo != nil { chs, coll, err := ResolveCharsetCollation(ast.CharsetOpt{Chs: charsetInfo.Chs, Col: charsetInfo.Col}) if err != nil { @@ -157,6 +175,7 @@ func (d *ddl) AlterSchema(ctx sessionctx.Context, stmt *ast.AlterDatabaseStmt) ( return ErrConflictingDeclarations.GenWithStackByArgs(toCharset, info.CharsetName) } toCollate = info.Name + } } if toCollate == "" { @@ -1705,13 +1724,13 @@ func buildTableInfoWithLike(ctx sessionctx.Context, ident ast.Ident, referTblInf // BuildTableInfoFromAST builds model.TableInfo from a SQL statement. // Note: TableID and PartitionID are left as uninitialized value. func BuildTableInfoFromAST(s *ast.CreateTableStmt) (*model.TableInfo, error) { - return buildTableInfoWithCheck(mock.NewContext(), s, mysql.DefaultCharset, "") + return buildTableInfoWithCheck(mock.NewContext(), s, mysql.DefaultCharset, "", nil, nil) } // buildTableInfoWithCheck builds model.TableInfo from a SQL statement. // Note: TableID and PartitionIDs are left as uninitialized value. -func buildTableInfoWithCheck(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string) (*model.TableInfo, error) { - tbInfo, err := buildTableInfoWithStmt(ctx, s, dbCharset, dbCollate) +func buildTableInfoWithCheck(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo, directPlacementOpts *model.PlacementSettings) (*model.TableInfo, error) { + tbInfo, err := buildTableInfoWithStmt(ctx, s, dbCharset, dbCollate, placementPolicyRef, directPlacementOpts) if err != nil { return nil, err } @@ -1728,7 +1747,7 @@ func buildTableInfoWithCheck(ctx sessionctx.Context, s *ast.CreateTableStmt, dbC } // BuildSessionTemporaryTableInfo builds model.TableInfo from a SQL statement. -func BuildSessionTemporaryTableInfo(ctx sessionctx.Context, is infoschema.InfoSchema, s *ast.CreateTableStmt, dbCharset, dbCollate string) (*model.TableInfo, error) { +func BuildSessionTemporaryTableInfo(ctx sessionctx.Context, is infoschema.InfoSchema, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo, directPlacementOpts *model.PlacementSettings) (*model.TableInfo, error) { ident := ast.Ident{Schema: s.Table.Schema, Name: s.Table.Name} //build tableInfo var tbInfo *model.TableInfo @@ -1746,13 +1765,13 @@ func BuildSessionTemporaryTableInfo(ctx sessionctx.Context, is infoschema.InfoSc } tbInfo, err = buildTableInfoWithLike(ctx, ident, referTbl.Meta(), s) } else { - tbInfo, err = buildTableInfoWithCheck(ctx, s, dbCharset, dbCollate) + tbInfo, err = buildTableInfoWithCheck(ctx, s, dbCharset, dbCollate, placementPolicyRef, directPlacementOpts) } return tbInfo, err } // buildTableInfoWithStmt builds model.TableInfo from a SQL statement without validity check -func buildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string) (*model.TableInfo, error) { +func buildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo, directPlacementOpts *model.PlacementSettings) (*model.TableInfo, error) { colDefs := s.Cols tableCharset, tableCollate, err := getCharsetAndCollateInTableOption(0, s.Options) if err != nil { @@ -1789,14 +1808,25 @@ func buildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCh return nil, errors.Trace(err) } - err = buildTablePartitionInfo(ctx, s.Partition, tbInfo) - if err != nil { + if err = handleTableOptions(s.Options, tbInfo); err != nil { return nil, errors.Trace(err) } - if err = handleTableOptions(s.Options, tbInfo); err != nil { + if tbInfo.PlacementPolicyRef == nil && tbInfo.DirectPlacementOpts == nil { + // Set the defaults from Schema. Note: they are mutual exlusive! + if placementPolicyRef != nil { + tbInfo.PlacementPolicyRef = placementPolicyRef + } else if directPlacementOpts != nil { + tbInfo.DirectPlacementOpts = directPlacementOpts + } + } + + // After handleTableOptions, so the partitions can get defaults from Table level + err = buildTablePartitionInfo(ctx, s.Partition, tbInfo) + if err != nil { return nil, errors.Trace(err) } + return tbInfo, nil } @@ -1846,7 +1876,7 @@ func (d *ddl) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStmt) (err e if s.ReferTable != nil { tbInfo, err = buildTableInfoWithLike(ctx, ident, referTbl.Meta(), s) } else { - tbInfo, err = buildTableInfoWithStmt(ctx, s, schema.Charset, schema.Collate) + tbInfo, err = buildTableInfoWithStmt(ctx, s, schema.Charset, schema.Collate, schema.PlacementPolicyRef, schema.DirectPlacementOpts) } if err != nil { return errors.Trace(err) @@ -2252,6 +2282,37 @@ func (d *ddl) handleAutoIncID(tbInfo *model.TableInfo, schemaID int64, newEnd in return nil } +// SetDirectPlacementOpt tries to make the PlacementSettings assignments generic for Schema/Table/Partition +func SetDirectPlacementOpt(placementSettings *model.PlacementSettings, placementOptionType ast.PlacementOptionType, stringVal string, uintVal uint64) error { + switch placementOptionType { + case ast.PlacementOptionPrimaryRegion: + placementSettings.PrimaryRegion = stringVal + case ast.PlacementOptionRegions: + placementSettings.Regions = stringVal + case ast.PlacementOptionFollowerCount: + placementSettings.Followers = uintVal + case ast.PlacementOptionVoterCount: + placementSettings.Voters = uintVal + case ast.PlacementOptionLearnerCount: + placementSettings.Learners = uintVal + case ast.PlacementOptionSchedule: + placementSettings.Schedule = stringVal + case ast.PlacementOptionConstraints: + placementSettings.Constraints = stringVal + case ast.PlacementOptionLeaderConstraints: + placementSettings.LeaderConstraints = stringVal + case ast.PlacementOptionLearnerConstraints: + placementSettings.LearnerConstraints = stringVal + case ast.PlacementOptionFollowerConstraints: + placementSettings.FollowerConstraints = stringVal + case ast.PlacementOptionVoterConstraints: + placementSettings.VoterConstraints = stringVal + default: + return errors.Trace(errors.New("unknown placement policy option")) + } + return nil +} + // handleTableOptions updates tableInfo according to table options. func handleTableOptions(options []*ast.TableOption, tbInfo *model.TableInfo) error { for _, op := range options { @@ -2296,61 +2357,19 @@ func handleTableOptions(options []*ast.TableOption, tbInfo *model.TableInfo) err tbInfo.PlacementPolicyRef = &model.PolicyRefInfo{ Name: model.NewCIStr(op.StrValue), } - case ast.TableOptionPlacementPrimaryRegion: - if tbInfo.DirectPlacementOpts == nil { - tbInfo.DirectPlacementOpts = &model.PlacementSettings{} - } - tbInfo.DirectPlacementOpts.PrimaryRegion = op.StrValue - case ast.TableOptionPlacementRegions: - if tbInfo.DirectPlacementOpts == nil { - tbInfo.DirectPlacementOpts = &model.PlacementSettings{} - } - tbInfo.DirectPlacementOpts.Regions = op.StrValue - case ast.TableOptionPlacementFollowerCount: - if tbInfo.DirectPlacementOpts == nil { - tbInfo.DirectPlacementOpts = &model.PlacementSettings{} - } - tbInfo.DirectPlacementOpts.Followers = op.UintValue - case ast.TableOptionPlacementVoterCount: - if tbInfo.DirectPlacementOpts == nil { - tbInfo.DirectPlacementOpts = &model.PlacementSettings{} - } - tbInfo.DirectPlacementOpts.Voters = op.UintValue - case ast.TableOptionPlacementLearnerCount: - if tbInfo.DirectPlacementOpts == nil { - tbInfo.DirectPlacementOpts = &model.PlacementSettings{} - } - tbInfo.DirectPlacementOpts.Learners = op.UintValue - case ast.TableOptionPlacementSchedule: - if tbInfo.DirectPlacementOpts == nil { - tbInfo.DirectPlacementOpts = &model.PlacementSettings{} - } - tbInfo.DirectPlacementOpts.Schedule = op.StrValue - case ast.TableOptionPlacementConstraints: - if tbInfo.DirectPlacementOpts == nil { - tbInfo.DirectPlacementOpts = &model.PlacementSettings{} - } - tbInfo.DirectPlacementOpts.Constraints = op.StrValue - case ast.TableOptionPlacementLeaderConstraints: - if tbInfo.DirectPlacementOpts == nil { - tbInfo.DirectPlacementOpts = &model.PlacementSettings{} - } - tbInfo.DirectPlacementOpts.LeaderConstraints = op.StrValue - case ast.TableOptionPlacementLearnerConstraints: + case ast.TableOptionPlacementPrimaryRegion, ast.TableOptionPlacementRegions, + ast.TableOptionPlacementFollowerCount, ast.TableOptionPlacementVoterCount, + ast.TableOptionPlacementLearnerCount, ast.TableOptionPlacementSchedule, + ast.TableOptionPlacementConstraints, ast.TableOptionPlacementLeaderConstraints, + ast.TableOptionPlacementLearnerConstraints, ast.TableOptionPlacementFollowerConstraints, + ast.TableOptionPlacementVoterConstraints: if tbInfo.DirectPlacementOpts == nil { tbInfo.DirectPlacementOpts = &model.PlacementSettings{} } - tbInfo.DirectPlacementOpts.LearnerConstraints = op.StrValue - case ast.TableOptionPlacementFollowerConstraints: - if tbInfo.DirectPlacementOpts == nil { - tbInfo.DirectPlacementOpts = &model.PlacementSettings{} - } - tbInfo.DirectPlacementOpts.FollowerConstraints = op.StrValue - case ast.TableOptionPlacementVoterConstraints: - if tbInfo.DirectPlacementOpts == nil { - tbInfo.DirectPlacementOpts = &model.PlacementSettings{} + err := SetDirectPlacementOpt(tbInfo.DirectPlacementOpts, ast.PlacementOptionType(op.Tp), op.StrValue, op.UintValue) + if err != nil { + return err } - tbInfo.DirectPlacementOpts.VoterConstraints = op.StrValue } } shardingBits := shardingBits(tbInfo) @@ -5951,7 +5970,7 @@ func (d *ddl) RepairTable(ctx sessionctx.Context, table *ast.TableName, createSt } // It is necessary to specify the table.ID and partition.ID manually. - newTableInfo, err := buildTableInfoWithCheck(ctx, createStmt, oldTableInfo.Charset, oldTableInfo.Collate) + newTableInfo, err := buildTableInfoWithCheck(ctx, createStmt, oldTableInfo.Charset, oldTableInfo.Collate, oldTableInfo.PlacementPolicyRef, oldTableInfo.DirectPlacementOpts) if err != nil { return errors.Trace(err) } @@ -6302,31 +6321,9 @@ func buildPolicyInfo(name model.CIStr, options []*ast.PlacementOption) (*model.P policyInfo := &model.PolicyInfo{PlacementSettings: &model.PlacementSettings{}} policyInfo.Name = name for _, opt := range options { - switch opt.Tp { - case ast.PlacementOptionPrimaryRegion: - policyInfo.PrimaryRegion = opt.StrValue - case ast.PlacementOptionRegions: - policyInfo.Regions = opt.StrValue - case ast.PlacementOptionFollowerCount: - policyInfo.Followers = opt.UintValue - case ast.PlacementOptionVoterCount: - policyInfo.Voters = opt.UintValue - case ast.PlacementOptionLearnerCount: - policyInfo.Learners = opt.UintValue - case ast.PlacementOptionSchedule: - policyInfo.Schedule = opt.StrValue - case ast.PlacementOptionConstraints: - policyInfo.Constraints = opt.StrValue - case ast.PlacementOptionLearnerConstraints: - policyInfo.LearnerConstraints = opt.StrValue - case ast.PlacementOptionFollowerConstraints: - policyInfo.FollowerConstraints = opt.StrValue - case ast.PlacementOptionVoterConstraints: - policyInfo.VoterConstraints = opt.StrValue - case ast.PlacementOptionLeaderConstraints: - policyInfo.LeaderConstraints = opt.StrValue - default: - return nil, errors.Trace(errors.New("unknown placement policy option")) + err := SetDirectPlacementOpt(policyInfo.PlacementSettings, opt.Tp, opt.StrValue, opt.UintValue) + if err != nil { + return nil, err } } return policyInfo, nil diff --git a/ddl/placement_sql_test.go b/ddl/placement_sql_test.go index 8fc2bb00cfa6d..c3ef8e1b3aab3 100644 --- a/ddl/placement_sql_test.go +++ b/ddl/placement_sql_test.go @@ -717,3 +717,88 @@ alter placement policy testFunc(testcase.name, testcase.hook, testcase.expectErr) } } + +func (s *testDBSuite6) TestCreateSchemaWithPlacement(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("drop schema if exists SchemaDirectPlacementTest") + tk.MustExec("drop schema if exists SchemaPolicyPlacementTest") + tk.Se.GetSessionVars().EnableAlterPlacement = true + defer func() { + tk.MustExec("drop schema if exists SchemaDirectPlacementTest") + tk.MustExec("drop schema if exists SchemaPolicyPlacementTest") + tk.MustExec("drop placement policy if exists PolicySchemaTest") + tk.MustExec("drop placement policy if exists PolicyTableTest") + tk.Se.GetSessionVars().EnableAlterPlacement = false + }() + + tk.MustExec(`CREATE SCHEMA SchemaDirectPlacementTest PRIMARY_REGION='nl' REGIONS = "se,nz" FOLLOWERS=3`) + tk.MustQuery("SHOW CREATE SCHEMA schemadirectplacementtest").Check(testkit.Rows("SchemaDirectPlacementTest CREATE DATABASE `SchemaDirectPlacementTest` /*!40100 DEFAULT CHARACTER SET utf8mb4 */ PRIMARY_REGION=\"nl\" REGIONS=\"se,nz\" FOLLOWERS=3")) + + tk.MustExec(`CREATE PLACEMENT POLICY PolicySchemaTest LEADER_CONSTRAINTS = "[+region=nl]" FOLLOWER_CONSTRAINTS="[+region=se]" FOLLOWERS=4 LEARNER_CONSTRAINTS="[+region=be]" LEARNERS=4`) + tk.MustExec(`CREATE PLACEMENT POLICY PolicyTableTest LEADER_CONSTRAINTS = "[+region=tl]" FOLLOWER_CONSTRAINTS="[+region=tf]" FOLLOWERS=2 LEARNER_CONSTRAINTS="[+region=tle]" LEARNERS=1`) + tk.MustQuery("SHOW PLACEMENT like 'POLICY %PolicySchemaTest%'").Check(testkit.Rows("POLICY PolicySchemaTest LEADER_CONSTRAINTS=\"[+region=nl]\" FOLLOWERS=4 FOLLOWER_CONSTRAINTS=\"[+region=se]\" LEARNERS=4 LEARNER_CONSTRAINTS=\"[+region=be]\" SCHEDULED")) + tk.MustQuery("SHOW PLACEMENT like 'POLICY %PolicyTableTest%'").Check(testkit.Rows("POLICY PolicyTableTest LEADER_CONSTRAINTS=\"[+region=tl]\" FOLLOWERS=2 FOLLOWER_CONSTRAINTS=\"[+region=tf]\" LEARNERS=1 LEARNER_CONSTRAINTS=\"[+region=tle]\" SCHEDULED")) + tk.MustExec("CREATE SCHEMA SchemaPolicyPlacementTest PLACEMENT POLICY = `PolicySchemaTest`") + tk.MustQuery("SHOW CREATE SCHEMA SCHEMAPOLICYPLACEMENTTEST").Check(testkit.Rows("SchemaPolicyPlacementTest CREATE DATABASE `SchemaPolicyPlacementTest` /*!40100 DEFAULT CHARACTER SET utf8mb4 */ PLACEMENT POLICY = `PolicySchemaTest`")) + + tk.MustExec(`CREATE TABLE SchemaDirectPlacementTest.UseSchemaDefault (a int unsigned primary key, b varchar(255))`) + tk.MustQuery(`SHOW CREATE TABLE SchemaDirectPlacementTest.UseSchemaDefault`).Check(testkit.Rows( + "UseSchemaDefault CREATE TABLE `UseSchemaDefault` (\n" + + " `a` int(10) unsigned NOT NULL,\n" + + " `b` varchar(255) DEFAULT NULL,\n" + + " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PRIMARY_REGION=\"nl\" REGIONS=\"se,nz\" FOLLOWERS=3 */")) + tk.MustExec(`CREATE TABLE SchemaDirectPlacementTest.UseDirectPlacement (a int unsigned primary key, b varchar(255)) PRIMARY_REGION="se"`) + tk.MustQuery(`SHOW CREATE TABLE SchemaDirectPlacementTest.UseDirectPlacement`).Check(testkit.Rows( + "UseDirectPlacement CREATE TABLE `UseDirectPlacement` (\n" + + " `a` int(10) unsigned NOT NULL,\n" + + " `b` varchar(255) DEFAULT NULL,\n" + + " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PRIMARY_REGION=\"se\" */")) + tk.MustExec(`CREATE TABLE SchemaDirectPlacementTest.UsePolicy (a int unsigned primary key, b varchar(255)) PLACEMENT POLICY = "PolicyTableTest"`) + tk.MustQuery(`SHOW CREATE TABLE SchemaDirectPlacementTest.UsePolicy`).Check(testkit.Rows( + "UsePolicy CREATE TABLE `UsePolicy` (\n" + + " `a` int(10) unsigned NOT NULL,\n" + + " `b` varchar(255) DEFAULT NULL,\n" + + " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PLACEMENT POLICY=`PolicyTableTest` */")) + + tk.MustExec(`CREATE TABLE SchemaPolicyPlacementTest.UseSchemaDefault (a int unsigned primary key, b varchar(255))`) + tk.MustQuery(`SHOW CREATE TABLE SchemaPolicyPlacementTest.UseSchemaDefault`).Check(testkit.Rows( + "UseSchemaDefault CREATE TABLE `UseSchemaDefault` (\n" + + " `a` int(10) unsigned NOT NULL,\n" + + " `b` varchar(255) DEFAULT NULL,\n" + + " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PLACEMENT POLICY=`PolicySchemaTest` */")) + tk.MustExec(`CREATE TABLE SchemaPolicyPlacementTest.UseDirectPlacement (a int unsigned primary key, b varchar(255)) PRIMARY_REGION="se"`) + tk.MustQuery(`SHOW CREATE TABLE SchemaPolicyPlacementTest.UseDirectPlacement`).Check(testkit.Rows( + "UseDirectPlacement CREATE TABLE `UseDirectPlacement` (\n" + + " `a` int(10) unsigned NOT NULL,\n" + + " `b` varchar(255) DEFAULT NULL,\n" + + " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PRIMARY_REGION=\"se\" */")) + tk.MustExec(`CREATE TABLE SchemaPolicyPlacementTest.UsePolicy (a int unsigned primary key, b varchar(255)) PLACEMENT POLICY = "PolicyTableTest"`) + tk.MustQuery(`SHOW CREATE TABLE SchemaPolicyPlacementTest.UsePolicy`).Check(testkit.Rows( + "UsePolicy CREATE TABLE `UsePolicy` (\n" + + " `a` int(10) unsigned NOT NULL,\n" + + " `b` varchar(255) DEFAULT NULL,\n" + + " PRIMARY KEY (`a`) /*T![clustered_index] CLUSTERED */\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PLACEMENT POLICY=`PolicyTableTest` */")) + + is := s.dom.InfoSchema() + + db, ok := is.SchemaByName(model.NewCIStr("SchemaDirectPlacementTest")) + c.Assert(ok, IsTrue) + c.Assert(db.PlacementPolicyRef, IsNil) + c.Assert(db.DirectPlacementOpts, NotNil) + c.Assert(db.DirectPlacementOpts.PrimaryRegion, Matches, "nl") + c.Assert(db.DirectPlacementOpts.Regions, Matches, "se,nz") + c.Assert(db.DirectPlacementOpts.Followers, Equals, uint64(3)) + c.Assert(db.DirectPlacementOpts.Learners, Equals, uint64(0)) + + db, ok = is.SchemaByName(model.NewCIStr("SchemaPolicyPlacementTest")) + c.Assert(ok, IsTrue) + c.Assert(db.PlacementPolicyRef, NotNil) + c.Assert(db.DirectPlacementOpts, IsNil) + c.Assert(db.PlacementPolicyRef.Name.O, Equals, "PolicySchemaTest") +} diff --git a/domain/domain_test.go b/domain/domain_test.go index c4107bd7a9601..77fa19f5fe3ac 100644 --- a/domain/domain_test.go +++ b/domain/domain_test.go @@ -133,7 +133,7 @@ func SubTestInfo(t *testing.T) { Col: "utf8_bin", } ctx := mock.NewContext() - require.NoError(t, dom.ddl.CreateSchema(ctx, model.NewCIStr("aaa"), cs)) + require.NoError(t, dom.ddl.CreateSchema(ctx, model.NewCIStr("aaa"), cs, nil, nil)) require.NoError(t, dom.Reload()) require.Equal(t, int64(1), dom.InfoSchema().SchemaMetaVersion()) @@ -175,7 +175,7 @@ func SubTestDomain(t *testing.T) { Chs: "utf8", Col: "utf8_bin", } - err = dd.CreateSchema(ctx, model.NewCIStr("aaa"), cs) + err = dd.CreateSchema(ctx, model.NewCIStr("aaa"), cs, nil, nil) require.NoError(t, err) // Test for fetchSchemasWithTables when "tables" isn't nil. @@ -232,7 +232,7 @@ func SubTestDomain(t *testing.T) { require.Equal(t, tblInfo2, tbl.Meta()) // Test for tryLoadSchemaDiffs when "isTooOldSchema" is false. - err = dd.CreateSchema(ctx, model.NewCIStr("bbb"), cs) + err = dd.CreateSchema(ctx, model.NewCIStr("bbb"), cs, nil, nil) require.NoError(t, err) err = dom.Reload() diff --git a/executor/ddl.go b/executor/ddl.go index 975db3aa23322..0e94e697d9b43 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -247,19 +247,38 @@ func (e *DDLExec) executeRenameTable(s *ast.RenameTableStmt) error { } func (e *DDLExec) executeCreateDatabase(s *ast.CreateDatabaseStmt) error { - var opt *ast.CharsetOpt + var charOpt *ast.CharsetOpt + var directPlacementOpts *model.PlacementSettings + var placementPolicyRef *model.PolicyRefInfo if len(s.Options) != 0 { - opt = &ast.CharsetOpt{} + charOpt = &ast.CharsetOpt{} for _, val := range s.Options { switch val.Tp { case ast.DatabaseOptionCharset: - opt.Chs = val.Value + charOpt.Chs = val.Value case ast.DatabaseOptionCollate: - opt.Col = val.Value + charOpt.Col = val.Value + case ast.DatabaseOptionPlacementPrimaryRegion, ast.DatabaseOptionPlacementRegions, + ast.DatabaseOptionPlacementFollowerCount, ast.DatabaseOptionPlacementLeaderConstraints, + ast.DatabaseOptionPlacementLearnerCount, ast.DatabaseOptionPlacementVoterCount, + ast.DatabaseOptionPlacementSchedule, ast.DatabaseOptionPlacementConstraints, + ast.DatabaseOptionPlacementFollowerConstraints, ast.DatabaseOptionPlacementVoterConstraints, + ast.DatabaseOptionPlacementLearnerConstraints: + if directPlacementOpts == nil { + directPlacementOpts = &model.PlacementSettings{} + } + err := ddl.SetDirectPlacementOpt(directPlacementOpts, ast.PlacementOptionType(val.Tp), val.Value, val.UintValue) + if err != nil { + return err + } + case ast.DatabaseOptionPlacementPolicy: + placementPolicyRef = &model.PolicyRefInfo{ + Name: model.NewCIStr(val.Value), + } } } } - err := domain.GetDomain(e.ctx).DDL().CreateSchema(e.ctx, model.NewCIStr(s.Name), opt) + err := domain.GetDomain(e.ctx).DDL().CreateSchema(e.ctx, model.NewCIStr(s.Name), charOpt, directPlacementOpts, placementPolicyRef) if err != nil { if infoschema.ErrDatabaseExists.Equal(err) && s.IfNotExists { err = nil @@ -295,7 +314,7 @@ func (e *DDLExec) createSessionTemporaryTable(s *ast.CreateTableStmt) error { return err } - tbInfo, err := ddl.BuildSessionTemporaryTableInfo(e.ctx, is, s, dbInfo.Charset, dbInfo.Collate) + tbInfo, err := ddl.BuildSessionTemporaryTableInfo(e.ctx, is, s, dbInfo.Charset, dbInfo.Collate, dbInfo.PlacementPolicyRef, dbInfo.DirectPlacementOpts) if err != nil { return err } diff --git a/executor/show.go b/executor/show.go index 824c07e545ec5..88ad555306274 100644 --- a/executor/show.go +++ b/executor/show.go @@ -1314,9 +1314,7 @@ func ConstructResultOfShowCreateDatabase(ctx sessionctx.Context, dbInfo *model.D fmt.Fprintf(buf, "COLLATE %s ", dbInfo.Collate) } fmt.Fprint(buf, "*/") - return nil - } - if dbInfo.Collate != "" { + } else if dbInfo.Collate != "" { collInfo, err := collate.GetCollationByName(dbInfo.Collate) if err != nil { return errors.Trace(err) @@ -1326,10 +1324,15 @@ func ConstructResultOfShowCreateDatabase(ctx sessionctx.Context, dbInfo *model.D fmt.Fprintf(buf, "COLLATE %s ", dbInfo.Collate) } fmt.Fprint(buf, "*/") - return nil } // MySQL 5.7 always show the charset info but TiDB may ignore it, which makes a slight difference. We keep this // behavior unchanged because it is trivial enough. + if dbInfo.DirectPlacementOpts != nil { + fmt.Fprintf(buf, " %s", dbInfo.DirectPlacementOpts) + } + if dbInfo.PlacementPolicyRef != nil { + fmt.Fprintf(buf, " PLACEMENT POLICY = %s", stringutil.Escape(dbInfo.PlacementPolicyRef.Name.O, sqlMode)) + } return nil } diff --git a/session/session_test.go b/session/session_test.go index 736bb1abf5250..4a147cb65f843 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -3974,7 +3974,7 @@ func (s *testSessionSerialSuite) TestDoDDLJobQuit(c *C) { defer failpoint.Disable("github.com/pingcap/tidb/ddl/storeCloseInLoop") // this DDL call will enter deadloop before this fix - err = dom.DDL().CreateSchema(se, model.NewCIStr("testschema"), nil) + err = dom.DDL().CreateSchema(se, model.NewCIStr("testschema"), nil, nil, nil) c.Assert(err.Error(), Equals, "context canceled") } From e5bf177cd91ad689176c5f647e1e37d5cd47899c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Thu, 23 Sep 2021 10:00:46 +0800 Subject: [PATCH 07/13] *: Use snapshot interceptor to handle temporary table's insert/update/delete (#28218) --- executor/batch_checker.go | 4 +-- executor/insert.go | 14 ++++---- executor/insert_common.go | 14 ++------ executor/replace.go | 17 +++++---- sessionctx/variable/session.go | 64 ---------------------------------- table/tables/index.go | 2 +- table/tables/tables.go | 2 +- 7 files changed, 20 insertions(+), 97 deletions(-) diff --git a/executor/batch_checker.go b/executor/batch_checker.go index ed0af152a2560..a5ef7efde964c 100644 --- a/executor/batch_checker.go +++ b/executor/batch_checker.go @@ -235,9 +235,9 @@ func formatDataForDupError(data []types.Datum) (string, error) { // getOldRow gets the table record row from storage for batch check. // t could be a normal table or a partition, but it must not be a PartitionedTable. -func getOldRow(ctx context.Context, sctx sessionctx.Context, kvGetter kv.Getter, t table.Table, handle kv.Handle, +func getOldRow(ctx context.Context, sctx sessionctx.Context, txn kv.Transaction, t table.Table, handle kv.Handle, genExprs []expression.Expression) ([]types.Datum, error) { - oldValue, err := kvGetter.Get(ctx, tablecodec.EncodeRecordKey(t.RecordPrefix(), handle)) + oldValue, err := txn.Get(ctx, tablecodec.EncodeRecordKey(t.RecordPrefix(), handle)) if err != nil { return nil, err } diff --git a/executor/insert.go b/executor/insert.go index c598232e895f1..f5b443387dd75 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -21,9 +21,8 @@ import ( "runtime/trace" "time" - "github.com/pingcap/parser/model" - "github.com/opentracing/opentracing-go" + "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" @@ -188,8 +187,8 @@ func (e *InsertValues) prefetchDataCache(ctx context.Context, txn kv.Transaction } // updateDupRow updates a duplicate row to a new row. -func (e *InsertExec) updateDupRow(ctx context.Context, idxInBatch int, kvGetter kv.Getter, row toBeCheckedRow, handle kv.Handle, onDuplicate []*expression.Assignment) error { - oldRow, err := getOldRow(ctx, e.ctx, kvGetter, row.t, handle, e.GenExprs) +func (e *InsertExec) updateDupRow(ctx context.Context, idxInBatch int, txn kv.Transaction, row toBeCheckedRow, handle kv.Handle, onDuplicate []*expression.Assignment) error { + oldRow, err := getOldRow(ctx, e.ctx, txn, row.t, handle, e.GenExprs) if err != nil { return err } @@ -237,7 +236,6 @@ func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.D e.stats.Prefetch += time.Since(prefetchStart) } - txnValueGetter := e.txnValueGetter(txn) for i, r := range toBeCheckedRows { if r.handleKey != nil { handle, err := tablecodec.DecodeRowKey(r.handleKey.newKey) @@ -245,7 +243,7 @@ func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.D return err } - err = e.updateDupRow(ctx, i, txnValueGetter, r, handle, e.OnDuplicate) + err = e.updateDupRow(ctx, i, txn, r, handle, e.OnDuplicate) if err == nil { continue } @@ -255,7 +253,7 @@ func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.D } for _, uk := range r.uniqueKeys { - val, err := txnValueGetter.Get(ctx, uk.newKey) + val, err := txn.Get(ctx, uk.newKey) if err != nil { if kv.IsErrNotFound(err) { continue @@ -267,7 +265,7 @@ func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.D return err } - err = e.updateDupRow(ctx, i, txnValueGetter, r, handle, e.OnDuplicate) + err = e.updateDupRow(ctx, i, txn, r, handle, e.OnDuplicate) if err != nil { if kv.IsErrNotFound(err) { // Data index inconsistent? A unique key provide the handle information, but the diff --git a/executor/insert_common.go b/executor/insert_common.go index e77f1a7f2912c..72ed1f51b584f 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -1065,7 +1065,6 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D e.stats.Prefetch += time.Since(prefetchStart) } - txnValueGetter := e.txnValueGetter(txn) // append warnings and get no duplicated error rows for i, r := range toBeCheckedRows { if r.ignored { @@ -1073,7 +1072,7 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D } skip := false if r.handleKey != nil { - _, err := txnValueGetter.Get(ctx, r.handleKey.newKey) + _, err := txn.Get(ctx, r.handleKey.newKey) if err == nil { e.ctx.GetSessionVars().StmtCtx.AppendWarning(r.handleKey.dupErr) continue @@ -1083,7 +1082,7 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D } } for _, uk := range r.uniqueKeys { - _, err := txnValueGetter.Get(ctx, uk.newKey) + _, err := txn.Get(ctx, uk.newKey) if err == nil { // If duplicate keys were found in BatchGet, mark row = nil. e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr) @@ -1112,15 +1111,6 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D return nil } -func (e *InsertValues) txnValueGetter(txn kv.Transaction) kv.Getter { - tblInfo := e.Table.Meta() - if tblInfo.TempTableType == model.TempTableNone { - return txn - } - - return e.ctx.GetSessionVars().TemporaryTableTxnReader(txn, tblInfo) -} - func (e *InsertValues) addRecord(ctx context.Context, row []types.Datum) error { return e.addRecordWithAutoIDHint(ctx, row, 0) } diff --git a/executor/replace.go b/executor/replace.go index 78fb519049d84..cf96ec99320bd 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -62,9 +62,9 @@ func (e *ReplaceExec) Open(ctx context.Context) error { // removeRow removes the duplicate row and cleanup its keys in the key-value map, // but if the to-be-removed row equals to the to-be-added row, no remove or add things to do. -func (e *ReplaceExec) removeRow(ctx context.Context, kvGetter kv.Getter, handle kv.Handle, r toBeCheckedRow) (bool, error) { +func (e *ReplaceExec) removeRow(ctx context.Context, txn kv.Transaction, handle kv.Handle, r toBeCheckedRow) (bool, error) { newRow := r.row - oldRow, err := getOldRow(ctx, e.ctx, kvGetter, r.t, handle, e.GenExprs) + oldRow, err := getOldRow(ctx, e.ctx, txn, r.t, handle, e.GenExprs) if err != nil { logutil.BgLogger().Error("get old row failed when replace", zap.String("handle", handle.String()), @@ -120,15 +120,14 @@ func (e *ReplaceExec) replaceRow(ctx context.Context, r toBeCheckedRow) error { return err } - txnValueGetter := e.txnValueGetter(txn) if r.handleKey != nil { handle, err := tablecodec.DecodeRowKey(r.handleKey.newKey) if err != nil { return err } - if _, err := txnValueGetter.Get(ctx, r.handleKey.newKey); err == nil { - rowUnchanged, err := e.removeRow(ctx, txnValueGetter, handle, r) + if _, err := txn.Get(ctx, r.handleKey.newKey); err == nil { + rowUnchanged, err := e.removeRow(ctx, txn, handle, r) if err != nil { return err } @@ -144,7 +143,7 @@ func (e *ReplaceExec) replaceRow(ctx context.Context, r toBeCheckedRow) error { // Keep on removing duplicated rows. for { - rowUnchanged, foundDupKey, err := e.removeIndexRow(ctx, txnValueGetter, r) + rowUnchanged, foundDupKey, err := e.removeIndexRow(ctx, txn, r) if err != nil { return err } @@ -171,9 +170,9 @@ func (e *ReplaceExec) replaceRow(ctx context.Context, r toBeCheckedRow) error { // 2. bool: true when found the duplicated key. This only means that duplicated key was found, // and the row was removed. // 3. error: the error. -func (e *ReplaceExec) removeIndexRow(ctx context.Context, kvGetter kv.Getter, r toBeCheckedRow) (bool, bool, error) { +func (e *ReplaceExec) removeIndexRow(ctx context.Context, txn kv.Transaction, r toBeCheckedRow) (bool, bool, error) { for _, uk := range r.uniqueKeys { - val, err := kvGetter.Get(ctx, uk.newKey) + val, err := txn.Get(ctx, uk.newKey) if err != nil { if kv.IsErrNotFound(err) { continue @@ -184,7 +183,7 @@ func (e *ReplaceExec) removeIndexRow(ctx context.Context, kvGetter kv.Getter, r if err != nil { return false, true, err } - rowUnchanged, err := e.removeRow(ctx, kvGetter, handle, r) + rowUnchanged, err := e.removeRow(ctx, txn, handle, r) if err != nil { return false, true, err } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 7ddccefbc1686..72c6df7d22174 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -16,7 +16,6 @@ package variable import ( "bytes" - "context" "crypto/tls" "encoding/binary" "fmt" @@ -2304,66 +2303,3 @@ func (s *SessionVars) GetSeekFactor(tbl *model.TableInfo) float64 { } return s.seekFactor } - -// TemporaryTableSnapshotReader can read the temporary table snapshot data -type TemporaryTableSnapshotReader struct { - temporaryTableData TemporaryTableData -} - -// Get gets the value for key k from snapshot. -func (s *TemporaryTableSnapshotReader) Get(ctx context.Context, k kv.Key) ([]byte, error) { - if s.temporaryTableData == nil { - return nil, kv.ErrNotExist - } - - v, err := s.temporaryTableData.Get(ctx, k) - if err != nil { - return v, err - } - - if len(v) == 0 { - return nil, kv.ErrNotExist - } - - return v, nil -} - -// TemporaryTableSnapshotReader can read the temporary table snapshot data -func (s *SessionVars) TemporaryTableSnapshotReader(tblInfo *model.TableInfo) *TemporaryTableSnapshotReader { - if tblInfo.TempTableType == model.TempTableGlobal { - return &TemporaryTableSnapshotReader{nil} - } - return &TemporaryTableSnapshotReader{s.TemporaryTableData} -} - -// TemporaryTableTxnReader can read the temporary table txn data -type TemporaryTableTxnReader struct { - memBuffer kv.MemBuffer - snapshot *TemporaryTableSnapshotReader -} - -// Get gets the value for key k from txn. -func (s *TemporaryTableTxnReader) Get(ctx context.Context, k kv.Key) ([]byte, error) { - v, err := s.memBuffer.Get(ctx, k) - if err == nil { - if len(v) == 0 { - return nil, kv.ErrNotExist - } - - return v, nil - } - - if !kv.IsErrNotFound(err) { - return v, err - } - - return s.snapshot.Get(ctx, k) -} - -// TemporaryTableTxnReader can read the temporary table txn data -func (s *SessionVars) TemporaryTableTxnReader(txn kv.Transaction, tblInfo *model.TableInfo) *TemporaryTableTxnReader { - return &TemporaryTableTxnReader{ - memBuffer: txn.GetMemBuffer(), - snapshot: s.TemporaryTableSnapshotReader(tblInfo), - } -} diff --git a/table/tables/index.go b/table/tables/index.go index b592c894f74ad..652a4572be57d 100644 --- a/table/tables/index.go +++ b/table/tables/index.go @@ -202,7 +202,7 @@ func (c *index) Create(sctx sessionctx.Context, txn kv.Transaction, indexedValue var value []byte if c.tblInfo.TempTableType != model.TempTableNone { // Always check key for temporary table because it does not write to TiKV - value, err = sctx.GetSessionVars().TemporaryTableTxnReader(txn, c.tblInfo).Get(ctx, key) + value, err = txn.Get(ctx, key) } else if sctx.GetSessionVars().LazyCheckKeyNotExists() { value, err = txn.GetMemBuffer().Get(ctx, key) } else { diff --git a/table/tables/tables.go b/table/tables/tables.go index 1bbcdd4f3590f..20504dd21a4f1 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -791,7 +791,7 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts . if (t.meta.IsCommonHandle || t.meta.PKIsHandle) && !skipCheck && !opt.SkipHandleCheck { if t.meta.TempTableType != model.TempTableNone { // Always check key for temporary table because it does not write to TiKV - _, err = sctx.GetSessionVars().TemporaryTableTxnReader(txn, t.meta).Get(ctx, key) + _, err = txn.Get(ctx, key) } else if sctx.GetSessionVars().LazyCheckKeyNotExists() { var v []byte v, err = txn.GetMemBuffer().Get(ctx, key) From c82c9d7068a7583fd538b9535d50a2b79372f9f8 Mon Sep 17 00:00:00 2001 From: Mattias Jonsson Date: Thu, 23 Sep 2021 04:14:45 +0200 Subject: [PATCH 08/13] *: Added server_debug to have a better experience when using a debugger (#26670) --- Makefile | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/Makefile b/Makefile index f50128cabee39..9157e8d9dbcd5 100644 --- a/Makefile +++ b/Makefile @@ -134,6 +134,13 @@ else CGO_ENABLED=1 $(GOBUILD) $(RACE_FLAG) -ldflags '$(LDFLAGS) $(CHECK_FLAG)' -o '$(TARGET)' tidb-server/main.go endif +server_debug: +ifeq ($(TARGET), "") + CGO_ENABLED=1 $(GOBUILD) -gcflags="all=-N -l" $(RACE_FLAG) -ldflags '$(LDFLAGS) $(CHECK_FLAG)' -o bin/tidb-server-debug tidb-server/main.go +else + CGO_ENABLED=1 $(GOBUILD) -gcflags="all=-N -l" $(RACE_FLAG) -ldflags '$(LDFLAGS) $(CHECK_FLAG)' -o '$(TARGET)' tidb-server/main.go +endif + server_check: ifeq ($(TARGET), "") $(GOBUILD) $(RACE_FLAG) -ldflags '$(CHECK_LDFLAGS)' -o bin/tidb-server tidb-server/main.go From e72dc49361202e50ad031b56837aa43aa9db567a Mon Sep 17 00:00:00 2001 From: tison Date: Thu, 23 Sep 2021 10:38:46 +0800 Subject: [PATCH 09/13] workflow: backport-x.y.z to fixes-x.y.z (#28115) --- .github/workflows/bug-closed.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/bug-closed.yml b/.github/workflows/bug-closed.yml index 3458315e01816..36e8574f00e19 100644 --- a/.github/workflows/bug-closed.yml +++ b/.github/workflows/bug-closed.yml @@ -10,7 +10,7 @@ jobs: if: | contains(github.event.issue.labels.*.name, 'type/bug') && !(contains(join(github.event.issue.labels.*.name, ', '), 'affects-') && - contains(join(github.event.issue.labels.*.name, ', '), 'backport-')) + contains(join(github.event.issue.labels.*.name, ', '), 'fixes-')) runs-on: ubuntu-latest permissions: issues: write @@ -25,5 +25,4 @@ jobs: with: issue-number: ${{ github.event.issue.number }} body: | - Please check whether the issue should be labeled with 'affects-x.y' or 'backport-x.y.z', - and then remove 'needs-more-info' label. + Please check whether the issue should be labeled with 'affects-x.y' or 'fixes-x.y.z', and then remove 'needs-more-info' label. From 452d34c1ee2764768365f610f39711e9791505bb Mon Sep 17 00:00:00 2001 From: xufei Date: Thu, 23 Sep 2021 11:42:46 +0800 Subject: [PATCH 10/13] planner/core: fix index out of bound bug when empty dual table is remove for mpp query (#28251) --- executor/tiflash_test.go | 22 ++++++++++++++++++++++ planner/core/stringer.go | 19 +++++++++++++++++-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index b03f6f82ef3a2..12cc111b9a561 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -697,6 +697,28 @@ func (s *tiflashTestSuite) TestMppUnionAll(c *C) { } +func (s *tiflashTestSuite) TestUnionWithEmptyDualTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t (a int not null, b int, c varchar(20))") + tk.MustExec("create table t1 (a int, b int not null, c double)") + tk.MustExec("alter table t set tiflash replica 1") + tk.MustExec("alter table t1 set tiflash replica 1") + tb := testGetTableByName(c, tk.Se, "test", "t") + err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tb = testGetTableByName(c, tk.Se, "test", "t1") + err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("insert into t values(1,2,3)") + tk.MustExec("insert into t1 values(1,2,3)") + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + tk.MustExec("set @@session.tidb_enforce_mpp=ON") + tk.MustQuery("select count(*) from (select a , b from t union all select a , c from t1 where false) tt").Check(testkit.Rows("1")) +} + func (s *tiflashTestSuite) TestMppApply(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/stringer.go b/planner/core/stringer.go index 914cdc1dfe6c2..ce41a53bf66c8 100644 --- a/planner/core/stringer.go +++ b/planner/core/stringer.go @@ -28,10 +28,25 @@ func ToString(p Plan) string { return strings.Join(strs, "->") } +func needIncludeChildrenString(plan Plan) bool { + switch x := plan.(type) { + case *LogicalUnionAll, *PhysicalUnionAll, *LogicalPartitionUnionAll: + // after https://github.com/pingcap/tidb/pull/25218, the union may contain less than 2 children, + // but we still wants to include its child plan's information when calling `toString` on union. + return true + case LogicalPlan: + return len(x.Children()) > 1 + case PhysicalPlan: + return len(x.Children()) > 1 + default: + return false + } +} + func toString(in Plan, strs []string, idxs []int) ([]string, []int) { switch x := in.(type) { case LogicalPlan: - if len(x.Children()) > 1 { + if needIncludeChildrenString(in) { idxs = append(idxs, len(strs)) } @@ -40,7 +55,7 @@ func toString(in Plan, strs []string, idxs []int) ([]string, []int) { } case *PhysicalExchangeReceiver: // do nothing case PhysicalPlan: - if len(x.Children()) > 1 { + if needIncludeChildrenString(in) { idxs = append(idxs, len(strs)) } From 624f7cab3b4ee8f06e4a03159f8a4e2cac05a61a Mon Sep 17 00:00:00 2001 From: xufei Date: Thu, 23 Sep 2021 13:22:46 +0800 Subject: [PATCH 11/13] copr: Fix bug that mpp node availability detect does not work in some corner cases (#28201) --- store/copr/batch_coprocessor.go | 106 ++++++++++++++++---------------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/store/copr/batch_coprocessor.go b/store/copr/batch_coprocessor.go index b00bfa738fd31..c50dd3f68fa5b 100644 --- a/store/copr/batch_coprocessor.go +++ b/store/copr/batch_coprocessor.go @@ -105,10 +105,11 @@ func (rs *batchCopResponse) RespTime() time.Duration { // if there is only 1 available store, then put the region to the related store // otherwise, use a greedy algorithm to put it into the store with highest weight func balanceBatchCopTask(ctx context.Context, kvStore *kvStore, originalTasks []*batchCopTask, mppStoreLastFailTime map[string]time.Time, ttl time.Duration) []*batchCopTask { - if len(originalTasks) <= 1 { + isMPP := mppStoreLastFailTime != nil + // for mpp, we still need to detect the store availability + if len(originalTasks) <= 1 && !isMPP { return originalTasks } - isMPP := mppStoreLastFailTime != nil cache := kvStore.GetRegionCache() storeTaskMap := make(map[uint64]*batchCopTask) // storeCandidateRegionMap stores all the possible store->region map. Its content is @@ -232,16 +233,28 @@ func balanceBatchCopTask(ctx context.Context, kvStore *kvStore, originalTasks [] } } } - if totalRemainingRegionNum == 0 { - return originalTasks - } - avgStorePerRegion := float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) - findNextStore := func(candidateStores []uint64) uint64 { - store := uint64(math.MaxUint64) - weightedRegionNum := math.MaxFloat64 - if candidateStores != nil { - for _, storeID := range candidateStores { + if totalRemainingRegionNum > 0 { + avgStorePerRegion := float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) + findNextStore := func(candidateStores []uint64) uint64 { + store := uint64(math.MaxUint64) + weightedRegionNum := math.MaxFloat64 + if candidateStores != nil { + for _, storeID := range candidateStores { + if _, validStore := storeCandidateRegionMap[storeID]; !validStore { + continue + } + num := float64(len(storeCandidateRegionMap[storeID]))/avgStorePerRegion + float64(len(storeTaskMap[storeID].regionInfos)) + if num < weightedRegionNum { + store = storeID + weightedRegionNum = num + } + } + if store != uint64(math.MaxUint64) { + return store + } + } + for storeID := range storeTaskMap { if _, validStore := storeCandidateRegionMap[storeID]; !validStore { continue } @@ -251,57 +264,44 @@ func balanceBatchCopTask(ctx context.Context, kvStore *kvStore, originalTasks [] weightedRegionNum = num } } - if store != uint64(math.MaxUint64) { - return store - } + return store } - for storeID := range storeTaskMap { - if _, validStore := storeCandidateRegionMap[storeID]; !validStore { - continue + + store := findNextStore(nil) + for totalRemainingRegionNum > 0 { + if store == uint64(math.MaxUint64) { + break } - num := float64(len(storeCandidateRegionMap[storeID]))/avgStorePerRegion + float64(len(storeTaskMap[storeID].regionInfos)) - if num < weightedRegionNum { - store = storeID - weightedRegionNum = num + var key string + var ri RegionInfo + for key, ri = range storeCandidateRegionMap[store] { + // get the first region + break } - } - return store - } - - store := findNextStore(nil) - for totalRemainingRegionNum > 0 { - if store == uint64(math.MaxUint64) { - break - } - var key string - var ri RegionInfo - for key, ri = range storeCandidateRegionMap[store] { - // get the first region - break - } - storeTaskMap[store].regionInfos = append(storeTaskMap[store].regionInfos, ri) - totalRemainingRegionNum-- - for _, id := range ri.AllStores { - if _, ok := storeCandidateRegionMap[id]; ok { - delete(storeCandidateRegionMap[id], key) - totalRegionCandidateNum-- - if len(storeCandidateRegionMap[id]) == 0 { - delete(storeCandidateRegionMap, id) + storeTaskMap[store].regionInfos = append(storeTaskMap[store].regionInfos, ri) + totalRemainingRegionNum-- + for _, id := range ri.AllStores { + if _, ok := storeCandidateRegionMap[id]; ok { + delete(storeCandidateRegionMap[id], key) + totalRegionCandidateNum-- + if len(storeCandidateRegionMap[id]) == 0 { + delete(storeCandidateRegionMap, id) + } } } + if totalRemainingRegionNum > 0 { + avgStorePerRegion = float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) + // it is not optimal because we only check the stores that affected by this region, in fact in order + // to find out the store with the lowest weightedRegionNum, all stores should be checked, but I think + // check only the affected stores is more simple and will get a good enough result + store = findNextStore(ri.AllStores) + } } if totalRemainingRegionNum > 0 { - avgStorePerRegion = float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) - // it is not optimal because we only check the stores that affected by this region, in fact in order - // to find out the store with the lowest weightedRegionNum, all stores should be checked, but I think - // check only the affected stores is more simple and will get a good enough result - store = findNextStore(ri.AllStores) + logutil.BgLogger().Warn("Some regions are not used when trying to balance batch cop task, give up balancing") + return originalTasks } } - if totalRemainingRegionNum > 0 { - logutil.BgLogger().Warn("Some regions are not used when trying to balance batch cop task, give up balancing") - return originalTasks - } var ret []*batchCopTask for _, task := range storeTaskMap { From ad9ec4e93ce815e774932e856bbede7aa4611b70 Mon Sep 17 00:00:00 2001 From: Mattias Jonsson Date: Thu, 23 Sep 2021 07:36:46 +0200 Subject: [PATCH 12/13] *: Add https/http depending on config for pdapi (#27695) --- domain/infosync/info.go | 21 ++++----------------- util/misc.go | 8 ++++++++ util/misc_test.go | 11 +++++++++++ 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/domain/infosync/info.go b/domain/infosync/info.go index 266946ae1e51b..08180ece782de 100644 --- a/domain/infosync/info.go +++ b/domain/infosync/info.go @@ -24,7 +24,6 @@ import ( "os" "path" "strconv" - "strings" "sync/atomic" "time" @@ -325,13 +324,7 @@ func doRequest(ctx context.Context, addrs []string, route, method string, body i var req *http.Request var res *http.Response for _, addr := range addrs { - var url string - if strings.HasPrefix(addr, "http") { - url = fmt.Sprintf("%s%s", addr, route) - } else { - url = fmt.Sprintf("%s://%s%s", util2.InternalHTTPSchema(), addr, route) - } - + url := util2.ComposeURL(addr, route) req, err = http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, err @@ -697,15 +690,9 @@ func (is *InfoSyncer) getPrometheusAddr() (string, error) { if !clientAvailable || len(pdAddrs) == 0 { return "", errors.Errorf("pd unavailable") } - // Get prometheus address from pdApi. - var url, res string - if strings.HasPrefix(pdAddrs[0], "http://") { - url = fmt.Sprintf("%s%s", pdAddrs[0], pdapi.Config) - } else { - url = fmt.Sprintf("http://%s%s", pdAddrs[0], pdapi.Config) - } - resp, err := http.Get(url) // #nosec G107 + url := util2.ComposeURL(pdAddrs[0], pdapi.Config) + resp, err := util2.InternalHTTPClient().Get(url) if err != nil { return "", err } @@ -715,7 +702,7 @@ func (is *InfoSyncer) getPrometheusAddr() (string, error) { if err != nil { return "", err } - res = metricStorage.PDServer.MetricStorage + res := metricStorage.PDServer.MetricStorage // Get prometheus address from etcdApi. if res == "" { diff --git a/util/misc.go b/util/misc.go index d877f259d4347..efc2068ed0f03 100644 --- a/util/misc.go +++ b/util/misc.go @@ -584,6 +584,14 @@ func initInternalClient() { } } +// ComposeURL adds HTTP schema if missing and concats address with path +func ComposeURL(address, path string) string { + if strings.HasPrefix(address, "http://") || strings.HasPrefix(address, "https://") { + return fmt.Sprintf("%s%s", address, path) + } + return fmt.Sprintf("%s://%s%s", InternalHTTPSchema(), address, path) +} + // GetLocalIP will return a local IP(non-loopback, non 0.0.0.0), if there is one func GetLocalIP() string { addrs, err := net.InterfaceAddrs() diff --git a/util/misc_test.go b/util/misc_test.go index fdd6e2b5ca341..1bc1a46e4e808 100644 --- a/util/misc_test.go +++ b/util/misc_test.go @@ -193,3 +193,14 @@ func TestToPB(t *testing.T) { assert.Equal(t, "column_id:1 collation:45 columnLen:-1 decimal:-1 ", ColumnToProto(column).String()) assert.Equal(t, "column_id:1 collation:45 columnLen:-1 decimal:-1 ", ColumnsToProto([]*model.ColumnInfo{column, column2}, false)[0].String()) } + +func TestComposeURL(t *testing.T) { + t.Parallel() + // TODO Setup config for TLS and verify https protocol output + assert.Equal(t, ComposeURL("server.example.com", ""), "http://server.example.com") + assert.Equal(t, ComposeURL("httpserver.example.com", ""), "http://httpserver.example.com") + assert.Equal(t, ComposeURL("http://httpserver.example.com", "/"), "http://httpserver.example.com/") + assert.Equal(t, ComposeURL("https://httpserver.example.com", "/api/test"), "https://httpserver.example.com/api/test") + assert.Equal(t, ComposeURL("http://server.example.com", ""), "http://server.example.com") + assert.Equal(t, ComposeURL("https://server.example.com", ""), "https://server.example.com") +} From f2cf4cc7ca0c90f8fc4162d5b83e65a35ab25d98 Mon Sep 17 00:00:00 2001 From: lysu Date: Thu, 23 Sep 2021 14:52:45 +0800 Subject: [PATCH 13/13] table: check duplicate row_id in insert stmt (#27455) --- ddl/db_integration_test.go | 7 ++++--- table/index.go | 10 ++-------- table/tables/tables.go | 3 +-- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go index bb045d6ff80b0..acf26a2cb0e68 100644 --- a/ddl/db_integration_test.go +++ b/ddl/db_integration_test.go @@ -2649,16 +2649,17 @@ func (s *testIntegrationSuite3) TestAutoIncrementForce(c *C) { } // Rebase _tidb_row_id. tk.MustExec("create table t (a int);") + tk.MustExec("alter table t force auto_increment = 2;") tk.MustExec("insert into t values (1),(2);") - tk.MustQuery("select a, _tidb_rowid from t;").Check(testkit.Rows("1 1", "2 2")) + tk.MustQuery("select a, _tidb_rowid from t;").Check(testkit.Rows("1 2", "2 3")) // Cannot set next global ID to 0. tk.MustGetErrCode("alter table t force auto_increment = 0;", errno.ErrAutoincReadFailed) tk.MustExec("alter table t force auto_increment = 1;") c.Assert(getNextGlobalID(), Equals, uint64(1)) // inserting new rows can overwrite the existing data. tk.MustExec("insert into t values (3);") - tk.MustExec("insert into t values (3);") - tk.MustQuery("select a, _tidb_rowid from t;").Check(testkit.Rows("3 1", "3 2")) + c.Assert(tk.ExecToErr("insert into t values (3);").Error(), Equals, "[kv:1062]Duplicate entry '2' for key 'PRIMARY'") + tk.MustQuery("select a, _tidb_rowid from t;").Check(testkit.Rows("3 1", "1 2", "2 3")) // Rebase auto_increment. tk.MustExec("drop table if exists t;") diff --git a/table/index.go b/table/index.go index 0b64adb484446..7974974d879fa 100644 --- a/table/index.go +++ b/table/index.go @@ -32,9 +32,8 @@ type IndexIterator interface { // CreateIdxOpt contains the options will be used when creating an index. type CreateIdxOpt struct { - Ctx context.Context - SkipHandleCheck bool // If true, skip the handle constraint check. - Untouched bool // If true, the index key/value is no need to commit. + Ctx context.Context + Untouched bool // If true, the index key/value is no need to commit. } // CreateIdxOptFunc is defined for the Create() method of Index interface. @@ -42,11 +41,6 @@ type CreateIdxOpt struct { // https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis type CreateIdxOptFunc func(*CreateIdxOpt) -// SkipHandleCheck is a defined value of CreateIdxFunc. -var SkipHandleCheck CreateIdxOptFunc = func(opt *CreateIdxOpt) { - opt.SkipHandleCheck = true -} - // IndexIsUntouched uses to indicate the index kv is untouched. var IndexIsUntouched CreateIdxOptFunc = func(opt *CreateIdxOpt) { opt.Untouched = true diff --git a/table/tables/tables.go b/table/tables/tables.go index 20504dd21a4f1..3ef849ca91391 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -787,8 +787,7 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts . value := writeBufs.RowValBuf var setPresume bool - skipCheck := sctx.GetSessionVars().StmtCtx.BatchCheck - if (t.meta.IsCommonHandle || t.meta.PKIsHandle) && !skipCheck && !opt.SkipHandleCheck { + if !sctx.GetSessionVars().StmtCtx.BatchCheck { if t.meta.TempTableType != model.TempTableNone { // Always check key for temporary table because it does not write to TiKV _, err = txn.Get(ctx, key)