Skip to content

Commit

Permalink
Add read replica host config and connection (#231)
Browse files Browse the repository at this point in the history
## Overview
- Add a new field to the postgres db config struct, `readReplicaHost`.
- Add a new endpoint in the `database` package to enable establishing a connection with a db without creating it if it doesn't exist

## Test Plan
- Unit tests and staging tests for artifacts db read replica connection
- Validate the new host name is correct: https://buildkite.com/unionai/artifacts-staging-deploy/builds/505#018f11f5-ea7f-47d0-b145-ca433159e3b4

## Rollout Plan (if applicable)
Staging -> canary -> prod

## Upstream Changes
Should this change be upstreamed to OSS (flyteorg/flyte)? If so, please check this box for auditing. Note, this is the responsibility of each developer. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F).
- [x] To be upstreamed

## Jira Issue
https://unionai.atlassian.net/browse/UNION-406

## Checklist
* [ ] Added tests
* [ ] Ran a deploy dry run and shared the terraform plan
* [ ] Added logging and metrics
* [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list)
* [ ] Updated documentation
  • Loading branch information
squiishyy committed Apr 24, 2024
1 parent 2af4bd1 commit 60d2dd4
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 10 deletions.
22 changes: 12 additions & 10 deletions flytestdlib/database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ var defaultConfig = &DbConfig{
ConnMaxLifeTime: config.Duration{Duration: time.Hour},
Postgres: PostgresConfig{
// These values are suitable for local sandbox development
Host: "localhost",
Port: 30001,
DbName: postgresStr,
User: postgresStr,
Password: postgresStr,
ExtraOptions: "sslmode=disable",
Host: "localhost",
ReadReplicaHost: "localhost",
Port: 30001,
DbName: postgresStr,
User: postgresStr,
Password: postgresStr,
ExtraOptions: "sslmode=disable",
},
}
var configSection = config.MustRegisterSection(database, defaultConfig)
Expand Down Expand Up @@ -64,10 +65,11 @@ type SQLiteConfig struct {

// PostgresConfig includes specific config options for opening a connection to a postgres database.
type PostgresConfig struct {
Host string `json:"host" pflag:",The host name of the database server"`
Port int `json:"port" pflag:",The port name of the database server"`
DbName string `json:"dbname" pflag:",The database name"`
User string `json:"username" pflag:",The database user who is connecting to the server."`
Host string `json:"host" pflag:",The host name of the database server"`
ReadReplicaHost string `json:"readReplicaHost" pflag:",The host name of the read replica database server"`
Port int `json:"port" pflag:",The port name of the database server"`
DbName string `json:"dbname" pflag:",The database name"`
User string `json:"username" pflag:",The database user who is connecting to the server."`
// Either Password or PasswordPath must be set.
Password string `json:"password" pflag:",The database password."`
PasswordPath string `json:"passwordPath" pflag:",Points to the file containing the database password."`
Expand Down
25 changes: 25 additions & 0 deletions flytestdlib/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,31 @@ func GetDB(ctx context.Context, dbConfig *DbConfig, logConfig *logger.Config) (
return gormDb, setupDbConnectionPool(ctx, gormDb, dbConfig)
}

// GetReadOnlyDB uses the dbConfig to create gorm DB object for the read replica passed via the config
func GetReadOnlyDB(ctx context.Context, dbConfig *DbConfig, logConfig *logger.Config) (*gorm.DB, error) {
if dbConfig == nil {
panic("Cannot initialize database repository from empty db config")
}

if dbConfig.Postgres.IsEmpty() || dbConfig.Postgres.ReadReplicaHost == "" {
return nil, fmt.Errorf("read replica host not provided in db config")
}

gormConfig := &gorm.Config{
Logger: GetGormLogger(ctx, logConfig),
DisableForeignKeyConstraintWhenMigrating: false,
}

var gormDb *gorm.DB
var err error
gormDb, err = CreatePostgresReadOnlyDbConnection(ctx, gormConfig, dbConfig.Postgres)
if err != nil {
return nil, err
}

return gormDb, nil
}

func setupDbConnectionPool(ctx context.Context, gormDb *gorm.DB, dbConfig *DbConfig) error {
genericDb, err := gormDb.DB()
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions flytestdlib/database/dbconfig_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions flytestdlib/database/dbconfig_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions flytestdlib/database/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ func getPostgresDsn(ctx context.Context, pgConfig PostgresConfig) string {
pgConfig.Host, pgConfig.Port, pgConfig.DbName, pgConfig.User, password, pgConfig.ExtraOptions)
}

// Produces the DSN (data source name) for the read replica for opening a postgres db connection.
func getPostgresReadDsn(ctx context.Context, pgConfig PostgresConfig) string {
password := resolvePassword(ctx, pgConfig.Password, pgConfig.PasswordPath)
if len(password) == 0 {
// The password-less case is included for development environments.
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s sslmode=disable",
pgConfig.ReadReplicaHost, pgConfig.Port, pgConfig.DbName, pgConfig.User)
}
return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s %s",
pgConfig.ReadReplicaHost, pgConfig.Port, pgConfig.DbName, pgConfig.User, password, pgConfig.ExtraOptions)
}

