Skip to content

Commit

Permalink
feat: add mutex lock to Trasaction (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
MuZhou233 authored Mar 31, 2024
1 parent c442abe commit 4e7c05a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 5 additions & 0 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"runtime"
"strings"
"sync"

"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"
Expand Down Expand Up @@ -81,6 +82,7 @@ type Adapter struct {
dbSpecified bool
db *gorm.DB
isFiltered bool
transactionMu *sync.Mutex
}

// finalizer is the destructor for Adapter.
Expand Down Expand Up @@ -134,6 +136,7 @@ func NewAdapter(driverName string, dataSourceName string, params ...interface{})
a.tableName = defaultTableName
a.databaseName = defaultDatabaseName
a.dbSpecified = false
a.transactionMu = &sync.Mutex{}

if len(params) == 1 {
switch p1 := params[0].(type) {
Expand Down Expand Up @@ -665,6 +668,8 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error

// Transaction perform a set of operations within a transaction
func (a *Adapter) Transaction(e casbin.IEnforcer, fc func(casbin.IEnforcer) error, opts ...*sql.TxOptions) error {
a.transactionMu.Lock()
defer a.transactionMu.Unlock()
var err error
oriAdapter := a.db
// reload policy from database to sync with the transaction
Expand Down
6 changes: 5 additions & 1 deletion adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ func TestTransactionRace(t *testing.T) {
a := initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule")
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)

concurrency := 10
concurrency := 100

var g errgroup.Group
for i := 0; i < concurrency; i++ {
Expand All @@ -721,4 +721,8 @@ func TestTransactionRace(t *testing.T) {
})
}
require.NoError(t, g.Wait())

for i := 0; i < concurrency; i++ {
require.True(t, e.HasPolicy("jack", fmt.Sprintf("data%d", i), "write"))
}
}

0 comments on commit 4e7c05a

Please sign in to comment.