Skip to content

Commit

Permalink
Move almost all functions' parameter db.Engine to context.Context (#1…
Browse files Browse the repository at this point in the history
…9748)

* Move almost all functions' parameter db.Engine to context.Context
* remove some unnecessary wrap functions
  • Loading branch information
lunny authored May 20, 2022
1 parent d81e31a commit fd7d83a
Show file tree
Hide file tree
Showing 232 changed files with 1,448 additions and 2,093 deletions.
6 changes: 3 additions & 3 deletions cmd/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ func runChangePassword(c *cli.Context) error {
return errors.New("The password you chose is on a list of stolen passwords previously exposed in public data breaches. Please try again with a different password.\nFor more details, see https://haveibeenpwned.com/Passwords")
}
uname := c.String("username")
user, err := user_model.GetUserByName(uname)
user, err := user_model.GetUserByName(ctx, uname)
if err != nil {
return err
}
Expand Down Expand Up @@ -659,7 +659,7 @@ func runDeleteUser(c *cli.Context) error {
if c.IsSet("email") {
user, err = user_model.GetUserByEmail(c.String("email"))
} else if c.IsSet("username") {
user, err = user_model.GetUserByName(c.String("username"))
user, err = user_model.GetUserByName(ctx, c.String("username"))
} else {
user, err = user_model.GetUserByID(c.Int64("id"))
}
Expand Down Expand Up @@ -689,7 +689,7 @@ func runGenerateAccessToken(c *cli.Context) error {
return err
}

user, err := user_model.GetUserByName(c.String("username"))
user, err := user_model.GetUserByName(ctx, c.String("username"))
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions integrations/api_issue_tracked_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestAPIGetTrackedTimes(t *testing.T) {
resp := session.MakeRequest(t, req, http.StatusOK)
var apiTimes api.TrackedTimeList
DecodeJSON(t, resp, &apiTimes)
expect, err := models.GetTrackedTimes(&models.FindTrackedTimesOptions{IssueID: issue2.ID})
expect, err := models.GetTrackedTimes(db.DefaultContext, &models.FindTrackedTimesOptions{IssueID: issue2.ID})
assert.NoError(t, err)
assert.Len(t, apiTimes, 3)

Expand Down Expand Up @@ -83,15 +83,15 @@ func TestAPIDeleteTrackedTime(t *testing.T) {
session.MakeRequest(t, req, http.StatusNotFound)

// Reset time of user 2 on issue 2
trackedSeconds, err := models.GetTrackedSeconds(models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
trackedSeconds, err := models.GetTrackedSeconds(db.DefaultContext, models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
assert.NoError(t, err)
assert.Equal(t, int64(3661), trackedSeconds)

req = NewRequestf(t, "DELETE", "/api/v1/repos/%s/%s/issues/%d/times?token=%s", user2.Name, issue2.Repo.Name, issue2.Index, token)
session.MakeRequest(t, req, http.StatusNoContent)
session.MakeRequest(t, req, http.StatusNotFound)

trackedSeconds, err = models.GetTrackedSeconds(models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
trackedSeconds, err = models.GetTrackedSeconds(db.DefaultContext, models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
assert.NoError(t, err)
assert.Equal(t, int64(0), trackedSeconds)
}
Expand Down
2 changes: 1 addition & 1 deletion integrations/api_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func testAPIRepoMigrateConflict(t *testing.T, u *url.URL) {
defer util.RemoveAll(dstPath)
t.Run("CreateRepo", doAPICreateRepository(httpContext, false))

user, err := user_model.GetUserByName(httpContext.Username)
user, err := user_model.GetUserByName(db.DefaultContext, httpContext.Username)
assert.NoError(t, err)
userID := user.ID

Expand Down
4 changes: 2 additions & 2 deletions integrations/auth_ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func TestLDAPGroupTeamSyncAddMember(t *testing.T) {
addAuthSourceLDAP(t, "", "on", `{"cn=ship_crew,ou=people,dc=planetexpress,dc=com":{"org26": ["team11"]},"cn=admin_staff,ou=people,dc=planetexpress,dc=com": {"non-existent": ["non-existent"]}}`)
org, err := organization.GetOrgByName("org26")
assert.NoError(t, err)
team, err := organization.GetTeam(org.ID, "team11")
team, err := organization.GetTeam(db.DefaultContext, org.ID, "team11")
assert.NoError(t, err)
auth.SyncExternalUsers(context.Background(), true)
for _, gitLDAPUser := range gitLDAPUsers {
Expand Down Expand Up @@ -366,7 +366,7 @@ func TestLDAPGroupTeamSyncRemoveMember(t *testing.T) {
addAuthSourceLDAP(t, "", "on", `{"cn=dispatch,ou=people,dc=planetexpress,dc=com": {"org26": ["team11"]}}`)
org, err := organization.GetOrgByName("org26")
assert.NoError(t, err)
team, err := organization.GetTeam(org.ID, "team11")
team, err := organization.GetTeam(db.DefaultContext, org.ID, "team11")
assert.NoError(t, err)
loginUserWithPassword(t, gitLDAPUsers[0].UserName, gitLDAPUsers[0].Password)
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{
Expand Down
3 changes: 2 additions & 1 deletion integrations/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/perm"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"
Expand Down Expand Up @@ -438,7 +439,7 @@ func doProtectBranch(ctx APITestContext, branch, userToWhitelist, unprotectedFil
})
ctx.Session.MakeRequest(t, req, http.StatusSeeOther)
} else {
user, err := user_model.GetUserByName(userToWhitelist)
user, err := user_model.GetUserByName(db.DefaultContext, userToWhitelist)
assert.NoError(t, err)
// Change branch to protected
req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/settings/branches/%s", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame), url.PathEscape(branch)), map[string]string{
Expand Down
2 changes: 1 addition & 1 deletion integrations/mirror_pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestMirrorPull(t *testing.T) {
IsTag: true,
}, nil, ""))

_, err = repo_model.GetMirrorByRepoID(mirror.ID)
_, err = repo_model.GetMirrorByRepoID(ctx, mirror.ID)
assert.NoError(t, err)

ok := mirror_service.SyncPullMirror(ctx, mirror.ID)
Expand Down
3 changes: 2 additions & 1 deletion integrations/pull_merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
Expand Down Expand Up @@ -407,7 +408,7 @@ func TestConflictChecking(t *testing.T) {
assert.NoError(t, err)

issue := unittest.AssertExistsAndLoadBean(t, &models.Issue{Title: "PR with conflict!"}).(*models.Issue)
conflictingPR, err := models.GetPullRequestByIssueID(issue.ID)
conflictingPR, err := models.GetPullRequestByIssueID(db.DefaultContext, issue.ID)
assert.NoError(t, err)

// Ensure conflictedFiles is populated.
Expand Down
3 changes: 2 additions & 1 deletion integrations/pull_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
Expand Down Expand Up @@ -165,7 +166,7 @@ func createOutdatedPR(t *testing.T, actor, forkOrg *user_model.User) *models.Pul
assert.NoError(t, err)

issue := unittest.AssertExistsAndLoadBean(t, &models.Issue{Title: "Test Pull -to-update-"}).(*models.Issue)
pr, err := models.GetPullRequestByIssueID(issue.ID)
pr, err := models.GetPullRequestByIssueID(db.DefaultContext, issue.ID)
assert.NoError(t, err)

return pr
Expand Down
12 changes: 5 additions & 7 deletions models/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,8 @@ func (a *Action) getCommentLink(ctx context.Context) string {
if a == nil {
return "#"
}
e := db.GetEngine(ctx)
if a.Comment == nil && a.CommentID != 0 {
a.Comment, _ = getCommentByID(e, a.CommentID)
a.Comment, _ = GetCommentByID(ctx, a.CommentID)
}
if a.Comment != nil {
return a.Comment.HTMLURL()
Expand All @@ -239,7 +238,7 @@ func (a *Action) getCommentLink(ctx context.Context) string {
return "#"
}

issue, err := getIssueByID(e, issueID)
issue, err := getIssueByID(ctx, issueID)
if err != nil {
return "#"
}
Expand Down Expand Up @@ -340,8 +339,7 @@ func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, error) {
return nil, err
}

e := db.GetEngine(ctx)
sess := e.Where(cond).
sess := db.GetEngine(ctx).Where(cond).
Select("`action`.*"). // this line will avoid select other joined table's columns
Join("INNER", "repository", "`repository`.id = `action`.repo_id")

Expand All @@ -354,7 +352,7 @@ func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, error) {
return nil, fmt.Errorf("Find: %v", err)
}

if err := ActionList(actions).loadAttributes(e); err != nil {
if err := ActionList(actions).loadAttributes(ctx); err != nil {
return nil, fmt.Errorf("LoadAttributes: %v", err)
}

Expand Down Expand Up @@ -504,7 +502,7 @@ func notifyWatchers(ctx context.Context, actions ...*Action) error {
permIssue = make([]bool, len(watchers))
permPR = make([]bool, len(watchers))
for i, watcher := range watchers {
user, err := user_model.GetUserByIDEngine(e, watcher.UserID)
user, err := user_model.GetUserByIDCtx(ctx, watcher.UserID)
if err != nil {
permCode[i] = false
permIssue[i] = false
Expand Down
21 changes: 11 additions & 10 deletions models/action_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package models

import (
"context"
"fmt"

"code.gitea.io/gitea/models/db"
Expand All @@ -26,14 +27,14 @@ func (actions ActionList) getUserIDs() []int64 {
return container.KeysInt64(userIDs)
}

func (actions ActionList) loadUsers(e db.Engine) (map[int64]*user_model.User, error) {
func (actions ActionList) loadUsers(ctx context.Context) (map[int64]*user_model.User, error) {
if len(actions) == 0 {
return nil, nil
}

userIDs := actions.getUserIDs()
userMaps := make(map[int64]*user_model.User, len(userIDs))
err := e.
err := db.GetEngine(ctx).
In("id", userIDs).
Find(&userMaps)
if err != nil {
Expand All @@ -56,14 +57,14 @@ func (actions ActionList) getRepoIDs() []int64 {
return container.KeysInt64(repoIDs)
}

func (actions ActionList) loadRepositories(e db.Engine) error {
func (actions ActionList) loadRepositories(ctx context.Context) error {
if len(actions) == 0 {
return nil
}

repoIDs := actions.getRepoIDs()
repoMaps := make(map[int64]*repo_model.Repository, len(repoIDs))
err := e.In("id", repoIDs).Find(&repoMaps)
err := db.GetEngine(ctx).In("id", repoIDs).Find(&repoMaps)
if err != nil {
return fmt.Errorf("find repository: %v", err)
}
Expand All @@ -74,7 +75,7 @@ func (actions ActionList) loadRepositories(e db.Engine) error {
return nil
}

func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_model.User) (err error) {
func (actions ActionList) loadRepoOwner(ctx context.Context, userMap map[int64]*user_model.User) (err error) {
if userMap == nil {
userMap = make(map[int64]*user_model.User)
}
Expand All @@ -85,7 +86,7 @@ func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_mod
}
repoOwner, ok := userMap[action.Repo.OwnerID]
if !ok {
repoOwner, err = user_model.GetUserByID(action.Repo.OwnerID)
repoOwner, err = user_model.GetUserByIDCtx(ctx, action.Repo.OwnerID)
if err != nil {
if user_model.IsErrUserNotExist(err) {
continue
Expand All @@ -101,15 +102,15 @@ func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_mod
}

// loadAttributes loads all attributes
func (actions ActionList) loadAttributes(e db.Engine) error {
userMap, err := actions.loadUsers(e)
func (actions ActionList) loadAttributes(ctx context.Context) error {
userMap, err := actions.loadUsers(ctx)
if err != nil {
return err
}

if err := actions.loadRepositories(e); err != nil {
if err := actions.loadRepositories(ctx); err != nil {
return err
}

return actions.loadRepoOwner(e, userMap)
return actions.loadRepoOwner(ctx, userMap)
}
8 changes: 4 additions & 4 deletions models/asymkey/gpg_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,16 @@ func parseGPGKey(ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, erro
}

// deleteGPGKey does the actual key deletion
func deleteGPGKey(e db.Engine, keyID string) (int64, error) {
func deleteGPGKey(ctx context.Context, keyID string) (int64, error) {
if keyID == "" {
return 0, fmt.Errorf("empty KeyId forbidden") // Should never happen but just to be sure
}
// Delete imported key
n, err := e.Where("key_id=?", keyID).Delete(new(GPGKeyImport))
n, err := db.GetEngine(ctx).Where("key_id=?", keyID).Delete(new(GPGKeyImport))
if err != nil {
return n, err
}
return e.Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey))
return db.GetEngine(ctx).Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey))
}

// DeleteGPGKey deletes GPG key information in database.
Expand All @@ -231,7 +231,7 @@ func DeleteGPGKey(doer *user_model.User, id int64) (err error) {
}
defer committer.Close()

if _, err = deleteGPGKey(db.GetEngine(ctx), key.KeyID); err != nil {
if _, err = deleteGPGKey(ctx, key.KeyID); err != nil {
return err
}

Expand Down
17 changes: 9 additions & 8 deletions models/asymkey/gpg_key_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package asymkey

import (
"context"
"strings"

"code.gitea.io/gitea/models/db"
Expand All @@ -29,36 +30,36 @@ import (
// This file contains functions relating to adding GPG Keys

// addGPGKey add key, import and subkeys to database
func addGPGKey(e db.Engine, key *GPGKey, content string) (err error) {
func addGPGKey(ctx context.Context, key *GPGKey, content string) (err error) {
// Add GPGKeyImport
if _, err = e.Insert(GPGKeyImport{
if err = db.Insert(ctx, &GPGKeyImport{
KeyID: key.KeyID,
Content: content,
}); err != nil {
return err
}
// Save GPG primary key.
if _, err = e.Insert(key); err != nil {
if err = db.Insert(ctx, key); err != nil {
return err
}
// Save GPG subs key.
for _, subkey := range key.SubsKey {
if err := addGPGSubKey(e, subkey); err != nil {
if err := addGPGSubKey(ctx, subkey); err != nil {
return err
}
}
return nil
}

// addGPGSubKey add subkeys to database
func addGPGSubKey(e db.Engine, key *GPGKey) (err error) {
func addGPGSubKey(ctx context.Context, key *GPGKey) (err error) {
// Save GPG primary key.
if _, err = e.Insert(key); err != nil {
if err = db.Insert(ctx, key); err != nil {
return err
}
// Save GPG subs key.
for _, subkey := range key.SubsKey {
if err := addGPGSubKey(e, subkey); err != nil {
if err := addGPGSubKey(ctx, subkey); err != nil {
return err
}
}
Expand Down Expand Up @@ -158,7 +159,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
return nil, err
}

if err = addGPGKey(db.GetEngine(ctx), key, content); err != nil {
if err = addGPGKey(ctx, key, content); err != nil {
return nil, err
}
keys = append(keys, key)
Expand Down
Loading

0 comments on commit fd7d83a

Please sign in to comment.