Skip to content

Commit

Permalink
fix: rollback when global session timeout (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
dk-lockdown committed Jun 11, 2022
1 parent 48a8f1a commit 1798cd1
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 30 deletions.
11 changes: 5 additions & 6 deletions pkg/dt/distributed_transaction_manger.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,11 @@ func (manager *DistributedTransactionManager) processNextGlobalSession(ctx conte
return true
}
if newGlobalSession.Status == api.Begin {
if isGlobalSessionTimeout(newGlobalSession) {
_, err = manager.Rollback(context.Background(), newGlobalSession.XID)
if err != nil {
log.Error(err)
}
_, err = manager.Rollback(context.Background(), newGlobalSession.XID)
if err != nil {
log.Error(err)
}
manager.globalSessionQueue.AddAfter(gs, time.Duration(gs.Timeout)*time.Millisecond)
}
if newGlobalSession.Status == api.Committing || newGlobalSession.Status == api.Rollbacking {
bsKeys, err := manager.storageDriver.GetBranchSessionKeys(context.Background(), newGlobalSession.XID)
Expand Down Expand Up @@ -469,7 +468,7 @@ func (manager *DistributedTransactionManager) recordGlobalTransactionMetric(tran
}

func isGlobalSessionTimeout(gs *api.GlobalSession) bool {
return (misc.CurrentTimeMillis() - uint64(gs.BeginTime)) > uint64(gs.Timeout)
return misc.CurrentTimeMillis()-uint64(gs.BeginTime) > uint64(gs.Timeout)
}

func (manager *DistributedTransactionManager) IsRollingBackDead(bs *api.BranchSession) bool {
Expand Down
57 changes: 33 additions & 24 deletions pkg/dt/storage/etcd/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,46 +101,55 @@ func (s *store) AddGlobalSession(ctx context.Context, globalSession *api.GlobalS
}

func (s *store) AddBranchSession(ctx context.Context, branchSession *api.BranchSession) error {
data, err := branchSession.Marshal()
if err != nil {
return err
}

gs, err := s.GetGlobalSession(ctx, branchSession.XID)
if err != nil {
if errors.Is(err, err2.CouldNotFoundGlobalTransaction) {
return err2.GlobalTransactionFinished
}
return err
}
if gs.Status > api.Begin {
return err2.GlobalTransactionFinished
}

txn := s.client.Txn(ctx)
ops := make([]clientv3.Op, 0)
ops = append(ops, clientv3.OpPut(branchSession.BranchID, string(data)))
// 全局事务关联的事务分支
globalBranchKey := fmt.Sprintf("bs/%s/%d", branchSession.XID, branchSession.BranchSessionID)
ops = append(ops, clientv3.OpPut(globalBranchKey, branchSession.BranchID))

if branchSession.Type == api.AT && branchSession.LockKey != "" {
rowKeys := misc.CollectRowKeys(branchSession.LockKey, branchSession.ResourceID)

txn := s.client.Txn(ctx)
var cmpSlice []clientv3.Cmp
for _, rowKey := range rowKeys {
cmpSlice = append(cmpSlice, notFound(rowKey))
}
txn = txn.If(cmpSlice...)

ops := make([]clientv3.Op, 0, 2*len(rowKeys))
for _, rowKey := range rowKeys {
lockKey := fmt.Sprintf("lk/%s/%s", branchSession.XID, rowKey)
ops = append(ops, clientv3.OpPut(lockKey, rowKey))
ops = append(ops, clientv3.OpPut(rowKey, lockKey))
}
txn.Then(ops...)

txnResp, err := txn.Commit()
if err != nil {
return err
}
if !txnResp.Succeeded {
return err2.BranchLockAcquireFailed
}
}

data, err := branchSession.Marshal()
txn.Then(ops...)

txnResp, err := txn.Commit()
if err != nil {
return err
}
_, err = s.client.Put(ctx, branchSession.BranchID, string(data))
if err != nil {
return err
if !txnResp.Succeeded {
return errors.Errorf("register branch session failed, xid: %s, resource id: %s", branchSession.XID, branchSession.ResourceID)
}

// 全局事务关联的事务分支
globalBranchKey := fmt.Sprintf("bs/%s/%d", branchSession.XID, branchSession.BranchSessionID)
_, err = s.client.Put(ctx, globalBranchKey, branchSession.BranchID)
return err
return nil
}

func (s *store) GlobalCommit(ctx context.Context, xid string) (api.GlobalSession_GlobalStatus, error) {
Expand All @@ -155,12 +164,12 @@ func (s *store) GlobalCommit(ctx context.Context, xid string) (api.GlobalSession
gs, err := s.GetGlobalSession(ctx, xid)
if err != nil {
if errors.Is(err, err2.CouldNotFoundGlobalTransaction) {
return api.Finished, nil
return api.Finished, err2.GlobalTransactionFinished
}
return api.Begin, err
}
if gs.Status > api.Begin {
return gs.Status, nil
return gs.Status, err2.GlobalTransactionFinished
}
gs.Status = api.Committing
data, err := gs.Marshal()
Expand Down Expand Up @@ -197,12 +206,12 @@ func (s *store) GlobalRollback(ctx context.Context, xid string) (api.GlobalSessi
gs, err := s.GetGlobalSession(ctx, xid)
if err != nil {
if errors.Is(err, err2.CouldNotFoundGlobalTransaction) {
return api.Finished, nil
return api.Finished, err2.GlobalTransactionFinished
}
return api.Begin, err
}
if gs.Status > api.Begin {
return gs.Status, nil
return gs.Status, err2.GlobalTransactionFinished
}
gs.Status = api.Rollbacking
data, err := gs.Marshal()
Expand Down
1 change: 1 addition & 0 deletions pkg/errors/dt.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import "errors"

var (
GlobalTransactionNotActive = errors.New("global session not active")
GlobalTransactionFinished = errors.New("global session finished")
CouldNotFoundGlobalTransaction = errors.New("could not found global transaction")
CouldNotFoundBranchTransaction = errors.New("could not found branch transaction")
BranchLockAcquireFailed = errors.New("branch lock acquire failed")
Expand Down
5 changes: 5 additions & 0 deletions pkg/listener/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ func (l *HttpListener) Listen() {
}
if err := l.doPostFilter(ctx); err != nil {
log.Error(err)
ctx.Response.Reset()
ctx.SetStatusCode(500)
if _, err = ctx.WriteString(fmt.Sprintf(`{"success":false,"message":"%s"}`, err.Error())); err != nil {
log.Error(err)
}
}
}); err != nil {
log.Error(err)
Expand Down

0 comments on commit 1798cd1

Please sign in to comment.