Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix matching of pubkeys to key algorithms #607

Merged
merged 9 commits into from
Jul 28, 2020
6 changes: 6 additions & 0 deletions src/main/java/com/hierynomus/sshj/key/KeyAlgorithms.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@
import net.schmizz.sshj.signature.SignatureECDSA;
import net.schmizz.sshj.signature.SignatureRSA;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class KeyAlgorithms {

public static List<String> SSH_RSA_SHA2_ALGORITHMS = Arrays.asList("rsa-sha2-512", "rsa-sha2-256");

public static Factory SSHRSA() { return new Factory("ssh-rsa", new SignatureRSA.FactorySSHRSA(), KeyType.RSA); }
public static Factory SSHRSACertV01() { return new Factory("ssh-rsa-cert-v01@openssh.com", new SignatureRSA.FactoryCERT(), KeyType.RSA_CERT); }
public static Factory RSASHA256() { return new Factory("rsa-sha2-256", new SignatureRSA.FactoryRSASHA256(), KeyType.RSA); }
Expand Down
9 changes: 3 additions & 6 deletions src/main/java/net/schmizz/sshj/transport/KeyExchanger.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,9 @@ private void gotKexInit(SSHPacket buf)
}
kex = Factory.Named.Util.create(transport.getConfig().getKeyExchangeFactories(),
negotiatedAlgs.getKeyExchangeAlgorithm());

List<KeyAlgorithm> keyAlgorithms = new ArrayList<KeyAlgorithm>();
for (String signatureAlgorithm : negotiatedAlgs.getSignatureAlgorithms()) {
keyAlgorithms.add(Factory.Named.Util.create(transport.getConfig().getKeyAlgorithms(), signatureAlgorithm));
}
transport.setKeyAlgorithms(keyAlgorithms);
transport.setHostKeyAlgorithm(Factory.Named.Util.create(transport.getConfig().getKeyAlgorithms(),
negotiatedAlgs.getSignatureAlgorithm()));
transport.setRSASHA2Support(negotiatedAlgs.getRSASHA2Support());

