diff --git a/src/main/java/net/schmizz/sshj/transport/Transport.java b/src/main/java/net/schmizz/sshj/transport/Transport.java index d187c49d..5ae55968 100644 --- a/src/main/java/net/schmizz/sshj/transport/Transport.java +++ b/src/main/java/net/schmizz/sshj/transport/Transport.java @@ -28,6 +28,7 @@ import java.io.InputStream; import java.io.OutputStream; +import java.util.List; import java.util.concurrent.TimeUnit; /** Transport layer of the SSH protocol. */ @@ -224,5 +225,5 @@ long write(SSHPacket payload) void die(Exception e); KeyAlgorithm getHostKeyAlgorithm(); - KeyAlgorithm getClientKeyAlgorithm(KeyType keyType) throws TransportException; + List getClientKeyAlgorithms(KeyType keyType) throws TransportException; } diff --git a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java index 137a9fc1..58107c5b 100644 --- a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java +++ b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java @@ -33,6 +33,7 @@ import java.io.InputStream; import java.io.OutputStream; import java.net.InetSocketAddress; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.ReentrantLock; @@ -637,15 +638,18 @@ public KeyAlgorithm getHostKeyAlgorithm() { } @Override - public KeyAlgorithm getClientKeyAlgorithm(KeyType keyType) throws TransportException { + public List getClientKeyAlgorithms(KeyType keyType) throws TransportException { List> factories = getConfig().getKeyAlgorithms(); + List available = new ArrayList<>(); if (factories != null) for (Factory.Named f : factories) if ( f instanceof KeyAlgorithms.Factory && ((KeyAlgorithms.Factory) f).getKeyType().equals(keyType) - || !(f instanceof KeyAlgorithms.Factory) && f.getName().equals(keyType.toString()) + || !(f instanceof KeyAlgorithms.Factory) && f.getName().equals(keyType.toString()) ) - return f.create(); - throw new TransportException("Cannot find an available KeyAlgorithm for type " + keyType); + available.add(f.create()); + if (available.isEmpty()) + throw new TransportException("Cannot find an available KeyAlgorithm for type " + keyType); + return available; } } diff --git a/src/main/java/net/schmizz/sshj/userauth/method/KeyedAuthMethod.java b/src/main/java/net/schmizz/sshj/userauth/method/KeyedAuthMethod.java index 3ec23e76..719e2f33 100644 --- a/src/main/java/net/schmizz/sshj/userauth/method/KeyedAuthMethod.java +++ b/src/main/java/net/schmizz/sshj/userauth/method/KeyedAuthMethod.java @@ -27,17 +27,36 @@ import java.io.IOException; import java.security.PrivateKey; import java.security.PublicKey; +import java.util.LinkedList; +import java.util.Queue; public abstract class KeyedAuthMethod extends AbstractAuthMethod { protected final KeyProvider kProv; + private Queue available; public KeyedAuthMethod(String name, KeyProvider kProv) { super(name); this.kProv = kProv; } + private KeyAlgorithm getPublicKeyAlgorithm(KeyType keyType) throws TransportException { + if (available == null) { + available = new LinkedList<>(params.getTransport().getClientKeyAlgorithms(keyType)); + } + return available.peek(); + } + + @Override + public boolean shouldRetry() { + if (available != null) { + available.poll(); + return !available.isEmpty(); + } + return false; + } + protected SSHPacket putPubKey(SSHPacket reqBuf) throws UserAuthException { PublicKey key; @@ -50,7 +69,7 @@ protected SSHPacket putPubKey(SSHPacket reqBuf) // public key as 2 strings: [ key type | key blob ] KeyType keyType = KeyType.fromKey(key); try { - KeyAlgorithm ka = params.getTransport().getClientKeyAlgorithm(keyType); + KeyAlgorithm ka = getPublicKeyAlgorithm(keyType); if (ka != null) { reqBuf.putString(ka.getKeyAlgorithm()) .putString(new Buffer.PlainBuffer().putPublicKey(key).getCompactData()); @@ -74,7 +93,7 @@ protected SSHPacket putSig(SSHPacket reqBuf) final KeyType kt = KeyType.fromKey(key); Signature signature; try { - signature = params.getTransport().getClientKeyAlgorithm(kt).newSignature(); + signature = getPublicKeyAlgorithm(kt).newSignature(); } catch (TransportException e) { throw new UserAuthException("No KeyAlgorithm configured for key " + kt); }