Skip to content

Commit

Permalink
Penultimate round of db.DefaultContext refactor (#27414)
Browse files Browse the repository at this point in the history
Part of #27065

---------

Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
  • Loading branch information
JakobDev and lunny authored Oct 11, 2023
1 parent 50166d1 commit ebe803e
Show file tree
Hide file tree
Showing 136 changed files with 428 additions and 421 deletions.
4 changes: 2 additions & 2 deletions cmd/admin_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func runListAuth(c *cli.Context) error {
return err
}

authSources, err := auth_model.Sources()
authSources, err := auth_model.Sources(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -100,7 +100,7 @@ func runDeleteAuth(c *cli.Context) error {
return err
}

source, err := auth_model.GetSourceByID(c.Int64("id"))
source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
if err != nil {
return err
}
Expand Down
22 changes: 11 additions & 11 deletions cmd/admin_auth_ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ import (
type (
authService struct {
initDB func(ctx context.Context) error
createAuthSource func(*auth.Source) error
updateAuthSource func(*auth.Source) error
getAuthSourceByID func(id int64) (*auth.Source, error)
createAuthSource func(context.Context, *auth.Source) error
updateAuthSource func(context.Context, *auth.Source) error
getAuthSourceByID func(ctx context.Context, id int64) (*auth.Source, error)
}
)

Expand Down Expand Up @@ -289,12 +289,12 @@ func findLdapSecurityProtocolByName(name string) (ldap.SecurityProtocol, bool) {

// getAuthSource gets the login source by its id defined in the command line flags.
// It returns an error if the id is not set, does not match any source or if the source is not of expected type.
func (a *authService) getAuthSource(c *cli.Context, authType auth.Type) (*auth.Source, error) {
func (a *authService) getAuthSource(ctx context.Context, c *cli.Context, authType auth.Type) (*auth.Source, error) {
if err := argsSet(c, "id"); err != nil {
return nil, err
}

authSource, err := a.getAuthSourceByID(c.Int64("id"))
authSource, err := a.getAuthSourceByID(ctx, c.Int64("id"))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -332,7 +332,7 @@ func (a *authService) addLdapBindDn(c *cli.Context) error {
return err
}

return a.createAuthSource(authSource)
return a.createAuthSource(ctx, authSource)
}

// updateLdapBindDn updates a new LDAP via Bind DN authentication source.
Expand All @@ -344,7 +344,7 @@ func (a *authService) updateLdapBindDn(c *cli.Context) error {
return err
}

authSource, err := a.getAuthSource(c, auth.LDAP)
authSource, err := a.getAuthSource(ctx, c, auth.LDAP)
if err != nil {
return err
}
Expand All @@ -354,7 +354,7 @@ func (a *authService) updateLdapBindDn(c *cli.Context) error {
return err
}

return a.updateAuthSource(authSource)
return a.updateAuthSource(ctx, authSource)
}

// addLdapSimpleAuth adds a new LDAP (simple auth) authentication source.
Expand Down Expand Up @@ -383,7 +383,7 @@ func (a *authService) addLdapSimpleAuth(c *cli.Context) error {
return err
}

return a.createAuthSource(authSource)
return a.createAuthSource(ctx, authSource)
}

// updateLdapBindDn updates a new LDAP (simple auth) authentication source.
Expand All @@ -395,7 +395,7 @@ func (a *authService) updateLdapSimpleAuth(c *cli.Context) error {
return err
}

authSource, err := a.getAuthSource(c, auth.DLDAP)
authSource, err := a.getAuthSource(ctx, c, auth.DLDAP)
if err != nil {
return err
}
Expand All @@ -405,5 +405,5 @@ func (a *authService) updateLdapSimpleAuth(c *cli.Context) error {
return err
}

return a.updateAuthSource(authSource)
return a.updateAuthSource(ctx, authSource)
}
24 changes: 12 additions & 12 deletions cmd/admin_auth_ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,15 @@ func TestAddLdapBindDn(t *testing.T) {
initDB: func(context.Context) error {
return nil
},
createAuthSource: func(authSource *auth.Source) error {
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
createdAuthSource = authSource
return nil
},
updateAuthSource: func(authSource *auth.Source) error {
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call updateAuthSource", n)
return nil
},
getAuthSourceByID: func(id int64) (*auth.Source, error) {
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
assert.FailNow(t, "case %d: should not call getAuthSourceByID", n)
return nil, nil
},
Expand Down Expand Up @@ -441,15 +441,15 @@ func TestAddLdapSimpleAuth(t *testing.T) {
initDB: func(context.Context) error {
return nil
},
createAuthSource: func(authSource *auth.Source) error {
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
createdAuthSource = authSource
return nil
},
updateAuthSource: func(authSource *auth.Source) error {
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call updateAuthSource", n)
return nil
},
getAuthSourceByID: func(id int64) (*auth.Source, error) {
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
assert.FailNow(t, "case %d: should not call getAuthSourceByID", n)
return nil, nil
},
Expand Down Expand Up @@ -896,15 +896,15 @@ func TestUpdateLdapBindDn(t *testing.T) {
initDB: func(context.Context) error {
return nil
},
createAuthSource: func(authSource *auth.Source) error {
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call createAuthSource", n)
return nil
},
updateAuthSource: func(authSource *auth.Source) error {
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
updatedAuthSource = authSource
return nil
},
getAuthSourceByID: func(id int64) (*auth.Source, error) {
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
if c.id != 0 {
assert.Equal(t, c.id, id, "case %d: wrong id", n)
}
Expand Down Expand Up @@ -1286,15 +1286,15 @@ func TestUpdateLdapSimpleAuth(t *testing.T) {
initDB: func(context.Context) error {
return nil
},
createAuthSource: func(authSource *auth.Source) error {
createAuthSource: func(ctx context.Context, authSource *auth.Source) error {
assert.FailNow(t, "case %d: should not call createAuthSource", n)
return nil
},
updateAuthSource: func(authSource *auth.Source) error {
updateAuthSource: func(ctx context.Context, authSource *auth.Source) error {
updatedAuthSource = authSource
return nil
},
getAuthSourceByID: func(id int64) (*auth.Source, error) {
getAuthSourceByID: func(ctx context.Context, id int64) (*auth.Source, error) {
if c.id != 0 {
assert.Equal(t, c.id, id, "case %d: wrong id", n)
}
Expand Down
6 changes: 3 additions & 3 deletions cmd/admin_auth_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func runAddOauth(c *cli.Context) error {
}
}

return auth_model.CreateSource(&auth_model.Source{
return auth_model.CreateSource(ctx, &auth_model.Source{
Type: auth_model.OAuth2,
Name: c.String("name"),
IsActive: true,
Expand All @@ -203,7 +203,7 @@ func runUpdateOauth(c *cli.Context) error {
return err
}

source, err := auth_model.GetSourceByID(c.Int64("id"))
source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
if err != nil {
return err
}
Expand Down Expand Up @@ -294,5 +294,5 @@ func runUpdateOauth(c *cli.Context) error {
oAuth2Config.CustomURLMapping = customURLMapping
source.Cfg = oAuth2Config

return auth_model.UpdateSource(source)
return auth_model.UpdateSource(ctx, source)
}
6 changes: 3 additions & 3 deletions cmd/admin_auth_stmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func runAddSMTP(c *cli.Context) error {
smtpConfig.Auth = "PLAIN"
}

return auth_model.CreateSource(&auth_model.Source{
return auth_model.CreateSource(ctx, &auth_model.Source{
Type: auth_model.SMTP,
Name: c.String("name"),
IsActive: active,
Expand All @@ -176,7 +176,7 @@ func runUpdateSMTP(c *cli.Context) error {
return err
}

source, err := auth_model.GetSourceByID(c.Int64("id"))
source, err := auth_model.GetSourceByID(ctx, c.Int64("id"))
if err != nil {
return err
}
Expand All @@ -197,5 +197,5 @@ func runUpdateSMTP(c *cli.Context) error {

source.Cfg = smtpConfig

return auth_model.UpdateSource(source)
return auth_model.UpdateSource(ctx, source)
}
2 changes: 1 addition & 1 deletion models/actions/run_job_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (jobs ActionJobList) LoadRuns(ctx context.Context, withRepo bool) error {
for _, r := range runs {
runsList = append(runsList, r)
}
return runsList.LoadRepos()
return runsList.LoadRepos(ctx)
}
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions models/actions/run_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ func (runs RunList) LoadTriggerUser(ctx context.Context) error {
return nil
}

func (runs RunList) LoadRepos() error {
func (runs RunList) LoadRepos(ctx context.Context) error {
repoIDs := runs.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs)
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions models/actions/schedule_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ func (schedules ScheduleList) LoadTriggerUser(ctx context.Context) error {
return nil
}

func (schedules ScheduleList) LoadRepos() error {
func (schedules ScheduleList) LoadRepos(ctx context.Context) error {
repoIDs := schedules.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs)
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions models/actions/schedule_spec_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ func (specs SpecList) GetRepoIDs() []int64 {
return ids.Values()
}

func (specs SpecList) LoadRepos() error {
func (specs SpecList) LoadRepos(ctx context.Context) error {
repoIDs := specs.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(repoIDs)
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/activities/statistic.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) {
stats.Counter.Follow, _ = e.Count(new(user_model.Follow))
stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror))
stats.Counter.Release, _ = e.Count(new(repo_model.Release))
stats.Counter.AuthSource = auth.CountSources()
stats.Counter.AuthSource = auth.CountSources(ctx)
stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook))
stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone))
stats.Counter.Label, _ = e.Count(new(issues_model.Label))
Expand Down
Loading

0 comments on commit ebe803e

Please sign in to comment.