diff --git a/examples/web_ui_auth/basic.py b/examples/web_ui_auth/basic.py index 89fe665c4b..0c945523ab 100644 --- a/examples/web_ui_auth/basic.py +++ b/examples/web_ui_auth/basic.py @@ -29,9 +29,6 @@ def get_id(self): return self.username -auth_blueprint = Blueprint("auth", "web_ui_auth") - - def load_user(username): return AuthUser(username) @@ -39,12 +36,14 @@ def load_user(username): @events.init.add_listener def locust_init(environment, **_kwargs): if environment.web_ui: + auth_blueprint = Blueprint("auth", "web_ui_auth", url_prefix=environment.parsed_options.web_base_path) + environment.web_ui.login_manager.user_loader(load_user) environment.web_ui.app.config["SECRET_KEY"] = os.getenv("FLASK_SECRET_KEY") environment.web_ui.auth_args = { - "username_password_callback": "/login_submit", + "username_password_callback": f"{environment.parsed_options.web_base_path}/login_submit", "auth_providers": [ { "label": "Github", @@ -61,7 +60,7 @@ def google_login(): session["username"] = username login_user(AuthUser("username")) - return redirect(url_for("index")) + return redirect(url_for("locust.index")) @auth_blueprint.route("/login_submit", methods=["POST"]) def login_submit(): @@ -72,10 +71,10 @@ def login_submit(): if password: login_user(AuthUser(username)) - return redirect(url_for("index")) + return redirect(url_for("locust.index")) session["auth_error"] = "Invalid username or password" - return redirect(url_for("login")) + return redirect(url_for("locust.login")) environment.web_ui.app.register_blueprint(auth_blueprint) diff --git a/examples/web_ui_auth/custom_form.py b/examples/web_ui_auth/custom_form.py index 134fe6d63d..2b29accdfd 100644 --- a/examples/web_ui_auth/custom_form.py +++ b/examples/web_ui_auth/custom_form.py @@ -34,9 +34,6 @@ def get_id(self): return self.username -auth_blueprint = Blueprint("auth", "web_ui_auth") - - def load_user(user_id): return AuthUser(user_id) @@ -44,6 +41,8 @@ def load_user(user_id): @events.init.add_listener def locust_init(environment, **_kwargs): if environment.web_ui: + auth_blueprint = Blueprint("auth", "web_ui_auth", url_prefix=environment.parsed_options.web_base_path) + environment.web_ui.login_manager.user_loader(load_user) environment.web_ui.app.config["SECRET_KEY"] = os.getenv("FLASK_SECRET_KEY") @@ -70,7 +69,7 @@ def locust_init(environment, **_kwargs): "is_secret": True, }, ], - "callback_url": "/login_submit", + "callback_url": f"{environment.parsed_options.web_base_path}/login_submit", "submit_button_text": "Submit", }, } @@ -86,7 +85,7 @@ def login_submit(): if password != confirm_password: session["auth_error"] = "Passwords do not match!" - return redirect(url_for("login")) + return redirect(url_for("locust.login")) # Implement real password verification here if password: @@ -98,10 +97,10 @@ def login_submit(): login_user(AuthUser(username)) - return redirect(url_for("index")) + return redirect(url_for("locust.index")) session["auth_error"] = "Invalid username or password" - return redirect(url_for("login")) + return redirect(url_for("locust.login")) environment.web_ui.app.register_blueprint(auth_blueprint) diff --git a/locust/argument_parser.py b/locust/argument_parser.py index c0e490ba9b..6dc7d3c3f4 100644 --- a/locust/argument_parser.py +++ b/locust/argument_parser.py @@ -612,6 +612,14 @@ def setup_parser_arguments(parser): env_var="LOCUST_MASTER_NODE_PORT", ) + web_ui_group.add_argument( + "--web-base-path", + type=str, + default="", + help="Base path for the web interface (e.g., '/locust'). Default is empty (root path).", + env_var="LOCUST_web_base_path", + ) + tag_group = parser.add_argument_group( "Tag options", "Locust tasks can be tagged using the @tag decorator. These options let specify which tasks to include or exclude during a test.", diff --git a/locust/env.py b/locust/env.py index 9ba161c684..a711fcde1b 100644 --- a/locust/env.py +++ b/locust/env.py @@ -165,6 +165,7 @@ def create_web_ui( self, host="", port=8089, + web_base_path: str | None = None, web_login: bool = False, tls_cert: str | None = None, tls_key: str | None = None, @@ -199,6 +200,7 @@ def create_web_ui( delayed_start=delayed_start, userclass_picker_is_active=userclass_picker_is_active, build_path=build_path, + web_base_path=web_base_path, ) return self.web_ui diff --git a/locust/main.py b/locust/main.py index c48d679c4c..9ba96daa21 100644 --- a/locust/main.py +++ b/locust/main.py @@ -454,7 +454,9 @@ def ensure_user_class_name(config): elif options.worker: try: runner = environment.create_worker_runner(options.master_host, options.master_port) - logger.debug("Connected to locust master: %s:%s", options.master_host, options.master_port) + logger.debug( + "Connected to locust master: %s:%s%s", options.master_host, options.master_port, options.web_base_path + ) except OSError as e: logger.error("Failed to connect to the Locust master: %s", e) sys.exit(-1) @@ -490,26 +492,32 @@ def ensure_user_class_name(config): if not options.headless and not options.worker: protocol = "https" if options.tls_cert and options.tls_key else "http" + if options.web_base_path and options.web_base_path[0] != "/": + logger.error( + f"Invalid format for --web-base-path argument ({options.web_base_path}): the url path must start with a slash." + ) + sys.exit(1) if options.web_host == "*": # special check for "*" so that we're consistent with --master-bind-host web_host = "" else: web_host = options.web_host if web_host: - logger.info(f"Starting web interface at {protocol}://{web_host}:{options.web_port}") + logger.info(f"Starting web interface at {protocol}://{web_host}:{options.web_port}{options.web_base_path}") if options.web_host_display_name: logger.info(f"Starting web interface at {options.web_host_display_name}") else: if os.name == "nt": logger.info( - f"Starting web interface at {protocol}://localhost:{options.web_port} (accepting connections from all network interfaces)" + f"Starting web interface at {protocol}://localhost:{options.web_port}{options.web_base_path} (accepting connections from all network interfaces)" ) else: - logger.info(f"Starting web interface at {protocol}://0.0.0.0:{options.web_port}") + logger.info(f"Starting web interface at {protocol}://0.0.0.0:{options.web_port}{options.web_base_path}") web_ui = environment.create_web_ui( host=web_host, port=options.web_port, + web_base_path=options.web_base_path, web_login=options.web_login, tls_cert=options.tls_cert, tls_key=options.tls_key, diff --git a/locust/runners.py b/locust/runners.py index 8dfa78b8b1..a15154dc06 100644 --- a/locust/runners.py +++ b/locust/runners.py @@ -1216,6 +1216,7 @@ def __init__(self, environment: Environment, master_host: str, master_port: int) self.client_id = socket.gethostname() + "_" + uuid4().hex self.master_host = master_host self.master_port = master_port + self.web_base_path = environment.parsed_options.web_base_path if environment.parsed_options else "" self.logs: list[str] = [] self.worker_cpu_warning_emitted = False self._users_dispatcher: UsersDispatcher | None = None @@ -1475,11 +1476,11 @@ def connect_to_master(self): if not success: if self.retry < 3: logger.debug( - f"Failed to connect to master {self.master_host}:{self.master_port}, retry {self.retry}/{CONNECT_RETRY_COUNT}." + f"Failed to connect to master {self.master_host}:{self.master_port}{self.web_base_path}, retry {self.retry}/{CONNECT_RETRY_COUNT}." ) else: logger.warning( - f"Failed to connect to master {self.master_host}:{self.master_port}, retry {self.retry}/{CONNECT_RETRY_COUNT}." + f"Failed to connect to master {self.master_host}:{self.master_port}{self.web_base_path}, retry {self.retry}/{CONNECT_RETRY_COUNT}." ) if self.retry > CONNECT_RETRY_COUNT: raise ConnectionError() diff --git a/locust/test/test_runners.py b/locust/test/test_runners.py index 62afe0abef..d1e7a5a6b0 100644 --- a/locust/test/test_runners.py +++ b/locust/test/test_runners.py @@ -73,7 +73,6 @@ def send(self, message): self.outbox.append(message) def send_to_client(self, message): - print(message) self.outbox.append(message) @classmethod diff --git a/locust/test/test_web.py b/locust/test/test_web.py index 029aa5d45f..67b1be3fed 100644 --- a/locust/test/test_web.py +++ b/locust/test/test_web.py @@ -55,7 +55,7 @@ def setUp(self): self.stats = self.environment.stats self.web_ui = self.environment.create_web_ui("127.0.0.1", 0) - self.web_ui.app.view_functions["request_stats"].clear_cache() + self.web_ui.app.view_functions["locust.request_stats"].clear_cache() gevent.sleep(0.01) self.web_port = self.web_ui.server.server_port @@ -148,7 +148,7 @@ def test_stats_cache(self): data = json.loads(requests.get("http://127.0.0.1:%i/stats/requests" % self.web_port).text) self.assertEqual(2, len(data["stats"])) # old value should be cached now - self.web_ui.app.view_functions["request_stats"].clear_cache() + self.web_ui.app.view_functions["locust.request_stats"].clear_cache() data = json.loads(requests.get("http://127.0.0.1:%i/stats/requests" % self.web_port).text) self.assertEqual(3, len(data["stats"])) # this should no longer be cached @@ -1208,7 +1208,7 @@ def setUp(self): self.environment, stats.PERCENTILES_TO_REPORT, self.STATS_BASE_NAME, full_history=True ) self.web_ui = self.environment.create_web_ui("127.0.0.1", 0, stats_csv_writer=self.stats_csv_writer) - self.web_ui.app.view_functions["request_stats"].clear_cache() + self.web_ui.app.view_functions["locust.request_stats"].clear_cache() gevent.sleep(0.01) self.web_port = self.web_ui.server.server_port diff --git a/locust/web.py b/locust/web.py index 4dc58a0cb8..2705cf8a1e 100644 --- a/locust/web.py +++ b/locust/web.py @@ -15,6 +15,7 @@ import gevent from flask import ( + Blueprint, Flask, Response, jsonify, @@ -120,6 +121,7 @@ def __init__( environment: Environment, host: str, port: int, + web_base_path: str | None = None, web_login: bool = False, tls_cert: str | None = None, tls_key: str | None = None, @@ -161,20 +163,21 @@ def __init__( self.auth_args = {} self.app.template_folder = build_path or DEFAULT_BUILD_PATH self.app.static_url_path = "/assets/" + + app_blueprint = Blueprint("locust", __name__, url_prefix=web_base_path) # ensures static js files work on Windows mimetypes.add_type("application/javascript", ".js") - if self.web_login: self._login_manager = LoginManager() self._login_manager.init_app(self.app) - self._login_manager.login_view = "login" + self._login_manager.login_view = "locust.login" if environment.runner: self.update_template_args() if not delayed_start: self.start() - @app.errorhandler(Exception) + @app_blueprint.errorhandler(Exception) def handle_exception(error): error_message = str(error) error_code = getattr(error, "code", 500) @@ -184,7 +187,7 @@ def handle_exception(error): ) return make_response(error_message, error_code) - @app.route("/assets/") + @app_blueprint.route("/assets/") def send_assets(path): directory = ( os.path.join(self.app.template_folder, "assets") @@ -194,7 +197,7 @@ def send_assets(path): return send_from_directory(directory, path) - @app.route("/") + @app_blueprint.route("/") @self.auth_required_if_enabled def index() -> str | Response: if not environment.runner: @@ -203,7 +206,7 @@ def index() -> str | Response: return render_template("index.html", template_args=self.template_args) - @app.route("/swarm", methods=["POST"]) + @app_blueprint.route("/swarm", methods=["POST"]) @self.auth_required_if_enabled def swarm() -> Response: assert request.method == "POST" @@ -317,7 +320,7 @@ def swarm() -> Response: else: return jsonify({"success": False, "message": "No runner", "host": environment.host}) - @app.route("/stop") + @app_blueprint.route("/stop") @self.auth_required_if_enabled def stop() -> Response: if self._swarm_greenlet is not None: @@ -327,7 +330,7 @@ def stop() -> Response: environment.runner.stop() return jsonify({"success": True, "message": "Test stopped"}) - @app.route("/stats/reset") + @app_blueprint.route("/stats/reset") @self.auth_required_if_enabled def reset_stats() -> str: environment.events.reset_stats.fire() @@ -336,7 +339,7 @@ def reset_stats() -> str: environment.runner.exceptions = {} return "ok" - @app.route("/stats/report") + @app_blueprint.route("/stats/report") @self.auth_required_if_enabled def stats_report() -> Response: theme = request.args.get("theme", "") @@ -382,7 +385,7 @@ def _download_csv_response(csv_data: str, filename_prefix: str) -> Response: ) return response - @app.route("/stats/requests/csv") + @app_blueprint.route("/stats/requests/csv") @self.auth_required_if_enabled def request_stats_csv() -> Response: data = StringIO() @@ -390,7 +393,7 @@ def request_stats_csv() -> Response: self.stats_csv_writer.requests_csv(writer) return _download_csv_response(data.getvalue(), "requests") - @app.route("/stats/requests_full_history/csv") + @app_blueprint.route("/stats/requests_full_history/csv") @self.auth_required_if_enabled def request_stats_full_history_csv() -> Response: options = self.environment.parsed_options @@ -408,7 +411,7 @@ def request_stats_full_history_csv() -> Response: return make_response("Error: Server was not started with option to generate full history.", 404) - @app.route("/stats/failures/csv") + @app_blueprint.route("/stats/failures/csv") @self.auth_required_if_enabled def failures_stats_csv() -> Response: data = StringIO() @@ -416,7 +419,7 @@ def failures_stats_csv() -> Response: self.stats_csv_writer.failures_csv(writer) return _download_csv_response(data.getvalue(), "failures") - @app.route("/stats/requests") + @app_blueprint.route("/stats/requests") @self.auth_required_if_enabled @memoize(timeout=DEFAULT_CACHE_TIME, dynamic_timeout=True) def request_stats() -> Response: @@ -486,7 +489,7 @@ def request_stats() -> Response: return jsonify(report) - @app.route("/exceptions") + @app_blueprint.route("/exceptions") @self.auth_required_if_enabled def exceptions() -> Response: return jsonify( @@ -503,7 +506,7 @@ def exceptions() -> Response: } ) - @app.route("/exceptions/csv") + @app_blueprint.route("/exceptions/csv") @self.auth_required_if_enabled def exceptions_csv() -> Response: data = StringIO() @@ -511,7 +514,7 @@ def exceptions_csv() -> Response: self.stats_csv_writer.exceptions_csv(writer) return _download_csv_response(data.getvalue(), "exceptions") - @app.route("/tasks") + @app_blueprint.route("/tasks") @self.auth_required_if_enabled def tasks() -> dict[str, dict[str, dict[str, float]]]: runner = self.environment.runner @@ -531,15 +534,15 @@ def tasks() -> dict[str, dict[str, dict[str, float]]]: } return task_data - @app.route("/logs") + @app_blueprint.route("/logs") @self.auth_required_if_enabled def logs(): return jsonify({"master": get_logs(), "workers": self.environment.worker_logs}) - @app.route("/login") + @app_blueprint.route("/login") def login(): if not self.web_login: - return redirect(url_for("index")) + return redirect(url_for("locust.index")) self.auth_args["error"] = session.get("auth_error", None) self.auth_args["info"] = session.get("auth_info", None) @@ -549,7 +552,7 @@ def login(): auth_args=self.auth_args, ) - @app.route("/user", methods=["POST"]) + @app_blueprint.route("/user", methods=["POST"]) def update_user(): assert request.method == "POST" @@ -558,6 +561,8 @@ def update_user(): return {}, 201 + app.register_blueprint(app_blueprint) + @property def login_manager(self): if self.web_login: