Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dbnode] Client borrow connection API #3019

Merged
merged 11 commits into from
Dec 22, 2020
43 changes: 37 additions & 6 deletions src/dbnode/client/client_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 16 additions & 10 deletions src/dbnode/client/connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ import (
"sync/atomic"
"time"

"github.com/m3db/m3/src/dbnode/generated/thrift/rpc"
"github.com/m3db/m3/src/dbnode/topology"
xresource "github.com/m3db/m3/src/x/resource"
murmur3 "github.com/m3db/stackmurmur3/v2"

"github.com/uber-go/tally"
"github.com/uber/tchannel-go"
"github.com/uber/tchannel-go/thrift"
"go.uber.org/zap"

"github.com/m3db/m3/src/dbnode/generated/thrift/rpc"
"github.com/m3db/m3/src/dbnode/topology"
)

const (
Expand Down Expand Up @@ -67,15 +67,21 @@ type connPool struct {
healthStatus tally.Gauge
}

// PooledChannel is a tchannel.Channel for a pooled connection.
type PooledChannel interface {
GetSubChannel(serviceName string, opts ...tchannel.SubChannelOption) *tchannel.SubChannel
Close()
}

type conn struct {
channel xresource.SimpleCloser
channel PooledChannel
client rpc.TChanNode
}

// NewConnectionFn is a function that creates a connection.
type NewConnectionFn func(
channelName string, addr string, opts Options,
) (xresource.SimpleCloser, rpc.TChanNode, error)
) (PooledChannel, rpc.TChanNode, error)

type healthCheckFn func(client rpc.TChanNode, opts Options) error

Expand Down Expand Up @@ -134,20 +140,20 @@ func (p *connPool) ConnectionCount() int {
return int(poolLen)
}

