Skip to content

Commit

Permalink
Prevent connection pool from emptying upon panics (#6408)
Browse files Browse the repository at this point in the history
## Motivation

Ensure an active statement is reset even if `decoder` in `db.Exec` panics
  • Loading branch information
fasmat committed Oct 23, 2024
1 parent 55c3547 commit cc334e4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
10 changes: 8 additions & 2 deletions sql/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,11 +609,13 @@ func (db *sqliteDatabase) getTx(ctx context.Context, initstmt string) (*sqliteTx
if db.closed {
return nil, ErrClosed
}
conn := db.getConn(ctx)
conCtx, cancel := context.WithCancel(ctx)
conn := db.getConn(conCtx)
if conn == nil {
cancel()
return nil, ErrNoConnection
}
tx := &sqliteTx{queryCache: db.queryCache, db: db, conn: conn}
tx := &sqliteTx{queryCache: db.queryCache, db: db, conn: conn, freeConn: cancel}
if err := tx.begin(initstmt); err != nil {
return nil, err
}
Expand Down Expand Up @@ -998,6 +1000,7 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (in
encoder(stmt)
}
defer stmt.ClearBindings()
defer stmt.Reset()

rows := 0
for {
Expand Down Expand Up @@ -1027,6 +1030,7 @@ type sqliteTx struct {
*queryCache
db *sqliteDatabase
conn *sqlite.Conn
freeConn func()
committed bool
err error
}
Expand Down Expand Up @@ -1055,10 +1059,12 @@ func (tx *sqliteTx) Commit() error {
func (tx *sqliteTx) Release() error {
defer tx.db.pool.Put(tx.conn)
if tx.committed {
tx.freeConn()
return nil
}
stmt := tx.conn.Prep("ROLLBACK")
_, tx.err = stmt.Step()
tx.freeConn()
return mapSqliteError(tx.err)
}

Expand Down
26 changes: 26 additions & 0 deletions sql/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path/filepath"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand All @@ -16,6 +17,31 @@ import (
"go.uber.org/zap/zaptest/observer"
)

func Test_ConReturnedToPool(t *testing.T) {
db := InMemory(
WithLogger(zaptest.NewLogger(t)),
WithConnections(1),
WithDatabaseSchema(&Schema{
Script: `CREATE TABLE testing1 (
id varchar primary key,
field int
);`,
}),
WithNoCheckSchemaDrift(),
)

require.Panics(t, func() {
db.Exec("select 1", nil, func(stmt *Statement) bool {
panic("decoder panic")
})
})

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
con := db.pool.Get(ctx)
require.NotNil(t, con, "connection was not returned")
}

func Test_Transaction_Isolation(t *testing.T) {
db := InMemory(
WithLogger(zaptest.NewLogger(t)),
Expand Down

0 comments on commit cc334e4

Please sign in to comment.