Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix memory leaks in PrepareStatementDB #7142

Merged
merged 4 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions prepare_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@ type Stmt struct {
}

type PreparedStmtDB struct {
Stmts map[string]*Stmt
PreparedSQL []string
Mux *sync.RWMutex
Stmts map[string]*Stmt
Mux *sync.RWMutex
ConnPool
}

func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
}
}

Expand All @@ -48,22 +46,32 @@ func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()
a631807682 marked this conversation as resolved.
Show resolved Hide resolved

for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
go stmt.Close()
}
for _, stmt := range db.Stmts {
go func(s *Stmt) {
// make sure the stmt must finish preparation first
<-s.prepared
a631807682 marked this conversation as resolved.
Show resolved Hide resolved
if s.Stmt != nil {
_ = s.Close()
}
}(stmt)
}
// setting db.Stmts to nil to avoid further using
db.Stmts = nil
}

func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()

for _, stmt := range sdb.Stmts {
go stmt.Close()
go func(s *Stmt) {
// make sure the stmt must finish preparation first
<-s.prepared
if s.Stmt != nil {
_ = s.Close()
}
}(stmt)
}
ivila marked this conversation as resolved.
Show resolved Hide resolved
sdb.PreparedSQL = make([]string, 0, 100)
sdb.Stmts = make(map[string]*Stmt)
}

Expand Down Expand Up @@ -93,7 +101,12 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact

return *stmt, nil
}

// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
// which cause by calling Close and executing SQL concurrently
if db.Stmts == nil {
db.Mux.Unlock()
return Stmt{}, ErrInvalidDB
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
Expand All @@ -118,7 +131,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact

db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()

return cacheStmt, nil
Expand Down
147 changes: 147 additions & 0 deletions tests/prepared_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -167,3 +168,149 @@ func TestPreparedStmtReset(t *testing.T) {
t.Fatalf("prepared stmt should be empty")
}
}

func isUsingClosedConnError(err error) bool {
// https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717
return err.Error() == "sql: statement is closed"
}

// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
func TestPreparedStmtConcurrentReset(t *testing.T) {
name := "prepared_stmt_concurrent_reset"
user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil {
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
}

// create a new connection to keep away from other tests
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
if err != nil {
t.Fatalf("failed to open test connection due to %s", err)
}
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}

loopCount := 100
var wg sync.WaitGroup
var unexpectedError bool
writerFinish := make(chan struct{})

wg.Add(1)
go func(id uint) {
defer wg.Done()
defer close(writerFinish)

for j := 0; j < loopCount; j++ {
var tmp User
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
if err == nil || isUsingClosedConnError(err) {
continue
}
t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err)
unexpectedError = true
break
}
}(user.ID)

wg.Add(1)
go func() {
defer wg.Done()
<-writerFinish
pdb.Reset()
}()

wg.Wait()

if unexpectedError {
t.Fatalf("should is a unexpected error")
}
}

// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
// for example: one goroutine found error and just close the database, and others are executing SQL
// this test making sure that the gorm would not get a Segmentation Fault,
// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
// and all of the goroutine must got gorm.ErrInvalidDB after database close
func TestPreparedStmtConcurrentClose(t *testing.T) {
name := "prepared_stmt_concurrent_close"
user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil {
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
}

// create a new connection to keep away from other tests
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
if err != nil {
t.Fatalf("failed to open test connection due to %s", err)
}
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}

loopCount := 100
var wg sync.WaitGroup
var lastErr error
closeValid := make(chan struct{}, loopCount)
closeStartIdx := loopCount / 2 // close the database at the middle of the execution
var lastRunIndex int
var closeFinishedAt int64

wg.Add(1)
go func(id uint) {
defer wg.Done()
defer close(closeValid)
for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ {
if lastRunIndex == closeStartIdx {
closeValid <- struct{}{}
}
var tmp User
now := time.Now().UnixNano()
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
if err == nil {
closeFinishedAt := atomic.LoadInt64(&closeFinishedAt)
if (closeFinishedAt != 0) && (now > closeFinishedAt) {
lastErr = errors.New("must got error after database closed")
break
}
continue
}
lastErr = err
break
}
}(user.ID)

wg.Add(1)
go func() {
defer wg.Done()
for range closeValid {
for i := 0; i < loopCount; i++ {
pdb.Close() // the Close method must can be call multiple times
atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano())
}
}
}()

wg.Wait()
var tmp User
err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error
if err != gorm.ErrInvalidDB {
t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err)
}

// must be error
if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) {
t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr)
}
if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx {
t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex)
}
if pdb.Stmts != nil {
t.Fatalf("stmts must be nil")
}
}
Loading