Skip to content

Commit

Permalink
Merge pull request #187 from ionutab/master
Browse files Browse the repository at this point in the history
MQTT v5.0 support
  • Loading branch information
cyberw authored Jul 11, 2024
2 parents ac4e47d + 9f9b920 commit 260e6cc
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 8 deletions.
1 change: 0 additions & 1 deletion examples/mqtt_custom_client_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

# extend the MqttClient class with your own custom implementation
class MyMqttClient(MqttClient):

# you can override the event name with your custom implementation
def _generate_event_name(self, event_type: str, qos: int, topic: str):
return f"mqtt:{event_type}:{qos}"
Expand Down
4 changes: 3 additions & 1 deletion examples/mqtt_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from locust.user.wait_time import between
from locust_plugins.users.mqtt import MqttUser


tls_context = ssl.SSLContext(ssl.PROTOCOL_TLS)
tls_context.load_verify_locations(os.environ["LOCUST_MQTT_CAFILE"])

Expand All @@ -23,6 +22,9 @@ class MyUser(MqttUser):
# 10-100 messages per second.
wait_time = between(0.01, 0.1)

# Uncomment below if you need to set MQTTv5
# protocol = paho.mqtt.client.MQTTv5

@task
class MyTasks(TaskSet):
# Sleep for a while to allow the client time to connect.
Expand Down
60 changes: 54 additions & 6 deletions locust_plugins/users/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
missing_extra("paho", "mqtt")

if typing.TYPE_CHECKING:
from paho.mqtt.enums import MQTTProtocolVersion
from paho.mqtt.client import MQTTMessageInfo
from paho.mqtt.properties import Properties
from paho.mqtt.reasoncodes import ReasonCode
from paho.mqtt.subscribeoptions import SubscribeOptions


Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(
*args,
environment: Environment,
client_id: typing.Optional[str] = None,
protocol: MQTTProtocolVersion = mqtt.MQTTv311,
**kwargs,
):
"""Initializes a paho.mqtt.Client for use in Locust swarms.
Expand All @@ -90,6 +93,8 @@ def __init__(
environment: the Locust environment with which to associate events.
client_id: the MQTT Client ID to use in connecting to the broker.
If not set, one will be randomly generated.
protocol: the MQTT protocol version.
defaults to MQTT v3.11.
"""
# If a client ID is not provided, this class will randomly generate an ID
# of the form: `locust-[0-9a-zA-Z]{16}` (i.e., `locust-` followed by 16
Expand All @@ -107,12 +112,18 @@ def __init__(
else:
self.client_id = client_id

super().__init__(*args, client_id=self.client_id, **kwargs)
super().__init__(*args, client_id=self.client_id, protocol=protocol, **kwargs)
self.environment = environment

self.on_publish = self._on_publish_cb
self.on_subscribe = self._on_subscribe_cb
self.on_disconnect = self._on_disconnect_cb
self.on_connect = self._on_connect_cb

if self.protocol == mqtt.MQTTv5:
self.on_disconnect = self._on_disconnect_cb_v5
self.on_connect = self._on_connect_cb_v5
else:
self.on_disconnect = self._on_disconnect_cb_v3x
self.on_connect = self._on_connect_cb_v3x

self._publish_requests: dict[int, PublishedMessageContext] = {}
self._subscribe_requests: dict[int, SubscribeContext] = {}
Expand Down Expand Up @@ -235,6 +246,24 @@ def _on_disconnect_cb(
},
)

def _on_disconnect_cb_v3x(
self,
client: mqtt.Client,
userdata: typing.Any,
rc: int,
):
return self._on_disconnect_cb(client, userdata, rc)

# pylint: disable=unused-argument
def _on_disconnect_cb_v5(
self,
client: mqtt.Client,
userdata: typing.Any,
reasoncode: ReasonCode,
properties: Properties,
):
return self._on_disconnect_cb(client, userdata, reasoncode)

def _on_connect_cb(
self,
client: mqtt.Client,
Expand Down Expand Up @@ -265,6 +294,26 @@ def _on_connect_cb(
},
)

def _on_connect_cb_v3x(
self,
client: mqtt.Client,
userdata: typing.Any,
flags: dict[str, int],
rc: int,
):
return self._on_connect_cb(client, userdata, flags, rc)

# pylint: disable=unused-argument
def _on_connect_cb_v5(
self,
client: mqtt.Client,
userdata: typing.Any,
flags: dict[str, int],
reasoncode: ReasonCode,
properties: Properties,
):
return self._on_connect_cb(client, userdata, flags, reasoncode)

def publish(
self,
topic: str,
Expand Down Expand Up @@ -355,13 +404,12 @@ class MqttUser(User):
client_id = None
username = None
password = None
protocol = mqtt.MQTTv311

def __init__(self, environment: Environment):
super().__init__(environment)
self.client: MqttClient = self.client_cls(
environment=self.environment,
transport=self.transport,
client_id=self.client_id,
environment=self.environment, transport=self.transport, client_id=self.client_id, protocol=self.protocol
)

if self.tls_context:
Expand Down

0 comments on commit 260e6cc

Please sign in to comment.