Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try all public key algorithms available for a specific key type in SSH_MSG_USERAUTH_REQUEST #763

Merged
merged 1 commit into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/main/java/net/schmizz/sshj/transport/Transport.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
import java.util.concurrent.TimeUnit;

/** Transport layer of the SSH protocol. */
Expand Down Expand Up @@ -224,5 +225,5 @@ long write(SSHPacket payload)
void die(Exception e);

KeyAlgorithm getHostKeyAlgorithm();
KeyAlgorithm getClientKeyAlgorithm(KeyType keyType) throws TransportException;
List<KeyAlgorithm> getClientKeyAlgorithms(KeyType keyType) throws TransportException;
}
12 changes: 8 additions & 4 deletions src/main/java/net/schmizz/sshj/transport/TransportImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
Expand Down Expand Up @@ -637,15 +638,18 @@ public KeyAlgorithm getHostKeyAlgorithm() {
}

@Override
public KeyAlgorithm getClientKeyAlgorithm(KeyType keyType) throws TransportException {
public List<KeyAlgorithm> getClientKeyAlgorithms(KeyType keyType) throws TransportException {
List<Factory.Named<KeyAlgorithm>> factories = getConfig().getKeyAlgorithms();
List<KeyAlgorithm> available = new ArrayList<>();
if (factories != null)
for (Factory.Named<KeyAlgorithm> f : factories)
if (
f instanceof KeyAlgorithms.Factory && ((KeyAlgorithms.Factory) f).getKeyType().equals(keyType)
|| !(f instanceof KeyAlgorithms.Factory) && f.getName().equals(keyType.toString())
|| !(f instanceof KeyAlgorithms.Factory) && f.getName().equals(keyType.toString())
)
return f.create();
throw new TransportException("Cannot find an available KeyAlgorithm for type " + keyType);
available.add(f.create());
if (available.isEmpty())
throw new TransportException("Cannot find an available KeyAlgorithm for type " + keyType);
return available;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,36 @@
import java.io.IOException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.util.LinkedList;
import java.util.Queue;

public abstract class KeyedAuthMethod
extends AbstractAuthMethod {

protected final KeyProvider kProv;
private Queue<KeyAlgorithm> available;

public KeyedAuthMethod(String name, KeyProvider kProv) {
super(name);
this.kProv = kProv;
}

private KeyAlgorithm getPublicKeyAlgorithm(KeyType keyType) throws TransportException {
if (available == null) {
available = new LinkedList<>(params.getTransport().getClientKeyAlgorithms(keyType));
}
return available.peek();
}

@Override
public boolean shouldRetry() {
if (available != null) {
available.poll();
return !available.isEmpty();
}
return false;
}

protected SSHPacket putPubKey(SSHPacket reqBuf)
throws UserAuthException {
PublicKey key;
Expand All @@ -50,7 +69,7 @@ protected SSHPacket putPubKey(SSHPacket reqBuf)
// public key as 2 strings: [ key type | key blob ]
KeyType keyType = KeyType.fromKey(key);
try {
KeyAlgorithm ka = params.getTransport().getClientKeyAlgorithm(keyType);
KeyAlgorithm ka = getPublicKeyAlgorithm(keyType);
if (ka != null) {
reqBuf.putString(ka.getKeyAlgorithm())
.putString(new Buffer.PlainBuffer().putPublicKey(key).getCompactData());
Expand All @@ -74,7 +93,7 @@ protected SSHPacket putSig(SSHPacket reqBuf)
final KeyType kt = KeyType.fromKey(key);
Signature signature;
try {
signature = params.getTransport().getClientKeyAlgorithm(kt).newSignature();
signature = getPublicKeyAlgorithm(kt).newSignature();
} catch (TransportException e) {
throw new UserAuthException("No KeyAlgorithm configured for key " + kt);
}
Expand Down