Skip to content

Commit

Permalink
core: remove channel reference from ManagedChannelWrapper (#5283)
Browse files Browse the repository at this point in the history
This avoids a memory leak when the channel itself participates in a
reference cycle (e.g., when an interceptor retains a reference to an
Android app's context). With the current implementation, the static
`ManagedChannelOrphanWrapper.refs` map will keep the channel reachable
and prevent the ref cycle from being GCed.
  • Loading branch information
ericgribkoff authored Jan 26, 2019
1 parent f973bbc commit ce2ae1f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.lang.ref.WeakReference;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.LogRecord;
import java.util.logging.Logger;
Expand Down Expand Up @@ -56,24 +55,17 @@ final class ManagedChannelOrphanWrapper extends ForwardingManagedChannel {
@Override
public ManagedChannel shutdown() {
phantom.shutdown = true;
phantom.clear();
return super.shutdown();
}

@Override
public ManagedChannel shutdownNow() {
phantom.shutdownNow = true;
phantom.shutdown = true;
phantom.clear();
return super.shutdownNow();
}

@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
boolean ret = super.awaitTermination(timeout, unit);
if (ret) {
phantom.clear();
}
return ret;
}

@VisibleForTesting
static final class ManagedChannelReference extends WeakReference<ManagedChannelOrphanWrapper> {

Expand All @@ -87,10 +79,9 @@ static final class ManagedChannelReference extends WeakReference<ManagedChannelO
private final ReferenceQueue<ManagedChannelOrphanWrapper> refqueue;
private final ConcurrentMap<ManagedChannelReference, ManagedChannelReference> refs;

private final ManagedChannel channel;
private final String channelStr;
private final Reference<RuntimeException> allocationSite;
private volatile boolean shutdown;
private volatile boolean shutdownNow;

ManagedChannelReference(
ManagedChannelOrphanWrapper orphanable,
Expand All @@ -102,7 +93,7 @@ static final class ManagedChannelReference extends WeakReference<ManagedChannelO
ENABLE_ALLOCATION_TRACKING
? new RuntimeException("ManagedChannel allocation site")
: missingCallSite);
this.channel = channel;
this.channelStr = channel.toString();
this.refqueue = refqueue;
this.refs = refs;
this.refs.put(this, this);
Expand Down Expand Up @@ -144,21 +135,18 @@ static int cleanQueue(ReferenceQueue<ManagedChannelOrphanWrapper> refqueue) {
while ((ref = (ManagedChannelReference) refqueue.poll()) != null) {
RuntimeException maybeAllocationSite = ref.allocationSite.get();
ref.clearInternal(); // technically the reference is gone already.
if (!(ref.shutdown && ref.channel.isTerminated())) {
if (!ref.shutdown) {
orphanedChannels++;
Level level = ref.shutdownNow ? Level.FINE : Level.SEVERE;
Level level = Level.SEVERE;
if (logger.isLoggable(level)) {
String fmt =
"*~*~*~ Channel {0} was not "
// Prefer to complain about shutdown if neither has been called.
+ (!ref.shutdown ? "shutdown" : "terminated")
+ " properly!!! ~*~*~*"
+ System.getProperty("line.separator")
+ " Make sure to call shutdown()/shutdownNow() and wait "
+ "until awaitTermination() returns true.";
"*~*~*~ Channel {0} was not shutdown properly!!! ~*~*~*"
+ System.getProperty("line.separator")
+ " Make sure to call shutdown()/shutdownNow() and wait "
+ "until awaitTermination() returns true.";
LogRecord lr = new LogRecord(level, fmt);
lr.setLoggerName(logger.getName());
lr.setParameters(new Object[]{ref.channel.toString()});
lr.setParameters(new Object[] {ref.channelStr});
lr.setThrown(maybeAllocationSite);
logger.log(lr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;

import com.google.common.testing.GcFinalization;
import com.google.common.testing.GcFinalization.FinalizationPredicate;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import io.grpc.internal.ManagedChannelOrphanWrapper.ManagedChannelReference;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Filter;
import java.util.logging.Level;
import java.util.logging.LogRecord;
Expand All @@ -43,10 +45,10 @@
@RunWith(JUnit4.class)
public final class ManagedChannelOrphanWrapperTest {
@Test
public void orphanedChannelsAreLogged() throws Exception {
ManagedChannel mc = mock(ManagedChannel.class);
public void orphanedChannelsAreLogged() {
ManagedChannel mc = new TestManagedChannel();
String channelString = mc.toString();
ReferenceQueue<ManagedChannelOrphanWrapper> refqueue =
final ReferenceQueue<ManagedChannelOrphanWrapper> refqueue =
new ReferenceQueue<ManagedChannelOrphanWrapper>();
ConcurrentMap<ManagedChannelReference, ManagedChannelReference> refs =
new ConcurrentHashMap<ManagedChannelReference, ManagedChannelReference>();
Expand All @@ -71,22 +73,18 @@ public boolean isLoggable(LogRecord record) {
}
});

// TODO(carl-mastrangelo): consider using com.google.common.testing.GcFinalization instead.
try {
channel = null;
boolean success = false;
for (int retry = 0; retry < 3; retry++) {
System.gc();
System.runFinalization();
int orphans = ManagedChannelReference.cleanQueue(refqueue);
if (orphans == 1) {
success = true;
break;
}
assertEquals("unexpected extra orphans", 0, orphans);
Thread.sleep(100L * (1L << retry));
}
assertTrue("Channel was not garbage collected", success);
final AtomicInteger numOrphans = new AtomicInteger();
GcFinalization.awaitDone(
new FinalizationPredicate() {
@Override
public boolean isDone() {
numOrphans.getAndAdd(ManagedChannelReference.cleanQueue(refqueue));
return numOrphans.get() > 0;
}
});
assertEquals("unexpected extra orphans", 1, numOrphans.get());

LogRecord lr;
synchronized (records) {
Expand All @@ -102,7 +100,32 @@ public boolean isLoggable(LogRecord record) {
}
}

private static final class TestManagedChannel extends ManagedChannel {
@Test
public void refCycleIsGCed() {
ReferenceQueue<ManagedChannelOrphanWrapper> refqueue =
new ReferenceQueue<ManagedChannelOrphanWrapper>();
ConcurrentMap<ManagedChannelReference, ManagedChannelReference> refs =
new ConcurrentHashMap<ManagedChannelReference, ManagedChannelReference>();
ApplicationWithChannelRef app = new ApplicationWithChannelRef();
ChannelWithApplicationRef channelImpl = new ChannelWithApplicationRef();
ManagedChannelOrphanWrapper channel =
new ManagedChannelOrphanWrapper(channelImpl, refqueue, refs);
app.channel = channel;
channelImpl.application = app;
WeakReference<ApplicationWithChannelRef> appWeakRef =
new WeakReference<ApplicationWithChannelRef>(app);

// Simulate the application and channel going out of scope. A ref cycle between app and
// channel remains, so ensure that our tracking of orphaned channels does not prevent this
// reference cycle from being GCed.
channel = null;
app = null;
channelImpl = null;

GcFinalization.awaitClear(appWeakRef);
}

private static class TestManagedChannel extends ManagedChannel {
@Override
public ManagedChannel shutdown() {
return null;
Expand Down Expand Up @@ -139,4 +162,12 @@ public String authority() {
return null;
}
}

private static final class ApplicationWithChannelRef {
private ManagedChannel channel;
}

private static final class ChannelWithApplicationRef extends TestManagedChannel {
private ApplicationWithChannelRef application;
}
}

0 comments on commit ce2ae1f

Please sign in to comment.