Skip to content

Commit

Permalink
Remote: gRPC load balancing. (Part 4)
Browse files Browse the repository at this point in the history
Implement DynamicConnectionPool which is built on top of SharedConnectionFactory. It creates connections on demands, applies rate limiting on the underying connection and uses Round-Robin algorithm to load balancing across multiple connections.

PiperOrigin-RevId: 358116905
  • Loading branch information
Googler authored and copybara-github committed Feb 18, 2021
1 parent 6667ad7 commit 7c081eb
Show file tree
Hide file tree
Showing 2 changed files with 327 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2021 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.remote.grpc;

import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection;
import io.reactivex.rxjava3.core.Single;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.concurrent.GuardedBy;

/**
* A {@link ConnectionPool} that creates new connection with given {@link ConnectionFactory} on
* demand and applies rate limiting w.r.t {@code maxConcurrencyPerConnection} for one underlying
* connection. It also uses Round-Robin algorithm to load balancing between underlying connections.
*
* <p>Connections must be closed with {@link Connection#close()} in order to be reused later.
*/
public class DynamicConnectionPool implements ConnectionPool {
private final ConnectionFactory connectionFactory;
private final int maxConcurrencyPerConnection;
private final AtomicBoolean closed = new AtomicBoolean(false);

@GuardedBy("this")
private final ArrayList<SharedConnectionFactory> factories;

@GuardedBy("this")
private int indexTicker = 0;

public DynamicConnectionPool(
ConnectionFactory connectionFactory, int maxConcurrencyPerConnection) {
this.connectionFactory = connectionFactory;
this.maxConcurrencyPerConnection = maxConcurrencyPerConnection;
this.factories = new ArrayList<>();
}

@Override
public void close() throws IOException {
if (closed.compareAndSet(false, true)) {
synchronized (this) {
for (SharedConnectionFactory factory : factories) {
factory.close();
}
factories.clear();
}
}
}

/**
* Performs a simple round robin on the list of {@link SharedConnectionFactory} and return one
* having available connections at this moment.
*
* <p>If no factory has available connections, it will create a new {@link
* SharedConnectionFactory}.
*/
private SharedConnectionFactory nextAvailableFactory() {
if (closed.get()) {
throw new IllegalStateException("closed");
}

synchronized (this) {
for (int times = 0; times < factories.size(); ++times) {
int index = Math.abs(indexTicker % factories.size());
indexTicker += 1;

SharedConnectionFactory factory = factories.get(index);
if (factory.numAvailableConnections() > 0) {
return factory;
}
}

SharedConnectionFactory factory =
new SharedConnectionFactory(connectionFactory, maxConcurrencyPerConnection);
factories.add(factory);
return factory;
}
}

@Override
public Single<SharedConnection> create() {
return Single.defer(
() -> {
SharedConnectionFactory factory = nextAvailableFactory();
return factory.create();
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
// Copyright 2021 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.remote.grpc;

import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.observers.TestObserver;
import io.reactivex.rxjava3.plugins.RxJavaPlugins;
import java.io.IOException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;

/** Tests for {@link DynamicConnectionPool}. */
@RunWith(JUnit4.class)
public class DynamicConnectionPoolTest {
@Rule public final MockitoRule mockito = MockitoJUnit.rule();
private final AtomicReference<Throwable> rxGlobalThrowable = new AtomicReference<>(null);

@Mock private Connection connection0;
@Mock private Connection connection1;
@Mock private ConnectionFactory connectionFactory;
private final AtomicInteger connectionFactoryCreateTimes = new AtomicInteger(0);

@Before
public void setUp() {
RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set);

when(connectionFactory.create())
.thenAnswer(
invocation -> {
int times = connectionFactoryCreateTimes.getAndIncrement();
if (times == 0) {
return Single.just(connection0);
} else {
return Single.just(connection1);
}
});
}

@After
public void tearDown() throws Throwable {
// Make sure rxjava didn't receive global errors
Throwable t = rxGlobalThrowable.getAndSet(null);
if (t != null) {
throw t;
}
}

@Test
public void create_smoke() {
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);

TestObserver<SharedConnection> observer = pool.create().test();

observer.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
assertThat(connectionFactoryCreateTimes.get()).isEqualTo(1);
}

@Test
public void create_exceedingMaxConcurrent_createNewConnection() {
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);

TestObserver<SharedConnection> observer0 = pool.create().test();
TestObserver<SharedConnection> observer1 = pool.create().test();

observer0.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection1).assertComplete();
assertThat(connectionFactoryCreateTimes.get()).isEqualTo(2);
}

@Test
public void create_pendingConnectionCreationAndExceedingMaxConcurrent_createNewConnection() {
AtomicBoolean terminated = new AtomicBoolean(false);
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
when(connectionFactory.create())
.thenAnswer(
invocation -> {
if (connectionFactoryCreateTimes.getAndIncrement() == 0) {
return Single.create(
emitter -> {
Thread t =
new Thread(
() -> {
try {
Thread.sleep(Integer.MAX_VALUE);
emitter.onSuccess(connection0);
} catch (InterruptedException e) {
emitter.onError(e);
}
terminated.set(true);
});
t.start();
});
} else {
return Single.just(connection1);
}
});
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);

TestObserver<SharedConnection> observer0 = pool.create().test();
TestObserver<SharedConnection> observer1 = pool.create().test();

assertThat(terminated.get()).isFalse();
observer0.assertEmpty();
observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection1).assertComplete();
assertThat(connectionFactoryCreateTimes.get()).isEqualTo(2);
}

