From ac46e58e3ae19edc8adbb562621918cb31859ef7 Mon Sep 17 00:00:00 2001 From: Grant Lodge <6323995+thelonelyvulpes@users.noreply.github.com> Date: Tue, 5 Mar 2024 20:57:45 +0000 Subject: [PATCH 1/2] Fix a race condition in the way message pairs are queued --- .../Internal/Protocol/BoltProtocolTests.cs | 38 +++++++++---------- .../Internal/Protocol/BoltProtocolV3Tests.cs | 36 ++++++++---------- .../Internal/Connector/DelegatedConnection.cs | 23 ++++++++++- .../Internal/Connector/IConnection.cs | 15 +++++++- .../Internal/Connector/SocketConnection.cs | 38 ++++++++++++++++--- .../Internal/Protocol/BoltProtocol.cs | 37 ++++++++++-------- .../Internal/Protocol/BoltProtocolV3.cs | 9 +++-- 7 files changed, 126 insertions(+), 70 deletions(-) diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Protocol/BoltProtocolTests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Protocol/BoltProtocolTests.cs index 0bf0f2ff8..c4a426fe9 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Protocol/BoltProtocolTests.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Protocol/BoltProtocolTests.cs @@ -448,8 +448,12 @@ public async Task ShouldUseQueryToFetchRoutingTableForBoltVersionLessThan43(int mockConn.Verify(x => x.ConfigureMode(AccessMode.Read), Times.Once); mockConn.Verify( - x => x.EnqueueAsync(It.IsAny(), It.IsAny()), - Times.Exactly(2)); + x => x.EnqueueAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny()), + Times.Exactly(1)); mockConn.Verify(x => x.SendAsync(), Times.Once); } @@ -761,16 +765,12 @@ public async Task ShouldSendPullMessageWhenNotReactive() Times.Once); mockConn.Verify( - x => x.EnqueueAsync(It.IsAny(), It.IsAny()), - Times.Exactly(2)); - - mockConn.Verify( - x => x.EnqueueAsync(It.IsAny(), It.IsAny()), - Times.Once); - - mockConn.Verify( - x => x.EnqueueAsync(It.IsAny(), It.IsAny()), - Times.Once); + x => x.EnqueueAsync( + It.IsNotNull(), + It.IsNotNull(), + It.IsNotNull(), + It.IsNotNull()), + Times.Exactly(1)); mockConn.Verify(x => x.SendAsync(), Times.Once); mockConn.Verify(x => x.SyncAsync(), Times.Never); @@ -1189,15 +1189,11 @@ await protocol.RunInExplicitTransactionAsync( Times.Once); mockConn.Verify( - x => x.EnqueueAsync(It.IsAny(), It.IsAny()), - Times.Exactly(2)); - - mockConn.Verify( - x => x.EnqueueAsync(It.IsAny(), It.IsAny()), - Times.Once); - - mockConn.Verify( - x => x.EnqueueAsync(It.IsAny(), It.IsAny()), + x => x.EnqueueAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny()), Times.Once); mockConn.Verify(x => x.SendAsync(), Times.Once); diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Protocol/BoltProtocolV3Tests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Protocol/BoltProtocolV3Tests.cs index bd12a6e45..1c9dc30f2 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Protocol/BoltProtocolV3Tests.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Internal/Protocol/BoltProtocolV3Tests.cs @@ -229,8 +229,12 @@ public async Task ShouldSendRunWithMetadataMessageToGetRoutingTable() mockConn.Verify(x => x.ConfigureMode(AccessMode.Read), Times.Once); mockConn.Verify( - x => x.EnqueueAsync(It.IsAny(), It.IsAny()), - Times.Exactly(2)); + x => x.EnqueueAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny()), + Times.Once); mockConn.Verify(x => x.SendAsync(), Times.Once); } @@ -391,15 +395,11 @@ public async Task ShouldSendMessages() Times.Once); mockConn.Verify( - x => x.EnqueueAsync(It.IsNotNull(), It.IsNotNull()), - Times.Exactly(2)); - - mockConn.Verify( - x => x.EnqueueAsync(It.IsNotNull(), It.IsNotNull()), - Times.Once); - - mockConn.Verify( - x => x.EnqueueAsync(PullAllMessage.Instance, It.IsNotNull()), + x => x.EnqueueAsync( + It.IsNotNull(), + It.IsNotNull(), + PullAllMessage.Instance, + It.IsNotNull()), Times.Once); mockConn.Verify(x => x.SendAsync(), Times.Once); @@ -672,15 +672,11 @@ public async Task ShouldSendMessages() Times.Once); mockConn.Verify( - x => x.EnqueueAsync(It.IsNotNull(), It.IsNotNull()), - Times.Exactly(2)); - - mockConn.Verify( - x => x.EnqueueAsync(It.IsNotNull(), It.IsNotNull()), - Times.Once); - - mockConn.Verify( - x => x.EnqueueAsync(PullAllMessage.Instance, It.IsNotNull()), + x => x.EnqueueAsync( + It.IsNotNull(), + It.IsNotNull(), + PullAllMessage.Instance, + It.IsNotNull()), Times.Once); mockConn.Verify(x => x.SendAsync(), Times.Once); diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Connector/DelegatedConnection.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Connector/DelegatedConnection.cs index 93d85c86b..a33d90133 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Connector/DelegatedConnection.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Connector/DelegatedConnection.cs @@ -134,6 +134,22 @@ public async Task EnqueueAsync(IRequestMessage message, IResponseHandler handler } } + public async ValueTask EnqueueAsync( + IRequestMessage message1, + IResponseHandler handler1, + IRequestMessage message2, + IResponseHandler handler2) + { + try + { + await Delegate.EnqueueAsync(message1, handler1, message2, handler2).ConfigureAwait(false); + } + catch (Exception e) + { + await OnErrorAsync(e).ConfigureAwait(false); + } + } + public virtual bool IsOpen => Delegate.IsOpen; public IServerInfo Server => Delegate.Server; @@ -190,7 +206,7 @@ public ValueTask ValidateCredsAsync() return Delegate.ValidateCredsAsync(); } - /// + /// public bool TelemetryEnabled { get => Delegate.TelemetryEnabled; @@ -232,7 +248,10 @@ public Task BeginTransactionAsync(BeginTransactionParams beginTransactionParams) return BoltProtocol.BeginTransactionAsync(this, beginTransactionParams); } - public Task RunInExplicitTransactionAsync(Query query, bool reactive, long fetchSize, + public Task RunInExplicitTransactionAsync( + Query query, + bool reactive, + long fetchSize, IInternalAsyncTransaction transaction) { return BoltProtocol.RunInExplicitTransactionAsync(this, query, reactive, fetchSize, transaction); diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Connector/IConnection.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Connector/IConnection.cs index 8b563b053..256082f77 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Connector/IConnection.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Connector/IConnection.cs @@ -60,6 +60,7 @@ internal interface IConnection : IConnectionDetails, IConnectionRunner IAuthTokenManager AuthTokenManager { get; } public SessionConfig SessionConfig { get; set; } + bool TelemetryEnabled { get; set; } void ConfigureMode(AccessMode? mode); void Configure(string database, AccessMode? mode); @@ -83,6 +84,13 @@ Task InitAsync( Task EnqueueAsync(IRequestMessage message, IResponseHandler handler); + ValueTask EnqueueAsync( + IRequestMessage message1, + IResponseHandler handler1, + IRequestMessage message2, + IResponseHandler handler2 + ); + // Enqueue a reset message Task ResetAsync(); @@ -100,7 +108,6 @@ Task InitAsync( void SetUseUtcEncodedDateTime(); ValueTask ValidateCredsAsync(); - bool TelemetryEnabled { get; set; } } internal interface IConnectionRunner @@ -123,8 +130,12 @@ Task RunInAutoCommitTransactionAsync( Task BeginTransactionAsync(BeginTransactionParams beginParams); - Task RunInExplicitTransactionAsync(Query query, bool reactive, long fetchSize, + Task RunInExplicitTransactionAsync( + Query query, + bool reactive, + long fetchSize, IInternalAsyncTransaction transaction); + Task CommitTransactionAsync(IBookmarksTracker bookmarksTracker); Task RollbackTransactionAsync(); } diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Connector/SocketConnection.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Connector/SocketConnection.cs index 8bb271174..e775bdb41 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Connector/SocketConnection.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Connector/SocketConnection.cs @@ -73,7 +73,7 @@ internal SocketConnection( ServerInfo server, IResponsePipeline responsePipeline = null, IAuthTokenManager authTokenManager = null, - IBoltProtocolFactory protocolFactory = null, + IBoltProtocolFactory protocolFactory = null, DriverContext context = null) { _client = socketClient ?? throw new ArgumentNullException(nameof(socketClient)); @@ -88,6 +88,7 @@ internal SocketConnection( } internal IReadOnlyList Messages => _messages.ToList(); + public DriverContext Context { get; } public AccessMode? Mode { get; private set; } @@ -138,7 +139,11 @@ public async Task InitAsync( try { - await BoltProtocol.AuthenticateAsync(this, Context.Config.UserAgent, authToken, Context.Config.NotificationsConfig) + await BoltProtocol.AuthenticateAsync( + this, + Context.Config.UserAgent, + authToken, + Context.Config.NotificationsConfig) .ConfigureAwait(false); } catch (Exception ex) @@ -242,6 +247,27 @@ public async Task ReceiveOneAsync() } } + public async ValueTask EnqueueAsync( + IRequestMessage message1, + IResponseHandler handler1, + IRequestMessage message2, + IResponseHandler handler2) + { + await _sendLock.WaitAsync().ConfigureAwait(false); + + try + { + _messages.Enqueue(message1); + _messages.Enqueue(message2); + _responsePipeline.Enqueue(handler1); + _responsePipeline.Enqueue(handler2); + } + finally + { + _sendLock.Release(); + } + } + public Task ResetAsync() { return BoltProtocol.ResetAsync(this); @@ -251,7 +277,6 @@ public Task ResetAsync() public IServerInfo Server => _serverInfo; public bool UtcEncodedDateTime { get; private set; } - public DriverContext Context { get; } public IAuthToken AuthToken { get; private set; } public bool TelemetryEnabled { get; set; } @@ -263,7 +288,7 @@ public void UpdateId(string newConnId) newConnId); _id = newConnId; - + if (_logger is PrefixLogger logger) { logger.Prefix = FormatPrefix(_id); @@ -417,7 +442,10 @@ public Task BeginTransactionAsync(BeginTransactionParams beginParams) return BoltProtocol.BeginTransactionAsync(this, beginParams); } - public Task RunInExplicitTransactionAsync(Query query, bool reactive, long fetchSize, + public Task RunInExplicitTransactionAsync( + Query query, + bool reactive, + long fetchSize, IInternalAsyncTransaction transaction) { return BoltProtocol.RunInExplicitTransactionAsync(this, query, reactive, fetchSize, transaction); diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocol.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocol.cs index bd3ed9abd..019272206 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocol.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocol.cs @@ -85,7 +85,7 @@ public Task> GetRoutingTableAsync( ? GetRoutingTableWithRouteMessageAsync(connection, database, sessionConfig?.ImpersonatedUser, bookmarks) : GetRoutingTableWithQueryAsync(connection, database, bookmarks); } - + public async Task RunInAutoCommitTransactionAsync( IConnection connection, AutoCommitParams autoCommitParams, @@ -115,7 +115,6 @@ public async Task RunInAutoCommitTransactionAsync( var runHandler = _protocolHandlerFactory.NewRunResponseHandler(streamBuilder, summaryBuilder); await AddTelemetryAsync(connection, autoCommitParams.TransactionInfo).ConfigureAwait(false); - await connection.EnqueueAsync(runMessage, runHandler).ConfigureAwait(false); if (!autoCommitParams.Reactive) { @@ -125,7 +124,11 @@ public async Task RunInAutoCommitTransactionAsync( streamBuilder, summaryBuilder); - await connection.EnqueueAsync(pullMessage, pullHandler).ConfigureAwait(false); + await connection.EnqueueAsync(runMessage, runHandler, pullMessage, pullHandler).ConfigureAwait(false); + } + else + { + await connection.EnqueueAsync(runMessage, runHandler).ConfigureAwait(false); } await connection.SendAsync().ConfigureAwait(false); @@ -164,19 +167,31 @@ public async Task RunInExplicitTransactionAsync( var runMessage = _protocolMessageFactory.NewRunWithMetadataMessage(connection, query, null); var runHandler = _protocolHandlerFactory.NewRunResponseHandler(streamBuilder, summaryBuilder); - await connection.EnqueueAsync(runMessage, runHandler).ConfigureAwait(false); - if (!reactive) { var pullMessage = _protocolMessageFactory.NewPullMessage(fetchSize); var pullHandler = _protocolHandlerFactory.NewPullResponseHandler(null, streamBuilder, summaryBuilder); - await connection.EnqueueAsync(pullMessage, pullHandler).ConfigureAwait(false); + await connection.EnqueueAsync(runMessage, runHandler, pullMessage, pullHandler).ConfigureAwait(false); + } + else + { + await connection.EnqueueAsync(runMessage, runHandler).ConfigureAwait(false); } await connection.SendAsync().ConfigureAwait(false); return streamBuilder.CreateCursor(); } + public Task CommitTransactionAsync(IConnection connection, IBookmarksTracker bookmarksTracker) + { + return _boltProtocolV3.CommitTransactionAsync(connection, bookmarksTracker); + } + + public Task RollbackTransactionAsync(IConnection connection) + { + return _boltProtocolV3.RollbackTransactionAsync(connection); + } + private Task AddTelemetryAsync(IConnection connection, TransactionInfo info) { if (!(info?.TelemetryEnabled ?? false) || !connection.TelemetryEnabled) @@ -189,16 +204,6 @@ private Task AddTelemetryAsync(IConnection connection, TransactionInfo info) return connection.EnqueueAsync(message, handler); } - public Task CommitTransactionAsync(IConnection connection, IBookmarksTracker bookmarksTracker) - { - return _boltProtocolV3.CommitTransactionAsync(connection, bookmarksTracker); - } - - public Task RollbackTransactionAsync(IConnection connection) - { - return _boltProtocolV3.RollbackTransactionAsync(connection); - } - private async Task AuthenticateWithLogonAsync( IConnection connection, string userAgent, diff --git a/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolV3.cs b/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolV3.cs index 868881b18..d9d416349 100644 --- a/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolV3.cs +++ b/Neo4j.Driver/Neo4j.Driver/Internal/Protocol/BoltProtocolV3.cs @@ -152,8 +152,8 @@ public async Task RunInAutoCommitTransactionAsync( autoCommitParams, notificationsConfig); - await connection.EnqueueAsync(autoCommitMessage, runHandler).ConfigureAwait(false); - await connection.EnqueueAsync(PullAllMessage.Instance, pullAllHandler).ConfigureAwait(false); + await connection.EnqueueAsync(autoCommitMessage, runHandler, PullAllMessage.Instance, pullAllHandler) + .ConfigureAwait(false); await connection.SendAsync().ConfigureAwait(false); return streamBuilder.CreateCursor(); @@ -208,8 +208,9 @@ public async Task RunInExplicitTransactionAsync( var message = _protocolMessageFactory.NewRunWithMetadataMessage(connection, query, null); - await connection.EnqueueAsync(message, runHandler).ConfigureAwait(false); - await connection.EnqueueAsync(PullAllMessage.Instance, pullAllHandler).ConfigureAwait(false); + await connection.EnqueueAsync(message, runHandler, PullAllMessage.Instance, pullAllHandler) + .ConfigureAwait(false); + await connection.SendAsync().ConfigureAwait(false); return streamBuilder.CreateCursor(); From aac9df9f8a5cf79270dff77a39078530c3d4f303 Mon Sep 17 00:00:00 2001 From: Grant Lodge <6323995+thelonelyvulpes@users.noreply.github.com> Date: Tue, 5 Mar 2024 21:12:20 +0000 Subject: [PATCH 2/2] add a test --- .../Connector/SocketConnectionTests.cs | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/Neo4j.Driver/Neo4j.Driver.Tests/Connector/SocketConnectionTests.cs b/Neo4j.Driver/Neo4j.Driver.Tests/Connector/SocketConnectionTests.cs index d1f16ab99..4766c3cf8 100644 --- a/Neo4j.Driver/Neo4j.Driver.Tests/Connector/SocketConnectionTests.cs +++ b/Neo4j.Driver/Neo4j.Driver.Tests/Connector/SocketConnectionTests.cs @@ -216,6 +216,24 @@ await con.EnqueueAsync( await con.EnqueueAsync(PullAllMessage.Instance, NoOpResponseHandler.Instance); pipeline.Verify(h => h.Enqueue(NoOpResponseHandler.Instance), Times.Exactly(2)); } + + [Fact] + public async Task ShouldEnqueueBoth() + { + var pipeline = new Mock(); + var con = NewSocketConnection(pipeline: pipeline.Object); + + var m1 = new Mock(); + var h1 = new Mock(); + var m2 = new Mock(); + var h2 = new Mock(); + + await con.EnqueueAsync(m1.Object, h1.Object, m2.Object, h2.Object); + + con.Messages[0].Should().Be(m1.Object); + con.Messages[1].Should().Be(m2.Object); + pipeline.Verify(x => x.Enqueue(It.IsAny()), Times.Exactly(2)); + } } public class ResetMethod @@ -347,4 +365,4 @@ public static TheoryData GenerateObjectDisposedExceptions() }; } } -} \ No newline at end of file +}