Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multihost connstring support #714

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions .travis.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,38 @@ pgdg_repository() {
}

postgresql_configure() {
sudo tee /etc/postgresql/$PGVERSION/main/pg_hba.conf > /dev/null <<-config
local instance=$1
case $instance in
primary)
sudo pg_createcluster -p 5432 $PGVERSION $instance
;;
secondary)
sudo pg_createcluster -p 54321 $PGVERSION $instance
sudo rm -rf /var/lib/postgresql/$PGVERSION/$instance
sudo -u $PGUSER pg_basebackup -D /var/lib/postgresql/$PGVERSION/$instance -R -Xs -P -d "host=${PGHOST} port=5432"
;;
*)
echo "first argument to postgresql_configure must be 'primary' or 'secondary'"
;;
esac

sudo tee /etc/postgresql/$PGVERSION/$instance/pg_hba.conf > /dev/null <<-config
local all all trust
hostnossl all pqgossltest 127.0.0.1/32 reject
hostnossl all pqgosslcert 127.0.0.1/32 reject
hostssl all pqgossltest 127.0.0.1/32 trust
hostssl all pqgosslcert 127.0.0.1/32 cert
host all all 127.0.0.1/32 trust
host replication all 127.0.0.1/32 trust
hostnossl all pqgossltest ::1/128 reject
hostnossl all pqgosslcert ::1/128 reject
hostssl all pqgossltest ::1/128 trust
hostssl all pqgosslcert ::1/128 cert
host all all ::1/128 trust
host replication all ::1/128 trust
config

xargs sudo install -o postgres -g postgres -m 600 -t /var/lib/postgresql/$PGVERSION/main/ <<-certificates
xargs sudo install -o postgres -g postgres -m 600 -t /var/lib/postgresql/$PGVERSION/$instance/ <<-certificates
certs/root.crt
certs/server.crt
certs/server.key
Expand All @@ -39,15 +56,18 @@ postgresql_configure() {
$PGVERSION
9.2
versions
sudo tee -a /etc/postgresql/$PGVERSION/main/postgresql.conf > /dev/null <<-config
ssl_ca_file = 'root.crt'
ssl_cert_file = 'server.crt'
ssl_key_file = 'server.key'
sudo tee -a /etc/postgresql/$PGVERSION/$instance/postgresql.conf > /dev/null <<-config
ssl_ca_file = 'root.crt'
ssl_cert_file = 'server.crt'
ssl_key_file = 'server.key'
wal_level = hot_standby
hot_standby = on
max_wal_senders = 2
config

echo 127.0.0.1 postgres | sudo tee -a /etc/hosts > /dev/null

sudo service postgresql restart
sudo pg_ctlcluster $PGVERSION $instance start
}

