From 7795dcdb0ab0258aab3e5ed9fe971cdfcb65b34e Mon Sep 17 00:00:00 2001 From: Pierre Fersing Date: Mon, 29 Apr 2024 21:14:41 +0200 Subject: [PATCH] Fix publish() a bytearray payload --- src/paho/mqtt/client.py | 6 ++--- tests/paho_test.py | 3 ++- tests/test_client.py | 53 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index 4dfd2f76..4ccc8696 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -465,7 +465,7 @@ def _force_bytes(s: str | bytes) -> bytes: return s -def _encode_payload(payload: str | bytes | bytearray | int | float | None) -> bytes: +def _encode_payload(payload: str | bytes | bytearray | int | float | None) -> bytes|bytearray: if isinstance(payload, str): return payload.encode("utf-8") @@ -3368,7 +3368,7 @@ def _send_publish( self, mid: int, topic: bytes, - payload: bytes = b"", + payload: bytes|bytearray = b"", qos: int = 0, retain: bool = False, dup: bool = False, @@ -3378,7 +3378,7 @@ def _send_publish( # we assume that topic and payload are already properly encoded if not isinstance(topic, bytes): raise TypeError('topic must be bytes, not str') - if payload and not isinstance(payload, bytes): + if payload and not isinstance(payload, (bytes, bytearray)): raise TypeError('payload must be bytes if set') if self._sock is None: diff --git a/tests/paho_test.py b/tests/paho_test.py index d0de2d88..40e950a4 100644 --- a/tests/paho_test.py +++ b/tests/paho_test.py @@ -228,7 +228,8 @@ def gen_publish(topic, qos, payload=None, retain=False, dup=False, mid=0, proto_ pack_format = pack_format + "%ds"%(len(properties)) if payload is not None: - payload = payload.encode("utf-8") + if isinstance(payload, str): + payload = payload.encode("utf-8") rl = rl + len(payload) pack_format = pack_format + str(len(payload)) + "s" else: diff --git a/tests/test_client.py b/tests/test_client.py index 0f637ff7..09e46066 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -427,6 +427,59 @@ def on_connect(mqttc, obj, flags, rc): packet_in = fake_broker.receive_packet(1) assert not packet_in # Check connection is closed + @pytest.mark.parametrize("user_payload,sent_payload", [ + ("string", b"string"), + (b"byte", b"byte"), + (bytearray(b"bytearray"), b"bytearray"), + (42, b"42"), + (4.2, b"4.2"), + (None, b""), + ]) + def test_publish_various_payload(self, user_payload: client.PayloadType, sent_payload: bytes, fake_broker: FakeBroker) -> None: + mqttc = client.Client( + CallbackAPIVersion.VERSION2, + "test_publish_various_payload", + transport=fake_broker.transport, + ) + + mqttc.connect("localhost", fake_broker.port) + mqttc.loop_start() + mqttc.enable_logger() + + try: + fake_broker.start() + + connect_packet = paho_test.gen_connect( + "test_publish_various_payload", keepalive=60, + proto_ver=client.MQTTv311) + fake_broker.expect_packet("connect", connect_packet) + + connack_packet = paho_test.gen_connack(rc=0) + count = fake_broker.send_packet(connack_packet) + assert count # Check connection was not closed + assert count == len(connack_packet) + + mqttc.publish("test", user_payload) + + publish_packet = paho_test.gen_publish( + b"test", payload=sent_payload, qos=0 + ) + fake_broker.expect_packet("publish", publish_packet) + + mqttc.disconnect() + + disconnect_packet = paho_test.gen_disconnect() + packet_in = fake_broker.receive_packet(1000) + assert packet_in # Check connection was not closed + assert packet_in == disconnect_packet + + finally: + mqttc.loop_stop() + + packet_in = fake_broker.receive_packet(1) + assert not packet_in # Check connection is closed + + @pytest.mark.parametrize("callback_version", [ (CallbackAPIVersion.VERSION1), (CallbackAPIVersion.VERSION2),