Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition #785

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,24 @@ await con.EnqueueAsync(
await con.EnqueueAsync(PullAllMessage.Instance, NoOpResponseHandler.Instance);
pipeline.Verify(h => h.Enqueue(NoOpResponseHandler.Instance), Times.Exactly(2));
}

[Fact]
public async Task ShouldEnqueueBoth()
{
var pipeline = new Mock<IResponsePipeline>();
var con = NewSocketConnection(pipeline: pipeline.Object);

var m1 = new Mock<IRequestMessage>();
var h1 = new Mock<IResponseHandler>();
var m2 = new Mock<IRequestMessage>();
var h2 = new Mock<IResponseHandler>();

await con.EnqueueAsync(m1.Object, h1.Object, m2.Object, h2.Object);

con.Messages[0].Should().Be(m1.Object);
con.Messages[1].Should().Be(m2.Object);
pipeline.Verify(x => x.Enqueue(It.IsAny<IResponseHandler>()), Times.Exactly(2));
}
}

public class ResetMethod
Expand Down Expand Up @@ -347,4 +365,4 @@ public static TheoryData<Exception> GenerateObjectDisposedExceptions()
};
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,12 @@ public async Task ShouldUseQueryToFetchRoutingTableForBoltVersionLessThan43(int

mockConn.Verify(x => x.ConfigureMode(AccessMode.Read), Times.Once);
mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<IRequestMessage>(), It.IsAny<IResponseHandler>()),
Times.Exactly(2));
x => x.EnqueueAsync(
It.IsAny<RunWithMetadataMessage>(),
It.IsAny<RunResponseHandler>(),
It.IsAny<PullMessage>(),
It.IsAny<PullResponseHandler>()),
Times.Exactly(1));

mockConn.Verify(x => x.SendAsync(), Times.Once);
}
Expand Down Expand Up @@ -761,16 +765,12 @@ public async Task ShouldSendPullMessageWhenNotReactive()
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<IRequestMessage>(), It.IsAny<IResponseHandler>()),
Times.Exactly(2));

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<RunWithMetadataMessage>(), It.IsAny<RunResponseHandler>()),
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<PullMessage>(), It.IsAny<PullResponseHandler>()),
Times.Once);
x => x.EnqueueAsync(
It.IsNotNull<RunWithMetadataMessage>(),
It.IsNotNull<RunResponseHandler>(),
It.IsNotNull<PullMessage>(),
It.IsNotNull<PullResponseHandler>()),
Times.Exactly(1));

mockConn.Verify(x => x.SendAsync(), Times.Once);
mockConn.Verify(x => x.SyncAsync(), Times.Never);
Expand Down Expand Up @@ -1189,15 +1189,11 @@ await protocol.RunInExplicitTransactionAsync(
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<IRequestMessage>(), It.IsAny<IResponseHandler>()),
Times.Exactly(2));

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<RunWithMetadataMessage>(), It.IsAny<RunResponseHandler>()),
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<PullMessage>(), It.IsAny<PullResponseHandler>()),
x => x.EnqueueAsync(
It.IsAny<RunWithMetadataMessage>(),
It.IsAny<RunResponseHandler>(),
It.IsAny<PullMessage>(),
It.IsAny<PullResponseHandler>()),
Times.Once);

mockConn.Verify(x => x.SendAsync(), Times.Once);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,12 @@ public async Task ShouldSendRunWithMetadataMessageToGetRoutingTable()

mockConn.Verify(x => x.ConfigureMode(AccessMode.Read), Times.Once);
mockConn.Verify(
x => x.EnqueueAsync(It.IsAny<IRequestMessage>(), It.IsAny<IResponseHandler>()),
Times.Exactly(2));
x => x.EnqueueAsync(
It.IsAny<IRequestMessage>(),
It.IsAny<IResponseHandler>(),
It.IsAny<IRequestMessage>(),
It.IsAny<IResponseHandler>()),
Times.Once);

