Skip to content

Commit

Permalink
Fix race condition (#785)
Browse files Browse the repository at this point in the history
* Fix a race condition in the way message pairs are queued

* add a test
  • Loading branch information
thelonelyvulpes authored Mar 6, 2024
1 parent f01703a commit a1670b7
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,24 @@ await con.EnqueueAsync(
await con.EnqueueAsync(PullAllMessage.Instance, NoOpResponseHandler.Instance);
pipeline.Verify(h => h.Enqueue(NoOpResponseHandler.Instance), Times.Exactly(2));
}

[Fact]
public async Task ShouldEnqueueBoth()
{
var pipeline = new Mock<IResponsePipeline>();
var con = NewSocketConnection(pipeline: pipeline.Object);

var m1 = new Mock<IRequestMessage>();
var h1 = new Mock<IResponseHandler>();
var m2 = new Mock<IRequestMessage>();
var h2 = new Mock<IResponseHandler>();

await con.EnqueueAsync(m1.Object, h1.Object, m2.Object, h2.Object);

con.Messages[0].Should().Be(m1.Object);
con.Messages[1].Should().Be(m2.Object);
pipeline.Verify(x => x.Enqueue(It.IsAny<IResponseHandler>()), Times.Exactly(2));
}
}

public class ResetMethod
Expand Down Expand Up @@ -347,4 +365,4 @@ public static TheoryData<Exception> GenerateObjectDisposedExceptions()
};
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,12 @@ public async Task ShouldUseQueryToFetchRoutingTableForBoltVersionLessThan43(int

mockConn.Verify(x => x.ConfigureMode(AccessMode.Read), Times.Once);
mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<IRequestMessage>(), It.IsAny<IResponseHandler>()),
Times.Exactly(2));
x => x.EnqueueAsync(
It.IsAny<RunWithMetadataMessage>(),
It.IsAny<RunResponseHandler>(),
It.IsAny<PullMessage>(),
It.IsAny<PullResponseHandler>()),
Times.Exactly(1));

mockConn.Verify(x => x.SendAsync(), Times.Once);
}
Expand Down Expand Up @@ -761,16 +765,12 @@ public async Task ShouldSendPullMessageWhenNotReactive()
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<IRequestMessage>(), It.IsAny<IResponseHandler>()),
Times.Exactly(2));

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<RunWithMetadataMessage>(), It.IsAny<RunResponseHandler>()),
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<PullMessage>(), It.IsAny<PullResponseHandler>()),
Times.Once);
x => x.EnqueueAsync(
It.IsNotNull<RunWithMetadataMessage>(),
It.IsNotNull<RunResponseHandler>(),
It.IsNotNull<PullMessage>(),
It.IsNotNull<PullResponseHandler>()),
Times.Exactly(1));

mockConn.Verify(x => x.SendAsync(), Times.Once);
mockConn.Verify(x => x.SyncAsync(), Times.Never);
Expand Down Expand Up @@ -1189,15 +1189,11 @@ await protocol.RunInExplicitTransactionAsync(
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<IRequestMessage>(), It.IsAny<IResponseHandler>()),
Times.Exactly(2));

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<RunWithMetadataMessage>(), It.IsAny<RunResponseHandler>()),
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<PullMessage>(), It.IsAny<PullResponseHandler>()),
x => x.EnqueueAsync(
It.IsAny<RunWithMetadataMessage>(),
It.IsAny<RunResponseHandler>(),
It.IsAny<PullMessage>(),
It.IsAny<PullResponseHandler>()),
Times.Once);

mockConn.Verify(x => x.SendAsync(), Times.Once);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,12 @@ public async Task ShouldSendRunWithMetadataMessageToGetRoutingTable()

mockConn.Verify(x => x.ConfigureMode(AccessMode.Read), Times.Once);
mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<IRequestMessage>(), It.IsAny<IResponseHandler>()),
Times.Exactly(2));
x => x.EnqueueAsync(
It.IsAny<IRequestMessage>(),
It.IsAny<IResponseHandler>(),
It.IsAny<IRequestMessage>(),
It.IsAny<IResponseHandler>()),
Times.Once);

