diff --git a/server/auth/store.go b/server/auth/store.go index b4723c6ec0a..3085b498448 100644 --- a/server/auth/store.go +++ b/server/auth/store.go @@ -18,7 +18,6 @@ import ( "bytes" "context" "encoding/base64" - "encoding/binary" "errors" "sort" "strings" @@ -76,8 +75,6 @@ const ( tokenTypeSimple = "simple" tokenTypeJWT = "jwt" - - revBytesLen = 8 ) type AuthInfo struct { @@ -237,7 +234,7 @@ func (as *authStore) AuthEnable() error { return ErrRootRoleNotExist } - tx.UnsafePut(buckets.Auth, buckets.AuthEnabledKeyName, authEnabled) + buckets.UnsafeSaveAuthEnabled(tx, true) as.enabled = true as.tokenProvider.enable() @@ -259,7 +256,7 @@ func (as *authStore) AuthDisable() { b := as.be tx := b.BatchTx() tx.Lock() - tx.UnsafePut(buckets.Auth, buckets.AuthEnabledKeyName, authDisabled) + buckets.UnsafeSaveAuthEnabled(tx, false) as.commitRevision(tx) tx.Unlock() b.ForceCommit() @@ -350,17 +347,11 @@ func (as *authStore) CheckPassword(username, password string) (uint64, error) { } func (as *authStore) Recover(be backend.Backend) { - enabled := false as.be = be tx := be.BatchTx() tx.Lock() - _, vs := tx.UnsafeRange(buckets.Auth, buckets.AuthEnabledKeyName, nil, 0) - if len(vs) == 1 { - if bytes.Equal(vs[0], authEnabled) { - enabled = true - } - } + enabled := buckets.UnsafeReadAuthEnabled(tx) as.setRevision(getRevision(tx)) tx.Unlock() @@ -936,17 +927,11 @@ func NewAuthStore(lg *zap.Logger, be backend.Backend, tp TokenProvider, bcryptCo tx := be.BatchTx() tx.Lock() - tx.UnsafeCreateBucket(buckets.Auth) + buckets.UnsafeCreateAuthBucket(tx) tx.UnsafeCreateBucket(buckets.AuthUsers) tx.UnsafeCreateBucket(buckets.AuthRoles) - enabled := false - _, vs := tx.UnsafeRange(buckets.Auth, buckets.AuthEnabledKeyName, nil, 0) - if len(vs) == 1 { - if bytes.Equal(vs[0], authEnabled) { - enabled = true - } - } + enabled := buckets.UnsafeReadAuthEnabled(tx) as := &authStore{ revision: getRevision(tx), @@ -982,18 +967,11 @@ func hasRootRole(u *authpb.User) bool { func (as *authStore) commitRevision(tx backend.BatchTx) { atomic.AddUint64(&as.revision, 1) - revBytes := make([]byte, revBytesLen) - binary.BigEndian.PutUint64(revBytes, as.Revision()) - tx.UnsafePut(buckets.Auth, buckets.AuthRevisionKeyName, revBytes) + buckets.UnsafeSaveAuthRevision(tx, as.Revision()) } func getRevision(tx backend.BatchTx) uint64 { - _, vs := tx.UnsafeRange(buckets.Auth, buckets.AuthRevisionKeyName, nil, 0) - if len(vs) != 1 { - // this can happen in the initialization phase - return 0 - } - return binary.BigEndian.Uint64(vs[0]) + return buckets.UnsafeReadAuthRevision(tx) } func (as *authStore) setRevision(rev uint64) { diff --git a/server/mvcc/buckets/auth.go b/server/mvcc/buckets/auth.go new file mode 100644 index 00000000000..099298a95c2 --- /dev/null +++ b/server/mvcc/buckets/auth.go @@ -0,0 +1,67 @@ +// Copyright 2021 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buckets + +import ( + "bytes" + "encoding/binary" + "go.etcd.io/etcd/server/v3/mvcc/backend" +) + +const ( + revBytesLen = 8 +) + +var ( + authEnabled = []byte{1} + authDisabled = []byte{0} +) + +func UnsafeCreateAuthBucket(tx backend.BatchTx) { + tx.UnsafeCreateBucket(Auth) +} + +func UnsafeSaveAuthEnabled(tx backend.BatchTx, enabled bool) { + if enabled { + tx.UnsafePut(Auth, AuthEnabledKeyName, authEnabled) + } else { + tx.UnsafePut(Auth, AuthEnabledKeyName, authDisabled) + } +} + +func UnsafeReadAuthEnabled(tx backend.ReadTx) bool { + _, vs := tx.UnsafeRange(Auth, AuthEnabledKeyName, nil, 0) + if len(vs) == 1 { + if bytes.Equal(vs[0], authEnabled) { + return true + } + } + return false +} + +func UnsafeSaveAuthRevision(tx backend.BatchTx, rev uint64) { + revBytes := make([]byte, revBytesLen) + binary.BigEndian.PutUint64(revBytes, rev) + tx.UnsafePut(Auth, AuthRevisionKeyName, revBytes) +} + +func UnsafeReadAuthRevision(tx backend.ReadTx) uint64 { + _, vs := tx.UnsafeRange(Auth, AuthRevisionKeyName, nil, 0) + if len(vs) != 1 { + // this can happen in the initialization phase + return 0 + } + return binary.BigEndian.Uint64(vs[0]) +} diff --git a/server/mvcc/buckets/auth_test.go b/server/mvcc/buckets/auth_test.go new file mode 100644 index 00000000000..b2386d25c54 --- /dev/null +++ b/server/mvcc/buckets/auth_test.go @@ -0,0 +1,99 @@ +// Copyright 2021 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buckets + +import ( + "fmt" + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/server/v3/mvcc/backend" + betesting "go.etcd.io/etcd/server/v3/mvcc/backend/testing" +) + +// TestAuthEnabled ensures that UnsafeSaveAuthEnabled&UnsafeReadAuthEnabled work well together. +func TestAuthEnabled(t *testing.T) { + tcs := []struct { + enabled bool + }{ + { + enabled: true, + }, + { + enabled: false, + }, + } + for _, tc := range tcs { + t.Run(fmt.Sprint(tc.enabled), func(t *testing.T) { + be, tmpPath := betesting.NewTmpBackend(t, time.Microsecond, 10) + tx := be.BatchTx() + if tx == nil { + t.Fatal("batch tx is nil") + } + tx.Lock() + UnsafeCreateAuthBucket(tx) + UnsafeSaveAuthEnabled(tx, tc.enabled) + tx.Unlock() + be.ForceCommit() + be.Close() + + b := backend.NewDefaultBackend(tmpPath) + defer b.Close() + v := UnsafeReadAuthEnabled(b.BatchTx()) + + assert.Equal(t, tc.enabled, v) + }) + } +} + +// TestAuthRevision ensures that UnsafeSaveAuthRevision&UnsafeReadAuthRevision work well together. +func TestAuthRevision(t *testing.T) { + tcs := []struct { + revision uint64 + }{ + { + revision: 0, + }, + { + revision: 1, + }, + { + revision: math.MaxUint64, + }, + } + for _, tc := range tcs { + t.Run(fmt.Sprint(tc.revision), func(t *testing.T) { + be, tmpPath := betesting.NewTmpBackend(t, time.Microsecond, 10) + tx := be.BatchTx() + if tx == nil { + t.Fatal("batch tx is nil") + } + tx.Lock() + UnsafeCreateAuthBucket(tx) + UnsafeSaveAuthRevision(tx, tc.revision) + tx.Unlock() + be.ForceCommit() + be.Close() + + b := backend.NewDefaultBackend(tmpPath) + defer b.Close() + v := UnsafeReadAuthRevision(b.BatchTx()) + + assert.Equal(t, tc.revision, v) + }) + } +}