Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add TestOnBorrowContext #660

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions redis/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,24 @@ type Pool struct {
// DialContext is an application supplied function for creating and configuring a
// connection with the given context.
//
// The connection returned from Dial must not be in a special state
// The connection returned from DialContext must not be in a special state
// (subscribed to pubsub channel, transaction started, ...).
DialContext func(ctx context.Context) (Conn, error)

// TestOnBorrow is an optional application supplied function for checking
// the health of an idle connection before the connection is used again by
// the application. Argument t is the time that the connection was returned
// the application. Argument lastUsed is the time when the connection was returned
// to the pool. If the function returns an error, then the connection is
// closed.
TestOnBorrow func(c Conn, lastUsed time.Time) error

// TestOnBorrowContext is an optional application supplied function
// for checking the health of an idle connection with the given context
// before the connection is used again by the application.
// Argument lastUsed is the time when the connection was returned
// to the pool. If the function returns an error, then the connection is
// closed.
TestOnBorrow func(c Conn, t time.Time) error
TestOnBorrowContext func(ctx context.Context, c Conn, lastUsed time.Time) error

// Maximum number of idle connections in the pool.
MaxIdle int
Expand Down Expand Up @@ -228,6 +236,7 @@ func (p *Pool) GetContext(ctx context.Context) (Conn, error) {
p.idle.popFront()
p.mu.Unlock()
if (p.TestOnBorrow == nil || p.TestOnBorrow(pc.c, pc.t) == nil) &&
(p.TestOnBorrowContext == nil || p.TestOnBorrowContext(ctx, pc.c, pc.t) == nil) &&
(p.MaxConnLifetime == 0 || nowFunc().Sub(pc.created) < p.MaxConnLifetime) {
return &activeConn{p: p, pc: pc}, nil
}
Expand Down
189 changes: 170 additions & 19 deletions redis/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,48 @@ func (c *poolTestConn) Close() error {
func (c *poolTestConn) Err() error { return c.err }

func (c *poolTestConn) Do(commandName string, args ...interface{}) (interface{}, error) {
return c.do(c.Conn.Do, commandName, args...)
}

func (c *poolTestConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (interface{}, error) {
cwc, ok := c.Conn.(redis.ConnWithContext)
if !ok {
return nil, errors.New("redis: connection does not support ConnWithContext")
}
return c.do(
func(c string, a ...interface{}) (interface{}, error) {
return cwc.DoContext(ctx, c, a...)
},
commandName, args)
}

func (c *poolTestConn) do(
fn func(commandName string, args ...interface{}) (interface{}, error),
commandName string, args ...interface{},
) (interface{}, error) {
if commandName == "ERR" {
c.err = args[0].(error)
commandName = "PING"
}
if commandName != "" {
c.d.commands = append(c.d.commands, commandName)
}
return c.Conn.Do(commandName, args...)
return fn(commandName, args...)
}

func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
c.d.commands = append(c.d.commands, commandName)
return c.Conn.Send(commandName, args...)
}

func (c *poolTestConn) ReceiveContext(ctx context.Context) (reply interface{}, err error) {
cwc, ok := c.Conn.(redis.ConnWithContext)
if !ok {
return nil, errors.New("redis: connection does not support ConnWithContext")
}
return cwc.ReceiveContext(ctx)
}

type poolDialer struct {
mu sync.Mutex
t *testing.T
Expand All @@ -73,14 +100,18 @@ type poolDialer struct {
}

func (d *poolDialer) dial() (redis.Conn, error) {
return d.dialContext(context.Background())
}

func (d *poolDialer) dialContext(ctx context.Context) (redis.Conn, error) {
d.mu.Lock()
d.dialed += 1
dialErr := d.dialErr
d.mu.Unlock()
if dialErr != nil {
return nil, d.dialErr
}
c, err := redis.DialDefaultServer()
c, err := redis.DialDefaultServerContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -90,15 +121,14 @@ func (d *poolDialer) dial() (redis.Conn, error) {
return &poolTestConn{d: d, Conn: c}, nil
}

func (d *poolDialer) dialContext(ctx context.Context) (redis.Conn, error) {
return d.dial()
}

func (d *poolDialer) check(message string, p *redis.Pool, dialed, open, inuse int) {
d.t.Helper()
vasayxtx marked this conversation as resolved.
Show resolved Hide resolved
d.checkAll(message, p, dialed, open, inuse, 0, 0)
}

func (d *poolDialer) checkAll(message string, p *redis.Pool, dialed, open, inuse int, waitCountMax int64, waitDurationMax time.Duration) {
d.t.Helper()

d.mu.Lock()
defer d.mu.Unlock()

Expand Down Expand Up @@ -368,21 +398,142 @@ func TestPoolConcurrenSendReceive(t *testing.T) {
}

func TestPoolBorrowCheck(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") },
pingN := func(ctx context.Context, p *redis.Pool, n int) {
for i := 0; i < n; i++ {
func() {
c, err := p.GetContext(ctx)
require.NoError(t, err)
defer func() {
require.NoError(t, c.Close())
}()
_, err = redis.DoContext(c, ctx, "PING")
require.NoError(t, err)
}()
}
}
defer p.Close()

for i := 0; i < 10; i++ {
c := p.Get()
_, err := c.Do("PING")
require.NoError(t, err)
c.Close()
checkLastUsedTimes := func(lastUsedTimes []time.Time, startTime, endTime time.Time, wantLen int) {
require.Len(t, lastUsedTimes, wantLen)
for i, lastUsed := range lastUsedTimes {
if i == 0 {
require.True(t, lastUsed.After(startTime))
} else {
require.True(t, lastUsed.After(lastUsedTimes[i-1]))
}
require.True(t, lastUsed.Before(endTime))
}
}
d.check("1", p, 10, 1, 0)

t.Run("TestOnBorrow-error", func(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
DialContext: d.dialContext,
TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") },
}
defer p.Close()
pingN(context.Background(), p, 10)
d.check("1", p, 10, 1, 0)
})

t.Run("TestOnBorrow-nil-error", func(t *testing.T) {
d := poolDialer{t: t}
var borrowErrs []error
var lastUsedTimes []time.Time
p := &redis.Pool{
MaxIdle: 2,
DialContext: d.dialContext,
TestOnBorrow: func(c redis.Conn, lastUsed time.Time) error {
lastUsedTimes = append(lastUsedTimes, lastUsed)
_, err := c.Do("PING")
if err != nil {
borrowErrs = append(borrowErrs, err)
}
return err
},
}
defer p.Close()

startTime := time.Now()
pingN(context.Background(), p, 10)
endTime := time.Now()

require.Empty(t, borrowErrs)
checkLastUsedTimes(lastUsedTimes, startTime, endTime, 9)
d.check("1", p, 1, 1, 0)
})

t.Run("TestOnBorrowContext-error", func(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
DialContext: d.dialContext,
TestOnBorrowContext: func(context.Context, redis.Conn, time.Time) error { return redis.Error("BLAH") },
}
defer p.Close()
pingN(context.Background(), p, 10)
d.check("1", p, 10, 1, 0)
})

t.Run("TestOnBorrowContext-nil-error", func(t *testing.T) {
d := poolDialer{t: t}
var borrowErrs []error
var lastUsedTimes []time.Time
p := &redis.Pool{
MaxIdle: 2,
DialContext: d.dialContext,
TestOnBorrowContext: func(ctx context.Context, c redis.Conn, lastUsed time.Time) error {
lastUsedTimes = append(lastUsedTimes, lastUsed)
_, err := redis.DoContext(c, ctx, "PING")
if err != nil {
borrowErrs = append(borrowErrs, err)
}
return err
},
}
defer p.Close()

startTime := time.Now()
pingN(context.Background(), p, 10)
endTime := time.Now()

require.Empty(t, borrowErrs)
checkLastUsedTimes(lastUsedTimes, startTime, endTime, 9)
d.check("1", p, 1, 1, 0)
})

t.Run("TestOnBorrowContext-context.Canceled", func(t *testing.T) {
d := poolDialer{t: t}
var borrowErrs []error
p := &redis.Pool{
MaxIdle: 2,
DialContext: d.dialContext,
TestOnBorrowContext: func(ctx context.Context, c redis.Conn, _ time.Time) error {
_, err := redis.DoContext(c, ctx, "PING")
if err != nil {
borrowErrs = append(borrowErrs, err)
}
return err
},
}
defer p.Close()

ctx, ctxCancel := context.WithCancel(context.Background())
defer ctxCancel()

pingN(ctx, p, 2)
d.check("1", p, 1, 1, 0)
require.Empty(t, borrowErrs)

ctxCancel()

_, err := p.GetContext(ctx)
require.ErrorIs(t, err, context.Canceled)

d.check("1", p, 2, 0, 0)
require.Len(t, borrowErrs, 1)
require.ErrorIs(t, borrowErrs[0], context.Canceled)
})
}

func TestPoolMaxActive(t *testing.T) {
Expand Down Expand Up @@ -757,7 +908,7 @@ func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) {
MaxIdle: count,
MaxActive: count,
Dial: d.dial,
TestOnBorrow: func(c redis.Conn, t time.Time) error {
TestOnBorrow: func(redis.Conn, time.Time) error {
return errors.New("No way back into the real world.")
},
}
Expand Down
11 changes: 9 additions & 2 deletions redis/test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package redis

import (
"bufio"
"context"
"errors"
"flag"
"fmt"
Expand Down Expand Up @@ -197,15 +198,21 @@ func DefaultServerAddr() (string, error) {
// DialDefaultServer starts the test server if not already started and dials a
// connection to the server.
func DialDefaultServer(options ...DialOption) (Conn, error) {
return DialDefaultServerContext(context.Background(), options...)
}

// DialDefaultServerContext starts the test server if not already started and
// dials a connection to the server with the given context.
func DialDefaultServerContext(ctx context.Context, options ...DialOption) (Conn, error) {
addr, err := DefaultServerAddr()
if err != nil {
return nil, err
}
c, err := Dial("tcp", addr, append([]DialOption{DialReadTimeout(1 * time.Second), DialWriteTimeout(1 * time.Second)}, options...)...)
c, err := DialContext(ctx, "tcp", addr, append([]DialOption{DialReadTimeout(1 * time.Second), DialWriteTimeout(1 * time.Second)}, options...)...)
if err != nil {
return nil, err
}
if _, err = c.Do("FLUSHDB"); err != nil {
if _, err = DoContext(c, ctx, "FLUSHDB"); err != nil {
return nil, err
}
return c, nil
Expand Down
Loading