Skip to content

Commit

Permalink
refactor: extracted PortForwarderWebsocketListener and added tests
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Nuri <marc@marcnuri.com>
  • Loading branch information
manusa committed May 19, 2022
1 parent aa89c52 commit bb20821
Show file tree
Hide file tree
Showing 3 changed files with 452 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import io.fabric8.kubernetes.client.http.WebSocket;
import io.fabric8.kubernetes.client.utils.URLUtils;
import io.fabric8.kubernetes.client.utils.Utils;
import io.fabric8.kubernetes.client.utils.internal.SerialExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -30,15 +29,12 @@
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -170,157 +166,7 @@ public Collection<Throwable> getServerThrowables() {

@Override
public PortForward forward(URL resourceBaseUrl, int port, final ReadableByteChannel in, final WritableByteChannel out) {
final AtomicBoolean alive = new AtomicBoolean(true);
final AtomicBoolean errorOccurred = new AtomicBoolean(false);
final Collection<Throwable> clientThrowables = Collections.synchronizedCollection(new ArrayList<>());
final Collection<Throwable> serverThrowables = Collections.synchronizedCollection(new ArrayList<>());
final String logPrefix = "FWD";

WebSocket.Listener listener = new WebSocket.Listener() {
private int messagesRead = 0;

private final ExecutorService pumperService = Executors.newSingleThreadExecutor();
private final SerialExecutor serialExecutor = new SerialExecutor(Utils.getCommonExecutorSerive());

@Override
public void onOpen(final WebSocket webSocket) {
LOG.debug("{}: onOpen", logPrefix);

if (in != null) {
pumperService.execute(() -> {
ByteBuffer buffer = ByteBuffer.allocate(4096);
int read;
try {
do {
buffer.clear();
buffer.put((byte) 0); // channel byte
read = in.read(buffer);
if (read > 0) {
buffer.flip();
webSocket.send(buffer);
} else if (read == 0) {
// in is non-blocking, prevent a busy loop
Thread.sleep(50);
}
} while (alive.get() && read >= 0);

} catch (IOException | InterruptedException e) {
if (alive.get()) {
clientThrowables.add(e);
LOG.error("Error while writing client data");
closeBothWays(webSocket, 1001, "Client error");
}
}
});
}
}

@Override
public void onMessage(WebSocket webSocket, String text) {
LOG.debug("{}: onMessage(String)", logPrefix);
onMessage(webSocket, ByteBuffer.wrap(text.getBytes(StandardCharsets.UTF_8)));
}

@Override
public void onMessage(WebSocket webSocket, ByteBuffer buffer) {
messagesRead++;
if (messagesRead <= 2) {
// skip the first two messages, containing the ports used internally
webSocket.request();
return;
}

if (!buffer.hasRemaining()) {
errorOccurred.set(true);
LOG.error("Received an empty message");
closeBothWays(webSocket, 1002, "Protocol error");
}

byte channel = buffer.get();
if (channel < 0 || channel > 1) {
errorOccurred.set(true);
LOG.error("Received a wrong channel from the remote socket: {}", channel);
closeBothWays(webSocket, 1002, "Protocol error");
} else if (channel == 1) {
// Error channel
errorOccurred.set(true);
LOG.error("Received an error from the remote socket");
closeForwarder();
} else {
// Data
if (out != null) {
serialExecutor.execute(() -> {
try {
while (buffer.hasRemaining()) {
int written = out.write(buffer); // channel byte already skipped
if (written == 0) {
// out is non-blocking, prevent a busy loop
Thread.sleep(50);
}
}
webSocket.request();
} catch (IOException | InterruptedException e) {
if (alive.get()) {
clientThrowables.add(e);
LOG.error("Error while forwarding data to the client", e);
closeBothWays(webSocket, 1002, "Protocol error");
}
}
});
}
}
}

@Override
public void onClose(WebSocket webSocket, int code, String reason) {
LOG.debug("{}: onClose. Code={}, Reason={}", logPrefix, code, reason);
if (alive.get()) {
closeForwarder();
}
}

@Override
public void onError(WebSocket webSocket, Throwable t) {
LOG.debug("{}: onFailure", logPrefix);
if (alive.get()) {
serverThrowables.add(t);
LOG.error("{}: Throwable received from websocket", logPrefix, t);
closeForwarder();
}
}

private void closeBothWays(WebSocket webSocket, int code, String message) {
LOG.debug("{}: Closing with code {} and reason: {}", logPrefix, code, message);
alive.set(false);
try {
webSocket.sendClose(code, message);
} catch (Exception e) {
serverThrowables.add(e);
LOG.error("Error while closing the websocket", e);
}
closeForwarder();
}

private void closeForwarder() {
alive.set(false);
if (in != null) {
try {
in.close();
} catch (IOException e) {
LOG.error("{}: Error while closing the client input channel", logPrefix, e);
}
}
if (out != null && out != in) {
try {
out.close();
} catch (IOException e) {
LOG.error("{}: Error while closing the client output channel", logPrefix, e);
}
}
pumperService.shutdownNow();
serialExecutor.shutdownNow();
}
};
final PortForwarderWebsocketListener listener = new PortForwarderWebsocketListener(in, out);
CompletableFuture<WebSocket> socket = client
.newWebSocketBuilder()
.uri(URI.create(URLUtils.join(resourceBaseUrl.toString(), "portforward?ports=" + port)))
Expand All @@ -334,7 +180,7 @@ private void closeForwarder() {

return new PortForward() {
@Override
public void close() throws IOException {
public void close() {
socket.cancel(true);
socket.whenComplete((w, t) -> {
if (w != null) {
Expand All @@ -345,22 +191,22 @@ public void close() throws IOException {

@Override
public boolean isAlive() {
return alive.get();
return listener.isAlive();
}

@Override
public boolean errorOccurred() {
return errorOccurred.get() || !clientThrowables.isEmpty() || !serverThrowables.isEmpty();
return listener.errorOccurred();
}

@Override
public Collection<Throwable> getClientThrowables() {
return clientThrowables;
return listener.getClientThrowables();
}

@Override
public Collection<Throwable> getServerThrowables() {
return serverThrowables;
return listener.getServerThrowables();
}
};
}
Expand Down
Loading

0 comments on commit bb20821

Please sign in to comment.