Skip to content

Commit

Permalink
Implement OpenSSH strict key exchange extension
Browse files Browse the repository at this point in the history
  • Loading branch information
hpoettker committed Dec 20, 2023
1 parent 50c753d commit 92bdefb
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/main/java/net/schmizz/sshj/transport/Converter.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ long getSequenceNumber() {
return seq;
}

void resetSequenceNumber() {
seq = -1;
}

boolean isSequenceNumberAtMax() {
return seq == 0xffffffffL;
}

void setAlgorithms(Cipher cipher, MAC mac, Compression compression) {
this.cipher = cipher;
this.mac = mac;
Expand Down
39 changes: 38 additions & 1 deletion src/main/java/net/schmizz/sshj/transport/KeyExchanger.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ private static enum Expected {

private final AtomicBoolean kexOngoing = new AtomicBoolean();

private final AtomicBoolean initialKex = new AtomicBoolean();

private final AtomicBoolean strictKex = new AtomicBoolean();

/** What we are expecting from the next packet */
private Expected expected = Expected.KEXINIT;

Expand Down Expand Up @@ -123,6 +127,19 @@ boolean isKexOngoing() {
return kexOngoing.get();
}

void markKexAsInitial() {
initialKex.set(true);
strictKex.set(false);
}

boolean isStrictKex() {
return strictKex.get();
}

boolean isInitialKex() {
return initialKex.get();
}

/**
* Starts key exchange by sending a {@code SSH_MSG_KEXINIT} packet. Key exchange needs to be done once mandatorily
* after initializing the {@link Transport} for it to be usable and may be initiated at any later point e.g. if
Expand Down Expand Up @@ -183,7 +200,7 @@ private void sendKexInit()
throws TransportException {
log.debug("Sending SSH_MSG_KEXINIT");
List<String> knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort());
clientProposal = new Proposal(transport.getConfig(), knownHostAlgs);
clientProposal = new Proposal(transport.getConfig(), knownHostAlgs, initialKex.get());
transport.write(clientProposal.getPacket());
kexInitSent.set();
}
Expand All @@ -202,6 +219,9 @@ private void sendNewKeys()
throws TransportException {
log.debug("Sending SSH_MSG_NEWKEYS");
transport.write(new SSHPacket(Message.NEWKEYS));
if (strictKex.get()) {
transport.getEncoder().resetSequenceNumber();
}
}

/**
Expand Down Expand Up @@ -234,6 +254,10 @@ private synchronized void verifyHost(PublicKey key)

private void setKexDone() {
kexOngoing.set(false);
initialKex.set(false);
if (strictKex.get()) {
transport.getDecoder().resetSequenceNumber();
}
kexInitSent.clear();
done.set();
}
Expand All @@ -242,6 +266,7 @@ private void gotKexInit(SSHPacket buf)
throws TransportException {
buf.rpos(buf.rpos() - 1);
final Proposal serverProposal = new Proposal(buf);
gotStrictKexInfo(serverProposal);
negotiatedAlgs = clientProposal.negotiate(serverProposal);
log.debug("Negotiated algorithms: {}", negotiatedAlgs);
for(AlgorithmsVerifier v: algorithmVerifiers) {
Expand All @@ -265,6 +290,18 @@ private void gotKexInit(SSHPacket buf)
}
}

private void gotStrictKexInfo(Proposal serverProposal) throws TransportException {
if (initialKex.get() && serverProposal.isStrictKeyExchangeSupportedByServer()) {
strictKex.set(true);
log.debug("Enabling strict key exchange extension");
if (transport.getDecoder().getSequenceNumber() != 0) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"SSH_MSG_KEXINIT was not first package during strict key exchange"
);
}
}
}

/**
* Private method used while putting new keys into use that will resize the key used to initialize the cipher to the
* needed length.
Expand Down
9 changes: 8 additions & 1 deletion src/main/java/net/schmizz/sshj/transport/Proposal.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ class Proposal {
private final List<String> s2cComp;
private final SSHPacket packet;

public Proposal(Config config, List<String> knownHostAlgs) {
public Proposal(Config config, List<String> knownHostAlgs, boolean initialKex) {
kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories());
if (initialKex) {
kex.add("kex-strict-c-v00@openssh.com");
}
sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs);
c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories());
c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories());
Expand Down Expand Up @@ -91,6 +94,10 @@ public List<String> getKeyExchangeAlgorithms() {
return kex;
}

public boolean isStrictKeyExchangeSupportedByServer() {
return kex.contains("kex-strict-s-v00@openssh.com");
}

public List<String> getHostKeyAlgorithms() {
return sig;
}
Expand Down
20 changes: 18 additions & 2 deletions src/main/java/net/schmizz/sshj/transport/TransportImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ public void init(String remoteHost, int remotePort, InputStream in, OutputStream
}

ThreadNameProvider.setThreadName(reader, this);
kexer.markKexAsInitial();
reader.start();
}

Expand Down Expand Up @@ -426,7 +427,7 @@ public long write(SSHPacket payload)
assert m != Message.KEXINIT;
kexer.waitForDone();
}
} else if (encoder.getSequenceNumber() == 0) // We get here every 2**32th packet
} else if (encoder.isSequenceNumberAtMax()) // We get here every 2**32th packet
kexer.startKex(true);

final long seq = encoder.encode(payload);
Expand Down Expand Up @@ -479,9 +480,20 @@ public void handle(Message msg, SSHPacket buf)

log.trace("Received packet {}", msg);

if (kexer.isInitialKex()) {
if (decoder.isSequenceNumberAtMax()) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"Sequence number of decoder is about to wrap during initial key exchange");
}
if (kexer.isStrictKex() && !isKexerPacket(msg) && msg != Message.DISCONNECT) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"Unexpected packet type during initial strict key exchange");
}
}

if (msg.geq(50)) { // not a transport layer packet
service.handle(msg, buf);
} else if (msg.in(20, 21) || msg.in(30, 49)) { // kex packet
} else if (isKexerPacket(msg)) {
kexer.handle(msg, buf);
} else {
switch (msg) {
Expand Down Expand Up @@ -513,6 +525,10 @@ public void handle(Message msg, SSHPacket buf)
}
}

private static boolean isKexerPacket(Message msg) {
return msg.in(20, 21) || msg.in(30, 49);
}

private void gotDebug(SSHPacket buf)
throws TransportException {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ private void performAndCheckKeyExchange() throws TransportException {
}

private SSHPacket getKexinitPacket() {
SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList()).getPacket();
SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList(), false).getPacket();
kexinitPacket.rpos(kexinitPacket.rpos() + 1);
return kexinitPacket;
}
Expand Down

0 comments on commit 92bdefb

Please sign in to comment.