Skip to content

Commit

Permalink
fix(vault): correcting how we renew vault secrets (#153)
Browse files Browse the repository at this point in the history
* Correcting how we renew db secrets

* Correcting how we renew db secrets

* Updating tests

* Correcting api tests
  • Loading branch information
Jacobbrewer1 authored Aug 2, 2024
1 parent e4acd83 commit 50f6ba3
Show file tree
Hide file tree
Showing 29 changed files with 3,957 additions and 217 deletions.
45 changes: 34 additions & 11 deletions cmd/summary/cmd_serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,35 +155,30 @@ func (s *serveCmd) generateConfig(ctx context.Context) error {
func (s *serveCmd) setup(ctx context.Context, r *mux.Router) {
v := viper.New()
v.SetConfigFile(s.configLocation)
if err := v.ReadInConfig(); err != nil {
err := v.ReadInConfig()
if err != nil {
slog.Error("Error reading config file", slog.String(logging.KeyError, err.Error()))
os.Exit(1)
}

var vc vault.Client
dbSec := new(vault.Secrets)
if s.vaultEnabled {
// Set up the vault client
vc, err := vault.NewClient(v)
vc, err = vault.NewClientUserPass(v)
if err != nil {
slog.Error("Error creating vault client", slog.String(logging.KeyError, err.Error()))
os.Exit(1)
}

dbSec, err := vc.GetSecrets(v.GetString("vault.db_path"))
dbSec, err = vc.GetSecrets(v.GetString("database.credentials_path"))
if err != nil {
slog.Error("Error getting database secrets", slog.String(logging.KeyError, err.Error()))
os.Exit(1)
}

slog.Debug("Database credentials retrieved from vault")

go func() {
err = vc.RenewLease(ctx, v.GetString("vault.db_path"), dbSec.Secret, func() (*vault2.Secret, error) {
slog.Warn("Vault lease expired, restarting application")
os.Exit(1) // Restart the application to get new secrets
return nil, nil
})
}()

dbConnStr := dataaccess.GenerateConnectionStr(v, *dbSec)
v.Set("db.conn_str", dbConnStr)
} else {
Expand All @@ -201,6 +196,34 @@ func (s *serveCmd) setup(ctx context.Context, r *mux.Router) {
os.Exit(1)
}

if s.vaultEnabled {
go func() {
err = vc.RenewLease(ctx, v.GetString("database.credentials_path"), dbSec.Secret, func() (*vault2.Secret, error) {
slog.Warn("Vault lease expired, reconnecting to database")

vs, err := vc.GetSecrets(v.GetString("database.credentials_path"))
if err != nil {
return nil, fmt.Errorf("error getting secrets from vault: %w", err)
}

dbConnectionString := dataaccess.GenerateConnectionStr(v, *vs)
v.Set("database.connection_string", dbConnectionString)

if err := db.Reconnect(ctx, dbConnectionString); err != nil {
return nil, fmt.Errorf("error reconnecting to database: %w", err)
}

slog.Info("Database reconnected")

return vs.Secret, nil
})
if err != nil {
slog.Error("Error renewing vault lease", slog.String(logging.KeyError, err.Error()))
os.Exit(1) // Forces new credentials to be fetched
}
}()
}

purgeSvc := purge.NewService(db)

// Set up the purge routine
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ require (
github.com/gorilla/mux v1.8.1
github.com/hashicorp/vault/api v1.14.0
github.com/hashicorp/vault/api/auth/approle v0.7.0
github.com/hashicorp/vault/api/auth/userpass v0.7.0
github.com/jmoiron/sqlx v1.4.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/oapi-codegen/runtime v1.1.1
github.com/prometheus/client_golang v1.19.1
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ github.com/hashicorp/vault/api v1.14.0 h1:Ah3CFLixD5jmjusOgm8grfN9M0d+Y8fVR2SW0K
github.com/hashicorp/vault/api v1.14.0/go.mod h1:pV9YLxBGSz+cItFDd8Ii4G17waWOQ32zVjMWHe/cOqk=
github.com/hashicorp/vault/api/auth/approle v0.7.0 h1:R5IRVuFA5JSdG3UdGVcGysi0StrL1lPmyJnrawiV0Ss=
github.com/hashicorp/vault/api/auth/approle v0.7.0/go.mod h1:B+WaC6VR+aSXiUxykpaPUoFiiZAhic53tDLbGjWZmRA=
github.com/hashicorp/vault/api/auth/userpass v0.7.0 h1:7Fk0qtF2NYSJyQ6EOO+Kt93dEobI30AqBrrC5wE6e+8=
github.com/hashicorp/vault/api/auth/userpass v0.7.0/go.mod h1:3tZ2KAAui23OKlo5PZ+sBycoJ4wdurY6oZdQWJ0UStg=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4=
Expand All @@ -139,6 +143,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
Expand Down
91 changes: 89 additions & 2 deletions pkg/dataaccess/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ package dataaccess
import (
"context"
"fmt"
"log/slog"
"strings"
"sync"
"time"

"github.com/Jacobbrewer1/puppet-summary/pkg/codegen/apis/summary"
"github.com/Jacobbrewer1/puppet-summary/pkg/entities"
"github.com/Jacobbrewer1/puppet-summary/pkg/vault"
"github.com/jmoiron/sqlx"
"github.com/spf13/viper"
)

Expand All @@ -21,6 +24,9 @@ type Database interface {
// Close closes the database connection.
Close(ctx context.Context) error

// Reconnect will be called periodically to refresh the database connection
Reconnect(ctx context.Context, connStr string) error

// SaveRun saves a PuppetRun to the database.
SaveRun(ctx context.Context, run *entities.PuppetReport) error

Expand Down Expand Up @@ -83,7 +89,88 @@ func GenerateConnectionStr(v *viper.Viper, vs vault.Secrets) string {
return fmt.Sprintf("%s:%s@tcp(%s)/%s?timeout=90s&multiStatements=true&parseTime=true",
vs.Data["username"],
vs.Data["password"],
v.GetString("db.host"),
v.GetString("db.schema"),
v.GetString("database.host"),
v.GetString("database.schema"),
)
}

type Db struct {
*sqlx.DB
*sync.RWMutex
}

// NewDb establishes a database connection with the given Vault credentials
func NewDb(db *sqlx.DB) *Db {
return &Db{
DB: db,
RWMutex: new(sync.RWMutex),
}
}

// Reconnect will be called periodically to refresh the database connection
// since the dynamic credentials expire after some time, it will:
// 1. construct a connection string using the given credentials
// 2. establish a database connection
// 3. close & replace the existing connection with the new one behind a mutex
func (d *Db) Reconnect(ctx context.Context, db *sqlx.DB) error {
ctx, cancelContextFunc := context.WithTimeout(ctx, 7*time.Second)
defer cancelContextFunc()

slog.Debug("Reconnecting to database")

// wait until the database is ready or timeout expires
for {
err := db.PingContext(ctx)
if err == nil {
break
}
select {
case <-time.After(500 * time.Millisecond):
slog.Debug("Database ping failed, retrying...")
continue
case <-ctx.Done():
return fmt.Errorf("failed to successfully ping database before context timeout: %w", err)
}
}

slog.Info("New database connection established")

d.closeReplaceConnection(db)

return nil
}

func (d *Db) closeReplaceConnection(newDb *sqlx.DB) {
slog.Debug("Replacing database connection")

// close the existing connection, if exists
if d.DB != nil {
_ = d.Close()
}

d.DB = newDb

slog.Debug("Database connection replaced")
}

func (d *Db) Close() error {
slog.Debug("Acquiring lock to close database connection")

d.Lock()
defer d.Unlock()

slog.Debug("Lock acquired to close database connection")

if d.DB != nil {
return d.DB.Close()
}

return nil
}

func (d *Db) PingContext(ctx context.Context) error {
d.RLock()
defer d.RUnlock()

return d.DB.PingContext(ctx)
}
5 changes: 5 additions & 0 deletions pkg/dataaccess/db_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ func (m *MockDb) Ping(ctx context.Context) error {
return args.Error(0)
}

func (m *MockDb) Reconnect(ctx context.Context, connStr string) error {
args := m.Called(ctx)
return args.Error(0)
}

func (m *MockDb) Close(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
Expand Down
30 changes: 27 additions & 3 deletions pkg/dataaccess/db_mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (
"sort"
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"

"github.com/Jacobbrewer1/puppet-summary/pkg/codegen/apis/summary"
"github.com/Jacobbrewer1/puppet-summary/pkg/entities"
"github.com/Jacobbrewer1/puppet-summary/pkg/logging"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/viper"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)

const mongoDatabase = "puppet-summary"
Expand All @@ -25,6 +26,29 @@ type mongodbImpl struct {
client *mongo.Client
}

func (m *mongodbImpl) Reconnect(ctx context.Context, connStr string) error {
if m.client != nil {
if err := m.client.Disconnect(ctx); err != nil {
return fmt.Errorf("error disconnecting from database: %w", err)
}
}

serverAPI := options.ServerAPI(options.ServerAPIVersion1)
opts := options.Client().ApplyURI(connStr).SetServerAPIOptions(serverAPI)
opts.SetAppName(mongoDatabase)

client, err := mongo.Connect(ctx, opts)
if err != nil {
return fmt.Errorf("connect to MongoDB: %w", err)
} else if client == nil {
return errors.New("nil MongoDB client")
}

m.client = client

return nil
}

func (m *mongodbImpl) Close(ctx context.Context) error {
return m.client.Disconnect(ctx)
}
Expand Down
24 changes: 21 additions & 3 deletions pkg/dataaccess/db_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,29 @@ import (
"github.com/Jacobbrewer1/puppet-summary/pkg/entities"
"github.com/Jacobbrewer1/puppet-summary/pkg/logging"
"github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/viper"
)

type mysqlImpl struct {
// client is the database.
client *sql.DB
client *Db
}

func (m *mysqlImpl) Reconnect(ctx context.Context, connStr string) error {
// Create a new database connection.
newDb, err := sqlx.Open("mysql", connStr)
if err != nil {
return fmt.Errorf("error opening mysql: %w", err)
}

err = m.client.Reconnect(ctx, newDb)
if err != nil {
return fmt.Errorf("error reconnecting: %w", err)
}

return nil
}

func (m *mysqlImpl) Close(_ context.Context) error {
Expand Down Expand Up @@ -538,13 +554,15 @@ func NewMySQL(v *viper.Viper) (Database, error) {
return nil, fmt.Errorf("no %s environment variable provided", EnvDbConnStr)
}

d, err := sql.Open("mysql", connectionString)
d, err := sqlx.Open("mysql", connectionString)
if err != nil {
return nil, fmt.Errorf("error opening mysql: %w", err)
}

newDb := NewDb(d)

impl := &mysqlImpl{
client: d,
client: newDb,
}

if err := impl.setup(); err != nil {
Expand Down
6 changes: 5 additions & 1 deletion pkg/dataaccess/db_mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/Jacobbrewer1/puppet-summary/pkg/codegen/apis/summary"
"github.com/Jacobbrewer1/puppet-summary/pkg/entities"
"github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
Expand Down Expand Up @@ -39,9 +40,12 @@ func (s *mysqlSuite) SetupTest() {
s.db = db
s.mockDB = mock

newDb := sqlx.NewDb(s.db, "mysql")
customDb := NewDb(newDb)

// Create a new database object.
s.dbObject = &mysqlImpl{
client: db,
client: customDb,
}
}

Expand Down
Loading

0 comments on commit 50f6ba3

Please sign in to comment.