diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java index c3d2c25267f50d..f6f70a643764c4 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java @@ -15,20 +15,19 @@ import static com.google.common.base.Preconditions.checkState; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.reactivex.rxjava3.annotations.NonNull; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.core.SingleObserver; import io.reactivex.rxjava3.disposables.Disposable; -import io.reactivex.rxjava3.subjects.AsyncSubject; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @@ -55,7 +54,7 @@ public final class AsyncTaskCache { private final Map finished; @GuardedBy("lock") - private final Map> inProgress; + private final Map inProgress; public static AsyncTaskCache create() { return new AsyncTaskCache<>(); @@ -91,79 +90,136 @@ public Single executeIfNot(KeyT key, Single task) { return execute(key, task, false); } - private static class Execution { - private final AtomicBoolean isTaskDisposed = new AtomicBoolean(false); - private final Single task; - private final AsyncSubject asyncSubject = AsyncSubject.create(); - private final AtomicInteger referenceCount = new AtomicInteger(0); - private final AtomicReference taskDisposable = new AtomicReference<>(null); + /** Returns count of subscribers for a task. */ + public int getSubscriberCount(KeyT key) { + synchronized (lock) { + Execution task = inProgress.get(key); + if (task != null) { + return task.getSubscriberCount(); + } + } + + return 0; + } + + class Execution extends Single implements SingleObserver { + private final KeyT key; + private final Single upstream; + + @GuardedBy("lock") + private boolean terminated = false; + + @GuardedBy("lock") + private Disposable upstreamDisposable; - Execution(Single task) { - this.task = task; + @GuardedBy("lock") + private final List> observers = new ArrayList<>(); + + Execution(KeyT key, Single upstream) { + this.key = key; + this.upstream = upstream; } - Single executeIfNot() { - checkState(!isTaskDisposed(), "disposed"); - - int subscribed = referenceCount.getAndIncrement(); - if (taskDisposable.get() == null && subscribed == 0) { - task.subscribe( - new SingleObserver() { - @Override - public void onSubscribe(@NonNull Disposable d) { - taskDisposable.compareAndSet(null, d); - } - - @Override - public void onSuccess(@NonNull ValueT value) { - asyncSubject.onNext(value); - asyncSubject.onComplete(); - } - - @Override - public void onError(@NonNull Throwable e) { - asyncSubject.onError(e); - } - }); + int getSubscriberCount() { + synchronized (lock) { + return observers.size(); } + } + + @Override + protected void subscribeActual(@NonNull SingleObserver observer) { + synchronized (lock) { + checkState(!terminated, "terminated"); + + boolean shouldSubscribe = observers.isEmpty(); + + observers.add(observer); + + observer.onSubscribe(new ExecutionDisposable(this, observer)); - return Single.fromObservable(asyncSubject); + if (shouldSubscribe) { + upstream.subscribe(this); + } + } } - boolean isTaskTerminated() { - return asyncSubject.hasComplete() || asyncSubject.hasThrowable(); + @Override + public void onSubscribe(@NonNull Disposable d) { + synchronized (lock) { + upstreamDisposable = d; + + if (terminated) { + d.dispose(); + } + } } - boolean isTaskDisposed() { - return isTaskDisposed.get(); + @Override + public void onSuccess(@NonNull ValueT value) { + synchronized (lock) { + if (!terminated) { + inProgress.remove(key); + finished.put(key, value); + terminated = true; + + for (SingleObserver observer : ImmutableList.copyOf(observers)) { + observer.onSuccess(value); + } + } + } } - void tryDisposeTask() { - checkState(!isTaskDisposed(), "disposed"); - checkState(!isTaskTerminated(), "terminated"); + @Override + public void onError(@NonNull Throwable error) { + synchronized (lock) { + if (!terminated) { + inProgress.remove(key); + terminated = true; - if (referenceCount.decrementAndGet() == 0) { - isTaskDisposed.set(true); - asyncSubject.onError(new CancellationException("disposed")); + for (SingleObserver observer : ImmutableList.copyOf(observers)) { + observer.onError(error); + } + } + } + } - Disposable d = taskDisposable.get(); - if (d != null) { - d.dispose(); + void remove(SingleObserver observer) { + synchronized (lock) { + observers.remove(observer); + + if (observers.isEmpty() && !terminated) { + inProgress.remove(key); + terminated = true; + + if (upstreamDisposable != null) { + upstreamDisposable.dispose(); + } } } } } - /** Returns count of subscribers for a task. */ - public int getSubscriberCount(KeyT key) { - synchronized (lock) { - Execution execution = inProgress.get(key); - if (execution != null) { - return execution.referenceCount.get(); + class ExecutionDisposable implements Disposable { + final Execution execution; + final SingleObserver observer; + AtomicBoolean isDisposed = new AtomicBoolean(false); + + ExecutionDisposable(Execution execution, SingleObserver observer) { + this.execution = execution; + this.observer = observer; + } + + @Override + public void dispose() { + if (isDisposed.compareAndSet(false, true)) { + execution.remove(observer); } } - return 0; + @Override + public boolean isDisposed() { + return isDisposed.get(); + } } /** @@ -185,62 +241,34 @@ public Single execute(KeyT key, Single task, boolean force) { finished.remove(key); - Execution execution = - inProgress.computeIfAbsent( - key, - ignoredKey -> { - AtomicInteger subscribeTimes = new AtomicInteger(0); - return new Execution<>( - Single.defer( - () -> { - int times = subscribeTimes.incrementAndGet(); - checkState(times == 1, "Subscribed more than once to the task"); - return task; - })); - }); - - execution - .executeIfNot() - .subscribe( - new SingleObserver() { - @Override - public void onSubscribe(@NonNull Disposable d) { - emitter.setCancellable( - () -> { - d.dispose(); - - if (!execution.isTaskTerminated()) { - synchronized (lock) { - execution.tryDisposeTask(); - if (execution.isTaskDisposed()) { - inProgress.remove(key); - } - } - } - }); - } - - @Override - public void onSuccess(@NonNull ValueT value) { - synchronized (lock) { - finished.put(key, value); - inProgress.remove(key); - } - - emitter.onSuccess(value); - } - - @Override - public void onError(@NonNull Throwable e) { - synchronized (lock) { - inProgress.remove(key); - } - - if (!emitter.isDisposed()) { - emitter.onError(e); - } - } - }); + Execution execution = + inProgress.computeIfAbsent(key, ignoredKey -> new Execution(key, task)); + + // We must subscribe the execution within the scope of lock to avoid race condition + // that: + // 1. Two callers get the same execution instance + // 2. One decides to dispose the execution, since no more observers, the execution + // will change to the terminate state + // 3. Another one try to subscribe, will get "terminated" error. + execution.subscribe( + new SingleObserver() { + @Override + public void onSubscribe(@NonNull Disposable d) { + emitter.setDisposable(d); + } + + @Override + public void onSuccess(@NonNull ValueT valueT) { + emitter.onSuccess(valueT); + } + + @Override + public void onError(@NonNull Throwable e) { + if (!emitter.isDisposed()) { + emitter.onError(e); + } + } + }); } }); } diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java index 8e3ee28b2cee30..279c342baa7f94 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java @@ -14,11 +14,18 @@ package com.google.devtools.build.lib.remote.util; import static com.google.common.truth.Truth.assertThat; -import static java.util.concurrent.TimeUnit.SECONDS; +import com.google.common.util.concurrent.SettableFuture; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.core.SingleEmitter; import io.reactivex.rxjava3.observers.TestObserver; +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -282,43 +289,101 @@ public void execute_multipleTasks_completeOne() { assertThat(cache.getFinishedTasks()).containsExactly("key1"); } + private Completable newTask(ExecutorService executorService) { + return RxFutures.toCompletable( + () -> { + SettableFuture future = SettableFuture.create(); + executorService.execute( + () -> { + try { + Thread.sleep((long) (Math.random() * 1000)); + future.set(null); + } catch (InterruptedException e) { + future.setException(new IOException(e)); + } + }); + return future; + }, + executorService); + } + @Test - public void execute_executeAndDisposeLoop_noErrors() throws InterruptedException { - AsyncTaskCache cache = AsyncTaskCache.create(); - Single task = Single.timer(1, SECONDS); + public void execute_executeAndDisposeLoop_noErrors() throws Throwable { + int taskCount = 1000; + int maxKey = 20; + Random random = new Random(); + ExecutorService executorService = Executors.newFixedThreadPool(taskCount); + AsyncTaskCache.NoResult cache = AsyncTaskCache.NoResult.create(); AtomicReference error = new AtomicReference<>(null); - AtomicInteger errorCount = new AtomicInteger(0); - int executionCount = 100; - Runnable runnable = - () -> { - try { - for (int i = 0; i < executionCount; ++i) { - TestObserver observer = cache.execute("key1", task, true).test(); + Semaphore semaphore = new Semaphore(0); + + for (int i = 0; i < taskCount; ++i) { + executorService.execute( + () -> { + try { + Completable task = + cache.execute("key" + random.nextInt(maxKey), newTask(executorService), true); + TestObserver observer = task.test(); observer.assertNoErrors(); - observer.dispose(); + if (random.nextBoolean()) { + observer.dispose(); + } else { + observer.await(); + observer.assertNoErrors(); + } + } catch (Throwable e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + error.set(e); + } finally { + semaphore.release(); } - } catch (Throwable t) { - errorCount.incrementAndGet(); - error.set(t); - } - }; - int threadCount = 10; - Thread[] threads = new Thread[threadCount]; - for (int i = 0; i < threadCount; ++i) { - Thread thread = new Thread(runnable); - threads[i] = thread; + }); } + semaphore.acquire(taskCount); - for (Thread thread : threads) { - thread.start(); + if (error.get() != null) { + throw error.get(); } - for (Thread thread : threads) { - thread.join(); + } + + @Test + public void execute_executeWithFutureAndCancelLoop_noErrors() throws Throwable { + int taskCount = 1000; + int maxKey = 20; + Random random = new Random(); + ExecutorService executorService = Executors.newFixedThreadPool(taskCount); + AsyncTaskCache.NoResult cache = AsyncTaskCache.NoResult.create(); + AtomicReference error = new AtomicReference<>(null); + Semaphore semaphore = new Semaphore(0); + + for (int i = 0; i < taskCount; ++i) { + executorService.execute( + () -> { + try { + Completable download = + cache.execute("key" + random.nextInt(maxKey), newTask(executorService), true); + Future future = RxFutures.toListenableFuture(download); + if (!future.isDone() && random.nextBoolean()) { + future.cancel(true); + } else { + future.get(); + } + } catch (Throwable e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + error.set(e); + } finally { + semaphore.release(); + } + }); } + semaphore.acquire(taskCount); if (error.get() != null) { - throw new IllegalStateException( - String.format("%s/%s errors", errorCount.get(), threadCount), error.get()); + throw error.get(); } } }