Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch Exception and not BaseException in the Connection #2104

Merged
merged 3 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,6 @@ async def read_from_socket(
# data was read from the socket and added to the buffer.
# return True to indicate that data was read.
return True
except asyncio.CancelledError:
raise
except (socket.timeout, asyncio.TimeoutError):
if raise_on_timeout:
raise TimeoutError("Timeout reading from socket") from None
Expand Down Expand Up @@ -721,7 +719,7 @@ async def connect(self):
lambda: self._connect(), lambda error: self.disconnect()
)
except asyncio.CancelledError:
raise
raise # in 3.7 and earlier, this is an Exception, not BaseException
except (socket.timeout, asyncio.TimeoutError):
raise TimeoutError("Timeout connecting to server")
except OSError as e:
Expand Down Expand Up @@ -916,7 +914,7 @@ async def send_packed_command(
raise ConnectionError(
f"Error {err_no} while writing to socket. {errmsg}."
) from e
except BaseException:
except Exception:
await self.disconnect()
raise

Expand Down Expand Up @@ -958,7 +956,7 @@ async def read_response(self, disable_decoding: bool = False):
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except BaseException:
except Exception:
await self.disconnect()
raise

Expand Down
4 changes: 2 additions & 2 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def send_packed_command(self, command, check_health=True):
errno = e.args[0]
errmsg = e.args[1]
raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
except BaseException:
except Exception:
self.disconnect()
raise

Expand Down Expand Up @@ -804,7 +804,7 @@ def read_response(self, disable_decoding=False):
except OSError as e:
self.disconnect()
raise ConnectionError(f"Error while reading from {hosterr}" f" : {e.args}")
except BaseException:
except Exception:
self.disconnect()
raise

Expand Down
74 changes: 74 additions & 0 deletions tests/test_asyncio/test_pubsub.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import functools
import socket
import sys
from typing import Optional
from unittest.mock import patch

import async_timeout
import pytest
Expand Down Expand Up @@ -914,3 +916,75 @@ async def loop_step_listen(self):
return True
except asyncio.TimeoutError:
return False


@pytest.mark.onlynoncluster
class TestBaseException:
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="requires python 3.8 or higher"
)
async def test_outer_timeout(self, r: redis.Redis):
"""
Using asyncio_timeout manually outside the inner method timeouts works.
This works on Python versions 3.8 and greater, at which time asyncio.
CancelledError became a BaseException instead of an Exception before.
"""
pubsub = r.pubsub()
await pubsub.subscribe("foo")
assert pubsub.connection.is_connected

async def get_msg_or_timeout(timeout=0.1):
async with async_timeout.timeout(timeout):
# blocking method to return messages
while True:
response = await pubsub.parse_response(block=True)
message = await pubsub.handle_message(
response, ignore_subscribe_messages=False
)
if message is not None:
return message

# get subscribe message
msg = await get_msg_or_timeout(10)
assert msg is not None
# timeout waiting for another message which never arrives
assert pubsub.connection.is_connected
with pytest.raises(asyncio.TimeoutError):
await get_msg_or_timeout()
# the timeout on the read should not cause disconnect
assert pubsub.connection.is_connected

async def test_base_exception(self, r: redis.Redis):
"""
Manually trigger a BaseException inside the parser's .read_response method
and verify that it isn't caught
"""
pubsub = r.pubsub()
await pubsub.subscribe("foo")
assert pubsub.connection.is_connected

async def get_msg():
# blocking method to return messages
while True:
response = await pubsub.parse_response(block=True)
message = await pubsub.handle_message(
response, ignore_subscribe_messages=False
)
if message is not None:
return message

# get subscribe message
msg = await get_msg()
assert msg is not None
# timeout waiting for another message which never arrives
assert pubsub.connection.is_connected
with patch("redis.asyncio.connection.PythonParser.read_response") as mock1:
mock1.side_effect = BaseException("boom")
with patch("redis.asyncio.connection.HiredisParser.read_response") as mock2:
mock2.side_effect = BaseException("boom")

with pytest.raises(BaseException):
await get_msg()

# the timeout on the read should not cause disconnect
assert pubsub.connection.is_connected
42 changes: 42 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,45 @@ def loop_step_listen(self):
for message in self.pubsub.listen():
self.messages.put(message)
return True


@pytest.mark.onlynoncluster
class TestBaseException:
def test_base_exception(self, r: redis.Redis):
"""
Manually trigger a BaseException inside the parser's .read_response method
and verify that it isn't caught
"""
pubsub = r.pubsub()
pubsub.subscribe("foo")

def is_connected():
return pubsub.connection._sock is not None

assert is_connected()

def get_msg():
# blocking method to return messages
while True:
response = pubsub.parse_response(block=True)
message = pubsub.handle_message(
response, ignore_subscribe_messages=False
)
if message is not None:
return message

# get subscribe message
msg = get_msg()
assert msg is not None
# timeout waiting for another message which never arrives
assert is_connected()
with patch("redis.connection.PythonParser.read_response") as mock1:
mock1.side_effect = BaseException("boom")
with patch("redis.connection.HiredisParser.read_response") as mock2:
mock2.side_effect = BaseException("boom")

with pytest.raises(BaseException):
get_msg()

# the timeout on the read should not cause disconnect
assert is_connected()