// CreatePostgresDbIfNotExists creates DB if it doesn't exist for the passed in config
func CreatePostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig PostgresConfig) (*gorm.DB, error) {
dialector := postgres.Open(getPostgresDsn(ctx, pgConfig))
Expand Down Expand Up @@ -94,6 +106,12 @@ func CreatePostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, p
return gorm.Open(dialector, gormConfig)
}

// CreatePostgresDbConnection creates DB connection and returns the gorm.DB object and error
func CreatePostgresReadOnlyDbConnection(ctx context.Context, gormConfig *gorm.Config, pgConfig PostgresConfig) (*gorm.DB, error) {
dialector := postgres.Open(getPostgresReadDsn(ctx, pgConfig))
return gorm.Open(dialector, gormConfig)
}

func IsPgErrorWithCode(err error, code string) bool {
// Newer versions of the gorm postgres driver seem to use
// "github.com/jackc/pgx/v5/pgconn"
Expand Down
41 changes: 41 additions & 0 deletions flytestdlib/database/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,47 @@ func TestGetPostgresDsn(t *testing.T) {
})
}

func TestGetPostgresReadDsn(t *testing.T) {
pgConfig := PostgresConfig{
Host: "localhost",
ReadReplicaHost: "readReplicaHost",
Port: 5432,
DbName: "postgres",
User: "postgres",
ExtraOptions: "sslmode=disable",
}
t.Run("no password", func(t *testing.T) {
dsn := getPostgresReadDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres sslmode=disable", dsn)
})
t.Run("with password", func(t *testing.T) {
pgConfig.Password = "passw"
dsn := getPostgresReadDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres password=passw sslmode=disable", dsn)

})
t.Run("with password, no extra", func(t *testing.T) {
pgConfig.Password = "passwo"
pgConfig.ExtraOptions = ""
dsn := getPostgresReadDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres password=passwo ", dsn)
})
t.Run("with password path", func(t *testing.T) {
password := "1234abc"
tmpFile, err := ioutil.TempFile("", "prefix")
if err != nil {
t.Errorf("Couldn't open temp file: %v", err)
}
defer tmpFile.Close()
if _, err = tmpFile.WriteString(password); err != nil {
t.Errorf("Couldn't write to temp file: %v", err)
}
pgConfig.PasswordPath = tmpFile.Name()
dsn := getPostgresReadDsn(context.TODO(), pgConfig)
assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres password=1234abc ", dsn)
})
}

type wrappedError struct {
err error
}
Expand Down

0 comments on commit 60d2dd4

Please sign in to comment.