Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More refactoring of db.DefaultContext #27083

Merged
merged 4 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/admin_user_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func runCreateUser(c *cli.Context) error {
UID: u.ID,
}

if err := auth_model.NewAccessToken(t); err != nil {
if err := auth_model.NewAccessToken(ctx, t); err != nil {
return err
}

Expand Down
4 changes: 2 additions & 2 deletions cmd/admin_user_generate_access_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func runGenerateAccessToken(c *cli.Context) error {
UID: user.ID,
}

exist, err := auth_model.AccessTokenByNameExists(t)
exist, err := auth_model.AccessTokenByNameExists(ctx, t)
if err != nil {
return err
}
Expand All @@ -79,7 +79,7 @@ func runGenerateAccessToken(c *cli.Context) error {
t.Scope = accessTokenScope

// create the token
if err := auth_model.NewAccessToken(t); err != nil {
if err := auth_model.NewAccessToken(ctx, t); err != nil {
return err
}

Expand Down
31 changes: 16 additions & 15 deletions models/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package auth

import (
"context"
"crypto/subtle"
"encoding/hex"
"fmt"
Expand Down Expand Up @@ -95,7 +96,7 @@ func init() {
}

// NewAccessToken creates new access token.
func NewAccessToken(t *AccessToken) error {
func NewAccessToken(ctx context.Context, t *AccessToken) error {
salt, err := util.CryptoRandomString(10)
if err != nil {
return err
Expand All @@ -108,7 +109,7 @@ func NewAccessToken(t *AccessToken) error {
t.Token = hex.EncodeToString(token)
t.TokenHash = HashToken(t.Token, t.TokenSalt)
t.TokenLastEight = t.Token[len(t.Token)-8:]
_, err = db.GetEngine(db.DefaultContext).Insert(t)
_, err = db.GetEngine(ctx).Insert(t)
return err
}

Expand Down Expand Up @@ -137,7 +138,7 @@ func getAccessTokenIDFromCache(token string) int64 {
}

// GetAccessTokenBySHA returns access token by given token value
func GetAccessTokenBySHA(token string) (*AccessToken, error) {
func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
if token == "" {
return nil, ErrAccessTokenEmpty{}
}
Expand All @@ -158,7 +159,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
TokenLastEight: lastEight,
}
// Re-get the token from the db in case it has been deleted in the intervening period
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(accessToken)
has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
if err != nil {
return nil, err
}
Expand All @@ -169,7 +170,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
}

var tokens []AccessToken
err := db.GetEngine(db.DefaultContext).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
if err != nil {
return nil, err
} else if len(tokens) == 0 {
Expand All @@ -189,8 +190,8 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
}

// AccessTokenByNameExists checks if a token name has been used already by a user.
func AccessTokenByNameExists(token *AccessToken) (bool, error) {
return db.GetEngine(db.DefaultContext).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
}

// ListAccessTokensOptions contain filter options
Expand All @@ -201,8 +202,8 @@ type ListAccessTokensOptions struct {
}

// ListAccessTokens returns a list of access tokens belongs to given user.
func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
func ListAccessTokens(ctx context.Context, opts ListAccessTokensOptions) ([]*AccessToken, error) {
sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)

if len(opts.Name) != 0 {
sess = sess.Where("name=?", opts.Name)
Expand All @@ -222,23 +223,23 @@ func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
}

// UpdateAccessToken updates information of access token.
func UpdateAccessToken(t *AccessToken) error {
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}

// CountAccessTokens count access tokens belongs to given user by options
func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) {
sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
func CountAccessTokens(ctx context.Context, opts ListAccessTokensOptions) (int64, error) {
sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)
if len(opts.Name) != 0 {
sess = sess.Where("name=?", opts.Name)
}
return sess.Count(&AccessToken{})
}

// DeleteAccessTokenByID deletes access token by given ID.
func DeleteAccessTokenByID(id, userID int64) error {
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&AccessToken{
func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
UID: userID,
})
if err != nil {
Expand Down
35 changes: 18 additions & 17 deletions models/auth/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"

"github.com/stretchr/testify/assert"
Expand All @@ -18,15 +19,15 @@ func TestNewAccessToken(t *testing.T) {
UID: 3,
Name: "Token C",
}
assert.NoError(t, auth_model.NewAccessToken(token))
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)

invalidToken := &auth_model.AccessToken{
ID: token.ID, // duplicate
UID: 2,
Name: "Token F",
}
assert.Error(t, auth_model.NewAccessToken(invalidToken))
assert.Error(t, auth_model.NewAccessToken(db.DefaultContext, invalidToken))
}

func TestAccessTokenByNameExists(t *testing.T) {
Expand All @@ -39,16 +40,16 @@ func TestAccessTokenByNameExists(t *testing.T) {
}

// Check to make sure it doesn't exists already
exist, err := auth_model.AccessTokenByNameExists(token)
exist, err := auth_model.AccessTokenByNameExists(db.DefaultContext, token)
assert.NoError(t, err)
assert.False(t, exist)

// Save it to the database
assert.NoError(t, auth_model.NewAccessToken(token))
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)

// This token must be found by name in the DB now
exist, err = auth_model.AccessTokenByNameExists(token)
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, token)
assert.NoError(t, err)
assert.True(t, exist)

Expand All @@ -59,32 +60,32 @@ func TestAccessTokenByNameExists(t *testing.T) {

// Name matches but different user ID, this shouldn't exists in the
// database
exist, err = auth_model.AccessTokenByNameExists(user4Token)
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, user4Token)
assert.NoError(t, err)
assert.False(t, exist)
}

func TestGetAccessTokenBySHA(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA("d2c6c1ba3890b309189a8e618c72a162e4efbf36")
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)
assert.Equal(t, "Token A", token.Name)
assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
assert.Equal(t, "e4efbf36", token.TokenLastEight)

_, err = auth_model.GetAccessTokenBySHA("notahash")
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "notahash")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))

