Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Fix health check after unsubscribe #1207

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGES/1207.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix #1206 health check message after unsubscribing
83 changes: 77 additions & 6 deletions aioredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
cast,
)

import async_timeout

from aioredis.compat import Protocol, TypedDict
from aioredis.connection import (
Connection,
Expand Down Expand Up @@ -3934,24 +3936,25 @@ def __init__(
self.shard_hint = shard_hint
self.ignore_subscribe_messages = ignore_subscribe_messages
self.connection: Optional[Connection] = None
self.subscribed_event = asyncio.Event()
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = self.connection_pool.get_encoder()
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
if self.encoder.decode_responses:
self.health_check_response: Iterable[Union[str, bytes]] = [
"pong",
self.HEALTH_CHECK_MESSAGE,
]
else:
self.health_check_response = [
b"pong",
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
]
self.health_check_response = [b"pong", self.health_check_response_b]
self.channels: Dict[ChannelT, PubSubHandler] = {}
self.pending_unsubscribe_channels: Set[ChannelT] = set()
self.patterns: Dict[ChannelT, PubSubHandler] = {}
self.pending_unsubscribe_patterns: Set[ChannelT] = set()
self._lock = asyncio.Lock()
self.health_check_response_counter = 0
self.subscribed_event.clear()

async def __aenter__(self):
return self
Expand All @@ -3971,9 +3974,11 @@ async def reset(self):
await self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
self.health_check_response_counter = 0
self.pending_unsubscribe_channels = set()
self.patterns = {}
self.pending_unsubscribe_patterns = set()
self.subscribed_event.clear()

def close(self) -> Awaitable[NoReturn]:
return self.reset()
Expand All @@ -3999,7 +4004,7 @@ async def on_connect(self, connection: Connection):
@property
def subscribed(self):
"""Indicates if there are subscriptions to any channels or patterns"""
return bool(self.channels or self.patterns)
return self.subscribed_event.is_set()

async def execute_command(self, *args: EncodableT):
"""Execute a publish/subscribe command"""
Expand All @@ -4017,8 +4022,30 @@ async def execute_command(self, *args: EncodableT):
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
await self.clean_health_check_responses()
await self._execute(connection, connection.send_command, *args, **kwargs)

async def clean_health_check_responses(self):
"""
If any health check responses are present, clean them
"""
ttl = 10
conn = self.connection
if not conn:
return
while self.health_check_response_counter > 0 and ttl > 0:
if await self._execute(conn, conn.can_read, timeout=conn.socket_timeout):
response = await self._execute(conn, conn.read_response)
if self.is_health_check_response(response):
self.health_check_response_counter -= 1
else:
raise PubSubError(
"A non health check response was cleaned by "
"execute_command: {}".format(response)
)
ttl -= 1

async def _execute(self, connection, command, *args, **kwargs):
try:
return await command(*args, **kwargs)
Expand Down Expand Up @@ -4049,11 +4076,25 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
return None
response = await self._execute(conn, conn.read_response)

if conn.health_check_interval and response == self.health_check_response:
# The response depends on whether there were any subscriptions
# active at the time the PING was issued.
if self.is_health_check_response(response):
# ignore the health check message as user might not expect it
self.health_check_response_counter -= 1
return None
return response

def is_health_check_response(self, response):
"""
Check if the response is a health check response.
If there are no subscriptions redis responds to PING command with a
bulk response, instead of a multi-bulk with "pong" and the response.
"""
return response in [
self.health_check_response, # If there was a subscription
self.health_check_response_b, # If there wasn't
]

async def check_health(self):
conn = self.connection
if conn is None:
Expand All @@ -4066,6 +4107,7 @@ async def check_health(self):
conn.health_check_interval
and asyncio.get_event_loop().time() > conn.next_health_check
):
self.health_check_response_counter += 1
await conn.send_command(
"PING", self.HEALTH_CHECK_MESSAGE, check_health=False
)
Expand Down Expand Up @@ -4098,6 +4140,11 @@ async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
# for the reconnection.
new_patterns = self._normalize_keys(new_patterns)
self.patterns.update(new_patterns)
if not self.subscribed:
# Set the subscribed_event flag to True
self.subscribed_event.set()
# Clear the health check counter
self.health_check_response_counter = 0
self.pending_unsubscribe_patterns.difference_update(new_patterns)
return ret_val

Expand Down Expand Up @@ -4134,6 +4181,11 @@ async def subscribe(self, *args: ChannelT, **kwargs: Callable):
# for the reconnection.
new_channels = self._normalize_keys(new_channels)
self.channels.update(new_channels)
if not self.subscribed:
# Set the subscribed_event flag to True
self.subscribed_event.set()
# Clear the health check counter
self.health_check_response_counter = 0
self.pending_unsubscribe_channels.difference_update(new_channels)
return ret_val

Expand Down Expand Up @@ -4168,6 +4220,21 @@ async def get_message(
before returning. Timeout should be specified as a floating point
number.
"""
if not self.subscribed:
# Wait for subscription
start_time = asyncio.get_event_loop().time()

async with async_timeout.timeout(timeout):
if await self.subscribed_event.wait() is True:
# The connection was subscribed during the timeout time frame.
# The timeout should be adjusted based on the time spent
# waiting for the subscription
time_spent = asyncio.get_event_loop().time() - start_time
timeout = max(0.0, timeout - time_spent)
else:
# The connection isn't subscribed to any channels or patterns,
# so no messages are available
return None
response = await self.parse_response(block=False, timeout=timeout)
if response:
return self.handle_message(response, ignore_subscribe_messages)
Expand Down Expand Up @@ -4221,6 +4288,10 @@ def handle_message(self, response, ignore_subscribe_messages=False):
if channel in self.pending_unsubscribe_channels:
self.pending_unsubscribe_channels.remove(channel)
self.channels.pop(channel, None)
if not self.channels and not self.patterns:
# There are no subscriptions anymore, set subscribed_event flag
# to false
self.subscribed_event.clear()

if message_type in self.PUBLISH_MESSAGE_TYPES:
# if there's a message handler, invoke it
Expand Down
43 changes: 32 additions & 11 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import threading
import time
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -343,15 +342,6 @@ async def test_unicode_pattern_message_handler(self, r):
"pmessage", channel, "test message", pattern=pattern
)

async def test_get_message_without_subscribe(self, r):
p = r.pubsub()
with pytest.raises(RuntimeError) as info:
await p.get_message()
expect = (
"connection not set: " "did you forget to call subscribe() or psubscribe()?"
)
assert expect in info.exconly()


class TestPubSubAutoDecoding:
"""These tests only validate that we get unicode values back"""
Expand Down Expand Up @@ -553,6 +543,37 @@ async def test_get_message_with_timeout_returns_none(self, r):
assert await wait_for_message(p) == make_message("subscribe", "foo", 1)
assert await p.get_message(timeout=0.01) is None

async def test_get_message_not_subscribed_return_none(self, r):
p = r.pubsub()
assert p.subscribed is False
assert await p.get_message() is None
assert await p.get_message(timeout=0.1) is None
with patch.object(asyncio.Event, "wait") as mock:
mock.return_value = False
assert await p.get_message(timeout=0.01) is None
assert mock.called

async def test_get_message_subscribe_during_waiting(self, r):
p = r.pubsub()

async def poll(ps, expected_res):
assert await ps.get_message() is None
message = await ps.get_message(timeout=1)
assert message == expected_res

subscribe_response = make_message("subscribe", "foo", 1)
asyncio.create_task(poll(p, subscribe_response))
await asyncio.sleep(0.2)
await p.subscribe("foo")

async def test_get_message_wait_for_subscription_not_being_called(self, r):
p = r.pubsub()
await p.subscribe("foo")
with patch.object(asyncio.Event, "wait") as mock:
assert p.subscribed is True
assert await wait_for_message(p) == make_message("subscribe", "foo", 1)
assert mock.called is False


class TestPubSubRun:
async def _subscribe(self, p, *args, **kwargs):
Expand Down