Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove GetByBean method because sometimes it's danger when query condition parameter is zero and also introduce new generic methods #28220

Merged
merged 19 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions models/asymkey/ssh_key_deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,22 @@ func AddDeployKey(ctx context.Context, repoID int64, name, content string, readO
}
defer committer.Close()

pkey, err := db.Get[PublicKey](ctx, builder.Eq{"fingerprint": fingerprint})
if err != nil && !db.IsErrNotExist(err) {
pkey, exist, err := db.Get[PublicKey](ctx, builder.Eq{"fingerprint": fingerprint})
if err != nil {
return nil, err
}

if err == nil {
} else if exist {
if pkey.Type != KeyTypeDeploy {
return nil, ErrKeyAlreadyExist{0, fingerprint, ""}
}
} else {
// First time use this deploy key.
pkey.Mode = accessMode
pkey.Type = KeyTypeDeploy
pkey.Content = content
pkey.Name = name
pkey = &PublicKey{
Fingerprint: fingerprint,
Mode: accessMode,
Type: KeyTypeDeploy,
Content: content,
Name: name,
}
if err = addKey(ctx, pkey); err != nil {
return nil, fmt.Errorf("addKey: %w", err)
}
Expand All @@ -161,24 +162,22 @@ func AddDeployKey(ctx context.Context, repoID int64, name, content string, readO

// GetDeployKeyByID returns deploy key by given ID.
func GetDeployKeyByID(ctx context.Context, id int64) (*DeployKey, error) {
key, err := db.GetByID[DeployKey](ctx, id)
key, exist, err := db.GetByID[DeployKey](ctx, id)
if err != nil {
if db.IsErrNotExist(err) {
return nil, ErrDeployKeyNotExist{0, 0, id}
}
return nil, err
} else if !exist {
return nil, ErrDeployKeyNotExist{id, 0, 0}
}
return key, nil
}

// GetDeployKeyByRepo returns deploy key by given public key ID and repository ID.
func GetDeployKeyByRepo(ctx context.Context, keyID, repoID int64) (*DeployKey, error) {
key, err := db.Get[DeployKey](ctx, builder.Eq{"key_id": keyID, "repo_id": repoID})
key, exist, err := db.Get[DeployKey](ctx, builder.Eq{"key_id": keyID, "repo_id": repoID})
if err != nil {
if db.IsErrNotExist(err) {
return nil, ErrDeployKeyNotExist{0, keyID, repoID}
}
return nil, err
} else if !exist {
return nil, ErrDeployKeyNotExist{0, keyID, repoID}
}
return key, nil
}
Expand Down
15 changes: 7 additions & 8 deletions models/auth/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,14 @@ func ReadSession(ctx context.Context, key string) (*Session, error) {
}
defer committer.Close()

session, err := db.Get[Session](ctx, builder.Eq{"key": key})
session, exist, err := db.Get[Session](ctx, builder.Eq{"key": key})
if err != nil {
if db.IsErrNotExist(err) {
session.Expiry = timeutil.TimeStampNow()
if err := db.Insert(ctx, &session); err != nil {
return nil, err
}
}
return nil, err
} else if !exist {
session.Expiry = timeutil.TimeStampNow()
if err := db.Insert(ctx, &session); err != nil {
return nil, err
}
}

return session, committer.Commit()
Expand Down Expand Up @@ -97,7 +96,7 @@ func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, er
return nil, err
}

s, err := db.Get[Session](ctx, builder.Eq{"key": newKey})
s, _, err := db.Get[Session](ctx, builder.Eq{"key": newKey})
if err != nil {
// is not exist, it should be impossible
return nil, err
Expand Down
24 changes: 16 additions & 8 deletions models/db/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,29 +173,37 @@ func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) {
return GetEngine(ctx).Exec(sqlAndArgs...)
}

func Get[T any](ctx context.Context, cond builder.Cond) (*T, error) {
func Get[T any](ctx context.Context, cond builder.Cond) (*T, bool, error) {
lunny marked this conversation as resolved.
Show resolved Hide resolved
if !cond.IsValid() {
return nil, false, ErrConditionRequired{}
}

var bean T
has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean)
lunny marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
return nil, false, err
} else if !has {
return nil, ErrNotExist{Resource: TableName(bean)}
return nil, false, nil
}
return &bean, nil
return &bean, true, nil
}

func GetByID[T any](ctx context.Context, id int64) (*T, error) {
func GetByID[T any](ctx context.Context, id int64) (*T, bool, error) {
lunny marked this conversation as resolved.
Show resolved Hide resolved
var bean T
has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean)
if err != nil {
return nil, err
return nil, false, err
} else if !has {
return nil, ErrNotExist{Resource: TableName(bean), ID: id}
return nil, false, nil
}
return &bean, nil
return &bean, true, nil
}

func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) {
if !cond.IsValid() {
return false, ErrConditionRequired{}
}

var bean T
return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean)
}
Expand Down
18 changes: 18 additions & 0 deletions models/db/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,21 @@ func (err ErrNotExist) Error() string {
func (err ErrNotExist) Unwrap() error {
return util.ErrNotExist
}

// ErrConditionRequired represents an error which require condition.
type ErrConditionRequired struct{}

// IsErrConditionRequired checks if an error is an ErrConditionRequired
func IsErrConditionRequired(err error) bool {
_, ok := err.(ErrConditionRequired)
return ok
}

func (err ErrConditionRequired) Error() string {
return "condition is required"
}

// Unwrap unwraps this as a ErrNotExist err
func (err ErrConditionRequired) Unwrap() error {
return util.ErrInvalidArgument
}
8 changes: 3 additions & 5 deletions models/git/lfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,10 @@ func NewLFSMetaObject(ctx context.Context, repoID int64, p lfs.Pointer) (*LFSMet
}
defer committer.Close()

m, err := db.Get[LFSMetaObject](ctx, builder.Eq{"repository_id": repoID, "oid": p.Oid})
if err != nil && !db.IsErrNotExist(err) {
m, exist, err := db.Get[LFSMetaObject](ctx, builder.Eq{"repository_id": repoID, "oid": p.Oid})
if err != nil {
return nil, err
}

if err == nil {
} else if exist {
m.Existing = true
return m, committer.Commit()
}
Expand Down
15 changes: 7 additions & 8 deletions models/git/protected_branch.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,24 +275,23 @@ func (protectBranch *ProtectedBranch) IsUnprotectedFile(patterns []glob.Glob, pa

// GetProtectedBranchRuleByName getting protected branch rule by name
func GetProtectedBranchRuleByName(ctx context.Context, repoID int64, ruleName string) (*ProtectedBranch, error) {
rel, err := db.Get[ProtectedBranch](ctx, builder.Eq{"repo_id": repoID, "rule_name": ruleName})
// branch_name is legacy name, it actually is rule name
rel, exist, err := db.Get[ProtectedBranch](ctx, builder.Eq{"repo_id": repoID, "branch_name": ruleName})
if err != nil {
if db.IsErrNotExist(err) {
return nil, nil
}
return nil, err
} else if !exist {
return nil, nil
}
return rel, nil
}

// GetProtectedBranchRuleByID getting protected branch rule by rule ID
func GetProtectedBranchRuleByID(ctx context.Context, repoID, ruleID int64) (*ProtectedBranch, error) {
rel, err := db.Get[ProtectedBranch](ctx, builder.Eq{"repo_id": repoID, "id": ruleID})
rel, exist, err := db.Get[ProtectedBranch](ctx, builder.Eq{"repo_id": repoID, "id": ruleID})
if err != nil {
if db.IsErrNotExist(err) {
return nil, nil
}
return nil, err
} else if !exist {
return nil, nil
}
return rel, nil
}
Expand Down
28 changes: 12 additions & 16 deletions models/issues/label.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,11 @@ func GetLabelInRepoByName(ctx context.Context, repoID int64, labelName string) (
return nil, ErrRepoLabelNotExist{0, repoID}
}

l, err := db.Get[Label](ctx, builder.Eq{"name": labelName, "repo_id": repoID})
l, exist, err := db.Get[Label](ctx, builder.Eq{"name": labelName, "repo_id": repoID})
if err != nil {
if db.IsErrNotExist(err) {
return nil, ErrRepoLabelNotExist{0, repoID}
}
return nil, err
} else if !exist {
return nil, ErrRepoLabelNotExist{0, repoID}
}
return l, nil
}
Expand All @@ -320,12 +319,11 @@ func GetLabelInRepoByID(ctx context.Context, repoID, labelID int64) (*Label, err
return nil, ErrRepoLabelNotExist{labelID, repoID}
}

l, err := db.Get[Label](ctx, builder.Eq{"id": labelID, "repo_id": repoID})
l, exist, err := db.Get[Label](ctx, builder.Eq{"id": labelID, "repo_id": repoID})
if err != nil {
if db.IsErrNotExist(err) {
return nil, ErrRepoLabelNotExist{l.ID, l.RepoID}
}
return nil, err
} else if !exist {
return nil, ErrRepoLabelNotExist{labelID, repoID}
}
return l, nil
}
Expand Down Expand Up @@ -402,12 +400,11 @@ func GetLabelInOrgByName(ctx context.Context, orgID int64, labelName string) (*L
return nil, ErrOrgLabelNotExist{0, orgID}
}

l, err := db.Get[Label](ctx, builder.Eq{"name": labelName, "org_id": orgID})
l, exist, err := db.Get[Label](ctx, builder.Eq{"name": labelName, "org_id": orgID})
if err != nil {
if db.IsErrNotExist(err) {
return nil, ErrOrgLabelNotExist{0, orgID}
}
return nil, err
} else if !exist {
return nil, ErrOrgLabelNotExist{0, orgID}
}
return l, nil
}
Expand All @@ -418,12 +415,11 @@ func GetLabelInOrgByID(ctx context.Context, orgID, labelID int64) (*Label, error
return nil, ErrOrgLabelNotExist{labelID, orgID}
}

