From f6376a267e4814508fc9494779f67143093c80ef Mon Sep 17 00:00:00 2001 From: Gerke Geurts Date: Wed, 5 Sep 2018 00:28:43 +0200 Subject: [PATCH] Solve race conditions in Connection response processing --- lib/PuppeteerSharp/Browser.cs | 1 - lib/PuppeteerSharp/CDPSession.cs | 65 +++++----- lib/PuppeteerSharp/Connection.cs | 137 +++++++++----------- lib/PuppeteerSharp/TargetClosedException.cs | 13 +- 4 files changed, 103 insertions(+), 113 deletions(-) diff --git a/lib/PuppeteerSharp/Browser.cs b/lib/PuppeteerSharp/Browser.cs index 9420b9ebb..c22b065ea 100644 --- a/lib/PuppeteerSharp/Browser.cs +++ b/lib/PuppeteerSharp/Browser.cs @@ -239,7 +239,6 @@ public async Task GetUserAgentAsync() private async Task CloseCoreAsync() { - Connection.StopReading(); try { try diff --git a/lib/PuppeteerSharp/CDPSession.cs b/lib/PuppeteerSharp/CDPSession.cs index 684c5f890..9c028d153 100755 --- a/lib/PuppeteerSharp/CDPSession.cs +++ b/lib/PuppeteerSharp/CDPSession.cs @@ -1,10 +1,10 @@ using System; +using System.Collections.Concurrent; +using System.Collections.Generic; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; using Newtonsoft.Json; -using System.Collections.Generic; using Newtonsoft.Json.Linq; -using Microsoft.Extensions.Logging; -using PuppeteerSharp.Helpers; namespace PuppeteerSharp { @@ -43,16 +43,14 @@ internal CDPSession(IConnection connection, TargetType targetType, string sessio TargetType = targetType; SessionId = sessionId; - _callbacks = new Dictionary(); _logger = Connection.LoggerFactory.CreateLogger(); - _sessions = new Dictionary(); } #region Private Members private int _lastId; - private readonly Dictionary _callbacks; + private readonly ConcurrentDictionary _callbacks = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); private readonly ILogger _logger; - private readonly Dictionary _sessions; #endregion #region Properties @@ -79,11 +77,12 @@ internal CDPSession(IConnection connection, TargetType targetType, string sessio /// Occurs when tracing is completed. /// public event EventHandler TracingComplete; + /// /// Gets or sets a value indicating whether this is closed. /// /// true if is closed; otherwise, false. - public bool IsClosed { get; internal set; } + public bool IsClosed => Connection == null; /// /// Gets the logger factory. @@ -111,6 +110,7 @@ internal async Task SendAsync(string method, bool rawContent, dynamic a { throw new Exception($"Protocol error ({method}): Session closed. Most likely the {TargetType} has been closed."); } + var id = ++_lastId; var message = JsonConvert.SerializeObject(new Dictionary { @@ -126,8 +126,7 @@ internal async Task SendAsync(string method, bool rawContent, dynamic a Method = method, RawContent = rawContent }; - - _callbacks[id] = callback; + _callbacks.TryAdd(id, callback); try { @@ -139,10 +138,9 @@ internal async Task SendAsync(string method, bool rawContent, dynamic a } catch (Exception ex) { - if (_callbacks.ContainsKey(id)) + if (_callbacks.TryRemove(id, out _)) { - _callbacks.Remove(id); - callback.TaskWrapper.SetException(new MessageException(ex.Message, ex)); + callback.TaskWrapper.TrySetException(new MessageException(ex.Message, ex)); } } @@ -163,18 +161,16 @@ public Task DetachAsync() internal void OnMessage(string message) { dynamic obj = JsonConvert.DeserializeObject(message); - var objAsJObject = obj as JObject; + var objAsJObject = (JObject)obj; _logger.LogTrace("◀ Receive {Message}", message); - if (objAsJObject["id"] != null && _callbacks.ContainsKey((int)obj.id)) + var id = (int?)objAsJObject["id"]; + if (id.HasValue && _callbacks.TryRemove(id.Value, out var callback)) { - var callback = _callbacks[(int)obj.id]; - _callbacks.Remove((int)obj.id); - if (objAsJObject["error"] != null) { - callback.TaskWrapper.SetException(new MessageException( + callback.TaskWrapper.TrySetException(new MessageException( $"Protocol error ({ callback.Method }): {obj.error.message} {obj.error.data}" )); } @@ -182,11 +178,11 @@ internal void OnMessage(string message) { if (callback.RawContent) { - callback.TaskWrapper.SetResult(JsonConvert.SerializeObject(obj.result)); + callback.TaskWrapper.TrySetResult(JsonConvert.SerializeObject(obj.result)); } else { - callback.TaskWrapper.SetResult(obj.result); + callback.TaskWrapper.TrySetResult(obj.result); } } } @@ -201,19 +197,16 @@ internal void OnMessage(string message) } else if (obj.method == "Target.receivedMessageFromTarget") { - var session = _sessions.GetValueOrDefault(objAsJObject["params"]["sessionId"].ToString()); - if (session != null) + if (_sessions.TryGetValue(objAsJObject["params"]["sessionId"].ToString(), out var session)) { session.OnMessage(objAsJObject["params"]["message"].ToString()); } } else if (obj.method == "Target.detachedFromTarget") { - var session = _sessions.GetValueOrDefault(objAsJObject["params"]["sessionId"].ToString()); - if (!(session?.IsClosed ?? true)) + if (_sessions.TryRemove(objAsJObject["params"]["sessionId"].ToString(), out var session) && !session.IsClosed) { session.OnClosed(); - _sessions.Remove(objAsJObject["params"]["sessionId"].ToString()); } } @@ -227,24 +220,30 @@ internal void OnMessage(string message) internal void OnClosed() { - IsClosed = true; - foreach (var callback in _callbacks.Values) + if (Connection == null) { - callback.TaskWrapper.SetException(new TargetClosedException( - $"Protocol error({callback.Method}): Target closed." - )); + return; } - _callbacks.Clear(); Connection = null; + + foreach (var entry in _callbacks) + { + if (_callbacks.TryRemove(entry.Key, out _)) + { + entry.Value.TaskWrapper.TrySetException( + new TargetClosedException($"Protocol error({entry.Value.Method}): Target closed.")); + } + } } internal CDPSession CreateSession(TargetType targetType, string sessionId) { var session = new CDPSession(this, targetType, sessionId); - _sessions[sessionId] = session; + _sessions.TryAdd(sessionId, session); return session; } #endregion + #region IConnection ILoggerFactory IConnection.LoggerFactory => LoggerFactory; bool IConnection.IsClosed => IsClosed; diff --git a/lib/PuppeteerSharp/Connection.cs b/lib/PuppeteerSharp/Connection.cs index c807e1f13..5622c5726 100755 --- a/lib/PuppeteerSharp/Connection.cs +++ b/lib/PuppeteerSharp/Connection.cs @@ -1,6 +1,6 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; -using System.Linq; using System.Net.WebSockets; using System.Text; using System.Threading; @@ -27,22 +27,16 @@ internal Connection(string url, int delay, WebSocket ws, ILoggerFactory loggerFa WebSocket = ws; _logger = LoggerFactory.CreateLogger(); - _socketQueue = new TaskQueue(); - _responses = new Dictionary(); - _sessions = new Dictionary(); - _websocketReaderCancellationSource = new CancellationTokenSource(); Task.Factory.StartNew(GetResponseAsync); } #region Private Members private int _lastId; - private Dictionary _responses; - private Dictionary _sessions; - private TaskQueue _socketQueue; - private const string CloseMessage = "Browser.close"; - private bool _stopReading; - private CancellationTokenSource _websocketReaderCancellationSource; + private readonly ConcurrentDictionary _responses = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _sessions = new ConcurrentDictionary(); + private readonly TaskQueue _socketQueue = new TaskQueue(); + private readonly CancellationTokenSource _websocketReaderCancellationSource = new CancellationTokenSource(); #endregion #region Properties @@ -50,17 +44,17 @@ internal Connection(string url, int delay, WebSocket ws, ILoggerFactory loggerFa /// Gets the WebSocket URL. /// /// The URL. - public string Url { get; private set; } + public string Url { get; } /// /// Gets the sleep time when a message is received. /// /// The delay. - public int Delay { get; private set; } + public int Delay { get; } /// /// Gets the WebSocket. /// /// The web socket. - public WebSocket WebSocket { get; private set; } + public WebSocket WebSocket { get; } /// /// Occurs when the connection is closed. /// @@ -102,23 +96,28 @@ internal async Task SendAsync(string method, dynamic args = null) _logger.LogTrace("Send ► {Id} Method {Method} Params {@Params}", id, method, (object)args); - var taskWrapper = new TaskCompletionSource(); - _responses[id] = new MessageTask + var callback = new MessageTask { - TaskWrapper = taskWrapper, + TaskWrapper = new TaskCompletionSource(), Method = method }; + _responses.TryAdd(id, callback); - var encoded = Encoding.UTF8.GetBytes(message); - var buffer = new ArraySegment(encoded, 0, encoded.Length); - await _socketQueue.Enqueue(() => WebSocket.SendAsync(buffer, WebSocketMessageType.Text, true, default)).ConfigureAwait(false); - - if (method == CloseMessage) + try + { + var encoded = Encoding.UTF8.GetBytes(message); + var buffer = new ArraySegment(encoded, 0, encoded.Length); + await _socketQueue.Enqueue(() => WebSocket.SendAsync(buffer, WebSocketMessageType.Text, true, default)).ConfigureAwait(false); + } + catch (Exception ex) { - StopReading(); + if (_responses.TryRemove(id, out _)) + { + callback.TaskWrapper.TrySetException(ex); + } } - return await taskWrapper.Task.ConfigureAwait(false); + return await callback.TaskWrapper.Task.ConfigureAwait(false); } internal async Task SendAsync(string method, dynamic args = null) @@ -134,41 +133,40 @@ internal async Task CreateSessionAsync(TargetInfo targetInfo) targetId = targetInfo.TargetId }).ConfigureAwait(false)).sessionId; var session = new CDPSession(this, targetInfo.Type, sessionId); - _sessions.Add(sessionId, session); + _sessions.TryAdd(sessionId, session); return session; } #endregion - private void OnClose() + private void OnClose(Exception ex) { if (IsClosed) { return; } - IsClosed = true; _websocketReaderCancellationSource.Cancel(); Closed?.Invoke(this, new EventArgs()); - foreach (var session in _sessions.Values) + foreach (var entry in _sessions) { - session.OnClosed(); + if (_sessions.TryRemove(entry.Key, out _)) + { + entry.Value.OnClosed(); + } } - foreach (var response in _responses.Values.Where(r => !r.TaskWrapper.Task.IsCompleted)) + foreach (var entry in _responses) { - response.TaskWrapper.SetException(new TargetClosedException( - $"Protocol error({response.Method}): Target closed." - )); + if (_responses.TryRemove(entry.Key, out _)) + { + entry.Value.TaskWrapper.TrySetException( + new TargetClosedException($"Protocol error({entry.Value.Method}): Target closed.", ex)); + } } - - _responses.Clear(); - _sessions.Clear(); } - internal void StopReading() => _stopReading = true; - #region Private Methods /// @@ -180,105 +178,86 @@ private async Task GetResponseAsync() var buffer = new byte[2048]; //If it's not in the list we wait for it - while (true) + while (!IsClosed) { - if (IsClosed) - { - OnClose(); - return null; - } - var endOfMessage = false; - var response = string.Empty; + var response = new StringBuilder(); while (!endOfMessage) { - WebSocketReceiveResult result = null; + WebSocketReceiveResult result; try { result = await WebSocket.ReceiveAsync( new ArraySegment(buffer), _websocketReaderCancellationSource.Token).ConfigureAwait(false); } - catch (Exception) when (_stopReading) - { - return null; - } catch (OperationCanceledException) { return null; } - catch (Exception) + catch (Exception ex) { - if (!IsClosed) - { - OnClose(); - return null; - } + OnClose(ex); + return null; } endOfMessage = result.EndOfMessage; if (result.MessageType == WebSocketMessageType.Text) { - response += Encoding.UTF8.GetString(buffer, 0, result.Count); + response.Append(Encoding.UTF8.GetString(buffer, 0, result.Count)); } else if (result.MessageType == WebSocketMessageType.Close) { - OnClose(); + OnClose(null); return null; } } - if (!string.IsNullOrEmpty(response)) + if (response.Length > 0) { if (Delay > 0) { await Task.Delay(Delay).ConfigureAwait(false); } - ProcessResponse(response); + ProcessResponse(response.ToString()); } } + + return null; } private void ProcessResponse(string response) { dynamic obj = JsonConvert.DeserializeObject(response); - var objAsJObject = obj as JObject; + var objAsJObject = (JObject)obj; _logger.LogTrace("◀ Receive {Message}", response); - if (objAsJObject["id"] != null) + var id = (int?)objAsJObject["id"]; + if (id.HasValue) { - var id = (int)objAsJObject["id"]; - - //If we get the object we are waiting for we return if - //if not we add this to the list, sooner or later some one will come for it - if (!_responses.ContainsKey(id)) + if (_responses.TryRemove(id.Value, out var callback)) { - _responses[id] = new MessageTask { TaskWrapper = new TaskCompletionSource() }; + callback.TaskWrapper.TrySetResult(obj.result); } - - _responses[id].TaskWrapper.SetResult(obj.result); } else { if (obj.method == "Target.receivedMessageFromTarget") { - var session = _sessions.GetValueOrDefault(objAsJObject["params"]["sessionId"].ToString()); - if (session != null) + if (_sessions.TryGetValue(objAsJObject["params"]["sessionId"].ToString(), out var session)) { session.OnMessage(objAsJObject["params"]["message"].ToString()); } } else if (obj.method == "Target.detachedFromTarget") { - var session = _sessions.GetValueOrDefault(objAsJObject["params"]["sessionId"].ToString()); - if (!(session?.IsClosed ?? true)) + if (_sessions.TryRemove(objAsJObject["params"]["sessionId"].ToString(), out var session) && !session.IsClosed) { session.OnClosed(); - _sessions.Remove(objAsJObject["params"]["sessionId"].ToString()); } } else @@ -286,12 +265,14 @@ private void ProcessResponse(string response) MessageReceived?.Invoke(this, new MessageEventArgs { MessageID = obj.method, - MessageData = objAsJObject["params"] as dynamic + MessageData = objAsJObject["params"] }); } } } + #endregion + #region Static Methods /// @@ -322,7 +303,7 @@ internal static async Task Create(string url, IConnectionOptions con /// was occupying. public void Dispose() { - OnClose(); + OnClose(new ObjectDisposedException($"Connection({Url})")); WebSocket.Dispose(); } #endregion diff --git a/lib/PuppeteerSharp/TargetClosedException.cs b/lib/PuppeteerSharp/TargetClosedException.cs index 10811b0d8..23a319c58 100644 --- a/lib/PuppeteerSharp/TargetClosedException.cs +++ b/lib/PuppeteerSharp/TargetClosedException.cs @@ -1,4 +1,6 @@ -namespace PuppeteerSharp +using System; + +namespace PuppeteerSharp { /// /// Exception thrown by the when it detects that the target was closed. @@ -12,5 +14,14 @@ public class TargetClosedException : PuppeteerException public TargetClosedException(string message) : base(message) { } + + /// + /// Initializes a new instance of the class. + /// + /// Message. + /// Inner exception. + public TargetClosedException(string message, Exception innerException) : base(message, innerException) + { + } } } \ No newline at end of file