Skip to content

Commit

Permalink
core,xds: Metrics recording in WRR LB (#11129)
Browse files Browse the repository at this point in the history
  • Loading branch information
temawi authored Apr 26, 2024
1 parent 795ee0f commit 06df25b
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 7 deletions.
88 changes: 81 additions & 7 deletions xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.Deadline.Ticker;
import io.grpc.DoubleHistogramMetricInstrument;
import io.grpc.EquivalentAddressGroup;
import io.grpc.ExperimentalApi;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.LongCounterMetricInstrument;
import io.grpc.MetricInstrumentRegistry;
import io.grpc.NameResolver;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
Expand All @@ -57,12 +61,17 @@
import java.util.logging.Logger;

/**
* A {@link LoadBalancer} that provides weighted-round-robin load-balancing over
* the {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are
* A {@link LoadBalancer} that provides weighted-round-robin load-balancing over the
* {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are
* determined by backend metrics using ORCA.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885")
final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {

private static final LongCounterMetricInstrument RR_FALLBACK_COUNTER;
private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER;
private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_STALE_COUNTER;
private static final DoubleHistogramMetricInstrument ENDPOINT_WEIGHTS_HISTOGRAM;
private static final Logger log = Logger.getLogger(
WeightedRoundRobinLoadBalancer.class.getName());
private WeightedRoundRobinLoadBalancerConfig config;
Expand All @@ -74,6 +83,31 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
private final long infTime;
private final Ticker ticker;

// The metric instruments are only registered once and shared by all instances of this LB.
static {
MetricInstrumentRegistry metricInstrumentRegistry
= MetricInstrumentRegistry.getDefaultRegistry();
RR_FALLBACK_COUNTER = metricInstrumentRegistry.registerLongCounter("grpc.lb.wrr.rr_fallback",
"Number of scheduler updates in which there were not enough endpoints with valid "
+ "weight, which caused the WRR policy to fall back to RR behavior", "update",
Lists.newArrayList("grpc.target"), Lists.newArrayList("grpc.lb.locality"), true);
ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER = metricInstrumentRegistry.registerLongCounter(
"grpc.lb.wrr.endpoint_weight_not_yet_usable",
"Number of endpoints from each scheduler update that don't yet have usable weight "
+ "information", "endpoint", Lists.newArrayList("grpc.target"),
Lists.newArrayList("grpc.lb.locality"), true);
ENDPOINT_WEIGHT_STALE_COUNTER = metricInstrumentRegistry.registerLongCounter(
"grpc.lb.wrr.endpoint_weight_stale",
"Number of endpoints from each scheduler update whose latest weight is older than the "
+ "expiration period", "endpoint", Lists.newArrayList("grpc.target"),
Lists.newArrayList("grpc.lb.locality"), true);
ENDPOINT_WEIGHTS_HISTOGRAM = metricInstrumentRegistry.registerDoubleHistogram(
"grpc.lb.wrr.endpoint_weights", "The histogram buckets will be endpoint weight ranges.",
"weight", Lists.newArrayList(), Lists.newArrayList("grpc.target"),
Lists.newArrayList("grpc.lb.locality"),
true);
}

public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) {
this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, new Random());
}
Expand Down Expand Up @@ -145,7 +179,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
@Override
public SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {
return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence);
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper());
}

@VisibleForTesting
Expand All @@ -163,16 +197,18 @@ public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Obj
super(key, policyProvider, childConfig, initialPicker);
}

private double getWeight() {
private double getWeight(AtomicInteger staleEndpoints, AtomicInteger notYetUsableEndpoints) {
if (config == null) {
return 0;
}
long now = ticker.nanoTime();
if (now - lastUpdated >= config.weightExpirationPeriodNanos) {
nonEmptySince = infTime;
staleEndpoints.incrementAndGet();
return 0;
} else if (now - nonEmptySince < config.blackoutPeriodNanos
&& config.blackoutPeriodNanos > 0) {
notYetUsableEndpoints.incrementAndGet();
return 0;
} else {
return weight;
Expand Down Expand Up @@ -336,10 +372,11 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker {
private final float errorUtilizationPenalty;
private final AtomicInteger sequence;
private final int hashCode;
private final LoadBalancer.Helper helper;
private volatile StaticStrideScheduler scheduler;

WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
float errorUtilizationPenalty, AtomicInteger sequence) {
float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper) {
checkNotNull(children, "children");
Preconditions.checkArgument(!children.isEmpty(), "empty child list");
this.children = children;
Expand All @@ -353,6 +390,7 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker {
this.enableOobLoadReport = enableOobLoadReport;
this.errorUtilizationPenalty = errorUtilizationPenalty;
this.sequence = checkNotNull(sequence, "sequence");
this.helper = helper;

// For equality we treat children as a set; use hash code as defined by Set
int sum = 0;
Expand Down Expand Up @@ -387,11 +425,37 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {

private void updateWeight() {
float[] newWeights = new float[children.size()];
AtomicInteger staleEndpoints = new AtomicInteger();
AtomicInteger notYetUsableEndpoints = new AtomicInteger();
for (int i = 0; i < children.size(); i++) {
double newWeight = ((WeightedChildLbState)children.get(i)).getWeight();
double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints,
notYetUsableEndpoints);
// TODO: add target and locality labels once available
helper.getMetricRecorder()
.recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight, ImmutableList.of(""),
ImmutableList.of(""));
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
}
if (staleEndpoints.get() > 0) {
// TODO: add target and locality labels once available
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
ImmutableList.of(""),
ImmutableList.of(""));
}
if (notYetUsableEndpoints.get() > 0) {
// TODO: add target and locality labels once available
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
ImmutableList.of(""), ImmutableList.of(""));
}

this.scheduler = new StaticStrideScheduler(newWeights, sequence);
if (this.scheduler.usesRoundRobin()) {
// TODO: add target and locality labels once available
helper.getMetricRecorder()
.addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(""), ImmutableList.of(""));
}
}

@Override
Expand Down Expand Up @@ -454,6 +518,7 @@ public boolean equals(Object o) {
static final class StaticStrideScheduler {
private final short[] scaledWeights;
private final AtomicInteger sequence;
private final boolean usesRoundRobin;
private static final int K_MAX_WEIGHT = 0xFFFF;

// Assuming the mean of all known weights is M, StaticStrideScheduler will clamp
Expand Down Expand Up @@ -494,8 +559,10 @@ static final class StaticStrideScheduler {
if (numWeightedChannels > 0) {
unscaledMeanWeight = sumWeight / numWeightedChannels;
unscaledMaxWeight = Math.min(unscaledMaxWeight, (float) (K_MAX_RATIO * unscaledMeanWeight));
usesRoundRobin = false;
} else {
// Fall back to round robin if all values are non-positives
usesRoundRobin = true;
unscaledMeanWeight = 1;
unscaledMaxWeight = 1;
}
Expand All @@ -521,7 +588,14 @@ static final class StaticStrideScheduler {
this.sequence = sequence;
}

/** Returns the next sequence number and atomically increases sequence with wraparound. */
// Without properly weighted channels, we do plain vanilla round_robin.
boolean usesRoundRobin() {
return usesRoundRobin;
}

/**
* Returns the next sequence number and atomically increases sequence with wraparound.
*/
private long nextSequence() {
return Integer.toUnsignedLong(sequence.getAndIncrement());
}
Expand Down
137 changes: 137 additions & 0 deletions xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.ConnectivityState.CONNECTING;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
Expand All @@ -40,6 +41,7 @@
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.DoubleHistogramMetricInstrument;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
Expand All @@ -49,6 +51,8 @@
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.LongCounterMetricInstrument;
import io.grpc.MetricRecorder;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock;
Expand Down Expand Up @@ -82,6 +86,7 @@
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
Expand Down Expand Up @@ -120,6 +125,9 @@ public class WeightedRoundRobinLoadBalancerTest {

private final FakeClock fakeClock = new FakeClock();

@Mock
private MetricRecorder mockMetricRecorder;

private WeightedRoundRobinLoadBalancerConfig weightedConfig =
WeightedRoundRobinLoadBalancerConfig.newBuilder().build();

Expand Down Expand Up @@ -1121,6 +1129,130 @@ public void removingAddressShutsdownSubchannel() {
inOrder.verify(subchannel2).shutdown();
}


@Test
public void metrics() {
// Give WRR some valid addresses to work with.
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));

// Flip the three subchannels to READY state to initiate the WRR logic
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel3 = it.next();
getSubchannelStateListener(readySubchannel3).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));

