From 384f92027ec0769d876e9eabbf054798d1a85afe Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Mon, 4 Dec 2023 14:58:54 +0700 Subject: [PATCH] get sql state from interface (#38) --- error.go | 15 ++++++++------- tx.go | 5 +---- tx_test.go | 2 -- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/error.go b/error.go index 274c166..3d8ab98 100644 --- a/error.go +++ b/error.go @@ -16,13 +16,14 @@ func contains(xs []string, x string) bool { return false } +type sqlState interface { + SQLState() string +} + // IsErrorCode checks is error has given code func IsErrorCode(err error, code string) bool { - var pqErr *pq.Error - if errors.As(err, &pqErr) && string(pqErr.Code) == code { - return true - } - return false + sErr, ok := err.(sqlState) + return ok && sErr.SQLState() == code } // IsErrorClass checks is error has given class @@ -52,7 +53,7 @@ func IsInvalidTextRepresentation(err error) bool { return IsErrorCode(err, "22P02") } -// IsCharacterNotInRepertoire checks is error an character_not_in_repertoire +// IsCharacterNotInRepertoire checks is error a character_not_in_repertoire func IsCharacterNotInRepertoire(err error) bool { return IsErrorCode(err, "22021") } @@ -75,7 +76,7 @@ func IsQueryCanceled(err error) bool { return IsErrorCode(err, "57014") } -// IsSerializationFailure checks is error an serialization_failure error +// IsSerializationFailure checks is error a serialization_failure error // (pq: could not serialize access due to read/write dependencies among transactions) func IsSerializationFailure(err error) bool { return IsErrorCode(err, "40001") diff --git a/tx.go b/tx.go index 9827828..e7615ef 100644 --- a/tx.go +++ b/tx.go @@ -4,8 +4,6 @@ import ( "context" "database/sql" "errors" - - "github.com/lib/pq" ) // ErrAbortTx rollbacks transaction and return nil error @@ -79,8 +77,7 @@ func RunInTxContext(ctx context.Context, db BeginTxer, opts *TxOptions, fn func( if err == nil || errors.Is(err, ErrAbortTx) { return nil } - var pqErr *pq.Error - if retryable := errors.As(err, &pqErr) && (pqErr.Code == "40001"); !retryable { + if !IsSerializationFailure(err) { return err } } diff --git a/tx_test.go b/tx_test.go index e74d200..be6536c 100644 --- a/tx_test.go +++ b/tx_test.go @@ -7,7 +7,6 @@ import ( "math/rand" "sync" "testing" - "time" "github.com/acoshift/pgsql" ) @@ -118,7 +117,6 @@ func TestTx(t *testing.T) { } wg := sync.WaitGroup{} - rand.Seed(time.Now().Unix()) for i := 0; i < 1000; i++ { wg.Add(1) go func() {