Skip to content

Commit

Permalink
Improve blocking borrow logic in pool
Browse files Browse the repository at this point in the history
 * After waiting for room in the pool, borrow also considers creating a new
   connection (if possible)
 * After waiting for room in the pool, borrow reads (but doesn't update) the
   routing table to get more up-to-date routing info
 * Connections are always fully returned to and then borrowed from the pool.
   This means the health check and other initializers are guaranteed to be run.

Signed-off-by: Florent Biville <florent.biville@neo4j.com>
  • Loading branch information
robsdedude authored and fbiville committed May 10, 2023
1 parent 4d3f387 commit cb75c31
Show file tree
Hide file tree
Showing 15 changed files with 395 additions and 230 deletions.
12 changes: 10 additions & 2 deletions neo4j/directrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,19 @@ func (r *directRouter) InvalidateReader(context.Context, string, string) error {
return nil
}

func (r *directRouter) Readers(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) {
func (r *directRouter) GetOrUpdateReaders(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) {
return []string{r.address}, nil
}

func (r *directRouter) Writers(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) {
func (r *directRouter) Readers(context.Context, string) ([]string, error) {
return []string{r.address}, nil
}

func (r *directRouter) GetOrUpdateWriters(context.Context, func(context.Context) ([]string, error), string, *db.ReAuthToken, log.BoltLogger) ([]string, error) {
return []string{r.address}, nil
}

func (r *directRouter) Writers(context.Context, string) ([]string, error) {
return []string{r.address}, nil
}

Expand Down
12 changes: 8 additions & 4 deletions neo4j/driver_with_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,18 @@ func routingContextFromUrl(useRouting bool, u *url.URL) (map[string]string, erro
}

type sessionRouter interface {
// Readers returns the list of servers that can serve reads on the requested database.
// GetOrUpdateReaders returns the list of servers that can serve reads on the requested database.
// note: bookmarks are lazily supplied, only when a new routing table needs to be fetched
// this is needed because custom bookmark managers may provide bookmarks from external systems
// they should not be called when it is not needed (e.g. when a routing table is cached)
Readers(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error)
// Writers returns the list of servers that can serve writes on the requested database.
GetOrUpdateReaders(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error)
// Readers returns the list of servers that can serve reads on the requested database.
Readers(ctx context.Context, database string) ([]string, error)
// GetOrUpdateWriters returns the list of servers that can serve writes on the requested database.
// note: bookmarks are lazily supplied, see Readers documentation to learn why
Writers(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error)
GetOrUpdateWriters(ctx context.Context, bookmarks func(context.Context) ([]string, error), database string, auth *idb.ReAuthToken, boltLogger log.BoltLogger) ([]string, error)
// Writers returns the list of servers that can serve writes on the requested database.
Writers(ctx context.Context, database string) ([]string, error)
// GetNameOfDefaultDatabase returns the name of the default database for the specified user.
// The correct database name is needed when requesting readers or writers.
// the bookmarks are eagerly provided since this method always fetches a new routing table
Expand Down
26 changes: 26 additions & 0 deletions neo4j/internal/errorutil/pool.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [https://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package errorutil

import "fmt"
Expand Down Expand Up @@ -25,3 +44,10 @@ type PoolClosed struct {
func (e *PoolClosed) Error() string {
return "Pool closed"
}

type PoolOutOfServers struct {
}

func (e *PoolOutOfServers) Error() string {
return "Pool could not find any servers to connect to"
}
213 changes: 90 additions & 123 deletions neo4j/internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ const DefaultLivenessCheckThreshold = math.MaxInt64
type Connect func(context.Context, string, *idb.ReAuthToken, bolt.Neo4jErrorCallback, log.BoltLogger) (idb.Connection, error)

type qitem struct {
servers []string
wakeup chan bool
conn idb.Connection
wakeup chan bool
}

type Pool struct {
Expand Down Expand Up @@ -109,22 +107,6 @@ func (p *Pool) Close(ctx context.Context) error {
return nil
}

func (p *Pool) anyExistingConnectionsOnServers(ctx context.Context, serverNames []string) (bool, error) {
if !p.serversMut.TryLock(ctx) {
return false, fmt.Errorf("could not acquire server lock in time when checking server connection")
}
defer p.serversMut.Unlock()
for _, s := range serverNames {
b := p.servers[s]
if b != nil {
if b.size() > 0 {
return true, nil
}
}
}
return false, nil
}

// For testing
func (p *Pool) queueSize(ctx context.Context) (int, error) {
if !p.queueMut.TryLock(ctx) {
Expand Down Expand Up @@ -230,100 +212,94 @@ serverLoop:
return nil, nil
}

func (p *Pool) Borrow(ctx context.Context, serverNames []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) {
func (p *Pool) Borrow(ctx context.Context, getServerNames func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) {
if p.closed {
return nil, &errorutil.PoolClosed{}
}
p.log.Debugf(log.Pool, p.logId, "Trying to borrow connection from %s", serverNames)

// Retrieve penalty for each server
penalties, err := p.getPenaltiesForServers(ctx, serverNames)
if err != nil {
return nil, err
}
// Sort server penalties by lowest penalty
sort.Slice(penalties, func(i, j int) bool {
return penalties[i].penalty < penalties[j].penalty
})

var conn idb.Connection
for _, s := range penalties {
conn, err = p.tryBorrow(ctx, s.name, boltLogger, idlenessThreshold, auth)
if err == nil {
return conn, nil
for {
serverNames, err := getServerNames(ctx)
if err != nil {
return nil, err
}

if errorutil.IsTimeoutError(err) {
p.log.Warnf(log.Pool, p.logId, "Borrow time-out")
return nil, &errorutil.PoolTimeout{Servers: serverNames, Err: err}
if len(serverNames) == 0 {
return nil, &errorutil.PoolOutOfServers{}
}
if errorutil.IsFatalDuringDiscovery(err) {
p.log.Debugf(log.Pool, p.logId, "Trying to borrow connection from %s", serverNames)
// Retrieve penalty for each server
penalties, err := p.getPenaltiesForServers(ctx, serverNames)
if err != nil {
return nil, err
}
}
// Sort server penalties by lowest penalty
sort.Slice(penalties, func(i, j int) bool {
return penalties[i].penalty < penalties[j].penalty
})

anyConnection, anyConnectionErr := p.anyExistingConnectionsOnServers(ctx, serverNames)
if anyConnectionErr != nil {
return nil, err
}
// If there are no connections for any of the servers, there is no point in waiting for anything
// to be returned.
if !anyConnection {
p.log.Warnf(log.Pool, p.logId, "No server connection available to any of %v", serverNames)
if err == nil {
err = fmt.Errorf("no server connection available to any of %v", serverNames)
var conn idb.Connection
for _, s := range penalties {
conn, err = p.tryBorrow(ctx, s.name, boltLogger, idlenessThreshold, auth)
if conn != nil {
return conn, nil
}

if errorutil.IsTimeoutError(err) {
p.log.Warnf(log.Pool, p.logId, "Borrow time-out")
return nil, &errorutil.PoolTimeout{Servers: serverNames, Err: err}
}
if errorutil.IsFatalDuringDiscovery(err) {
return nil, err
}
}
// Intentionally return last error from last connection attempt to make it easier to
// see connection errors for users.
return nil, err
}

if !wait {
return nil, &errorutil.PoolFull{Servers: serverNames}
}
if err != nil {
// Intentionally return last error from last connection attempt to make it easier to
// see connection errors for users.
return nil, err
}

// Wait for a matching connection to be returned from another thread.
if !p.queueMut.TryLock(ctx) {
return nil, racing.LockTimeoutError("could not acquire lock in time when trying to get an idle connection")
}
// Ok, now that we own the queue we can add the item there but between getting the lock
// and above check for an existing connection another thread might have returned a connection
// so check again to avoid potentially starving this thread.
conn, err = p.tryAnyIdle(ctx, serverNames, idlenessThreshold, auth, boltLogger)
if err != nil {
p.queueMut.Unlock()
return nil, err
}
if conn != nil {
p.queueMut.Unlock()
return conn, nil
}
// Add a waiting request to the queue and unlock the queue to let other threads that return
// their connections access the queue.
q := &qitem{
servers: serverNames,
wakeup: make(chan bool),
}
e := p.queue.PushBack(q)
p.queueMut.Unlock()
if !wait {
return nil, &errorutil.PoolFull{Servers: serverNames}
}

p.log.Warnf(log.Pool, p.logId, "Borrow queued")
// Wait for either a wake-up signal that indicates that we got a connection or a timeout.
select {
case <-q.wakeup:
return q.conn, nil
case <-ctx.Done():
// TODO: provided ctx has reached deadline already - set some hardcoded timeout instead?
if !p.queueMut.TryLock(context.Background()) {
return nil, racing.LockTimeoutError("could not acquire lock in time when removing server wait request")
// Wait for a matching connection to be returned from another thread.
if !p.queueMut.TryLock(ctx) {
return nil, racing.LockTimeoutError("could not acquire lock in time when trying to get an idle connection")
}
p.queue.Remove(e)
// Ok, now that we own the queue we can add the item there but between getting the lock
// and above check for an existing connection another thread might have returned a connection
// so check again to avoid potentially starving this thread.
conn, err = p.tryAnyIdle(ctx, serverNames, idlenessThreshold, auth, boltLogger)
if err != nil {
p.queueMut.Unlock()
return nil, err
}
if conn != nil {
p.queueMut.Unlock()
return conn, nil
}
// Add a waiting request to the queue and unlock the queue to let other threads that return
// their connections access the queue.
q := &qitem{
wakeup: make(chan bool, 1),
}
e := p.queue.PushBack(q)
p.queueMut.Unlock()
if q.conn != nil {
return q.conn, nil

p.log.Warnf(log.Pool, p.logId, "Borrow queued")
// Wait for either a wake-up signal that indicates that we got a connection or a timeout.
select {
case <-q.wakeup:
continue
case <-ctx.Done():
if !p.queueMut.TryLock(context.Background()) {
return nil, racing.LockTimeoutError("could not acquire lock in time when removing server wait request")
}
p.queue.Remove(e)
p.queueMut.Unlock()
p.log.Warnf(log.Pool, p.logId, "Borrow time-out")
return nil, &errorutil.PoolTimeout{Err: ctx.Err(), Servers: serverNames}
}
p.log.Warnf(log.Pool, p.logId, "Borrow time-out")
return nil, &errorutil.PoolTimeout{Err: ctx.Err(), Servers: serverNames}
}
}

Expand All @@ -343,7 +319,7 @@ func (p *Pool) tryBorrow(ctx context.Context, serverName string, boltLogger log.
connection := srv.getIdle()
if connection == nil {
if srv.size() >= p.config.MaxConnectionPoolSize {
return nil, &errorutil.PoolFull{Servers: []string{serverName}}
return nil, nil
}
break
}
Expand Down Expand Up @@ -483,10 +459,20 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) error {
return err
}
p.log.Infof(log.Pool, p.logId, "Unregistering dead or too old connection to %s", serverName)
// Returning here could cause a waiting thread to wait until it times out, to do it
// properly we could wake up threads that waits on the server and wake them up if there
// are no more connections to wait for.
return nil
}

if isAlive {
// Just put it back in the list of idle connections for this server
if !p.serversMut.TryLock(ctx) {
return racing.LockTimeoutError("could not acquire server lock when putting connection back to idle")
}
server := p.servers[serverName]
if server != nil { // Strange when server not found
server.returnBusy(c)
} else {
p.log.Warnf(log.Pool, p.logId, "Server %s not found", serverName)
}
p.serversMut.Unlock()
}

// Check if there is anyone in the queue waiting for a connection to this server.
Expand All @@ -495,30 +481,11 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) error {
}
for e := p.queue.Front(); e != nil; e = e.Next() {
queuedRequest := e.Value.(*qitem)
// Check requested servers
for _, rserver := range queuedRequest.servers {
if rserver == serverName {
queuedRequest.conn = c
p.queue.Remove(e)
p.queueMut.Unlock()
queuedRequest.wakeup <- true
return nil
}
}
p.queue.Remove(e)
queuedRequest.wakeup <- true
}
p.queueMut.Unlock()

// Just put it back in the list of idle connections for this server
if !p.serversMut.TryLock(ctx) {
return racing.LockTimeoutError("could not acquire server lock when putting connection back to idle")
}
defer p.serversMut.Unlock()
server := p.servers[serverName]
if server != nil { // Strange when server not found
server.returnBusy(c)
} else {
p.log.Warnf(log.Pool, p.logId, "Server %s not found", serverName)
}
return nil
}

Expand Down
Loading

0 comments on commit cb75c31

Please sign in to comment.