diff --git a/src/dbnode/client/client_mock.go b/src/dbnode/client/client_mock.go index 6ce3b00bc1..a40a47d848 100644 --- a/src/dbnode/client/client_mock.go +++ b/src/dbnode/client/client_mock.go @@ -1085,6 +1085,21 @@ func (mr *MockAdminSessionMockRecorder) FetchBlocksFromPeers(namespace, shard, c return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchBlocksFromPeers", reflect.TypeOf((*MockAdminSession)(nil).FetchBlocksFromPeers), namespace, shard, consistencyLevel, metadatas, opts) } +// BorrowConnections mocks base method +func (m *MockAdminSession) BorrowConnections(shardID uint32, fn WithBorrowConnectionFn, opts BorrowConnectionOptions) (BorrowConnectionsResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BorrowConnections", shardID, fn, opts) + ret0, _ := ret[0].(BorrowConnectionsResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BorrowConnections indicates an expected call of BorrowConnections +func (mr *MockAdminSessionMockRecorder) BorrowConnections(shardID, fn, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BorrowConnections", reflect.TypeOf((*MockAdminSession)(nil).BorrowConnections), shardID, fn, opts) +} + // MockOptions is a mock of Options interface type MockOptions struct { ctrl *gomock.Controller @@ -4812,6 +4827,21 @@ func (mr *MockclientSessionMockRecorder) FetchBlocksFromPeers(namespace, shard, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchBlocksFromPeers", reflect.TypeOf((*MockclientSession)(nil).FetchBlocksFromPeers), namespace, shard, consistencyLevel, metadatas, opts) } +// BorrowConnections mocks base method +func (m *MockclientSession) BorrowConnections(shardID uint32, fn WithBorrowConnectionFn, opts BorrowConnectionOptions) (BorrowConnectionsResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BorrowConnections", shardID, fn, opts) + ret0, _ := ret[0].(BorrowConnectionsResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BorrowConnections indicates an expected call of BorrowConnections +func (mr *MockclientSessionMockRecorder) BorrowConnections(shardID, fn, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BorrowConnections", reflect.TypeOf((*MockclientSession)(nil).BorrowConnections), shardID, fn, opts) +} + // Open mocks base method func (m *MockclientSession) Open() error { m.ctrl.T.Helper() @@ -4932,7 +4962,7 @@ func (mr *MockhostQueueMockRecorder) ConnectionPool() *gomock.Call { } // BorrowConnection mocks base method -func (m *MockhostQueue) BorrowConnection(fn withConnectionFn) error { +func (m *MockhostQueue) BorrowConnection(fn WithConnectionFn) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "BorrowConnection", fn) ret0, _ := ret[0].(error) @@ -5007,12 +5037,13 @@ func (mr *MockconnectionPoolMockRecorder) ConnectionCount() *gomock.Call { } // NextClient mocks base method -func (m *MockconnectionPool) NextClient() (rpc.TChanNode, error) { +func (m *MockconnectionPool) NextClient() (rpc.TChanNode, PooledChannel, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "NextClient") ret0, _ := ret[0].(rpc.TChanNode) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(PooledChannel) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // NextClient indicates an expected call of NextClient @@ -5057,7 +5088,7 @@ func (m *MockpeerSource) EXPECT() *MockpeerSourceMockRecorder { } // BorrowConnection mocks base method -func (m *MockpeerSource) BorrowConnection(hostID string, fn withConnectionFn) error { +func (m *MockpeerSource) BorrowConnection(hostID string, fn WithConnectionFn) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "BorrowConnection", hostID, fn) ret0, _ := ret[0].(error) @@ -5108,7 +5139,7 @@ func (mr *MockpeerMockRecorder) Host() *gomock.Call { } // BorrowConnection mocks base method -func (m *Mockpeer) BorrowConnection(fn withConnectionFn) error { +func (m *Mockpeer) BorrowConnection(fn WithConnectionFn) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "BorrowConnection", fn) ret0, _ := ret[0].(error) diff --git a/src/dbnode/client/connection_pool.go b/src/dbnode/client/connection_pool.go index 4d1e1af1d2..effe4405b9 100644 --- a/src/dbnode/client/connection_pool.go +++ b/src/dbnode/client/connection_pool.go @@ -29,14 +29,14 @@ import ( "sync/atomic" "time" - "github.com/m3db/m3/src/dbnode/generated/thrift/rpc" - "github.com/m3db/m3/src/dbnode/topology" - xresource "github.com/m3db/m3/src/x/resource" murmur3 "github.com/m3db/stackmurmur3/v2" - "github.com/uber-go/tally" + "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/thrift" "go.uber.org/zap" + + "github.com/m3db/m3/src/dbnode/generated/thrift/rpc" + "github.com/m3db/m3/src/dbnode/topology" ) const ( @@ -67,15 +67,21 @@ type connPool struct { healthStatus tally.Gauge } +// PooledChannel is a tchannel.Channel for a pooled connection. +type PooledChannel interface { + GetSubChannel(serviceName string, opts ...tchannel.SubChannelOption) *tchannel.SubChannel + Close() +} + type conn struct { - channel xresource.SimpleCloser + channel PooledChannel client rpc.TChanNode } // NewConnectionFn is a function that creates a connection. type NewConnectionFn func( channelName string, addr string, opts Options, -) (xresource.SimpleCloser, rpc.TChanNode, error) +) (PooledChannel, rpc.TChanNode, error) type healthCheckFn func(client rpc.TChanNode, opts Options) error @@ -134,20 +140,20 @@ func (p *connPool) ConnectionCount() int { return int(poolLen) } -func (p *connPool) NextClient() (rpc.TChanNode, error) { +func (p *connPool) NextClient() (rpc.TChanNode, PooledChannel, error) { p.RLock() if p.status != statusOpen { p.RUnlock() - return nil, errConnectionPoolClosed + return nil, nil, errConnectionPoolClosed } if p.poolLen < 1 { p.RUnlock() - return nil, errConnectionPoolHasNoConnections + return nil, nil, errConnectionPoolHasNoConnections } n := atomic.AddInt64(&p.used, 1) conn := p.pool[n%p.poolLen] p.RUnlock() - return conn.client, nil + return conn.client, conn.channel, nil } func (p *connPool) Close() { diff --git a/src/dbnode/client/connection_pool_test.go b/src/dbnode/client/connection_pool_test.go index f4d391180c..6c45038bdb 100644 --- a/src/dbnode/client/connection_pool_test.go +++ b/src/dbnode/client/connection_pool_test.go @@ -30,10 +30,10 @@ import ( "github.com/m3db/m3/src/dbnode/generated/thrift/rpc" "github.com/m3db/m3/src/dbnode/topology" xclock "github.com/m3db/m3/src/x/clock" - xresource "github.com/m3db/m3/src/x/resource" - "github.com/stretchr/testify/require" "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "github.com/uber/tchannel-go" ) const ( @@ -42,10 +42,19 @@ const ( ) var ( - h = topology.NewHost(testHostStr, testHostAddr) - channelNone = &nullChannel{} + h = topology.NewHost(testHostStr, testHostAddr) ) +type noopPooledChannel struct{} + +func (c *noopPooledChannel) Close() {} +func (c *noopPooledChannel) GetSubChannel( + serviceName string, + opts ...tchannel.SubChannelOption, +) *tchannel.SubChannel { + return nil +} + func newConnectionPoolTestOptions() Options { return newSessionTestOptions(). SetBackgroundConnectInterval(5 * time.Millisecond). @@ -85,12 +94,12 @@ func TestConnectionPoolConnectsAndRetriesConnects(t *testing.T) { fn := func( ch string, addr string, opts Options, - ) (xresource.SimpleCloser, rpc.TChanNode, error) { + ) (PooledChannel, rpc.TChanNode, error) { attempt := int(atomic.AddInt32(&attempts, 1)) if attempt == 1 { return nil, nil, fmt.Errorf("a connect error") } - return channelNone, nil, nil + return &noopPooledChannel{}, nil, nil } opts = opts.SetNewConnectionFn(fn) @@ -151,7 +160,7 @@ func TestConnectionPoolConnectsAndRetriesConnects(t *testing.T) { conns.Close() doneWg.Done() - nextClient, err := conns.NextClient() + nextClient, _, err := conns.NextClient() require.Nil(t, nextClient) require.Equal(t, errConnectionPoolClosed, err) } @@ -237,12 +246,12 @@ func TestConnectionPoolHealthChecks(t *testing.T) { fn := func( ch string, addr string, opts Options, - ) (xresource.SimpleCloser, rpc.TChanNode, error) { + ) (PooledChannel, rpc.TChanNode, error) { attempt := atomic.AddInt32(&newConnAttempt, 1) if attempt == 1 { - return channelNone, client1, nil + return &noopPooledChannel{}, client1, nil } else if attempt == 2 { - return channelNone, client2, nil + return &noopPooledChannel{}, client2, nil } return nil, nil, fmt.Errorf("spawning only 2 connections") } @@ -307,7 +316,7 @@ func TestConnectionPoolHealthChecks(t *testing.T) { return conns.ConnectionCount() == 1 }, 5*time.Second) for i := 0; i < 2; i++ { - nextClient, err := conns.NextClient() + nextClient, _, err := conns.NextClient() require.NoError(t, err) require.Equal(t, client2, nextClient) } @@ -324,17 +333,13 @@ func TestConnectionPoolHealthChecks(t *testing.T) { // and the connection actually being removed. return conns.ConnectionCount() == 0 }, 5*time.Second) - nextClient, err := conns.NextClient() + nextClient, _, err := conns.NextClient() require.Nil(t, nextClient) require.Equal(t, errConnectionPoolHasNoConnections, err) conns.Close() - nextClient, err = conns.NextClient() + nextClient, _, err = conns.NextClient() require.Nil(t, nextClient) require.Equal(t, errConnectionPoolClosed, err) } - -type nullChannel struct{} - -func (*nullChannel) Close() {} diff --git a/src/dbnode/client/host_queue.go b/src/dbnode/client/host_queue.go index 32a42dbb94..9d042ed99d 100644 --- a/src/dbnode/client/host_queue.go +++ b/src/dbnode/client/host_queue.go @@ -531,7 +531,7 @@ func (q *queue) asyncTaggedWrite( // NB(bl): host is passed to writeState to determine the state of the // shard on the node we're writing to - client, err := q.connPool.NextClient() + client, _, err := q.connPool.NextClient() if err != nil { // No client available callAllCompletionFns(ops, q.host, err) @@ -591,7 +591,7 @@ func (q *queue) asyncTaggedWriteV2( // NB(bl): host is passed to writeState to determine the state of the // shard on the node we're writing to. - client, err := q.connPool.NextClient() + client, _, err := q.connPool.NextClient() if err != nil { // No client available callAllCompletionFns(ops, q.host, err) @@ -656,7 +656,7 @@ func (q *queue) asyncWrite( // NB(bl): host is passed to writeState to determine the state of the // shard on the node we're writing to - client, err := q.connPool.NextClient() + client, _, err := q.connPool.NextClient() if err != nil { // No client available callAllCompletionFns(ops, q.host, err) @@ -715,7 +715,7 @@ func (q *queue) asyncWriteV2( // NB(bl): host is passed to writeState to determine the state of the // shard on the node we're writing to. - client, err := q.connPool.NextClient() + client, _, err := q.connPool.NextClient() if err != nil { // No client available. callAllCompletionFns(ops, q.host, err) @@ -768,7 +768,7 @@ func (q *queue) asyncFetch(op *fetchBatchOp) { q.Done() } - client, err := q.connPool.NextClient() + client, _, err := q.connPool.NextClient() if err != nil { // No client available op.completeAll(nil, err) @@ -821,7 +821,7 @@ func (q *queue) asyncFetchV2( q.Done() } - client, err := q.connPool.NextClient() + client, _, err := q.connPool.NextClient() if err != nil { // No client available. callAllCompletionFns(ops, nil, err) @@ -868,7 +868,7 @@ func (q *queue) asyncFetchTagged(op *fetchTaggedOp) { q.Done() } - client, err := q.connPool.NextClient() + client, _, err := q.connPool.NextClient() if err != nil { // No client available op.CompletionFn()(fetchTaggedResultAccumulatorOpts{host: q.host}, err) @@ -901,7 +901,7 @@ func (q *queue) asyncAggregate(op *aggregateOp) { q.Done() } - client, err := q.connPool.NextClient() + client, _, err := q.connPool.NextClient() if err != nil { // No client available op.CompletionFn()(aggregateResultAccumulatorOpts{host: q.host}, err) @@ -931,7 +931,7 @@ func (q *queue) asyncTruncate(op *truncateOp) { q.workerPool.Go(func() { cleanup := q.Done - client, err := q.connPool.NextClient() + client, _, err := q.connPool.NextClient() if err != nil { // No client available op.completionFn(nil, err) @@ -1003,7 +1003,7 @@ func (q *queue) ConnectionPool() connectionPool { return q.connPool } -func (q *queue) BorrowConnection(fn withConnectionFn) error { +func (q *queue) BorrowConnection(fn WithConnectionFn) error { q.RLock() if q.status != statusOpen { q.RUnlock() @@ -1014,12 +1014,12 @@ func (q *queue) BorrowConnection(fn withConnectionFn) error { defer q.Done() q.RUnlock() - conn, err := q.connPool.NextClient() + conn, ch, err := q.connPool.NextClient() if err != nil { return err } - fn(conn) + fn(conn, ch) return nil } diff --git a/src/dbnode/client/host_queue_aggregate_test.go b/src/dbnode/client/host_queue_aggregate_test.go index 5b57f93d49..e1b204fdc7 100644 --- a/src/dbnode/client/host_queue_aggregate_test.go +++ b/src/dbnode/client/host_queue_aggregate_test.go @@ -36,7 +36,7 @@ import ( ) func TestHostQueueDrainOnCloseAggregate(t *testing.T) { - ctrl := gomock.NewController(xtest.Reporter{t}) + ctrl := gomock.NewController(xtest.Reporter{T: t}) defer ctrl.Finish() mockConnPool := NewMockconnectionPool(ctrl) @@ -73,7 +73,7 @@ func TestHostQueueDrainOnCloseAggregate(t *testing.T) { assert.Equal(t, aggregate.request.NameSpace, req.NameSpace) } mockClient.EXPECT().AggregateRaw(gomock.Any(), gomock.Any()).Do(aggregateExec).Return(nil, nil) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) mockConnPool.EXPECT().Close().AnyTimes() // Close the queue should cause all writes to be flushed @@ -202,7 +202,7 @@ func testHostQueueAggregate( // Prepare mocks for flush mockClient := rpc.NewMockTChanNode(ctrl) if testOpts != nil && testOpts.nextClientErr != nil { - mockConnPool.EXPECT().NextClient().Return(nil, testOpts.nextClientErr) + mockConnPool.EXPECT().NextClient().Return(nil, nil, testOpts.nextClientErr) } else if testOpts != nil && testOpts.aggregateErr != nil { aggregateExec := func(ctx thrift.Context, req *rpc.AggregateQueryRawRequest) { require.NotNil(t, req) @@ -213,7 +213,7 @@ func testHostQueueAggregate( Do(aggregateExec). Return(nil, testOpts.aggregateErr) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) } else { aggregateExec := func(ctx thrift.Context, req *rpc.AggregateQueryRawRequest) { require.NotNil(t, req) @@ -224,7 +224,7 @@ func testHostQueueAggregate( Do(aggregateExec). Return(result, nil) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) } // Fetch diff --git a/src/dbnode/client/host_queue_fetch_batch_test.go b/src/dbnode/client/host_queue_fetch_batch_test.go index a485afaf46..ef3138b527 100644 --- a/src/dbnode/client/host_queue_fetch_batch_test.go +++ b/src/dbnode/client/host_queue_fetch_batch_test.go @@ -123,7 +123,7 @@ func TestHostQueueFetchBatchesV2MultiNS(t *testing.T) { Do(verifyFetchBatchRawV2). Return(result, nil) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) for _, fetchBatch := range fetchBatches { assert.NoError(t, queue.Enqueue(fetchBatch)) @@ -310,7 +310,7 @@ func testHostQueueFetchBatches( } } if testOpts != nil && testOpts.nextClientErr != nil { - mockConnPool.EXPECT().NextClient().Return(nil, testOpts.nextClientErr) + mockConnPool.EXPECT().NextClient().Return(nil, nil, testOpts.nextClientErr) } else if testOpts != nil && testOpts.fetchRawBatchErr != nil { if opts.UseV2BatchAPIs() { mockClient.EXPECT(). @@ -326,7 +326,7 @@ func testHostQueueFetchBatches( Do(fetchBatchRaw). Return(nil, testOpts.fetchRawBatchErr) } - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) } else { if opts.UseV2BatchAPIs() { mockClient.EXPECT(). @@ -343,7 +343,7 @@ func testHostQueueFetchBatches( Return(result, nil) } - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) } // Fetch diff --git a/src/dbnode/client/host_queue_fetch_tagged_test.go b/src/dbnode/client/host_queue_fetch_tagged_test.go index e48e83092c..df66a2815b 100644 --- a/src/dbnode/client/host_queue_fetch_tagged_test.go +++ b/src/dbnode/client/host_queue_fetch_tagged_test.go @@ -72,7 +72,7 @@ func TestHostQueueDrainOnCloseFetchTagged(t *testing.T) { assert.Equal(t, fetch.request.NameSpace, req.NameSpace) } mockClient.EXPECT().FetchTagged(gomock.Any(), gomock.Any()).Do(fetchTagged).Return(nil, nil) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) mockConnPool.EXPECT().Close().AnyTimes() // Close the queue should cause all writes to be flushed @@ -202,7 +202,7 @@ func testHostQueueFetchTagged( // Prepare mocks for flush mockClient := rpc.NewMockTChanNode(ctrl) if testOpts != nil && testOpts.nextClientErr != nil { - mockConnPool.EXPECT().NextClient().Return(nil, testOpts.nextClientErr) + mockConnPool.EXPECT().NextClient().Return(nil, nil, testOpts.nextClientErr) } else if testOpts != nil && testOpts.fetchTaggedErr != nil { fetchTaggedExec := func(ctx thrift.Context, req *rpc.FetchTaggedRequest) { require.NotNil(t, req) @@ -213,7 +213,7 @@ func testHostQueueFetchTagged( Do(fetchTaggedExec). Return(nil, testOpts.fetchTaggedErr) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) } else { fetchTaggedExec := func(ctx thrift.Context, req *rpc.FetchTaggedRequest) { require.NotNil(t, req) @@ -224,7 +224,7 @@ func testHostQueueFetchTagged( Do(fetchTaggedExec). Return(result, nil) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) } // Fetch diff --git a/src/dbnode/client/host_queue_write_batch_test.go b/src/dbnode/client/host_queue_write_batch_test.go index 0bcab39622..23ae55db34 100644 --- a/src/dbnode/client/host_queue_write_batch_test.go +++ b/src/dbnode/client/host_queue_write_batch_test.go @@ -115,7 +115,7 @@ func TestHostQueueWriteBatches(t *testing.T) { mockClient.EXPECT().WriteBatchRaw(gomock.Any(), gomock.Any()).Do(writeBatch).Return(nil) } - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) // Final write will flush assert.NoError(t, queue.Enqueue(writes[3])) @@ -203,7 +203,7 @@ func TestHostQueueWriteBatchesDifferentNamespaces(t *testing.T) { // Assert the writes will be handled in two batches mockClient.EXPECT().WriteBatchRawV2(gomock.Any(), gomock.Any()).Do(writeBatch).Return(nil).Times(1) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil).Times(1) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil).Times(1) } else { writeBatch := func(ctx thrift.Context, req *rpc.WriteBatchRawRequest) { var writesForNamespace []*writeOperation @@ -221,7 +221,7 @@ func TestHostQueueWriteBatchesDifferentNamespaces(t *testing.T) { // Assert the writes will be handled in two batches mockClient.EXPECT().WriteBatchRaw(gomock.Any(), gomock.Any()).Do(writeBatch).Return(nil).Times(2) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil).Times(2) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil).Times(2) } for _, write := range writes { @@ -267,7 +267,7 @@ func TestHostQueueWriteBatchesNoClientAvailable(t *testing.T) { // Prepare mocks for flush nextClientErr := fmt.Errorf("an error") - mockConnPool.EXPECT().NextClient().Return(nil, nextClientErr) + mockConnPool.EXPECT().NextClient().Return(nil, nil, nextClientErr) // Write var wg sync.WaitGroup @@ -357,7 +357,7 @@ func TestHostQueueWriteBatchesPartialBatchErrs(t *testing.T) { } mockClient.EXPECT().WriteBatchRaw(gomock.Any(), gomock.Any()).Do(writeBatch).Return(batchErrs) } - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) // Perform writes for _, write := range writes { @@ -418,7 +418,7 @@ func TestHostQueueWriteBatchesEntireBatchErr(t *testing.T) { } } mockClient.EXPECT().WriteBatchRaw(gomock.Any(), gomock.Any()).Do(writeBatch).Return(writeErr) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) // Perform writes for _, write := range writes { @@ -488,7 +488,7 @@ func TestHostQueueDrainOnClose(t *testing.T) { } mockClient.EXPECT().WriteBatchRaw(gomock.Any(), gomock.Any()).Do(writeBatch).Return(nil) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) mockConnPool.EXPECT().Close().AnyTimes() diff --git a/src/dbnode/client/host_queue_write_tagged_test.go b/src/dbnode/client/host_queue_write_tagged_test.go index 46f8ae5d03..c99f3f05f4 100644 --- a/src/dbnode/client/host_queue_write_tagged_test.go +++ b/src/dbnode/client/host_queue_write_tagged_test.go @@ -129,7 +129,7 @@ func TestHostQueueWriteTaggedBatches(t *testing.T) { } mockClient.EXPECT().WriteTaggedBatchRaw(gomock.Any(), gomock.Any()).Do(writeBatch).Return(nil) } - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) // Final write will flush assert.NoError(t, queue.Enqueue(writes[3])) @@ -225,7 +225,7 @@ func TestHostQueueWriteTaggedBatchesDifferentNamespaces(t *testing.T) { } // Assert the writes will be handled in two batches. mockClient.EXPECT().WriteTaggedBatchRawV2(gomock.Any(), gomock.Any()).Do(writeBatch).Return(nil).Times(1) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil).Times(1) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil).Times(1) } else { writeBatch := func(ctx thrift.Context, req *rpc.WriteTaggedBatchRawRequest) { var writesForNamespace []*writeTaggedOperation @@ -243,7 +243,7 @@ func TestHostQueueWriteTaggedBatchesDifferentNamespaces(t *testing.T) { } // Assert the writes will be handled in two batches. mockClient.EXPECT().WriteTaggedBatchRaw(gomock.Any(), gomock.Any()).Do(writeBatch).Return(nil).Times(2) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil).Times(2) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil).Times(2) } for _, write := range writes { assert.NoError(t, queue.Enqueue(write)) @@ -288,7 +288,7 @@ func TestHostQueueWriteTaggedBatchesNoClientAvailable(t *testing.T) { // Prepare mocks for flush nextClientErr := fmt.Errorf("an error") - mockConnPool.EXPECT().NextClient().Return(nil, nextClientErr) + mockConnPool.EXPECT().NextClient().Return(nil, nil, nextClientErr) // Write var wg sync.WaitGroup @@ -368,7 +368,7 @@ func TestHostQueueWriteTaggedBatchesPartialBatchErrs(t *testing.T) { }}, }} mockClient.EXPECT().WriteTaggedBatchRaw(gomock.Any(), gomock.Any()).Do(writeBatch).Return(batchErrs) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) // Perform writes for _, write := range writes { @@ -428,7 +428,7 @@ func TestHostQueueWriteTaggedBatchesEntireBatchErr(t *testing.T) { } } mockClient.EXPECT().WriteTaggedBatchRaw(gomock.Any(), gomock.Any()).Do(writeBatch).Return(writeErr) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) // Perform writes for _, write := range writes { @@ -499,7 +499,7 @@ func TestHostQueueDrainOnCloseTaggedWrite(t *testing.T) { } mockClient.EXPECT().WriteTaggedBatchRaw(gomock.Any(), gomock.Any()).Do(writeBatch).Return(nil) - mockConnPool.EXPECT().NextClient().Return(mockClient, nil) + mockConnPool.EXPECT().NextClient().Return(mockClient, &noopPooledChannel{}, nil) mockConnPool.EXPECT().Close().AnyTimes() diff --git a/src/dbnode/client/options.go b/src/dbnode/client/options.go index 9ec1613fd8..5cb4ea3eb1 100644 --- a/src/dbnode/client/options.go +++ b/src/dbnode/client/options.go @@ -42,13 +42,12 @@ import ( "github.com/m3db/m3/src/x/ident" "github.com/m3db/m3/src/x/instrument" "github.com/m3db/m3/src/x/pool" - xresource "github.com/m3db/m3/src/x/resource" xretry "github.com/m3db/m3/src/x/retry" "github.com/m3db/m3/src/x/sampler" "github.com/m3db/m3/src/x/serialize" xsync "github.com/m3db/m3/src/x/sync" - tchannel "github.com/uber/tchannel-go" + "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/thrift" ) @@ -319,7 +318,7 @@ func NewOptionsForAsyncClusters(opts Options, topoInits []topology.Initializer, func defaultNewConnectionFn( channelName string, address string, clientOpts Options, -) (xresource.SimpleCloser, rpc.TChanNode, error) { +) (PooledChannel, rpc.TChanNode, error) { // NB(r): Keep ref to a local channel options since it's actually modified // by TChannel itself to set defaults. var opts *tchannel.ChannelOptions diff --git a/src/dbnode/client/peer.go b/src/dbnode/client/peer.go index b5bdc075b4..0128e58453 100644 --- a/src/dbnode/client/peer.go +++ b/src/dbnode/client/peer.go @@ -38,6 +38,6 @@ func (p *sessionPeer) Host() topology.Host { return p.host } -func (p *sessionPeer) BorrowConnection(fn withConnectionFn) error { +func (p *sessionPeer) BorrowConnection(fn WithConnectionFn) error { return p.source.BorrowConnection(p.host.ID(), fn) } diff --git a/src/dbnode/client/replicated_session.go b/src/dbnode/client/replicated_session.go index 8064a5df8e..5df3dfb506 100644 --- a/src/dbnode/client/replicated_session.go +++ b/src/dbnode/client/replicated_session.go @@ -24,6 +24,9 @@ import ( "fmt" "time" + "github.com/uber-go/tally" + "go.uber.org/zap" + "github.com/m3db/m3/src/dbnode/encoding" "github.com/m3db/m3/src/dbnode/namespace" "github.com/m3db/m3/src/dbnode/storage/block" @@ -33,8 +36,6 @@ import ( "github.com/m3db/m3/src/x/ident" m3sync "github.com/m3db/m3/src/x/sync" xtime "github.com/m3db/m3/src/x/time" - "github.com/uber-go/tally" - "go.uber.org/zap" ) type newSessionFn func(Options) (clientSession, error) @@ -111,8 +112,6 @@ func newReplicatedSession(opts Options, asyncOpts []Options, options ...replicat return &session, nil } -type writeFunc func(Session) error - func (s *replicatedSession) setSession(opts Options) error { if opts.TopologyInitializer() == nil { return nil @@ -343,6 +342,14 @@ func (s replicatedSession) FetchBlocksFromPeers( return s.session.FetchBlocksFromPeers(namespace, shard, consistencyLevel, metadatas, opts) } +func (s *replicatedSession) BorrowConnections( + shardID uint32, + fn WithBorrowConnectionFn, + opts BorrowConnectionOptions, +) (BorrowConnectionsResult, error) { + return s.session.BorrowConnections(shardID, fn, opts) +} + // Open the client session. func (s replicatedSession) Open() error { if err := s.session.Open(); err != nil { diff --git a/src/dbnode/client/session.go b/src/dbnode/client/session.go index be79aa8331..e1de530ebe 100644 --- a/src/dbnode/client/session.go +++ b/src/dbnode/client/session.go @@ -629,7 +629,73 @@ func (s *session) Open() error { return nil } -func (s *session) BorrowConnection(hostID string, fn withConnectionFn) error { +func (s *session) BorrowConnections( + shardID uint32, + fn WithBorrowConnectionFn, + opts BorrowConnectionOptions, +) (BorrowConnectionsResult, error) { + var result BorrowConnectionsResult + s.state.RLock() + topoMap, err := s.topologyMapWithStateRLock() + s.state.RUnlock() + if err != nil { + return result, err + } + + var ( + multiErr = xerrors.NewMultiError() + breakLoop bool + ) + err = topoMap.RouteShardForEach(shardID, func( + _ int, + shard shard.Shard, + host topology.Host, + ) { + if multiErr.NumErrors() > 0 || breakLoop { + // Error or has broken + return + } + + var ( + userResult WithBorrowConnectionResult + userErr error + ) + borrowErr := s.BorrowConnection(host.ID(), func( + client rpc.TChanNode, + channel PooledChannel, + ) { + userResult, userErr = fn(shard, host, client, channel) + }) + if borrowErr != nil { + // Wasn't able to even borrow, skip if don't want to error + // on down hosts or return the borrow error. + if !opts.ContinueOnBorrowError { + multiErr = multiErr.Add(borrowErr) + } + return + } + + // Track successful borrow. + result.Borrowed++ + + // Track whether has broken loop. + breakLoop = userResult.Break + + // Return whether user error occurred to break or not. + if userErr != nil { + multiErr = multiErr.Add(userErr) + } + }) + if err != nil { + // Route error. + return result, err + } + // Potentially a user error or borrow error, otherwise + // FinalError() will return nil. + return result, multiErr.FinalError() +} + +func (s *session) BorrowConnection(hostID string, fn WithConnectionFn) error { s.state.RLock() unlocked := false queue, ok := s.state.queuesByHostID[hostID] @@ -637,13 +703,13 @@ func (s *session) BorrowConnection(hostID string, fn withConnectionFn) error { s.state.RUnlock() return errSessionHasNoHostQueueForHost } - err := queue.BorrowConnection(func(c rpc.TChanNode) { + err := queue.BorrowConnection(func(client rpc.TChanNode, ch PooledChannel) { // Unlock early on success s.state.RUnlock() unlocked = true // Execute function with borrowed connection - fn(c) + fn(client, ch) }) if !unlocked { s.state.RUnlock() @@ -2557,7 +2623,7 @@ func (s *session) streamBlocksMetadataFromPeer( } var attemptErr error - checkedAttemptFn := func(client rpc.TChanNode) { + checkedAttemptFn := func(client rpc.TChanNode, _ PooledChannel) { attemptErr = attemptFn(client) } @@ -3074,7 +3140,7 @@ func (s *session) streamBlocksBatchFromPeer( // Attempt request if err := retrier.Attempt(func() error { var attemptErr error - borrowErr := peer.BorrowConnection(func(client rpc.TChanNode) { + borrowErr := peer.BorrowConnection(func(client rpc.TChanNode, _ PooledChannel) { tctx, _ := thrift.NewContext(s.streamBlocksBatchTimeout) result, attemptErr = client.FetchBlocksRaw(tctx, req) }) diff --git a/src/dbnode/client/session_fetch_bulk_blocks_test.go b/src/dbnode/client/session_fetch_bulk_blocks_test.go index db4d1afecd..0183fb4703 100644 --- a/src/dbnode/client/session_fetch_bulk_blocks_test.go +++ b/src/dbnode/client/session_fetch_bulk_blocks_test.go @@ -339,7 +339,7 @@ func TestFetchBootstrapBlocksDontRetryHostNotAvailableInRetrier(t *testing.T) { connectionPool := NewMockconnectionPool(ctrl) connectionPool.EXPECT(). NextClient(). - Return(nil, errConnectionPoolHasNoConnections). + Return(nil, nil, errConnectionPoolHasNoConnections). AnyTimes() hostQueue := NewMockhostQueue(ctrl) hostQueue.EXPECT().Open() @@ -2008,15 +2008,15 @@ func defaultHostAndClientWithExpect( ) (*MockhostQueue, *rpc.MockTChanNode) { client := rpc.NewMockTChanNode(ctrl) connectionPool := NewMockconnectionPool(ctrl) - connectionPool.EXPECT().NextClient().Return(client, nil).AnyTimes() + connectionPool.EXPECT().NextClient().Return(client, &noopPooledChannel{}, nil).AnyTimes() hostQueue := NewMockhostQueue(ctrl) hostQueue.EXPECT().Open() hostQueue.EXPECT().Host().Return(host).AnyTimes() hostQueue.EXPECT().ConnectionCount().Return(opts.MinConnectionCount()).Times(sessionTestShards) hostQueue.EXPECT().ConnectionPool().Return(connectionPool).AnyTimes() - hostQueue.EXPECT().BorrowConnection(gomock.Any()).Do(func(fn withConnectionFn) { - fn(client) + hostQueue.EXPECT().BorrowConnection(gomock.Any()).Do(func(fn WithConnectionFn) { + fn(client, &noopPooledChannel{}) }).Return(nil).AnyTimes() hostQueue.EXPECT().Close() diff --git a/src/dbnode/client/session_fetch_high_concurrency_test.go b/src/dbnode/client/session_fetch_high_concurrency_test.go index a4acd087c3..35a50e7677 100644 --- a/src/dbnode/client/session_fetch_high_concurrency_test.go +++ b/src/dbnode/client/session_fetch_high_concurrency_test.go @@ -28,6 +28,10 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/m3db/m3/src/cluster/shard" "github.com/m3db/m3/src/dbnode/encoding/m3tsz" "github.com/m3db/m3/src/dbnode/generated/thrift/rpc" @@ -35,18 +39,9 @@ import ( "github.com/m3db/m3/src/dbnode/topology" "github.com/m3db/m3/src/dbnode/ts" "github.com/m3db/m3/src/x/ident" - xresource "github.com/m3db/m3/src/x/resource" xtime "github.com/m3db/m3/src/x/time" - - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -type noopCloser struct{} - -func (noopCloser) Close() {} - func TestSessionFetchIDsHighConcurrency(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -102,7 +97,7 @@ func TestSessionFetchIDsHighConcurrency(t *testing.T) { // to be able to mock the entire end to end pipeline newConnFn := func( _ string, addr string, _ Options, - ) (xresource.SimpleCloser, rpc.TChanNode, error) { + ) (PooledChannel, rpc.TChanNode, error) { mockClient := rpc.NewMockTChanNode(ctrl) mockClient.EXPECT().Health(gomock.Any()). Return(healthCheckResult, nil). @@ -110,7 +105,7 @@ func TestSessionFetchIDsHighConcurrency(t *testing.T) { mockClient.EXPECT().FetchBatchRaw(gomock.Any(), gomock.Any()). Return(respResult, nil). AnyTimes() - return noopCloser{}, mockClient, nil + return &noopPooledChannel{}, mockClient, nil } shards := make([]shard.Shard, numShards) for i := range shards { diff --git a/src/dbnode/client/types.go b/src/dbnode/client/types.go index 53399bf333..0de707fe23 100644 --- a/src/dbnode/client/types.go +++ b/src/dbnode/client/types.go @@ -23,6 +23,7 @@ package client import ( "time" + "github.com/m3db/m3/src/cluster/shard" "github.com/m3db/m3/src/dbnode/encoding" "github.com/m3db/m3/src/dbnode/generated/thrift/rpc" "github.com/m3db/m3/src/dbnode/namespace" @@ -42,7 +43,7 @@ import ( xsync "github.com/m3db/m3/src/x/sync" xtime "github.com/m3db/m3/src/x/time" - tchannel "github.com/uber/tchannel-go" + "github.com/uber/tchannel-go" ) // Client can create sessions to write and read to a cluster. @@ -245,6 +246,39 @@ type AdminSession interface { metadatas []block.ReplicaMetadata, opts result.Options, ) (PeerBlocksIter, error) + + // BorrowConnections will borrow connection for hosts belonging to a shard. + BorrowConnections( + shardID uint32, + fn WithBorrowConnectionFn, + opts BorrowConnectionOptions, + ) (BorrowConnectionsResult, error) +} + +// BorrowConnectionOptions are options to use when borrowing a connection +type BorrowConnectionOptions struct { + // ContinueOnBorrowError allows skipping hosts that cannot borrow + // a connection for. + ContinueOnBorrowError bool +} + +// BorrowConnectionsResult is a result used when borrowing connections. +type BorrowConnectionsResult struct { + Borrowed int +} + +// WithBorrowConnectionFn is used to do work with a borrowed connection. +type WithBorrowConnectionFn func( + shard shard.Shard, + host topology.Host, + client rpc.TChanNode, + channel PooledChannel, +) (WithBorrowConnectionResult, error) + +// WithBorrowConnectionResult is returned from a borrow connection function. +type WithBorrowConnectionResult struct { + // Break will break the iteration. + Break bool } // Options is a set of client options. @@ -694,13 +728,14 @@ type hostQueue interface { ConnectionPool() connectionPool // BorrowConnection will borrow a connection and execute a user function. - BorrowConnection(fn withConnectionFn) error + BorrowConnection(fn WithConnectionFn) error // Close the host queue, will flush any operations still pending. Close() } -type withConnectionFn func(c rpc.TChanNode) +// WithConnectionFn is a callback for a connection to a host. +type WithConnectionFn func(client rpc.TChanNode, ch PooledChannel) type connectionPool interface { // Open starts the connection pool connecting and health checking. @@ -710,7 +745,7 @@ type connectionPool interface { ConnectionCount() int // NextClient gets the next client for use by the connection pool. - NextClient() (rpc.TChanNode, error) + NextClient() (rpc.TChanNode, PooledChannel, error) // Close the connection pool. Close() @@ -718,7 +753,7 @@ type connectionPool interface { type peerSource interface { // BorrowConnection will borrow a connection and execute a user function. - BorrowConnection(hostID string, fn withConnectionFn) error + BorrowConnection(hostID string, fn WithConnectionFn) error } type peer interface { @@ -726,7 +761,7 @@ type peer interface { Host() topology.Host // BorrowConnection will borrow a connection and execute a user function. - BorrowConnection(fn withConnectionFn) error + BorrowConnection(fn WithConnectionFn) error } type status int