mockConn.Verify(x => x.SendAsync(), Times.Once);
}
Expand Down Expand Up @@ -391,15 +395,11 @@ public async Task ShouldSendMessages()
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsNotNull<IRequestMessage>(), It.IsNotNull<IResponseHandler>()),
Times.Exactly(2));

mockConn.Verify(
x => x.EnqueueAsync(It.IsNotNull<RunWithMetadataMessage>(), It.IsNotNull<RunResponseHandlerV3>()),
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(PullAllMessage.Instance, It.IsNotNull<PullAllResponseHandler>()),
x => x.EnqueueAsync(
It.IsNotNull<RunWithMetadataMessage>(),
It.IsNotNull<RunResponseHandlerV3>(),
PullAllMessage.Instance,
It.IsNotNull<PullAllResponseHandler>()),
Times.Once);

mockConn.Verify(x => x.SendAsync(), Times.Once);
Expand Down Expand Up @@ -672,15 +672,11 @@ public async Task ShouldSendMessages()
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(It.IsNotNull<IRequestMessage>(), It.IsNotNull<IResponseHandler>()),
Times.Exactly(2));

mockConn.Verify(
x => x.EnqueueAsync(It.IsNotNull<RunWithMetadataMessage>(), It.IsNotNull<RunResponseHandlerV3>()),
Times.Once);

mockConn.Verify(
x => x.EnqueueAsync(PullAllMessage.Instance, It.IsNotNull<PullAllResponseHandler>()),
x => x.EnqueueAsync(
It.IsNotNull<RunWithMetadataMessage>(),
It.IsNotNull<RunResponseHandlerV3>(),
PullAllMessage.Instance,
It.IsNotNull<PullAllResponseHandler>()),
Times.Once);

mockConn.Verify(x => x.SendAsync(), Times.Once);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ public async Task EnqueueAsync(IRequestMessage message, IResponseHandler handler
}
}

public async ValueTask EnqueueAsync(
IRequestMessage message1,
IResponseHandler handler1,
IRequestMessage message2,
IResponseHandler handler2)
{
try
{
await Delegate.EnqueueAsync(message1, handler1, message2, handler2).ConfigureAwait(false);
}
catch (Exception e)
{
await OnErrorAsync(e).ConfigureAwait(false);
}
}

public virtual bool IsOpen => Delegate.IsOpen;

public IServerInfo Server => Delegate.Server;
Expand Down Expand Up @@ -190,7 +206,7 @@ public ValueTask ValidateCredsAsync()
return Delegate.ValidateCredsAsync();
}

/// <inheritdoc />
/// <inheritdoc/>
public bool TelemetryEnabled
{
get => Delegate.TelemetryEnabled;
Expand Down Expand Up @@ -232,7 +248,10 @@ public Task BeginTransactionAsync(BeginTransactionParams beginTransactionParams)
return BoltProtocol.BeginTransactionAsync(this, beginTransactionParams);
}

