Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add worker token TTL #621

Merged
merged 7 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/k8s/pkg/k8sd/api/cluster_tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@ func (e *Endpoints) postClusterJoinTokens(s state.State, r *http.Request) respon
}

var token string

ttl := req.TTL
if ttl == 0 {
// Set the default token lifetime to 24 hours.
ttl = 24 * time.Hour
}

if req.Worker {
token, err = getOrCreateWorkerToken(r.Context(), s, hostname)
token, err = getOrCreateWorkerToken(r.Context(), s, hostname, ttl)
} else {
token, err = getOrCreateJoinToken(r.Context(), e.provider.MicroCluster(), hostname, req.TTL)
token, err = getOrCreateJoinToken(r.Context(), e.provider.MicroCluster(), hostname, ttl)
}
if err != nil {
return response.InternalError(fmt.Errorf("failed to create token: %w", err))
Expand All @@ -54,22 +61,18 @@ func getOrCreateJoinToken(ctx context.Context, m *microcluster.MicroCluster, tok
fmt.Println("No token exists yet. Creating a new token.")
}

if ttl == 0 {
ttl = 24 * time.Hour
}

token, err := m.NewJoinToken(ctx, tokenName, ttl)
if err != nil {
return "", fmt.Errorf("failed to generate a new microcluster join token: %w", err)
}
return token, nil
}

