Skip to content

Commit

Permalink
Refactor CSRF token (go-gitea#32216)
Browse files Browse the repository at this point in the history
  • Loading branch information
wxiaoguang authored Oct 10, 2024
1 parent 368b088 commit dd83cfc
Show file tree
Hide file tree
Showing 29 changed files with 90 additions and 126 deletions.
8 changes: 5 additions & 3 deletions routers/web/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func autoSignIn(ctx *context.Context) (bool, error) {
return false, err
}

ctx.Csrf.DeleteCookie(ctx)
ctx.Csrf.PrepareForSessionUser(ctx)
return true, nil
}

Expand Down Expand Up @@ -359,8 +359,8 @@ func handleSignInFull(ctx *context.Context, u *user_model.User, remember, obeyRe
ctx.Locale = middleware.Locale(ctx.Resp, ctx.Req)
}

// Clear whatever CSRF cookie has right now, force to generate a new one
ctx.Csrf.DeleteCookie(ctx)
// force to generate a new CSRF token
ctx.Csrf.PrepareForSessionUser(ctx)

// Register last login
if err := user_service.UpdateUser(ctx, u, &user_service.UpdateOptions{SetLastLogin: true}); err != nil {
Expand Down Expand Up @@ -804,6 +804,8 @@ func handleAccountActivation(ctx *context.Context, user *user_model.User) {
return
}

ctx.Csrf.PrepareForSessionUser(ctx)

if err := resetLocale(ctx, user); err != nil {
ctx.ServerError("resetLocale", err)
return
Expand Down
4 changes: 2 additions & 2 deletions routers/web/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ func handleOAuth2SignIn(ctx *context.Context, source *auth.Source, u *user_model
return
}

// Clear whatever CSRF cookie has right now, force to generate a new one
ctx.Csrf.DeleteCookie(ctx)
// force to generate a new CSRF token
ctx.Csrf.PrepareForSessionUser(ctx)

if err := resetLocale(ctx, u); err != nil {
ctx.ServerError("resetLocale", err)
Expand Down
4 changes: 2 additions & 2 deletions services/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ func handleSignIn(resp http.ResponseWriter, req *http.Request, sess SessionStore

middleware.SetLocaleCookie(resp, user.Language, 0)

// Clear whatever CSRF has right now, force to generate a new one
// force to generate a new CSRF token
if ctx := gitea_context.GetWebContext(req); ctx != nil {
ctx.Csrf.DeleteCookie(ctx)
ctx.Csrf.PrepareForSessionUser(ctx)
}
}
4 changes: 1 addition & 3 deletions services/context/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,8 @@ func (c *csrfProtector) PrepareForSessionUser(ctx *Context) {
}

if needsNew {
// FIXME: actionId.
c.token = GenerateCsrfToken(c.opt.Secret, c.id, "POST", time.Now())
cookie := newCsrfCookie(&c.opt, c.token)
ctx.Resp.Header().Add("Set-Cookie", cookie.String())
ctx.Resp.Header().Add("Set-Cookie", newCsrfCookie(&c.opt, c.token).String())
}

ctx.Data["CsrfToken"] = c.token
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/admin_user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func testSuccessfullEdit(t *testing.T, formData user_model.User) {

func makeRequest(t *testing.T, formData user_model.User, headerCode int) {
session := loginUser(t, "user1")
csrf := GetCSRF(t, session, "/admin/users/"+strconv.Itoa(int(formData.ID))+"/edit")
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", "/admin/users/"+strconv.Itoa(int(formData.ID))+"/edit", map[string]string{
"_csrf": csrf,
"user_name": formData.Name,
Expand All @@ -72,7 +72,7 @@ func TestAdminDeleteUser(t *testing.T) {

session := loginUser(t, "user1")

csrf := GetCSRF(t, session, "/admin/users/8/edit")
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", "/admin/users/8/delete", map[string]string{
"_csrf": csrf,
})
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/api_httpsig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func TestHTTPSigCert(t *testing.T) {
defer tests.PrepareTestEnv(t)()
session := loginUser(t, "user1")

csrf := GetCSRF(t, session, "/user/settings/keys")
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", "/user/settings/keys", map[string]string{
"_csrf": csrf,
"content": "user1",
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/api_packages_container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ func TestPackageContainer(t *testing.T) {
newOwnerName := "newUsername"

req := NewRequestWithValues(t, "POST", "/user/settings", map[string]string{
"_csrf": GetCSRF(t, session, "/user/settings"),
"_csrf": GetUserCSRFToken(t, session),
"name": newOwnerName,
"email": "user2@example.com",
"language": "en-US",
Expand All @@ -794,7 +794,7 @@ func TestPackageContainer(t *testing.T) {
t.Run(fmt.Sprintf("Catalog[%s]", newOwnerName), checkCatalog(newOwnerName))

req = NewRequestWithValues(t, "POST", "/user/settings", map[string]string{
"_csrf": GetCSRF(t, session, "/user/settings"),
"_csrf": GetUserCSRFToken(t, session),
"name": user.Name,
"email": "user2@example.com",
"language": "en-US",
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/attachment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ func createAttachment(t *testing.T, session *TestSession, csrf, repoURL, filenam
func TestCreateAnonymousAttachment(t *testing.T) {
defer tests.PrepareTestEnv(t)()
session := emptyTestSession(t)
createAttachment(t, session, GetCSRF(t, session, "/user/login"), "user2/repo1", "image.png", generateImg(), http.StatusSeeOther)
createAttachment(t, session, GetAnonymousCSRFToken(t, session), "user2/repo1", "image.png", generateImg(), http.StatusSeeOther)
}

func TestCreateIssueAttachment(t *testing.T) {
defer tests.PrepareTestEnv(t)()
const repoURL = "user2/repo1"
session := loginUser(t, "user2")
uuid := createAttachment(t, session, GetCSRF(t, session, repoURL), repoURL, "image.png", generateImg(), http.StatusOK)
uuid := createAttachment(t, session, GetUserCSRFToken(t, session), repoURL, "image.png", generateImg(), http.StatusOK)

req := NewRequest(t, "GET", repoURL+"/issues/new")
resp := session.MakeRequest(t, req, http.StatusOK)
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/auth_ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func addAuthSourceLDAP(t *testing.T, sshKeyAttribute, groupFilter string, groupM
groupTeamMap = groupMapParams[1]
}
session := loginUser(t, "user1")
csrf := GetCSRF(t, session, "/admin/auths/new")
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", "/admin/auths/new", buildAuthSourceLDAPPayload(csrf, sshKeyAttribute, groupFilter, groupTeamMap, groupTeamMapRemoval))
session.MakeRequest(t, req, http.StatusSeeOther)
}
Expand Down Expand Up @@ -252,7 +252,7 @@ func TestLDAPUserSyncWithEmptyUsernameAttribute(t *testing.T) {
defer tests.PrepareTestEnv(t)()

session := loginUser(t, "user1")
csrf := GetCSRF(t, session, "/admin/auths/new")
csrf := GetUserCSRFToken(t, session)
payload := buildAuthSourceLDAPPayload(csrf, "", "", "", "")
payload["attribute_username"] = ""
req := NewRequestWithValues(t, "POST", "/admin/auths/new", payload)
Expand Down Expand Up @@ -487,7 +487,7 @@ func TestLDAPPreventInvalidGroupTeamMap(t *testing.T) {
defer tests.PrepareTestEnv(t)()

session := loginUser(t, "user1")
csrf := GetCSRF(t, session, "/admin/auths/new")
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", "/admin/auths/new", buildAuthSourceLDAPPayload(csrf, "", "", `{"NOT_A_VALID_JSON"["MISSING_DOUBLE_POINT"]}`, "off"))
session.MakeRequest(t, req, http.StatusOK) // StatusOK = failed, StatusSeeOther = ok
}
4 changes: 2 additions & 2 deletions tests/integration/change_default_branch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ func TestChangeDefaultBranch(t *testing.T) {
session := loginUser(t, owner.Name)
branchesURL := fmt.Sprintf("/%s/%s/settings/branches", owner.Name, repo.Name)

csrf := GetCSRF(t, session, branchesURL)
csrf := GetUserCSRFToken(t, session)
req := NewRequestWithValues(t, "POST", branchesURL, map[string]string{
"_csrf": csrf,
"action": "default_branch",
"branch": "DefaultBranch",
})
session.MakeRequest(t, req, http.StatusSeeOther)

csrf = GetCSRF(t, session, branchesURL)
csrf = GetUserCSRFToken(t, session)
req = NewRequestWithValues(t, "POST", branchesURL, map[string]string{
"_csrf": csrf,
"action": "default_branch",
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/delete_user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestUserDeleteAccount(t *testing.T) {
defer tests.PrepareTestEnv(t)()

session := loginUser(t, "user8")
csrf := GetCSRF(t, session, "/user/settings/account")
csrf := GetUserCSRFToken(t, session)
urlStr := fmt.Sprintf("/user/settings/account/delete?password=%s", userPassword)
req := NewRequestWithValues(t, "POST", urlStr, map[string]string{
"_csrf": csrf,
Expand All @@ -48,7 +48,7 @@ func TestUserDeleteAccountStillOwnRepos(t *testing.T) {
defer tests.PrepareTestEnv(t)()

session := loginUser(t, "user2")
csrf := GetCSRF(t, session, "/user/settings/account")
csrf := GetUserCSRFToken(t, session)
urlStr := fmt.Sprintf("/user/settings/account/delete?password=%s", userPassword)
req := NewRequestWithValues(t, "POST", urlStr, map[string]string{
"_csrf": csrf,
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/editor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestCreateFileOnProtectedBranch(t *testing.T) {
onGiteaRun(t, func(t *testing.T, u *url.URL) {
session := loginUser(t, "user2")

csrf := GetCSRF(t, session, "/user2/repo1/settings/branches")
csrf := GetUserCSRFToken(t, session)
// Change master branch to protected
req := NewRequestWithValues(t, "POST", "/user2/repo1/settings/branches/edit", map[string]string{
"_csrf": csrf,
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestCreateFileOnProtectedBranch(t *testing.T) {
assert.Contains(t, resp.Body.String(), "Cannot commit to protected branch "master".")

// remove the protected branch
csrf = GetCSRF(t, session, "/user2/repo1/settings/branches")
csrf = GetUserCSRFToken(t, session)

// Change master branch to protected
req = NewRequestWithValues(t, "POST", "/user2/repo1/settings/branches/1/delete", map[string]string{
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/empty_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
func testAPINewFile(t *testing.T, session *TestSession, user, repo, branch, treePath, content string) *httptest.ResponseRecorder {
url := fmt.Sprintf("/%s/%s/_new/%s", user, repo, branch)
req := NewRequestWithValues(t, "POST", url, map[string]string{
"_csrf": GetCSRF(t, session, "/user/settings"),
"_csrf": GetUserCSRFToken(t, session),
"commit_choice": "direct",
"tree_path": treePath,
"content": content,
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestEmptyRepoAddFile(t *testing.T) {
doc := NewHTMLParser(t, resp.Body).Find(`input[name="commit_choice"]`)
assert.Empty(t, doc.AttrOr("checked", "_no_"))
req = NewRequestWithValues(t, "POST", "/user30/empty/_new/"+setting.Repository.DefaultBranch, map[string]string{
"_csrf": GetCSRF(t, session, "/user/settings"),
"_csrf": GetUserCSRFToken(t, session),
"commit_choice": "direct",
"tree_path": "test-file.md",
"content": "newly-added-test-file",
Expand All @@ -89,7 +89,7 @@ func TestEmptyRepoUploadFile(t *testing.T) {

body := &bytes.Buffer{}
mpForm := multipart.NewWriter(body)
_ = mpForm.WriteField("_csrf", GetCSRF(t, session, "/user/settings"))
_ = mpForm.WriteField("_csrf", GetUserCSRFToken(t, session))
file, _ := mpForm.CreateFormFile("file", "uploaded-file.txt")
_, _ = io.Copy(file, bytes.NewBufferString("newly-uploaded-test-file"))
_ = mpForm.Close()
Expand All @@ -101,7 +101,7 @@ func TestEmptyRepoUploadFile(t *testing.T) {
assert.NoError(t, json.Unmarshal(resp.Body.Bytes(), &respMap))

req = NewRequestWithValues(t, "POST", "/user30/empty/_upload/"+setting.Repository.DefaultBranch, map[string]string{
"_csrf": GetCSRF(t, session, "/user/settings"),
"_csrf": GetUserCSRFToken(t, session),
"commit_choice": "direct",
"files": respMap["uuid"],
"tree_path": "",
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ func doBranchProtectPRMerge(baseCtx *APITestContext, dstPath string) func(t *tes
func doProtectBranch(ctx APITestContext, branch, userToWhitelistPush, userToWhitelistForcePush, unprotectedFilePatterns string) func(t *testing.T) {
// We are going to just use the owner to set the protection.
return func(t *testing.T) {
csrf := GetCSRF(t, ctx.Session, fmt.Sprintf("/%s/%s/settings/branches", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame)))
csrf := GetUserCSRFToken(t, ctx.Session)

formData := map[string]string{
"_csrf": csrf,
Expand Down Expand Up @@ -644,7 +644,7 @@ func doPushCreate(ctx APITestContext, u *url.URL) func(t *testing.T) {

func doBranchDelete(ctx APITestContext, owner, repo, branch string) func(*testing.T) {
return func(t *testing.T) {
csrf := GetCSRF(t, ctx.Session, fmt.Sprintf("/%s/%s/branches", url.PathEscape(owner), url.PathEscape(repo)))
csrf := GetUserCSRFToken(t, ctx.Session)

req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/branches/delete?name=%s", url.PathEscape(owner), url.PathEscape(repo), url.QueryEscape(branch)), map[string]string{
"_csrf": csrf,
Expand Down
26 changes: 11 additions & 15 deletions tests/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,23 +486,19 @@ func VerifyJSONSchema(t testing.TB, resp *httptest.ResponseRecorder, schemaFile
assert.True(t, result.Valid())
}

// GetCSRF returns CSRF token from body
// If it fails, it means the CSRF token is not found in the response body returned by the url with the given session.
// In this case, you should find a better url to get it.
func GetCSRF(t testing.TB, session *TestSession, urlStr string) string {
// GetUserCSRFToken returns CSRF token for current user
func GetUserCSRFToken(t testing.TB, session *TestSession) string {
t.Helper()
req := NewRequest(t, "GET", urlStr)
resp := session.MakeRequest(t, req, http.StatusOK)
doc := NewHTMLParser(t, resp.Body)
csrf := doc.GetCSRF()
require.NotEmpty(t, csrf)
return csrf
cookie := session.GetCookie("_csrf")
require.NotEmpty(t, cookie)
return cookie.Value
}

// GetCSRFFrom returns CSRF token from body
func GetCSRFFromCookie(t testing.TB, session *TestSession, urlStr string) string {
// GetUserCSRFToken returns CSRF token for anonymous user (not logged in)
func GetAnonymousCSRFToken(t testing.TB, session *TestSession) string {
t.Helper()
req := NewRequest(t, "GET", urlStr)
session.MakeRequest(t, req, http.StatusOK)
return session.GetCookie("_csrf").Value
resp := session.MakeRequest(t, NewRequest(t, "GET", "/user/login"), http.StatusOK)
csrfToken := NewHTMLParser(t, resp.Body).GetCSRF()
require.NotEmpty(t, csrfToken)
return csrfToken
}
20 changes: 10 additions & 10 deletions tests/integration/issue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,21 +197,21 @@ func TestEditIssue(t *testing.T) {
issueURL := testNewIssue(t, session, "user2", "repo1", "Title", "Description")

req := NewRequestWithValues(t, "POST", fmt.Sprintf("%s/content", issueURL), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": "modified content",
"context": fmt.Sprintf("/%s/%s", "user2", "repo1"),
})
session.MakeRequest(t, req, http.StatusOK)

req = NewRequestWithValues(t, "POST", fmt.Sprintf("%s/content", issueURL), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": "modified content",
"context": fmt.Sprintf("/%s/%s", "user2", "repo1"),
})
session.MakeRequest(t, req, http.StatusBadRequest)

req = NewRequestWithValues(t, "POST", fmt.Sprintf("%s/content", issueURL), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": "modified content",
"content_version": "1",
"context": fmt.Sprintf("/%s/%s", "user2", "repo1"),
Expand Down Expand Up @@ -246,11 +246,11 @@ func TestIssueCommentDelete(t *testing.T) {

// Using the ID of a comment that does not belong to the repository must fail
req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d/delete", "user5", "repo4", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
})
session.MakeRequest(t, req, http.StatusNotFound)
req = NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d/delete", "user2", "repo1", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
})
session.MakeRequest(t, req, http.StatusOK)
unittest.AssertNotExistsBean(t, &issues_model.Comment{ID: commentID})
Expand All @@ -270,13 +270,13 @@ func TestIssueCommentUpdate(t *testing.T) {

// Using the ID of a comment that does not belong to the repository must fail
req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d", "user5", "repo4", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": modifiedContent,
})
session.MakeRequest(t, req, http.StatusNotFound)

req = NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d", "user2", "repo1", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": modifiedContent,
})
session.MakeRequest(t, req, http.StatusOK)
Expand All @@ -298,21 +298,21 @@ func TestIssueCommentUpdateSimultaneously(t *testing.T) {
modifiedContent := comment.Content + "MODIFIED"

req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d", "user2", "repo1", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": modifiedContent,
})
session.MakeRequest(t, req, http.StatusOK)

modifiedContent = comment.Content + "2"

req = NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d", "user2", "repo1", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": modifiedContent,
})
session.MakeRequest(t, req, http.StatusBadRequest)

req = NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/comments/%d", "user2", "repo1", commentID), map[string]string{
"_csrf": GetCSRF(t, session, issueURL),
"_csrf": GetUserCSRFToken(t, session),
"content": modifiedContent,
"content_version": "1",
})
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/mirror_push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func testMirrorPush(t *testing.T, u *url.URL) {

func doCreatePushMirror(ctx APITestContext, address, username, password string) func(t *testing.T) {
return func(t *testing.T) {
csrf := GetCSRF(t, ctx.Session, fmt.Sprintf("/%s/%s/settings", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame)))
csrf := GetUserCSRFToken(t, ctx.Session)

req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/settings", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame)), map[string]string{
"_csrf": csrf,
Expand All @@ -101,7 +101,7 @@ func doCreatePushMirror(ctx APITestContext, address, username, password string)

func doRemovePushMirror(ctx APITestContext, address, username, password string, pushMirrorID int) func(t *testing.T) {
return func(t *testing.T) {
csrf := GetCSRF(t, ctx.Session, fmt.Sprintf("/%s/%s/settings", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame)))
csrf := GetUserCSRFToken(t, ctx.Session)

req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/settings", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame)), map[string]string{
"_csrf": csrf,
Expand Down
Loading

0 comments on commit dd83cfc

Please sign in to comment.