Skip to content

Commit

Permalink
chore(db): move db transactions into own file (#1703)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven-urbanski-freiheit-com authored Jun 26, 2024
1 parent 4ee0695 commit e0d7741
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 63 deletions.
63 changes: 0 additions & 63 deletions pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ func SqliteToPostgresQuery(query string) string {
return q
}

type DBFunction func(ctx context.Context, transaction *sql.Tx) error

func Remove(s []string, r string) []string {
for i, v := range s {
if v == r {
Expand All @@ -211,67 +209,6 @@ func Remove(s []string, r string) []string {
return s
}

// WithTransaction opens a transaction, runs `f` and then calls either Commit or Rollback.
// Use this if the only thing to return from `f` is an error.
func (h *DBHandler) WithTransaction(ctx context.Context, f DBFunction) error {
_, err := WithTransactionT(h, ctx, func(ctx context.Context, transaction *sql.Tx) (*interface{}, error) {
err2 := f(ctx, transaction)
if err2 != nil {
return nil, err2
}
return nil, nil
})
if err != nil {
return err
}
return nil
}

type DBFunctionT[T any] func(ctx context.Context, transaction *sql.Tx) (*T, error)

// WithTransactionT is the same as WithTransaction, but you can also return data, not just the error.
func WithTransactionT[T any](h *DBHandler, ctx context.Context, f DBFunctionT[T]) (*T, error) {
res, err := WithTransactionMultipleEntriesT(h, ctx, func(ctx context.Context, transaction *sql.Tx) ([]T, error) {
fRes, err2 := f(ctx, transaction)
if err2 != nil {
return nil, err2
}
if fRes == nil {
return make([]T, 0), nil
}
return []T{*fRes}, nil
})
if err != nil || len(res) == 0 {
return nil, err
}
return &res[0], err
}

type DBFunctionMultipleEntriesT[T any] func(ctx context.Context, transaction *sql.Tx) ([]T, error)

// WithTransactionMultipleEntriesT is the same as WithTransaction, but you can also return and array of data, not just the error.
func WithTransactionMultipleEntriesT[T any](h *DBHandler, ctx context.Context, f DBFunctionMultipleEntriesT[T]) ([]T, error) {
tx, err := h.DB.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func(tx *sql.Tx) {
_ = tx.Rollback()
// we ignore the error returned from Rollback() here,
// because it is always set when Commit() was successful
}(tx)

result, err := f(ctx, tx)
if err != nil {
return nil, err
}
err = tx.Commit()
if err != nil {
return nil, err
}
return result, nil
}

func closeRows(rows *sql.Rows) error {
err := rows.Close()
if err != nil {
Expand Down
83 changes: 83 additions & 0 deletions pkg/db/transactions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*This file is part of kuberpult.
Kuberpult is free software: you can redistribute it and/or modify
it under the terms of the Expat(MIT) License as published by
the Free Software Foundation.
Kuberpult is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
MIT License for more details.
You should have received a copy of the MIT License
along with kuberpult. If not, see <https://directory.fsf.org/wiki/License:Expat>.
Copyright freiheit.com*/

package db

import (
"context"
"database/sql"
)

type DBFunction func(ctx context.Context, transaction *sql.Tx) error
type DBFunctionT[T any] func(ctx context.Context, transaction *sql.Tx) (*T, error)
type DBFunctionMultipleEntriesT[T any] func(ctx context.Context, transaction *sql.Tx) ([]T, error)

// WithTransaction opens a transaction, runs `f` and then calls either Commit or Rollback.
// Use this if the only thing to return from `f` is an error.
func (h *DBHandler) WithTransaction(ctx context.Context, f DBFunction) error {
_, err := WithTransactionT(h, ctx, func(ctx context.Context, transaction *sql.Tx) (*interface{}, error) {
err2 := f(ctx, transaction)
if err2 != nil {
return nil, err2
}
return nil, nil
})
if err != nil {
return err
}
return nil
}

// WithTransactionT is the same as WithTransaction, but you can also return data, not just the error.
func WithTransactionT[T any](h *DBHandler, ctx context.Context, f DBFunctionT[T]) (*T, error) {
res, err := WithTransactionMultipleEntriesT(h, ctx, func(ctx context.Context, transaction *sql.Tx) ([]T, error) {
fRes, err2 := f(ctx, transaction)
if err2 != nil {
return nil, err2
}
if fRes == nil {
return make([]T, 0), nil
}
return []T{*fRes}, nil
})
if err != nil || len(res) == 0 {
return nil, err
}
return &res[0], err
}

// WithTransactionMultipleEntriesT is the same as WithTransaction, but you can also return and array of data, not just the error.
func WithTransactionMultipleEntriesT[T any](h *DBHandler, ctx context.Context, f DBFunctionMultipleEntriesT[T]) ([]T, error) {
tx, err := h.DB.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func(tx *sql.Tx) {
_ = tx.Rollback()
// we ignore the error returned from Rollback() here,
// because it is always set when Commit() was successful
}(tx)

result, err := f(ctx, tx)
if err != nil {
return nil, err
}
err = tx.Commit()
if err != nil {
return nil, err
}
return result, nil
}

0 comments on commit e0d7741

Please sign in to comment.