diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py index 1f6f7483ddd12..b67991115bee6 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py @@ -25,6 +25,7 @@ # Type aliases Message = google.protobuf.message.Message +RpcError = grpc.RpcError class GrpcClientHelper: diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py index 59e14786dfd5a..df25e2e5aaf86 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py @@ -30,6 +30,7 @@ # Type aliases # Channel Channel = channelz_pb2.Channel +ChannelData = channelz_pb2.ChannelData ChannelConnectivityState = channelz_pb2.ChannelConnectivityState ChannelState = ChannelConnectivityState.State # pylint: disable=no-member _GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest @@ -109,6 +110,7 @@ def channel_repr(channel: Channel) -> str: result += f" target={channel.data.target}" result += ( f" call_started={channel.data.calls_started}" + + f" calls_succeeded={channel.data.calls_succeeded}" + f" calls_failed={channel.data.calls_failed}" ) result += f" state={ChannelState.Name(channel.data.state.state)}>" @@ -170,6 +172,26 @@ def list_channels(self, **kwargs) -> Iterator[Channel]: start = max(start, channel.ref.channel_id) yield channel + def get_channel(self, channel_id, **kwargs) -> Channel: + """Return a single Channel, otherwise raises RpcError.""" + response: channelz_pb2.GetChannelResponse + try: + response = self.call_unary_with_deadline( + rpc="GetChannel", + req=channelz_pb2.GetChannelRequest(channel_id=channel_id), + **kwargs, + ) + return response.channel + except grpc.RpcError as err: + if isinstance(err, grpc.Call): + # Translate NOT_FOUND into GrpcApp.NotFound. + if err.code() is grpc.StatusCode.NOT_FOUND: + raise framework.rpc.grpc.GrpcApp.NotFound( + f"Channel with channel_id {channel_id} not found", + ) + + raise + def list_servers(self, **kwargs) -> Iterator[Server]: """Iterate over all pages of all servers that exist in the process.""" start: int = -1 diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py index 35cabd9b907cd..5da8a088908f4 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py @@ -17,6 +17,7 @@ import datetime import functools import logging +import time from typing import Iterable, List, Optional import framework.errors @@ -36,6 +37,7 @@ ) _ChannelzServiceClient = grpc_channelz.ChannelzServiceClient _ChannelzChannel = grpc_channelz.Channel +_ChannelzChannelData = grpc_channelz.ChannelData _ChannelzChannelState = grpc_channelz.ChannelState _ChannelzSubchannel = grpc_channelz.Subchannel _ChannelzSocket = grpc_channelz.Socket @@ -280,7 +282,7 @@ def wait_for_xds_channel_active( ) logger.info( - "[%s] ADS: Waiting for successful calls to xDS control plane to %s", + "[%s] ADS: Waiting for active calls to xDS control plane to %s", self.hostname, xds_server_uri, ) @@ -290,7 +292,7 @@ def wait_for_xds_channel_active( rpc_deadline=rpc_deadline, ) logger.info( - "[%s] ADS: Detected successful calls to xDS control plane %s", + "[%s] ADS: Detected active calls to xDS control plane %s", self.hostname, xds_server_uri, ) @@ -306,7 +308,7 @@ def find_active_xds_channel( if rpc_deadline is not None: rpc_params["deadline_sec"] = rpc_deadline.total_seconds() - for channel in self.get_server_channels(xds_server_uri, **rpc_params): + for channel in self.find_channels(xds_server_uri, **rpc_params): logger.info( "[%s] xDS control plane channel: %s", self.hostname, @@ -314,21 +316,33 @@ def find_active_xds_channel( ) try: - channel = self.check_channel_successful_calls( + channel_upd = self.check_channel_in_flight_calls( channel, **rpc_params ) logger.info( - "[%s] Detected successful calls to xDS control plane %s," + "[%s] Detected active calls to xDS control plane %s," " channel: %s", self.hostname, xds_server_uri, - _ChannelzServiceClient.channel_repr(channel), + _ChannelzServiceClient.channel_repr(channel_upd), ) + return channel_upd except self.NotFound: - # Otherwise, keep searching. + # Continue checking other channels to the same target on + # not found. continue - - return channel + except framework.rpc.grpc.RpcError as err: + # Logged at 'info' and not at 'warning' because this method is + # expected to be called in a retryer. If this error eventually + # causes the retryer to fail, it will be logged fully at 'error' + logger.info( + "[%s] Unexpected error while checking xDS control plane" + " channel %s: %r", + self.hostname, + _ChannelzServiceClient.channel_repr(channel), + err, + ) + raise raise self.ChannelNotActive( f"[{self.hostname}] Client has no" @@ -351,7 +365,7 @@ def find_server_channel_with_state( expected_state_name: str = _ChannelzChannelState.Name(expected_state) target: str = self.server_target - for channel in self.get_server_channels(target, **rpc_params): + for channel in self.find_channels(target, **rpc_params): channel_state: _ChannelzChannelState = channel.data.state.state logger.info( "[%s] Server channel: %s", @@ -386,10 +400,12 @@ def find_server_channel_with_state( expected_state=expected_state, ) - def get_server_channels( - self, server_target: str, **kwargs + def find_channels( + self, + target: str, + **rpc_params, ) -> Iterable[_ChannelzChannel]: - return self.channelz.find_channels_for_target(server_target, **kwargs) + return self.channelz.find_channels_for_target(target, **rpc_params) def find_subchannel_with_state( self, channel: _ChannelzChannel, state: _ChannelzChannelState, **kwargs @@ -419,23 +435,65 @@ def find_subchannels_with_state( subchannels.append(subchannel) return subchannels - def check_channel_successful_calls( - self, channel: _ChannelzChannel, **kwargs - ) -> _ChannelzChannel: - """Checks if the channel has any successful calls. - - We consider the channel is active if channel is in READY state and calls_started is - greater than calls_failed. + def check_channel_in_flight_calls( + self, + channel: _ChannelzChannel, + *, + wait_between_checks: Optional[_timedelta] = None, + **rpc_params, + ) -> Optional[_ChannelzChannel]: + """Checks if the channel has calls that started, but didn't complete. + + We consider the channel is active if channel is in READY state and + calls_started is greater than calls_failed. + + This method address race where a call to the xDS control plane server + has just started and a channelz request comes in before the call has + had a chance to fail. + + With channels to the xDS control plane, the channel can be READY but the + calls could be failing to initialize, f.e. due to a failure to fetch + OAUTH2 token. To increase the confidence that we have a valid channel + with working OAUTH2 tokens, we check whether the channel is in a READY + state with active calls twice with an interval of 2 seconds between the + two attempts. If the OAUTH2 token is not valid, the call would fail and + be caught in either the first attempt, or the second attempt. It is + possible that between the two attempts, a call fails and a new call is + started, so we also test for equality between the started calls of the + two channelz results. + + There still exists a possibility that a call fails on fetching OAUTH2 + token after 2 seconds (maybe because there is a slowdown in the + system.) If such a case is observed, consider increasing the interval + from 2 seconds to 5 seconds. + + Returns updated channel on success, or None on failure. """ + if not self.calc_calls_in_flight(channel): + return None + + if not wait_between_checks: + wait_between_checks = _timedelta(seconds=2) + + # Load the channel second time after the timeout. + time.sleep(wait_between_checks.total_seconds()) + channel_upd: _ChannelzChannel = self.channelz.get_channel( + channel.ref.channel_id, **rpc_params + ) if ( - channel.data.state.state is _ChannelzChannelState.READY - and channel.data.calls_started > channel.data.calls_failed + not self.calc_calls_in_flight(channel_upd) + or channel.data.calls_started != channel_upd.data.calls_started ): - return channel + return None + return channel_upd - raise self.NotFound( - f"[{self.hostname}] Not found successful calls over the channel." - ) + @classmethod + def calc_calls_in_flight(cls, channel: _ChannelzChannel) -> int: + cdata: _ChannelzChannelData = channel.data + if cdata.state.state is not _ChannelzChannelState.READY: + return 0 + + return cdata.calls_started - cdata.calls_succeeded - cdata.calls_failed class ChannelNotFound(framework.rpc.grpc.GrpcApp.NotFound): """Channel with expected status not found"""