Skip to content

Commit

Permalink
Remove GetByBean method because sometimes it's danger when query cond…
Browse files Browse the repository at this point in the history
…ition parameter is zero and also introduce new generic methods (go-gitea#28220)

The function `GetByBean` has an obvious defect that when the fields are
empty values, it will be ignored. Then users will get a wrong result
which is possibly used to make a security problem.

To avoid the possibility, this PR removed function `GetByBean` and all
references.
And some new generic functions have been introduced to be used.

The recommand usage like below.

```go
// if query an object according id
obj, err := db.GetByID[Object](ctx, id)
// query with other conditions
obj, err := db.Get[Object](ctx, builder.Eq{"a": a, "b":b})
```
  • Loading branch information
lunny authored Dec 7, 2023
1 parent beb71f5 commit dd30d9d
Show file tree
Hide file tree
Showing 28 changed files with 189 additions and 174 deletions.
33 changes: 13 additions & 20 deletions models/asymkey/ssh_key_deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,24 +131,22 @@ func AddDeployKey(ctx context.Context, repoID int64, name, content string, readO
}
defer committer.Close()

pkey := &PublicKey{
Fingerprint: fingerprint,
}
has, err := db.GetByBean(ctx, pkey)
pkey, exist, err := db.Get[PublicKey](ctx, builder.Eq{"fingerprint": fingerprint})
if err != nil {
return nil, err
}

if has {
} else if exist {
if pkey.Type != KeyTypeDeploy {
return nil, ErrKeyAlreadyExist{0, fingerprint, ""}
}
} else {
// First time use this deploy key.
pkey.Mode = accessMode
pkey.Type = KeyTypeDeploy
pkey.Content = content
pkey.Name = name
pkey = &PublicKey{
Fingerprint: fingerprint,
Mode: accessMode,
Type: KeyTypeDeploy,
Content: content,
Name: name,
}
if err = addKey(ctx, pkey); err != nil {
return nil, fmt.Errorf("addKey: %w", err)
}
Expand All @@ -164,26 +162,21 @@ func AddDeployKey(ctx context.Context, repoID int64, name, content string, readO

// GetDeployKeyByID returns deploy key by given ID.
func GetDeployKeyByID(ctx context.Context, id int64) (*DeployKey, error) {
key := new(DeployKey)
has, err := db.GetEngine(ctx).ID(id).Get(key)
key, exist, err := db.GetByID[DeployKey](ctx, id)
if err != nil {
return nil, err
} else if !has {
} else if !exist {
return nil, ErrDeployKeyNotExist{id, 0, 0}
}
return key, nil
}

// GetDeployKeyByRepo returns deploy key by given public key ID and repository ID.
func GetDeployKeyByRepo(ctx context.Context, keyID, repoID int64) (*DeployKey, error) {
key := &DeployKey{
KeyID: keyID,
RepoID: repoID,
}
has, err := db.GetByBean(ctx, key)
key, exist, err := db.Get[DeployKey](ctx, builder.Eq{"key_id": keyID, "repo_id": repoID})
if err != nil {
return nil, err
} else if !has {
} else if !exist {
return nil, ErrDeployKeyNotExist{0, keyID, repoID}
}
return key, nil
Expand Down
5 changes: 2 additions & 3 deletions models/asymkey/ssh_key_fingerprint.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"code.gitea.io/gitea/modules/util"

"golang.org/x/crypto/ssh"
"xorm.io/builder"
)

// ___________.__ .__ __
Expand All @@ -31,9 +32,7 @@ import (
// checkKeyFingerprint only checks if key fingerprint has been used as public key,
// it is OK to use same key as deploy key for multiple repositories/users.
func checkKeyFingerprint(ctx context.Context, fingerprint string) error {
has, err := db.GetByBean(ctx, &PublicKey{
Fingerprint: fingerprint,
})
has, err := db.Exist[PublicKey](ctx, builder.Eq{"fingerprint": fingerprint})
if err != nil {
return err
} else if has {
Expand Down
35 changes: 13 additions & 22 deletions models/auth/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"

"xorm.io/builder"
)

// Session represents a session compatible for go-chi session
Expand All @@ -33,34 +35,28 @@ func UpdateSession(ctx context.Context, key string, data []byte) error {

// ReadSession reads the data for the provided session
func ReadSession(ctx context.Context, key string) (*Session, error) {
session := Session{
Key: key,
}

ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
defer committer.Close()

if has, err := db.GetByBean(ctx, &session); err != nil {
session, exist, err := db.Get[Session](ctx, builder.Eq{"key": key})
if err != nil {
return nil, err
} else if !has {
} else if !exist {
session.Expiry = timeutil.TimeStampNow()
if err := db.Insert(ctx, &session); err != nil {
return nil, err
}
}

return &session, committer.Commit()
return session, committer.Commit()
}

// ExistSession checks if a session exists
func ExistSession(ctx context.Context, key string) (bool, error) {
session := Session{
Key: key,
}
return db.GetEngine(ctx).Get(&session)
return db.Exist[Session](ctx, builder.Eq{"key": key})
}

// DestroySession destroys a session
Expand All @@ -79,17 +75,13 @@ func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, er
}
defer committer.Close()

if has, err := db.GetByBean(ctx, &Session{
Key: newKey,
}); err != nil {
if has, err := db.Exist[Session](ctx, builder.Eq{"key": newKey}); err != nil {
return nil, err
} else if has {
return nil, fmt.Errorf("session Key: %s already exists", newKey)
}

if has, err := db.GetByBean(ctx, &Session{
Key: oldKey,
}); err != nil {
if has, err := db.Exist[Session](ctx, builder.Eq{"key": oldKey}); err != nil {
return nil, err
} else if !has {
if err := db.Insert(ctx, &Session{
Expand All @@ -104,14 +96,13 @@ func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, er
return nil, err
}

s := Session{
Key: newKey,
}
if _, err := db.GetByBean(ctx, &s); err != nil {
s, _, err := db.Get[Session](ctx, builder.Eq{"key": newKey})
if err != nil {
// is not exist, it should be impossible
return nil, err
}

return &s, committer.Commit()
return s, committer.Commit()
}

// CountSessions returns the number of sessions
Expand Down
4 changes: 2 additions & 2 deletions models/auth/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ func IsSSPIEnabled(ctx context.Context) bool {
return false
}

exist, err := db.Exists[Source](ctx, FindSourcesOptions{
exist, err := db.Exist[Source](ctx, FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
LoginType: SSPI,
})
}.ToConds())
if err != nil {
log.Error("Active SSPI Sources: %v", err)
return false
Expand Down
46 changes: 38 additions & 8 deletions models/db/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,44 @@ func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) {
return GetEngine(ctx).Exec(sqlAndArgs...)
}

// GetByBean filled empty fields of the bean according non-empty fields to query in database.
func GetByBean(ctx context.Context, bean any) (bool, error) {
return GetEngine(ctx).Get(bean)
func Get[T any](ctx context.Context, cond builder.Cond) (object *T, exist bool, err error) {
if !cond.IsValid() {
return nil, false, ErrConditionRequired{}
}

var bean T
has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean)
if err != nil {
return nil, false, err
} else if !has {
return nil, false, nil
}
return &bean, true, nil
}

func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err error) {
var bean T
has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean)
if err != nil {
return nil, false, err
} else if !has {
return nil, false, nil
}
return &bean, true, nil
}

func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) {
if !cond.IsValid() {
return false, ErrConditionRequired{}
}

var bean T
return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean)
}

func ExistByID[T any](ctx context.Context, id int64) (bool, error) {
var bean T
return GetEngine(ctx).ID(id).NoAutoCondition().Exist(&bean)
}

// DeleteByBean deletes all records according non-empty fields of the bean as conditions.
Expand Down Expand Up @@ -264,8 +299,3 @@ func inTransaction(ctx context.Context) (*xorm.Session, bool) {
return nil, false
}
}

func Exists[T any](ctx context.Context, opts FindOptions) (bool, error) {
var bean T
return GetEngine(ctx).Where(opts.ToConds()).Exist(&bean)
}
18 changes: 18 additions & 0 deletions models/db/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,21 @@ func (err ErrNotExist) Error() string {
func (err ErrNotExist) Unwrap() error {
return util.ErrNotExist
}

// ErrConditionRequired represents an error which require condition.
type ErrConditionRequired struct{}

// IsErrConditionRequired checks if an error is an ErrConditionRequired
func IsErrConditionRequired(err error) bool {
_, ok := err.(ErrConditionRequired)
return ok
}

func (err ErrConditionRequired) Error() string {
return "condition is required"
}

// Unwrap unwraps this as a ErrNotExist err
func (err ErrConditionRequired) Unwrap() error {
return util.ErrInvalidArgument
}
6 changes: 3 additions & 3 deletions models/db/iterate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ func TestIterate(t *testing.T) {
assert.EqualValues(t, cnt, repoUnitCnt)

err = db.Iterate(db.DefaultContext, nil, func(ctx context.Context, repoUnit *repo_model.RepoUnit) error {
reopUnit2 := repo_model.RepoUnit{ID: repoUnit.ID}
has, err := db.GetByBean(ctx, &reopUnit2)
has, err := db.ExistByID[repo_model.RepoUnit](ctx, repoUnit.ID)
if err != nil {
return err
} else if !has {
}
if !has {
return db.ErrNotExist{Resource: "repo_unit", ID: repoUnit.ID}
}
assert.EqualValues(t, repoUnit.RepoID, repoUnit.RepoID)
Expand Down
9 changes: 4 additions & 5 deletions models/git/lfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ var ErrLFSObjectNotExist = db.ErrNotExist{Resource: "LFS Meta object"}

// NewLFSMetaObject stores a given populated LFSMetaObject structure in the database
// if it is not already present.
func NewLFSMetaObject(ctx context.Context, m *LFSMetaObject) (*LFSMetaObject, error) {
func NewLFSMetaObject(ctx context.Context, repoID int64, p lfs.Pointer) (*LFSMetaObject, error) {
var err error

ctx, committer, err := db.TxContext(ctx)
Expand All @@ -144,16 +144,15 @@ func NewLFSMetaObject(ctx context.Context, m *LFSMetaObject) (*LFSMetaObject, er
}
defer committer.Close()

has, err := db.GetByBean(ctx, m)
m, exist, err := db.Get[LFSMetaObject](ctx, builder.Eq{"repository_id": repoID, "oid": p.Oid})
if err != nil {
return nil, err
}

if has {
} else if exist {
m.Existing = true
return m, committer.Commit()
}

m = &LFSMetaObject{Pointer: p, RepositoryID: repoID}
if err = db.Insert(ctx, m); err != nil {
return nil, err
}
Expand Down
14 changes: 6 additions & 8 deletions models/git/protected_branch.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/gobwas/glob"
"github.com/gobwas/glob/syntax"
"xorm.io/builder"
)

var ErrBranchIsProtected = errors.New("branch is protected")
Expand Down Expand Up @@ -274,25 +275,22 @@ func (protectBranch *ProtectedBranch) IsUnprotectedFile(patterns []glob.Glob, pa

// GetProtectedBranchRuleByName getting protected branch rule by name
func GetProtectedBranchRuleByName(ctx context.Context, repoID int64, ruleName string) (*ProtectedBranch, error) {
rel := &ProtectedBranch{RepoID: repoID, RuleName: ruleName}
has, err := db.GetByBean(ctx, rel)
// branch_name is legacy name, it actually is rule name
rel, exist, err := db.Get[ProtectedBranch](ctx, builder.Eq{"repo_id": repoID, "branch_name": ruleName})
if err != nil {
return nil, err
}
if !has {
} else if !exist {
return nil, nil
}
return rel, nil
}

// GetProtectedBranchRuleByID getting protected branch rule by rule ID
func GetProtectedBranchRuleByID(ctx context.Context, repoID, ruleID int64) (*ProtectedBranch, error) {
rel := &ProtectedBranch{ID: ruleID, RepoID: repoID}
has, err := db.GetByBean(ctx, rel)
rel, exist, err := db.Get[ProtectedBranch](ctx, builder.Eq{"repo_id": repoID, "id": ruleID})
if err != nil {
return nil, err
}
if !has {
} else if !exist {
return nil, nil
}
return rel, nil
Expand Down
4 changes: 3 additions & 1 deletion models/issues/assignees.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/util"

"xorm.io/builder"
)

// IssueAssignees saves all issue assignees
Expand Down Expand Up @@ -59,7 +61,7 @@ func GetAssigneeIDsByIssue(ctx context.Context, issueID int64) ([]int64, error)

// IsUserAssignedToIssue returns true when the user is assigned to the issue
func IsUserAssignedToIssue(ctx context.Context, issue *Issue, user *user_model.User) (isAssigned bool, err error) {
return db.GetByBean(ctx, &IssueAssignees{IssueID: issue.ID, AssigneeID: user.ID})
return db.Exist[IssueAssignees](ctx, builder.Eq{"assignee_id": user.ID, "issue_id": issue.ID})
}

// ToggleIssueAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it.
Expand Down
Loading

0 comments on commit dd30d9d

Please sign in to comment.