Skip to content

Commit

Permalink
PrivateKeyCredential: support avoiding prompts for keys that server w…
Browse files Browse the repository at this point in the history
…on't accept.
  • Loading branch information
tmds committed Dec 25, 2024
1 parent aa301e2 commit 9428c15
Show file tree
Hide file tree
Showing 13 changed files with 193 additions and 83 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -489,18 +489,19 @@ abstract class Credential
class PrivateKeyCredential : Credential
{
PrivateKeyCredential(string path, string? password = null, string? identifier ??= path);
PrivateKeyCredential(string path, Func<string?> passwordPrompt, string? identifier ??= path);
PrivateKeyCredential(string path, Func<string?> passwordPrompt, bool queryKey = true, string? identifier ??= path);

PrivateKeyCredential(char[] rawKey, string? password = null, string identifier = "[raw key]");
PrivateKeyCredential(char[] rawKey, Func<string?> passwordPrompt, string identifier = "[raw key]");
PrivateKeyCredential(char[] rawKey, Func<string?> passwordPrompt, bool queryKey = true, string identifier = "[raw key]");

// Enable derived classes to use private keys from other sources.
protected PrivateKeyCredential(Func<CancellationToken, ValueTask<Key>> loadKey, string identifier);
protected struct Key
{
Key(RSA rsa);
Key(ECDsa ecdsa);
Key(ReadOnlyMemory<char> rawKey, Func<string?>? passwordPrompt = null);
Key(ReadOnlyMemory<char> rawKey, string? password = null);
Key(ReadOnlyMemory<char> rawKey, Func<string?> passwordPrompt, bool queryKey = true);
}
}
class PasswordCredential : Credential
Expand Down
38 changes: 36 additions & 2 deletions src/Tmds.Ssh/AlgorithmNames.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ static class AlgorithmNames // TODO: rename to KnownNames
private static readonly byte[] EcdhSha2Nistp521Bytes = "ecdh-sha2-nistp521"u8.ToArray();
public static Name EcdhSha2Nistp521 => new Name(EcdhSha2Nistp521Bytes);

// Host key algorithms.
// Host key algorithms: key types and signature algorithms.
private static readonly byte[] SshRsaBytes = "ssh-rsa"u8.ToArray();
public static Name SshRsa => new Name(SshRsaBytes);
private static readonly byte[] RsaSshSha2_256Bytes = "rsa-sha2-256"u8.ToArray();
Expand All @@ -33,6 +33,41 @@ static class AlgorithmNames // TODO: rename to KnownNames
public static Name EcdsaSha2Nistp521 => new Name(EcdsaSha2Nistp521Bytes);
private static readonly byte[] SshEd25519Bytes = "ssh-ed25519"u8.ToArray();
public static Name SshEd25519 => new Name(SshEd25519Bytes);
// Key type to signature algorithms.
public static readonly Name[] SshRsaAlgorithms = [ RsaSshSha2_512, RsaSshSha2_256 ];
public static readonly Name[] EcdsaSha2Nistp256Algorithms = [ EcdsaSha2Nistp256 ];
public static readonly Name[] EcdsaSha2Nistp384Algorithms = [ EcdsaSha2Nistp384 ];
public static readonly Name[] EcdsaSha2Nistp521Algorithms = [ EcdsaSha2Nistp521 ];
public static readonly Name[] SshEd25519Algorithms = [ SshEd25519 ];

public static Name[] GetAlgorithmsForKeyType(Name keyType)
{
if (keyType == SshRsa)
{
return SshRsaAlgorithms;
}
else if (keyType == EcdsaSha2Nistp256)
{
return EcdsaSha2Nistp256Algorithms;
}
else if (keyType == EcdsaSha2Nistp384)
{
return EcdsaSha2Nistp384Algorithms;
}
else if (keyType == EcdsaSha2Nistp521)
{
return EcdsaSha2Nistp521Algorithms;
}
else if (keyType == SshEd25519)
{
return SshEd25519Algorithms;
}
else
{
// Unknown key types.
return [ keyType ];
}
}

// Encryption algorithms.
private static readonly byte[] Aes128CbcBytes = "aes128-cbc"u8.ToArray();
Expand Down Expand Up @@ -72,7 +107,6 @@ static class AlgorithmNames // TODO: rename to KnownNames

// These fields are initialized in order, so these list must be created after the names.
// Algorithms are in **order of preference**.
public static readonly Name[] SshRsaAlgorithms = [ RsaSshSha2_512, RsaSshSha2_256 ];

// Authentications
private static readonly byte[] GssApiWithMicBytes = "gssapi-with-mic"u8.ToArray();
Expand Down
6 changes: 3 additions & 3 deletions src/Tmds.Ssh/ECDsaPrivateKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ sealed class ECDsaPrivateKey : PrivateKey
private readonly HashAlgorithmName _hashAlgorithm;