l, err := db.Get[Label](ctx, builder.Eq{"id": labelID, "org_id": orgID})
l, exist, err := db.Get[Label](ctx, builder.Eq{"id": labelID, "org_id": orgID})
if err != nil {
if db.IsErrNotExist(err) {
return nil, ErrOrgLabelNotExist{l.ID, l.OrgID}
}
return nil, err
} else if !exist {
return nil, ErrOrgLabelNotExist{labelID, orgID}
}
return l, nil
}
Expand Down
7 changes: 3 additions & 4 deletions models/issues/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,12 +660,11 @@ func GetPullRequestByIssueIDWithNoAttributes(ctx context.Context, issueID int64)

// GetPullRequestByIssueID returns pull request by given issue ID.
func GetPullRequestByIssueID(ctx context.Context, issueID int64) (*PullRequest, error) {
pr, err := db.Get[PullRequest](ctx, builder.Eq{"issue_id": issueID})
pr, exist, err := db.Get[PullRequest](ctx, builder.Eq{"issue_id": issueID})
if err != nil {
if db.IsErrNotExist(err) {
return nil, ErrPullRequestNotExist{0, issueID, 0, 0, "", ""}
}
return nil, err
} else if !exist {
return nil, ErrPullRequestNotExist{0, issueID, 0, 0, "", ""}
}
return pr, pr.LoadAttributes(ctx)
}
Expand Down
7 changes: 3 additions & 4 deletions models/organization/team.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,11 @@ func IsUsableTeamName(name string) error {

// GetTeam returns team by given team name and organization.
func GetTeam(ctx context.Context, orgID int64, name string) (*Team, error) {
t, err := db.Get[Team](ctx, builder.Eq{"org_id": orgID, "lower_name": strings.ToLower(name)})
t, exist, err := db.Get[Team](ctx, builder.Eq{"org_id": orgID, "lower_name": strings.ToLower(name)})
if err != nil {
if db.IsErrNotExist(err) {
return nil, ErrTeamNotExist{orgID, 0, name}
}
return nil, err
} else if !exist {
return nil, ErrTeamNotExist{orgID, 0, name}
}
return t, nil
}
Expand Down
7 changes: 3 additions & 4 deletions models/perm/access/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,11 @@ func accessLevel(ctx context.Context, user *user_model.User, repo *repo_model.Re
return perm.AccessModeOwner, nil
}

a, err := db.Get[Access](ctx, builder.Eq{"user_id": userID, "repo_id": repo.ID})
a, exist, err := db.Get[Access](ctx, builder.Eq{"user_id": userID, "repo_id": repo.ID})
if err != nil {
if db.IsErrNotExist(err) {
return mode, nil
}
return mode, err
} else if !exist {
return mode, nil
}
return a.Mode, nil
}
Expand Down
15 changes: 7 additions & 8 deletions models/system/setting.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,15 @@ func init() {
const keyRevision = "revision"

func GetRevision(ctx context.Context) int {
revision, err := db.Get[Setting](ctx, builder.Eq{"setting_key": keyRevision})
revision, exist, err := db.Get[Setting](ctx, builder.Eq{"setting_key": keyRevision})
if err != nil {
if db.IsErrNotExist(err) {
err = db.Insert(ctx, &Setting{SettingKey: keyRevision, Version: 1})
if err != nil {
return 0
}
return 1
}
return 0
} else if !exist {
err = db.Insert(ctx, &Setting{SettingKey: keyRevision, Version: 1})
if err != nil {
return 0
}
return 1
}
if revision.Version <= 0 || revision.Version >= math.MaxInt-1 {
_, err = db.Exec(ctx, "UPDATE system_setting SET version=1 WHERE setting_key=?", keyRevision)
Expand Down
11 changes: 4 additions & 7 deletions models/system/setting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,11 @@ func TestSettings(t *testing.T) {
assert.EqualValues(t, "false", settings[keyName])

// setting the same value should not trigger DuplicateKey error, and the "version" should be increased
setting := &system.Setting{SettingKey: keyName}
_, err = db.GetByBean(db.DefaultContext, setting)
assert.NoError(t, err)
assert.EqualValues(t, 2, setting.Version)
err = system.SetSettings(db.DefaultContext, map[string]string{keyName: "false"})
assert.NoError(t, err)
setting = &system.Setting{SettingKey: keyName}
_, err = db.GetByBean(db.DefaultContext, setting)

rev, settings, err = system.GetAllSettings(db.DefaultContext)
assert.NoError(t, err)
assert.EqualValues(t, 3, setting.Version)
assert.Len(t, settings, 2)
assert.EqualValues(t, 4, rev)
}
Loading