mockConn.Verify(x => x.SendAsync(), Times.Once);
}
Expand Down Expand Up @@ -391,15 +395,11 @@ public async Task ShouldSendMessages()
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsNotNull<IRequestMessage>(), It.IsNotNull<IResponseHandler>()),
Times.Exactly(2));

mockConn.Verify(
x => x.EnqueueAsync(It.IsNotNull<RunWithMetadataMessage>(), It.IsNotNull<RunResponseHandlerV3>()),
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(PullAllMessage.Instance, It.IsNotNull<PullAllResponseHandler>()),
x => x.EnqueueAsync(
It.IsNotNull<RunWithMetadataMessage>(),
It.IsNotNull<RunResponseHandlerV3>(),
PullAllMessage.Instance,
It.IsNotNull<PullAllResponseHandler>()),
Times.Once);

mockConn.Verify(x => x.SendAsync(), Times.Once);
Expand Down Expand Up @@ -672,15 +672,11 @@ public async Task ShouldSendMessages()
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsNotNull<IRequestMessage>(), It.IsNotNull<IResponseHandler>()),
Times.Exactly(2));

mockConn.Verify(
x => x.EnqueueAsync(It.IsNotNull<RunWithMetadataMessage>(), It.IsNotNull<RunResponseHandlerV3>()),
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(PullAllMessage.Instance, It.IsNotNull<PullAllResponseHandler>()),
x => x.EnqueueAsync(
It.IsNotNull<RunWithMetadataMessage>(),
It.IsNotNull<RunResponseHandlerV3>(),
PullAllMessage.Instance,
It.IsNotNull<PullAllResponseHandler>()),
Times.Once);

mockConn.Verify(x => x.SendAsync(), Times.Once);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ public async Task EnqueueAsync(IRequestMessage message, IResponseHandler handler
}
}

public async ValueTask EnqueueAsync(
IRequestMessage message1,
IResponseHandler handler1,
IRequestMessage message2,
IResponseHandler handler2)
{
try
{
await Delegate.EnqueueAsync(message1, handler1, message2, handler2).ConfigureAwait(false);
}
catch (Exception e)
{
await OnErrorAsync(e).ConfigureAwait(false);
}
}

public virtual bool IsOpen => Delegate.IsOpen;

public IServerInfo Server => Delegate.Server;
Expand Down Expand Up @@ -190,7 +206,7 @@ public ValueTask ValidateCredsAsync()
return Delegate.ValidateCredsAsync();
}

/// <inheritdoc />
/// <inheritdoc/>
public bool TelemetryEnabled
{
get => Delegate.TelemetryEnabled;
Expand Down Expand Up @@ -232,7 +248,10 @@ public Task BeginTransactionAsync(BeginTransactionParams beginTransactionParams)
return BoltProtocol.BeginTransactionAsync(this, beginTransactionParams);
}

