From e05345dec405a0333c48ff188d2d9a1be90dad53 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Mon, 2 Jan 2023 01:10:57 -0800 Subject: [PATCH] Add support for wrapping system streams in WorkRequestHandler There are often [places](https://github.com/bazelbuild/bazel/blob/ea19c17075478092eb77580e6d3825d480126d3a/src/tools/android/java/com/google/devtools/build/android/ResourceProcessorBusyBox.java#L188) where persistent workers need to swap out the standard system streams to avoid tools poisoning the worker communication streams by writing logs/exceptions to it. This pull request extracts that pattern into an optional WorkerIO wrapper can be used to swap in and out the standard streams without the added boilerplate. Closes #14201. PiperOrigin-RevId: 498983983 Change-Id: Iefb956d38a5887d9e5bbf0821551eb0efa14fce9 --- .../build/lib/worker/WorkRequestHandler.java | 164 ++++++++++++++++-- .../build/lib/worker/ExampleWorker.java | 16 +- .../lib/worker/WorkRequestHandlerTest.java | 83 ++++++++- 3 files changed, 239 insertions(+), 24 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java b/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java index fcd8896e0479a0..997ea80c802ed4 100644 --- a/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java +++ b/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java @@ -18,11 +18,15 @@ import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.sun.management.OperatingSystemMXBean; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; import java.io.PrintStream; import java.io.PrintWriter; import java.io.StringWriter; import java.lang.management.ManagementFactory; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.List; import java.util.Optional; @@ -317,8 +321,17 @@ public WorkRequestHandler build() { * then writing the corresponding {@link WorkResponse} to {@code out}. If there is an error * reading or writing the requests or responses, it writes an error message on {@code err} and * returns. If {@code in} reaches EOF, it also returns. + * + *

This function also wraps the system streams in a {@link WorkerIO} instance that prevents the + * underlying tool from writing to {@link System#out} or reading from {@link System#in}, which + * would corrupt the worker worker protocol. When the while loop exits, the original system + * streams will be swapped back into {@link System}. */ public void processRequests() throws IOException { + // Wrap the system streams into a WorkerIO instance to prevent unexpected reads and writes on + // stdin/stdout. + WorkerIO workerIO = WorkerIO.capture(); + try { while (!shutdownWorker.get()) { WorkRequest request = messageProcessor.readWorkRequest(); @@ -328,31 +341,39 @@ public void processRequests() throws IOException { if (request.getCancel()) { respondToCancelRequest(request); } else { - startResponseThread(request); + startResponseThread(workerIO, request); } } } catch (IOException e) { stderr.println("Error reading next WorkRequest: " + e); e.printStackTrace(stderr); - } - // TODO(b/220878242): Give the outstanding requests a chance to send a "shutdown" response, - // but also try to kill stuck threads. For now, we just interrupt the remaining threads. - // We considered doing System.exit here, but that is hard to test and would deny the callers - // of this method a chance to clean up. Instead, we initiate the cleanup of our resources here - // and the caller can decide whether to wait for an orderly shutdown or now. - for (RequestInfo ri : activeRequests.values()) { - if (ri.thread.isAlive()) { - try { - ri.thread.interrupt(); - } catch (RuntimeException e) { - // If we can't interrupt, we can't do much else. + } finally { + // TODO(b/220878242): Give the outstanding requests a chance to send a "shutdown" response, + // but also try to kill stuck threads. For now, we just interrupt the remaining threads. + // We considered doing System.exit here, but that is hard to test and would deny the callers + // of this method a chance to clean up. Instead, we initiate the cleanup of our resources here + // and the caller can decide whether to wait for an orderly shutdown or now. + for (RequestInfo ri : activeRequests.values()) { + if (ri.thread.isAlive()) { + try { + ri.thread.interrupt(); + } catch (RuntimeException e) { + // If we can't interrupt, we can't do much else. + } } } + + try { + // Unwrap the system streams placing the original streams back + workerIO.close(); + } catch (Exception e) { + stderr.println(e.getMessage()); + } } } /** Starts a thread for the given request. */ - void startResponseThread(WorkRequest request) { + void startResponseThread(WorkerIO workerIO, WorkRequest request) { Thread currentThread = Thread.currentThread(); String threadName = request.getRequestId() > 0 @@ -381,7 +402,7 @@ void startResponseThread(WorkRequest request) { return; } try { - respondToRequest(request, requestInfo); + respondToRequest(workerIO, request, requestInfo); } catch (IOException e) { // IOExceptions here means a problem talking to the server, so we must shut down. if (!shutdownWorker.compareAndSet(false, true)) { @@ -419,7 +440,8 @@ void startResponseThread(WorkRequest request) { * #callback} are reported with exit code 1. */ @VisibleForTesting - void respondToRequest(WorkRequest request, RequestInfo requestInfo) throws IOException { + void respondToRequest(WorkerIO workerIO, WorkRequest request, RequestInfo requestInfo) + throws IOException { int exitCode; StringWriter sw = new StringWriter(); try (PrintWriter pw = new PrintWriter(sw)) { @@ -431,6 +453,16 @@ void respondToRequest(WorkRequest request, RequestInfo requestInfo) throws IOExc e.printStackTrace(pw); exitCode = 1; } + + try { + // Read out the captured string for the final WorkResponse output + String captured = workerIO.readCapturedAsUtf8String().trim(); + if (!captured.isEmpty()) { + pw.write(captured); + } + } catch (IOException e) { + stderr.println(e.getMessage()); + } } Optional optBuilder = requestInfo.takeBuilder(); if (optBuilder.isPresent()) { @@ -541,4 +573,104 @@ private void maybePerformGc() { } } } + + /** + * Class that wraps the standard {@link System#in}, {@link System#out}, and {@link System#err} + * with our own ByteArrayOutputStream that allows {@link WorkRequestHandler} to safely capture + * outputs that can't be directly captured by the PrintStream associated with the work request. + * + *

This is most useful when integrating JVM tools that write exceptions and logs directly to + * {@link System#out} and {@link System#err}, which would corrupt the persistent worker protocol. + * We also redirect {@link System#in}, just in case a tool should attempt to read it. + * + *

WorkerIO implements {@link AutoCloseable} and will swap the original streams back into + * {@link System} once close has been called. + */ + public static class WorkerIO implements AutoCloseable { + private final InputStream originalInputStream; + private final PrintStream originalOutputStream; + private final PrintStream originalErrorStream; + private final ByteArrayOutputStream capturedStream; + private final AutoCloseable restore; + + /** + * Creates a new {@link WorkerIO} that allows {@link WorkRequestHandler} to capture standard + * output and error streams that can't be directly captured by the PrintStream associated with + * the work request. + */ + @VisibleForTesting + WorkerIO( + InputStream originalInputStream, + PrintStream originalOutputStream, + PrintStream originalErrorStream, + ByteArrayOutputStream capturedStream, + AutoCloseable restore) { + this.originalInputStream = originalInputStream; + this.originalOutputStream = originalOutputStream; + this.originalErrorStream = originalErrorStream; + this.capturedStream = capturedStream; + this.restore = restore; + } + + /** Wraps the standard System streams and WorkerIO instance */ + public static WorkerIO capture() { + // Save the original streams + InputStream originalInputStream = System.in; + PrintStream originalOutputStream = System.out; + PrintStream originalErrorStream = System.err; + + // Replace the original streams with our own instances + ByteArrayOutputStream capturedStream = new ByteArrayOutputStream(); + PrintStream outputBuffer = new PrintStream(capturedStream, true); + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(new byte[0]); + System.setIn(byteArrayInputStream); + System.setOut(outputBuffer); + System.setErr(outputBuffer); + + return new WorkerIO( + originalInputStream, + originalOutputStream, + originalErrorStream, + capturedStream, + () -> { + System.setIn(originalInputStream); + System.setOut(originalOutputStream); + System.setErr(originalErrorStream); + outputBuffer.close(); + byteArrayInputStream.close(); + }); + } + + /** Returns the original input stream most commonly provided by {@link System#in} */ + @VisibleForTesting + InputStream getOriginalInputStream() { + return originalInputStream; + } + + /** Returns the original output stream most commonly provided by {@link System#out} */ + @VisibleForTesting + PrintStream getOriginalOutputStream() { + return originalOutputStream; + } + + /** Returns the original error stream most commonly provided by {@link System#err} */ + @VisibleForTesting + PrintStream getOriginalErrorStream() { + return originalErrorStream; + } + + /** Returns the captured outputs as a UTF-8 string */ + @VisibleForTesting + String readCapturedAsUtf8String() throws IOException { + capturedStream.flush(); + String captureOutput = capturedStream.toString(StandardCharsets.UTF_8); + capturedStream.reset(); + return captureOutput; + } + + @Override + public void close() throws Exception { + restore.close(); + } + } } diff --git a/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java b/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java index c58c732d87113e..2edd557cfc8f5c 100644 --- a/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java +++ b/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java @@ -54,7 +54,7 @@ public final class ExampleWorker { static final Pattern FLAG_FILE_PATTERN = Pattern.compile("(?:@|--?flagfile=)(.+)"); // A UUID that uniquely identifies this running worker process. - static final UUID workerUuid = UUID.randomUUID(); + static final UUID WORKER_UUID = UUID.randomUUID(); // A counter that increases with each work unit processed. static int workUnitCounter = 1; @@ -83,6 +83,9 @@ private static class InterruptableWorkRequestHandler extends WorkRequestHandler @Override @SuppressWarnings("SystemExitOutsideMain") public void processRequests() throws IOException { + ByteArrayOutputStream captured = new ByteArrayOutputStream(); + WorkerIO workerIO = new WorkerIO(System.in, System.out, System.err, captured, captured); + while (true) { WorkRequest request = messageProcessor.readWorkRequest(); if (request == null) { @@ -100,12 +103,19 @@ public void processRequests() throws IOException { if (request.getCancel()) { respondToCancelRequest(request); } else { - startResponseThread(request); + startResponseThread(workerIO, request); } if (workerOptions.exitAfter > 0 && workUnitCounter > workerOptions.exitAfter) { System.exit(0); } } + + try { + // Unwrap the system streams placing the original streams back + workerIO.close(); + } catch (Exception e) { + workerIO.getOriginalErrorStream().println(e.getMessage()); + } } } @@ -241,7 +251,7 @@ private static void parseOptionsAndLog(List args) throws Exception { List outputs = new ArrayList<>(); if (options.writeUUID) { - outputs.add("UUID " + workerUuid); + outputs.add("UUID " + WORKER_UUID); } if (options.writeCounter) { diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java index 5e279f1bbed4e3..9c1ea29fab7c8a 100644 --- a/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java +++ b/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java @@ -25,6 +25,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; import java.io.InterruptedIOException; import java.io.PipedInputStream; import java.io.PipedOutputStream; @@ -34,6 +35,7 @@ import java.util.List; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -44,11 +46,18 @@ @RunWith(JUnit4.class) public class WorkRequestHandlerTest { + private final WorkRequestHandler.WorkerIO testWorkerIO = createTestWorkerIO(); + @Before public void init() { MockitoAnnotations.initMocks(this); } + @After + public void after() throws Exception { + testWorkerIO.close(); + } + @Test public void testNormalWorkRequest() throws IOException { ByteArrayOutputStream out = new ByteArrayOutputStream(); @@ -60,7 +69,7 @@ public void testNormalWorkRequest() throws IOException { List args = Arrays.asList("--sources", "A.java"); WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).build(); - handler.respondToRequest(request, new RequestInfo(null)); + handler.respondToRequest(testWorkerIO, request, new RequestInfo(null)); WorkResponse response = WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray())); @@ -80,7 +89,7 @@ public void testMultiplexWorkRequest() throws IOException { List args = Arrays.asList("--sources", "A.java"); WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).setRequestId(42).build(); - handler.respondToRequest(request, new RequestInfo(null)); + handler.respondToRequest(testWorkerIO, request, new RequestInfo(null)); WorkResponse response = WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray())); @@ -236,7 +245,7 @@ public void testOutput() throws IOException { List args = Arrays.asList("--sources", "A.java"); WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).build(); - handler.respondToRequest(request, new RequestInfo(null)); + handler.respondToRequest(testWorkerIO, request, new RequestInfo(null)); WorkResponse response = WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray())); @@ -258,7 +267,7 @@ public void testException() throws IOException { List args = Arrays.asList("--sources", "A.java"); WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).build(); - handler.respondToRequest(request, new RequestInfo(null)); + handler.respondToRequest(testWorkerIO, request, new RequestInfo(null)); WorkResponse response = WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray())); @@ -518,7 +527,7 @@ public void testWorkRequestHandler_withWorkRequestCallback() throws IOException List args = Arrays.asList("--sources", "B.java"); WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).build(); - handler.respondToRequest(request, new RequestInfo(null)); + handler.respondToRequest(testWorkerIO, request, new RequestInfo(null)); WorkResponse response = WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray())); @@ -553,6 +562,70 @@ private void runRequestHandlerThread( .start(); } + @Test + public void testWorkerIO_doesWrapSystemStreams() throws Exception { + // Save the original streams + InputStream originalInputStream = System.in; + PrintStream originalOutputStream = System.out; + PrintStream originalErrorStream = System.err; + + // Swap in the test streams to assert against + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(new byte[0]); + System.setIn(byteArrayInputStream); + PrintStream outputBuffer = new PrintStream(new ByteArrayOutputStream(), true); + System.setOut(outputBuffer); + System.setErr(outputBuffer); + + try (outputBuffer; + byteArrayInputStream; + WorkRequestHandler.WorkerIO io = WorkRequestHandler.WorkerIO.capture()) { + // Assert that the WorkerIO returns the correct wrapped streams and the new System instance + // has been swapped out with the wrapped one + assertThat(io.getOriginalInputStream()).isSameInstanceAs(byteArrayInputStream); + assertThat(System.in).isNotSameInstanceAs(byteArrayInputStream); + + assertThat(io.getOriginalOutputStream()).isSameInstanceAs(outputBuffer); + assertThat(System.out).isNotSameInstanceAs(outputBuffer); + + assertThat(io.getOriginalErrorStream()).isSameInstanceAs(outputBuffer); + assertThat(System.err).isNotSameInstanceAs(outputBuffer); + } finally { + // Swap back in the original streams + System.setIn(originalInputStream); + System.setOut(originalOutputStream); + System.setErr(originalErrorStream); + } + } + + @Test + public void testWorkerIO_doesCaptureStandardOutAndErrorStreams() throws Exception { + try (WorkRequestHandler.WorkerIO io = WorkRequestHandler.WorkerIO.capture()) { + // Assert that nothing has been captured in the new instance + assertThat(io.readCapturedAsUtf8String()).isEmpty(); + + // Assert that the standard out/error stream redirect to our own streams + System.out.print("This is a standard out message!"); + System.err.print("This is a standard error message!"); + assertThat(io.readCapturedAsUtf8String()) + .isEqualTo("This is a standard out message!This is a standard error message!"); + + // Assert that readCapturedAsUtf8String calls reset on the captured stream after a read + assertThat(io.readCapturedAsUtf8String()).isEmpty(); + + System.out.print("out 1"); + System.err.print("err 1"); + System.out.print("out 2"); + System.err.print("err 2"); + assertThat(io.readCapturedAsUtf8String()).isEqualTo("out 1err 1out 2err 2"); + assertThat(io.readCapturedAsUtf8String()).isEmpty(); + } + } + + private WorkRequestHandler.WorkerIO createTestWorkerIO() { + ByteArrayOutputStream captured = new ByteArrayOutputStream(); + return new WorkRequestHandler.WorkerIO(System.in, System.out, System.err, captured, captured); + } + /** A wrapper around a WorkerMessageProcessor that can be stopped by calling {@code #stop()}. */ private static class StoppableWorkerMessageProcessor implements WorkerMessageProcessor { private final WorkerMessageProcessor delegate;