From 4b9557e73a52d9c0967385525b431e80bad96f4f Mon Sep 17 00:00:00 2001 From: exceptionfactory Date: Mon, 20 Dec 2021 22:15:54 -0600 Subject: [PATCH] Updated SSHClient to interrupt KeepAlive Thread when disconnecting (#506) - Changed KeepAlive.setKeepAliveInterval() to avoid starting Thread - Updated SSHClient.onConnect() to start KeepAlive Thread when enabled - Updated SSHClient.disconnect() to interrupt KeepAlive Thread - Updated KeepAliveThreadTerminationTest to verify state of KeepAlive Thread --- .../java/net/schmizz/keepalive/KeepAlive.java | 32 +++----- src/main/java/net/schmizz/sshj/SSHClient.java | 6 ++ .../KeepAliveThreadTerminationTest.java | 74 +++++++++++-------- 3 files changed, 61 insertions(+), 51 deletions(-) diff --git a/src/main/java/net/schmizz/keepalive/KeepAlive.java b/src/main/java/net/schmizz/keepalive/KeepAlive.java index bfaef564c..35e97c2b2 100644 --- a/src/main/java/net/schmizz/keepalive/KeepAlive.java +++ b/src/main/java/net/schmizz/keepalive/KeepAlive.java @@ -20,6 +20,8 @@ import net.schmizz.sshj.transport.TransportException; import org.slf4j.Logger; +import java.util.concurrent.TimeUnit; + public abstract class KeepAlive extends Thread { protected final Logger log; protected final ConnectionImpl conn; @@ -33,40 +35,32 @@ protected KeepAlive(ConnectionImpl conn, String name) { setDaemon(true); } + public boolean isEnabled() { + return keepAliveInterval > 0; + } + public synchronized int getKeepAliveInterval() { return keepAliveInterval; } public synchronized void setKeepAliveInterval(int keepAliveInterval) { this.keepAliveInterval = keepAliveInterval; - if (keepAliveInterval > 0 && getState() == State.NEW) { - start(); - } - notify(); - } - - synchronized protected int getPositiveInterval() - throws InterruptedException { - while (keepAliveInterval <= 0) { - wait(); - } - return keepAliveInterval; } @Override public void run() { - log.debug("Starting {}, sending keep-alive every {} seconds", getClass().getSimpleName(), keepAliveInterval); + log.debug("{} Started with interval [{} seconds]", getClass().getSimpleName(), keepAliveInterval); try { while (!isInterrupted()) { - final int hi = getPositiveInterval(); + final int interval = getKeepAliveInterval(); if (conn.getTransport().isRunning()) { - log.debug("Sending keep-alive since {} seconds elapsed", hi); + log.debug("{} Sending after interval [{} seconds]", getClass().getSimpleName(), interval); doKeepAlive(); } - Thread.sleep(hi * 1000); + TimeUnit.SECONDS.sleep(interval); } } catch (InterruptedException e) { - // Interrupt signal may be catched when sleeping. + log.trace("{} Interrupted while sleeping", getClass().getSimpleName()); } catch (Exception e) { // If we weren't interrupted, kill the transport, then this exception was unexpected. // Else we're in shutdown-mode already, so don't forcibly kill the transport. @@ -74,9 +68,7 @@ public void run() { conn.getTransport().die(e); } } - - log.debug("Stopping {}", getClass().getSimpleName()); - + log.debug("{} Stopped", getClass().getSimpleName()); } protected abstract void doKeepAlive() throws TransportException, ConnectionException; diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index 6ecec30e6..78ad0779f 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj; +import net.schmizz.keepalive.KeepAlive; import net.schmizz.sshj.common.*; import net.schmizz.sshj.connection.Connection; import net.schmizz.sshj.connection.ConnectionException; @@ -424,6 +425,7 @@ public void authGssApiWithMic(String username, LoginContext context, Oid support @Override public void disconnect() throws IOException { + conn.getKeepAlive().interrupt(); for (LocalPortForwarder forwarder : forwarders) { try { forwarder.close(); @@ -791,6 +793,10 @@ protected void onConnect() throws IOException { super.onConnect(); trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream()); + final KeepAlive keepAliveThread = conn.getKeepAlive(); + if (keepAliveThread.isEnabled()) { + keepAliveThread.start(); + } doKex(); } diff --git a/src/test/java/com/hierynomus/sshj/keepalive/KeepAliveThreadTerminationTest.java b/src/test/java/com/hierynomus/sshj/keepalive/KeepAliveThreadTerminationTest.java index fe9cc4f23..6bdc2423c 100644 --- a/src/test/java/com/hierynomus/sshj/keepalive/KeepAliveThreadTerminationTest.java +++ b/src/test/java/com/hierynomus/sshj/keepalive/KeepAliveThreadTerminationTest.java @@ -15,54 +15,66 @@ */ package com.hierynomus.sshj.keepalive; -import com.hierynomus.sshj.test.KnownFailingTests; -import com.hierynomus.sshj.test.SlowTests; import com.hierynomus.sshj.test.SshFixture; +import net.schmizz.keepalive.KeepAlive; import net.schmizz.keepalive.KeepAliveProvider; import net.schmizz.sshj.DefaultConfig; import net.schmizz.sshj.SSHClient; import net.schmizz.sshj.userauth.UserAuthException; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import java.io.IOException; -import java.lang.management.ManagementFactory; -import java.lang.management.ThreadInfo; -import java.lang.management.ThreadMXBean; -import static org.junit.Assert.fail; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; public class KeepAliveThreadTerminationTest { + private static final int KEEP_ALIVE_SECONDS = 1; + + private static final long STOP_SLEEP = 1500; + @Rule public SshFixture fixture = new SshFixture(); @Test - @Category({SlowTests.class, KnownFailingTests.class}) - public void shouldCorrectlyTerminateThreadOnDisconnect() throws IOException, InterruptedException { - DefaultConfig defaultConfig = new DefaultConfig(); + public void shouldNotStartThreadOnSetKeepAliveInterval() { + final SSHClient sshClient = setupClient(); + + final KeepAlive keepAlive = sshClient.getConnection().getKeepAlive(); + assertTrue(keepAlive.isDaemon()); + assertFalse(keepAlive.isAlive()); + assertEquals(Thread.State.NEW, keepAlive.getState()); + } + + @Test + public void shouldStartThreadOnConnectAndInterruptOnDisconnect() throws IOException, InterruptedException { + final SSHClient sshClient = setupClient(); + + final KeepAlive keepAlive = sshClient.getConnection().getKeepAlive(); + assertTrue(keepAlive.isDaemon()); + assertEquals(Thread.State.NEW, keepAlive.getState()); + + fixture.connectClient(sshClient); + assertEquals(Thread.State.TIMED_WAITING, keepAlive.getState()); + + assertThrows(UserAuthException.class, () -> sshClient.authPassword("bad", "credentials")); + + fixture.stopClient(); + Thread.sleep(STOP_SLEEP); + + assertFalse(keepAlive.isAlive()); + assertEquals(Thread.State.TERMINATED, keepAlive.getState()); + } + + private SSHClient setupClient() { + final DefaultConfig defaultConfig = new DefaultConfig(); defaultConfig.setKeepAliveProvider(KeepAliveProvider.KEEP_ALIVE); - for (int i = 0; i < 10; i++) { - SSHClient sshClient = fixture.setupClient(defaultConfig); - fixture.connectClient(sshClient); - sshClient.getConnection().getKeepAlive().setKeepAliveInterval(1); - try { - sshClient.authPassword("bad", "credentials"); - fail("Should not auth."); - } catch (UserAuthException e) { - // OK - } - fixture.stopClient(); - Thread.sleep(2000); - } - - ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean(); - for (long l : threadMXBean.getAllThreadIds()) { - ThreadInfo threadInfo = threadMXBean.getThreadInfo(l); - if (threadInfo.getThreadName().equals("keep-alive") && threadInfo.getThreadState() != Thread.State.TERMINATED) { - fail("Found alive keep-alive thread in state " + threadInfo.getThreadState()); - } - } + final SSHClient sshClient = fixture.setupClient(defaultConfig); + sshClient.getConnection().getKeepAlive().setKeepAliveInterval(KEEP_ALIVE_SECONDS); + return sshClient; } }