public Task<IResultCursor> RunInExplicitTransactionAsync(Query query, bool reactive, long fetchSize,
public Task<IResultCursor> RunInExplicitTransactionAsync(
Query query,
bool reactive,
long fetchSize,
IInternalAsyncTransaction transaction)
{
return BoltProtocol.RunInExplicitTransactionAsync(this, query, reactive, fetchSize, transaction);
Expand Down
15 changes: 13 additions & 2 deletions Neo4j.Driver/Neo4j.Driver/Internal/Connector/IConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ internal interface IConnection : IConnectionDetails, IConnectionRunner
IAuthTokenManager AuthTokenManager { get; }

public SessionConfig SessionConfig { get; set; }
bool TelemetryEnabled { get; set; }

void ConfigureMode(AccessMode? mode);
void Configure(string database, AccessMode? mode);
Expand All @@ -83,6 +84,13 @@ Task InitAsync(

Task EnqueueAsync(IRequestMessage message, IResponseHandler handler);

ValueTask EnqueueAsync(
IRequestMessage message1,
IResponseHandler handler1,
IRequestMessage message2,
IResponseHandler handler2
);

// Enqueue a reset message
Task ResetAsync();

Expand All @@ -100,7 +108,6 @@ Task InitAsync(

void SetUseUtcEncodedDateTime();
ValueTask ValidateCredsAsync();
bool TelemetryEnabled { get; set; }
}

internal interface IConnectionRunner
Expand All @@ -123,8 +130,12 @@ Task<IResultCursor> RunInAutoCommitTransactionAsync(

Task BeginTransactionAsync(BeginTransactionParams beginParams);

Task<IResultCursor> RunInExplicitTransactionAsync(Query query, bool reactive, long fetchSize,
Task<IResultCursor> RunInExplicitTransactionAsync(
Query query,
bool reactive,
long fetchSize,
IInternalAsyncTransaction transaction);

Task CommitTransactionAsync(IBookmarksTracker bookmarksTracker);
Task RollbackTransactionAsync();
}
Expand Down
38 changes: 33 additions & 5 deletions Neo4j.Driver/Neo4j.Driver/Internal/Connector/SocketConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ internal SocketConnection(
ServerInfo server,
IResponsePipeline responsePipeline = null,
IAuthTokenManager authTokenManager = null,
IBoltProtocolFactory protocolFactory = null,
IBoltProtocolFactory protocolFactory = null,
DriverContext context = null)
{
_client = socketClient ?? throw new ArgumentNullException(nameof(socketClient));
Expand All @@ -88,6 +88,7 @@ internal SocketConnection(
}

internal IReadOnlyList<IRequestMessage> Messages => _messages.ToList();
public DriverContext Context { get; }

public AccessMode? Mode { get; private set; }

Expand Down Expand Up @@ -138,7 +139,11 @@ public async Task InitAsync(

try
{
await BoltProtocol.AuthenticateAsync(this, Context.Config.UserAgent, authToken, Context.Config.NotificationsConfig)
await BoltProtocol.AuthenticateAsync(
this,
Context.Config.UserAgent,
authToken,
Context.Config.NotificationsConfig)
.ConfigureAwait(false);
}
catch (Exception ex)
Expand Down Expand Up @@ -242,6 +247,27 @@ public async Task ReceiveOneAsync()
}
}

public async ValueTask EnqueueAsync(
IRequestMessage message1,
IResponseHandler handler1,
IRequestMessage message2,
IResponseHandler handler2)
{
await _sendLock.WaitAsync().ConfigureAwait(false);

try
{
_messages.Enqueue(message1);
_messages.Enqueue(message2);
_responsePipeline.Enqueue(handler1);
_responsePipeline.Enqueue(handler2);
}
finally
{
_sendLock.Release();
}
}

public Task ResetAsync()
{
return BoltProtocol.ResetAsync(this);
Expand All @@ -251,7 +277,6 @@ public Task ResetAsync()
public IServerInfo Server => _serverInfo;

public bool UtcEncodedDateTime { get; private set; }
public DriverContext Context { get; }
public IAuthToken AuthToken { get; private set; }
public bool TelemetryEnabled { get; set; }

Expand All @@ -263,7 +288,7 @@ public void UpdateId(string newConnId)
newConnId);

_id = newConnId;

if (_logger is PrefixLogger logger)
{
logger.Prefix = FormatPrefix(_id);
Expand Down Expand Up @@ -417,7 +442,10 @@ public Task BeginTransactionAsync(BeginTransactionParams beginParams)
return BoltProtocol.BeginTransactionAsync(this, beginParams);
}

public Task<IResultCursor> RunInExplicitTransactionAsync(Query query, bool reactive, long fetchSize,
public Task<IResultCursor> RunInExplicitTransactionAsync(
Query query,
bool reactive,
long fetchSize,
IInternalAsyncTransaction transaction)
{
return BoltProtocol.RunInExplicitTransactionAsync(this, query, reactive, fetchSize, transaction);
Expand Down
Loading

0 comments on commit a1670b7

Please sign in to comment.