From 5f01fdde266ce8448600f4a510bb1c05a8d288f1 Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Wed, 12 Aug 2015 16:19:44 -0700 Subject: [PATCH] Add dial options --- internal/redistest/testdb.go | 3 + redis/conn.go | 108 ++++++++++++++++++++++++----------- 2 files changed, 77 insertions(+), 34 deletions(-) diff --git a/internal/redistest/testdb.go b/internal/redistest/testdb.go index 5f955c42..b6f205b7 100644 --- a/internal/redistest/testdb.go +++ b/internal/redistest/testdb.go @@ -49,15 +49,18 @@ func Dial() (redis.Conn, error) { _, err = c.Do("SELECT", "9") if err != nil { + c.Close() return nil, err } n, err := redis.Int(c.Do("DBSIZE")) if err != nil { + c.Close() return nil, err } if n != 0 { + c.Close() return nil, errors.New("database #9 is not empty, test can not continue") } diff --git a/redis/conn.go b/redis/conn.go index e277bc75..f09f9938 100644 --- a/redis/conn.go +++ b/redis/conn.go @@ -51,56 +51,96 @@ type conn struct { numScratch [40]byte } -// Dial connects to the Redis server at the given network and address. -func Dial(network, address string) (Conn, error) { - dialer := xDialer{} - return dialer.Dial(network, address) -} - // DialTimeout acts like Dial but takes timeouts for establishing the // connection to the server, writing a command and reading a reply. +// +// DialTimeout is deprecated. func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) { - netDialer := net.Dialer{Timeout: connectTimeout} - dialer := xDialer{ - NetDial: netDialer.Dial, - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - } - return dialer.Dial(network, address) + return Dial(network, address, + DialConnectTimeout(connectTimeout), + DialReadTimeout(readTimeout), + DialWriteTimeout(writeTimeout)) } -// A Dialer specifies options for connecting to a Redis server. -type xDialer struct { - // NetDial specifies the dial function for creating TCP connections. If - // NetDial is nil, then net.Dial is used. - NetDial func(network, addr string) (net.Conn, error) +// DialOption specifies an option for dialing a Redis server. +type DialOption struct { + f func(*dialOptions) +} - // ReadTimeout specifies the timeout for reading a single command - // reply. If ReadTimeout is zero, then no timeout is used. - ReadTimeout time.Duration +type dialOptions struct { + readTimeout time.Duration + writeTimeout time.Duration + dial func(network, addr string) (net.Conn, error) + db int + password string +} + +// DialReadTimeout specifies the timeout for reading a single command reply. +func DialReadTimeout(d time.Duration) DialOption { + return DialOption{func(do *dialOptions) { + do.readTimeout = d + }} +} + +// DialWriteTimeout specifies the timeout for writing a single command. +func DialWriteTimeout(d time.Duration) DialOption { + return DialOption{func(do *dialOptions) { + do.writeTimeout = d + }} +} + +// DialConnectTimeout specifies the timeout for connecting to the Redis server. +func DialConnectTimeout(d time.Duration) DialOption { + return DialOption{func(do *dialOptions) { + dialer := net.Dialer{Timeout: d} + do.dial = dialer.Dial + }} +} - // WriteTimeout specifies the timeout for writing a single command. If - // WriteTimeout is zero, then no timeout is used. - WriteTimeout time.Duration +// DialDatabase specifies the database to select when dialing a connection. +func DialDatabase(db int) DialOption { + return DialOption{func(do *dialOptions) { + do.db = db + }} } -// Dial connects to the Redis server at address on the named network. -func (d *xDialer) Dial(network, address string) (Conn, error) { - dial := d.NetDial - if dial == nil { - dial = net.Dial +// Dial connects to the Redis server at the given network and +// address using the specified options. +func Dial(network, address string, options ...DialOption) (Conn, error) { + do := dialOptions{ + dial: net.Dial, } - netConn, err := dial(network, address) + for _, option := range options { + option.f(&do) + } + + netConn, err := do.dial(network, address) if err != nil { return nil, err } - return &conn{ + c := &conn{ conn: netConn, bw: bufio.NewWriter(netConn), br: bufio.NewReader(netConn), - readTimeout: d.ReadTimeout, - writeTimeout: d.WriteTimeout, - }, nil + readTimeout: do.readTimeout, + writeTimeout: do.writeTimeout, + } + + if do.password != "" { + if _, err := c.Do("AUTH", do.password); err != nil { + netConn.Close() + return nil, err + } + } + + if do.db != 0 { + if _, err := c.Do("SELECT", do.db); err != nil { + netConn.Close() + return nil, err + } + } + + return c, nil } // NewConn returns a new Redigo connection for the given net connection.