public Task<IResultCursor> RunInExplicitTransactionAsync(Query query, bool reactive, long fetchSize,
public Task<IResultCursor> RunInExplicitTransactionAsync(
Query query,
bool reactive,
long fetchSize,
IInternalAsyncTransaction transaction)
{
return BoltProtocol.RunInExplicitTransactionAsync(this, query, reactive, fetchSize, transaction);
Expand Down
15 changes: 13 additions & 2 deletions Neo4j.Driver/Neo4j.Driver/Internal/Connector/IConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ internal interface IConnection : IConnectionDetails, IConnectionRunner
IAuthTokenManager AuthTokenManager { get; }

public SessionConfig SessionConfig { get; set; }
bool TelemetryEnabled { get; set; }

void ConfigureMode(AccessMode? mode);
void Configure(string database, AccessMode? mode);
Expand All @@ -83,6 +84,13 @@ Task InitAsync(

Task EnqueueAsync(IRequestMessage message, IResponseHandler handler);

ValueTask EnqueueAsync(
IRequestMessage message1,
IResponseHandler handler1,
IRequestMessage message2,
IResponseHandler handler2
);

// Enqueue a reset message
Task ResetAsync();

Expand All @@ -100,7 +108,6 @@ Task InitAsync(

void SetUseUtcEncodedDateTime();
ValueTask ValidateCredsAsync();
bool TelemetryEnabled { get; set; }
}

internal interface IConnectionRunner
Expand All @@ -123,8 +130,12 @@ Task<IResultCursor> RunInAutoCommitTransactionAsync(

Task BeginTransactionAsync(BeginTransactionParams beginParams);

Task<IResultCursor> RunInExplicitTransactionAsync(Query query, bool reactive, long fetchSize,
Task<IResultCursor> RunInExplicitTransactionAsync(
Query query,
bool reactive,
long fetchSize,
IInternalAsyncTransaction transaction);

Task CommitTransactionAsync(IBookmarksTracker bookmarksTracker);
Task RollbackTransactionAsync();
}
Expand Down
38 changes: 33 additions & 5 deletions Neo4j.Driver/Neo4j.Driver/Internal/Connector/SocketConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ internal SocketConnection(
ServerInfo server,
IResponsePipeline responsePipeline = null,
IAuthTokenManager authTokenManager = null,
IBoltProtocolFactory protocolFactory = null,
IBoltProtocolFactory protocolFactory = null,
DriverContext context = null)
{
_client = socketClient ?? throw new ArgumentNullException(nameof(socketClient));
Expand All @@ -88,6 +88,7 @@ internal SocketConnection(
}

internal IReadOnlyList<IRequestMessage> Messages => _messages.ToList();
public DriverContext Context { get; }

public AccessMode? Mode { get; private set; }

Expand Down Expand Up @@ -138,7 +139,11 @@ public async Task InitAsync(

try
{
await BoltProtocol.AuthenticateAsync(this, Context.Config.UserAgent, authToken, Context.Config.NotificationsConfig)
await BoltProtocol.AuthenticateAsync(
this,
Context.Config.UserAgent,
authToken,
Context.Config.NotificationsConfig)
.ConfigureAwait(false);
}
catch (Exception ex)
Expand Down Expand Up @@ -242,6 +247,27 @@ public async Task ReceiveOneAsync()
}
}

public async ValueTask EnqueueAsync(
IRequestMessage message1,
IResponseHandler handler1,
IRequestMessage message2,
IResponseHandler handler2)
{
await _sendLock.WaitAsync().ConfigureAwait(false);

try
{
_messages.Enqueue(message1);
_messages.Enqueue(message2);
_responsePipeline.Enqueue(handler1);
_responsePipeline.Enqueue(handler2);
}
finally
{
_sendLock.Release();
}
}

public Task ResetAsync()
{
return BoltProtocol.ResetAsync(this);
Expand All @@ -251,7 +277,6 @@ public Task ResetAsync()
public IServerInfo Server => _serverInfo;

public bool UtcEncodedDateTime { get; private set; }
public DriverContext Context { get; }
public IAuthToken AuthToken { get; private set; }
public bool TelemetryEnabled { get; set; }

Expand All @@ -263,7 +288,7 @@ public void UpdateId(string newConnId)
newConnId);

_id = newConnId;

if (_logger is PrefixLogger logger)
{
logger.Prefix = FormatPrefix(_id);
Expand Down Expand Up @@ -417,7 +442,10 @@ public Task BeginTransactionAsync(BeginTransactionParams beginParams)
return BoltProtocol.BeginTransactionAsync(this, beginParams);
}

public Task<IResultCursor> RunInExplicitTransactionAsync(Query query, bool reactive, long fetchSize,
public Task<IResultCursor> RunInExplicitTransactionAsync(
Query query,
bool reactive,
long fetchSize,
IInternalAsyncTransaction transaction)
{
return BoltProtocol.RunInExplicitTransactionAsync(this, query, reactive, fetchSize, transaction);
Expand Down
Loading