Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: statement prepared in TX is closed with TX #117

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 28 additions & 35 deletions package_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,27 @@ func (s *PackageSuite) TestTransactions(c *C) {
c.Assert(err, IsNil)
}

// Test that when preparing a statement inside a transaction it can still be prepared on the db directly and that it is
// not closed along with the transaction.
func (s *PackageSuite) TestStatementTXReuse(c *C) {
sqldb, err := setupDB(c.TestName())
c.Assert(err, IsNil)

db := sqlair.NewDB(sqldb)

// Create a statement and run it on a transaction.
selectStmt := sqlair.MustPrepare(`SELECT 'hello'`)
tx, err := db.Begin(nil, nil)
c.Assert(err, IsNil)
q := tx.Query(nil, selectStmt)
c.Assert(q.Run(), IsNil)
c.Assert(tx.Commit(), IsNil)

// Run the same existing statement outside the transaction.
q = db.Query(nil, selectStmt)
c.Assert(q.Run(), IsNil)
}

func (s *PackageSuite) TestTransactionErrors(c *C) {
tables, sqldb, err := personAndAddressDB(c)
c.Assert(err, IsNil)
Expand Down Expand Up @@ -1183,40 +1204,6 @@ func (s *PackageSuite) TestPreparedStmtCaching(c *C) {
checkCacheEmpty()
}

func (s *PackageSuite) TestTransactionWithOneConn(c *C) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have some protection against this still. I have a few ideas but I'm not sure about any of them.

  1. In NewDB we return an error if the max connections is 1. The problem with this is that a deadlock can still if the number of in progress transactions matches the number connections in the pool. Also the user can still increase the max connections.
  2. We could track the number of connections ourselves and have a sqlair version of SetMaxOpenConns. Then we could not prepare on the DB if we know it would block. Again, the problem with this is that the user can still use the "plain" DB and we would not control it. A solution for this would be to also override the Open method and remove PlainDB however this would be a breaking change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding 2. I don't think that is the responsibility of SQLair, it should be done at the sql std library instead. Essentially we would be having a parallel "pool" implementation/accounting.

tables, sqldb, err := personAndAddressDB(c)
c.Assert(err, IsNil)
sqldb.SetMaxOpenConns(1)
ctx := context.Background()

db := sqlair.NewDB(sqldb)
defer dropTables(c, db, tables...)

// This test sets the maximum number of connections to the DB to one. The
// database/sql library makes use of a pool of connections to communicate
// with the DB. Certain operations require a dedicated connection to run,
// such as transactions.
// This test ensures that we do not enter a deadlock when doing a behind
// the scenes prepare for a transaction.
selectStmt := sqlair.MustPrepare("SELECT &Person.* FROM person WHERE name = 'Mark'", Person{})
mark := Person{20, "Mark", 1500}

tx, err := db.Begin(ctx, nil)
c.Assert(err, IsNil)

q := tx.Query(ctx, selectStmt)
defer func() {
c.Assert(tx.Commit(), IsNil)
}()
iter := q.Iter()
c.Assert(iter.Next(), Equals, true)
p := Person{}
c.Assert(iter.Get(&p), IsNil)
c.Assert(mark, Equals, p)
c.Assert(iter.Next(), Equals, false)
c.Assert(iter.Close(), IsNil)
}

