Skip to content

Commit

Permalink
Merge pull request #659 from AndreasKl/add-with-connection-to-postgres
Browse files Browse the repository at this point in the history
Add WithConnection to Postgres similar to MySQL.
  • Loading branch information
dhui authored Feb 24, 2022
2 parents e1d604b + 3dfae0d commit 57aead3
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 14 deletions.
46 changes: 32 additions & 14 deletions database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"context"
"database/sql"
"fmt"
"go.uber.org/atomic"
"io"
"io/ioutil"
nurl "net/url"
Expand All @@ -16,10 +15,12 @@ import (
"strings"
"time"

"go.uber.org/atomic"

"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/multistmt"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-multierror"
"github.com/lib/pq"
)

Expand Down Expand Up @@ -65,19 +66,19 @@ type Postgres struct {
config *Config
}

func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
if err := conn.PingContext(ctx); err != nil {
return nil, err
}

if config.DatabaseName == "" {
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -91,7 +92,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config.SchemaName == "" {
query := `SELECT CURRENT_SCHEMA()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand Down Expand Up @@ -119,15 +120,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
}
}

conn, err := instance.Conn(context.Background())

if err != nil {
return nil, err
}

px := &Postgres{
conn: conn,
db: instance,
config: config,
}

Expand All @@ -138,6 +132,26 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return px, nil
}

func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
ctx := context.Background()

if err := instance.Ping(); err != nil {
return nil, err
}

conn, err := instance.Conn(ctx)
if err != nil {
return nil, err
}

px, err := WithConnection(ctx, conn, config)
if err != nil {
return nil, err
}
px.db = instance
return px, nil
}

func (p *Postgres) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
Expand Down Expand Up @@ -207,7 +221,11 @@ func (p *Postgres) Open(url string) (database.Driver, error) {

func (p *Postgres) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
var dbErr error
if p.db != nil {
dbErr = p.db.Close()
}

if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
Expand Down
38 changes: 38 additions & 0 deletions database/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,44 @@ func TestWithInstance_Concurrent(t *testing.T) {
}
})
}

func TestWithConnection(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

db, err := sql.Open("postgres", pgConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()

ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}

p, err := WithConnection(ctx, conn, &Config{})
if err != nil {
t.Fatal(err)
}

defer func() {
if err := p.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, p, []byte("SELECT 1"))
})
}

func Test_computeLineFromPos(t *testing.T) {
testcases := []struct {
pos int
Expand Down

0 comments on commit 57aead3

Please sign in to comment.