postgresql_install() {
Expand All @@ -56,6 +76,9 @@ postgresql_install() {
postgresql-server-dev-$PGVERSION
postgresql-contrib-$PGVERSION
packages
# disable packaged default cluster; will add our own with postgresql_configure
sudo service postgresql stop
sudo pg_dropcluster $PGVERSION main
}

postgresql_uninstall() {
Expand Down Expand Up @@ -95,4 +118,4 @@ golint_install() {
go get github.com/golang/lint/golint
}

$1
$@
6 changes: 2 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ env:
- PGVERSION=9.5
- PGVERSION=9.4
- PGVERSION=9.3
- PGVERSION=9.2
- PGVERSION=9.1
- PGVERSION=9.0

before_install:
- ./.travis.sh postgresql_uninstall
- ./.travis.sh pgdg_repository
- ./.travis.sh postgresql_install
- ./.travis.sh postgresql_configure
- ./.travis.sh postgresql_configure primary
- ./.travis.sh postgresql_configure secondary
- ./.travis.sh client_configure
- ./.travis.sh megacheck_install
- ./.travis.sh golint_install
Expand Down
146 changes: 121 additions & 25 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ type conn struct {

// Handle driver-side settings in parsed connection string.
func (cn *conn) handleDriverSettings(o values) (err error) {
if targetSessionAttrs, ok := o["target_session_attrs"]; ok {
switch targetSessionAttrs {
case "any":
case "read-write":
default:
return fmt.Errorf("unrecognized value %q for target_session_attrs", targetSessionAttrs)
}
}

boolSetting := func(key string, val *bool) error {
if value, ok := o[key]; ok {
if value == "yes" {
Expand Down Expand Up @@ -251,6 +260,19 @@ func Open(name string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, name)
}

// errors is used to accumulate connection errors when attempting to connect to
// multiple hosts.
type errorSlice []error

func (es errorSlice) Error() string {
// use bytes.Buffer?
out := ""
for _, e := range es {
out += "; " + e.Error()
}
return out
}

// DialOpen opens a new connection to the database using a dialer.
func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
// Handle any panics during connection initialization. Note that we
Expand Down Expand Up @@ -325,39 +347,56 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
o["user"] = u
}

cn := &conn{
opts: o,
dialer: d,
}
err = cn.handleDriverSettings(o)
if err != nil {
return nil, err
}
cn.handlePgpass(o)

cn.c, err = dial(d, o)
oo, err := buildHostOptions(o)
if err != nil {
return nil, err
}
errs := make(errorSlice, 0)
for _, o = range oo {
cn := &conn{
opts: o,
dialer: d,
}
err = cn.handleDriverSettings(o)
if err != nil {
return nil, err
}
cn.handlePgpass(o)

// cn.ssl and cn.startup panic on error. Make sure we don't leak cn.c.
panicking := true
defer func() {
if panicking {
cn.c.Close()
cn.c, err = dial(d, o)
if err != nil {
errs = append(errs, err)
continue
}
}()

cn.ssl(o)
cn.buf = bufio.NewReader(cn.c)
cn.startup(o)
// cn.ssl and cn.startup panic on error. Make sure we don't leak cn.c.
panicking := true
defer func() {
if panicking {
cn.c.Close()
}
}()

// reset the deadline, in case one was set (see dial)
if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
err = cn.c.SetDeadline(time.Time{})
cn.ssl(o)
cn.buf = bufio.NewReader(cn.c)
cn.startup(o)
if err = cn.checkWritable(o); err != nil {
errs = append(errs, err)
continue
}

// reset the deadline, in case one was set (see dial)
if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
err = cn.c.SetDeadline(time.Time{})
}
panicking = false
return cn, err
}
// try to avoid breakig clients that expect a single error
if len(errs) == 1 {
return nil, errs[0]
}
panicking = false
return cn, err
return nil, errs
}

func dial(d Dialer, o values) (net.Conn, error) {
Expand Down Expand Up @@ -400,6 +439,29 @@ func network(o values) (string, string) {
return "tcp", net.JoinHostPort(host, o["port"])
}

func buildHostOptions(o values) ([]values, error) {
hosts := strings.Split(o["host"], ",")
ports := strings.Split(o["port"], ",")
if len(ports) > 1 && len(hosts) != len(ports) {
return nil, fmt.Errorf("could not match %d port numbers to %d hosts", len(ports), len(hosts))
}
oo := make([]values, len(hosts))
for i, host := range hosts {
oCopy := make(values, len(o))
for key, val := range o {
oCopy[key] = val
}
oCopy["host"] = host
if len(ports) > 1 {
oCopy["port"] = ports[i]
} else {
oCopy["port"] = ports[0]
}
oo[i] = oCopy
}
return oo, nil
}

type values map[string]string

// scanner implements a tokenizer for libpq-style option strings.
Expand Down Expand Up @@ -1074,6 +1136,8 @@ func isDriverSetting(key string) bool {
return true
case "binary_parameters":
return true
case "target_session_attrs":
return true

default:
return false
Expand Down Expand Up @@ -1123,6 +1187,38 @@ func (cn *conn) startup(o values) {
}
}

func (cn *conn) checkWritable(o values) error {
tsa, ok := o["target_session_attrs"]
if !ok {
return nil
}
switch tsa {
case "any":
return nil
case "read-write":
res, err := cn.simpleQuery("SHOW transaction_read_only")
if err != nil {
return err
}
defer res.Close()
// ok to be optimistic here?
vs := make([]driver.Value, 1)
res.Next(vs)

readOnly, ok := vs[0].(string)
if !ok {
return errors.New("could not parse result of transaction_read_only as string")
}
if readOnly == "on" {
return errors.New("could not make a writable connection to server")
}
return nil
default:
// sanity check; should never happen because we handleDriverSettings() before connecting
panic("unrecognized value for target_session_attrs")
}
}

func (cn *conn) auth(r *readBuf, o values) {
switch code := r.int32(); code {
case 0:
Expand Down
36 changes: 29 additions & 7 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,22 +121,44 @@ func TestCommitInFailedTransaction(t *testing.T) {
}

func TestOpenURL(t *testing.T) {
testURL := func(url string) {
maybeFatal := func(f func(string) error, url string, expectError bool) {
err := f(url)
if err != nil && !expectError {
t.Fatal(url, err)
} else if err == nil && expectError {
t.Fatal("expected failed connection")
}
}
testURL := func(url string) error {
db, err := openTestConnConninfo(url)
if err != nil {
t.Fatal(err)
return err
}
defer db.Close()
// database/sql might not call our Open at all unless we do something with
// the connection
txn, err := db.Begin()
_, err = db.Exec("SHOW server_version")
if err != nil {
t.Fatal(err)
return err
}
txn.Rollback()
return nil
}
testURL("postgres://")
testURL("postgresql://")
maybeFatal(testURL, "postgres://", false)
maybeFatal(testURL, "postgresql://", false)

// ensure we get an error for a non-running server
maybeFatal(testURL, "postgresql://:55555", true)

// ensure connection attempts where one server is down still work
maybeFatal(testURL, "postgresql://:5432,:55555/", false)
maybeFatal(testURL, "postgresql://:55555,:5432/", false)

// ensure target_session_attrs works
maybeFatal(testURL, "postgresql://:5432/?target_session_attrs=read-write", false)
maybeFatal(testURL, "postgresql://:54321/?target_session_attrs=any", false)
maybeFatal(testURL, "postgresql://:54321/?target_session_attrs=read-write", true)
maybeFatal(testURL, "postgresql://:5432,:54321/?target_session_attrs=read-write", false)
maybeFatal(testURL, "postgresql://:54321,:5432/?target_session_attrs=read-write", false)
}

const pgpassFile = "/tmp/pqgotest_pgpass"
Expand Down
19 changes: 14 additions & 5 deletions url.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,21 @@ func ParseURL(url string) (string, error) {
accrue("password", v)
}

if host, port, err := net.SplitHostPort(u.Host); err != nil {
accrue("host", u.Host)
} else {
accrue("host", host)
accrue("port", port)
// handle non-standard "host1:port1,host2:port2" format
// (for feature parity with libpq >= 10)
hostports := strings.Split(u.Host, ",")
hosts := make([]string, 0, len(hostports))
ports := make([]string, 0, len(hostports))
for _, hostport := range hostports {
if host, port, err := net.SplitHostPort(hostport); err != nil {
hosts = append(hosts, hostport)
} else {
hosts = append(hosts, host)
ports = append(ports, port)
}
}
accrue("host", strings.Join(hosts, ","))
accrue("port", strings.Join(ports, ","))

if u.Path != "" {
accrue("dbname", u.Path[1:])
Expand Down
Loading