diff --git a/pkg/controller.v1alpha3/consts/const.go b/pkg/controller.v1alpha3/consts/const.go index ba97e9763ad..d2608000d25 100644 --- a/pkg/controller.v1alpha3/consts/const.go +++ b/pkg/controller.v1alpha3/consts/const.go @@ -1,8 +1,6 @@ package consts -import ( - "os" -) +import "github.com/kubeflow/katib/pkg/util/v1alpha3/env" const ( ConfigExperimentSuggestionName = "experiment-suggestion-name" @@ -31,12 +29,5 @@ const ( ) var ( - DefaultKatibNamespace = getEnvOrDefault(DefaultKatibNamespaceEnvName, "kubeflow") + DefaultKatibNamespace = env.GetEnvOrDefault(DefaultKatibNamespaceEnvName, "kubeflow") ) - -func getEnvOrDefault(key string, fallback string) string { - if value, ok := os.LookupEnv(key); ok { - return value - } - return fallback -} diff --git a/pkg/db/v1alpha3/common/const.go b/pkg/db/v1alpha3/common/const.go index 818d8bab284..2defdaff849 100644 --- a/pkg/db/v1alpha3/common/const.go +++ b/pkg/db/v1alpha3/common/const.go @@ -6,4 +6,10 @@ const ( MySqlDBNameEnvValue = "mysql" DBPasswordEnvName = "DB_PASSWORD" + + MySQLDBHostEnvName = "MYSQL_HOST" + MySQLDBPortEnvName = "MYSQL_PORT" + + DefaultMySQLHost = "katib-db" + DefaultMySQLPort = "3306" ) diff --git a/pkg/db/v1alpha3/mysql/mysql.go b/pkg/db/v1alpha3/mysql/mysql.go index dbdb87ec3de..5b7c738933d 100644 --- a/pkg/db/v1alpha3/mysql/mysql.go +++ b/pkg/db/v1alpha3/mysql/mysql.go @@ -14,11 +14,12 @@ import ( v1alpha3 "github.com/kubeflow/katib/pkg/apis/manager/v1alpha3" "github.com/kubeflow/katib/pkg/db/v1alpha3/common" + "github.com/kubeflow/katib/pkg/util/v1alpha3/env" ) const ( dbDriver = "mysql" - dbNameTmpl = "root:%s@tcp(katib-db:3306)/katib?timeout=5s" + dbNameTmpl = "root:%s@tcp(%s:%s)/katib?timeout=5s" mysqlTimeFmt = "2006-01-02 15:04:05.999999" connectInterval = 5 * time.Second @@ -32,7 +33,11 @@ type dbConn struct { func getDbName() string { dbPassEnvName := common.DBPasswordEnvName dbPass := os.Getenv(dbPassEnvName) - return fmt.Sprintf(dbNameTmpl, dbPass) + dbHost := env.GetEnvOrDefault( + common.MySQLDBHostEnvName, common.DefaultMySQLHost) + dbPort := env.GetEnvOrDefault( + common.MySQLDBPortEnvName, common.DefaultMySQLPort) + return fmt.Sprintf(dbNameTmpl, dbPass, dbHost, dbPort) } func openSQLConn(driverName string, dataSourceName string, interval time.Duration, diff --git a/pkg/util/v1alpha3/env/env.go b/pkg/util/v1alpha3/env/env.go new file mode 100644 index 00000000000..3de45f1865f --- /dev/null +++ b/pkg/util/v1alpha3/env/env.go @@ -0,0 +1,10 @@ +package env + +import "os" + +func GetEnvOrDefault(key string, fallback string) string { + if value, ok := os.LookupEnv(key); ok { + return value + } + return fallback +} diff --git a/pkg/util/v1alpha3/env/env_test.go b/pkg/util/v1alpha3/env/env_test.go new file mode 100644 index 00000000000..b1f5d0caf74 --- /dev/null +++ b/pkg/util/v1alpha3/env/env_test.go @@ -0,0 +1,21 @@ +package env + +import ( + "os" + "testing" +) + +func TestGetEnvWithDefault(t *testing.T) { + expected := "FAKE" + key := "TEST" + v := GetEnvOrDefault(key, expected) + if v != expected { + t.Errorf("Expected %s, got %s", expected, v) + } + expected = "FAKE1" + os.Setenv(key, expected) + v = GetEnvOrDefault(key, "") + if v != expected { + t.Errorf("Expected %s, got %s", expected, v) + } +}