Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ack for worker connection #2077

Merged
merged 9 commits into from
Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions locust/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import gevent
import greenlet
import psutil
from gevent.event import Event
from gevent.pool import Group

from . import User
Expand Down Expand Up @@ -61,6 +62,8 @@
HEARTBEAT_LIVENESS = 3
HEARTBEAT_DEAD_INTERNAL = -60
FALLBACK_INTERVAL = 5
CONNECTION_TIMEOUT = 5
CONNECTION_RETRY_COUNT = 2


greenlet_exception_handler = greenlet_exception_logger(logger)
Expand Down Expand Up @@ -920,6 +923,7 @@ def client_listener(self):
logger.warning(
f"A worker ({client_id}) running a different version ({msg.data}) connected, master version is {__version__}"
)
self.send_message('ack', client_id=client_id)
worker_node_id = msg.node_id
self.clients[worker_node_id] = WorkerNode(worker_node_id, heartbeat_liveness=HEARTBEAT_LIVENESS)
if self._users_dispatcher is not None:
Expand Down Expand Up @@ -1045,16 +1049,19 @@ def __init__(self, environment, master_host, master_port):
:param master_port: Port to use for connecting to the master
"""
super().__init__(environment)
self.retry = 0
self.connected = False
self.connection_event = Event()
self.worker_state = STATE_INIT
self.client_id = socket.gethostname() + "_" + uuid4().hex
self.master_host = master_host
self.master_port = master_port
self.worker_cpu_warning_emitted = False
self._users_dispatcher = None
self.client = rpc.Client(master_host, master_port, self.client_id)
self.greenlet.spawn(self.heartbeat).link_exception(greenlet_exception_handler)
self.greenlet.spawn(self.worker).link_exception(greenlet_exception_handler)
self.client.send(Message("client_ready", __version__, self.client_id))
self.connect_to_master()
self.greenlet.spawn(self.heartbeat).link_exception(greenlet_exception_handler)
self.greenlet.spawn(self.stats_reporter).link_exception(greenlet_exception_handler)

# register listener for when all users have spawned, and report it to the master node
Expand Down Expand Up @@ -1167,7 +1174,9 @@ def worker(self):
except RPCError as e:
logger.error(f"RPCError found when receiving from master: {e}")
continue
if msg.type == "spawn":
if msg.type == "ack":
self.connection_event.set()
elif msg.type == "spawn":
self.client.send(Message("spawning", None, self.client_id))
job = msg.data
if job["timestamp"] <= last_received_spawn_timestamp:
Expand Down Expand Up @@ -1241,6 +1250,16 @@ def _send_stats(self):
self.environment.events.report_to_master.fire(client_id=self.client_id, data=data)
self.client.send(Message("stats", data, self.client_id))

def connect_to_master(self):
self.retry += 1
self.client.send(Message("client_ready", __version__, self.client_id))
success = self.connection_event.wait(timeout=CONNECTION_TIMEOUT)
if not success:
if self.retry > CONNECTION_RETRY_COUNT:
raise ConnectionError()
self.connect_to_master()
self.connected = True


def _format_user_classes_count_for_log(user_classes_count: Dict[str, int]) -> str:
return "{} ({} total users)".format(
Expand Down
57 changes: 49 additions & 8 deletions locust/test/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2308,8 +2308,9 @@ def my_task(self):
)

master.start(100, 20)
self.assertEqual(5, len(server.outbox))
for i, (_, msg) in enumerate(server.outbox.copy()):
self.assertEqual(6, len(server.outbox))
# First element of the outbox list is ack msg. That is why it is skipped in for loop
for i, (_, msg) in enumerate(server.outbox[1:].copy()):
self.assertDictEqual({"TestUser": int((i + 1) * 20)}, msg.data["user_classes_count"])
server.outbox.pop()

Expand All @@ -2321,7 +2322,7 @@ def my_task(self):
server.mocked_send(Message("client_ready", __version__, "zeh_fake_client2"))
self.assertEqual(2, len(master.clients))
sleep(0.1) # give time for messages to be sent to clients
self.assertEqual(2, len(server.outbox))
self.assertEqual(4, len(server.outbox))
cyberw marked this conversation as resolved.
Show resolved Hide resolved
client_id, msg = server.outbox.pop()
self.assertEqual({"TestUser": 50}, msg.data["user_classes_count"])
client_id, msg = server.outbox.pop()
Expand Down Expand Up @@ -2370,7 +2371,7 @@ def on_test_start(*a, **kw):
server.mocked_send(Message("client_ready", __version__, "fake_client%i" % i))

master.start(7, 7)
self.assertEqual(5, len(server.outbox))
self.assertEqual(10, len(server.outbox))
self.assertEqual(1, run_count[0])

# change number of users and check that test_start isn't fired again
Expand Down Expand Up @@ -2409,7 +2410,7 @@ def on_test_stop(*_, **__):
server.mocked_send(Message("client_ready", __version__, "fake_client%i" % i))

master.start(7, 7)
self.assertEqual(5, len(server.outbox))
self.assertEqual(10, len(server.outbox))
master.stop()
self.assertTrue(self.runner_stopping)
self.assertTrue(self.runner_stopped)
Expand Down Expand Up @@ -2448,7 +2449,7 @@ def on_test_stop(*_, **__):
server.mocked_send(Message("client_ready", __version__, "fake_client%i" % i))

master.start(7, 7)
self.assertEqual(5, len(server.outbox))
self.assertEqual(10, len(server.outbox))
master.quit()
self.assertTrue(self.runner_stopping)
self.assertTrue(self.runner_stopped)
Expand Down Expand Up @@ -2495,7 +2496,7 @@ def my_task(self):
server.mocked_send(Message("client_ready", __version__, "fake_client%i" % i))

master.start(7, 7)
self.assertEqual(5, len(server.outbox))
self.assertEqual(10, len(server.outbox))

num_users = sum(sum(msg.data["user_classes_count"].values()) for _, msg in server.outbox if msg.data)

Expand All @@ -2514,7 +2515,7 @@ def my_task(self):
server.mocked_send(Message("client_ready", __version__, "fake_client%i" % i))

master.start(2, 2)
self.assertEqual(5, len(server.outbox))
self.assertEqual(10, len(server.outbox))

num_users = sum(sum(msg.data["user_classes_count"].values()) for _, msg in server.outbox if msg.data)

Expand Down Expand Up @@ -2842,6 +2843,7 @@ def the_task(self):

with mock.patch("locust.rpc.rpc.Client", mocked_rpc()) as client:
environment = Environment()
client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyTestUser])
self.assertEqual(1, len(client.outbox))
self.assertEqual("client_ready", client.outbox[0].type)
Expand Down Expand Up @@ -2883,6 +2885,7 @@ def the_task(self):

with mock.patch("locust.rpc.rpc.Client", mocked_rpc()) as client:
environment = Environment(stop_timeout=None)
client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyTestUser])
self.assertEqual(1, len(client.outbox))
self.assertEqual("client_ready", client.outbox[0].type)
Expand Down Expand Up @@ -2929,6 +2932,7 @@ def my_task(self):

with mock.patch("locust.rpc.rpc.Client", mocked_rpc()) as client:
environment = Environment()
client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyUser])

client.mocked_send(
Expand Down Expand Up @@ -3025,6 +3029,7 @@ def my_task(self):

with mock.patch("locust.rpc.rpc.Client", mocked_rpc()) as client:
environment = Environment()
client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyUser])

client.mocked_send(
Expand Down Expand Up @@ -3077,6 +3082,7 @@ def my_task(self):

with mock.patch("locust.rpc.rpc.Client", mocked_rpc()) as client:
environment = Environment()
client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyUser])

t0 = time.perf_counter()
Expand Down Expand Up @@ -3108,6 +3114,7 @@ def my_task(self):

with mock.patch("locust.rpc.rpc.Client", mocked_rpc()) as client:
environment = Environment()
client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyUser])

client.mocked_send(
Expand Down Expand Up @@ -3160,6 +3167,7 @@ def my_task(self):

with mock.patch("locust.rpc.rpc.Client", mocked_rpc()) as client:
environment = Environment()
client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyUser1, MyUser2])

client.mocked_send(
Expand Down Expand Up @@ -3210,6 +3218,7 @@ def my_task(self):

with mock.patch("locust.rpc.rpc.Client", mocked_rpc()) as client:
environment = Environment()
client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyUser])
client.outbox.clear()
worker.send_message("test_custom_msg", {"test_data": 123})
Expand All @@ -3234,6 +3243,7 @@ def on_custom_msg(msg, **kw):
test_custom_msg[0] = True
test_custom_msg_data[0] = msg.data

client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyUser])
worker.register_message("test_custom_msg", on_custom_msg)

Expand All @@ -3259,6 +3269,7 @@ def my_task(self):
def on_custom_msg(msg, **kw):
test_custom_msg[0] = True

client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyUser])
worker.register_message("test_custom_msg", on_custom_msg)

Expand Down Expand Up @@ -3287,6 +3298,7 @@ def the_task(self):
def on_test_start(*args, **kw):
run_count[0] += 1

client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyTestUser])
self.assertEqual(1, len(client.outbox))
self.assertEqual("client_ready", client.outbox[0].type)
Expand Down Expand Up @@ -3373,6 +3385,7 @@ def the_task(self):
def on_test_stop(*args, **kw):
run_count[0] += 1

client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=environment, user_classes=[MyTestUser])
self.assertEqual(1, len(client.outbox))
self.assertEqual("client_ready", client.outbox[0].type)
Expand Down Expand Up @@ -3431,6 +3444,34 @@ def on_test_stop(*args, **kw):
gevent.sleep(0.01)
self.assertEqual(2, run_count[0])

def test_worker_connect_success(self):
class MyTestUser(User):
@task
def the_task(self):
pass

with mock.patch("locust.runners.CONNECTION_TIMEOUT", new=1):
with mock.patch("locust.rpc.rpc.Client", mocked_rpc()) as client:
client.mocked_send(Message('ack', {}, "dummy_client_id"))
worker = self.get_runner(environment=Environment(), user_classes=[MyTestUser])

self.assertEqual('client_ready', client.outbox[0].type)
self.assertEqual(1, len(client.outbox))
self.assertTrue(worker.connected)

def test_worker_connect_failure(self):
class MyTestUser(User):
@task
def the_task(self):
pass

with mock.patch("locust.runners.CONNECTION_TIMEOUT", new=0.01):
with mock.patch("locust.runners.CONNECTION_RETRY_COUNT", new=1):
with mock.patch("locust.rpc.rpc.Client", mocked_rpc()) as client:
with self.assertRaises(ConnectionError):
self.get_runner(environment=Environment(), user_classes=[MyTestUser])
self.assertEqual(2, len(client.outbox))


class TestMessageSerializing(unittest.TestCase):
def test_message_serialize(self):
Expand Down