From e5b4be77716e02d32cdaf98a2982cfd6c0dd680c Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Tue, 20 Aug 2019 18:59:15 +0200 Subject: [PATCH] Let database.Open() use schemeFromURL as well (#271) * Let database.Open() use schemeFromURL as well Otherwise it will fail on MySQL DSNs. Moved schemeFromURL into the database package. Also removed databaseSchemeFromURL and sourceSchemeFromURL as they were just calling schemeFromURL. Fixes https://github.com/golang-migrate/migrate/pull/265#issuecomment-522301237 * Moved url functions into internal/url Also merged the test cases. * Add some database tests to improve coverage * Fix suggestions --- database/driver.go | 16 +++--- database/driver_test.go | 107 +++++++++++++++++++++++++++++++++++++++ internal/url/url.go | 25 +++++++++ internal/url/url_test.go | 48 ++++++++++++++++++ migrate.go | 9 ++-- util.go | 36 ------------- util_test.go | 102 ------------------------------------- 7 files changed, 191 insertions(+), 152 deletions(-) create mode 100644 internal/url/url.go create mode 100644 internal/url/url_test.go diff --git a/database/driver.go b/database/driver.go index 901e5dd66..2c673caaf 100644 --- a/database/driver.go +++ b/database/driver.go @@ -7,8 +7,9 @@ package database import ( "fmt" "io" - nurl "net/url" "sync" + + iurl "github.com/golang-migrate/migrate/v4/internal/url" ) var ( @@ -81,21 +82,16 @@ type Driver interface { // Open returns a new driver instance. func Open(url string) (Driver, error) { - u, err := nurl.Parse(url) + scheme, err := iurl.SchemeFromURL(url) if err != nil { - return nil, fmt.Errorf("Unable to parse URL. Did you escape all reserved URL characters? "+ - "See: https://github.com/golang-migrate/migrate#database-urls Error: %v", err) - } - - if u.Scheme == "" { - return nil, fmt.Errorf("database driver: invalid URL scheme") + return nil, err } driversMu.RLock() - d, ok := drivers[u.Scheme] + d, ok := drivers[scheme] driversMu.RUnlock() if !ok { - return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", u.Scheme) + return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", scheme) } return d.Open(url) diff --git a/database/driver_test.go b/database/driver_test.go index c0a29304f..7880f3208 100644 --- a/database/driver_test.go +++ b/database/driver_test.go @@ -1,8 +1,115 @@ package database +import ( + "io" + "testing" +) + func ExampleDriver() { // see database/stub for an example // database/stub/stub.go has the driver implementation // database/stub/stub_test.go runs database/testing/test.go:Test } + +// Using database/stub here is not possible as it +// results in an import cycle. +type mockDriver struct { + url string +} + +func (m *mockDriver) Open(url string) (Driver, error) { + return &mockDriver{ + url: url, + }, nil +} + +func (m *mockDriver) Close() error { + return nil +} + +func (m *mockDriver) Lock() error { + return nil +} + +func (m *mockDriver) Unlock() error { + return nil +} + +func (m *mockDriver) Run(migration io.Reader) error { + return nil +} + +func (m *mockDriver) SetVersion(version int, dirty bool) error { + return nil +} + +func (m *mockDriver) Version() (version int, dirty bool, err error) { + return 0, false, nil +} + +func (m *mockDriver) Drop() error { + return nil +} + +func TestRegisterTwice(t *testing.T) { + Register("mock", &mockDriver{}) + + var err interface{} + func() { + defer func() { + err = recover() + }() + Register("mock", &mockDriver{}) + }() + + if err == nil { + t.Fatal("expected a panic when calling Register twice") + } +} + +func TestOpen(t *testing.T) { + // Make sure the driver is registered. + // But if the previous test already registered it just ignore the panic. + // If we don't do this it will be impossible to run this test standalone. + func() { + defer func() { + _ = recover() + }() + Register("mock", &mockDriver{}) + }() + + cases := []struct { + url string + err bool + }{ + { + "mock://user:pass@tcp(host:1337)/db", + false, + }, + { + "unknown://bla", + true, + }, + } + + for _, c := range cases { + t.Run(c.url, func(t *testing.T) { + d, err := Open(c.url) + + if err == nil { + if c.err { + t.Fatal("expected an error for an unknown driver") + } else { + if md, ok := d.(*mockDriver); !ok { + t.Fatalf("expected *mockDriver got %T", d) + } else if md.url != c.url { + t.Fatalf("expected %q got %q", c.url, md.url) + } + } + } else if !c.err { + t.Fatalf("did not expect %q", err) + } + }) + } +} diff --git a/internal/url/url.go b/internal/url/url.go new file mode 100644 index 000000000..e793fa828 --- /dev/null +++ b/internal/url/url.go @@ -0,0 +1,25 @@ +package url + +import ( + "errors" + "strings" +) + +var errNoScheme = errors.New("no scheme") +var errEmptyURL = errors.New("URL cannot be empty") + +// schemeFromURL returns the scheme from a URL string +func SchemeFromURL(url string) (string, error) { + if url == "" { + return "", errEmptyURL + } + + i := strings.Index(url, ":") + + // No : or : is the first character. + if i < 1 { + return "", errNoScheme + } + + return url[0:i], nil +} diff --git a/internal/url/url_test.go b/internal/url/url_test.go new file mode 100644 index 000000000..de338e76b --- /dev/null +++ b/internal/url/url_test.go @@ -0,0 +1,48 @@ +package url + +import ( + "testing" +) + +func TestSchemeFromUrl(t *testing.T) { + cases := []struct { + name string + urlStr string + expected string + expectErr error + }{ + { + name: "Simple", + urlStr: "protocol://path", + expected: "protocol", + }, + { + // See issue #264 + name: "MySQLWithPort", + urlStr: "mysql://user:pass@tcp(host:1337)/db", + expected: "mysql", + }, + { + name: "Empty", + urlStr: "", + expectErr: errEmptyURL, + }, + { + name: "NoScheme", + urlStr: "hello", + expectErr: errNoScheme, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + s, err := SchemeFromURL(tc.urlStr) + if err != tc.expectErr { + t.Fatalf("expected %q, but received %q", tc.expectErr, err) + } + if s != tc.expected { + t.Fatalf("expected %q, but received %q", tc.expected, s) + } + }) + } +} diff --git a/migrate.go b/migrate.go index 3ede504e5..f692d6f9e 100644 --- a/migrate.go +++ b/migrate.go @@ -13,6 +13,7 @@ import ( "time" "github.com/golang-migrate/migrate/v4/database" + iurl "github.com/golang-migrate/migrate/v4/internal/url" "github.com/golang-migrate/migrate/v4/source" ) @@ -85,13 +86,13 @@ type Migrate struct { func New(sourceURL, databaseURL string) (*Migrate, error) { m := newCommon() - sourceName, err := sourceSchemeFromURL(sourceURL) + sourceName, err := iurl.SchemeFromURL(sourceURL) if err != nil { return nil, err } m.sourceName = sourceName - databaseName, err := databaseSchemeFromURL(databaseURL) + databaseName, err := iurl.SchemeFromURL(databaseURL) if err != nil { return nil, err } @@ -119,7 +120,7 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { m := newCommon() - sourceName, err := schemeFromURL(sourceURL) + sourceName, err := iurl.SchemeFromURL(sourceURL) if err != nil { return nil, err } @@ -145,7 +146,7 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { m := newCommon() - databaseName, err := schemeFromURL(databaseURL) + databaseName, err := iurl.SchemeFromURL(databaseURL) if err != nil { return nil, err } diff --git a/util.go b/util.go index ecf377391..26131a3ff 100644 --- a/util.go +++ b/util.go @@ -1,7 +1,6 @@ package migrate import ( - "errors" "fmt" nurl "net/url" "strings" @@ -49,41 +48,6 @@ func suint(n int) uint { return uint(n) } -var errNoScheme = errors.New("no scheme") -var errEmptyURL = errors.New("URL cannot be empty") - -func sourceSchemeFromURL(url string) (string, error) { - u, err := schemeFromURL(url) - if err != nil { - return "", fmt.Errorf("source: %v", err) - } - return u, nil -} - -func databaseSchemeFromURL(url string) (string, error) { - u, err := schemeFromURL(url) - if err != nil { - return "", fmt.Errorf("database: %v", err) - } - return u, nil -} - -// schemeFromURL returns the scheme from a URL string -func schemeFromURL(url string) (string, error) { - if url == "" { - return "", errEmptyURL - } - - i := strings.Index(url, ":") - - // No : or : is the first character. - if i < 1 { - return "", errNoScheme - } - - return url[0:i], nil -} - // FilterCustomQuery filters all query values starting with `x-` func FilterCustomQuery(u *nurl.URL) *nurl.URL { ux := *u diff --git a/util_test.go b/util_test.go index ef395e84f..1ad234473 100644 --- a/util_test.go +++ b/util_test.go @@ -1,7 +1,6 @@ package migrate import ( - "errors" nurl "net/url" "testing" ) @@ -31,104 +30,3 @@ func TestFilterCustomQuery(t *testing.T) { t.Fatalf("didn't expect x-custom") } } - -func TestSourceSchemeFromUrlSuccess(t *testing.T) { - urlStr := "protocol://path" - expected := "protocol" - - u, err := sourceSchemeFromURL(urlStr) - if err != nil { - t.Fatalf("expected no error, but received %q", err) - } - if u != expected { - t.Fatalf("expected %q, but received %q", expected, u) - } -} - -func TestSourceSchemeFromUrlFailure(t *testing.T) { - cases := []struct { - name string - urlStr string - expectErr error - }{ - { - name: "Empty", - urlStr: "", - expectErr: errors.New("source: URL cannot be empty"), - }, - { - name: "NoScheme", - urlStr: "hello", - expectErr: errors.New("source: no scheme"), - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - _, err := sourceSchemeFromURL(tc.urlStr) - if err.Error() != tc.expectErr.Error() { - t.Fatalf("expected %q, but received %q", tc.expectErr, err) - } - }) - } -} - -func TestDatabaseSchemeFromUrlSuccess(t *testing.T) { - cases := []struct { - name string - urlStr string - expected string - }{ - { - name: "Simple", - urlStr: "protocol://path", - expected: "protocol", - }, - { - // See issue #264 - name: "MySQLWithPort", - urlStr: "mysql://user:pass@tcp(host:1337)/db", - expected: "mysql", - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - u, err := databaseSchemeFromURL(tc.urlStr) - if err != nil { - t.Fatalf("expected no error, but received %q", err) - } - if u != tc.expected { - t.Fatalf("expected %q, but received %q", tc.expected, u) - } - }) - } -} - -func TestDatabaseSchemeFromUrlFailure(t *testing.T) { - cases := []struct { - name string - urlStr string - expectErr error - }{ - { - name: "Empty", - urlStr: "", - expectErr: errors.New("database: URL cannot be empty"), - }, - { - name: "NoScheme", - urlStr: "hello", - expectErr: errors.New("database: no scheme"), - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - _, err := databaseSchemeFromURL(tc.urlStr) - if err.Error() != tc.expectErr.Error() { - t.Fatalf("expected %q, but received %q", tc.expectErr, err) - } - }) - } -}