type JujuLeaseKey struct {
Namespace string `db:"type"`
ModelUUID string `db:"model_uuid"`
Expand Down Expand Up @@ -1401,6 +1388,11 @@ func (s *PackageSuite) TestRaceConditionFinalizer(c *C) {
}
func (s *PackageSuite) TestRaceConditionFinalizerTX(c *C) {
var q *sqlair.Query
// Because of how the sql internal library works, we have to commit the transaction for the statement to be closed
// properly and the test to pass. Reason: the sql library keeps track of what connections are used for transactions
// and, when they are, it does not close statements, instead it marks them to be closed only when the connection
// itself is closed.
var tx *sqlair.TX
// Drop all the values except the query itself.
func() {
sqldb, err := setupDB(c.TestName())
Expand All @@ -1409,7 +1401,7 @@ func (s *PackageSuite) TestRaceConditionFinalizerTX(c *C) {
db := sqlair.NewDB(sqldb)

selectStmt := sqlair.MustPrepare(`SELECT 'hello'`)
tx, err := db.Begin(nil, nil)
tx, err = db.Begin(nil, nil)
c.Assert(err, IsNil)
q = tx.Query(nil, selectStmt)
}()
Expand All @@ -1422,4 +1414,5 @@ func (s *PackageSuite) TestRaceConditionFinalizerTX(c *C) {

// Assert that sql.Stmt was not closed early.
c.Assert(q.Run(), IsNil)
c.Assert(tx.Commit(), IsNil)
}
41 changes: 11 additions & 30 deletions sqlair.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (db *DB) Query(ctx context.Context, s *Statement, inputArgs ...any) *Query
ctx = context.Background()
}

sqlstmt, err := db.prepareStmt(ctx, db.sqldb, s)
sqlstmt, err := db.prepareStmt(ctx, s)
if err != nil {
return &Query{ctx: ctx, err: err}
}
Expand All @@ -180,25 +180,17 @@ func (db *DB) Query(ctx context.Context, s *Statement, inputArgs ...any) *Query
return &Query{sqlstmt: sqlstmt, stmt: s, db: db, pq: pq, ctx: ctx, err: nil}
}

// prepareSubstrate is an object that queries can be prepared on, e.g. a sql.DB
// or sql.Conn. It is used in prepareStmt.
type prepareSubstrate interface {
PrepareContext(context.Context, string) (*sql.Stmt, error)
}

// prepareStmt prepares a Statement on a prepareSubstrate. It first checks in
// the cache to see if it has already been prepared on the DB.
// The prepareSubstrate must be assosiated with the same DB that prepareStmt is
// a method of.
func (db *DB) prepareStmt(ctx context.Context, ps prepareSubstrate, s *Statement) (*sql.Stmt, error) {
// prepareStmt prepares a Statement on a DB. It first checks in the cache to
// see if it has already been prepared.
func (db *DB) prepareStmt(ctx context.Context, s *Statement) (*sql.Stmt, error) {
var err error
cacheMutex.RLock()
Copy link
Collaborator

@manadart manadart Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to see this tucked inside a cache implementation that uses the singleton pattern discussed prior, but not in this patch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR #126 does this.

// The statement ID is only removed from the cache when the finalizer is
// run, so it is always in stmtDBCache.
sqlstmt, ok := stmtDBCache[s.cacheID][db.cacheID]
cacheMutex.RUnlock()
if !ok {
sqlstmt, err = ps.PrepareContext(ctx, s.te.SQL())
sqlstmt, err = db.sqldb.PrepareContext(ctx, s.te.SQL())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -479,10 +471,9 @@ func (q *Query) GetAll(sliceArgs ...any) (err error) {
}

type TX struct {
sqltx *sql.Tx
sqlconn *sql.Conn
db *DB
done int32
sqltx *sql.Tx
db *DB
done int32
}

func (tx *TX) isDone() bool {
Expand All @@ -501,15 +492,11 @@ func (db *DB) Begin(ctx context.Context, opts *TXOptions) (*TX, error) {
if ctx == nil {
ctx = context.Background()
}
sqlconn, err := db.sqldb.Conn(ctx)
if err != nil {
return nil, err
}
sqltx, err := sqlconn.BeginTx(ctx, opts.plainTXOptions())
sqltx, err := db.sqldb.BeginTx(ctx, opts.plainTXOptions())
if err != nil {
return nil, err
}
return &TX{sqltx: sqltx, sqlconn: sqlconn, db: db}, nil
return &TX{sqltx: sqltx, db: db}, nil
}

// Commit commits the transaction.
Expand All @@ -518,9 +505,6 @@ func (tx *TX) Commit() error {
if err == nil {
err = tx.sqltx.Commit()
}
if cerr := tx.sqlconn.Close(); err == nil {
err = cerr
}
return err
}

Expand All @@ -530,9 +514,6 @@ func (tx *TX) Rollback() error {
if err == nil {
err = tx.sqltx.Rollback()
}
if cerr := tx.sqlconn.Close(); err == nil {
err = cerr
}
return err
}

Expand Down Expand Up @@ -561,7 +542,7 @@ func (tx *TX) Query(ctx context.Context, s *Statement, inputArgs ...any) *Query
return &Query{ctx: ctx, err: ErrTXDone}
}

sqlstmt, err := tx.db.prepareStmt(ctx, tx.sqlconn, s)
sqlstmt, err := tx.db.prepareStmt(ctx, s)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about this?

  • prepareStmt checks dbStats to see if MaxOpenConnections is set to 1.
  • If it is, we return a typed error.
  • If we catch that error here, we log a warning and just prepare against the transaction instead of the db, skipping the cache.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could work but only for that specific case. If we had MaxOpenConnections set to 2 and we created two transactions we are going to have the same problem. And the same applies to n connections and n transactions. My other concern is that MaxOpenConnections can be changed concurrently from another thread accessing the underlying sqldb which is going to lead to some bugs that are difficult to trace.

Another solution would be to check whether there is a max number of connections and issue a warning or even return an error. We could even require an environment variable to be set just to be explicit about the number of connections being a problem. The issue, once again, is the concurrency aspect.

if err != nil {
return &Query{ctx: ctx, err: err}
}
Expand Down