diff --git a/src/main/java/net/schmizz/sshj/transport/Converter.java b/src/main/java/net/schmizz/sshj/transport/Converter.java index 9d532f96..6f3431c3 100644 --- a/src/main/java/net/schmizz/sshj/transport/Converter.java +++ b/src/main/java/net/schmizz/sshj/transport/Converter.java @@ -51,6 +51,14 @@ long getSequenceNumber() { return seq; } + void resetSequenceNumber() { + seq = -1; + } + + boolean isSequenceNumberAtMax() { + return seq == 0xffffffffL; + } + void setAlgorithms(Cipher cipher, MAC mac, Compression compression) { this.cipher = cipher; this.mac = mac; diff --git a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java index b8979f7b..5f0b3b1f 100644 --- a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java +++ b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java @@ -60,6 +60,10 @@ private static enum Expected { private final AtomicBoolean kexOngoing = new AtomicBoolean(); + private final AtomicBoolean initialKex = new AtomicBoolean(); + + private final AtomicBoolean strictKex = new AtomicBoolean(); + /** What we are expecting from the next packet */ private Expected expected = Expected.KEXINIT; @@ -123,6 +127,19 @@ boolean isKexOngoing() { return kexOngoing.get(); } + void markKexAsInitial() { + initialKex.set(true); + strictKex.set(false); + } + + boolean isStrictKex() { + return strictKex.get(); + } + + boolean isInitialKex() { + return initialKex.get(); + } + /** * Starts key exchange by sending a {@code SSH_MSG_KEXINIT} packet. Key exchange needs to be done once mandatorily * after initializing the {@link Transport} for it to be usable and may be initiated at any later point e.g. if @@ -183,7 +200,7 @@ private void sendKexInit() throws TransportException { log.debug("Sending SSH_MSG_KEXINIT"); List knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort()); - clientProposal = new Proposal(transport.getConfig(), knownHostAlgs); + clientProposal = new Proposal(transport.getConfig(), knownHostAlgs, initialKex.get()); transport.write(clientProposal.getPacket()); kexInitSent.set(); } @@ -202,6 +219,9 @@ private void sendNewKeys() throws TransportException { log.debug("Sending SSH_MSG_NEWKEYS"); transport.write(new SSHPacket(Message.NEWKEYS)); + if (strictKex.get()) { + transport.getEncoder().resetSequenceNumber(); + } } /** @@ -234,6 +254,10 @@ private synchronized void verifyHost(PublicKey key) private void setKexDone() { kexOngoing.set(false); + initialKex.set(false); + if (strictKex.get()) { + transport.getDecoder().resetSequenceNumber(); + } kexInitSent.clear(); done.set(); } @@ -242,6 +266,7 @@ private void gotKexInit(SSHPacket buf) throws TransportException { buf.rpos(buf.rpos() - 1); final Proposal serverProposal = new Proposal(buf); + gotStrictKexInfo(serverProposal); negotiatedAlgs = clientProposal.negotiate(serverProposal); log.debug("Negotiated algorithms: {}", negotiatedAlgs); for(AlgorithmsVerifier v: algorithmVerifiers) { @@ -265,6 +290,18 @@ private void gotKexInit(SSHPacket buf) } } + private void gotStrictKexInfo(Proposal serverProposal) throws TransportException { + if (initialKex.get() && serverProposal.isStrictKeyExchangeSupportedByServer()) { + strictKex.set(true); + log.debug("Enabling strict key exchange extension"); + if (transport.getDecoder().getSequenceNumber() != 0) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "SSH_MSG_KEXINIT was not first package during strict key exchange" + ); + } + } + } + /** * Private method used while putting new keys into use that will resize the key used to initialize the cipher to the * needed length. diff --git a/src/main/java/net/schmizz/sshj/transport/Proposal.java b/src/main/java/net/schmizz/sshj/transport/Proposal.java index 5f5f8a1f..3a4102dd 100644 --- a/src/main/java/net/schmizz/sshj/transport/Proposal.java +++ b/src/main/java/net/schmizz/sshj/transport/Proposal.java @@ -37,8 +37,11 @@ class Proposal { private final List s2cComp; private final SSHPacket packet; - public Proposal(Config config, List knownHostAlgs) { + public Proposal(Config config, List knownHostAlgs, boolean initialKex) { kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories()); + if (initialKex) { + kex.add("kex-strict-c-v00@openssh.com"); + } sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs); c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories()); c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories()); @@ -91,6 +94,10 @@ public List getKeyExchangeAlgorithms() { return kex; } + public boolean isStrictKeyExchangeSupportedByServer() { + return kex.contains("kex-strict-s-v00@openssh.com"); + } + public List getHostKeyAlgorithms() { return sig; } diff --git a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java index 58107c5b..8b1d83f6 100644 --- a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java +++ b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java @@ -166,6 +166,7 @@ public void init(String remoteHost, int remotePort, InputStream in, OutputStream } ThreadNameProvider.setThreadName(reader, this); + kexer.markKexAsInitial(); reader.start(); } @@ -426,7 +427,7 @@ public long write(SSHPacket payload) assert m != Message.KEXINIT; kexer.waitForDone(); } - } else if (encoder.getSequenceNumber() == 0) // We get here every 2**32th packet + } else if (encoder.isSequenceNumberAtMax()) // We get here every 2**32th packet kexer.startKex(true); final long seq = encoder.encode(payload); @@ -479,9 +480,20 @@ public void handle(Message msg, SSHPacket buf) log.trace("Received packet {}", msg); + if (kexer.isInitialKex()) { + if (decoder.isSequenceNumberAtMax()) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "Sequence number of decoder is about to wrap during initial key exchange"); + } + if (kexer.isStrictKex() && !isKexerPacket(msg) && msg != Message.DISCONNECT) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "Unexpected packet type during initial strict key exchange"); + } + } + if (msg.geq(50)) { // not a transport layer packet service.handle(msg, buf); - } else if (msg.in(20, 21) || msg.in(30, 49)) { // kex packet + } else if (isKexerPacket(msg)) { kexer.handle(msg, buf); } else { switch (msg) { @@ -513,6 +525,10 @@ public void handle(Message msg, SSHPacket buf) } } + private static boolean isKexerPacket(Message msg) { + return msg.in(20, 21) || msg.in(30, 49); + } + private void gotDebug(SSHPacket buf) throws TransportException { try { diff --git a/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java b/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java index c1f8655a..e6160e70 100644 --- a/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java +++ b/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java @@ -112,7 +112,7 @@ private void performAndCheckKeyExchange() throws TransportException { } private SSHPacket getKexinitPacket() { - SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList()).getPacket(); + SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList(), false).getPacket(); kexinitPacket.rpos(kexinitPacket.rpos() + 1); return kexinitPacket; }