From 662cf54de7a103db30e04ebae2d2b919437c4846 Mon Sep 17 00:00:00 2001 From: chiwang Date: Thu, 25 Mar 2021 18:57:09 -0700 Subject: [PATCH] Remote: Fix an issue that a failed action could lead to RuntimeException caused by InterruptedException thrown when acquiring gRPC connections. https://github.com/bazelbuild/bazel/issues/13239 When --keep_going is not enabled, Bazel will cancel other executing actions if an action failed. An action which is executing remotely could in the state of waiting for a lock available to acquire the gRPC connection. SharedConnectionFactory uses ReentrantLock#lockInterruptibly to acquire the lock and will throw InterruptedException when the thread is interrupted which happens when the action is cancelled by Bazel. However, this InterruptedException is wrapped inside a RuntimeException results in a build error. ReentrantLock was choosen initially to implement a hand-over-hand locking algorithem but it's no longer necessary after a few iterations. This change replaces ReentrantLock with `synchronized` keyword so we won't throw InterruptedException when acquiring gRPC connections. Call sites can still throw InterruptedException to cancel an action execution. PiperOrigin-RevId: 365170212 --- .../remote/grpc/SharedConnectionFactory.java | 22 ++---- .../devtools/build/lib/remote/grpc/BUILD | 1 + .../grpc/SharedConnectionFactoryTest.java | 70 ++++++++----------- 3 files changed, 37 insertions(+), 56 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java b/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java index d731574bbf44a8..a606e3d964f643 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java +++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java @@ -25,7 +25,6 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicReference; -import java.util.concurrent.locks.ReentrantLock; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -44,10 +43,9 @@ public class SharedConnectionFactory implements ConnectionPool { private final ConnectionFactory factory; @Nullable - @GuardedBy("connectionLock") + @GuardedBy("this") private AsyncSubject connectionAsyncSubject = null; - private final ReentrantLock connectionLock = new ReentrantLock(); private final AtomicReference connectionCreationDisposable = new AtomicReference<>(null); @@ -70,9 +68,7 @@ public void close() throws IOException { d.dispose(); } - try { - connectionLock.lockInterruptibly(); - + synchronized (this) { if (connectionAsyncSubject != null) { Connection connection = connectionAsyncSubject.getValue(); if (connection != null) { @@ -83,16 +79,11 @@ public void close() throws IOException { connectionAsyncSubject.onError(new IllegalStateException("closed")); } } - } catch (InterruptedException e) { - throw new IOException(e); - } finally { - connectionLock.unlock(); } } - private AsyncSubject createUnderlyingConnectionIfNot() throws InterruptedException { - connectionLock.lockInterruptibly(); - try { + private AsyncSubject createUnderlyingConnectionIfNot() { + synchronized (this) { if (connectionAsyncSubject == null || connectionAsyncSubject.hasThrowable()) { connectionAsyncSubject = factory @@ -103,14 +94,11 @@ private AsyncSubject createUnderlyingConnectionIfNot() throws Interr } return connectionAsyncSubject; - } finally { - connectionLock.unlock(); } } private Single acquireConnection() { - return Single.fromCallable(this::createUnderlyingConnectionIfNot) - .flatMap(Single::fromObservable); + return Single.fromObservable(createUnderlyingConnectionIfNot()); } /** diff --git a/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD b/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD index 8df04632cb3a8d..ddf9c27ad51d8e 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD +++ b/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD @@ -22,6 +22,7 @@ java_test( deps = [ "//src/main/java/com/google/devtools/build/lib/remote/grpc", "//src/test/java/com/google/devtools/build/lib:test_runner", + "//src/test/java/com/google/devtools/build/lib/remote/util", "//third_party:guava", "//third_party:junit4", "//third_party:mockito", diff --git a/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java b/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java index 124b79e4a2488d..ad3f3c73f1999d 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java @@ -20,15 +20,14 @@ import static org.mockito.Mockito.when; import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection; +import com.google.devtools.build.lib.remote.util.RxNoGlobalErrorsRule; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.observers.TestObserver; -import io.reactivex.rxjava3.plugins.RxJavaPlugins; import java.io.IOException; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -42,28 +41,16 @@ @RunWith(JUnit4.class) public class SharedConnectionFactoryTest { @Rule public final MockitoRule mockito = MockitoJUnit.rule(); - - private final AtomicReference rxGlobalThrowable = new AtomicReference<>(null); + @Rule public final RxNoGlobalErrorsRule rxNoGlobalErrorsRule = new RxNoGlobalErrorsRule(); @Mock private Connection connection; @Mock private ConnectionFactory connectionFactory; @Before public void setUp() { - RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set); - when(connectionFactory.create()).thenAnswer(invocation -> Single.just(connection)); } - @After - public void tearDown() throws Throwable { - // Make sure rxjava didn't receive global errors - Throwable t = rxGlobalThrowable.getAndSet(null); - if (t != null) { - throw t; - } - } - @Test public void create_smoke() { SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); @@ -125,32 +112,37 @@ public void create_belowMaxConcurrency_shareConnections() { @Test public void create_concurrentCreate_shareConnections() throws InterruptedException { - SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2); - Semaphore semaphore = new Semaphore(0); - AtomicBoolean finished = new AtomicBoolean(false); - Thread t = - new Thread( - () -> { - factory - .create() - .doOnSuccess( - conn -> { - assertThat(conn.getUnderlyingConnection()).isEqualTo(connection); - semaphore.release(); - Thread.sleep(Integer.MAX_VALUE); - finished.set(true); - }) - .blockingSubscribe(); - - finished.set(true); - }); - t.start(); - semaphore.acquire(); + int maxConcurrency = 10; + SharedConnectionFactory factory = + new SharedConnectionFactory(connectionFactory, maxConcurrency); + AtomicReference error = new AtomicReference<>(null); + Runnable runnable = + () -> { + try { + TestObserver observer = factory.create().test(); + + observer + .assertNoErrors() + .assertValue(conn -> conn.getUnderlyingConnection() == connection) + .assertComplete(); + } catch (Throwable e) { + error.set(e); + } + }; + Thread[] threads = new Thread[maxConcurrency]; + for (int i = 0; i < threads.length; ++i) { + threads[i] = new Thread(runnable); + } - TestObserver observer = factory.create().test(); + for (Thread thread : threads) { + thread.start(); + } + for (Thread thread : threads) { + thread.join(); + } - observer.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete(); - assertThat(finished.get()).isFalse(); + assertThat(error.get()).isNull(); + verify(connectionFactory, times(1)).create(); } @Test