_, err = auth_model.GetAccessTokenBySHA("")
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
}

func TestListAccessTokens(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
tokens, err := auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 1})
tokens, err := auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 1})
assert.NoError(t, err)
if assert.Len(t, tokens, 2) {
assert.Equal(t, int64(1), tokens[0].UID)
Expand All @@ -93,39 +94,39 @@ func TestListAccessTokens(t *testing.T) {
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
}

tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 2})
tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 2})
assert.NoError(t, err)
if assert.Len(t, tokens, 1) {
assert.Equal(t, int64(2), tokens[0].UID)
assert.Equal(t, "Token A", tokens[0].Name)
}

tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 100})
tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 100})
assert.NoError(t, err)
assert.Empty(t, tokens)
}

func TestUpdateAccessToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
token.Name = "Token Z"

assert.NoError(t, auth_model.UpdateAccessToken(token))
assert.NoError(t, auth_model.UpdateAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
}

func TestDeleteAccessTokenByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())

token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)

assert.NoError(t, auth_model.DeleteAccessTokenByID(token.ID, 1))
assert.NoError(t, auth_model.DeleteAccessTokenByID(db.DefaultContext, token.ID, 1))
unittest.AssertNotExistsBean(t, token)

err = auth_model.DeleteAccessTokenByID(100, 100)
err = auth_model.DeleteAccessTokenByID(db.DefaultContext, 100, 100)
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
}
21 changes: 11 additions & 10 deletions models/auth/twofactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package auth

import (
"context"
"crypto/md5"
"crypto/subtle"
"encoding/base32"
Expand Down Expand Up @@ -121,22 +122,22 @@ func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
}

// NewTwoFactor creates a new two-factor authentication token.
func NewTwoFactor(t *TwoFactor) error {
_, err := db.GetEngine(db.DefaultContext).Insert(t)
func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
_, err := db.GetEngine(ctx).Insert(t)
return err
}

// UpdateTwoFactor updates a two-factor authentication token.
func UpdateTwoFactor(t *TwoFactor) error {
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}

// GetTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
twofa := &TwoFactor{}
has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa)
has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
if err != nil {
return nil, err
} else if !has {
Expand All @@ -147,13 +148,13 @@ func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {

// HasTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
func HasTwoFactorByUID(uid int64) (bool, error) {
return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{})
func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
}

// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
func DeleteTwoFactorByID(id, userID int64) error {
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{
func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
UID: userID,
})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions models/issues/comment.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,12 @@ func (c *Comment) LoadPoster(ctx context.Context) (err error) {
}

// AfterDelete is invoked from XORM after the object is deleted.
func (c *Comment) AfterDelete() {
func (c *Comment) AfterDelete(ctx context.Context) {
if c.ID <= 0 {
return
}

_, err := repo_model.DeleteAttachmentsByComment(c.ID, true)
_, err := repo_model.DeleteAttachmentsByComment(ctx, c.ID, true)
if err != nil {
log.Info("Could not delete files for comment %d on issue #%d: %s", c.ID, c.IssueID, err)
}
Expand Down
14 changes: 7 additions & 7 deletions models/issues/pull_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ type PullRequestsOptions struct {
MilestoneID int64
}

func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) {
sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", baseRepoID)
func listPullRequestStatement(ctx context.Context, baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) {
sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", baseRepoID)

sess.Join("INNER", "issue", "pull_request.issue_id = issue.id")
switch opts.State {
Expand Down Expand Up @@ -115,21 +115,21 @@ func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch
}

// GetPullRequestIDsByCheckStatus returns all pull requests according the special checking status.
func GetPullRequestIDsByCheckStatus(status PullRequestStatus) ([]int64, error) {
func GetPullRequestIDsByCheckStatus(ctx context.Context, status PullRequestStatus) ([]int64, error) {
prs := make([]int64, 0, 10)
return prs, db.GetEngine(db.DefaultContext).Table("pull_request").
return prs, db.GetEngine(ctx).Table("pull_request").
Where("status=?", status).
Cols("pull_request.id").
Find(&prs)
}

// PullRequests returns all pull requests for a base Repo by the given conditions
func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, int64, error) {
func PullRequests(ctx context.Context, baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, int64, error) {
if opts.Page <= 0 {
opts.Page = 1
}

countSession, err := listPullRequestStatement(baseRepoID, opts)
countSession, err := listPullRequestStatement(ctx, baseRepoID, opts)
if err != nil {
log.Error("listPullRequestStatement: %v", err)
return nil, 0, err
Expand All @@ -140,7 +140,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest,
return nil, maxResults, err
}

findSession, err := listPullRequestStatement(baseRepoID, opts)
findSession, err := listPullRequestStatement(ctx, baseRepoID, opts)
applySorts(findSession, opts.SortType, 0)
if err != nil {
log.Error("listPullRequestStatement: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions models/issues/pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestPullRequest_LoadHeadRepo(t *testing.T) {

func TestPullRequestsNewest(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
prs, count, err := issues_model.PullRequests(1, &issues_model.PullRequestsOptions{
prs, count, err := issues_model.PullRequests(db.DefaultContext, 1, &issues_model.PullRequestsOptions{
ListOptions: db.ListOptions{
Page: 1,
},
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestLoadRequestedReviewers(t *testing.T) {

func TestPullRequestsOldest(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
prs, count, err := issues_model.PullRequests(1, &issues_model.PullRequestsOptions{
prs, count, err := issues_model.PullRequests(db.DefaultContext, 1, &issues_model.PullRequestsOptions{
ListOptions: db.ListOptions{
Page: 1,
},
Expand Down
Loading