try {
kex.init(transport,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,38 @@
*/
package net.schmizz.sshj.transport;

import java.util.List;

public final class NegotiatedAlgorithms {

private final String kex;
private final List<String> availableSigs;
private final String sig;
private final String c2sCipher;
private final String s2cCipher;
private final String c2sMAC;
private final String s2cMAC;
private final String c2sComp;
private final String s2cComp;

NegotiatedAlgorithms(String kex, List<String> availableSigs, String c2sCipher, String s2cCipher, String c2sMAC, String s2cMAC,
String c2sComp, String s2cComp) {
private final Boolean rsaSHA2Support;
fmeum marked this conversation as resolved.
Show resolved Hide resolved

NegotiatedAlgorithms(String kex, String sig, String c2sCipher, String s2cCipher, String c2sMAC, String s2cMAC,
String c2sComp, String s2cComp, Boolean rsaSHA2Support) {
this.kex = kex;
this.availableSigs = availableSigs;
this.sig = sig;
this.c2sCipher = c2sCipher;
this.s2cCipher = s2cCipher;
this.c2sMAC = c2sMAC;
this.s2cMAC = s2cMAC;
this.c2sComp = c2sComp;
this.s2cComp = s2cComp;
this.rsaSHA2Support = rsaSHA2Support;
}

public String getKeyExchangeAlgorithm() {
return kex;
}

public List<String> getSignatureAlgorithms() {
return availableSigs;
public String getSignatureAlgorithm() {
return sig;
}

public String getClient2ServerCipherAlgorithm() {
Expand All @@ -72,17 +73,22 @@ public String getServer2ClientCompressionAlgorithm() {
return s2cComp;
}

public Boolean getRSASHA2Support() {
return rsaSHA2Support;
}

@Override
public String toString() {
return ("[ " +
"kex=" + kex + "; " +
"availableSigs=" + availableSigs + "; " +
"sig=" + sig + "; " +
"c2sCipher=" + c2sCipher + "; " +
"s2cCipher=" + s2cCipher + "; " +
"c2sMAC=" + c2sMAC + "; " +
"s2cMAC=" + s2cMAC + "; " +
"c2sComp=" + c2sComp + "; " +
"s2cComp=" + s2cComp +
"s2cComp=" + s2cComp + "; " +
"rsaSHA2Support=" + rsaSHA2Support +
" ]");
}

Expand Down
8 changes: 5 additions & 3 deletions src/main/java/net/schmizz/sshj/transport/Proposal.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package net.schmizz.sshj.transport;

import com.hierynomus.sshj.key.KeyAlgorithms;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.Factory;
Expand Down Expand Up @@ -91,7 +92,7 @@ public List<String> getKeyExchangeAlgorithms() {
return kex;
}

public List<String> getSignatureAlgorithms() {
public List<String> getHostKeyAlgorithms() {
return sig;
}

Expand Down Expand Up @@ -127,13 +128,14 @@ public NegotiatedAlgorithms negotiate(Proposal other)
throws TransportException {
return new NegotiatedAlgorithms(
firstMatch(this.getKeyExchangeAlgorithms(), other.getKeyExchangeAlgorithms()),
allMatch(this.getSignatureAlgorithms(), other.getSignatureAlgorithms()),
firstMatch(this.getHostKeyAlgorithms(), other.getHostKeyAlgorithms()),
fmeum marked this conversation as resolved.
Show resolved Hide resolved
firstMatch(this.getClient2ServerCipherAlgorithms(), other.getClient2ServerCipherAlgorithms()),
firstMatch(this.getServer2ClientCipherAlgorithms(), other.getServer2ClientCipherAlgorithms()),
firstMatch(this.getClient2ServerMACAlgorithms(), other.getClient2ServerMACAlgorithms()),
firstMatch(this.getServer2ClientMACAlgorithms(), other.getServer2ClientMACAlgorithms()),
firstMatch(this.getClient2ServerCompressionAlgorithms(), other.getClient2ServerCompressionAlgorithms()),
firstMatch(this.getServer2ClientCompressionAlgorithms(), other.getServer2ClientCompressionAlgorithms())
firstMatch(this.getServer2ClientCompressionAlgorithms(), other.getServer2ClientCompressionAlgorithms()),
other.getHostKeyAlgorithms().containsAll(KeyAlgorithms.SSH_RSA_SHA2_ALGORITHMS)
);
}

Expand Down
5 changes: 4 additions & 1 deletion src/main/java/net/schmizz/sshj/transport/Transport.java
Original file line number Diff line number Diff line change
Expand Up @@ -238,5 +238,8 @@ long write(SSHPacket payload)
*/
void die(Exception e);

KeyAlgorithm getKeyAlgorithm(KeyType keyType) throws TransportException;
void setHostKeyAlgorithm(KeyAlgorithm keyAlgorithm);
fmeum marked this conversation as resolved.
Show resolved Hide resolved
KeyAlgorithm getHostKeyAlgorithm();
void setRSASHA2Support(Boolean rsaSHA2Support);
KeyAlgorithm getClientKeyAlgorithm(KeyType keyType) throws TransportException;
}
37 changes: 27 additions & 10 deletions src/main/java/net/schmizz/sshj/transport/TransportImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package net.schmizz.sshj.transport;

import com.hierynomus.sshj.key.KeyAlgorithm;
import com.hierynomus.sshj.key.KeyAlgorithms;
import com.hierynomus.sshj.transport.IdentificationStringParser;
import net.schmizz.concurrent.ErrorDeliveryUtil;
import net.schmizz.concurrent.Event;
Expand Down Expand Up @@ -89,7 +90,9 @@ static final class ConnInfo {

private final Decoder decoder;

private List<KeyAlgorithm> keyAlgorithms;
private KeyAlgorithm hostKeyAlgorithm;

private Boolean rsaSHA2Support;

private final Event<TransportException> serviceAccept;

Expand Down Expand Up @@ -658,17 +661,31 @@ ConnInfo getConnInfo() {
}

@Override
public KeyAlgorithm getKeyAlgorithm(KeyType keyType) throws TransportException {
for (KeyAlgorithm ka : keyAlgorithms) {
if (ka.getKeyFormat().equals(keyType)) {
return ka;
}
}
public void setHostKeyAlgorithm(KeyAlgorithm keyAlgorithm) {
this.hostKeyAlgorithm = keyAlgorithm;
}

throw new TransportException("Cannot find an available KeyAlgorithm for type " + keyType);
@Override
public KeyAlgorithm getHostKeyAlgorithm() {
return this.hostKeyAlgorithm;
}

public void setKeyAlgorithms(List<KeyAlgorithm> keyAlgorithms) {
this.keyAlgorithms = keyAlgorithms;
@Override
public void setRSASHA2Support(Boolean rsaSHA2Support) {
this.rsaSHA2Support = rsaSHA2Support;
}

@Override
public KeyAlgorithm getClientKeyAlgorithm(KeyType keyType) throws TransportException {
fmeum marked this conversation as resolved.
Show resolved Hide resolved
if (keyType != KeyType.RSA || !rsaSHA2Support) {
return Factory.Named.Util.create(getConfig().getKeyAlgorithms(), keyType.toString());
}

List<Factory.Named<KeyAlgorithm>> factories = getConfig().getKeyAlgorithms();
if (factories != null)
for (Factory.Named<KeyAlgorithm> f : factories)
if (f.getName().equals("ssh-rsa") || KeyAlgorithms.SSH_RSA_SHA2_ALGORITHMS.contains(f.getName()))
return f.create();
throw new TransportException("Cannot find an available KeyAlgorithm for type " + keyType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public boolean next(Message msg, SSHPacket packet)
H = digest.digest();


Signature signature = trans.getKeyAlgorithm(KeyType.fromKey(hostKey)).newSignature();
Signature signature = trans.getHostKeyAlgorithm().newSignature();
signature.initVerify(hostKey);
signature.update(H, 0, H.length);
if (!signature.verify(sig))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private boolean parseGexReply(SSHPacket buffer) throws Buffer.BufferException, G
.putMPInt(k);
digest.update(buf.array(), buf.rpos(), buf.available());
H = digest.digest();
KeyAlgorithm keyAlgorithm = trans.getKeyAlgorithm(KeyType.fromKey(hostKey));
KeyAlgorithm keyAlgorithm = trans.getHostKeyAlgorithm();
Signature signature = keyAlgorithm.newSignature();
signature.initVerify(hostKey);
signature.update(H, 0, H.length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ protected SSHPacket putPubKey(SSHPacket reqBuf)
// public key as 2 strings: [ key type | key blob ]
KeyType keyType = KeyType.fromKey(key);
try {
KeyAlgorithm ka = params.getTransport().getKeyAlgorithm(keyType);
KeyAlgorithm ka = params.getTransport().getClientKeyAlgorithm(keyType);
reqBuf.putString(ka.getKeyAlgorithm())
.putString(new Buffer.PlainBuffer().putPublicKey(key).getCompactData());
return reqBuf;
Expand All @@ -71,7 +71,7 @@ protected SSHPacket putSig(SSHPacket reqBuf)
final KeyType kt = KeyType.fromKey(key);
Signature signature;
try {
signature = params.getTransport().getKeyAlgorithm(kt).newSignature();
signature = params.getTransport().getClientKeyAlgorithm(kt).newSignature();
} catch (TransportException e) {
throw new UserAuthException("No KeyAlgorithm configured for key " + kt);
}
Expand Down