func getOrCreateWorkerToken(ctx context.Context, s state.State, nodeName string) (string, error) {
func getOrCreateWorkerToken(ctx context.Context, s state.State, nodeName string, ttl time.Duration) (string, error) {
var token string
if err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
var err error
token, err = database.GetOrCreateWorkerNodeToken(ctx, tx, nodeName)
token, err = database.GetOrCreateWorkerNodeToken(ctx, tx, nodeName, time.Now().Add(ttl))
if err != nil {
return fmt.Errorf("failed to create worker node token: %w", err)
}
Expand Down
6 changes: 6 additions & 0 deletions src/k8s/pkg/k8sd/database/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@ import (
)

var (
// SchemaExtensions defines the schema updates for the database.
// SchemaExtensions are apply only.
// Note(ben): Never change the order or remove a migration as this would break the internal microcluster counter!
bschimke95 marked this conversation as resolved.
Show resolved Hide resolved
SchemaExtensions = []schema.Update{
schemaApplyMigration("kubernetes-auth-tokens", "000-create.sql"),
schemaApplyMigration("cluster-configs", "000-create.sql"),

schemaApplyMigration("worker-tokens", "000-create.sql"),
schemaApplyMigration("worker-tokens", "001-add-expiry.sql"),

schemaApplyMigration("feature-status", "000-feature-status.sql"),
}
Comment on lines +17 to 26
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the addition of schemaApplyMigration("worker-tokens", "001-add-expiry.sql") did not obey the comment at the top of this function:

// Note(ben): Never change the order or remove a migration as this would break the internal microcluster counter!


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE worker_tokens
ADD COLUMN expiry DATETIME DEFAULT '2100-01-01 23:59:59';
Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bschimke95 what if the table worker_tokens has a column named expiry already. Won't this fail?

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
INSERT INTO
worker_tokens(name, token)
worker_tokens(name, token, expiry)
VALUES
( ?, ? )
( ?, ?, ? )
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
t.name
t.name, t.expiry
FROM
worker_tokens AS t
WHERE
Expand Down
14 changes: 10 additions & 4 deletions src/k8s/pkg/k8sd/database/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"database/sql"
"encoding/hex"
"fmt"
"time"

"github.com/canonical/microcluster/v2/cluster"
)
Expand All @@ -20,21 +21,26 @@ var (
)

// CheckWorkerNodeToken returns true if the specified token can be used to join the specified node on the cluster.
// CheckWorkerNodeToken will return true if the token is empty or if the token is associated with the specified node
// and has not expired.
func CheckWorkerNodeToken(ctx context.Context, tx *sql.Tx, nodeName string, token string) (bool, error) {
selectTxStmt, err := cluster.Stmt(tx, workerStmts["select-token"])
if err != nil {
return false, fmt.Errorf("failed to prepare select statement: %w", err)
}
var tokenNodeName string
if selectTxStmt.QueryRowContext(ctx, token).Scan(&tokenNodeName) == nil {
return tokenNodeName == "" || subtle.ConstantTimeCompare([]byte(nodeName), []byte(tokenNodeName)) == 1, nil
var expiry time.Time
if selectTxStmt.QueryRowContext(ctx, token).Scan(&tokenNodeName, &expiry) == nil {
isValidToken := tokenNodeName == "" || subtle.ConstantTimeCompare([]byte(nodeName), []byte(tokenNodeName)) == 1
notExpired := time.Now().Before(expiry)
return isValidToken && notExpired, nil
}
return false, nil
}

// GetOrCreateWorkerNodeToken returns a token that can be used to join a worker node on the cluster.
// GetOrCreateWorkerNodeToken will return the existing token, if one already exists for the node.
func GetOrCreateWorkerNodeToken(ctx context.Context, tx *sql.Tx, nodeName string) (string, error) {
func GetOrCreateWorkerNodeToken(ctx context.Context, tx *sql.Tx, nodeName string, expiry time.Time) (string, error) {
insertTxStmt, err := cluster.Stmt(tx, workerStmts["insert-token"])
if err != nil {
return "", fmt.Errorf("failed to prepare insert statement: %w", err)
Expand All @@ -46,7 +52,7 @@ func GetOrCreateWorkerNodeToken(ctx context.Context, tx *sql.Tx, nodeName string
return "", fmt.Errorf("is the system entropy low? failed to get random bytes: %w", err)
}
token := fmt.Sprintf("worker::%s", hex.EncodeToString(b))
if _, err := insertTxStmt.ExecContext(ctx, nodeName, token); err != nil {
if _, err := insertTxStmt.ExecContext(ctx, nodeName, token, expiry); err != nil {
return "", fmt.Errorf("insert token query failed: %w", err)
}
return token, nil
Expand Down
34 changes: 30 additions & 4 deletions src/k8s/pkg/k8sd/database/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"testing"
"time"

"github.com/canonical/k8s/pkg/k8sd/database"
. "github.com/onsi/gomega"
Expand All @@ -12,17 +13,18 @@ import (
func TestWorkerNodeToken(t *testing.T) {
WithDB(t, func(ctx context.Context, db DB) {
_ = db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
tokenExpiry := time.Now().Add(time.Hour)
t.Run("Default", func(t *testing.T) {
g := NewWithT(t)
exists, err := database.CheckWorkerNodeToken(ctx, tx, "somenode", "sometoken")
g.Expect(err).To(BeNil())
g.Expect(exists).To(BeFalse())

token, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "somenode")
token, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "somenode", tokenExpiry)
g.Expect(err).To(BeNil())
g.Expect(token).To(HaveLen(48))

othertoken, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "someothernode")
othertoken, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "someothernode", tokenExpiry)
g.Expect(err).To(BeNil())
g.Expect(othertoken).To(HaveLen(48))
g.Expect(othertoken).NotTo(Equal(token))
Expand All @@ -46,15 +48,39 @@ func TestWorkerNodeToken(t *testing.T) {
g.Expect(err).To(BeNil())
g.Expect(valid).To(BeFalse())

newToken, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "somenode")
newToken, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "somenode", tokenExpiry)
g.Expect(err).To(BeNil())
g.Expect(newToken).To(HaveLen(48))
g.Expect(newToken).ToNot(Equal(token))
})

t.Run("Expiry", func(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
g := NewWithT(t)
token, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "nodeExpiry1", time.Now().Add(time.Hour))
g.Expect(err).To(BeNil())
g.Expect(token).To(HaveLen(48))

valid, err := database.CheckWorkerNodeToken(ctx, tx, "nodeExpiry1", token)
g.Expect(err).To(BeNil())
g.Expect(valid).To(BeTrue())
})

t.Run("Expired", func(t *testing.T) {
g := NewWithT(t)
token, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "nodeExpiry2", time.Now().Add(-time.Hour))
g.Expect(err).To(BeNil())
g.Expect(token).To(HaveLen(48))

valid, err := database.CheckWorkerNodeToken(ctx, tx, "nodeExpiry2", token)
g.Expect(err).To(BeNil())
g.Expect(valid).To(BeFalse())
})
})

t.Run("AnyNodeName", func(t *testing.T) {
g := NewWithT(t)
token, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "")
token, err := database.GetOrCreateWorkerNodeToken(ctx, tx, "", tokenExpiry)
g.Expect(err).To(BeNil())
g.Expect(token).To(HaveLen(48))

Expand Down
Loading