From 7c081eb020186bfb16d4ef1c3832a8e946e99da1 Mon Sep 17 00:00:00 2001 From: Googler Date: Thu, 18 Feb 2021 00:01:19 -0800 Subject: [PATCH] Remote: gRPC load balancing. (Part 4) Implement DynamicConnectionPool which is built on top of SharedConnectionFactory. It creates connections on demands, applies rate limiting on the underying connection and uses Round-Robin algorithm to load balancing across multiple connections. PiperOrigin-RevId: 358116905 --- .../remote/grpc/DynamicConnectionPool.java | 98 ++++++++ .../grpc/DynamicConnectionPoolTest.java | 229 ++++++++++++++++++ 2 files changed, 327 insertions(+) create mode 100644 src/main/java/com/google/devtools/build/lib/remote/grpc/DynamicConnectionPool.java create mode 100644 src/test/java/com/google/devtools/build/lib/remote/grpc/DynamicConnectionPoolTest.java diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/DynamicConnectionPool.java b/src/main/java/com/google/devtools/build/lib/remote/grpc/DynamicConnectionPool.java new file mode 100644 index 00000000000000..480ed66241f5c1 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/DynamicConnectionPool.java @@ -0,0 +1,98 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.grpc; + +import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection; +import io.reactivex.rxjava3.core.Single; +import java.io.IOException; +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.annotation.concurrent.GuardedBy; + +/** + * A {@link ConnectionPool} that creates new connection with given {@link ConnectionFactory} on + * demand and applies rate limiting w.r.t {@code maxConcurrencyPerConnection} for one underlying + * connection. It also uses Round-Robin algorithm to load balancing between underlying connections. + * + *

Connections must be closed with {@link Connection#close()} in order to be reused later. + */ +public class DynamicConnectionPool implements ConnectionPool { + private final ConnectionFactory connectionFactory; + private final int maxConcurrencyPerConnection; + private final AtomicBoolean closed = new AtomicBoolean(false); + + @GuardedBy("this") + private final ArrayList factories; + + @GuardedBy("this") + private int indexTicker = 0; + + public DynamicConnectionPool( + ConnectionFactory connectionFactory, int maxConcurrencyPerConnection) { + this.connectionFactory = connectionFactory; + this.maxConcurrencyPerConnection = maxConcurrencyPerConnection; + this.factories = new ArrayList<>(); + } + + @Override + public void close() throws IOException { + if (closed.compareAndSet(false, true)) { + synchronized (this) { + for (SharedConnectionFactory factory : factories) { + factory.close(); + } + factories.clear(); + } + } + } + + /** + * Performs a simple round robin on the list of {@link SharedConnectionFactory} and return one + * having available connections at this moment. + * + *

If no factory has available connections, it will create a new {@link + * SharedConnectionFactory}. + */ + private SharedConnectionFactory nextAvailableFactory() { + if (closed.get()) { + throw new IllegalStateException("closed"); + } + + synchronized (this) { + for (int times = 0; times < factories.size(); ++times) { + int index = Math.abs(indexTicker % factories.size()); + indexTicker += 1; + + SharedConnectionFactory factory = factories.get(index); + if (factory.numAvailableConnections() > 0) { + return factory; + } + } + + SharedConnectionFactory factory = + new SharedConnectionFactory(connectionFactory, maxConcurrencyPerConnection); + factories.add(factory); + return factory; + } + } + + @Override + public Single create() { + return Single.defer( + () -> { + SharedConnectionFactory factory = nextAvailableFactory(); + return factory.create(); + }); + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/grpc/DynamicConnectionPoolTest.java b/src/test/java/com/google/devtools/build/lib/remote/grpc/DynamicConnectionPoolTest.java new file mode 100644 index 00000000000000..89d6a99a059ca2 --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/grpc/DynamicConnectionPoolTest.java @@ -0,0 +1,229 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.observers.TestObserver; +import io.reactivex.rxjava3.plugins.RxJavaPlugins; +import java.io.IOException; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Tests for {@link DynamicConnectionPool}. */ +@RunWith(JUnit4.class) +public class DynamicConnectionPoolTest { + @Rule public final MockitoRule mockito = MockitoJUnit.rule(); + private final AtomicReference rxGlobalThrowable = new AtomicReference<>(null); + + @Mock private Connection connection0; + @Mock private Connection connection1; + @Mock private ConnectionFactory connectionFactory; + private final AtomicInteger connectionFactoryCreateTimes = new AtomicInteger(0); + + @Before + public void setUp() { + RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set); + + when(connectionFactory.create()) + .thenAnswer( + invocation -> { + int times = connectionFactoryCreateTimes.getAndIncrement(); + if (times == 0) { + return Single.just(connection0); + } else { + return Single.just(connection1); + } + }); + } + + @After + public void tearDown() throws Throwable { + // Make sure rxjava didn't receive global errors + Throwable t = rxGlobalThrowable.getAndSet(null); + if (t != null) { + throw t; + } + } + + @Test + public void create_smoke() { + DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1); + + TestObserver observer = pool.create().test(); + + observer.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete(); + assertThat(connectionFactoryCreateTimes.get()).isEqualTo(1); + } + + @Test + public void create_exceedingMaxConcurrent_createNewConnection() { + DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1); + + TestObserver observer0 = pool.create().test(); + TestObserver observer1 = pool.create().test(); + + observer0.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete(); + observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection1).assertComplete(); + assertThat(connectionFactoryCreateTimes.get()).isEqualTo(2); + } + + @Test + public void create_pendingConnectionCreationAndExceedingMaxConcurrent_createNewConnection() { + AtomicBoolean terminated = new AtomicBoolean(false); + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + when(connectionFactory.create()) + .thenAnswer( + invocation -> { + if (connectionFactoryCreateTimes.getAndIncrement() == 0) { + return Single.create( + emitter -> { + Thread t = + new Thread( + () -> { + try { + Thread.sleep(Integer.MAX_VALUE); + emitter.onSuccess(connection0); + } catch (InterruptedException e) { + emitter.onError(e); + } + terminated.set(true); + }); + t.start(); + }); + } else { + return Single.just(connection1); + } + }); + DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1); + + TestObserver observer0 = pool.create().test(); + TestObserver observer1 = pool.create().test(); + + assertThat(terminated.get()).isFalse(); + observer0.assertEmpty(); + observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection1).assertComplete(); + assertThat(connectionFactoryCreateTimes.get()).isEqualTo(2); + } + + @Test + public void create_belowMaxConcurrency_shareConnections() { + DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 2); + + TestObserver observer0 = pool.create().test(); + TestObserver observer1 = pool.create().test(); + + observer0.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete(); + observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete(); + assertThat(connectionFactoryCreateTimes.get()).isEqualTo(1); + } + + @Test + public void create_afterConnectionClosed_shareConnections() throws IOException { + DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1); + TestObserver observer0 = pool.create().test(); + observer0.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete(); + observer0.values().get(0).close(); + + TestObserver observer1 = pool.create().test(); + + observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete(); + assertThat(connectionFactoryCreateTimes.get()).isEqualTo(1); + } + + @Test + public void closePool_noNewConnectionAllowed() throws IOException { + DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1); + pool.close(); + + TestObserver observer = pool.create().test(); + + observer + .assertError(IllegalStateException.class) + .assertError(e -> e.getMessage().contains("closed")); + } + + @Test + public void closePool_closeUnderlyingConnection() throws IOException { + DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1); + TestObserver observer = pool.create().test(); + observer.assertComplete(); + + pool.close(); + + verify(connection0, times(1)).close(); + } + + @Test + public void closePool_pendingConnectionCreation_closedError() + throws IOException, InterruptedException { + AtomicBoolean canceled = new AtomicBoolean(false); + AtomicBoolean finished = new AtomicBoolean(false); + Semaphore terminated = new Semaphore(0); + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + when(connectionFactory.create()) + .thenAnswer( + invocation -> + Single.create( + emitter -> { + Thread t = + new Thread( + () -> { + try { + Thread.sleep(Integer.MAX_VALUE); + finished.set(true); + emitter.onSuccess(connection0); + } catch (InterruptedException ignored) { + /* no-op */ + } + + terminated.release(); + }); + t.start(); + + emitter.setCancellable(t::interrupt); + }) + .doOnDispose(() -> canceled.set(true))); + DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1); + TestObserver observer = pool.create().test(); + observer.assertEmpty(); + + assertThat(canceled.get()).isFalse(); + pool.close(); + + terminated.acquire(); + observer + .assertError(IllegalStateException.class) + .assertError(e -> e.getMessage().contains("closed")); + assertThat(canceled.get()).isTrue(); + assertThat(finished.get()).isFalse(); + } +}