From 8a8ab7fc02d0678fbfbc516e422ac17f98cf27fb Mon Sep 17 00:00:00 2001 From: Jason Ginchereau Date: Fri, 12 Jul 2024 16:12:03 -1000 Subject: [PATCH] Fix race condition with protocol extensions --- src/cs/Ssh/IO/SshProtocol.cs | 2 +- src/cs/Ssh/SshSession.cs | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/cs/Ssh/IO/SshProtocol.cs b/src/cs/Ssh/IO/SshProtocol.cs index 457df73..2e0bd2a 100644 --- a/src/cs/Ssh/IO/SshProtocol.cs +++ b/src/cs/Ssh/IO/SshProtocol.cs @@ -91,7 +91,7 @@ public SshProtocol( internal bool TraceChannelData { get; set; } - internal Dictionary? Extensions { get; set; } + internal IReadOnlyDictionary? Extensions { get; set; } internal KeyExchangeService? KeyExchangeService { get; set; } diff --git a/src/cs/Ssh/SshSession.cs b/src/cs/Ssh/SshSession.cs index d18f341..2c373fd 100644 --- a/src/cs/Ssh/SshSession.cs +++ b/src/cs/Ssh/SshSession.cs @@ -1475,23 +1475,27 @@ internal async Task HandleMessageAsync( return; } - Protocol.Extensions = new Dictionary(); - var proposedExtensions = message.ExtensionInfo; if (proposedExtensions == null) { + Protocol.Extensions = new Dictionary(); return; } + // Fill the extensions dictionary with only the extensions that are enabled. + // Assign to the Protocol.Extensions property only after it is filled to avoid a race. + Dictionary extensions = new Dictionary(); foreach (var extensionName in Config.ProtocolExtensions) { if (proposedExtensions.TryGetValue(extensionName, out var value)) { - Protocol.Extensions.Add(extensionName, value); + extensions.Add(extensionName, value); } } - if (Protocol.Extensions.ContainsKey(SshProtocolExtensionNames.SessionReconnect)) + Protocol.Extensions = extensions; + + if (extensions.ContainsKey(SshProtocolExtensionNames.SessionReconnect)) { // Reconnect is not enabled until each side sends a special request message. await EnableReconnectAsync(cancellation).ConfigureAwait(false);