// WRR creates a picker that updates the weights for each of the child subchannels. This should
// give us three "rr_fallback" metric events as we don't yet have any weights to do weighted
// round-robin.
verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 3, 1);

// We should also see six records of endpoint weights. They should all be for 0 as we don't yet
// have valid weights.
verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 6, 0);

// We should not yet be seeing any "endpoint_weight_stale" events since we don't even have
// valid weights yet.
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 0, 1);

// Each time weights are updated, WRR will see if each subchannel weight is useable. As we have
// no weights yet, we should see three "endpoint_weight_not_yet_usable" metric events with the
// value increasing by one each time as all the endpoints come online.
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 1);
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 2);
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 3);

// Send each child LB state an ORCA update with some valid utilization/qps data so that weights
// can be calculated.
Iterator<ChildLbState> childLbStates = wrr.getChildLbStates().iterator();
((WeightedChildLbState)childLbStates.next()).new OrcaReportListener(
weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(),
new HashMap<>(), new HashMap<>()));
((WeightedChildLbState)childLbStates.next()).new OrcaReportListener(
weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(),
new HashMap<>(), new HashMap<>()));
((WeightedChildLbState)childLbStates.next()).new OrcaReportListener(
weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(),
new HashMap<>(), new HashMap<>()));

// Let's reset the mock MetricsRecorder so that it's easier to verify what happened after the
// weights were updated
reset(mockMetricRecorder);

