diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java index a9e7246bd132fc..f726e85d2b4d55 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java @@ -19,8 +19,6 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; import com.google.common.flogger.GoogleLogger; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.actions.ActionInput; @@ -34,17 +32,17 @@ import com.google.devtools.build.lib.profiler.SilentCloseable; import com.google.devtools.build.lib.remote.common.CacheNotFoundException; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; +import com.google.devtools.build.lib.remote.util.AsyncTaskCache; import com.google.devtools.build.lib.remote.util.DigestUtil; +import com.google.devtools.build.lib.remote.util.RxFutures; import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; import com.google.devtools.build.lib.remote.util.Utils; import com.google.devtools.build.lib.sandbox.SandboxHelpers; import com.google.devtools.build.lib.vfs.Path; +import io.reactivex.rxjava3.core.Completable; import java.io.IOException; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; -import java.util.Set; -import javax.annotation.concurrent.GuardedBy; /** * Stages output files that are stored remotely to the local filesystem. @@ -55,17 +53,10 @@ class RemoteActionInputFetcher implements ActionInputPrefetcher { private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); + private final AsyncTaskCache.NoResult downloadCache = AsyncTaskCache.NoResult.create(); private final Object lock = new Object(); - /** Set of successfully downloaded output files. */ - @GuardedBy("lock") - private final Set downloadedPaths = new HashSet<>(); - - @VisibleForTesting - @GuardedBy("lock") - final Map> downloadsInProgress = new HashMap<>(); - private final String buildRequestId; private final String commandId; private final RemoteCache remoteCache; @@ -110,11 +101,8 @@ public void prefetchFiles( Path path = execRoot.getRelative(input.getExecPath()); synchronized (lock) { - if (downloadedPaths.contains(path)) { - continue; - } - ListenableFuture download = downloadFileAsync(path, metadata); - downloadsToWaitFor.putIfAbsent(path, download); + downloadsToWaitFor.computeIfAbsent( + path, key -> RxFutures.toListenableFuture(downloadFileAsync(path, metadata))); } } } @@ -143,65 +131,59 @@ public void prefetchFiles( } ImmutableSet downloadedFiles() { - synchronized (lock) { - return ImmutableSet.copyOf(downloadedPaths); - } + return downloadCache.getFinishedTasks(); + } + + ImmutableSet downloadsInProgress() { + return downloadCache.getInProgressTasks(); + } + + @VisibleForTesting + AsyncTaskCache.NoResult getDownloadCache() { + return downloadCache; } void downloadFile(Path path, FileArtifactValue metadata) throws IOException, InterruptedException { - Utils.getFromFuture(downloadFileAsync(path, metadata)); + Utils.getFromFuture(RxFutures.toListenableFuture(downloadFileAsync(path, metadata))); } - private ListenableFuture downloadFileAsync(Path path, FileArtifactValue metadata) - throws IOException { - synchronized (lock) { - if (downloadedPaths.contains(path)) { - return Futures.immediateFuture(null); - } + private Completable downloadFileAsync(Path path, FileArtifactValue metadata) { + Completable download = + RxFutures.toCompletable( + () -> { + RequestMetadata requestMetadata = + TracingMetadataUtils.buildMetadata( + buildRequestId, commandId, metadata.getActionId()); + RemoteActionExecutionContext context = + RemoteActionExecutionContext.create(requestMetadata); + + Digest digest = DigestUtil.buildDigest(metadata.getDigest(), metadata.getSize()); + + return remoteCache.downloadFile(context, path, digest); + }, + MoreExecutors.directExecutor()) + .doOnComplete(() -> finalizeDownload(path)) + .doOnError(error -> deletePartialDownload(path)) + .doOnDispose(() -> deletePartialDownload(path)); + + return downloadCache.executeIfNot(path, download); + } - ListenableFuture download = downloadsInProgress.get(path); - if (download == null) { - RequestMetadata requestMetadata = - TracingMetadataUtils.buildMetadata(buildRequestId, commandId, metadata.getActionId()); - RemoteActionExecutionContext context = RemoteActionExecutionContext.create(requestMetadata); - - Digest digest = DigestUtil.buildDigest(metadata.getDigest(), metadata.getSize()); - download = remoteCache.downloadFile(context, path, digest); - downloadsInProgress.put(path, download); - Futures.addCallback( - download, - new FutureCallback() { - @Override - public void onSuccess(Void v) { - synchronized (lock) { - downloadsInProgress.remove(path); - downloadedPaths.add(path); - } - - try { - path.chmod(0755); - } catch (IOException e) { - logger.atWarning().withCause(e).log("Failed to chmod 755 on %s", path); - } - } - - @Override - public void onFailure(Throwable t) { - synchronized (lock) { - downloadsInProgress.remove(path); - } - try { - path.delete(); - } catch (IOException e) { - logger.atWarning().withCause(e).log( - "Failed to delete output file after incomplete download: %s", path); - } - } - }, - MoreExecutors.directExecutor()); - } - return download; + private void finalizeDownload(Path path) { + try { + path.chmod(0755); + } catch (IOException e) { + logger.atWarning().withCause(e).log("Failed to chmod 755 on %s", path); + } + } + + private void deletePartialDownload(Path path) { + try { + path.delete(); + } catch (IOException e) { + logger.atWarning().withCause(e).log( + "Failed to delete output file after incomplete download: %s", path); } } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java index de26b8db8b2107..7005364f284b8c 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java @@ -13,15 +13,21 @@ // limitations under the License. package com.google.devtools.build.lib.remote.util; -import com.google.common.base.Preconditions; +import static com.google.common.base.Preconditions.checkState; + import com.google.common.collect.ImmutableSet; +import io.reactivex.rxjava3.annotations.NonNull; import io.reactivex.rxjava3.core.Completable; -import io.reactivex.rxjava3.core.Observable; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.core.SingleObserver; +import io.reactivex.rxjava3.disposables.Disposable; +import io.reactivex.rxjava3.subjects.AsyncSubject; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @@ -42,11 +48,13 @@ */ @ThreadSafe public final class AsyncTaskCache { - @GuardedBy("this") + private final Object lock = new Object(); + + @GuardedBy("lock") private final Map finished; - @GuardedBy("this") - private final Map> inProgress; + @GuardedBy("lock") + private final Map inProgress; public static AsyncTaskCache create() { return new AsyncTaskCache<>(); @@ -59,14 +67,14 @@ private AsyncTaskCache() { /** Returns a set of keys for tasks which is finished. */ public ImmutableSet getFinishedTasks() { - synchronized (this) { + synchronized (lock) { return ImmutableSet.copyOf(finished.keySet()); } } /** Returns a set of keys for tasks which is still executing. */ public ImmutableSet getInProgressTasks() { - synchronized (this) { + synchronized (lock) { return ImmutableSet.copyOf(inProgress.keySet()); } } @@ -82,6 +90,65 @@ public Single executeIfNot(KeyT key, Single task) { return execute(key, task, false); } + private class Execution { + private final Single task; + private final AsyncSubject asyncSubject = AsyncSubject.create(); + private final AtomicInteger subscriberCount = new AtomicInteger(0); + private final AtomicReference taskDisposable = new AtomicReference<>(null); + + Execution(Single task) { + this.task = task; + } + + public Single start() { + if (taskDisposable.get() == null) { + task.subscribe( + new SingleObserver() { + @Override + public void onSubscribe(@NonNull Disposable d) { + taskDisposable.compareAndSet(null, d); + } + + @Override + public void onSuccess(@NonNull ValueT value) { + asyncSubject.onNext(value); + asyncSubject.onComplete(); + } + + @Override + public void onError(@NonNull Throwable e) { + asyncSubject.onError(e); + } + }); + } + + return Single.fromObservable(asyncSubject) + .doOnSubscribe(d -> subscriberCount.incrementAndGet()) + .doOnDispose( + () -> { + if (subscriberCount.decrementAndGet() == 0) { + Disposable d = taskDisposable.get(); + if (d != null) { + d.dispose(); + } + asyncSubject.onError(new CancellationException("disposed")); + } + }); + } + } + + /** Returns count of subscribers for a task. */ + public int getSubscriberCount(KeyT key) { + synchronized (lock) { + Execution execution = inProgress.get(key); + if (execution != null) { + return execution.subscriberCount.get(); + } + } + + return 0; + } + /** * Executes a task. * @@ -93,50 +160,47 @@ public Single executeIfNot(KeyT key, Single task) { public Single execute(KeyT key, Single task, boolean force) { return Single.defer( () -> { - synchronized (this) { + synchronized (lock) { if (!force && finished.containsKey(key)) { return Single.just(finished.get(key)); } finished.remove(key); - Observable execution = + Execution execution = inProgress.computeIfAbsent( key, missingKey -> { AtomicInteger subscribeTimes = new AtomicInteger(0); - return Single.defer( - () -> { - int times = subscribeTimes.incrementAndGet(); - Preconditions.checkState( - times == 1, "Subscribed more than once to the task"); - return task; - }) - .doOnSuccess( - value -> { - synchronized (this) { - finished.put(key, value); - inProgress.remove(key); - } - }) - .doOnError( - error -> { - synchronized (this) { - inProgress.remove(key); - } - }) - .doOnDispose( - () -> { - synchronized (this) { - inProgress.remove(key); - } - }) - .toObservable() - .publish() - .refCount(); + return new Execution( + Single.defer( + () -> { + int times = subscribeTimes.incrementAndGet(); + checkState(times == 1, "Subscribed more than once to the task"); + return task; + }) + .doOnSuccess( + value -> { + synchronized (lock) { + finished.put(key, value); + inProgress.remove(key); + } + }) + .doOnError( + error -> { + synchronized (lock) { + inProgress.remove(key); + } + }) + .doOnDispose( + () -> { + synchronized (lock) { + inProgress.remove(key); + } + })); }); - return Single.fromObservable(execution); + return execution.start(); } }); } @@ -174,5 +238,10 @@ public ImmutableSet getFinishedTasks() { public ImmutableSet getInProgressTasks() { return cache.getInProgressTasks(); } + + /** Returns count of subscribers for a task. */ + public int getSubscriberCount(KeyT key) { + return cache.getSubscriberCount(key); + } } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java index 31e58214132a99..6f22ce6041b87d 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java @@ -107,7 +107,7 @@ public void testFetching() throws Exception { assertThat(a2.getPath().isExecutable()).isTrue(); assertThat(actionInputFetcher.downloadedFiles()).hasSize(2); assertThat(actionInputFetcher.downloadedFiles()).containsAtLeast(a1.getPath(), a2.getPath()); - assertThat(actionInputFetcher.downloadsInProgress).isEmpty(); + assertThat(actionInputFetcher.downloadsInProgress()).isEmpty(); } @Test @@ -127,7 +127,7 @@ public void testStagingVirtualActionInput() throws Exception { assertThat(FileSystemUtils.readContent(p, StandardCharsets.UTF_8)).isEqualTo("hello world"); assertThat(p.isExecutable()).isFalse(); assertThat(actionInputFetcher.downloadedFiles()).isEmpty(); - assertThat(actionInputFetcher.downloadsInProgress).isEmpty(); + assertThat(actionInputFetcher.downloadsInProgress()).isEmpty(); } @Test @@ -144,7 +144,7 @@ public void testStagingEmptyVirtualActionInput() throws Exception { // assert that nothing happened assertThat(actionInputFetcher.downloadedFiles()).isEmpty(); - assertThat(actionInputFetcher.downloadsInProgress).isEmpty(); + assertThat(actionInputFetcher.downloadsInProgress()).isEmpty(); } @Test @@ -167,7 +167,7 @@ public void testFileNotFound() throws Exception { // assert assertThat(actionInputFetcher.downloadedFiles()).isEmpty(); - assertThat(actionInputFetcher.downloadsInProgress).isEmpty(); + assertThat(actionInputFetcher.downloadsInProgress()).isEmpty(); } @Test @@ -189,7 +189,7 @@ public void testIgnoreNoneRemoteFiles() throws Exception { // assert assertThat(actionInputFetcher.downloadedFiles()).isEmpty(); - assertThat(actionInputFetcher.downloadsInProgress).isEmpty(); + assertThat(actionInputFetcher.downloadsInProgress()).isEmpty(); } @Test @@ -261,6 +261,113 @@ public void testDownloadFile_onInterrupt_deletePartialDownloadedFile() throws Ex assertThat(a1.getPath().exists()).isFalse(); } + @Test + public void testPrefetchFiles_multipleThreads_downloadIsNotCancelledByOtherThreads() + throws Exception { + // Test multiple threads can share downloads, but do not cancel each other when interrupted + + // arrange + Map metadata = new HashMap<>(); + Map cacheEntries = new HashMap<>(); + Artifact artifact = createRemoteArtifact("file1", "hello world", metadata, cacheEntries); + MetadataProvider metadataProvider = new StaticMetadataProvider(metadata); + SettableFuture download = SettableFuture.create(); + RemoteCache remoteCache = mock(RemoteCache.class); + when(remoteCache.downloadFile(any(), any(), any())).thenAnswer(invocation -> download); + RemoteActionInputFetcher actionInputFetcher = + new RemoteActionInputFetcher("none", "none", remoteCache, execRoot); + Thread cancelledThread = + new Thread( + () -> { + try { + actionInputFetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider); + } catch (IOException | InterruptedException ignored) { + // do nothing + } + }); + + AtomicBoolean successful = new AtomicBoolean(false); + Thread successfulThread = + new Thread( + () -> { + try { + actionInputFetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider); + successful.set(true); + } catch (IOException | InterruptedException ignored) { + // do nothing + } + }); + cancelledThread.start(); + successfulThread.start(); + while (true) { + if (actionInputFetcher + .getDownloadCache() + .getSubscriberCount(execRoot.getRelative(artifact.getExecPath())) + == 2) { + break; + } + } + + // act + cancelledThread.interrupt(); + cancelledThread.join(); + // simulate the download finishing + assertThat(download.isCancelled()).isFalse(); + download.set(null); + successfulThread.join(); + + // assert + assertThat(successful.get()).isTrue(); + } + + @Test + public void testPrefetchFiles_multipleThreads_downloadIsCancelled() throws Exception { + // Test shared downloads are cancelled if all threads/callers are interrupted + + // arrange + Map metadata = new HashMap<>(); + Map cacheEntries = new HashMap<>(); + Artifact artifact = createRemoteArtifact("file1", "hello world", metadata, cacheEntries); + MetadataProvider metadataProvider = new StaticMetadataProvider(metadata); + + SettableFuture download = SettableFuture.create(); + RemoteCache remoteCache = mock(RemoteCache.class); + when(remoteCache.downloadFile(any(), any(), any())).thenAnswer(invocation -> download); + RemoteActionInputFetcher actionInputFetcher = + new RemoteActionInputFetcher("none", "none", remoteCache, execRoot); + + Thread cancelledThread1 = + new Thread( + () -> { + try { + actionInputFetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider); + } catch (IOException | InterruptedException ignored) { + // do nothing + } + }); + + Thread cancelledThread2 = + new Thread( + () -> { + try { + actionInputFetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider); + } catch (IOException | InterruptedException ignored) { + // do nothing + } + }); + + // act + cancelledThread1.start(); + cancelledThread2.start(); + cancelledThread1.interrupt(); + cancelledThread2.interrupt(); + cancelledThread1.join(); + cancelledThread2.join(); + + // assert + assertThat(download.isCancelled()).isTrue(); + } + private Artifact createRemoteArtifact( String pathFragment, String contents,