@Test
public void create_belowMaxConcurrency_shareConnections() {
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 2);

TestObserver<SharedConnection> observer0 = pool.create().test();
TestObserver<SharedConnection> observer1 = pool.create().test();

observer0.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
assertThat(connectionFactoryCreateTimes.get()).isEqualTo(1);
}

@Test
public void create_afterConnectionClosed_shareConnections() throws IOException {
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
TestObserver<SharedConnection> observer0 = pool.create().test();
observer0.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
observer0.values().get(0).close();

TestObserver<SharedConnection> observer1 = pool.create().test();

observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection0).assertComplete();
assertThat(connectionFactoryCreateTimes.get()).isEqualTo(1);
}

@Test
public void closePool_noNewConnectionAllowed() throws IOException {
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
pool.close();

TestObserver<SharedConnection> observer = pool.create().test();

observer
.assertError(IllegalStateException.class)
.assertError(e -> e.getMessage().contains("closed"));
}

@Test
public void closePool_closeUnderlyingConnection() throws IOException {
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
TestObserver<SharedConnection> observer = pool.create().test();
observer.assertComplete();

pool.close();

verify(connection0, times(1)).close();
}

@Test
public void closePool_pendingConnectionCreation_closedError()
throws IOException, InterruptedException {
AtomicBoolean canceled = new AtomicBoolean(false);
AtomicBoolean finished = new AtomicBoolean(false);
Semaphore terminated = new Semaphore(0);
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
when(connectionFactory.create())
.thenAnswer(
invocation ->
Single.create(
emitter -> {
Thread t =
new Thread(
() -> {
try {
Thread.sleep(Integer.MAX_VALUE);
finished.set(true);
emitter.onSuccess(connection0);
} catch (InterruptedException ignored) {
/* no-op */
}

terminated.release();
});
t.start();

emitter.setCancellable(t::interrupt);
})
.doOnDispose(() -> canceled.set(true)));
DynamicConnectionPool pool = new DynamicConnectionPool(connectionFactory, 1);
TestObserver<SharedConnection> observer = pool.create().test();
observer.assertEmpty();

assertThat(canceled.get()).isFalse();
pool.close();

terminated.acquire();
observer
.assertError(IllegalStateException.class)
.assertError(e -> e.getMessage().contains("closed"));
assertThat(canceled.get()).isTrue();
assertThat(finished.get()).isFalse();
}
}

0 comments on commit 7c081eb

Please sign in to comment.