public ECDsaPrivateKey(ECDsa ecdsa, Name algorithm, Name curveName, HashAlgorithmName hashAlgorithm, SshKey sshPublicKey) :
base([algorithm], sshPublicKey)
base(AlgorithmNames.GetAlgorithmsForKeyType(algorithm), sshPublicKey)
{
_ecdsa = ecdsa ?? throw new ArgumentNullException(nameof(ecdsa));
_algorithm = algorithm;
Expand All @@ -41,7 +41,7 @@ public static SshKey DeterminePublicSshKey(ECDsa ecdsa, Name algorithm, Name cur
return new SshKey(algorithm, writer.ToArray());
}

public override ValueTask<byte[]?> TrySignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken)
public override ValueTask<byte[]> SignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken)
{
if (algorithm != _algorithm)
{
Expand All @@ -67,6 +67,6 @@ public static SshKey DeterminePublicSshKey(ECDsa ecdsa, Name algorithm, Name cur
innerWriter.WriteString(algorithm);
innerWriter.WriteString(ecdsaSigWriter.ToArray());

return ValueTask.FromResult((byte[]?)innerWriter.ToArray());
return ValueTask.FromResult(innerWriter.ToArray());
}
}
6 changes: 3 additions & 3 deletions src/Tmds.Ssh/Ed25519PrivateKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ sealed class Ed25519PrivateKey : PrivateKey
private readonly byte[] _publicKey;

public Ed25519PrivateKey(byte[] privateKey, byte[] publicKey, SshKey sshPublicKey) :
base([AlgorithmNames.SshEd25519], sshPublicKey)
base(AlgorithmNames.SshEd25519Algorithms, sshPublicKey)
{
_privateKey = privateKey;
_publicKey = publicKey;
Expand All @@ -31,7 +31,7 @@ public static SshKey DeterminePublicSshKey(byte[] privateKey, byte[] publicKey)
return new SshKey(AlgorithmNames.SshEd25519, writer.ToArray());
}

public override ValueTask<byte[]?> TrySignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken)
public override ValueTask<byte[]> SignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken)
{
if (algorithm != Algorithms[0])
{
Expand All @@ -55,6 +55,6 @@ public static SshKey DeterminePublicSshKey(byte[] privateKey, byte[] publicKey)
innerWriter.WriteString(algorithm);
innerWriter.WriteString(signature);

return ValueTask.FromResult((byte[]?)innerWriter.ToArray());
return ValueTask.FromResult(innerWriter.ToArray());
}
}
4 changes: 1 addition & 3 deletions src/Tmds.Ssh/PrivateKey.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

using System.Buffers;

namespace Tmds.Ssh;

abstract class PrivateKey : IDisposable
Expand All @@ -18,5 +16,5 @@ private protected PrivateKey(Name[] algorithms, SshKey publicKey)

public abstract void Dispose();

public abstract ValueTask<byte[]?> TrySignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken);
public abstract ValueTask<byte[]> SignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken);
}
78 changes: 61 additions & 17 deletions src/Tmds.Ssh/PrivateKeyCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ public class PrivateKeyCredential : Credential
private Func<CancellationToken, ValueTask<Key>> LoadKey { get; }

public PrivateKeyCredential(string path, string? password = null, string? identifier = null) :
this(path, () => password, identifier)
this(path, () => password, queryKey: false, identifier)
{ }

public PrivateKeyCredential(string path, Func<string?> passwordPrompt, string? identifier = null) :
this(LoadKeyFromFile(path ?? throw new ArgumentNullException(nameof(path)), passwordPrompt), identifier ?? path)
public PrivateKeyCredential(string path, Func<string?> passwordPrompt, bool queryKey = true, string? identifier = null) :
this(LoadKeyFromFile(path ?? throw new ArgumentNullException(nameof(path)), passwordPrompt, queryKey), identifier ?? path)
{ }

public PrivateKeyCredential(char[] rawKey, string? password = null, string identifier = "[raw key]") :
this(rawKey, () => password, identifier)
this(rawKey, () => password, queryKey: false, identifier)
{ }

public PrivateKeyCredential(char[] rawKey, Func<string?> passwordPrompt, string identifier = "[raw key]") :
this(LoadRawKey(ValidateRawKeyArgument(rawKey), passwordPrompt), identifier)
public PrivateKeyCredential(char[] rawKey, Func<string?> passwordPrompt, bool queryKey = true, string identifier = "[raw key]") :
this(LoadRawKey(ValidateRawKeyArgument(rawKey), passwordPrompt, queryKey), identifier)
{ }

