Skip to content

Commit

Permalink
Only warn when connect was called
Browse files Browse the repository at this point in the history
  • Loading branch information
dwoz committed Nov 20, 2023
1 parent 0d997fd commit 67b43d3
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 7 deletions.
30 changes: 25 additions & 5 deletions salt/transport/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,38 @@ def publish_client(opts, io_loop):
raise Exception("Transport type not found: {}".format(ttype))


class TransportWarning(Warning):
"""
Transport warning.
"""


class Transport:
def __init__(self, *args, **kwargs):
self._trace = "\n".join(traceback.format_stack()[:-1])
if not hasattr(self, "_closing"):
self._closing = False
if not hasattr(self, "_connect_called"):
self._connect_called = False

def connect(self, *args, **kwargs):
self._connect_called = True

# pylint: disable=W1701
def __del__(self):
if not self._closing:
"""
Warn the user if the transport's close method was never called.
If the _closing attribute is missing we won't raise a warning. This
prevents issues when class's dunder init method is called with improper
arguments, and is later getting garbage collected. Users of this class
should take care to call super() and validate the functionality with a
test.
"""
if getattr(self, "_connect_called") and not getattr(self, "_closing", True):
warnings.warn(
f"Unclosed transport {self!r} \n{self._trace}",
ResourceWarning,
f"Unclosed transport! {self!r} \n{self._trace}",
TransportWarning,
source=self,
)

Expand Down Expand Up @@ -137,7 +157,7 @@ def close(self):
"""
raise NotImplementedError

def connect(self):
def connect(self): # pylint: disable=W0221
"""
Connect to the server / broker.
"""
Expand Down Expand Up @@ -233,7 +253,7 @@ def on_recv(self, callback):
raise NotImplementedError

@salt.ext.tornado.gen.coroutine
def connect(self, publish_port, connect_callback=None, disconnect_callback=None):
def connect(self, publish_port, connect_callback=None, disconnect_callback=None): # pylint: disable=W0221
"""
Create a network connection to the the PublishServer or broker.
"""
Expand Down
2 changes: 2 additions & 0 deletions salt/transport/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def close(self):

@salt.ext.tornado.gen.coroutine
def connect(self, publish_port, connect_callback=None, disconnect_callback=None):
self._connect_called = True
self.publish_port = publish_port
self.message_client = MessageClient(
self.opts,
Expand Down Expand Up @@ -1054,6 +1055,7 @@ def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231

@salt.ext.tornado.gen.coroutine
def connect(self):
self._connect_called = True
yield self.message_client.connect()

@salt.ext.tornado.gen.coroutine
Expand Down
9 changes: 7 additions & 2 deletions salt/transport/zeromq.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,16 @@ def __exit__(self, exc_type, exc_val, exc_tb):
# TODO: this is the time to see if we are connected, maybe use the req channel to guess?
@salt.ext.tornado.gen.coroutine
def connect(self, publish_port, connect_callback=None, disconnect_callback=None):
self._connect_called = True
self.publish_port = publish_port
log.debug(
"Connecting the Minion to the Master publish port, using the URI: %s",
self.master_pub,
)
log.debug("%r connecting to %s", self, self.master_pub)
self._socket.connect(self.master_pub)
connect_callback(True)
if connect_callback is not None:
connect_callback(True)

@property
def master_pub(self):
Expand Down Expand Up @@ -886,13 +888,16 @@ def __init__(self, opts, io_loop): # pylint: disable=W0231
io_loop=io_loop,
)
self._closing = False
self._connect_called = False

@salt.ext.tornado.gen.coroutine
def connect(self):
self._connect_called = True
self.message_client.connect()

@salt.ext.tornado.gen.coroutine
def send(self, load, timeout=60):
self.connect()
yield self.connect()
ret = yield self.message_client.send(load, timeout=timeout)
raise salt.ext.tornado.gen.Return(ret)

Expand Down
21 changes: 21 additions & 0 deletions tests/pytests/unit/transport/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Unit tests for salt.transport.base.
"""
import pytest

import salt.transport.base

pytestmark = [
pytest.mark.core_test,
]


def test_unclosed_warning():

transport = salt.transport.base.Transport()
assert transport._closing is False
assert transport._connect_called is False
transport.connect()
assert transport._connect_called is True
with pytest.warns(salt.transport.base.TransportWarning):
del transport
28 changes: 28 additions & 0 deletions tests/pytests/unit/transport/test_zeromq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,3 +1498,31 @@ def test_pub_client_init(minion_opts, io_loop):
client = salt.transport.zeromq.PublishClient(minion_opts, io_loop)
client.send(b"asf")
client.close()


async def test_unclosed_request_client(minion_opts, io_loop):
minion_opts["master_uri"] = "tcp://127.0.0.1:4506"
client = salt.transport.zeromq.RequestClient(minion_opts, io_loop)
await client.connect()
try:
assert client._closing is False
with pytest.warns(salt.transport.base.TransportWarning):
client.__del__()
finally:
client.close()


async def test_unclosed_publish_client(minion_opts, io_loop):
minion_opts["id"] = "minion"
minion_opts["__role"] = "minion"
minion_opts["master_ip"] = "127.0.0.1"
minion_opts["zmq_filtering"] = True
minion_opts["zmq_monitor"] = True
client = salt.transport.zeromq.PublishClient(minion_opts, io_loop)
await client.connect(2121)
try:
assert client._closing is False
with pytest.warns(salt.transport.base.TransportWarning):
client.__del__()
finally:
client.close()

0 comments on commit 67b43d3

Please sign in to comment.