diff --git a/pkg/executor/read_write_splitting.go b/pkg/executor/read_write_splitting.go index 6e851d5..7013f1a 100644 --- a/pkg/executor/read_write_splitting.go +++ b/pkg/executor/read_write_splitting.go @@ -230,6 +230,24 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComQuery( return nil, 0, err } return result, 0, err + case *ast.XAStartStmt: + tx, result, err = executor.dbGroup.XAStart(spanCtx, sqlText) + if err != nil { + return nil, 0, err + } + executor.localTransactionMap.Store(connectionID, tx) + return result, 0, nil + case *ast.XAPrepareStmt: + txi, ok := executor.localTransactionMap.Load(connectionID) + if !ok { + return nil, 0, errors.New("there is no transaction") + } + defer executor.localTransactionMap.Delete(connectionID) + tx = txi.(proto.Tx) + if result, err = tx.XAPrepare(ctx, sqlText); err != nil { + return nil, 0, err + } + return result, 0, err case *ast.InsertStmt, *ast.DeleteStmt, *ast.UpdateStmt: txi, ok := executor.localTransactionMap.Load(connectionID) if ok { diff --git a/pkg/executor/single_db.go b/pkg/executor/single_db.go index 45b2afe..127e96d 100644 --- a/pkg/executor/single_db.go +++ b/pkg/executor/single_db.go @@ -225,6 +225,24 @@ func (executor *SingleDBExecutor) ExecutorComQuery( return nil, 0, err } return result, 0, err + case *ast.XAStartStmt: + tx, result, err = db.XAStart(spanCtx, sqlText) + if err != nil { + return nil, 0, err + } + executor.localTransactionMap.Store(connectionID, tx) + return result, 0, nil + case *ast.XAPrepareStmt: + txi, ok := executor.localTransactionMap.Load(connectionID) + if !ok { + return nil, 0, errors.New("there is no transaction") + } + defer executor.localTransactionMap.Delete(connectionID) + tx = txi.(proto.Tx) + if result, err = tx.XAPrepare(ctx, sqlText); err != nil { + return nil, 0, err + } + return result, 0, err default: txi, ok := executor.localTransactionMap.Load(connectionID) if ok { diff --git a/pkg/group/group.go b/pkg/group/group.go index ce13da2..2867c39 100644 --- a/pkg/group/group.go +++ b/pkg/group/group.go @@ -81,6 +81,11 @@ func (group *DBGroup) Begin(ctx context.Context) (proto.Tx, proto.Result, error) return dbs[0].Begin(ctx) } +func (group *DBGroup) XAStart(ctx context.Context, sql string) (proto.Tx, proto.Result, error) { + dbs := group.getAvailableMasters() + return dbs[0].XAStart(ctx, sql) +} + func (group *DBGroup) Query(ctx context.Context, query string) (proto.Result, uint16, error) { db := group.pick(ctx) return db.Query(ctx, query) diff --git a/pkg/proto/interface.go b/pkg/proto/interface.go index 2d75e9b..cfb0c33 100644 --- a/pkg/proto/interface.go +++ b/pkg/proto/interface.go @@ -163,6 +163,7 @@ type ( ExecuteSql(ctx context.Context, sql string, args ...interface{}) (Result, uint16, error) ExecuteSqlDirectly(sql string, args ...interface{}) (Result, uint16, error) Begin(ctx context.Context) (Tx, Result, error) + XAStart(ctx context.Context, sql string) (Tx, Result, error) } Tx interface { @@ -174,6 +175,7 @@ type ( Commit(ctx context.Context) (Result, error) Rollback(ctx context.Context, stmt *ast.RollbackStmt) (Result, error) ReleaseSavepoint(ctx context.Context, savepoint string) (result Result, err error) + XAPrepare(ctx context.Context, sql string) (Result, error) } DBManager interface { @@ -190,6 +192,7 @@ type ( PrepareQuery(ctx context.Context, query string, args ...interface{}) (Result, uint16, error) PrepareExecute(ctx context.Context, query string, args ...interface{}) (Result, uint16, error) PrepareExecuteStmt(ctx context.Context, stmt *Stmt) (Result, uint16, error) + XAStart(ctx context.Context, sql string) (Tx, Result, error) } DBGroupTx interface { diff --git a/pkg/sql/db.go b/pkg/sql/db.go index 150db7d..f4adddf 100644 --- a/pkg/sql/db.go +++ b/pkg/sql/db.go @@ -464,6 +464,36 @@ func (db *DB) Begin(ctx context.Context) (proto.Tx, proto.Result, error) { }, result, nil } +func (db *DB) XAStart(ctx context.Context, sql string) (proto.Tx, proto.Result, error) { + var ( + result proto.Result + conn *driver.BackendConnection + err error + ) + + spanCtx, span := tracing.GetTraceSpan(ctx, tracing.DBXAStart) + span.SetAttributes(attribute.KeyValue{Key: "db", Value: attribute.StringValue(db.name)}) + defer span.End() + + r, err := db.pool.Get(spanCtx) + if err != nil { + err = errors.WithStack(err) + return nil, nil, err + } + conn = r.(*driver.BackendConnection) + + if result, err = conn.Execute(ctx, sql, false); err != nil { + db.pool.Put(r) + return nil, nil, err + } + + return &Tx{ + closed: atomic.NewBool(false), + db: db, + conn: conn, + }, result, nil +} + func (db *DB) SetConnectionPreFilters(filters []proto.DBConnectionPreFilter) { db.connectionPreFilters = filters } diff --git a/pkg/sql/tx.go b/pkg/sql/tx.go index 1cfb104..427c654 100644 --- a/pkg/sql/tx.go +++ b/pkg/sql/tx.go @@ -171,6 +171,23 @@ func (tx *Tx) Rollback(ctx context.Context, stmt *ast.RollbackStmt) (result prot return } +func (tx *Tx) XAPrepare(ctx context.Context, sql string) (result proto.Result, err error) { + _, span := tracing.GetTraceSpan(ctx, tracing.TxXAPrepare) + span.SetAttributes(attribute.KeyValue{Key: "db", Value: attribute.StringValue(tx.db.name)}) + defer span.End() + + if tx.closed.Load() { + return nil, nil + } + if tx.db == nil || tx.db.IsClosed() { + return nil, err2.ErrInvalidConn + } + result, err = tx.conn.Execute(ctx, sql, false) + tx.db.pool.Put(tx.conn) + tx.Close() + return +} + func (tx *Tx) ReleaseSavepoint(ctx context.Context, savepoint string) (result proto.Result, err error) { _, span := tracing.GetTraceSpan(ctx, tracing.TxReleaseSavePoint) span.SetAttributes(attribute.KeyValue{Key: "db", Value: attribute.StringValue(tx.db.name)}) diff --git a/pkg/tracing/constant.go b/pkg/tracing/constant.go index 2180ecd..c0e3efa 100644 --- a/pkg/tracing/constant.go +++ b/pkg/tracing/constant.go @@ -66,19 +66,21 @@ const ( DBExecStmt = "db_exec_stmt" DBExecFieldList = "db_exec_field_list" DBLocalTransactionBegin = "db_tx_begin" + DBXAStart = "db_xa_start" // group GroupQuery = "group_query" GroupExecute = "group_execute" - GroupTransactionBegin = "group_transaction_begin" + GroupTransactionBegin = "group_tx_begin" // tx TxQuery = "tx_query" TxExecSQL = "tx_exec_sql" TxExecStmt = "tx_exec_stmt" - TxCommit = "db_local_transaction_commit" - TxRollback = "db_local_transaction_rollback" - TxReleaseSavePoint = "db_local_transaction_release_savepoint" + TxCommit = "db_tx_commit" + TxRollback = "db_tx_rollback" + TxReleaseSavePoint = "db_tx_release_savepoint" + TxXAPrepare = "db_xa_prepare" // group tx GroupTxQuery = "group_tx_query" diff --git a/test/rws/read_write_splitting_test.go b/test/rws/read_write_splitting_test.go index d3103e8..9fe3f6c 100644 --- a/test/rws/read_write_splitting_test.go +++ b/test/rws/read_write_splitting_test.go @@ -17,11 +17,13 @@ package rws import ( + "context" "database/sql" "testing" "time" _ "github.com/go-sql-driver/mysql" // register mysql + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) @@ -79,6 +81,14 @@ func (suite *_ReadWriteSplittingSuite) SetupSuite() { suite.Equal(int64(1), affected) } } + + result, err = masterDB.Exec(insertEmployee, 100005, "1992-05-03", "jane", "lewis", "M", "2014-09-01") + if suite.NoErrorf(err, "insert row error: %v", err) { + affected, err := result.RowsAffected() + if suite.NoErrorf(err, "insert row error: %v", err) { + suite.Equal(int64(1), affected) + } + } } } @@ -230,5 +240,26 @@ func (suite *_ReadWriteSplittingSuite) TestUpdateEncryption() { } } +func (suite *_ReadWriteSplittingSuite) TestXATransaction() { + ctx := context.Background() + conn, err := suite.db.Conn(ctx) + assert.Nil(suite.T(), err) + _, err = conn.ExecContext(ctx, "XA START 'abc'") + assert.Nil(suite.T(), err) + result, err := conn.ExecContext(ctx, deleteEmployee, 100005) + if suite.NoErrorf(err, "delete row error: %v", err) { + affected, err := result.RowsAffected() + if suite.NoErrorf(err, "delete row error: %v", err) { + suite.Equal(int64(1), affected) + } + } + _, err = conn.ExecContext(ctx, "XA END 'abc'") + assert.Nil(suite.T(), err) + _, err = conn.ExecContext(ctx, "XA PREPARE 'abc'") + assert.Nil(suite.T(), err) + _, err = conn.ExecContext(ctx, "XA COMMIT 'abc'") + assert.Nil(suite.T(), err) +} + func (suite *_ReadWriteSplittingSuite) TearDownSuite() { } diff --git a/test/sdb/crud_test.go b/test/sdb/crud_test.go index fa47aa8..7ec1b0b 100644 --- a/test/sdb/crud_test.go +++ b/test/sdb/crud_test.go @@ -17,11 +17,13 @@ package sdb import ( + "context" "database/sql" "testing" "time" _ "github.com/go-sql-driver/mysql" // register mysql + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) @@ -62,6 +64,13 @@ func (suite *_CRUDSuite) SetupSuite() { suite.Equal(int64(1), affected) } } + result, err = suite.db.Exec(insertEmployee, 100005, "1992-05-03", "jane", "lewis", "M", "2014-09-01") + if suite.NoErrorf(err, "insert row error: %v", err) { + affected, err := result.RowsAffected() + if suite.NoErrorf(err, "insert row error: %v", err) { + suite.Equal(int64(1), affected) + } + } } func (suite *_CRUDSuite) TestDelete() { @@ -157,6 +166,27 @@ func (suite *_CRUDSuite) TestUpdateEncryption() { } } +func (suite *_CRUDSuite) TestXATransaction() { + ctx := context.Background() + conn, err := suite.db.Conn(ctx) + assert.Nil(suite.T(), err) + _, err = conn.ExecContext(ctx, "XA START 'abc'") + assert.Nil(suite.T(), err) + result, err := conn.ExecContext(ctx, deleteEmployee, 100005) + if suite.NoErrorf(err, "delete row error: %v", err) { + affected, err := result.RowsAffected() + if suite.NoErrorf(err, "delete row error: %v", err) { + suite.Equal(int64(1), affected) + } + } + _, err = conn.ExecContext(ctx, "XA END 'abc'") + assert.Nil(suite.T(), err) + _, err = conn.ExecContext(ctx, "XA PREPARE 'abc'") + assert.Nil(suite.T(), err) + _, err = conn.ExecContext(ctx, "XA COMMIT 'abc'") + assert.Nil(suite.T(), err) +} + func (suite *_CRUDSuite) TearDownSuite() { result, err := suite.db.Exec(deleteEmployee, 100001) if suite.NoErrorf(err, "delete row error: %v", err) { diff --git a/testdata/mock_db.go b/testdata/mock_db.go index d2def84..8f6e517 100644 --- a/testdata/mock_db.go +++ b/testdata/mock_db.go @@ -539,3 +539,19 @@ func (mr *MockDBMockRecorder) WriteWeight() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteWeight", reflect.TypeOf((*MockDB)(nil).WriteWeight)) } + +// XAStart mocks base method. +func (m *MockDB) XAStart(arg0 context.Context, arg1 string) (proto.Tx, proto.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "XAStart", arg0, arg1) + ret0, _ := ret[0].(proto.Tx) + ret1, _ := ret[1].(proto.Result) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// XAStart indicates an expected call of XAStart. +func (mr *MockDBMockRecorder) XAStart(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XAStart", reflect.TypeOf((*MockDB)(nil).XAStart), arg0, arg1) +} diff --git a/testdata/mock_tx.go b/testdata/mock_tx.go index 9fc07c2..619c99d 100644 --- a/testdata/mock_tx.go +++ b/testdata/mock_tx.go @@ -187,3 +187,18 @@ func (mr *MockTxMockRecorder) Rollback(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTx)(nil).Rollback), arg0, arg1) } + +// XAPrepare mocks base method. +func (m *MockTx) XAPrepare(arg0 context.Context, arg1 string) (proto.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "XAPrepare", arg0, arg1) + ret0, _ := ret[0].(proto.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// XAPrepare indicates an expected call of XAPrepare. +func (mr *MockTxMockRecorder) XAPrepare(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XAPrepare", reflect.TypeOf((*MockTx)(nil).XAPrepare), arg0, arg1) +}