// Allows the user to implement derived classes that represent a private key.
Expand All @@ -43,15 +43,15 @@ private static char[] ValidateRawKeyArgument(char[] rawKey)
return rawKey;
}

private static Func<CancellationToken, ValueTask<Key>> LoadRawKey(char[] rawKey, Func<string?>? passwordPrompt)
private static Func<CancellationToken, ValueTask<Key>> LoadRawKey(char[] rawKey, Func<string?> passwordPrompt, bool queryKey)
=> (CancellationToken cancellationToken) =>
{
Key key = new Key(rawKey.AsMemory(), passwordPrompt);
Key key = new Key(rawKey.AsMemory(), passwordPrompt, queryKey);

return ValueTask.FromResult(key);
};

private static Func<CancellationToken, ValueTask<Key>> LoadKeyFromFile(string path, Func<string?> passwordPrompt)
private static Func<CancellationToken, ValueTask<Key>> LoadKeyFromFile(string path, Func<string?> passwordPrompt, bool queryKey)
=> (CancellationToken cancellationToken) =>
{
string rawKey;
Expand All @@ -65,26 +65,37 @@ private static Func<CancellationToken, ValueTask<Key>> LoadKeyFromFile(string pa
return ValueTask.FromResult(default(Key)); // not found.
}

Key key = new Key(rawKey.AsMemory(), passwordPrompt);
Key key = new Key(rawKey.AsMemory(), passwordPrompt, queryKey);

return ValueTask.FromResult(key);
};

// This is a type we expose to our derive types to avoid having to expose PrivateKey and a bunch of other internals.
protected readonly struct Key
internal protected readonly struct Key : IDisposable
{
internal PrivateKey? PrivateKey { get; }

internal bool QueryKey { get; }

public Key(RSA rsa)
{
PrivateKey = new RsaPrivateKey(rsa, RsaPrivateKey.DeterminePublicSshKey(rsa));
}

public Key(ReadOnlyMemory<char> rawKey, Func<string?>? passwordPrompt = null)
public Key(ReadOnlyMemory<char> rawKey, string? password = null)
{
QueryKey = false;

PrivateKey = PrivateKeyParser.ParsePrivateKey(rawKey, passwordPrompt: delegate { return password; });
}

public Key(ReadOnlyMemory<char> rawKey, Func<string?> passwordPrompt, bool queryKey = true)
{
passwordPrompt ??= delegate { return null; };
ArgumentNullException.ThrowIfNull(passwordPrompt);

QueryKey = queryKey;

PrivateKey = PrivateKeyParser.ParsePrivateKey(rawKey, passwordPrompt);
PrivateKey = queryKey ? ParsedPrivateKey.Create(rawKey, passwordPrompt) : PrivateKeyParser.ParsePrivateKey(rawKey, passwordPrompt);
}

public Key(ECDsa ecdsa)
Expand Down Expand Up @@ -122,11 +133,44 @@ internal Key(PrivateKey key)

private static bool OidEquals(Oid oidA, Oid oidB)
=> oidA.Value is not null && oidB.Value is not null && oidA.Value == oidB.Value;

public void Dispose()
{
PrivateKey?.Dispose();
}
}

internal async ValueTask<Key> LoadKeyAsync(CancellationToken cancellationToken)
{
return await LoadKey(cancellationToken);
}

internal async ValueTask<PrivateKey?> LoadKeyAsync(CancellationToken cancellationToken)
sealed class ParsedPrivateKey : PrivateKey
{
Key key = await LoadKey(cancellationToken);
return key.PrivateKey;
private ReadOnlyMemory<char> _rawKey;
private Func<string?> _passwordPrompt;

public static PrivateKey Create(ReadOnlyMemory<char> rawKey, Func<string?> passwordPrompt)
{
SshKey sshKey = PrivateKeyParser.ParsePublicKey(rawKey);
Name[] algorithms = AlgorithmNames.GetAlgorithmsForKeyType(sshKey.Type);
return new ParsedPrivateKey(algorithms, sshKey, rawKey, passwordPrompt);
}

private ParsedPrivateKey(Name[] algorithms, SshKey publicKey, ReadOnlyMemory<char> rawKey, Func<string?> passwordPrompt)
: base(algorithms, publicKey)
{
_rawKey = rawKey;
_passwordPrompt = passwordPrompt;
}

public override void Dispose()
{ }

public override ValueTask<byte[]> SignAsync(Name algorithm, byte[] data, CancellationToken cancellationToken)
{
using PrivateKey pk = PrivateKeyParser.ParsePrivateKey(_rawKey, _passwordPrompt);
return pk.SignAsync(algorithm, data, cancellationToken);
}
}
}
32 changes: 19 additions & 13 deletions src/Tmds.Ssh/PrivateKeyParser.OpenSsh.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// See file LICENSE for full license details.

