diff --git a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java index bc53a732..ce2853be 100644 --- a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java +++ b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java @@ -170,11 +170,22 @@ private static void ensureReceivedMatchesExpected(Message got, Message expected) private void sendKexInit() throws TransportException { log.debug("Sending SSH_MSG_KEXINIT"); - clientProposal = new Proposal(transport.getConfig()); + String knowHostAlg = findKnowHostAlg(transport.getRemoteHost(), transport.getRemotePort()); + clientProposal = new Proposal(transport.getConfig(), knowHostAlg); transport.write(clientProposal.getPacket()); kexInitSent.set(); } + private String findKnowHostAlg(String hostname, int port) { + for (HostKeyVerifier hkv : hostVerifiers) { + String keyType = hkv.findExistingAlgorithm(hostname, port); + if (keyType != null) { + return keyType; + } + } + return null; + } + private void sendNewKeys() throws TransportException { log.debug("Sending SSH_MSG_NEWKEYS"); diff --git a/src/main/java/net/schmizz/sshj/transport/Proposal.java b/src/main/java/net/schmizz/sshj/transport/Proposal.java index ccc5a528..3a1cfb82 100644 --- a/src/main/java/net/schmizz/sshj/transport/Proposal.java +++ b/src/main/java/net/schmizz/sshj/transport/Proposal.java @@ -38,9 +38,21 @@ class Proposal { private final List s2cComp; private final SSHPacket packet; - public Proposal(Config config) { + public Proposal(Config config, String knowHostAlg) { kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories()); - sig = Factory.Named.Util.getNames(config.getKeyAlgorithms()); + if (knowHostAlg != null) { + List pSig = Factory.Named.Util.getNames(config.getKeyAlgorithms()); + if (pSig.contains(knowHostAlg)) { + ArrayList sigT = new ArrayList<>(); + sigT.add(knowHostAlg); + sigT.addAll(pSig); + sig = sigT; + } else { + sig = pSig; + } + } else { + sig = Factory.Named.Util.getNames(config.getKeyAlgorithms()); + } c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories()); c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories()); c2sComp = s2cComp = Factory.Named.Util.getNames(config.getCompressionFactories()); diff --git a/src/main/java/net/schmizz/sshj/transport/verification/FingerprintVerifier.java b/src/main/java/net/schmizz/sshj/transport/verification/FingerprintVerifier.java index 58656bf5..356af850 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/FingerprintVerifier.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/FingerprintVerifier.java @@ -74,6 +74,11 @@ public static HostKeyVerifier getInstance(String fingerprint) { public boolean verify(String h, int p, PublicKey k) { return SecurityUtils.getFingerprint(k).equals(md5); } + + @Override + public String findExistingAlgorithm(String hostname, int port) { + return null; + } }); } catch (SSHRuntimeException e) { throw e; @@ -120,6 +125,11 @@ public boolean verify(String hostname, int port, PublicKey key) { return Arrays.equals(fingerprintData, digestData); } + @Override + public String findExistingAlgorithm(String hostname, int port) { + return null; + } + @Override public String toString() { return "FingerprintVerifier{digestAlgorithm='" + digestAlgorithm + "'}"; diff --git a/src/main/java/net/schmizz/sshj/transport/verification/HostKeyVerifier.java b/src/main/java/net/schmizz/sshj/transport/verification/HostKeyVerifier.java index bcfc54d7..4d2ea54c 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/HostKeyVerifier.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/HostKeyVerifier.java @@ -35,4 +35,12 @@ public interface HostKeyVerifier { */ boolean verify(String hostname, int port, PublicKey key); + /** + * It is necessary to connect with the type of algorithm that matches an existing know_host entry. + * This will allow a match when we later verify with the negotiated key {@code HostKeyVerifier.verify} + * @param hostname remote hostname + * @param port remote port + * @return existing key type or null if not key matches hostname + */ + String findExistingAlgorithm(String hostname, int port); } diff --git a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java index 8f38472a..ec1a25b4 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java @@ -90,6 +90,10 @@ private void readEntries(BufferedReader br) throws IOException { } } + private String adjustHostname(final String hostname, final int port) { + String lowerHN = hostname.toLowerCase(); + return (port != 22) ? "[" + lowerHN + "]:" + port : lowerHN; + } public File getFile() { return khFile; @@ -103,7 +107,7 @@ public boolean verify(final String hostname, final int port, final PublicKey key return false; } - final String adjustedHostname = (port != 22) ? "[" + hostname + "]:" + port : hostname; + final String adjustedHostname = adjustHostname(hostname, port); boolean foundApplicableHostEntry = false; for (KnownHostEntry e : entries) { @@ -127,6 +131,19 @@ public boolean verify(final String hostname, final int port, final PublicKey key return hostKeyUnverifiableAction(adjustedHostname, key); } + @Override + public String findExistingAlgorithm(String hostname, int port) { + final String adjustedHostname = adjustHostname(hostname, port); + for (KnownHostEntry e : entries) { + try { + if (e.appliesTo(adjustedHostname)) { + return e.getType().toString(); + } + } catch (IOException ioe) {} + } + return null; + } + protected boolean hostKeyUnverifiableAction(String hostname, PublicKey key) { return false; } diff --git a/src/main/java/net/schmizz/sshj/transport/verification/PromiscuousVerifier.java b/src/main/java/net/schmizz/sshj/transport/verification/PromiscuousVerifier.java index a95cc41f..bb202567 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/PromiscuousVerifier.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/PromiscuousVerifier.java @@ -25,4 +25,9 @@ public boolean verify(String hostname, int port, PublicKey key) { return true; } + @Override + public String findExistingAlgorithm(String hostname, int port) { + return null; + } + }