func (p *connPool) NextClient() (rpc.TChanNode, error) {
func (p *connPool) NextClient() (rpc.TChanNode, PooledChannel, error) {
p.RLock()
if p.status != statusOpen {
p.RUnlock()
return nil, errConnectionPoolClosed
return nil, nil, errConnectionPoolClosed
}
if p.poolLen < 1 {
p.RUnlock()
return nil, errConnectionPoolHasNoConnections
return nil, nil, errConnectionPoolHasNoConnections
}
n := atomic.AddInt64(&p.used, 1)
conn := p.pool[n%p.poolLen]
p.RUnlock()
return conn.client, nil
return conn.client, conn.channel, nil
}

func (p *connPool) Close() {
Expand Down
39 changes: 22 additions & 17 deletions src/dbnode/client/connection_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ import (
"github.com/m3db/m3/src/dbnode/generated/thrift/rpc"
"github.com/m3db/m3/src/dbnode/topology"
xclock "github.com/m3db/m3/src/x/clock"
xresource "github.com/m3db/m3/src/x/resource"
"github.com/stretchr/testify/require"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/uber/tchannel-go"
)

const (
Expand All @@ -42,10 +42,19 @@ const (
)

var (
h = topology.NewHost(testHostStr, testHostAddr)
channelNone = &nullChannel{}
h = topology.NewHost(testHostStr, testHostAddr)
)

type noopPooledChannel struct{}

func (c *noopPooledChannel) Close() {}
func (c *noopPooledChannel) GetSubChannel(
serviceName string,
opts ...tchannel.SubChannelOption,
) *tchannel.SubChannel {
return nil
}

func newConnectionPoolTestOptions() Options {
return newSessionTestOptions().
SetBackgroundConnectInterval(5 * time.Millisecond).
Expand Down Expand Up @@ -85,12 +94,12 @@ func TestConnectionPoolConnectsAndRetriesConnects(t *testing.T) {

fn := func(
ch string, addr string, opts Options,
) (xresource.SimpleCloser, rpc.TChanNode, error) {
) (PooledChannel, rpc.TChanNode, error) {
attempt := int(atomic.AddInt32(&attempts, 1))
if attempt == 1 {
return nil, nil, fmt.Errorf("a connect error")
}
return channelNone, nil, nil
return &noopPooledChannel{}, nil, nil
}

opts = opts.SetNewConnectionFn(fn)
Expand Down Expand Up @@ -151,7 +160,7 @@ func TestConnectionPoolConnectsAndRetriesConnects(t *testing.T) {
conns.Close()
doneWg.Done()

nextClient, err := conns.NextClient()
nextClient, _, err := conns.NextClient()
require.Nil(t, nextClient)
require.Equal(t, errConnectionPoolClosed, err)
}
Expand Down Expand Up @@ -237,12 +246,12 @@ func TestConnectionPoolHealthChecks(t *testing.T) {

fn := func(
ch string, addr string, opts Options,
) (xresource.SimpleCloser, rpc.TChanNode, error) {
) (PooledChannel, rpc.TChanNode, error) {
attempt := atomic.AddInt32(&newConnAttempt, 1)
if attempt == 1 {
return channelNone, client1, nil
return &noopPooledChannel{}, client1, nil
} else if attempt == 2 {
return channelNone, client2, nil
return &noopPooledChannel{}, client2, nil
}
return nil, nil, fmt.Errorf("spawning only 2 connections")
}
Expand Down Expand Up @@ -307,7 +316,7 @@ func TestConnectionPoolHealthChecks(t *testing.T) {
return conns.ConnectionCount() == 1
}, 5*time.Second)
for i := 0; i < 2; i++ {
nextClient, err := conns.NextClient()
nextClient, _, err := conns.NextClient()
require.NoError(t, err)
require.Equal(t, client2, nextClient)
}
Expand All @@ -324,17 +333,13 @@ func TestConnectionPoolHealthChecks(t *testing.T) {
// and the connection actually being removed.
return conns.ConnectionCount() == 0
}, 5*time.Second)
nextClient, err := conns.NextClient()
nextClient, _, err := conns.NextClient()
require.Nil(t, nextClient)
require.Equal(t, errConnectionPoolHasNoConnections, err)

conns.Close()

nextClient, err = conns.NextClient()
nextClient, _, err = conns.NextClient()
require.Nil(t, nextClient)
require.Equal(t, errConnectionPoolClosed, err)
}

type nullChannel struct{}

func (*nullChannel) Close() {}
24 changes: 12 additions & 12 deletions src/dbnode/client/host_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ func (q *queue) asyncTaggedWrite(
// NB(bl): host is passed to writeState to determine the state of the
// shard on the node we're writing to

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
callAllCompletionFns(ops, q.host, err)
Expand Down Expand Up @@ -591,7 +591,7 @@ func (q *queue) asyncTaggedWriteV2(

// NB(bl): host is passed to writeState to determine the state of the
// shard on the node we're writing to.
client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
callAllCompletionFns(ops, q.host, err)
Expand Down Expand Up @@ -656,7 +656,7 @@ func (q *queue) asyncWrite(
// NB(bl): host is passed to writeState to determine the state of the
// shard on the node we're writing to

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
callAllCompletionFns(ops, q.host, err)
Expand Down Expand Up @@ -715,7 +715,7 @@ func (q *queue) asyncWriteV2(

// NB(bl): host is passed to writeState to determine the state of the
// shard on the node we're writing to.
client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available.
callAllCompletionFns(ops, q.host, err)
Expand Down Expand Up @@ -768,7 +768,7 @@ func (q *queue) asyncFetch(op *fetchBatchOp) {
q.Done()
}

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
op.completeAll(nil, err)
Expand Down Expand Up @@ -821,7 +821,7 @@ func (q *queue) asyncFetchV2(
q.Done()
}

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available.
callAllCompletionFns(ops, nil, err)
Expand Down Expand Up @@ -868,7 +868,7 @@ func (q *queue) asyncFetchTagged(op *fetchTaggedOp) {
q.Done()
}

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
op.CompletionFn()(fetchTaggedResultAccumulatorOpts{host: q.host}, err)
Expand Down Expand Up @@ -901,7 +901,7 @@ func (q *queue) asyncAggregate(op *aggregateOp) {
q.Done()
}

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
op.CompletionFn()(aggregateResultAccumulatorOpts{host: q.host}, err)
Expand Down Expand Up @@ -931,7 +931,7 @@ func (q *queue) asyncTruncate(op *truncateOp) {
q.workerPool.Go(func() {
cleanup := q.Done

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
op.completionFn(nil, err)
Expand Down Expand Up @@ -1003,7 +1003,7 @@ func (q *queue) ConnectionPool() connectionPool {
return q.connPool
}

func (q *queue) BorrowConnection(fn withConnectionFn) error {
func (q *queue) BorrowConnection(fn WithConnectionFn) error {
q.RLock()
if q.status != statusOpen {
q.RUnlock()
Expand All @@ -1014,12 +1014,12 @@ func (q *queue) BorrowConnection(fn withConnectionFn) error {
defer q.Done()
q.RUnlock()

conn, err := q.connPool.NextClient()
conn, ch, err := q.connPool.NextClient()
if err != nil {
return err
}

fn(conn)
fn(conn, ch)
return nil
}

Expand Down
Loading