From a68b5ff2da8517cf8b6c8c0e473de6bf47b43ed1 Mon Sep 17 00:00:00 2001 From: Sergey Petrunin Date: Wed, 28 Jan 2015 14:18:28 +0200 Subject: [PATCH] Rewrite stmt cache for transactions --- stmtcacher.go | 51 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/stmtcacher.go b/stmtcacher.go index c2dc2208..4afdaa23 100644 --- a/stmtcacher.go +++ b/stmtcacher.go @@ -71,20 +71,55 @@ func (sc *stmtCacher) QueryRow(query string, args ...interface{}) RowScanner { return stmt.QueryRow(args...) } -type DBProxyBeginner interface { +// DBTransactionProxy wraps transaction and includes DBProxy interface +type DBTransactionProxy interface { DBProxy - Begin() (*sql.Tx, error) + Begin() error + Commit() error + Rollback() error } -type stmtCacheProxy struct { +type stmtCacheTransactionProxy struct { DBProxy - db *sql.DB + db *sql.DB + transaction *sql.Tx } -func NewStmtCacheProxy(db *sql.DB) DBProxyBeginner { - return &stmtCacheProxy{DBProxy: NewStmtCacher(db), db: db} +// NewStmtCacheTransactionProxy returns a DBTransactionProxy +// wrapping an open transaction in stmtCacher. +// You should use Begin() each time you want a new transaction and +// cache will be valid only for that transaction. +// +// Usage example: +// proxy := sq.NewStmtCacheTransactionProxy(db) +// mydb := sq.StatementBuilder.RunWith(proxy) +// insertUsers := mydb.Insert("users").Columns("name").Values("username") +// proxy.Commit() +// proxy.Begin() +// insertPets := mydb.Insert("pets").Columns("name", "username").Values("petname", "username") +// proxy.Commit() +func NewStmtCacheTransactionProxy(db *sql.DB) (proxy DBTransactionProxy, err error) { + proxy = &stmtCacheTransactionProxy{db: db} + return proxy, proxy.Begin() +} + +func (sp *stmtCacheTransactionProxy) Begin() (err error) { + tr, err := sp.db.Begin() + + if err != nil { + return + } + + sp.DBProxy = NewStmtCacher(tr) + sp.transaction = tr + + return +} + +func (sp *stmtCacheTransactionProxy) Commit() error { + return sp.transaction.Commit() } -func (sp *stmtCacheProxy) Begin() (*sql.Tx, error) { - return sp.db.Begin() +func (sp *stmtCacheTransactionProxy) Rollback() error { + return sp.transaction.Rollback() }