using System.Buffers;
using System.Diagnostics.CodeAnalysis;
using System.Numerics;
using System.Security.Cryptography;
using System.Text;
Expand All @@ -15,9 +14,10 @@ partial class PrivateKeyParser
/// Parses an OpenSSH PEM formatted key. This is a new key format used by
/// OpenSSH for private keys.
/// </summary>
internal static PrivateKey ParseOpenSshKey(
internal static (SshKey PublicKey, PrivateKey? PrivateKey) ParseOpenSshKey(
byte[] keyData,
Func<string?> passwordPrompt)
Func<string?> passwordPrompt,
bool parsePrivate)
{
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
/*
Expand Down Expand Up @@ -48,7 +48,13 @@ string publickeyN
{
throw new FormatException($"The data contains multiple keys.");
}
byte[] publicKey = reader.ReadStringAsByteArray();

SshKey publicKey = reader.ReadSshKey();
if (!parsePrivate)
{
return (publicKey, null);
}

ReadOnlySequence<byte> privateKeyList;
if (cipherName == AlgorithmNames.None)
{
Expand Down Expand Up @@ -94,15 +100,15 @@ byte padlen % 255
Name keyType = reader.ReadName();
if (keyType == AlgorithmNames.SshRsa)
{
return ParseOpenSshRsaKey(publicKey, reader);
return (publicKey, ParseOpenSshRsaKey(publicKey, reader));
}
else if (keyType.ToString().StartsWith("ecdsa-sha2-"))
{
return ParseOpenSshEcdsaKey(publicKey, keyType, reader);
return (publicKey, ParseOpenSshEcdsaKey(publicKey, keyType, reader));
}
else if (keyType == AlgorithmNames.SshEd25519)
{
return ParseOpenSshEd25519Key(publicKey, reader);
return (publicKey, ParseOpenSshEd25519Key(publicKey, reader));
}
else
{
Expand Down Expand Up @@ -166,7 +172,7 @@ uint32 rounds
}
}

private static PrivateKey ParseOpenSshRsaKey(byte[] publicKey, SequenceReader reader)
private static PrivateKey ParseOpenSshRsaKey(SshKey publicKey, SequenceReader reader)
{
// .NET RSA's class has some length expectations:
// D must have the same length as Modulus.
Expand Down Expand Up @@ -198,7 +204,7 @@ private static PrivateKey ParseOpenSshRsaKey(byte[] publicKey, SequenceReader re
try
{
rsa.ImportParameters(parameters);
return new RsaPrivateKey(rsa, new SshKey(AlgorithmNames.SshRsa, publicKey));
return new RsaPrivateKey(rsa, publicKey);
}
catch (Exception ex)
{
Expand All @@ -207,7 +213,7 @@ private static PrivateKey ParseOpenSshRsaKey(byte[] publicKey, SequenceReader re
}
}

private static PrivateKey ParseOpenSshEcdsaKey(byte[] publicKey, Name keyType, SequenceReader reader)
private static PrivateKey ParseOpenSshEcdsaKey(SshKey publicKey, Name keyType, SequenceReader reader)
{
Name curveName = reader.ReadName();

Expand Down Expand Up @@ -247,7 +253,7 @@ private static PrivateKey ParseOpenSshEcdsaKey(byte[] publicKey, Name keyType, S
};

ecdsa.ImportParameters(parameters);
return new ECDsaPrivateKey(ecdsa, keyType, curveName, hashAlgorithm, new SshKey(keyType, publicKey));
return new ECDsaPrivateKey(ecdsa, keyType, curveName, hashAlgorithm, publicKey);
}
catch (Exception ex)
{
Expand All @@ -256,7 +262,7 @@ private static PrivateKey ParseOpenSshEcdsaKey(byte[] publicKey, Name keyType, S
}
}

private static PrivateKey ParseOpenSshEd25519Key(byte[] sshPublicKey, SequenceReader reader)
private static PrivateKey ParseOpenSshEd25519Key(SshKey sshPublicKey, SequenceReader reader)
{
// https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent-14#section-3.2.3
/*
Expand All @@ -276,7 +282,7 @@ concatenation of the private key k and the public ENC(A) key. Why it is
return new Ed25519PrivateKey(
keyData.Slice(0, keyData.Length - publicKey.Length).ToArray(),
publicKey.ToArray(),
new SshKey(AlgorithmNames.SshEd25519, sshPublicKey));
sshPublicKey);
}
catch (Exception ex)
{
Expand Down
Loading

0 comments on commit 9428c15

Please sign in to comment.