diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD b/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD index 34f7a7863ecf3b..40a3c64b00e995 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD +++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD @@ -18,6 +18,7 @@ java_library( deps = [ "//src/main/java/com/google/devtools/build/lib/concurrent", "//third_party:guava", + "//third_party:jsr305", "//third_party:rxjava3", "//third_party/grpc:grpc-jar", ], diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/ConnectionPool.java b/src/main/java/com/google/devtools/build/lib/remote/grpc/ConnectionPool.java index ee513847873e24..2326e3189b379c 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/grpc/ConnectionPool.java +++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/ConnectionPool.java @@ -13,6 +13,7 @@ // limitations under the License. package com.google.devtools.build.lib.remote.grpc; +import io.reactivex.rxjava3.core.Single; import java.io.Closeable; import java.io.IOException; @@ -24,6 +25,13 @@ *

Connections must be closed with {@link Connection#close()} in order to be reused later. */ public interface ConnectionPool extends ConnectionFactory, Closeable { + /** + * Reuses a {@link Connection} in the pool and will potentially create a new connection depends on + * implementation. + */ + @Override + Single create(); + /** Closes the connection pool and closes all the underlying connections */ @Override void close() throws IOException; diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java b/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java new file mode 100644 index 00000000000000..d731574bbf44a8 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java @@ -0,0 +1,170 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.grpc; + +import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.MethodDescriptor; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.disposables.Disposable; +import io.reactivex.rxjava3.functions.Action; +import io.reactivex.rxjava3.subjects.AsyncSubject; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; + +/** + * A {@link ConnectionPool} that creates one connection using provided {@link ConnectionFactory} and + * shares the connection upto {@code maxConcurrency}. + * + *

This is useful if underlying connection maintains a connection pool internally. (such as + * {@code Channel} in gRPC) + * + *

Connections must be closed with {@link Connection#close()} in order to be reused later. + */ +@ThreadSafe +public class SharedConnectionFactory implements ConnectionPool { + private final TokenBucket tokenBucket; + private final ConnectionFactory factory; + + @Nullable + @GuardedBy("connectionLock") + private AsyncSubject connectionAsyncSubject = null; + + private final ReentrantLock connectionLock = new ReentrantLock(); + private final AtomicReference connectionCreationDisposable = + new AtomicReference<>(null); + + public SharedConnectionFactory(ConnectionFactory factory, int maxConcurrency) { + this.factory = factory; + + List initialTokens = new ArrayList<>(maxConcurrency); + for (int i = 0; i < maxConcurrency; ++i) { + initialTokens.add(i); + } + this.tokenBucket = new TokenBucket<>(initialTokens); + } + + @Override + public void close() throws IOException { + tokenBucket.close(); + + Disposable d = connectionCreationDisposable.getAndSet(null); + if (d != null && !d.isDisposed()) { + d.dispose(); + } + + try { + connectionLock.lockInterruptibly(); + + if (connectionAsyncSubject != null) { + Connection connection = connectionAsyncSubject.getValue(); + if (connection != null) { + connection.close(); + } + + if (!connectionAsyncSubject.hasComplete()) { + connectionAsyncSubject.onError(new IllegalStateException("closed")); + } + } + } catch (InterruptedException e) { + throw new IOException(e); + } finally { + connectionLock.unlock(); + } + } + + private AsyncSubject createUnderlyingConnectionIfNot() throws InterruptedException { + connectionLock.lockInterruptibly(); + try { + if (connectionAsyncSubject == null || connectionAsyncSubject.hasThrowable()) { + connectionAsyncSubject = + factory + .create() + .doOnSubscribe(connectionCreationDisposable::set) + .toObservable() + .subscribeWith(AsyncSubject.create()); + } + + return connectionAsyncSubject; + } finally { + connectionLock.unlock(); + } + } + + private Single acquireConnection() { + return Single.fromCallable(this::createUnderlyingConnectionIfNot) + .flatMap(Single::fromObservable); + } + + /** + * Reuses the underlying {@link Connection} and wait for it to be released if is exceeding {@code + * maxConcurrency}. + */ + @Override + public Single create() { + return tokenBucket + .acquireToken() + .flatMap( + token -> + acquireConnection() + .doOnError(ignored -> tokenBucket.addToken(token)) + .doOnDispose(() -> tokenBucket.addToken(token)) + .map( + conn -> + new SharedConnection( + conn, /* onClose= */ () -> tokenBucket.addToken(token)))); + } + + /** Returns current number of available connections. */ + public int numAvailableConnections() { + return tokenBucket.size(); + } + + /** A {@link Connection} which wraps an underlying connection and is shared between consumers. */ + public static class SharedConnection implements Connection { + private final Connection connection; + private final Action onClose; + + public SharedConnection(Connection connection, Action onClose) { + this.connection = connection; + this.onClose = onClose; + } + + @Override + public ClientCall call( + MethodDescriptor method, CallOptions options) { + return connection.call(method, options); + } + + @Override + public void close() throws IOException { + try { + onClose.run(); + } catch (Throwable t) { + throw new IOException(t); + } + } + + /** Returns the underlying connection this shared connection built on */ + public Connection getUnderlyingConnection() { + return connection; + } + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java b/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java new file mode 100644 index 00000000000000..124b79e4a2488d --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java @@ -0,0 +1,354 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.observers.TestObserver; +import io.reactivex.rxjava3.plugins.RxJavaPlugins; +import java.io.IOException; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Tests for {@link SharedConnectionFactory}. */ +@RunWith(JUnit4.class) +public class SharedConnectionFactoryTest { + @Rule public final MockitoRule mockito = MockitoJUnit.rule(); + + private final AtomicReference rxGlobalThrowable = new AtomicReference<>(null); + + @Mock private Connection connection; + @Mock private ConnectionFactory connectionFactory; + + @Before + public void setUp() { + RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set); + + when(connectionFactory.create()).thenAnswer(invocation -> Single.just(connection)); + } + + @After + public void tearDown() throws Throwable { + // Make sure rxjava didn't receive global errors + Throwable t = rxGlobalThrowable.getAndSet(null); + if (t != null) { + throw t; + } + } + + @Test + public void create_smoke() { + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); + assertThat(factory.numAvailableConnections()).isEqualTo(1); + + TestObserver observer = factory.create().test(); + + observer.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete(); + verify(connectionFactory, times(1)).create(); + assertThat(factory.numAvailableConnections()).isEqualTo(0); + } + + @Test + public void create_noConnectionCreationBeforeSubscription() { + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); + + factory.create(); + + verify(connectionFactory, times(0)).create(); + } + + @Test + public void create_exceedingMaxConcurrency_waiting() { + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); + TestObserver observer1 = factory.create().test(); + assertThat(factory.numAvailableConnections()).isEqualTo(0); + observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete(); + + TestObserver observer2 = factory.create().test(); + observer2.assertEmpty(); + } + + @Test + public void create_afterConnectionClosed_shareConnections() throws IOException { + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); + TestObserver observer1 = factory.create().test(); + assertThat(factory.numAvailableConnections()).isEqualTo(0); + observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete(); + TestObserver observer2 = factory.create().test(); + + observer1.values().get(0).close(); + + observer2.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete(); + assertThat(factory.numAvailableConnections()).isEqualTo(0); + } + + @Test + public void create_belowMaxConcurrency_shareConnections() { + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2); + + TestObserver observer1 = factory.create().test(); + assertThat(factory.numAvailableConnections()).isEqualTo(1); + observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete(); + + TestObserver observer2 = factory.create().test(); + observer2.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete(); + assertThat(factory.numAvailableConnections()).isEqualTo(0); + } + + @Test + public void create_concurrentCreate_shareConnections() throws InterruptedException { + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2); + Semaphore semaphore = new Semaphore(0); + AtomicBoolean finished = new AtomicBoolean(false); + Thread t = + new Thread( + () -> { + factory + .create() + .doOnSuccess( + conn -> { + assertThat(conn.getUnderlyingConnection()).isEqualTo(connection); + semaphore.release(); + Thread.sleep(Integer.MAX_VALUE); + finished.set(true); + }) + .blockingSubscribe(); + + finished.set(true); + }); + t.start(); + semaphore.acquire(); + + TestObserver observer = factory.create().test(); + + observer.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete(); + assertThat(finished.get()).isFalse(); + } + + @Test + public void create_afterLastFailed_success() { + AtomicInteger times = new AtomicInteger(0); + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + when(connectionFactory.create()) + .thenAnswer( + invocation -> { + if (times.getAndIncrement() == 0) { + return Single.error(new IllegalStateException("error")); + } + + return Single.just(connection); + }); + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); + Single connectionSingle = factory.create(); + + connectionSingle + .test() + .assertError(IllegalStateException.class) + .assertError(e -> e.getMessage().contains("error")); + assertThat(factory.numAvailableConnections()).isEqualTo(1); + connectionSingle + .test() + .assertValue(conn -> conn.getUnderlyingConnection() == connection) + .assertComplete(); + + assertThat(times.get()).isEqualTo(2); + assertThat(factory.numAvailableConnections()).isEqualTo(0); + } + + @Test + public void create_disposeWhenWaitingForConnectionCreation_doNotCancelCreation() + throws InterruptedException { + AtomicBoolean canceled = new AtomicBoolean(false); + AtomicBoolean finished = new AtomicBoolean(false); + Semaphore disposed = new Semaphore(0); + Semaphore terminated = new Semaphore(0); + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + when(connectionFactory.create()) + .thenAnswer( + invocation -> + Single.create( + emitter -> + new Thread( + () -> { + try { + disposed.acquire(); + finished.set(true); + emitter.onSuccess(connection); + } catch (InterruptedException e) { + emitter.onError(e); + } + terminated.release(); + }) + .start()) + .doOnDispose(() -> canceled.set(true))); + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); + TestObserver observer = factory.create().test(); + assertThat(factory.numAvailableConnections()).isEqualTo(0); + + observer.assertEmpty().dispose(); + disposed.release(); + + terminated.acquire(); + assertThat(canceled.get()).isFalse(); + assertThat(finished.get()).isTrue(); + assertThat(factory.numAvailableConnections()).isEqualTo(1); + } + + @Test + public void create_interrupt_terminate() throws InterruptedException { + AtomicBoolean finished = new AtomicBoolean(false); + AtomicBoolean interrupted = new AtomicBoolean(true); + Semaphore threadTerminatedSemaphore = new Semaphore(0); + Semaphore connectionCreationSemaphore = new Semaphore(0); + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + when(connectionFactory.create()) + .thenAnswer( + invocation -> + Single.create( + emitter -> + new Thread( + () -> { + try { + Thread.sleep(Integer.MAX_VALUE); + finished.set(true); + emitter.onSuccess(connectionFactory); + } catch (InterruptedException e) { + emitter.onError(e); + } + }) + .start())); + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2); + factory.create().test().assertEmpty(); + Thread t = + new Thread( + () -> { + try { + TestObserver observer = factory.create().test(); + connectionCreationSemaphore.release(); + observer.await(); + } catch (InterruptedException e) { + interrupted.set(true); + } + + threadTerminatedSemaphore.release(); + }); + t.start(); + + connectionCreationSemaphore.acquire(); + t.interrupt(); + threadTerminatedSemaphore.acquire(); + + assertThat(finished.get()).isFalse(); + assertThat(interrupted.get()).isTrue(); + } + + @Test + public void closeConnection_connectionBecomeAvailable() throws IOException { + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); + TestObserver observer = factory.create().test(); + observer.assertComplete(); + SharedConnection conn = observer.values().get(0); + assertThat(factory.numAvailableConnections()).isEqualTo(0); + + conn.close(); + + assertThat(factory.numAvailableConnections()).isEqualTo(1); + verify(connection, times(0)).close(); + } + + @Test + public void closeFactory_closeUnderlyingConnection() throws IOException { + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); + TestObserver observer = factory.create().test(); + observer.assertComplete(); + + factory.close(); + + verify(connection, times(1)).close(); + } + + @Test + public void closeFactory_noNewConnectionAllowed() throws IOException { + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); + factory.close(); + + TestObserver observer = factory.create().test(); + + observer + .assertError(IllegalStateException.class) + .assertError(e -> e.getMessage().contains("closed")); + } + + @Test + public void closeFactory_pendingConnectionCreation_closedError() + throws IOException, InterruptedException { + AtomicBoolean canceled = new AtomicBoolean(false); + AtomicBoolean finished = new AtomicBoolean(false); + Semaphore terminated = new Semaphore(0); + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + when(connectionFactory.create()) + .thenAnswer( + invocation -> + Single.create( + emitter -> { + Thread t = + new Thread( + () -> { + try { + Thread.sleep(Integer.MAX_VALUE); + finished.set(true); + emitter.onSuccess(connection); + } catch (InterruptedException ignored) { + /* no-op */ + } + + terminated.release(); + }); + t.start(); + + emitter.setCancellable(t::interrupt); + }) + .doOnDispose(() -> canceled.set(true))); + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1); + TestObserver observer = factory.create().test(); + observer.assertEmpty(); + + assertThat(canceled.get()).isFalse(); + factory.close(); + + terminated.acquire(); + observer + .assertError(IllegalStateException.class) + .assertError(e -> e.getMessage().contains("closed")); + assertThat(canceled.get()).isTrue(); + assertThat(finished.get()).isFalse(); + } +}