Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for wrapping system streams in WorkRequestHandler #17583

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.sun.management.OperatingSystemMXBean;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.management.ManagementFactory;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -317,8 +321,17 @@ public WorkRequestHandler build() {
* then writing the corresponding {@link WorkResponse} to {@code out}. If there is an error
* reading or writing the requests or responses, it writes an error message on {@code err} and
* returns. If {@code in} reaches EOF, it also returns.
*
* <p>This function also wraps the system streams in a {@link WorkerIO} instance that prevents the
* underlying tool from writing to {@link System#out} or reading from {@link System#in}, which
* would corrupt the worker worker protocol. When the while loop exits, the original system
* streams will be swapped back into {@link System}.
*/
public void processRequests() throws IOException {
// Wrap the system streams into a WorkerIO instance to prevent unexpected reads and writes on
// stdin/stdout.
WorkerIO workerIO = WorkerIO.capture();

try {
while (!shutdownWorker.get()) {
WorkRequest request = messageProcessor.readWorkRequest();
Expand All @@ -328,31 +341,39 @@ public void processRequests() throws IOException {
if (request.getCancel()) {
respondToCancelRequest(request);
} else {
startResponseThread(request);
startResponseThread(workerIO, request);
}
}
} catch (IOException e) {
stderr.println("Error reading next WorkRequest: " + e);
e.printStackTrace(stderr);
}
// TODO(b/220878242): Give the outstanding requests a chance to send a "shutdown" response,
// but also try to kill stuck threads. For now, we just interrupt the remaining threads.
// We considered doing System.exit here, but that is hard to test and would deny the callers
// of this method a chance to clean up. Instead, we initiate the cleanup of our resources here
// and the caller can decide whether to wait for an orderly shutdown or now.
for (RequestInfo ri : activeRequests.values()) {
if (ri.thread.isAlive()) {
try {
ri.thread.interrupt();
} catch (RuntimeException e) {
// If we can't interrupt, we can't do much else.
} finally {
// TODO(b/220878242): Give the outstanding requests a chance to send a "shutdown" response,
// but also try to kill stuck threads. For now, we just interrupt the remaining threads.
// We considered doing System.exit here, but that is hard to test and would deny the callers
// of this method a chance to clean up. Instead, we initiate the cleanup of our resources here
// and the caller can decide whether to wait for an orderly shutdown or now.
for (RequestInfo ri : activeRequests.values()) {
if (ri.thread.isAlive()) {
try {
ri.thread.interrupt();
} catch (RuntimeException e) {
// If we can't interrupt, we can't do much else.
}
}
}

try {
// Unwrap the system streams placing the original streams back
workerIO.close();
} catch (Exception e) {
stderr.println(e.getMessage());
}
}
}

/** Starts a thread for the given request. */
void startResponseThread(WorkRequest request) {
void startResponseThread(WorkerIO workerIO, WorkRequest request) {
Thread currentThread = Thread.currentThread();
String threadName =
request.getRequestId() > 0
Expand Down Expand Up @@ -381,7 +402,7 @@ void startResponseThread(WorkRequest request) {
return;
}
try {
respondToRequest(request, requestInfo);
respondToRequest(workerIO, request, requestInfo);
} catch (IOException e) {
// IOExceptions here means a problem talking to the server, so we must shut down.
if (!shutdownWorker.compareAndSet(false, true)) {
Expand Down Expand Up @@ -419,7 +440,8 @@ void startResponseThread(WorkRequest request) {
* #callback} are reported with exit code 1.
*/
@VisibleForTesting
void respondToRequest(WorkRequest request, RequestInfo requestInfo) throws IOException {
void respondToRequest(WorkerIO workerIO, WorkRequest request, RequestInfo requestInfo)
throws IOException {
int exitCode;
StringWriter sw = new StringWriter();
try (PrintWriter pw = new PrintWriter(sw)) {
Expand All @@ -431,6 +453,16 @@ void respondToRequest(WorkRequest request, RequestInfo requestInfo) throws IOExc
e.printStackTrace(pw);
exitCode = 1;
}

try {
// Read out the captured string for the final WorkResponse output
String captured = workerIO.readCapturedAsUtf8String().trim();
if (!captured.isEmpty()) {
pw.write(captured);
}
} catch (IOException e) {
stderr.println(e.getMessage());
}
}
Optional<WorkResponse.Builder> optBuilder = requestInfo.takeBuilder();
if (optBuilder.isPresent()) {
Expand Down Expand Up @@ -541,4 +573,104 @@ private void maybePerformGc() {
}
}
}

/**
* Class that wraps the standard {@link System#in}, {@link System#out}, and {@link System#err}
* with our own ByteArrayOutputStream that allows {@link WorkRequestHandler} to safely capture
* outputs that can't be directly captured by the PrintStream associated with the work request.
*
* <p>This is most useful when integrating JVM tools that write exceptions and logs directly to
* {@link System#out} and {@link System#err}, which would corrupt the persistent worker protocol.
* We also redirect {@link System#in}, just in case a tool should attempt to read it.
*
* <p>WorkerIO implements {@link AutoCloseable} and will swap the original streams back into
* {@link System} once close has been called.
*/
public static class WorkerIO implements AutoCloseable {
private final InputStream originalInputStream;
private final PrintStream originalOutputStream;
private final PrintStream originalErrorStream;
private final ByteArrayOutputStream capturedStream;
private final AutoCloseable restore;

/**
* Creates a new {@link WorkerIO} that allows {@link WorkRequestHandler} to capture standard
* output and error streams that can't be directly captured by the PrintStream associated with
* the work request.
*/
@VisibleForTesting
WorkerIO(
InputStream originalInputStream,
PrintStream originalOutputStream,
PrintStream originalErrorStream,
ByteArrayOutputStream capturedStream,
AutoCloseable restore) {
this.originalInputStream = originalInputStream;
this.originalOutputStream = originalOutputStream;
this.originalErrorStream = originalErrorStream;
this.capturedStream = capturedStream;
this.restore = restore;
}

/** Wraps the standard System streams and WorkerIO instance */
public static WorkerIO capture() {
// Save the original streams
InputStream originalInputStream = System.in;
PrintStream originalOutputStream = System.out;
PrintStream originalErrorStream = System.err;

// Replace the original streams with our own instances
ByteArrayOutputStream capturedStream = new ByteArrayOutputStream();
PrintStream outputBuffer = new PrintStream(capturedStream, true);
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(new byte[0]);
System.setIn(byteArrayInputStream);
System.setOut(outputBuffer);
System.setErr(outputBuffer);

return new WorkerIO(
originalInputStream,
originalOutputStream,
originalErrorStream,
capturedStream,
() -> {
System.setIn(originalInputStream);
System.setOut(originalOutputStream);
System.setErr(originalErrorStream);
outputBuffer.close();
byteArrayInputStream.close();
});
}

/** Returns the original input stream most commonly provided by {@link System#in} */
@VisibleForTesting
InputStream getOriginalInputStream() {
return originalInputStream;
}

/** Returns the original output stream most commonly provided by {@link System#out} */
@VisibleForTesting
PrintStream getOriginalOutputStream() {
return originalOutputStream;
}

/** Returns the original error stream most commonly provided by {@link System#err} */
@VisibleForTesting
PrintStream getOriginalErrorStream() {
return originalErrorStream;
}

/** Returns the captured outputs as a UTF-8 string */
@VisibleForTesting
String readCapturedAsUtf8String() throws IOException {
capturedStream.flush();
String captureOutput = capturedStream.toString(StandardCharsets.UTF_8);
capturedStream.reset();
return captureOutput;
}

@Override
public void close() throws Exception {
restore.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public final class ExampleWorker {
static final Pattern FLAG_FILE_PATTERN = Pattern.compile("(?:@|--?flagfile=)(.+)");

// A UUID that uniquely identifies this running worker process.
static final UUID workerUuid = UUID.randomUUID();
static final UUID WORKER_UUID = UUID.randomUUID();

// A counter that increases with each work unit processed.
static int workUnitCounter = 1;
Expand Down Expand Up @@ -83,6 +83,9 @@ private static class InterruptableWorkRequestHandler extends WorkRequestHandler
@Override
@SuppressWarnings("SystemExitOutsideMain")
public void processRequests() throws IOException {
ByteArrayOutputStream captured = new ByteArrayOutputStream();
WorkerIO workerIO = new WorkerIO(System.in, System.out, System.err, captured, captured);

while (true) {
WorkRequest request = messageProcessor.readWorkRequest();
if (request == null) {
Expand All @@ -100,12 +103,19 @@ public void processRequests() throws IOException {
if (request.getCancel()) {
respondToCancelRequest(request);
} else {
startResponseThread(request);
startResponseThread(workerIO, request);
}
if (workerOptions.exitAfter > 0 && workUnitCounter > workerOptions.exitAfter) {
System.exit(0);
}
}

try {
// Unwrap the system streams placing the original streams back
workerIO.close();
} catch (Exception e) {
workerIO.getOriginalErrorStream().println(e.getMessage());
}
}
}

Expand Down Expand Up @@ -241,7 +251,7 @@ private static void parseOptionsAndLog(List<String> args) throws Exception {
List<String> outputs = new ArrayList<>();

if (options.writeUUID) {
outputs.add("UUID " + workerUuid);
outputs.add("UUID " + WORKER_UUID);
}

if (options.writeCounter) {
Expand Down
Loading