diff --git a/executor/prepared.go b/executor/prepared.go index f1c34c14032b3..e7c5d3d340180 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -237,6 +237,11 @@ func (e *DeallocateExec) Next(ctx context.Context, chk *chunk.Chunk) error { return errors.Trace(plannercore.ErrStmtNotFound) } delete(vars.PreparedStmtNameToID, e.Name) + if plannercore.PreparedPlanCacheEnabled() { + e.ctx.PreparedPlanCache().Delete(plannercore.NewPSTMTPlanCacheKey( + vars, id, vars.PreparedStmts[id].SchemaVersion, + )) + } delete(vars.PreparedStmts, id) return nil } diff --git a/executor/prepared_test.go b/executor/prepared_test.go index 1f4771541361a..37b475d153eef 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -567,3 +567,37 @@ func (s *testSuite) TestPreparedDelete(c *C) { result.Check(nil) } } + +func (s *testSuite) TestPrepareDealloc(c *C) { + orgEnable := plannercore.PreparedPlanCacheEnabled() + orgCapacity := plannercore.PreparedPlanCacheCapacity + defer func() { + plannercore.SetPreparedPlanCache(orgEnable) + plannercore.PreparedPlanCacheCapacity = orgCapacity + }() + plannercore.SetPreparedPlanCache(true) + plannercore.PreparedPlanCacheCapacity = 3 + + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists prepare_test") + tk.MustExec("create table prepare_test (id int PRIMARY KEY, c1 int)") + + c.Assert(tk.Se.PreparedPlanCache().Size(), Equals, 0) + tk.MustExec(`prepare stmt1 from 'select * from prepare_test'`) + tk.MustExec("execute stmt1") + tk.MustExec(`prepare stmt2 from 'select * from prepare_test'`) + tk.MustExec("execute stmt2") + tk.MustExec(`prepare stmt3 from 'select * from prepare_test'`) + tk.MustExec("execute stmt3") + tk.MustExec(`prepare stmt4 from 'select * from prepare_test'`) + tk.MustExec("execute stmt4") + c.Assert(tk.Se.PreparedPlanCache().Size(), Equals, 3) + + tk.MustExec("deallocate prepare stmt1") + c.Assert(tk.Se.PreparedPlanCache().Size(), Equals, 3) + tk.MustExec("deallocate prepare stmt2") + tk.MustExec("deallocate prepare stmt3") + tk.MustExec("deallocate prepare stmt4") + c.Assert(tk.Se.PreparedPlanCache().Size(), Equals, 0) +} diff --git a/planner/core/cache.go b/planner/core/cache.go index b16a99992a6cf..6659d3dc6d97b 100644 --- a/planner/core/cache.go +++ b/planner/core/cache.go @@ -70,12 +70,14 @@ type pstmtPlanCacheKey struct { // Hash implements Key interface. func (key *pstmtPlanCacheKey) Hash() []byte { - if key.hash == nil { + if len(key.hash) == 0 { var ( dbBytes = hack.Slice(key.database) bufferSize = len(dbBytes) + 8*6 ) - key.hash = make([]byte, 0, bufferSize) + if key.hash == nil { + key.hash = make([]byte, 0, bufferSize) + } key.hash = append(key.hash, dbBytes...) key.hash = codec.EncodeInt(key.hash, int64(key.connID)) key.hash = codec.EncodeInt(key.hash, int64(key.pstmtID)) @@ -87,6 +89,18 @@ func (key *pstmtPlanCacheKey) Hash() []byte { return key.hash } +// SetPstmtIDSchemaVersion implements PstmtCacheKeyMutator interface to change pstmtID and schemaVersion of cacheKey. +// so we can reuse Key instead of new every time. +func SetPstmtIDSchemaVersion(key kvcache.Key, pstmtID uint32, schemaVersion int64) { + psStmtKey, isPsStmtKey := key.(*pstmtPlanCacheKey) + if !isPsStmtKey { + return + } + psStmtKey.pstmtID = pstmtID + psStmtKey.schemaVersion = schemaVersion + psStmtKey.hash = psStmtKey.hash[:0] +} + // NewPSTMTPlanCacheKey creates a new pstmtPlanCacheKey object. func NewPSTMTPlanCacheKey(sessionVars *variable.SessionVars, pstmtID uint32, schemaVersion int64) kvcache.Key { timezoneOffset := 0 diff --git a/session/session.go b/session/session.go index d51c97dd710e7..8d95b61b816ac 100644 --- a/session/session.go +++ b/session/session.go @@ -162,12 +162,32 @@ func (s *session) getMembufCap() int { } func (s *session) cleanRetryInfo() { - if !s.sessionVars.RetryInfo.Retrying { - retryInfo := s.sessionVars.RetryInfo - for _, stmtID := range retryInfo.DroppedPreparedStmtIDs { - delete(s.sessionVars.PreparedStmts, stmtID) + if s.sessionVars.RetryInfo.Retrying { + return + } + + retryInfo := s.sessionVars.RetryInfo + defer retryInfo.Clean() + if len(retryInfo.DroppedPreparedStmtIDs) == 0 { + return + } + + planCacheEnabled := plannercore.PreparedPlanCacheEnabled() + var cacheKey kvcache.Key + if planCacheEnabled { + firstStmtID := retryInfo.DroppedPreparedStmtIDs[0] + cacheKey = plannercore.NewPSTMTPlanCacheKey( + s.sessionVars, firstStmtID, s.sessionVars.PreparedStmts[firstStmtID].SchemaVersion, + ) + } + for i, stmtID := range retryInfo.DroppedPreparedStmtIDs { + if planCacheEnabled { + if i > 0 { + plannercore.SetPstmtIDSchemaVersion(cacheKey, stmtID, s.sessionVars.PreparedStmts[stmtID].SchemaVersion) + } + s.PreparedPlanCache().Delete(cacheKey) } - retryInfo.Clean() + delete(s.sessionVars.PreparedStmts, stmtID) } } diff --git a/util/kvcache/simple_lru.go b/util/kvcache/simple_lru.go index 4bf04808f7da7..a9304d378bc42 100644 --- a/util/kvcache/simple_lru.go +++ b/util/kvcache/simple_lru.go @@ -89,3 +89,20 @@ func (l *SimpleLRUCache) Put(key Key, value Value) { l.size-- } } + +// Delete deletes the key-value pair from the LRU Cache. +func (l *SimpleLRUCache) Delete(key Key) { + k := string(key.Hash()) + element := l.elements[k] + if element == nil { + return + } + l.cache.Remove(element) + delete(l.elements, k) + l.size-- +} + +// Size gets the current cache size. +func (l *SimpleLRUCache) Size() int { + return int(l.size) +} diff --git a/util/kvcache/simple_lru_test.go b/util/kvcache/simple_lru_test.go index e3d73ed1631fd..bf4b99a8c1cf2 100644 --- a/util/kvcache/simple_lru_test.go +++ b/util/kvcache/simple_lru_test.go @@ -143,3 +143,29 @@ func (s *testLRUCacheSuite) TestGet(c *C) { c.Assert(value, Equals, vals[i]) } } + +func (s *testLRUCacheSuite) TestDelete(c *C) { + lru := NewSimpleLRUCache(3) + + keys := make([]*mockCacheKey, 3) + vals := make([]int64, 3) + + for i := 0; i < 3; i++ { + keys[i] = newMockHashKey(int64(i)) + vals[i] = int64(i) + lru.Put(keys[i], vals[i]) + } + c.Assert(int(lru.size), Equals, 3) + + lru.Delete(keys[1]) + value, exists := lru.Get(keys[1]) + c.Assert(exists, IsFalse) + c.Assert(value, IsNil) + c.Assert(int(lru.size), Equals, 2) + + _, exists = lru.Get(keys[0]) + c.Assert(exists, IsTrue) + + _, exists = lru.Get(keys[2]) + c.Assert(exists, IsTrue) +}