diff --git a/pgxmock_test.go b/pgxmock_test.go index 63517ba..7ac6033 100644 --- a/pgxmock_test.go +++ b/pgxmock_test.go @@ -215,6 +215,38 @@ func TestTransactionExpectations(t *testing.T) { t.Errorf("an error '%s' was not expected when committing a transaction", err) } + // beginTx and commit + mock.ExpectBeginTx(pgx.TxOptions{}) + mock.ExpectCommit() + + tx, err = mock.BeginTx(context.Background(), pgx.TxOptions{}) + if err != nil { + t.Errorf("an error '%s' was not expected when beginning a transaction", err) + } + + err = tx.Commit(context.Background()) + if err != nil { + t.Errorf("an error '%s' was not expected when committing a transaction", err) + } + + // beginTxFunc and commit + mock.ExpectBeginTx(pgx.TxOptions{}) + mock.ExpectCommit() + + err = mock.BeginFunc(context.Background(), func(tx pgx.Tx) error { return nil }) + if err != nil { + t.Errorf("an error '%s' was not expected when beginning a transaction", err) + } + + // beginTxFunc and rollback + mock.ExpectBeginTx(pgx.TxOptions{}) + mock.ExpectRollback() + + err = mock.BeginFunc(context.Background(), func(tx pgx.Tx) error { return errors.New("smth wrong") }) + if err == nil { + t.Error("an error was expected whithin a transaction, but got none") + } + // begin and rollback mock.ExpectBegin() mock.ExpectRollback()