From 9428c150fa0ee37a2b62393c742af9baa2a13d67 Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Wed, 25 Dec 2024 21:31:01 +0100 Subject: [PATCH] PrivateKeyCredential: support avoiding prompts for keys that server won't accept. --- README.md | 7 +- src/Tmds.Ssh/AlgorithmNames.cs | 38 ++++++++- src/Tmds.Ssh/ECDsaPrivateKey.cs | 6 +- src/Tmds.Ssh/Ed25519PrivateKey.cs | 6 +- src/Tmds.Ssh/PrivateKey.cs | 4 +- src/Tmds.Ssh/PrivateKeyCredential.cs | 78 +++++++++++++++---- src/Tmds.Ssh/PrivateKeyParser.OpenSsh.cs | 32 ++++---- src/Tmds.Ssh/PrivateKeyParser.cs | 8 +- src/Tmds.Ssh/RsaPrivateKey.cs | 4 +- src/Tmds.Ssh/SshAgentPrivateKey.cs | 27 +++---- src/Tmds.Ssh/SshClientLogger.cs | 2 +- .../UserAuthentication.PublicKeyAuth.cs | 26 +++---- .../PrivateKeyCredentialTests.cs | 38 +++++++-- 13 files changed, 193 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index 9eb5498..33e5d40 100644 --- a/README.md +++ b/README.md @@ -489,10 +489,10 @@ abstract class Credential class PrivateKeyCredential : Credential { PrivateKeyCredential(string path, string? password = null, string? identifier ??= path); - PrivateKeyCredential(string path, Func passwordPrompt, string? identifier ??= path); + PrivateKeyCredential(string path, Func passwordPrompt, bool queryKey = true, string? identifier ??= path); PrivateKeyCredential(char[] rawKey, string? password = null, string identifier = "[raw key]"); - PrivateKeyCredential(char[] rawKey, Func passwordPrompt, string identifier = "[raw key]"); + PrivateKeyCredential(char[] rawKey, Func passwordPrompt, bool queryKey = true, string identifier = "[raw key]"); // Enable derived classes to use private keys from other sources. protected PrivateKeyCredential(Func> loadKey, string identifier); @@ -500,7 +500,8 @@ class PrivateKeyCredential : Credential { Key(RSA rsa); Key(ECDsa ecdsa); - Key(ReadOnlyMemory rawKey, Func? passwordPrompt = null); + Key(ReadOnlyMemory rawKey, string? password = null); + Key(ReadOnlyMemory rawKey, Func passwordPrompt, bool queryKey = true); } } class PasswordCredential : Credential diff --git a/src/Tmds.Ssh/AlgorithmNames.cs b/src/Tmds.Ssh/AlgorithmNames.cs index 65cedae..0ed94d4 100644 --- a/src/Tmds.Ssh/AlgorithmNames.cs +++ b/src/Tmds.Ssh/AlgorithmNames.cs @@ -18,7 +18,7 @@ static class AlgorithmNames // TODO: rename to KnownNames private static readonly byte[] EcdhSha2Nistp521Bytes = "ecdh-sha2-nistp521"u8.ToArray(); public static Name EcdhSha2Nistp521 => new Name(EcdhSha2Nistp521Bytes); - // Host key algorithms. + // Host key algorithms: key types and signature algorithms. private static readonly byte[] SshRsaBytes = "ssh-rsa"u8.ToArray(); public static Name SshRsa => new Name(SshRsaBytes); private static readonly byte[] RsaSshSha2_256Bytes = "rsa-sha2-256"u8.ToArray(); @@ -33,6 +33,41 @@ static class AlgorithmNames // TODO: rename to KnownNames public static Name EcdsaSha2Nistp521 => new Name(EcdsaSha2Nistp521Bytes); private static readonly byte[] SshEd25519Bytes = "ssh-ed25519"u8.ToArray(); public static Name SshEd25519 => new Name(SshEd25519Bytes); + // Key type to signature algorithms. + public static readonly Name[] SshRsaAlgorithms = [ RsaSshSha2_512, RsaSshSha2_256 ]; + public static readonly Name[] EcdsaSha2Nistp256Algorithms = [ EcdsaSha2Nistp256 ]; + public static readonly Name[] EcdsaSha2Nistp384Algorithms = [ EcdsaSha2Nistp384 ]; + public static readonly Name[] EcdsaSha2Nistp521Algorithms = [ EcdsaSha2Nistp521 ]; + public static readonly Name[] SshEd25519Algorithms = [ SshEd25519 ]; + + public static Name[] GetAlgorithmsForKeyType(Name keyType) + { + if (keyType == SshRsa) + { + return SshRsaAlgorithms; + } + else if (keyType == EcdsaSha2Nistp256) + { + return EcdsaSha2Nistp256Algorithms; + } + else if (keyType == EcdsaSha2Nistp384) + { + return EcdsaSha2Nistp384Algorithms; + } + else if (keyType == EcdsaSha2Nistp521) + { + return EcdsaSha2Nistp521Algorithms; + } + else if (keyType == SshEd25519) + { + return SshEd25519Algorithms; + } + else + { + // Unknown key types. + return [ keyType ]; + } + } // Encryption algorithms. private static readonly byte[] Aes128CbcBytes = "aes128-cbc"u8.ToArray(); @@ -72,7 +107,6 @@ static class AlgorithmNames // TODO: rename to KnownNames // These fields are initialized in order, so these list must be created after the names. // Algorithms are in **order of preference**. - public static readonly Name[] SshRsaAlgorithms = [ RsaSshSha2_512, RsaSshSha2_256 ]; // Authentications private static readonly byte[] GssApiWithMicBytes = "gssapi-with-mic"u8.ToArray(); diff --git a/src/Tmds.Ssh/ECDsaPrivateKey.cs b/src/Tmds.Ssh/ECDsaPrivateKey.cs index 9ffdf65..59be0ac 100644 --- a/src/Tmds.Ssh/ECDsaPrivateKey.cs +++ b/src/Tmds.Ssh/ECDsaPrivateKey.cs @@ -16,7 +16,7 @@ sealed class ECDsaPrivateKey : PrivateKey private readonly HashAlgorithmName _hashAlgorithm; public ECDsaPrivateKey(ECDsa ecdsa, Name algorithm, Name curveName, HashAlgorithmName hashAlgorithm, SshKey sshPublicKey) : - base([algorithm], sshPublicKey) + base(AlgorithmNames.GetAlgorithmsForKeyType(algorithm), sshPublicKey) { _ecdsa = ecdsa ?? throw new ArgumentNullException(nameof(ecdsa)); _algorithm = algorithm; @@ -41,7 +41,7 @@ public static SshKey DeterminePublicSshKey(ECDsa ecdsa, Name algorithm, Name cur return new SshKey(algorithm, writer.ToArray()); } - public override ValueTask TrySignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken) + public override ValueTask SignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken) { if (algorithm != _algorithm) { @@ -67,6 +67,6 @@ public static SshKey DeterminePublicSshKey(ECDsa ecdsa, Name algorithm, Name cur innerWriter.WriteString(algorithm); innerWriter.WriteString(ecdsaSigWriter.ToArray()); - return ValueTask.FromResult((byte[]?)innerWriter.ToArray()); + return ValueTask.FromResult(innerWriter.ToArray()); } } diff --git a/src/Tmds.Ssh/Ed25519PrivateKey.cs b/src/Tmds.Ssh/Ed25519PrivateKey.cs index 4c6bc90..c71bc48 100644 --- a/src/Tmds.Ssh/Ed25519PrivateKey.cs +++ b/src/Tmds.Ssh/Ed25519PrivateKey.cs @@ -13,7 +13,7 @@ sealed class Ed25519PrivateKey : PrivateKey private readonly byte[] _publicKey; public Ed25519PrivateKey(byte[] privateKey, byte[] publicKey, SshKey sshPublicKey) : - base([AlgorithmNames.SshEd25519], sshPublicKey) + base(AlgorithmNames.SshEd25519Algorithms, sshPublicKey) { _privateKey = privateKey; _publicKey = publicKey; @@ -31,7 +31,7 @@ public static SshKey DeterminePublicSshKey(byte[] privateKey, byte[] publicKey) return new SshKey(AlgorithmNames.SshEd25519, writer.ToArray()); } - public override ValueTask TrySignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken) + public override ValueTask SignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken) { if (algorithm != Algorithms[0]) { @@ -55,6 +55,6 @@ public static SshKey DeterminePublicSshKey(byte[] privateKey, byte[] publicKey) innerWriter.WriteString(algorithm); innerWriter.WriteString(signature); - return ValueTask.FromResult((byte[]?)innerWriter.ToArray()); + return ValueTask.FromResult(innerWriter.ToArray()); } } diff --git a/src/Tmds.Ssh/PrivateKey.cs b/src/Tmds.Ssh/PrivateKey.cs index 5c697ca..2a25487 100644 --- a/src/Tmds.Ssh/PrivateKey.cs +++ b/src/Tmds.Ssh/PrivateKey.cs @@ -1,8 +1,6 @@ // This file is part of Tmds.Ssh which is released under MIT. // See file LICENSE for full license details. -using System.Buffers; - namespace Tmds.Ssh; abstract class PrivateKey : IDisposable @@ -18,5 +16,5 @@ private protected PrivateKey(Name[] algorithms, SshKey publicKey) public abstract void Dispose(); - public abstract ValueTask TrySignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken); + public abstract ValueTask SignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken); } diff --git a/src/Tmds.Ssh/PrivateKeyCredential.cs b/src/Tmds.Ssh/PrivateKeyCredential.cs index ec6cca9..2573dab 100644 --- a/src/Tmds.Ssh/PrivateKeyCredential.cs +++ b/src/Tmds.Ssh/PrivateKeyCredential.cs @@ -12,19 +12,19 @@ public class PrivateKeyCredential : Credential private Func> LoadKey { get; } public PrivateKeyCredential(string path, string? password = null, string? identifier = null) : - this(path, () => password, identifier) + this(path, () => password, queryKey: false, identifier) { } - public PrivateKeyCredential(string path, Func passwordPrompt, string? identifier = null) : - this(LoadKeyFromFile(path ?? throw new ArgumentNullException(nameof(path)), passwordPrompt), identifier ?? path) + public PrivateKeyCredential(string path, Func passwordPrompt, bool queryKey = true, string? identifier = null) : + this(LoadKeyFromFile(path ?? throw new ArgumentNullException(nameof(path)), passwordPrompt, queryKey), identifier ?? path) { } public PrivateKeyCredential(char[] rawKey, string? password = null, string identifier = "[raw key]") : - this(rawKey, () => password, identifier) + this(rawKey, () => password, queryKey: false, identifier) { } - public PrivateKeyCredential(char[] rawKey, Func passwordPrompt, string identifier = "[raw key]") : - this(LoadRawKey(ValidateRawKeyArgument(rawKey), passwordPrompt), identifier) + public PrivateKeyCredential(char[] rawKey, Func passwordPrompt, bool queryKey = true, string identifier = "[raw key]") : + this(LoadRawKey(ValidateRawKeyArgument(rawKey), passwordPrompt, queryKey), identifier) { } // Allows the user to implement derived classes that represent a private key. @@ -43,15 +43,15 @@ private static char[] ValidateRawKeyArgument(char[] rawKey) return rawKey; } - private static Func> LoadRawKey(char[] rawKey, Func? passwordPrompt) + private static Func> LoadRawKey(char[] rawKey, Func passwordPrompt, bool queryKey) => (CancellationToken cancellationToken) => { - Key key = new Key(rawKey.AsMemory(), passwordPrompt); + Key key = new Key(rawKey.AsMemory(), passwordPrompt, queryKey); return ValueTask.FromResult(key); }; - private static Func> LoadKeyFromFile(string path, Func passwordPrompt) + private static Func> LoadKeyFromFile(string path, Func passwordPrompt, bool queryKey) => (CancellationToken cancellationToken) => { string rawKey; @@ -65,26 +65,37 @@ private static Func> LoadKeyFromFile(string pa return ValueTask.FromResult(default(Key)); // not found. } - Key key = new Key(rawKey.AsMemory(), passwordPrompt); + Key key = new Key(rawKey.AsMemory(), passwordPrompt, queryKey); return ValueTask.FromResult(key); }; // This is a type we expose to our derive types to avoid having to expose PrivateKey and a bunch of other internals. - protected readonly struct Key + internal protected readonly struct Key : IDisposable { internal PrivateKey? PrivateKey { get; } + internal bool QueryKey { get; } + public Key(RSA rsa) { PrivateKey = new RsaPrivateKey(rsa, RsaPrivateKey.DeterminePublicSshKey(rsa)); } - public Key(ReadOnlyMemory rawKey, Func? passwordPrompt = null) + public Key(ReadOnlyMemory rawKey, string? password = null) + { + QueryKey = false; + + PrivateKey = PrivateKeyParser.ParsePrivateKey(rawKey, passwordPrompt: delegate { return password; }); + } + + public Key(ReadOnlyMemory rawKey, Func passwordPrompt, bool queryKey = true) { - passwordPrompt ??= delegate { return null; }; + ArgumentNullException.ThrowIfNull(passwordPrompt); + + QueryKey = queryKey; - PrivateKey = PrivateKeyParser.ParsePrivateKey(rawKey, passwordPrompt); + PrivateKey = queryKey ? ParsedPrivateKey.Create(rawKey, passwordPrompt) : PrivateKeyParser.ParsePrivateKey(rawKey, passwordPrompt); } public Key(ECDsa ecdsa) @@ -122,11 +133,44 @@ internal Key(PrivateKey key) private static bool OidEquals(Oid oidA, Oid oidB) => oidA.Value is not null && oidB.Value is not null && oidA.Value == oidB.Value; + + public void Dispose() + { + PrivateKey?.Dispose(); + } + } + + internal async ValueTask LoadKeyAsync(CancellationToken cancellationToken) + { + return await LoadKey(cancellationToken); } - internal async ValueTask LoadKeyAsync(CancellationToken cancellationToken) + sealed class ParsedPrivateKey : PrivateKey { - Key key = await LoadKey(cancellationToken); - return key.PrivateKey; + private ReadOnlyMemory _rawKey; + private Func _passwordPrompt; + + public static PrivateKey Create(ReadOnlyMemory rawKey, Func passwordPrompt) + { + SshKey sshKey = PrivateKeyParser.ParsePublicKey(rawKey); + Name[] algorithms = AlgorithmNames.GetAlgorithmsForKeyType(sshKey.Type); + return new ParsedPrivateKey(algorithms, sshKey, rawKey, passwordPrompt); + } + + private ParsedPrivateKey(Name[] algorithms, SshKey publicKey, ReadOnlyMemory rawKey, Func passwordPrompt) + : base(algorithms, publicKey) + { + _rawKey = rawKey; + _passwordPrompt = passwordPrompt; + } + + public override void Dispose() + { } + + public override ValueTask SignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken) + { + using PrivateKey pk = PrivateKeyParser.ParsePrivateKey(_rawKey, _passwordPrompt); + return pk.SignAsync(algorithm, data, cancellationToken); + } } } diff --git a/src/Tmds.Ssh/PrivateKeyParser.OpenSsh.cs b/src/Tmds.Ssh/PrivateKeyParser.OpenSsh.cs index 9637f5d..be32af1 100644 --- a/src/Tmds.Ssh/PrivateKeyParser.OpenSsh.cs +++ b/src/Tmds.Ssh/PrivateKeyParser.OpenSsh.cs @@ -2,7 +2,6 @@ // See file LICENSE for full license details. using System.Buffers; -using System.Diagnostics.CodeAnalysis; using System.Numerics; using System.Security.Cryptography; using System.Text; @@ -15,9 +14,10 @@ partial class PrivateKeyParser /// Parses an OpenSSH PEM formatted key. This is a new key format used by /// OpenSSH for private keys. /// - internal static PrivateKey ParseOpenSshKey( + internal static (SshKey PublicKey, PrivateKey? PrivateKey) ParseOpenSshKey( byte[] keyData, - Func passwordPrompt) + Func passwordPrompt, + bool parsePrivate) { // https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key /* @@ -48,7 +48,13 @@ string publickeyN { throw new FormatException($"The data contains multiple keys."); } - byte[] publicKey = reader.ReadStringAsByteArray(); + + SshKey publicKey = reader.ReadSshKey(); + if (!parsePrivate) + { + return (publicKey, null); + } + ReadOnlySequence privateKeyList; if (cipherName == AlgorithmNames.None) { @@ -94,15 +100,15 @@ byte padlen % 255 Name keyType = reader.ReadName(); if (keyType == AlgorithmNames.SshRsa) { - return ParseOpenSshRsaKey(publicKey, reader); + return (publicKey, ParseOpenSshRsaKey(publicKey, reader)); } else if (keyType.ToString().StartsWith("ecdsa-sha2-")) { - return ParseOpenSshEcdsaKey(publicKey, keyType, reader); + return (publicKey, ParseOpenSshEcdsaKey(publicKey, keyType, reader)); } else if (keyType == AlgorithmNames.SshEd25519) { - return ParseOpenSshEd25519Key(publicKey, reader); + return (publicKey, ParseOpenSshEd25519Key(publicKey, reader)); } else { @@ -166,7 +172,7 @@ uint32 rounds } } - private static PrivateKey ParseOpenSshRsaKey(byte[] publicKey, SequenceReader reader) + private static PrivateKey ParseOpenSshRsaKey(SshKey publicKey, SequenceReader reader) { // .NET RSA's class has some length expectations: // D must have the same length as Modulus. @@ -198,7 +204,7 @@ private static PrivateKey ParseOpenSshRsaKey(byte[] publicKey, SequenceReader re try { rsa.ImportParameters(parameters); - return new RsaPrivateKey(rsa, new SshKey(AlgorithmNames.SshRsa, publicKey)); + return new RsaPrivateKey(rsa, publicKey); } catch (Exception ex) { @@ -207,7 +213,7 @@ private static PrivateKey ParseOpenSshRsaKey(byte[] publicKey, SequenceReader re } } - private static PrivateKey ParseOpenSshEcdsaKey(byte[] publicKey, Name keyType, SequenceReader reader) + private static PrivateKey ParseOpenSshEcdsaKey(SshKey publicKey, Name keyType, SequenceReader reader) { Name curveName = reader.ReadName(); @@ -247,7 +253,7 @@ private static PrivateKey ParseOpenSshEcdsaKey(byte[] publicKey, Name keyType, S }; ecdsa.ImportParameters(parameters); - return new ECDsaPrivateKey(ecdsa, keyType, curveName, hashAlgorithm, new SshKey(keyType, publicKey)); + return new ECDsaPrivateKey(ecdsa, keyType, curveName, hashAlgorithm, publicKey); } catch (Exception ex) { @@ -256,7 +262,7 @@ private static PrivateKey ParseOpenSshEcdsaKey(byte[] publicKey, Name keyType, S } } - private static PrivateKey ParseOpenSshEd25519Key(byte[] sshPublicKey, SequenceReader reader) + private static PrivateKey ParseOpenSshEd25519Key(SshKey sshPublicKey, SequenceReader reader) { // https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent-14#section-3.2.3 /* @@ -276,7 +282,7 @@ concatenation of the private key k and the public ENC(A) key. Why it is return new Ed25519PrivateKey( keyData.Slice(0, keyData.Length - publicKey.Length).ToArray(), publicKey.ToArray(), - new SshKey(AlgorithmNames.SshEd25519, sshPublicKey)); + sshPublicKey); } catch (Exception ex) { diff --git a/src/Tmds.Ssh/PrivateKeyParser.cs b/src/Tmds.Ssh/PrivateKeyParser.cs index efe00ac..1954e65 100644 --- a/src/Tmds.Ssh/PrivateKeyParser.cs +++ b/src/Tmds.Ssh/PrivateKeyParser.cs @@ -8,6 +8,12 @@ namespace Tmds.Ssh; partial class PrivateKeyParser { internal static PrivateKey ParsePrivateKey(ReadOnlyMemory rawKey, Func passwordPrompt) + => ParseKey(rawKey, passwordPrompt, parsePrivate: true).PrivateKey!; + + internal static SshKey ParsePublicKey(ReadOnlyMemory rawKey) + => ParseKey(rawKey, passwordPrompt: delegate { return null; }, parsePrivate: false).PublicKey; + + private static (SshKey PublicKey, PrivateKey? PrivateKey) ParseKey(ReadOnlyMemory rawKey, Func passwordPrompt, bool parsePrivate) { ReadOnlySpan content = rawKey.Span; @@ -74,7 +80,7 @@ internal static PrivateKey ParsePrivateKey(ReadOnlyMemory rawKey, Func TrySignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken) + public override ValueTask SignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken) { HashAlgorithmName hashAlgorithmName; if (algorithm == AlgorithmNames.RsaSshSha2_256) @@ -62,6 +62,6 @@ public static SshKey DeterminePublicSshKey(RSA rsa) } innerWriter.WriteString(signature); - return ValueTask.FromResult((byte[]?)innerWriter.ToArray()); + return ValueTask.FromResult(innerWriter.ToArray()); } } diff --git a/src/Tmds.Ssh/SshAgentPrivateKey.cs b/src/Tmds.Ssh/SshAgentPrivateKey.cs index 524908c..a73ef58 100644 --- a/src/Tmds.Ssh/SshAgentPrivateKey.cs +++ b/src/Tmds.Ssh/SshAgentPrivateKey.cs @@ -1,5 +1,7 @@ // This file is part of Tmds.Ssh which is released under MIT. // See file LICENSE for full license details. +using System.Security.Cryptography; + namespace Tmds.Ssh; sealed class SshAgentPrivateKey : PrivateKey @@ -7,33 +9,28 @@ sealed class SshAgentPrivateKey : PrivateKey private readonly SshAgent _sshAgent; public SshAgentPrivateKey(SshAgent sshAgent, SshKey publicKey) : - base(GetAlgorithmsForKeyType(publicKey.Type), publicKey) + base(AlgorithmNames.GetAlgorithmsForKeyType(publicKey.Type), publicKey) { _sshAgent = sshAgent; } - private static Name[] GetAlgorithmsForKeyType(Name keyType) - { - if (keyType == AlgorithmNames.SshRsa) - { - return AlgorithmNames.SshRsaAlgorithms; - } - else - { - return [ keyType ]; - } - } - public override void Dispose() { } - public override async ValueTask TrySignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken) + public override async ValueTask SignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken) { if (Array.IndexOf(Algorithms, algorithm) == -1) { ThrowHelper.ThrowProtocolUnexpectedValue(); } - return await _sshAgent.TrySignAsync(algorithm, PublicKey.Data, data, cancellationToken).ConfigureAwait(false); + byte[]? signature = await _sshAgent.TrySignAsync(algorithm, PublicKey.Data, data, cancellationToken).ConfigureAwait(false); + + if (signature is null) + { + throw new CryptographicException("SSH Agent failed to sign."); + } + + return signature; } } diff --git a/src/Tmds.Ssh/SshClientLogger.cs b/src/Tmds.Ssh/SshClientLogger.cs index 7208935..35e048c 100644 --- a/src/Tmds.Ssh/SshClientLogger.cs +++ b/src/Tmds.Ssh/SshClientLogger.cs @@ -272,7 +272,7 @@ public static void PacketSend(this ILogger logger, ReadOnlyPacket pac EventId = 31, Level = LogLevel.Error, Message = "Private key '{KeyIdentifier}' failed to sign with {Algorithm}")] - public static partial void PrivateKeyFailedToSign(this ILogger logger, string keyIdentifier, Name algorithm); + public static partial void PrivateKeyFailedToSign(this ILogger logger, string keyIdentifier, Name algorithm, Exception exception); [LoggerMessage( EventId = 32, diff --git a/src/Tmds.Ssh/UserAuthentication.PublicKeyAuth.cs b/src/Tmds.Ssh/UserAuthentication.PublicKeyAuth.cs index 434355c..efc63f2 100644 --- a/src/Tmds.Ssh/UserAuthentication.PublicKeyAuth.cs +++ b/src/Tmds.Ssh/UserAuthentication.PublicKeyAuth.cs @@ -17,11 +17,11 @@ public sealed class PublicKeyAuth public static async Task TryAuthenticate(PrivateKeyCredential keyCredential, UserAuthContext context, SshConnectionInfo connectionInfo, ILogger logger, CancellationToken ct) { string keyIdentifier = keyCredential.Identifier; - PrivateKey? pk; + PrivateKeyCredential.Key key; try { - pk = await keyCredential.LoadKeyAsync(ct); - if (pk is null) + key = await keyCredential.LoadKeyAsync(ct); + if (key.PrivateKey is null) { logger.PrivateKeyNotFound(keyIdentifier); return AuthResult.Skipped; @@ -35,9 +35,9 @@ public static async Task TryAuthenticate(PrivateKeyCredential keyCre AuthResult result; - using (pk) + using (key.PrivateKey) { - result = await DoAuthAsync(keyIdentifier, pk, queryKey: false, context, context.SupportedAcceptedPublicKeyAlgorithms, connectionInfo, logger, ct).ConfigureAwait(false); + result = await DoAuthAsync(keyIdentifier, key.PrivateKey, key.QueryKey, context, context.SupportedAcceptedPublicKeyAlgorithms, connectionInfo, logger, ct).ConfigureAwait(false); } return result; @@ -49,15 +49,11 @@ private static bool MeetsMinimumRSAKeySize(PrivateKey privateKey, int minimumRSA { return rsaKey.KeySize >= minimumRSAKeySize; } - else if (privateKey is SshAgentPrivateKey) + else { RsaPublicKey publicKey = RsaPublicKey.CreateFromSshKey(privateKey.PublicKey.Data); return publicKey.KeySize >= minimumRSAKeySize; } - else - { - throw new NotSupportedException($"Unexpected PrivateKey type: {privateKey.GetType().FullName}"); - } } public static async ValueTask DoAuthAsync(string keyIdentifier, PrivateKey pk, bool queryKey, UserAuthContext context, IReadOnlyCollection? acceptedAlgorithms, SshConnectionInfo connectionInfo, ILogger logger, CancellationToken ct) @@ -118,10 +114,14 @@ public static async ValueTask DoAuthAsync(string keyIdentifier, Priv } byte[] data = CreateDataForSigning(signAlgorithm, context.UserName, connectionInfo.SessionId!, pk.PublicKey.Data); - byte[]? signature = await pk.TrySignAsync(signAlgorithm, data, ct); - if (signature is null) + byte[] signature; + try + { + signature = await pk.SignAsync(signAlgorithm, data, ct); + } + catch (Exception ex) { - logger.PrivateKeyFailedToSign(keyIdentifier, signAlgorithm); + logger.PrivateKeyFailedToSign(keyIdentifier, signAlgorithm, ex); return AuthResult.Failure; } diff --git a/test/Tmds.Ssh.Tests/PrivateKeyCredentialTests.cs b/test/Tmds.Ssh.Tests/PrivateKeyCredentialTests.cs index 6647545..3259a02 100644 --- a/test/Tmds.Ssh.Tests/PrivateKeyCredentialTests.cs +++ b/test/Tmds.Ssh.Tests/PrivateKeyCredentialTests.cs @@ -54,6 +54,7 @@ public ECDsaKeyCredential(ECCurve curve) dm+XIBiEGmR1nKHD1y/Axu9QHE7vLyULTt8NOgWS26oSLv5VfMBXghhm/XqG/+omFlBNq5 aT1PXN88b9LAsAAAALdG1kc0BmZWRvcmE= -----END OPENSSH PRIVATE KEY----- + """; private const string TestPassword = "Cafés"; @@ -69,8 +70,8 @@ public PrivateKeyCredentialTests(SshServer sshServer) public async Task CtorWithRawKey() { PrivateKeyCredential credential = new PrivateKeyCredential(PrivateRsaKey.ToArray()); - using var privateKey = await credential.LoadKeyAsync(default); - Assert.NotNull(privateKey); + using var key = await credential.LoadKeyAsync(default); + Assert.NotNull(key.PrivateKey); } [Theory] @@ -143,7 +144,8 @@ public async Task Ecdsa256InMemoryKey() Name expectedalgorithm = new Name("ecdsa-sha2-nistp256"u8.ToArray()); var credential = new ECDsaKeyCredential(ECCurve.NamedCurves.nistP256); - using var privateKey = await credential.LoadKeyAsync(default); + using var key = await credential.LoadKeyAsync(default); + var privateKey = key.PrivateKey; Assert.NotNull(privateKey); Assert.Single(privateKey.Algorithms); Assert.Equal(expectedalgorithm, privateKey.Algorithms[0]); @@ -155,7 +157,8 @@ public async Task Ecdsa384InMemoryKey() Name expectedalgorithm = new Name("ecdsa-sha2-nistp384"u8.ToArray()); var credential = new ECDsaKeyCredential(ECCurve.NamedCurves.nistP384); - using var privateKey = await credential.LoadKeyAsync(default); + using var key = await credential.LoadKeyAsync(default); + var privateKey = key.PrivateKey; Assert.NotNull(privateKey); Assert.Single(privateKey.Algorithms); Assert.Equal(expectedalgorithm, privateKey.Algorithms[0]); @@ -167,7 +170,8 @@ public async Task Ecdsa521InMemoryKey() Name expectedalgorithm = new Name("ecdsa-sha2-nistp521"u8.ToArray()); var credential = new ECDsaKeyCredential(ECCurve.NamedCurves.nistP521); - using var privateKey = await credential.LoadKeyAsync(default); + using var key = await credential.LoadKeyAsync(default); + var privateKey = key.PrivateKey; Assert.NotNull(privateKey); Assert.Single(privateKey.Algorithms); Assert.Equal(expectedalgorithm, privateKey.Algorithms[0]); @@ -216,6 +220,26 @@ await RunWithKeyConversion(_sshServer.TestUserIdentityFile, async (string localK }, async (c) => await c.ConnectAsync()); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task MismatchedKeyNoPromptWhenQueryKey(bool queryKey) + { + using TempFile mismatchedKeyFile = new TempFile(Path.Combine(Path.GetTempPath(), Path.GetRandomFileName())); + File.WriteAllText(mismatchedKeyFile.Path, PrivateRsaKey); + + bool promptCalled = false; + + await Assert.ThrowsAsync(() => + RunWithKeyConversion(mismatchedKeyFile.Path, async (string localKey) => + { + await EncryptSshKey(localKey, "RFC4716", TestPassword, "aes256-ctr"); + return new PrivateKeyCredential(localKey, () => { promptCalled = true; return TestPassword; }, queryKey); + }, async (c) => await c.ConnectAsync())); + + Assert.Equal(!queryKey, promptCalled); + } + [Fact] public async Task OpenSshKeyWithWhitespacePassword() { @@ -268,8 +292,8 @@ private async Task RunWithKeyConversion(string keyFile, Func