// We go forward in time past the default 10s blackout period before weights can be considered
// for wrr. The eights would get updated as the default update interval is 1s.
fakeClock.forwardTime(11, TimeUnit.SECONDS);

// Since we have weights on all the child LB states, the weight update should not result in
// further rr_fallback metric entries.
verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 0, 1);

// We should not see an increase to the earlier count of "endpoint_weight_not_yet_usable".
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 0, 1);

// No endpoints should have gotten stale yet either.
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 0, 1);

// Now with valid weights we should have seen the value in the endpoint weights histogram.
verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 3, 10);

reset(mockMetricRecorder);

// Weights become stale in three minutes. Let's move ahead in time by 3 minutes and make sure
// we get metrics events for each endpoint.
fakeClock.forwardTime(3, TimeUnit.MINUTES);

verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 1, 3);

// With the weights stale each three endpoints should report 0 weights.
verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 3, 0);

// Since the weights are now stale the update should have triggered an additional rr_fallback
// event.
verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 1, 1);

// No further weights-not-useable events should occur, since we have received weights and
// are out of the blackout.
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 0, 1);

// All metric events should be accounted for.
verifyNoMoreInteractions(mockMetricRecorder);
}

// Verifies that the MetricRecorder has been called to record a long counter value of 1 for the
// given metric name, the given number of times
private void verifyLongCounterRecord(String name, int times, long value) {
verify(mockMetricRecorder, times(times)).addLongCounter(
argThat(new ArgumentMatcher<LongCounterMetricInstrument>() {
@Override
public boolean matches(LongCounterMetricInstrument longCounterInstrument) {
return longCounterInstrument.getName().equals(name);
}
}), eq(value), eq(Lists.newArrayList("")), eq(Lists.newArrayList("")));
}

// Verifies that the MetricRecorder has been called to record a given double histogram value the
// given amount of times.
private void verifyDoubleHistogramRecord(String name, int times, double value) {
verify(mockMetricRecorder, times(times)).recordDoubleHistogram(
argThat(new ArgumentMatcher<DoubleHistogramMetricInstrument>() {
@Override
public boolean matches(DoubleHistogramMetricInstrument doubleHistogramInstrument) {
return doubleHistogramInstrument.getName().equals(name);
}
}), eq(value), eq(Lists.newArrayList("")), eq(Lists.newArrayList("")));
}

private int getNumFilteredPendingTasks() {
return AbstractTestHelper.getNumFilteredPendingTasks(fakeClock);
}
Expand Down Expand Up @@ -1189,5 +1321,10 @@ public Map<Subchannel, Subchannel> getMockToRealSubChannelMap() {
public Map<Subchannel, SubchannelStateListener> getSubchannelStateListeners() {
return subchannelStateListeners;
}

@Override
public MetricRecorder getMetricRecorder() {
return mockMetricRecorder;
}
}
}

0 comments on commit 06df25b

Please sign in to comment.