Skip to content

Commit

Permalink
Remote: gRPC load balancing. (Part 3)
Browse files Browse the repository at this point in the history
Implement SharedConnectionFactory which applys rate limiting on top of one connection.

PiperOrigin-RevId: 358084865
  • Loading branch information
Googler authored and philwo committed Mar 15, 2021
1 parent 6ad192b commit c5fedad
Show file tree
Hide file tree
Showing 4 changed files with 533 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ java_library(
deps = [
"//src/main/java/com/google/devtools/build/lib/concurrent",
"//third_party:guava",
"//third_party:jsr305",
"//third_party:rxjava3",
"//third_party/grpc:grpc-jar",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.
package com.google.devtools.build.lib.remote.grpc;

import io.reactivex.rxjava3.core.Single;
import java.io.Closeable;
import java.io.IOException;

Expand All @@ -24,6 +25,13 @@
* <p>Connections must be closed with {@link Connection#close()} in order to be reused later.
*/
public interface ConnectionPool extends ConnectionFactory, Closeable {
/**
* Reuses a {@link Connection} in the pool and will potentially create a new connection depends on
* implementation.
*/
@Override
Single<? extends Connection> create();

/** Closes the connection pool and closes all the underlying connections */
@Override
void close() throws IOException;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright 2021 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.remote.grpc;

import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.MethodDescriptor;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.functions.Action;
import io.reactivex.rxjava3.subjects.AsyncSubject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;

/**
* A {@link ConnectionPool} that creates one connection using provided {@link ConnectionFactory} and
* shares the connection upto {@code maxConcurrency}.
*
* <p>This is useful if underlying connection maintains a connection pool internally. (such as
* {@code Channel} in gRPC)
*
* <p>Connections must be closed with {@link Connection#close()} in order to be reused later.
*/
@ThreadSafe
public class SharedConnectionFactory implements ConnectionPool {
private final TokenBucket<Integer> tokenBucket;
private final ConnectionFactory factory;

@Nullable
@GuardedBy("connectionLock")
private AsyncSubject<Connection> connectionAsyncSubject = null;

private final ReentrantLock connectionLock = new ReentrantLock();
private final AtomicReference<Disposable> connectionCreationDisposable =
new AtomicReference<>(null);

public SharedConnectionFactory(ConnectionFactory factory, int maxConcurrency) {
this.factory = factory;

List<Integer> initialTokens = new ArrayList<>(maxConcurrency);
for (int i = 0; i < maxConcurrency; ++i) {
initialTokens.add(i);
}
this.tokenBucket = new TokenBucket<>(initialTokens);
}

@Override
public void close() throws IOException {
tokenBucket.close();

Disposable d = connectionCreationDisposable.getAndSet(null);
if (d != null && !d.isDisposed()) {
d.dispose();
}

try {
connectionLock.lockInterruptibly();

if (connectionAsyncSubject != null) {
Connection connection = connectionAsyncSubject.getValue();
if (connection != null) {
connection.close();
}

if (!connectionAsyncSubject.hasComplete()) {
connectionAsyncSubject.onError(new IllegalStateException("closed"));
}
}
} catch (InterruptedException e) {
throw new IOException(e);
} finally {
connectionLock.unlock();
}
}

private AsyncSubject<Connection> createUnderlyingConnectionIfNot() throws InterruptedException {
connectionLock.lockInterruptibly();
try {
if (connectionAsyncSubject == null || connectionAsyncSubject.hasThrowable()) {
connectionAsyncSubject =
factory
.create()
.doOnSubscribe(connectionCreationDisposable::set)
.toObservable()
.subscribeWith(AsyncSubject.create());
}

return connectionAsyncSubject;
} finally {
connectionLock.unlock();
}
}

private Single<? extends Connection> acquireConnection() {
return Single.fromCallable(this::createUnderlyingConnectionIfNot)
.flatMap(Single::fromObservable);
}

/**
* Reuses the underlying {@link Connection} and wait for it to be released if is exceeding {@code
* maxConcurrency}.
*/
@Override
public Single<SharedConnection> create() {
return tokenBucket
.acquireToken()
.flatMap(
token ->
acquireConnection()
.doOnError(ignored -> tokenBucket.addToken(token))
.doOnDispose(() -> tokenBucket.addToken(token))
.map(
conn ->
new SharedConnection(
conn, /* onClose= */ () -> tokenBucket.addToken(token))));
}

/** Returns current number of available connections. */
public int numAvailableConnections() {
return tokenBucket.size();
}

/** A {@link Connection} which wraps an underlying connection and is shared between consumers. */
public static class SharedConnection implements Connection {
private final Connection connection;
private final Action onClose;

public SharedConnection(Connection connection, Action onClose) {
this.connection = connection;
this.onClose = onClose;
}

@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> call(
MethodDescriptor<ReqT, RespT> method, CallOptions options) {
return connection.call(method, options);
}

@Override
public void close() throws IOException {
try {
onClose.run();
} catch (Throwable t) {
throw new IOException(t);
}
}

/** Returns the underlying connection this shared connection built on */
public Connection getUnderlyingConnection() {
return connection;
}
}
}
Loading

0 comments on commit c5fedad

Please sign in to comment.