Skip to content

Commit

Permalink
Properly filter out custom query params in MySQL DB driver
Browse files Browse the repository at this point in the history
  • Loading branch information
dhui committed Aug 22, 2019
1 parent a354c6d commit 0064ee8
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
26 changes: 25 additions & 1 deletion database/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return mx, nil
}

// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
if c == nil {
return nil, ErrNilConfig
}
customQueryParams := map[string]string{}

for k, v := range c.Params {
if strings.HasPrefix(k, "x-") {
customQueryParams[k] = v
delete(c.Params, k)
}
}
return customQueryParams, nil
}

func urlToMySQLConfig(url string) (*mysql.Config, error) {
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
if err != nil {
Expand Down Expand Up @@ -174,6 +191,13 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
if err != nil {
return nil, err
}
fmt.Printf("config: %+v\n", config)

customParams, err := extractCustomQueryParams(config)
if err != nil {
return nil, err
}
fmt.Printf("config: %+v\n", config)

db, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
Expand All @@ -182,7 +206,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) {

mx, err := WithInstance(db, &Config{
DatabaseName: config.DBName,
MigrationsTable: config.Params["x-migrations-table"],
MigrationsTable: customParams["x-migrations-table"],
})
if err != nil {
return nil, err
Expand Down
60 changes: 58 additions & 2 deletions database/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ import (
sqldriver "database/sql/driver"
"fmt"
"log"

"github.com/golang-migrate/migrate/v4"
"testing"
)

import (
"github.com/dhui/dktest"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
)

import (
"github.com/golang-migrate/migrate/v4"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
Expand Down Expand Up @@ -175,6 +175,62 @@ func TestLockWorks(t *testing.T) {
})
}

func TestExtractCustomQueryParams(t *testing.T) {
testcases := []struct {
name string
config *mysql.Config
expectedParams map[string]string
expectedCustomParams map[string]string
expectedErr error
}{
{name: "nil config", expectedErr: ErrNilConfig},
{
name: "no params",
config: mysql.NewConfig(),
expectedCustomParams: map[string]string{},
},
{
name: "no custom params",
config: &mysql.Config{Params: map[string]string{"hello": "world"}},
expectedParams: map[string]string{"hello": "world"},
expectedCustomParams: map[string]string{},
},
{
name: "one param, one custom param",
config: &mysql.Config{
Params: map[string]string{"hello": "world", "x-foo": "bar"},
},
expectedParams: map[string]string{"hello": "world"},
expectedCustomParams: map[string]string{"x-foo": "bar"},
},
{
name: "multiple params, multiple custom params",
config: &mysql.Config{
Params: map[string]string{
"hello": "world",
"x-foo": "bar",
"dead": "beef",
"x-cat": "hat",
},
},
expectedParams: map[string]string{"hello": "world", "dead": "beef"},
expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
customParams, err := extractCustomQueryParams(tc.config)
if tc.config != nil {
assert.Equal(t, tc.expectedParams, tc.config.Params,
"Expected config params have custom params properly removed")
}
assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
assert.Equal(t, tc.expectedCustomParams, customParams,
"Expected custom params to be properly extracted")
})
}
}

func TestURLToMySQLConfig(t *testing.T) {
testcases := []struct {
name string
Expand Down

0 comments on commit 0064ee8

Please sign in to comment.