From 9e157278c7464237783bf11adcedcd08d7f0d580 Mon Sep 17 00:00:00 2001 From: disksing Date: Tue, 27 Jul 2021 16:00:16 +0800 Subject: [PATCH] placement: return rules after clone (#3892) * placement: return rules after clone Signed-off-by: disksing * fix test Signed-off-by: disksing * move stringsEqual to typeutils Signed-off-by: disksing Co-authored-by: Ti Chi Robot --- pkg/typeutil/comparison.go | 13 ++++++++++++ server/api/cluster_test.go | 5 ++++- server/api/rule_test.go | 4 +++- server/schedule/placement/rule.go | 9 ++++++++ server/schedule/placement/rule_manager.go | 16 ++++++++++---- .../schedule/placement/rule_manager_test.go | 21 +++++++++++++++---- server/server.go | 6 ++---- 7 files changed, 60 insertions(+), 14 deletions(-) diff --git a/pkg/typeutil/comparison.go b/pkg/typeutil/comparison.go index 3f10fe51593..73b9aef05f9 100644 --- a/pkg/typeutil/comparison.go +++ b/pkg/typeutil/comparison.go @@ -38,3 +38,16 @@ func MinDuration(a, b time.Duration) time.Duration { } return b } + +// StringsEqual checks if two string slices are equal. Empyt slice and nil are considered equal. +func StringsEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/server/api/cluster_test.go b/server/api/cluster_test.go index 33e86c91de7..1914823178e 100644 --- a/server/api/cluster_test.go +++ b/server/api/cluster_test.go @@ -49,9 +49,12 @@ func (s *testClusterSuite) TestCluster(c *C) { s.testGetClusterStatus(c) s.svr.GetPersistOptions().SetPlacementRuleEnabled(true) s.svr.GetPersistOptions().GetReplicationConfig().LocationLabels = []string{"host"} - rule := s.svr.GetRaftCluster().GetRuleManager().GetRule("pd", "default") + rm := s.svr.GetRaftCluster().GetRuleManager() + rule := rm.GetRule("pd", "default") rule.LocationLabels = []string{"host"} rule.Count = 1 + rm.SetRule(rule) + // Test set the config url := fmt.Sprintf("%s/cluster", s.urlPrefix) c1 := &metapb.Cluster{} diff --git a/server/api/rule_test.go b/server/api/rule_test.go index 84755bc0551..51cafd893f4 100644 --- a/server/api/rule_test.go +++ b/server/api/rule_test.go @@ -239,7 +239,9 @@ func (s *testRuleSuite) TestSetAll(c *C) { rule6 := placement.Rule{GroupID: "pd", ID: "default", StartKeyHex: "", EndKeyHex: "", Role: "voter", Count: 3} s.svr.GetPersistOptions().GetReplicationConfig().LocationLabels = []string{"host"} - s.svr.GetRaftCluster().GetRuleManager().GetRule("pd", "default").LocationLabels = []string{"host"} + defaultRule := s.svr.GetRaftCluster().GetRuleManager().GetRule("pd", "default") + defaultRule.LocationLabels = []string{"host"} + s.svr.GetRaftCluster().GetRuleManager().SetRule(defaultRule) successData, err := json.Marshal([]*placement.Rule{&rule1, &rule2}) c.Assert(err, IsNil) diff --git a/server/schedule/placement/rule.go b/server/schedule/placement/rule.go index d5c167f90ae..3cfeb935011 100644 --- a/server/schedule/placement/rule.go +++ b/server/schedule/placement/rule.go @@ -73,6 +73,15 @@ func (r *Rule) String() string { return string(b) } +// Clone returns a copy of Rule. +func (r *Rule) Clone() *Rule { + var clone Rule + json.Unmarshal([]byte(r.String()), &clone) + clone.StartKey = append(r.StartKey[:0:0], r.StartKey...) + clone.EndKey = append(r.EndKey[:0:0], r.EndKey...) + return &clone +} + // Key returns (groupID, ID) as the global unique key of a rule. func (r *Rule) Key() [2]string { return [2]string{r.GroupID, r.ID} diff --git a/server/schedule/placement/rule_manager.go b/server/schedule/placement/rule_manager.go index 288329757bf..202111cbf72 100644 --- a/server/schedule/placement/rule_manager.go +++ b/server/schedule/placement/rule_manager.go @@ -216,7 +216,10 @@ func (m *RuleManager) adjustRule(r *Rule, groupID string) (err error) { func (m *RuleManager) GetRule(group, id string) *Rule { m.RLock() defer m.RUnlock() - return m.ruleConfig.getRule([2]string{group, id}) + if r := m.ruleConfig.getRule([2]string{group, id}); r != nil { + return r.Clone() + } + return nil } // SetRule inserts or updates a Rule. @@ -261,7 +264,7 @@ func (m *RuleManager) GetAllRules() []*Rule { defer m.RUnlock() rules := make([]*Rule, 0, len(m.ruleConfig.rules)) for _, r := range m.ruleConfig.rules { - rules = append(rules, r) + rules = append(rules, r.Clone()) } sortRules(rules) return rules @@ -274,7 +277,7 @@ func (m *RuleManager) GetRulesByGroup(group string) []*Rule { var rules []*Rule for _, r := range m.ruleConfig.rules { if r.GroupID == group { - rules = append(rules, r) + rules = append(rules, r.Clone()) } } sortRules(rules) @@ -285,7 +288,12 @@ func (m *RuleManager) GetRulesByGroup(group string) []*Rule { func (m *RuleManager) GetRulesByKey(key []byte) []*Rule { m.RLock() defer m.RUnlock() - return m.ruleList.getRulesByKey(key) + rules := m.ruleList.getRulesByKey(key) + ret := make([]*Rule, 0, len(rules)) + for _, r := range rules { + ret = append(ret, r.Clone()) + } + return ret } // GetRulesForApplyRegion returns the rules list that should be applied to a region. diff --git a/server/schedule/placement/rule_manager_test.go b/server/schedule/placement/rule_manager_test.go index 1d8fb51b75d..b11e0f793aa 100644 --- a/server/schedule/placement/rule_manager_test.go +++ b/server/schedule/placement/rule_manager_test.go @@ -107,16 +107,29 @@ func (s *testManagerSuite) TestSaveLoad(c *C) { {GroupID: "foo", ID: "bar", Role: "learner", Count: 1}, } for _, r := range rules { - c.Assert(s.manager.SetRule(r), IsNil) + c.Assert(s.manager.SetRule(r.Clone()), IsNil) } m2 := NewRuleManager(s.store, nil) err := m2.Initialize(3, []string{"no", "labels"}) c.Assert(err, IsNil) c.Assert(m2.GetAllRules(), HasLen, 3) - c.Assert(m2.GetRule("pd", "default"), DeepEquals, rules[0]) - c.Assert(m2.GetRule("foo", "baz"), DeepEquals, rules[1]) - c.Assert(m2.GetRule("foo", "bar"), DeepEquals, rules[2]) + c.Assert(m2.GetRule("pd", "default").String(), Equals, rules[0].String()) + c.Assert(m2.GetRule("foo", "baz").String(), Equals, rules[1].String()) + c.Assert(m2.GetRule("foo", "bar").String(), Equals, rules[2].String()) +} + +// https://github.com/tikv/pd/issues/3886 +func (s *testManagerSuite) TestSetAfterGet(c *C) { + rule := s.manager.GetRule("pd", "default") + rule.Count = 1 + s.manager.SetRule(rule) + + m2 := NewRuleManager(s.store, nil) + err := m2.Initialize(100, []string{}) + c.Assert(err, IsNil) + rule = m2.GetRule("pd", "default") + c.Assert(rule.Count, Equals, 1) } func (s *testManagerSuite) checkRules(c *C, rules []*Rule, expect [][2]string) { diff --git a/server/server.go b/server/server.go index f5375859517..b4ff35ed4c3 100644 --- a/server/server.go +++ b/server/server.go @@ -22,7 +22,6 @@ import ( "os" "path" "path/filepath" - "reflect" "strconv" "strings" "sync" @@ -853,15 +852,14 @@ func (s *Server) SetReplicationConfig(cfg config.ReplicationConfig) error { len(defaultRule.StartKey) == 0 && len(defaultRule.EndKey) == 0) { return errors.New("cannot update MaxReplicas or LocationLabels when placement rules feature is enabled and not only default rule exists, please update rule instead") } - rule = defaultRule - if !(rule.Count == int(old.MaxReplicas) && reflect.DeepEqual(rule.LocationLabels, []string(old.LocationLabels))) { + if !(defaultRule.Count == int(old.MaxReplicas) && typeutil.StringsEqual(defaultRule.LocationLabels, []string(old.LocationLabels))) { return errors.New("cannot to update replication config, the default rules do not consistent with replication config, please update rule instead") } return nil } - if !(cfg.MaxReplicas == old.MaxReplicas && reflect.DeepEqual(cfg.LocationLabels, old.LocationLabels)) { + if !(cfg.MaxReplicas == old.MaxReplicas && typeutil.StringsEqual(cfg.LocationLabels, old.LocationLabels)) { if err := CheckInDefaultRule(); err != nil { return err }